diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 125 |
1 files changed, 92 insertions, 33 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index b513f9a..ed176f3 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -441,9 +441,100 @@ int OpMinimum<Rank, Dtype>::register_fcn() } template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> +int OpMul<Rank, InDtype, OutDtype>::eval() +{ + // All cases except in_out_t == int32_t go to the general binary op workflow. + if constexpr (InDtype != TOSA_REF_TYPE_INT32) + { + return BinaryNode<Rank, InDtype, OutDtype>::eval(); + } + else + { + std::vector<int> calculated_shape; + this->broadcast(calculated_shape); + + auto result_shape = this->result->getShape(); + ERROR_IF(calculated_shape != result_shape, + "Broadcast_shape failure, calculated_shape and result_shape don't match"); + + TIn ia = this->a->getTensor().broadcast(this->bcast_a); + TIn ib = this->b->getTensor().broadcast(this->bcast_b); + + using TInt64 = Eigen::Tensor<int64_t, Rank>; + TInt64 tmp_result = ia.binaryExpr(ib, this->mul_fcn); + + // Retrieve `shift` value and construct a Eigen tensor instance for it. + s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]); + ASSERT_MEM(s); + + int shift = s->getTensor()(0); + TIn is(ia); + is.setConstant(shift); + + TOut result = tmp_result.binaryExpr(is, this->shr_fcn); + this->result->getTensor() = result; + + return GraphNode::eval(); + } +} + +// Eigen operators requires tensor operands meet NumDims > 0, partial specialize +// this like we did for the base class. +template <> +int OpMul<0, TOSA_REF_TYPE_INT32, TOSA_REF_TYPE_INT32>::eval() +{ + Eigen::Tensor<int64_t, 0> tmp_result = this->a->getTensor().binaryExpr(this->b->getTensor(), this->mul_fcn); + + // Retrieve `shift` value. + s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]); + ASSERT_MEM(s); + + Eigen::Tensor<int64_t, 0> shift; + shift.setConstant(s->getTensor()(0)); + + this->result->getTensor() = tmp_result.binaryExpr(shift, this->shr_fcn); + + return GraphNode::eval(); +} + +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int OpMul<Rank, InDtype, OutDtype>::register_fcn() { - int32_t shift = attribute->shift(); + // Register evaluation function for in_out_t == int32_t case first as it supports shift + // right to int32_t output. + if constexpr (InDtype == TOSA_REF_TYPE_INT32) + { + // Perform multiplication on int32_t inputs to product int64_t result. + this->mul_fcn = [](InEigenType a, InEigenType b) -> int64_t { + int64_t result = static_cast<int64_t>(a) * static_cast<int64_t>(b); + return result; + }; + + // Convert data from int64_t to int32_t. + this->shr_fcn = [this](int64_t a, InEigenType shift) -> OutEigenType { + int64_t result; + if (shift > 0) + { + int64_t round = INT64_C(1) << (shift - 1); + result = a + round; + result = result >> shift; + + REQUIRE(result >= QMin && result <= QMax, + "OpMul: result %" PRId64 " exceeds valid range [%" PRId64 ", %" PRId64 "]", result, QMin, QMax); + } + else + { + result = a; + int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max()); + int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min()); + REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range"); + return static_cast<InEigenType>(result); + } + return static_cast<OutEigenType>(result); + }; + + return 0; + } switch (InDtype) { @@ -455,31 +546,6 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn() case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; break; - case TOSA_REF_TYPE_INT32: - this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType { - int64_t result; - if (shift > 0) - { - int64_t round = INT64_C(1) << (shift - 1); - result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round; - result = result >> shift; - - REQUIRE(result >= QMin && result <= QMax, - "OpMul: result %" PRId64 " exceeds valid range [%" PRId64 ", %" PRId64 "]", result, QMin, - QMax); - } - else - { - result = static_cast<int64_t>(a) * b; - int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max()); - int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min()); - REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range"); - return static_cast<InEigenType>(result); - } - - return static_cast<OutEigenType>(result); - }; - break; case TOSA_REF_TYPE_INT8: case TOSA_REF_TYPE_INT16: this->fcn = [](InEigenType lhs, InEigenType rhs) -> OutEigenType { @@ -497,13 +563,6 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn() return 0; } -template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> -OpMul<Rank, InDtype, OutDtype>::~OpMul() -{ - if (attribute) - delete attribute; -} - template <int Rank, TOSA_REF_TYPE Dtype> int OpPow<Rank, Dtype>::register_fcn() { |