From aee1facbde25caf27cc34e5ec08eb8bba6af8e18 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Wed, 11 Nov 2020 13:54:06 -0800 Subject: Implement and add unit tests for MUL and ARITHMETIC_RIGHT_SHIFT add .clang-format Add expected failure for RESIZE and RESCALE unit tests Signed-off-by: Kevin Cheng Change-Id: I33c8afdc8998e8518f2b0e5fabddd36ce3aa2ee9 --- .clang-format | 34 ++++++++ reference_model/src/ops/ewise_binary.cc | 45 ++++++++-- reference_model/src/ops/ewise_binary.h | 88 +++++++++++-------- reference_model/src/ops/type_conversion.cc | 6 +- reference_model/src/quant_util.h | 3 +- serialization/attribute.def | 6 ++ serialization/operator.def | 4 +- serialization/tosa.fbs | 11 ++- serialization/tosa_generated.h | 126 +++++++++++++++++++--------- verif/tosa/ArithmeticRightShiftAttribute.py | 45 ++++++++++ verif/tosa/Attribute.py | 7 +- verif/tosa/MulAttribute.py | 45 ++++++++++ verif/tosa_serializer.py | 18 ++++ verif/tosa_test_gen.py | 103 ++++++++++++++++++++--- 14 files changed, 437 insertions(+), 104 deletions(-) create mode 100644 .clang-format create mode 100644 verif/tosa/ArithmeticRightShiftAttribute.py create mode 100644 verif/tosa/MulAttribute.py diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..66b0148 --- /dev/null +++ b/.clang-format @@ -0,0 +1,34 @@ +BasedOnStyle: LLVM +AccessModifierOffset: -4 +AllowShortFunctionsOnASingleLine: None +AlwaysBreakTemplateDeclarations: true +BinPackParameters: false +BraceWrapping: + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: true + AfterObjCDeclaration: true + AfterStruct: true + AfterUnion: true + AfterExternBlock: true + BeforeCatch: true + BeforeElse: true + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: true +BreakBeforeBraces: Custom +BreakConstructorInitializersBeforeComma: true +BreakConstructorInitializers: BeforeColon +Cpp11BracedListStyle: false +IndentCaseLabels: true +IndentWidth: 4 +IndentWrappedFunctionNames: true +PointerAlignment: Left +SpacesInContainerLiterals: false +AlignConsecutiveAssignments: true +ColumnLimit: 120 +ReflowComments: false +SpacesBeforeTrailingComments: 4 diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 4d4f8b9..d07790e 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -212,6 +212,7 @@ int OpAdd::register_fcn() template int OpArithmeticRightShift::register_fcn() { + bool round = attribute->round(); int32_t num_bits = 0; switch (Dtype) { @@ -228,13 +229,18 @@ int OpArithmeticRightShift::register_fcn() FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); } - this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { - uint32_t sign = a & (1 << (num_bits - 1)); - uint32_t ones_mask = ONES_MASK(b) << (num_bits - b); - if (sign) - return ones_mask | (a >> b); - else - return (~ones_mask) & (a >> b); + this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType { + ASSERT_MSG_NODE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]", + (int32_t)b, num_bits); + + InEigenType acc = a >> b; + + if (round && b > 0 && (a >> (b - 1) & 1) != 0) + { + acc++; + } + + return acc; }; return 0; @@ -415,11 +421,34 @@ int OpMinimum::register_fcn() template int OpMul::register_fcn() { + int32_t shift = attribute->shift(); + ASSERT_MSG_NODE(InDtype == DType_INT32 || shift == 0, "OpMul: shift needs to be 0 but is %d if input is %s", shift, + EnumNamesDType()[InDtype]); + switch (InDtype) { case DType_FLOAT: + this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + break; case DType_INT32: - this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType { + int64_t result; + if (shift > 0) + { + int64_t round = 1L << (shift - 1); + result = a * b + round; + result = result >> shift; + + ASSERT_MSG_NODE(result >= QMin && result <= QMax, + "OpMul: result %ld exceeds valid range [%ld, %ld]", result, QMin, QMax); + } + else + { + result = a * b; + } + + return static_cast(result); + }; break; case DType_INT8: case DType_INT16: diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 00fb3d9..5bc5630 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -104,7 +104,7 @@ public: virtual int eval(); }; -#define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \ +#define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \ template \ class Op##Opname : public BinaryNode \ { \ @@ -121,41 +121,59 @@ public: virtual int register_fcn(); \ }; -#define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \ - template \ - class Op##Opname : public BinaryNode \ - { \ - public: \ - Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : BinaryNode(Op_##OPNAME, qinfo_, id_) \ - { \ - register_fcn(); \ - } \ - static constexpr int32_t QMin = GetQMin::value; \ - static constexpr int32_t QMax = GetQMax::value; \ - using InEigenType = typename GetEigenType::type; \ - using OutEigenType = typename GetEigenType::type; \ - virtual int register_fcn(); \ - }; +DEF_TEMPLATE_BINARY_OP_DEFAULT(Add, ADD) +DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseAnd, BITWISE_AND) +DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseOr, BITWISE_OR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseXor, BITWISE_XOR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalAnd, LOGICAL_AND) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalLeftShift, LOGICAL_LEFT_SHIFT) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalRightShift, LOGICAL_RIGHT_SHIFT) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalOr, LOGICAL_OR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalXor, LOGICAL_XOR) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Maximum, MAXIMUM) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Minimum, MINIMUM) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Pow, POW) +DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB) + +#undef DEF_TEMPLATE_BINARY_OP_DEFAULT + +template +class OpArithmeticRightShift : public BinaryNode +{ +public: + OpArithmeticRightShift(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode(Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_) + { + INIT_ATTRIBUTE(ArithmeticRightShift); + register_fcn(); + } + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + virtual int register_fcn(); -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM) -DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW) -DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB) - -#undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE -#undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE +protected: + TosaArithmeticRightShiftAttribute* attribute; +}; + +template +class OpMul : public BinaryNode +{ +public: + OpMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode(Op_MUL, qinfo_, id_) + { + INIT_ATTRIBUTE(Mul); + register_fcn(); + } + static constexpr int64_t QMin = GetQMin::value; + static constexpr int64_t QMax = GetQMax::value; + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + virtual int register_fcn(); + +protected: + TosaMulAttribute* attribute; +}; template class OpTable : public GraphNode diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index a97bc0d..c505e29 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -130,8 +130,8 @@ int OpRescale::eval() curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift, double_round](InEigenType in_val) -> OutEigenType { InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled = TosaReference::QuantUtil::apply_scale_32( - input_zp_shifted, channel_multiplier, channel_shift, double_round); + int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier, + channel_shift, double_round); OutEigenType out_val = (OutEigenType)(scaled + output_zp); out_val = std::max(out_val, QMin); out_val = std::min(out_val, QMax); @@ -151,7 +151,7 @@ int OpRescale::eval() output_2d = input_reshaped.unaryExpr( [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType { InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, + int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift, double_round); OutEigenType out_val = (OutEigenType)(scaled + output_zp); out_val = std::max(out_val, QMin); diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h index 3b58b66..f9ac501 100644 --- a/reference_model/src/quant_util.h +++ b/reference_model/src/quant_util.h @@ -34,8 +34,7 @@ public: int32_t& multiplier, int32_t& shift) { - ASSERT_MSG(value > 0, - "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value); + ASSERT_MSG(value > 0, "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value); uint32_t value_u32 = (uint32_t)value; int32_t k = 32 - LEADING_ZEROS_32(value_u32 - 1); // (1< struct AttributeTraits { static const Attribute enum_value = Attribute_RescaleAttribute; }; -template<> struct AttributeTraits { - static const Attribute enum_value = Attribute_CustomAttribute; +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_MulAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_ArithmeticRightShiftAttribute; }; template<> struct AttributeTraits { @@ -1457,54 +1466,84 @@ inline flatbuffers::Offset CreateRescaleAttributeDirect( per_channel); } -struct CustomAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { +struct MulAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_IDENTIFIER = 4 + VT_SHIFT = 4 }; - const flatbuffers::String *identifier() const { - return GetPointer(VT_IDENTIFIER); + int32_t shift() const { + return GetField(VT_SHIFT, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_IDENTIFIER) && - verifier.VerifyString(identifier()) && + VerifyField(verifier, VT_SHIFT) && verifier.EndTable(); } }; -struct CustomAttributeBuilder { +struct MulAttributeBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_identifier(flatbuffers::Offset identifier) { - fbb_.AddOffset(CustomAttribute::VT_IDENTIFIER, identifier); + void add_shift(int32_t shift) { + fbb_.AddElement(MulAttribute::VT_SHIFT, shift, 0); } - explicit CustomAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit MulAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - CustomAttributeBuilder &operator=(const CustomAttributeBuilder &); - flatbuffers::Offset Finish() { + MulAttributeBuilder &operator=(const MulAttributeBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateCustomAttribute( +inline flatbuffers::Offset CreateMulAttribute( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset identifier = 0) { - CustomAttributeBuilder builder_(_fbb); - builder_.add_identifier(identifier); + int32_t shift = 0) { + MulAttributeBuilder builder_(_fbb); + builder_.add_shift(shift); return builder_.Finish(); } -inline flatbuffers::Offset CreateCustomAttributeDirect( +struct ArithmeticRightShiftAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ROUND = 4 + }; + bool round() const { + return GetField(VT_ROUND, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ROUND) && + verifier.EndTable(); + } +}; + +struct ArithmeticRightShiftAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_round(bool round) { + fbb_.AddElement(ArithmeticRightShiftAttribute::VT_ROUND, static_cast(round), 0); + } + explicit ArithmeticRightShiftAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ArithmeticRightShiftAttributeBuilder &operator=(const ArithmeticRightShiftAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateArithmeticRightShiftAttribute( flatbuffers::FlatBufferBuilder &_fbb, - const char *identifier = nullptr) { - auto identifier__ = identifier ? _fbb.CreateString(identifier) : 0; - return tosa::CreateCustomAttribute( - _fbb, - identifier__); + bool round = false) { + ArithmeticRightShiftAttributeBuilder builder_(_fbb); + builder_.add_round(round); + return builder_.Finish(); } struct CondIfAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -2066,8 +2105,11 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const RescaleAttribute *attribute_as_RescaleAttribute() const { return attribute_type() == Attribute_RescaleAttribute ? static_cast(attribute()) : nullptr; } - const CustomAttribute *attribute_as_CustomAttribute() const { - return attribute_type() == Attribute_CustomAttribute ? static_cast(attribute()) : nullptr; + const MulAttribute *attribute_as_MulAttribute() const { + return attribute_type() == Attribute_MulAttribute ? static_cast(attribute()) : nullptr; + } + const ArithmeticRightShiftAttribute *attribute_as_ArithmeticRightShiftAttribute() const { + return attribute_type() == Attribute_ArithmeticRightShiftAttribute ? static_cast(attribute()) : nullptr; } const CondIfAttribute *attribute_as_CondIfAttribute() const { return attribute_type() == Attribute_CondIfAttribute ? static_cast(attribute()) : nullptr; @@ -2163,8 +2205,12 @@ template<> inline const RescaleAttribute *TosaOperator::attribute_as inline const CustomAttribute *TosaOperator::attribute_as() const { - return attribute_as_CustomAttribute(); +template<> inline const MulAttribute *TosaOperator::attribute_as() const { + return attribute_as_MulAttribute(); +} + +template<> inline const ArithmeticRightShiftAttribute *TosaOperator::attribute_as() const { + return attribute_as_ArithmeticRightShiftAttribute(); } template<> inline const CondIfAttribute *TosaOperator::attribute_as() const { @@ -2492,8 +2538,12 @@ inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, At auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - case Attribute_CustomAttribute: { - auto ptr = reinterpret_cast(obj); + case Attribute_MulAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_ArithmeticRightShiftAttribute: { + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case Attribute_CondIfAttribute: { diff --git a/verif/tosa/ArithmeticRightShiftAttribute.py b/verif/tosa/ArithmeticRightShiftAttribute.py new file mode 100644 index 0000000..eaa52ab --- /dev/null +++ b/verif/tosa/ArithmeticRightShiftAttribute.py @@ -0,0 +1,45 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers + +class ArithmeticRightShiftAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsArithmeticRightShiftAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ArithmeticRightShiftAttribute() + x.Init(buf, n + offset) + return x + + # ArithmeticRightShiftAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ArithmeticRightShiftAttribute + def Round(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + +def ArithmeticRightShiftAttributeStart(builder): builder.StartObject(1) +def ArithmeticRightShiftAttributeAddRound(builder, round): builder.PrependBoolSlot(0, round, 0) +def ArithmeticRightShiftAttributeEnd(builder): return builder.EndObject() diff --git a/verif/tosa/Attribute.py b/verif/tosa/Attribute.py index a4d96e0..5d79a08 100644 --- a/verif/tosa/Attribute.py +++ b/verif/tosa/Attribute.py @@ -30,7 +30,8 @@ class Attribute(object): ResizeAttribute = 9 ClampAttribute = 10 RescaleAttribute = 11 - CustomAttribute = 12 - CondIfAttribute = 13 - WhileLoopAttribute = 14 + MulAttribute = 12 + ArithmeticRightShiftAttribute = 13 + CondIfAttribute = 14 + WhileLoopAttribute = 15 diff --git a/verif/tosa/MulAttribute.py b/verif/tosa/MulAttribute.py new file mode 100644 index 0000000..f45b285 --- /dev/null +++ b/verif/tosa/MulAttribute.py @@ -0,0 +1,45 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers + +class MulAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsMulAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = MulAttribute() + x.Init(buf, n + offset) + return x + + # MulAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # MulAttribute + def Shift(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def MulAttributeStart(builder): builder.StartObject(1) +def MulAttributeAddShift(builder, shift): builder.PrependInt32Slot(0, shift, 0) +def MulAttributeEnd(builder): return builder.EndObject() diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py index 7ba68c3..07e0e1a 100644 --- a/verif/tosa_serializer.py +++ b/verif/tosa_serializer.py @@ -247,6 +247,24 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.bools.append((a.RescaleAttributeAddPerChannel, per_channel)) + def MulAttribute(self, shift): + from tosa import MulAttribute as a, Attribute + + self.utype = Attribute.Attribute().MulAttribute + self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd) + + self.ints.append((a.MulAttributeAddShift, + shift)) + + def ArithmeticRightShiftAttribute(self, round): + from tosa import ArithmeticRightShiftAttribute as a, Attribute + + self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute + self.optFcns = (a.ArithmeticRightShiftAttributeStart, a.ArithmeticRightShiftAttributeEnd) + + self.bools.append((a.ArithmeticRightShiftAttributeAddRound, + round)) + def CustomAttribute(self, identifier): from tosa import CustomAttribute as a, Attribute diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index dc2d803..302e4f4 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -489,6 +489,30 @@ class TosaArgGen: return arg_list + @staticmethod + def agMul(testGen, opName, shapeList, dtype): + arg_list = [] + + if dtype is DType.INT32: + for p in range(testGen.args.num_rand_permutations): + + shift = testGen.randInt(0, 32) + + arg_list.append(('perm{}_shift{}'.format(p, shift), [shift])) + else: + arg_list.append(('shift0', [0])) + + return arg_list + + @staticmethod + def agArithmeticRightShift(testGen, opName, shapeList, dtype): + arg_list = [] + + arg_list.append(('roundTrue', [True])) + arg_list.append(('roundFalse', [False])) + + return arg_list + # Helper function for reshape. Gets some factors of a larger number. @staticmethod def getFactors(val, start=1): @@ -647,7 +671,7 @@ class TosaArgGen: arg_list.append(('mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}'.format(m, shift, output_dims[0], output_dims[1], testGen.typeStr(outputDType), stride[0], stride[1], offset[0], offset[1]), - [m, stride, offset, shift, output_dims, outputDType])) + [m, stride, offset, shift, output_dims, dtype, outputDType])) return arg_list @@ -850,7 +874,16 @@ class TosaTestGen: self.ser.addOperator(op, [a.name, b.name], [result_tens.name]) return result_tens - def build_mul(self, op, a, b): + def build_arithmetic_right_shift(self, op, a, b, round): + result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b) + + attr = ts.TosaSerializerAttribute() + attr.ArithmeticRightShiftAttribute(round) + + self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr) + return result_tens + + def build_mul(self, op, a, b, shift): result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b) # Special for multiply: @@ -858,7 +891,10 @@ class TosaTestGen: if a.dtype != DType.FLOAT: result_tens.setDtype(DType.INT32) - self.ser.addOperator(op, [a.name, b.name], [result_tens.name]) + attr = ts.TosaSerializerAttribute() + attr.MulAttribute(shift) + + self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr) return result_tens def build_table(self, op, a): @@ -1121,8 +1157,8 @@ class TosaTestGen: return result_tens - def build_resize(self, op, input, mode, stride, offset, shift, output_dims, output_dtype): - result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, output_dims, output_dtype) + def build_resize(self, op, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype): + result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype) attr = ts.TosaSerializerAttribute() attr.ResizeAttribute(output_dims, stride, offset, shift, mode) @@ -1191,6 +1227,8 @@ class TosaTestGen: for i in range(nc): multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(scale_arr[i], scale32) + if shift_arr[i] < 2 or shift_arr[i] > 62: + self.ser.setExpectedFailure(True, 'OpRescale: invalid shift value') #print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp)) @@ -1413,8 +1451,30 @@ class TosaTestGen: # Build the random tensor operands and the test tens = [] - tens.extend(self.buildPlaceholderTensors(shapeList[0:pCount], dtype)) - tens.extend(self.buildConstTensors(shapeList[pCount:], dtype)) + + # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits] + if op['op'] == Op.ARITHMETIC_RIGHT_SHIFT: + assert pCount == 2 and cCount == 0, 'Op.ArithmeticRightShift must have 2 placeholders, 0 consts' + + placeholders = [] + for idx, shape in enumerate(shapeList[:]): + if idx == 1: + if dtype == DType.INT8: + arr = np.int32(self.rng.integers(low=0, high=8, size=shape)) + elif dtype == DType.INT16: + arr = np.int32(self.rng.integers(low=0, high=16, size=shape)) + elif dtype == DType.INT32: + arr = np.int32(self.rng.integers(low=0, high=32, size=shape)) + else: + raise Exception('OpArithmeticRightShift: invalid input dtype') + else: + arr = self.getRandTensor(shapeList[0], dtype) + placeholders.append(self.ser.addPlaceholder(shape, dtype, Usage.ACTIVATION, [], arr)) + + tens.extend(placeholders) + else: + tens.extend(self.buildPlaceholderTensors(shapeList[0:pCount], dtype)) + tens.extend(self.buildConstTensors(shapeList[pCount:], dtype)) if qgen is not None: qinfo = qgen(self, op, dtype) @@ -1536,7 +1596,7 @@ class TosaTestGen: 'arithmetic_right_shift': { 'op': Op.ARITHMETIC_RIGHT_SHIFT, 'operands': (2, 0), - 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), + 'build_fcn': (build_arithmetic_right_shift, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agArithmeticRightShift), 'types': TYPE_PURE_INT }, 'bitwise_and': @@ -1602,7 +1662,7 @@ class TosaTestGen: 'mul': { 'op': Op.MUL, 'operands': (2, 0), - 'build_fcn': (build_mul, TosaTensorGen.tgBroadcastFuzz, None), + 'build_fcn': (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul), 'types': TYPE_PURE_INT_FP }, 'pow': @@ -2271,13 +2331,36 @@ class OutputShaper: return ser.addOutput(input.shape, DType.INT32, input.usage, input.dformat) @staticmethod - def resizeOp(ser, input, mode, stride, offset, shift, output_dims, output_dtype): + def resizeOp(ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype): output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]] if stride[0] <= 0 or stride[1] <= 0: ser.setExpectedFailure(True, 'Negative or zero stride') + if mode == ResizeMode.BILINEAR: + if input_dtype == DType.INT8: + if output_dtype != DType.INT32: + ser.setExpectedFailure(True, 'Invalid output data type') + elif input_dtype == DType.INT16: + if output_dtype != DType.INT48: + ser.setexpectedfailure(true, 'Invalid output data type') + else: + ser.setexpectedfailure(true, 'Invalid input data type') + + elif mode == ResizeMode.NEAREST: + if input_dtype == DType.INT8: + if output_dtype != DType.INT8: + ser.setExpectedFailure(True, 'Invalid output data type') + elif input_dtype == DType.INT16: + if output_dtype != DType.INT16: + ser.setexpectedfailure(true, 'Invalid output data type') + else: + ser.setexpectedfailure(true, 'Invalid input data type') + + else: + ser.setexpectedfailure(true, 'Invalid resize mode') + return ser.addOutput(output_dims, output_dtype, input.usage, input.dformat) @staticmethod -- cgit v1.2.1