diff options
author | TatWai Chong <tatwai.chong@arm.com> | 2024-02-12 16:53:23 -0800 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-02-22 02:08:36 +0000 |
commit | c7bfa58c76e73aac772f714d8ae04cc875715689 (patch) | |
tree | 0491f8466ad6463ec03cbb2c80ccaa416d940b4a /reference_model/src/ops/ewise_binary.h | |
parent | 2c34b4616a10539211e7006bc43f3c71e86c30bb (diff) | |
download | reference_model-c7bfa58c76e73aac772f714d8ae04cc875715689.tar.gz |
Change the shift of mul to tensor type
Right shift result on i32_t data type only, i.e. other data types
don't carry the shift operand.
In the spec, the shift type is a tensor in MT profile and is an
attribute in BI/MI profiles. Currently we treat the shift as tensor
throughout.
In implementation, since `ternaryExpr` is not implemented in Eigen,
decompose the original calculation into multiply and shift operation
seperately, and execute them via `binaryExpr`.
Change-Id: I349f4969545134ac5f13bc83032cd75cca3e7ba0
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Diffstat (limited to 'reference_model/src/ops/ewise_binary.h')
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 1215c93..8d2e486 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -159,18 +159,33 @@ public: OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_) { - INIT_ATTRIBUTE(Mul); + if constexpr (InDtype == TOSA_REF_TYPE_INT32) + { + // Require `shift` operand. + this->setRequiredOperands(3, 1); + } register_fcn(); } - virtual ~OpMul(); static constexpr int64_t QMin = GetQMin<OutDtype>::value; static constexpr int64_t QMax = GetQMax<OutDtype>::value; - using InEigenType = typename GetEigenType<InDtype>::type; - using OutEigenType = typename GetEigenType<OutDtype>::type; - virtual int register_fcn(); + + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using ShiftEigenType = typename GetEigenType<TOSA_REF_TYPE_INT8>::type; + + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + using TShift = Eigen::Tensor<ShiftEigenType, 0>; + + int register_fcn(); + int eval(); + + // Note that INT64 is not natively supported in Dtype system. + std::function<int64_t(InEigenType, InEigenType)> mul_fcn; + std::function<OutEigenType(int64_t, InEigenType)> shr_fcn; protected: - TosaMulAttribute* attribute; + TosaReference::TensorTemplate<TShift>* s; }; template <int Rank, TOSA_REF_TYPE InDtype> |