aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_binary.h
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2024-02-12 16:53:23 -0800
committerEric Kunze <eric.kunze@arm.com>2024-02-22 02:08:36 +0000
commitc7bfa58c76e73aac772f714d8ae04cc875715689 (patch)
tree0491f8466ad6463ec03cbb2c80ccaa416d940b4a /reference_model/src/ops/ewise_binary.h
parent2c34b4616a10539211e7006bc43f3c71e86c30bb (diff)
downloadreference_model-c7bfa58c76e73aac772f714d8ae04cc875715689.tar.gz
Change the shift of mul to tensor type
Right shift result on i32_t data type only, i.e. other data types don't carry the shift operand. In the spec, the shift type is a tensor in MT profile and is an attribute in BI/MI profiles. Currently we treat the shift as tensor throughout. In implementation, since `ternaryExpr` is not implemented in Eigen, decompose the original calculation into multiply and shift operation seperately, and execute them via `binaryExpr`. Change-Id: I349f4969545134ac5f13bc83032cd75cca3e7ba0 Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Diffstat (limited to 'reference_model/src/ops/ewise_binary.h')
-rw-r--r--reference_model/src/ops/ewise_binary.h27
1 files changed, 21 insertions, 6 deletions
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 1215c93..8d2e486 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -159,18 +159,33 @@ public:
OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_)
{
- INIT_ATTRIBUTE(Mul);
+ if constexpr (InDtype == TOSA_REF_TYPE_INT32)
+ {
+ // Require `shift` operand.
+ this->setRequiredOperands(3, 1);
+ }
register_fcn();
}
- virtual ~OpMul();
static constexpr int64_t QMin = GetQMin<OutDtype>::value;
static constexpr int64_t QMax = GetQMax<OutDtype>::value;
- using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<OutDtype>::type;
- virtual int register_fcn();
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using ShiftEigenType = typename GetEigenType<TOSA_REF_TYPE_INT8>::type;
+
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+ using TShift = Eigen::Tensor<ShiftEigenType, 0>;
+
+ int register_fcn();
+ int eval();
+
+ // Note that INT64 is not natively supported in Dtype system.
+ std::function<int64_t(InEigenType, InEigenType)> mul_fcn;
+ std::function<OutEigenType(int64_t, InEigenType)> shr_fcn;
protected:
- TosaMulAttribute* attribute;
+ TosaReference::TensorTemplate<TShift>* s;
};
template <int Rank, TOSA_REF_TYPE InDtype>