From c7bfa58c76e73aac772f714d8ae04cc875715689 Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Mon, 12 Feb 2024 16:53:23 -0800 Subject: Change the shift of mul to tensor type Right shift result on i32_t data type only, i.e. other data types don't carry the shift operand. In the spec, the shift type is a tensor in MT profile and is an attribute in BI/MI profiles. Currently we treat the shift as tensor throughout. In implementation, since `ternaryExpr` is not implemented in Eigen, decompose the original calculation into multiply and shift operation seperately, and execute them via `binaryExpr`. Change-Id: I349f4969545134ac5f13bc83032cd75cca3e7ba0 Signed-off-by: TatWai Chong --- reference_model/src/ops/ewise_binary.cc | 125 +++++++++++++++++++++++--------- reference_model/src/ops/ewise_binary.h | 27 +++++-- 2 files changed, 113 insertions(+), 39 deletions(-) diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index b513f9a..ed176f3 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -440,10 +440,101 @@ int OpMinimum::register_fcn() return 0; } +template +int OpMul::eval() +{ + // All cases except in_out_t == int32_t go to the general binary op workflow. + if constexpr (InDtype != TOSA_REF_TYPE_INT32) + { + return BinaryNode::eval(); + } + else + { + std::vector calculated_shape; + this->broadcast(calculated_shape); + + auto result_shape = this->result->getShape(); + ERROR_IF(calculated_shape != result_shape, + "Broadcast_shape failure, calculated_shape and result_shape don't match"); + + TIn ia = this->a->getTensor().broadcast(this->bcast_a); + TIn ib = this->b->getTensor().broadcast(this->bcast_b); + + using TInt64 = Eigen::Tensor; + TInt64 tmp_result = ia.binaryExpr(ib, this->mul_fcn); + + // Retrieve `shift` value and construct a Eigen tensor instance for it. + s = dynamic_cast*>(this->inputs[2]); + ASSERT_MEM(s); + + int shift = s->getTensor()(0); + TIn is(ia); + is.setConstant(shift); + + TOut result = tmp_result.binaryExpr(is, this->shr_fcn); + this->result->getTensor() = result; + + return GraphNode::eval(); + } +} + +// Eigen operators requires tensor operands meet NumDims > 0, partial specialize +// this like we did for the base class. +template <> +int OpMul<0, TOSA_REF_TYPE_INT32, TOSA_REF_TYPE_INT32>::eval() +{ + Eigen::Tensor tmp_result = this->a->getTensor().binaryExpr(this->b->getTensor(), this->mul_fcn); + + // Retrieve `shift` value. + s = dynamic_cast*>(this->inputs[2]); + ASSERT_MEM(s); + + Eigen::Tensor shift; + shift.setConstant(s->getTensor()(0)); + + this->result->getTensor() = tmp_result.binaryExpr(shift, this->shr_fcn); + + return GraphNode::eval(); +} + template int OpMul::register_fcn() { - int32_t shift = attribute->shift(); + // Register evaluation function for in_out_t == int32_t case first as it supports shift + // right to int32_t output. + if constexpr (InDtype == TOSA_REF_TYPE_INT32) + { + // Perform multiplication on int32_t inputs to product int64_t result. + this->mul_fcn = [](InEigenType a, InEigenType b) -> int64_t { + int64_t result = static_cast(a) * static_cast(b); + return result; + }; + + // Convert data from int64_t to int32_t. + this->shr_fcn = [this](int64_t a, InEigenType shift) -> OutEigenType { + int64_t result; + if (shift > 0) + { + int64_t round = INT64_C(1) << (shift - 1); + result = a + round; + result = result >> shift; + + REQUIRE(result >= QMin && result <= QMax, + "OpMul: result %" PRId64 " exceeds valid range [%" PRId64 ", %" PRId64 "]", result, QMin, QMax); + } + else + { + result = a; + int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); + int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); + REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range"); + return static_cast(result); + } + return static_cast(result); + }; + + return 0; + } switch (InDtype) { @@ -455,31 +546,6 @@ int OpMul::register_fcn() case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; break; - case TOSA_REF_TYPE_INT32: - this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType { - int64_t result; - if (shift > 0) - { - int64_t round = INT64_C(1) << (shift - 1); - result = static_cast(a) * static_cast(b) + round; - result = result >> shift; - - REQUIRE(result >= QMin && result <= QMax, - "OpMul: result %" PRId64 " exceeds valid range [%" PRId64 ", %" PRId64 "]", result, QMin, - QMax); - } - else - { - result = static_cast(a) * b; - int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); - int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); - REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range"); - return static_cast(result); - } - - return static_cast(result); - }; - break; case TOSA_REF_TYPE_INT8: case TOSA_REF_TYPE_INT16: this->fcn = [](InEigenType lhs, InEigenType rhs) -> OutEigenType { @@ -497,13 +563,6 @@ int OpMul::register_fcn() return 0; } -template -OpMul::~OpMul() -{ - if (attribute) - delete attribute; -} - template int OpPow::register_fcn() { diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 1215c93..8d2e486 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -159,18 +159,33 @@ public: OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : BinaryNode(sgt_, Op_MUL, id_) { - INIT_ATTRIBUTE(Mul); + if constexpr (InDtype == TOSA_REF_TYPE_INT32) + { + // Require `shift` operand. + this->setRequiredOperands(3, 1); + } register_fcn(); } - virtual ~OpMul(); static constexpr int64_t QMin = GetQMin::value; static constexpr int64_t QMax = GetQMax::value; - using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; - virtual int register_fcn(); + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using ShiftEigenType = typename GetEigenType::type; + + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + using TShift = Eigen::Tensor; + + int register_fcn(); + int eval(); + + // Note that INT64 is not natively supported in Dtype system. + std::function mul_fcn; + std::function shr_fcn; protected: - TosaMulAttribute* attribute; + TosaReference::TensorTemplate* s; }; template -- cgit v1.2.1