diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-02-28 13:20:05 +0000 |
---|---|---|
committer | TatWai Chong <tatwai.chong@arm.com> | 2024-03-01 13:16:56 -0800 |
commit | 0a042997ac24fee1a338e806caf18bd8dfba28f3 (patch) | |
tree | 1cfe325d7d775b778873a3940407e68d39c80a48 /reference_model | |
parent | 3195a665e3f96809a67b4cb04a57330d2bfeb0de (diff) | |
download | reference_model-0a042997ac24fee1a338e806caf18bd8dfba28f3.tar.gz |
Testing support for MUL with shift as input
Always create the shift as a tensor for all types in testing.
In the reference model, set the shift operand to be available for
all types, but only read in the shift tensor for i32.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Change-Id: Ia267cbf8b63ca0a9c97b38e8fb4db83eeb8c0538
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/src/generate/generate_fixed_data.cc | 41 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 22 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 17 |
3 files changed, 49 insertions, 31 deletions
diff --git a/reference_model/src/generate/generate_fixed_data.cc b/reference_model/src/generate/generate_fixed_data.cc index 3d4ee3e..b0b6c81 100644 --- a/reference_model/src/generate/generate_fixed_data.cc +++ b/reference_model/src/generate/generate_fixed_data.cc @@ -20,8 +20,22 @@ #include <type_traits> #include <vector> +namespace +{ +template <typename OutType> +bool copyFixedData(const int64_t elements, const std::vector<int32_t> inData, OutType* outData) +{ + for (auto t = 0; t < elements; t++) + { + outData[t] = inData[t]; + } + return true; +} +} // namespace + namespace TosaReference { + bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size) { // Check we support the operator @@ -31,22 +45,23 @@ bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size) return false; } + std::vector<int32_t> inData = cfg.fixedDataInfo.data; + const auto T = TosaReference::numElementsFromShape(cfg.shape); + if (T != static_cast<int64_t>(inData.size())) + { + WARNING("[Generator][FD] Given data size %d does not match output size %d.", inData.size(), T); + return false; + } + switch (cfg.dataType) { case DType::DType_SHAPE: { - int32_t* outData = reinterpret_cast<int32_t*>(data); - std::vector<int32_t> inData = cfg.fixedDataInfo.data; - const auto T = TosaReference::numElementsFromShape(cfg.shape); - if (T != static_cast<int64_t>(inData.size())) - { - WARNING("[Generator][FD] Size does not match."); - return false; - } - for (auto t = 0; t < T; t++) - { - outData[t] = inData[t]; - } - return true; + int32_t* outData = reinterpret_cast<int32_t*>(data); + return copyFixedData(T, inData, outData); + } + case DType::DType_INT8: { + int8_t* outData = reinterpret_cast<int8_t*>(data); + return copyFixedData(T, inData, outData); } default: WARNING("[Generator][FD] Unsupported type."); diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index ed176f3..8cc1319 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -463,11 +463,18 @@ int OpMul<Rank, InDtype, OutDtype>::eval() using TInt64 = Eigen::Tensor<int64_t, Rank>; TInt64 tmp_result = ia.binaryExpr(ib, this->mul_fcn); - // Retrieve `shift` value and construct a Eigen tensor instance for it. - s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]); - ASSERT_MEM(s); + // Retrieve `shift` value and construct a Eigen tensor instance for it. Shift is stored + // as rank-0 tensor in Flatbuffer. + auto s0 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank0>*>(this->inputs[2]); - int shift = s->getTensor()(0); + // Get zero element from rank-0 tensor (i.e. shape = (0,)) in Numpy since `class Tensor` + // currenly has no knowledge of the size of rank-0 tensor. Store rank-1 tensor instead + // for testing. + auto s1 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank1>*>(this->inputs[2]); + + ASSERT_MEM(s0 || s1); + + int shift = s0 ? s0->getTensor()(0) : s1->getTensor()(0); TIn is(ia); is.setConstant(shift); @@ -486,11 +493,12 @@ int OpMul<0, TOSA_REF_TYPE_INT32, TOSA_REF_TYPE_INT32>::eval() Eigen::Tensor<int64_t, 0> tmp_result = this->a->getTensor().binaryExpr(this->b->getTensor(), this->mul_fcn); // Retrieve `shift` value. - s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]); - ASSERT_MEM(s); + auto s0 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank0>*>(this->inputs[2]); + auto s1 = dynamic_cast<TosaReference::TensorTemplate<TShiftRank1>*>(this->inputs[2]); + ASSERT_MEM(s0 || s1); Eigen::Tensor<int64_t, 0> shift; - shift.setConstant(s->getTensor()(0)); + shift.setConstant(s0 ? s0->getTensor()(0) : s1->getTensor()(0)); this->result->getTensor() = tmp_result.binaryExpr(shift, this->shr_fcn); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 8d2e486..7ebd852 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -159,11 +159,8 @@ public: OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_) { - if constexpr (InDtype == TOSA_REF_TYPE_INT32) - { - // Require `shift` operand. - this->setRequiredOperands(3, 1); - } + // Require `shift` operand. + this->setRequiredOperands(3, 1); register_fcn(); } static constexpr int64_t QMin = GetQMin<OutDtype>::value; @@ -173,9 +170,10 @@ public: using OutEigenType = typename GetEigenType<OutDtype>::type; using ShiftEigenType = typename GetEigenType<TOSA_REF_TYPE_INT8>::type; - using TIn = Eigen::Tensor<InEigenType, Rank>; - using TOut = Eigen::Tensor<OutEigenType, Rank>; - using TShift = Eigen::Tensor<ShiftEigenType, 0>; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + using TShiftRank0 = Eigen::Tensor<ShiftEigenType, 0>; + using TShiftRank1 = Eigen::Tensor<ShiftEigenType, 1>; int register_fcn(); int eval(); @@ -183,9 +181,6 @@ public: // Note that INT64 is not natively supported in Dtype system. std::function<int64_t(InEigenType, InEigenType)> mul_fcn; std::function<OutEigenType(int64_t, InEigenType)> shr_fcn; - -protected: - TosaReference::TensorTemplate<TShift>* s; }; template <int Rank, TOSA_REF_TYPE InDtype> |