diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.h')
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 1215c93..8d2e486 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -159,18 +159,33 @@ public: OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_) { - INIT_ATTRIBUTE(Mul); + if constexpr (InDtype == TOSA_REF_TYPE_INT32) + { + // Require `shift` operand. + this->setRequiredOperands(3, 1); + } register_fcn(); } - virtual ~OpMul(); static constexpr int64_t QMin = GetQMin<OutDtype>::value; static constexpr int64_t QMax = GetQMax<OutDtype>::value; - using InEigenType = typename GetEigenType<InDtype>::type; - using OutEigenType = typename GetEigenType<OutDtype>::type; - virtual int register_fcn(); + + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using ShiftEigenType = typename GetEigenType<TOSA_REF_TYPE_INT8>::type; + + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + using TShift = Eigen::Tensor<ShiftEigenType, 0>; + + int register_fcn(); + int eval(); + + // Note that INT64 is not natively supported in Dtype system. + std::function<int64_t(InEigenType, InEigenType)> mul_fcn; + std::function<OutEigenType(int64_t, InEigenType)> shr_fcn; protected: - TosaMulAttribute* attribute; + TosaReference::TensorTemplate<TShift>* s; }; template <int Rank, TOSA_REF_TYPE InDtype> |