aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_binary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r--reference_model/src/ops/ewise_binary.cc45
1 files changed, 37 insertions, 8 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 4d4f8b9..d07790e 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -212,6 +212,7 @@ int OpAdd<Rank, Dtype>::register_fcn()
template <int Rank, DType Dtype>
int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
{
+ bool round = attribute->round();
int32_t num_bits = 0;
switch (Dtype)
{
@@ -228,13 +229,18 @@ int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
- this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
- uint32_t sign = a & (1 << (num_bits - 1));
- uint32_t ones_mask = ONES_MASK(b) << (num_bits - b);
- if (sign)
- return ones_mask | (a >> b);
- else
- return (~ones_mask) & (a >> b);
+ this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
+ ASSERT_MSG_NODE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
+ (int32_t)b, num_bits);
+
+ InEigenType acc = a >> b;
+
+ if (round && b > 0 && (a >> (b - 1) & 1) != 0)
+ {
+ acc++;
+ }
+
+ return acc;
};
return 0;
@@ -415,11 +421,34 @@ int OpMinimum<Rank, Dtype>::register_fcn()
template <int Rank, DType InDtype, DType OutDtype>
int OpMul<Rank, InDtype, OutDtype>::register_fcn()
{
+ int32_t shift = attribute->shift();
+ ASSERT_MSG_NODE(InDtype == DType_INT32 || shift == 0, "OpMul: shift needs to be 0 but is %d if input is %s", shift,
+ EnumNamesDType()[InDtype]);
+
switch (InDtype)
{
case DType_FLOAT:
+ this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ break;
case DType_INT32:
- this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
+ int64_t result;
+ if (shift > 0)
+ {
+ int64_t round = 1L << (shift - 1);
+ result = a * b + round;
+ result = result >> shift;
+
+ ASSERT_MSG_NODE(result >= QMin && result <= QMax,
+ "OpMul: result %ld exceeds valid range [%ld, %ld]", result, QMin, QMax);
+ }
+ else
+ {
+ result = a * b;
+ }
+
+ return static_cast<OutEigenType>(result);
+ };
break;
case DType_INT8:
case DType_INT16: