aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r--reference_model/src/ops/ewise_binary.cc22
-rw-r--r--reference_model/src/ops/ewise_binary.h17
2 files changed, 21 insertions, 18 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index ed176f3..8cc1319 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -463,11 +463,18 @@ int OpMul<Rank, InDtype, OutDtype>::eval()
using TInt64 = Eigen::Tensor<int64_t, Rank>;
TInt64 tmp_result = ia.binaryExpr(ib, this->mul_fcn);
- // Retrieve `shift` value and construct a Eigen tensor instance for it.
- s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]);
- ASSERT_MEM(s);
+ // Retrieve `shift` value and construct a Eigen tensor instance for it. Shift is stored
+ // as rank-0 tensor in Flatbuffer.
+ auto s0 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank0>*>(this->inputs[2]);
- int shift = s->getTensor()(0);
+ // Get zero element from rank-0 tensor (i.e. shape = (0,)) in Numpy since `class Tensor`
+ // currenly has no knowledge of the size of rank-0 tensor. Store rank-1 tensor instead
+ // for testing.
+ auto s1 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank1>*>(this->inputs[2]);
+
+ ASSERT_MEM(s0 || s1);
+
+ int shift = s0 ? s0->getTensor()(0) : s1->getTensor()(0);
TIn is(ia);
is.setConstant(shift);
@@ -486,11 +493,12 @@ int OpMul<0, TOSA_REF_TYPE_INT32, TOSA_REF_TYPE_INT32>::eval()
Eigen::Tensor<int64_t, 0> tmp_result = this->a->getTensor().binaryExpr(this->b->getTensor(), this->mul_fcn);
// Retrieve `shift` value.
- s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]);
- ASSERT_MEM(s);
+ auto s0 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank0>*>(this->inputs[2]);
+ auto s1 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank1>*>(this->inputs[2]);
+ ASSERT_MEM(s0 || s1);
Eigen::Tensor<int64_t, 0> shift;
- shift.setConstant(s->getTensor()(0));
+ shift.setConstant(s0 ? s0->getTensor()(0) : s1->getTensor()(0));
this->result->getTensor() = tmp_result.binaryExpr(shift, this->shr_fcn);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 8d2e486..7ebd852 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -159,11 +159,8 @@ public:
OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_)
{
- if constexpr (InDtype == TOSA_REF_TYPE_INT32)
- {
- // Require `shift` operand.
- this->setRequiredOperands(3, 1);
- }
+ // Require `shift` operand.
+ this->setRequiredOperands(3, 1);
register_fcn();
}
static constexpr int64_t QMin = GetQMin<OutDtype>::value;
@@ -173,9 +170,10 @@ public:
using OutEigenType = typename GetEigenType<OutDtype>::type;
using ShiftEigenType = typename GetEigenType<TOSA_REF_TYPE_INT8>::type;
- using TIn = Eigen::Tensor<InEigenType, Rank>;
- using TOut = Eigen::Tensor<OutEigenType, Rank>;
- using TShift = Eigen::Tensor<ShiftEigenType, 0>;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+ using TShiftRank0 = Eigen::Tensor<ShiftEigenType, 0>;
+ using TShiftRank1 = Eigen::Tensor<ShiftEigenType, 1>;
int register_fcn();
int eval();
@@ -183,9 +181,6 @@ public:
// Note that INT64 is not natively supported in Dtype system.
std::function<int64_t(InEigenType, InEigenType)> mul_fcn;
std::function<OutEigenType(int64_t, InEigenType)> shr_fcn;
-
-protected:
- TosaReference::TensorTemplate<TShift>* s;
};
template <int Rank, TOSA_REF_TYPE InDtype>