diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2020-11-11 13:54:06 -0800 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2020-11-12 11:47:16 -0800 |
commit | aee1facbde25caf27cc34e5ec08eb8bba6af8e18 (patch) | |
tree | 0ff32b95e6f32444445ca01c1b47835b52fb955f /reference_model/src | |
parent | 99bea145a050e12f1b5f8301979713d9a9b04e12 (diff) | |
download | reference_model-aee1facbde25caf27cc34e5ec08eb8bba6af8e18.tar.gz |
Implement and add unit tests for MUL and ARITHMETIC_RIGHT_SHIFT
add .clang-format
Add expected failure for RESIZE and RESCALE unit tests
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I33c8afdc8998e8518f2b0e5fabddd36ce3aa2ee9
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 45 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 88 | ||||
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 6 | ||||
-rw-r--r-- | reference_model/src/quant_util.h | 3 |
4 files changed, 94 insertions, 48 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 4d4f8b9..d07790e 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -212,6 +212,7 @@ int OpAdd<Rank, Dtype>::register_fcn() template <int Rank, DType Dtype> int OpArithmeticRightShift<Rank, Dtype>::register_fcn() { + bool round = attribute->round(); int32_t num_bits = 0; switch (Dtype) { @@ -228,13 +229,18 @@ int OpArithmeticRightShift<Rank, Dtype>::register_fcn() FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } - this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { - uint32_t sign = a & (1 << (num_bits - 1)); - uint32_t ones_mask = ONES_MASK(b) << (num_bits - b); - if (sign) - return ones_mask | (a >> b); - else - return (~ones_mask) & (a >> b); + this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType { + ASSERT_MSG_NODE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]", + (int32_t)b, num_bits); + + InEigenType acc = a >> b; + + if (round && b > 0 && (a >> (b - 1) & 1) != 0) + { + acc++; + } + + return acc; }; return 0; @@ -415,11 +421,34 @@ int OpMinimum<Rank, Dtype>::register_fcn() template <int Rank, DType InDtype, DType OutDtype> int OpMul<Rank, InDtype, OutDtype>::register_fcn() { + int32_t shift = attribute->shift(); + ASSERT_MSG_NODE(InDtype == DType_INT32 || shift == 0, "OpMul: shift needs to be 0 but is %d if input is %s", shift, + EnumNamesDType()[InDtype]); + switch (InDtype) { case DType_FLOAT: + this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + break; case DType_INT32: - this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType { + int64_t result; + if (shift > 0) + { + int64_t round = 1L << (shift - 1); + result = a * b + round; + result = result >> shift; + + ASSERT_MSG_NODE(result >= QMin && result <= QMax, + "OpMul: result %ld exceeds valid range [%ld, %ld]", result, QMin, QMax); + } + else + { + result = a * b; + } + + return static_cast<OutEigenType>(result); + }; break; case DType_INT8: case DType_INT16: diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 00fb3d9..5bc5630 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -104,7 +104,7 @@ public: virtual int eval(); }; -#define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \ +#define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \ template <int Rank, DType Dtype> \ class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \ { \ @@ -121,41 +121,59 @@ public: virtual int register_fcn(); \ }; -#define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \ - template <int Rank, DType InDtype, DType OutDtype> \ - class Op##Opname : public BinaryNode<Rank, InDtype, OutDtype> \ - { \ - public: \ - Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : BinaryNode<Rank, InDtype, OutDtype>(Op_##OPNAME, qinfo_, id_) \ - { \ - register_fcn(); \ - } \ - static constexpr int32_t QMin = GetQMin<OutDtype>::value; \ - static constexpr int32_t QMax = GetQMax<OutDtype>::value; \ - using InEigenType = typename GetEigenType<InDtype>::type; \ - using OutEigenType = typename GetEigenType<OutDtype>::type; \ - virtual int register_fcn(); \ - }; +DEF_TEMPLATE_BINARY_OP_DEFAULT(Add, ADD) +DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseAnd, BITWISE_AND) +DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseOr, BITWISE_OR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseXor, BITWISE_XOR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalAnd, LOGICAL_AND) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalLeftShift, LOGICAL_LEFT_SHIFT) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalRightShift, LOGICAL_RIGHT_SHIFT) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalOr, LOGICAL_OR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalXor, LOGICAL_XOR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Maximum, MAXIMUM) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Minimum, MINIMUM) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Pow, POW) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB) + +#undef DEF_TEMPLATE_BINARY_OP_DEFAULT + +template <int Rank, DType Dtype> +class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype> +{ +public: + OpArithmeticRightShift(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode<Rank, Dtype, Dtype>(Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_) + { + INIT_ATTRIBUTE(ArithmeticRightShift); + register_fcn(); + } + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + virtual int register_fcn(); -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM) -DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB) - -#undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE -#undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE +protected: + TosaArithmeticRightShiftAttribute* attribute; +}; + +template <int Rank, DType InDtype, DType OutDtype> +class OpMul : public BinaryNode<Rank, InDtype, OutDtype> +{ +public: + OpMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode<Rank, InDtype, OutDtype>(Op_MUL, qinfo_, id_) + { + INIT_ATTRIBUTE(Mul); + register_fcn(); + } + static constexpr int64_t QMin = GetQMin<OutDtype>::value; + static constexpr int64_t QMax = GetQMax<OutDtype>::value; + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + virtual int register_fcn(); + +protected: + TosaMulAttribute* attribute; +}; template <int Rank> class OpTable : public GraphNode diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index a97bc0d..c505e29 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -130,8 +130,8 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift, double_round](InEigenType in_val) -> OutEigenType { InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled = TosaReference::QuantUtil::apply_scale_32( - input_zp_shifted, channel_multiplier, channel_shift, double_round); + int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier, + channel_shift, double_round); OutEigenType out_val = (OutEigenType)(scaled + output_zp); out_val = std::max<OutEigenType>(out_val, QMin); out_val = std::min<OutEigenType>(out_val, QMax); @@ -151,7 +151,7 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() output_2d = input_reshaped.unaryExpr( [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType { InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, + int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift, double_round); OutEigenType out_val = (OutEigenType)(scaled + output_zp); out_val = std::max<OutEigenType>(out_val, QMin); diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h index 3b58b66..f9ac501 100644 --- a/reference_model/src/quant_util.h +++ b/reference_model/src/quant_util.h @@ -34,8 +34,7 @@ public: int32_t& multiplier, int32_t& shift) { - ASSERT_MSG(value > 0, - "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value); + ASSERT_MSG(value > 0, "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value); uint32_t value_u32 = (uint32_t)value; int32_t k = 32 - LEADING_ZEROS_32(value_u32 - 1); // (1<<k)/2 < value <= (1<<k) int64_t numerator = ((1L << 30) + 1) << k; |