diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 45 |
1 files changed, 37 insertions, 8 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 4d4f8b9..d07790e 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -212,6 +212,7 @@ int OpAdd<Rank, Dtype>::register_fcn() template <int Rank, DType Dtype> int OpArithmeticRightShift<Rank, Dtype>::register_fcn() { + bool round = attribute->round(); int32_t num_bits = 0; switch (Dtype) { @@ -228,13 +229,18 @@ int OpArithmeticRightShift<Rank, Dtype>::register_fcn() FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } - this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { - uint32_t sign = a & (1 << (num_bits - 1)); - uint32_t ones_mask = ONES_MASK(b) << (num_bits - b); - if (sign) - return ones_mask | (a >> b); - else - return (~ones_mask) & (a >> b); + this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType { + ASSERT_MSG_NODE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]", + (int32_t)b, num_bits); + + InEigenType acc = a >> b; + + if (round && b > 0 && (a >> (b - 1) & 1) != 0) + { + acc++; + } + + return acc; }; return 0; @@ -415,11 +421,34 @@ int OpMinimum<Rank, Dtype>::register_fcn() template <int Rank, DType InDtype, DType OutDtype> int OpMul<Rank, InDtype, OutDtype>::register_fcn() { + int32_t shift = attribute->shift(); + ASSERT_MSG_NODE(InDtype == DType_INT32 || shift == 0, "OpMul: shift needs to be 0 but is %d if input is %s", shift, + EnumNamesDType()[InDtype]); + switch (InDtype) { case DType_FLOAT: + this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + break; case DType_INT32: - this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType { + int64_t result; + if (shift > 0) + { + int64_t round = 1L << (shift - 1); + result = a * b + round; + result = result >> shift; + + ASSERT_MSG_NODE(result >= QMin && result <= QMax, + "OpMul: result %ld exceeds valid range [%ld, %ld]", result, QMin, QMax); + } + else + { + result = a * b; + } + + return static_cast<OutEigenType>(result); + }; break; case DType_INT8: case DType_INT16: |