diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 50 |
1 files changed, 28 insertions, 22 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 8578527..3fa4194 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -23,9 +23,7 @@ using namespace Eigen; using namespace tosa; template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> -BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_, - const Op& op_, - uint64_t id_) +BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) : GraphNode(sgt_, op_, id_) { setRequiredOperands(2, 1); @@ -100,11 +98,16 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast(std::vector<int>& calcula // calculates the broadcasted output shape calculated_shape = a_shape; - for (size_t i = 0; i < calculated_shape.size(); i++) { - if (calculated_shape[i] == 1) { + for (size_t i = 0; i < calculated_shape.size(); i++) + { + if (calculated_shape[i] == 1) + { calculated_shape[i] = b_shape[i]; - } else { - ERROR_IF(b_shape[i] != 1 && b_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + else + { + ERROR_IF(b_shape[i] != 1 && b_shape[i] != calculated_shape[i], + "Broadcast_shape failure, input shapes are not compatible"); } } @@ -118,7 +121,8 @@ int BinaryNode<Rank, InDtype, OutDtype>::eval() this->broadcast(calculated_shape); auto result_shape = this->result->getShape(); - ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match"); + ERROR_IF(calculated_shape != result_shape, + "Broadcast_shape failure, calculated_shape and result_shape don't match"); Eigen::array<int, Rank> reshaper; reshaper.fill(1); @@ -210,7 +214,8 @@ int OpArithmeticRightShift<Rank, Dtype>::register_fcn() template <int Rank, TOSA_REF_TYPE Dtype> OpArithmeticRightShift<Rank, Dtype>::~OpArithmeticRightShift() { - if (attribute) delete attribute; + if (attribute) + delete attribute; } template <int Rank, TOSA_REF_TYPE Dtype> @@ -309,21 +314,21 @@ int OpLogicalLeftShift<Rank, Dtype>::register_fcn() case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); + (int32_t)b); return static_cast<OutEigenType>(static_cast<int8_t>(a << b)); }; break; case TOSA_REF_TYPE_INT16: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); + (int32_t)b); return static_cast<OutEigenType>(static_cast<int16_t>(a << b)); }; break; case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); + (int32_t)b); return static_cast<OutEigenType>(static_cast<int32_t>(a << b)); }; break; @@ -342,21 +347,21 @@ int OpLogicalRightShift<Rank, Dtype>::register_fcn() case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); + (int32_t)b); return static_cast<OutEigenType>(static_cast<int8_t>(a) >> b); }; break; case TOSA_REF_TYPE_INT16: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); + (int32_t)b); return static_cast<OutEigenType>(static_cast<int16_t>(a) >> b); }; break; case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", - (int32_t)b); + (int32_t)b); return static_cast<OutEigenType>(static_cast<int32_t>(a) >> b); }; break; @@ -494,7 +499,8 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn() template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpMul<Rank, InDtype, OutDtype>::~OpMul() { - if (attribute) delete attribute; + if (attribute) + delete attribute; } template <int Rank, TOSA_REF_TYPE Dtype> @@ -547,9 +553,7 @@ int OpSub<Rank, Dtype>::register_fcn() } template <int Rank, TOSA_REF_TYPE InDtype> -OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) +OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_TABLE, id_) { setRequiredOperands(1, 1); @@ -561,7 +565,8 @@ OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, template <int Rank, TOSA_REF_TYPE InDtype> OpTable<Rank, InDtype>::~OpTable() { - if (attribute) delete attribute; + if (attribute) + delete attribute; } template <int Rank, TOSA_REF_TYPE InDtype> @@ -624,10 +629,11 @@ int OpTable<Rank, InDtype>::eval() int32_t base = table[index]; int32_t next = table[index + 1]; int32_t slope = next - base; - REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(), "OpTable: slope out of int16_t range"); + REQUIRE(slope <= std::numeric_limits<int16_t>::max() && slope >= std::numeric_limits<int16_t>::min(), + "OpTable: slope out of int16_t range"); // 4. interpolate, generate 16.7 (23-bit) output - int32_t value = (base << 7) + (slope) * frac; + int32_t value = (base << 7) + (slope)*frac; return value; }); |