aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_binary.h
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-02-28 13:20:05 +0000
committerTatWai Chong <tatwai.chong@arm.com>2024-03-01 13:16:56 -0800
commit0a042997ac24fee1a338e806caf18bd8dfba28f3 (patch)
tree1cfe325d7d775b778873a3940407e68d39c80a48 /reference_model/src/ops/ewise_binary.h
parent3195a665e3f96809a67b4cb04a57330d2bfeb0de (diff)
downloadreference_model-0a042997ac24fee1a338e806caf18bd8dfba28f3.tar.gz
Testing support for MUL with shift as input
Always create the shift as a tensor for all types in testing. In the reference model, set the shift operand to be available for all types, but only read in the shift tensor for i32. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Signed-off-by: TatWai Chong <tatwai.chong@arm.com> Change-Id: Ia267cbf8b63ca0a9c97b38e8fb4db83eeb8c0538
Diffstat (limited to 'reference_model/src/ops/ewise_binary.h')
-rw-r--r--reference_model/src/ops/ewise_binary.h17
1 files changed, 6 insertions, 11 deletions
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 8d2e486..7ebd852 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -159,11 +159,8 @@ public:
OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_)
{
- if constexpr (InDtype == TOSA_REF_TYPE_INT32)
- {
- // Require `shift` operand.
- this->setRequiredOperands(3, 1);
- }
+ // Require `shift` operand.
+ this->setRequiredOperands(3, 1);
register_fcn();
}
static constexpr int64_t QMin = GetQMin<OutDtype>::value;
@@ -173,9 +170,10 @@ public:
using OutEigenType = typename GetEigenType<OutDtype>::type;
using ShiftEigenType = typename GetEigenType<TOSA_REF_TYPE_INT8>::type;
- using TIn = Eigen::Tensor<InEigenType, Rank>;
- using TOut = Eigen::Tensor<OutEigenType, Rank>;
- using TShift = Eigen::Tensor<ShiftEigenType, 0>;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+ using TShiftRank0 = Eigen::Tensor<ShiftEigenType, 0>;
+ using TShiftRank1 = Eigen::Tensor<ShiftEigenType, 1>;
int register_fcn();
int eval();
@@ -183,9 +181,6 @@ public:
// Note that INT64 is not natively supported in Dtype system.
std::function<int64_t(InEigenType, InEigenType)> mul_fcn;
std::function<OutEigenType(int64_t, InEigenType)> shr_fcn;
-
-protected:
- TosaReference::TensorTemplate<TShift>* s;
};
template <int Rank, TOSA_REF_TYPE InDtype>