[Groonga-commit] groonga/grnxx at a9ccee3 [master] Implement UnaryNode of Expression. (#104)

Back to archive index

susumu.yata null+****@clear*****
Tue Dec 16 10:41:26 JST 2014


susumu.yata	2014-11-12 18:53:06 +0900 (Wed, 12 Nov 2014)

  New Revision: a9ccee31aefa7b81582e539821d7dc5385e119b0
  https://github.com/groonga/grnxx/commit/a9ccee31aefa7b81582e539821d7dc5385e119b0

  Message:
    Implement UnaryNode of Expression. (#104)

  Modified files:
    lib/grnxx/impl/expression.cpp
    lib/grnxx/impl/expression.hpp

  Modified: lib/grnxx/impl/expression.cpp (+438 -1)
===================================================================
--- lib/grnxx/impl/expression.cpp    2014-11-12 16:41:19 +0900 (a461aca)
+++ lib/grnxx/impl/expression.cpp    2014-11-12 18:53:06 +0900 (7e324fc)
@@ -428,6 +428,299 @@ class ColumnNode<Float> : public TypedNode<Float> {
   const impl::Column<Value> *column_;
 };
 
+// -- OperatorNode --
+
+template <typename T>
+class OperatorNode : public TypedNode<T> {
+ public:
+  using Value = T;
+
+  OperatorNode() = default;
+  virtual ~OperatorNode() = default;
+
+  NodeType node_type() const {
+    return OPERATOR_NODE;
+  }
+};
+
+// Evaluate "*arg" for "records".
+//
+// The evaluation results are stored into "*arg_values".
+//
+// On failure, throws an exception.
+template <typename T>
+void fill_node_arg_values(ArrayCRef<Record> records,
+                          TypedNode<T> *arg,
+                          Array<T> *arg_values) {
+  size_t old_size = arg_values->size();
+  if (old_size < records.size()) {
+    arg_values->resize(records.size());
+  }
+  switch (arg->node_type()) {
+    case CONSTANT_NODE: {
+      if (old_size < records.size()) {
+        arg->evaluate(records.cref(old_size), arg_values->ref(old_size));
+      }
+      break;
+    }
+    default: {
+      arg->evaluate(records, arg_values->ref(0, records.size()));
+      break;
+    }
+  }
+}
+
+// --- UnaryNode ---
+
+template <typename T, typename U>
+class UnaryNode : public OperatorNode<T> {
+ public:
+  using Value = T;
+  using Arg = U;
+
+  explicit UnaryNode(std::unique_ptr<Node> &&arg)
+      : OperatorNode<Value>(),
+        arg_(static_cast<TypedNode<Arg> *>(arg.release())),
+        arg_values_() {}
+  virtual ~UnaryNode() = default;
+
+ protected:
+  std::unique_ptr<TypedNode<Arg>> arg_;
+  Array<Arg> arg_values_;
+
+  void fill_arg_values(ArrayCRef<Record> records) {
+    fill_node_arg_values(records, arg_.get(), &arg_values_);
+  }
+};
+
+// ---- LogicalNotNode ----
+
+class LogicalNotNode : public UnaryNode<Bool, Bool> {
+ public:
+  using Value = Bool;
+  using Arg = Bool;
+
+  explicit LogicalNotNode(std::unique_ptr<Node> &&arg)
+      : UnaryNode<Value, Arg>(std::move(arg)),
+        temp_records_() {}
+  ~LogicalNotNode() = default;
+
+  void filter(ArrayCRef<Record> input_records,
+              ArrayRef<Record> *output_records);
+  void evaluate(ArrayCRef<Record> records, ArrayRef<Value> results);
+
+ private:
+  Array<Record> temp_records_;
+};
+
+void LogicalNotNode::filter(ArrayCRef<Record> input_records,
+                            ArrayRef<Record> *output_records) {
+  // TODO: Find the best implementation.
+
+  // Apply an argument filter to "input_records" and store the result to
+  // "temp_records_". Then, appends a sentinel to the end.
+  temp_records_.resize(input_records.size() + 1);
+  ArrayRef<Record> ref = temp_records_.ref();
+  arg_->filter(input_records, &ref);
+  temp_records_[ref.size()].row_id = Int::na();
+
+  // Extract records which appear in "input_records" and don't appear in "ref".
+  size_t count = 0;
+  for (size_t i = 0, j = 0; i < input_records.size(); ++i) {
+    if (input_records[i].row_id == ref[i].row_id) {
+      ++j;
+      continue;
+    }
+    (*output_records)[count] = input_records[i];
+    ++count;
+  }
+  *output_records = output_records->ref(0, count);
+}
+
+void LogicalNotNode::evaluate(ArrayCRef<Record> records,
+                              ArrayRef<Value> results) {
+  arg_->evaluate(records, results);
+  for (size_t i = 0; i < records.size(); ++i) {
+    results[i] = !results[i];
+  }
+}
+
+// ---- BitwiseNotNode ----
+
+template <typename T> class BitwiseNotNode;
+
+template <>
+class BitwiseNotNode<Bool> : public UnaryNode<Bool, Bool> {
+ public:
+  using Value = Bool;
+  using Arg = Bool;
+
+  explicit BitwiseNotNode(std::unique_ptr<Node> &&arg)
+      : UnaryNode<Value, Arg>(std::move(arg)) {}
+  ~BitwiseNotNode() = default;
+
+  void filter(ArrayCRef<Record> input_records,
+              ArrayRef<Record> *output_records);
+  void evaluate(ArrayCRef<Record> records, ArrayRef<Value> results);
+};
+
+void BitwiseNotNode<Bool>::filter(ArrayCRef<Record> input_records,
+                                  ArrayRef<Record> *output_records) {
+  fill_arg_values(input_records);
+  size_t count = 0;
+  for (size_t i = 0; i < input_records.size(); ++i) {
+    if (!arg_values_[i]) {
+      (*output_records)[count] = input_records[i];
+      ++count;
+    }
+  }
+  *output_records = output_records->ref(0, count);
+}
+
+void BitwiseNotNode<Bool>::evaluate(ArrayCRef<Record> records,
+                                    ArrayRef<Value> results) {
+  arg_->evaluate(records, results);
+  // TODO: Should be processed per 8 bytes.
+  //       Check the 64-bit boundary and do it!
+  for (size_t i = 0; i < records.size(); ++i) {
+    results[i] = !results[i];
+  }
+}
+
+template <>
+class BitwiseNotNode<Int> : public UnaryNode<Int, Int> {
+ public:
+  using Value = Int;
+  using Arg = Int;
+
+  explicit BitwiseNotNode(std::unique_ptr<Node> &&arg)
+      : UnaryNode<Value, Arg>(std::move(arg)) {}
+  ~BitwiseNotNode() = default;
+
+  void evaluate(ArrayCRef<Record> records, ArrayRef<Value> results);
+};
+
+void BitwiseNotNode<Int>::evaluate(ArrayCRef<Record> records,
+                                   ArrayRef<Value> results) {
+  arg_->evaluate(records, results);
+  for (size_t i = 0; i < records.size(); ++i) {
+    results[i] = ~results[i];
+  }
+}
+
+// ---- PositiveNode ----
+
+// Nothing to do.
+
+// ---- NegativeNode ----
+
+template <typename T>
+class NegativeNode : public UnaryNode<T, T> {
+ public:
+  using Value = T;
+  using Arg = T;
+
+  explicit NegativeNode(std::unique_ptr<Node> &&arg)
+      : UnaryNode<Value, Arg>(std::move(arg)) {}
+  ~NegativeNode() = default;
+
+  void evaluate(ArrayCRef<Record> records, ArrayRef<Value> results);
+};
+
+template <typename T>
+void NegativeNode<T>::evaluate(ArrayCRef<Record> records,
+                               ArrayRef<Value> results) {
+  this->arg_->evaluate(records, results);
+  for (size_t i = 0; i < records.size(); ++i) {
+    results[i] = -results[i];
+  }
+}
+
+template <>
+class NegativeNode<Float> : public UnaryNode<Float, Float> {
+ public:
+  using Value = Float;
+  using Arg = Float;
+
+  explicit NegativeNode(std::unique_ptr<Node> &&arg)
+      : UnaryNode<Value, Arg>(std::move(arg)) {}
+  ~NegativeNode() = default;
+
+  void adjust(ArrayRef<Record> records);
+  void evaluate(ArrayCRef<Record> records, ArrayRef<Value> results);
+};
+
+void NegativeNode<Float>::adjust(ArrayRef<Record> records) {
+  arg_->adjust(records);
+  for (size_t i = 0; i < records.size(); ++i) {
+    records[i].score = -records[i].score;
+  }
+}
+
+void NegativeNode<Float>::evaluate(ArrayCRef<Record> records,
+                                   ArrayRef<Value> results) {
+  arg_->evaluate(records, results);
+  for (size_t i = 0; i < records.size(); ++i) {
+    results[i] = -results[i];
+  }
+}
+
+// ---- ToIntNode ----
+
+class ToIntNode : public UnaryNode<Int, Float> {
+ public:
+  using Value = Int;
+  using Arg = Float;
+
+  explicit ToIntNode(std::unique_ptr<Node> &&arg)
+      : UnaryNode<Value, Arg>(std::move(arg)) {}
+  ~ToIntNode() = default;
+
+  void evaluate(ArrayCRef<Record> records,
+                ArrayRef<Value> results);
+};
+
+void ToIntNode::evaluate(ArrayCRef<Record> records,
+                         ArrayRef<Value> results) {
+  fill_arg_values(records);
+  for (size_t i = 0; i < records.size(); ++i) {
+    // TODO: Typecast inteface must be provided!
+    results[i] = Value(static_cast<int64_t>(arg_values_[i].value()));
+  }
+}
+
+// ---- ToFloatNode ----
+
+class ToFloatNode : public UnaryNode<Float, Int> {
+ public:
+  using Value = Float;
+  using Arg = Int;
+
+  explicit ToFloatNode(std::unique_ptr<Node> &&arg)
+      : UnaryNode<Value, Arg>(std::move(arg)) {}
+  ~ToFloatNode() = default;
+
+  void adjust(ArrayRef<Record> records);
+  void evaluate(ArrayCRef<Record> records, ArrayRef<Value> results);
+};
+
+void ToFloatNode::adjust(ArrayRef<Record> records) {
+  fill_arg_values(records);
+  for (size_t i = 0; i < records.size(); ++i) {
+    // TODO: Typecast inteface must be provided!
+    records[i].score = Value(static_cast<double>(arg_values_[i].value()));
+  }
+}
+
+void ToFloatNode::evaluate(ArrayCRef<Record> records,
+                           ArrayRef<Value> results) {
+  fill_arg_values(records);
+  for (size_t i = 0; i < records.size(); ++i) {
+    // TODO: Typecast inteface must be provided!
+    results[i] = Value(static_cast<double>(arg_values_[i].value()));
+  }
+}
+
 }  // namespace expression
 
 using namespace expression;
@@ -574,6 +867,7 @@ void Expression::_evaluate(ArrayCRef<Record> records, Array<T> *results) {
 
 template <typename T>
 void Expression::_evaluate(ArrayCRef<Record> records, ArrayRef<T> results) {
+std::cout << "TEST!" << std::endl;
   if (T::type() != data_type()) {
     throw "Data type conflict";  // TODO
   }
@@ -640,7 +934,38 @@ void ExpressionBuilder::push_operator(OperatorType operator_type) {
   if (subexpression_builder_) {
     subexpression_builder_->push_operator(operator_type);
   } else {
-    // TODO
+    switch (operator_type) {
+      case LOGICAL_NOT_OPERATOR:
+      case BITWISE_NOT_OPERATOR:
+      case POSITIVE_OPERATOR:
+      case NEGATIVE_OPERATOR:
+      case TO_INT_OPERATOR:
+      case TO_FLOAT_OPERATOR: {
+        return push_unary_operator(operator_type);
+      }
+      case LOGICAL_AND_OPERATOR:
+      case LOGICAL_OR_OPERATOR:
+      case EQUAL_OPERATOR:
+      case NOT_EQUAL_OPERATOR:
+      case LESS_OPERATOR:
+      case LESS_EQUAL_OPERATOR:
+      case GREATER_OPERATOR:
+      case GREATER_EQUAL_OPERATOR:
+      case BITWISE_AND_OPERATOR:
+      case BITWISE_OR_OPERATOR:
+      case BITWISE_XOR_OPERATOR:
+      case PLUS_OPERATOR:
+      case MINUS_OPERATOR:
+      case MULTIPLICATION_OPERATOR:
+      case DIVISION_OPERATOR:
+      case MODULUS_OPERATOR:
+      case SUBSCRIPT_OPERATOR: {
+        return push_binary_operator(operator_type);
+      }
+      default: {
+        throw "Not supported yet";  // TODO
+      }
+    }
   }
 }
 
@@ -698,6 +1023,29 @@ std::unique_ptr<ExpressionInterface> ExpressionBuilder::release(
   throw "Memory allocation failed";  // TODO
 }
 
+void ExpressionBuilder::push_unary_operator(OperatorType operator_type) {
+  if (node_stack_.size() == 0) {
+    throw "No operand";  // TODO
+  }
+  std::unique_ptr<Node> arg = std::move(node_stack_.back());
+  node_stack_.pop_back();
+  std::unique_ptr<Node> node(
+      create_unary_node(operator_type, std::move(arg)));
+  node_stack_.push_back(std::move(node));
+}
+
+void ExpressionBuilder::push_binary_operator(OperatorType operator_type) {
+  if (node_stack_.size() < 2) {
+    throw "Not enough operands";  // TODO
+  }
+  std::unique_ptr<Node> arg1 = std::move(node_stack_[node_stack_.size() - 2]);
+  std::unique_ptr<Node> arg2 = std::move(node_stack_[node_stack_.size() - 1]);
+  node_stack_.resize(node_stack_.size() - 2);
+  std::unique_ptr<Node> node(
+      create_binary_node(operator_type, std::move(arg1), std::move(arg2)));
+  node_stack_.push_back(std::move(node));
+}
+
 void ExpressionBuilder::push_dereference(const ExpressionOptions &options) {
   throw "Not supported yet";  // TODO
 }
@@ -788,5 +1136,94 @@ Node *ExpressionBuilder::create_column_node(
   throw "Memory allocation failed";  // TODO
 }
 
+Node *ExpressionBuilder::create_unary_node(
+    OperatorType operator_type,
+    std::unique_ptr<Node> &&arg) try {
+  switch (operator_type) {
+    case LOGICAL_NOT_OPERATOR: {
+      switch (arg->data_type()) {
+        case BOOL_DATA: {
+          return new LogicalNotNode(std::move(arg));
+        }
+        default: {
+          throw "Invalid data type";  // TODO
+        }
+      }
+    }
+    case BITWISE_NOT_OPERATOR: {
+      switch (arg->data_type()) {
+        case BOOL_DATA: {
+          return new BitwiseNotNode<Bool>(std::move(arg));
+        }
+        case INT_DATA: {
+          return new BitwiseNotNode<Int>(std::move(arg));
+        }
+        default: {
+          throw "Invalid data type";  // TODO
+        }
+      }
+    }
+    case POSITIVE_OPERATOR: {
+      switch (arg->data_type()) {
+        case INT_DATA:
+        case FLOAT_DATA: {
+          // A positive operator does nothing.
+          return arg.release();
+        }
+        default: {
+          throw "Invalid data type";  // TODO
+        }
+      }
+    }
+    case NEGATIVE_OPERATOR: {
+      switch (arg->data_type()) {
+        case INT_DATA: {
+          return new NegativeNode<Int>(std::move(arg));
+        }
+        case FLOAT_DATA: {
+          return new NegativeNode<Float>(std::move(arg));
+        }
+        default: {
+          throw "Invalid data type";  // TODO
+        }
+      }
+    }
+    case TO_INT_OPERATOR: {
+      switch (arg->data_type()) {
+        case FLOAT_DATA: {
+          return new ToIntNode(std::move(arg));
+        }
+        default: {
+          throw "Invalid data type";  // TODO
+        }
+      }
+    }
+    case TO_FLOAT_OPERATOR: {
+      switch (arg->data_type()) {
+        case INT_DATA: {
+          return new ToFloatNode(std::move(arg));
+        }
+        default: {
+          throw "Invalid data type";  // TODO
+        }
+      }
+    }
+    default: {
+      throw "Not supported yet";
+    }
+  }
+} catch (const std::bad_alloc &) {
+  throw "Memory allocation failed";  // TODO
+}
+
+Node *ExpressionBuilder::create_binary_node(
+    OperatorType operator_type,
+    std::unique_ptr<Node> &&arg1,
+    std::unique_ptr<Node> &&arg2) try {
+  throw "Not supported yet";  // TODO
+} catch (const std::bad_alloc &) {
+  throw "Memory allocation failed";  // TODO
+}
+
 }  // namespace impl
 }  // namespace grnxx

  Modified: lib/grnxx/impl/expression.hpp (+24 -0)
===================================================================
--- lib/grnxx/impl/expression.hpp    2014-11-12 16:41:19 +0900 (82f222e)
+++ lib/grnxx/impl/expression.hpp    2014-11-12 18:53:06 +0900 (08fc2d8)
@@ -100,6 +100,16 @@ class ExpressionBuilder : public ExpressionBuilderInterface {
   Array<std::unique_ptr<Node>> node_stack_;
   std::unique_ptr<ExpressionBuilder> subexpression_builder_;
 
+  // Push a node associated with a unary operator.
+  //
+  // On failure, throws an exception.
+  void push_unary_operator(OperatorType operator_type);
+
+  // Push a node associated with a binary operator.
+  //
+  // On failure, throws an exception.
+  void push_binary_operator(OperatorType operator_type);
+
   // Push a node associated with the dereference operator.
   //
   // On failure, throws an exception.
@@ -115,6 +125,20 @@ class ExpressionBuilder : public ExpressionBuilderInterface {
   // On failure, throws an exception.
   Node *create_column_node(const String &name);
 
+  // Create a node associated with a unary operator.
+  //
+  // On failure, throws an exception.
+  Node *create_unary_node(
+      OperatorType operator_type,
+      std::unique_ptr<Node> &&arg);
+
+  // Create a node associated with a binary operator.
+  //
+  // On failure, throws an exception.
+  Node *create_binary_node(
+      OperatorType operator_type,
+      std::unique_ptr<Node> &&arg1,
+      std::unique_ptr<Node> &&arg2);
 };
 
 }  // namespace impl
-------------- next part --------------
HTML����������������������������...
다운로드 



More information about the Groonga-commit mailing list
Back to archive index