aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src')
-rw-r--r--reference_model/src/ops/ewise_binary.cc45
-rw-r--r--reference_model/src/ops/ewise_binary.h88
-rw-r--r--reference_model/src/ops/type_conversion.cc6
-rw-r--r--reference_model/src/quant_util.h3
4 files changed, 94 insertions, 48 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 4d4f8b9..d07790e 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -212,6 +212,7 @@ int OpAdd<Rank, Dtype>::register_fcn()
template <int Rank, DType Dtype>
int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
{
+ bool round = attribute->round();
int32_t num_bits = 0;
switch (Dtype)
{
@@ -228,13 +229,18 @@ int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
- this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
- uint32_t sign = a & (1 << (num_bits - 1));
- uint32_t ones_mask = ONES_MASK(b) << (num_bits - b);
- if (sign)
- return ones_mask | (a >> b);
- else
- return (~ones_mask) & (a >> b);
+ this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
+ ASSERT_MSG_NODE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
+ (int32_t)b, num_bits);
+
+ InEigenType acc = a >> b;
+
+ if (round && b > 0 && (a >> (b - 1) & 1) != 0)
+ {
+ acc++;
+ }
+
+ return acc;
};
return 0;
@@ -415,11 +421,34 @@ int OpMinimum<Rank, Dtype>::register_fcn()
template <int Rank, DType InDtype, DType OutDtype>
int OpMul<Rank, InDtype, OutDtype>::register_fcn()
{
+ int32_t shift = attribute->shift();
+ ASSERT_MSG_NODE(InDtype == DType_INT32 || shift == 0, "OpMul: shift needs to be 0 but is %d if input is %s", shift,
+ EnumNamesDType()[InDtype]);
+
switch (InDtype)
{
case DType_FLOAT:
+ this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ break;
case DType_INT32:
- this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
+ int64_t result;
+ if (shift > 0)
+ {
+ int64_t round = 1L << (shift - 1);
+ result = a * b + round;
+ result = result >> shift;
+
+ ASSERT_MSG_NODE(result >= QMin && result <= QMax,
+ "OpMul: result %ld exceeds valid range [%ld, %ld]", result, QMin, QMax);
+ }
+ else
+ {
+ result = a * b;
+ }
+
+ return static_cast<OutEigenType>(result);
+ };
break;
case DType_INT8:
case DType_INT16:
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 00fb3d9..5bc5630 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -104,7 +104,7 @@ public:
virtual int eval();
};
-#define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \
+#define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \
template <int Rank, DType Dtype> \
class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
{ \
@@ -121,41 +121,59 @@ public:
virtual int register_fcn(); \
};
-#define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \
- template <int Rank, DType InDtype, DType OutDtype> \
- class Op##Opname : public BinaryNode<Rank, InDtype, OutDtype> \
- { \
- public: \
- Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
- : BinaryNode<Rank, InDtype, OutDtype>(Op_##OPNAME, qinfo_, id_) \
- { \
- register_fcn(); \
- } \
- static constexpr int32_t QMin = GetQMin<OutDtype>::value; \
- static constexpr int32_t QMax = GetQMax<OutDtype>::value; \
- using InEigenType = typename GetEigenType<InDtype>::type; \
- using OutEigenType = typename GetEigenType<OutDtype>::type; \
- virtual int register_fcn(); \
- };
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Add, ADD)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseAnd, BITWISE_AND)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseOr, BITWISE_OR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseXor, BITWISE_XOR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalAnd, LOGICAL_AND)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalOr, LOGICAL_OR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalXor, LOGICAL_XOR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Maximum, MAXIMUM)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Minimum, MINIMUM)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Pow, POW)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB)
+
+#undef DEF_TEMPLATE_BINARY_OP_DEFAULT
+
+template <int Rank, DType Dtype>
+class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype>
+{
+public:
+ OpArithmeticRightShift(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, Dtype>(Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_)
+ {
+ INIT_ATTRIBUTE(ArithmeticRightShift);
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM)
-DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB)
-
-#undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE
-#undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE
+protected:
+ TosaArithmeticRightShiftAttribute* attribute;
+};
+
+template <int Rank, DType InDtype, DType OutDtype>
+class OpMul : public BinaryNode<Rank, InDtype, OutDtype>
+{
+public:
+ OpMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, InDtype, OutDtype>(Op_MUL, qinfo_, id_)
+ {
+ INIT_ATTRIBUTE(Mul);
+ register_fcn();
+ }
+ static constexpr int64_t QMin = GetQMin<OutDtype>::value;
+ static constexpr int64_t QMax = GetQMax<OutDtype>::value;
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ virtual int register_fcn();
+
+protected:
+ TosaMulAttribute* attribute;
+};
template <int Rank>
class OpTable : public GraphNode
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index a97bc0d..c505e29 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -130,8 +130,8 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
double_round](InEigenType in_val) -> OutEigenType {
InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
- int32_t scaled = TosaReference::QuantUtil::apply_scale_32(
- input_zp_shifted, channel_multiplier, channel_shift, double_round);
+ int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
+ channel_shift, double_round);
OutEigenType out_val = (OutEigenType)(scaled + output_zp);
out_val = std::max<OutEigenType>(out_val, QMin);
out_val = std::min<OutEigenType>(out_val, QMax);
@@ -151,7 +151,7 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
output_2d = input_reshaped.unaryExpr(
[input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType {
InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
- int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
+ int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
tensor_shift, double_round);
OutEigenType out_val = (OutEigenType)(scaled + output_zp);
out_val = std::max<OutEigenType>(out_val, QMin);
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
index 3b58b66..f9ac501 100644
--- a/reference_model/src/quant_util.h
+++ b/reference_model/src/quant_util.h
@@ -34,8 +34,7 @@ public:
int32_t& multiplier,
int32_t& shift)
{
- ASSERT_MSG(value > 0,
- "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value);
+ ASSERT_MSG(value > 0, "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value);
uint32_t value_u32 = (uint32_t)value;
int32_t k = 32 - LEADING_ZEROS_32(value_u32 - 1); // (1<<k)/2 < value <= (1<<k)
int64_t numerator = ((1L << 30) + 1) << k;