diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.h')
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 17 |
1 files changed, 6 insertions, 11 deletions
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 8d2e486..7ebd852 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -159,11 +159,8 @@ public: OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_) { - if constexpr (InDtype == TOSA_REF_TYPE_INT32) - { - // Require `shift` operand. - this->setRequiredOperands(3, 1); - } + // Require `shift` operand. + this->setRequiredOperands(3, 1); register_fcn(); } static constexpr int64_t QMin = GetQMin<OutDtype>::value; @@ -173,9 +170,10 @@ public: using OutEigenType = typename GetEigenType<OutDtype>::type; using ShiftEigenType = typename GetEigenType<TOSA_REF_TYPE_INT8>::type; - using TIn = Eigen::Tensor<InEigenType, Rank>; - using TOut = Eigen::Tensor<OutEigenType, Rank>; - using TShift = Eigen::Tensor<ShiftEigenType, 0>; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + using TShiftRank0 = Eigen::Tensor<ShiftEigenType, 0>; + using TShiftRank1 = Eigen::Tensor<ShiftEigenType, 1>; int register_fcn(); int eval(); @@ -183,9 +181,6 @@ public: // Note that INT64 is not natively supported in Dtype system. std::function<int64_t(InEigenType, InEigenType)> mul_fcn; std::function<OutEigenType(int64_t, InEigenType)> shr_fcn; - -protected: - TosaReference::TensorTemplate<TShift>* s; }; template <int Rank, TOSA_REF_TYPE InDtype> |