diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.h')
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 88 |
1 files changed, 53 insertions, 35 deletions
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 00fb3d9..5bc5630 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -104,7 +104,7 @@ public: virtual int eval(); }; -#define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \ +#define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \ template <int Rank, DType Dtype> \ class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \ { \ @@ -121,41 +121,59 @@ public: virtual int register_fcn(); \ }; -#define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \ - template <int Rank, DType InDtype, DType OutDtype> \ - class Op##Opname : public BinaryNode<Rank, InDtype, OutDtype> \ - { \ - public: \ - Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : BinaryNode<Rank, InDtype, OutDtype>(Op_##OPNAME, qinfo_, id_) \ - { \ - register_fcn(); \ - } \ - static constexpr int32_t QMin = GetQMin<OutDtype>::value; \ - static constexpr int32_t QMax = GetQMax<OutDtype>::value; \ - using InEigenType = typename GetEigenType<InDtype>::type; \ - using OutEigenType = typename GetEigenType<OutDtype>::type; \ - virtual int register_fcn(); \ - }; +DEF_TEMPLATE_BINARY_OP_DEFAULT(Add, ADD) +DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseAnd, BITWISE_AND) +DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseOr, BITWISE_OR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseXor, BITWISE_XOR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalAnd, LOGICAL_AND) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalLeftShift, LOGICAL_LEFT_SHIFT) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalRightShift, LOGICAL_RIGHT_SHIFT) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalOr, LOGICAL_OR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalXor, LOGICAL_XOR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Maximum, MAXIMUM) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Minimum, MINIMUM) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Pow, POW) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB) + +#undef DEF_TEMPLATE_BINARY_OP_DEFAULT + +template <int Rank, DType Dtype> +class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype> +{ +public: + OpArithmeticRightShift(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode<Rank, Dtype, Dtype>(Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_) + { + INIT_ATTRIBUTE(ArithmeticRightShift); + register_fcn(); + } + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + virtual int register_fcn(); -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM) -DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB) - -#undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE -#undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE +protected: + TosaArithmeticRightShiftAttribute* attribute; +}; + +template <int Rank, DType InDtype, DType OutDtype> +class OpMul : public BinaryNode<Rank, InDtype, OutDtype> +{ +public: + OpMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode<Rank, InDtype, OutDtype>(Op_MUL, qinfo_, id_) + { + INIT_ATTRIBUTE(Mul); + register_fcn(); + } + static constexpr int64_t QMin = GetQMin<OutDtype>::value; + static constexpr int64_t QMax = GetQMax<OutDtype>::value; + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + virtual int register_fcn(); + +protected: + TosaMulAttribute* attribute; +}; template <int Rank> class OpTable : public GraphNode |