aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_binary.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/ewise_binary.h')
-rw-r--r--reference_model/src/ops/ewise_binary.h88
1 files changed, 53 insertions, 35 deletions
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 00fb3d9..5bc5630 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -104,7 +104,7 @@ public:
virtual int eval();
};
-#define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \
+#define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \
template <int Rank, DType Dtype> \
class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
{ \
@@ -121,41 +121,59 @@ public:
virtual int register_fcn(); \
};
-#define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \
- template <int Rank, DType InDtype, DType OutDtype> \
- class Op##Opname : public BinaryNode<Rank, InDtype, OutDtype> \
- { \
- public: \
- Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
- : BinaryNode<Rank, InDtype, OutDtype>(Op_##OPNAME, qinfo_, id_) \
- { \
- register_fcn(); \
- } \
- static constexpr int32_t QMin = GetQMin<OutDtype>::value; \
- static constexpr int32_t QMax = GetQMax<OutDtype>::value; \
- using InEigenType = typename GetEigenType<InDtype>::type; \
- using OutEigenType = typename GetEigenType<OutDtype>::type; \
- virtual int register_fcn(); \
- };
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Add, ADD)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseAnd, BITWISE_AND)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseOr, BITWISE_OR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseXor, BITWISE_XOR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalAnd, LOGICAL_AND)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalOr, LOGICAL_OR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalXor, LOGICAL_XOR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Maximum, MAXIMUM)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Minimum, MINIMUM)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Pow, POW)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB)
+
+#undef DEF_TEMPLATE_BINARY_OP_DEFAULT
+
+template <int Rank, DType Dtype>
+class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype>
+{
+public:
+ OpArithmeticRightShift(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, Dtype>(Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_)
+ {
+ INIT_ATTRIBUTE(ArithmeticRightShift);
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM)
-DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB)
-
-#undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE
-#undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE
+protected:
+ TosaArithmeticRightShiftAttribute* attribute;
+};
+
+template <int Rank, DType InDtype, DType OutDtype>
+class OpMul : public BinaryNode<Rank, InDtype, OutDtype>
+{
+public:
+ OpMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, InDtype, OutDtype>(Op_MUL, qinfo_, id_)
+ {
+ INIT_ATTRIBUTE(Mul);
+ register_fcn();
+ }
+ static constexpr int64_t QMin = GetQMin<OutDtype>::value;
+ static constexpr int64_t QMax = GetQMax<OutDtype>::value;
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ virtual int register_fcn();
+
+protected:
+ TosaMulAttribute* attribute;
+};
template <int Rank>
class OpTable : public GraphNode