diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index ed176f3..8cc1319 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -463,11 +463,18 @@ int OpMul<Rank, InDtype, OutDtype>::eval() using TInt64 = Eigen::Tensor<int64_t, Rank>; TInt64 tmp_result = ia.binaryExpr(ib, this->mul_fcn); - // Retrieve `shift` value and construct a Eigen tensor instance for it. - s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]); - ASSERT_MEM(s); + // Retrieve `shift` value and construct a Eigen tensor instance for it. Shift is stored + // as rank-0 tensor in Flatbuffer. + auto s0 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank0>*>(this->inputs[2]); - int shift = s->getTensor()(0); + // Get zero element from rank-0 tensor (i.e. shape = (0,)) in Numpy since `class Tensor` + // currenly has no knowledge of the size of rank-0 tensor. Store rank-1 tensor instead + // for testing. + auto s1 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank1>*>(this->inputs[2]); + + ASSERT_MEM(s0 || s1); + + int shift = s0 ? s0->getTensor()(0) : s1->getTensor()(0); TIn is(ia); is.setConstant(shift); @@ -486,11 +493,12 @@ int OpMul<0, TOSA_REF_TYPE_INT32, TOSA_REF_TYPE_INT32>::eval() Eigen::Tensor<int64_t, 0> tmp_result = this->a->getTensor().binaryExpr(this->b->getTensor(), this->mul_fcn); // Retrieve `shift` value. - s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]); - ASSERT_MEM(s); + auto s0 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank0>*>(this->inputs[2]); + auto s1 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank1>*>(this->inputs[2]); + ASSERT_MEM(s0 || s1); Eigen::Tensor<int64_t, 0> shift; - shift.setConstant(s->getTensor()(0)); + shift.setConstant(s0 ? s0->getTensor()(0) : s1->getTensor()(0)); this->result->getTensor() = tmp_result.binaryExpr(shift, this->shr_fcn); |