aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2020-11-11 13:54:06 -0800
committerKevin Cheng <kevin.cheng@arm.com>2020-11-12 11:47:16 -0800
commitaee1facbde25caf27cc34e5ec08eb8bba6af8e18 (patch)
tree0ff32b95e6f32444445ca01c1b47835b52fb955f
parent99bea145a050e12f1b5f8301979713d9a9b04e12 (diff)
downloadreference_model-aee1facbde25caf27cc34e5ec08eb8bba6af8e18.tar.gz
Implement and add unit tests for MUL and ARITHMETIC_RIGHT_SHIFT
add .clang-format Add expected failure for RESIZE and RESCALE unit tests Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I33c8afdc8998e8518f2b0e5fabddd36ce3aa2ee9
-rw-r--r--.clang-format34
-rw-r--r--reference_model/src/ops/ewise_binary.cc45
-rw-r--r--reference_model/src/ops/ewise_binary.h88
-rw-r--r--reference_model/src/ops/type_conversion.cc6
-rw-r--r--reference_model/src/quant_util.h3
-rw-r--r--serialization/attribute.def6
-rw-r--r--serialization/operator.def4
-rw-r--r--serialization/tosa.fbs11
-rw-r--r--serialization/tosa_generated.h126
-rw-r--r--verif/tosa/ArithmeticRightShiftAttribute.py45
-rw-r--r--verif/tosa/Attribute.py7
-rw-r--r--verif/tosa/MulAttribute.py45
-rw-r--r--verif/tosa_serializer.py18
-rw-r--r--verif/tosa_test_gen.py103
14 files changed, 437 insertions, 104 deletions
diff --git a/.clang-format b/.clang-format
new file mode 100644
index 0000000..66b0148
--- /dev/null
+++ b/.clang-format
@@ -0,0 +1,34 @@
+BasedOnStyle: LLVM
+AccessModifierOffset: -4
+AllowShortFunctionsOnASingleLine: None
+AlwaysBreakTemplateDeclarations: true
+BinPackParameters: false
+BraceWrapping:
+ AfterClass: true
+ AfterControlStatement: true
+ AfterEnum: true
+ AfterFunction: true
+ AfterNamespace: true
+ AfterObjCDeclaration: true
+ AfterStruct: true
+ AfterUnion: true
+ AfterExternBlock: true
+ BeforeCatch: true
+ BeforeElse: true
+ IndentBraces: false
+ SplitEmptyFunction: false
+ SplitEmptyRecord: false
+ SplitEmptyNamespace: true
+BreakBeforeBraces: Custom
+BreakConstructorInitializersBeforeComma: true
+BreakConstructorInitializers: BeforeColon
+Cpp11BracedListStyle: false
+IndentCaseLabels: true
+IndentWidth: 4
+IndentWrappedFunctionNames: true
+PointerAlignment: Left
+SpacesInContainerLiterals: false
+AlignConsecutiveAssignments: true
+ColumnLimit: 120
+ReflowComments: false
+SpacesBeforeTrailingComments: 4
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 4d4f8b9..d07790e 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -212,6 +212,7 @@ int OpAdd<Rank, Dtype>::register_fcn()
template <int Rank, DType Dtype>
int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
{
+ bool round = attribute->round();
int32_t num_bits = 0;
switch (Dtype)
{
@@ -228,13 +229,18 @@ int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
}
- this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
- uint32_t sign = a & (1 << (num_bits - 1));
- uint32_t ones_mask = ONES_MASK(b) << (num_bits - b);
- if (sign)
- return ones_mask | (a >> b);
- else
- return (~ones_mask) & (a >> b);
+ this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
+ ASSERT_MSG_NODE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
+ (int32_t)b, num_bits);
+
+ InEigenType acc = a >> b;
+
+ if (round && b > 0 && (a >> (b - 1) & 1) != 0)
+ {
+ acc++;
+ }
+
+ return acc;
};
return 0;
@@ -415,11 +421,34 @@ int OpMinimum<Rank, Dtype>::register_fcn()
template <int Rank, DType InDtype, DType OutDtype>
int OpMul<Rank, InDtype, OutDtype>::register_fcn()
{
+ int32_t shift = attribute->shift();
+ ASSERT_MSG_NODE(InDtype == DType_INT32 || shift == 0, "OpMul: shift needs to be 0 but is %d if input is %s", shift,
+ EnumNamesDType()[InDtype]);
+
switch (InDtype)
{
case DType_FLOAT:
+ this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ break;
case DType_INT32:
- this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
+ int64_t result;
+ if (shift > 0)
+ {
+ int64_t round = 1L << (shift - 1);
+ result = a * b + round;
+ result = result >> shift;
+
+ ASSERT_MSG_NODE(result >= QMin && result <= QMax,
+ "OpMul: result %ld exceeds valid range [%ld, %ld]", result, QMin, QMax);
+ }
+ else
+ {
+ result = a * b;
+ }
+
+ return static_cast<OutEigenType>(result);
+ };
break;
case DType_INT8:
case DType_INT16:
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 00fb3d9..5bc5630 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -104,7 +104,7 @@ public:
virtual int eval();
};
-#define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \
+#define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \
template <int Rank, DType Dtype> \
class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
{ \
@@ -121,41 +121,59 @@ public:
virtual int register_fcn(); \
};
-#define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \
- template <int Rank, DType InDtype, DType OutDtype> \
- class Op##Opname : public BinaryNode<Rank, InDtype, OutDtype> \
- { \
- public: \
- Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
- : BinaryNode<Rank, InDtype, OutDtype>(Op_##OPNAME, qinfo_, id_) \
- { \
- register_fcn(); \
- } \
- static constexpr int32_t QMin = GetQMin<OutDtype>::value; \
- static constexpr int32_t QMax = GetQMax<OutDtype>::value; \
- using InEigenType = typename GetEigenType<InDtype>::type; \
- using OutEigenType = typename GetEigenType<OutDtype>::type; \
- virtual int register_fcn(); \
- };
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Add, ADD)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseAnd, BITWISE_AND)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseOr, BITWISE_OR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseXor, BITWISE_XOR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalAnd, LOGICAL_AND)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalOr, LOGICAL_OR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalXor, LOGICAL_XOR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Maximum, MAXIMUM)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Minimum, MINIMUM)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Pow, POW)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB)
+
+#undef DEF_TEMPLATE_BINARY_OP_DEFAULT
+
+template <int Rank, DType Dtype>
+class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype>
+{
+public:
+ OpArithmeticRightShift(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, Dtype>(Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_)
+ {
+ INIT_ATTRIBUTE(ArithmeticRightShift);
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM)
-DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW)
-DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB)
-
-#undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE
-#undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE
+protected:
+ TosaArithmeticRightShiftAttribute* attribute;
+};
+
+template <int Rank, DType InDtype, DType OutDtype>
+class OpMul : public BinaryNode<Rank, InDtype, OutDtype>
+{
+public:
+ OpMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, InDtype, OutDtype>(Op_MUL, qinfo_, id_)
+ {
+ INIT_ATTRIBUTE(Mul);
+ register_fcn();
+ }
+ static constexpr int64_t QMin = GetQMin<OutDtype>::value;
+ static constexpr int64_t QMax = GetQMax<OutDtype>::value;
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ virtual int register_fcn();
+
+protected:
+ TosaMulAttribute* attribute;
+};
template <int Rank>
class OpTable : public GraphNode
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index a97bc0d..c505e29 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -130,8 +130,8 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
double_round](InEigenType in_val) -> OutEigenType {
InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
- int32_t scaled = TosaReference::QuantUtil::apply_scale_32(
- input_zp_shifted, channel_multiplier, channel_shift, double_round);
+ int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
+ channel_shift, double_round);
OutEigenType out_val = (OutEigenType)(scaled + output_zp);
out_val = std::max<OutEigenType>(out_val, QMin);
out_val = std::min<OutEigenType>(out_val, QMax);
@@ -151,7 +151,7 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
output_2d = input_reshaped.unaryExpr(
[input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType {
InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
- int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
+ int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
tensor_shift, double_round);
OutEigenType out_val = (OutEigenType)(scaled + output_zp);
out_val = std::max<OutEigenType>(out_val, QMin);
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
index 3b58b66..f9ac501 100644
--- a/reference_model/src/quant_util.h
+++ b/reference_model/src/quant_util.h
@@ -34,8 +34,7 @@ public:
int32_t& multiplier,
int32_t& shift)
{
- ASSERT_MSG(value > 0,
- "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value);
+ ASSERT_MSG(value > 0, "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value);
uint32_t value_u32 = (uint32_t)value;
int32_t k = 32 - LEADING_ZEROS_32(value_u32 - 1); // (1<<k)/2 < value <= (1<<k)
int64_t numerator = ((1L << 30) + 1) << k;
diff --git a/serialization/attribute.def b/serialization/attribute.def
index 88e8c81..d937395 100644
--- a/serialization/attribute.def
+++ b/serialization/attribute.def
@@ -81,6 +81,12 @@ DEF_ATTRIBUTE(Rescale, 7,
bool, S, double_round,
bool, S, per_channel)
+DEF_ATTRIBUTE(Mul, 1,
+ int32_t, S, shift)
+
+DEF_ATTRIBUTE(ArithmeticRightShift, 1,
+ bool, S, round)
+
DEF_ATTRIBUTE(CondIf, 2,
string, S, then_branch,
string, S, else_branch)
diff --git a/serialization/operator.def b/serialization/operator.def
index 66d3784..267976c 100644
--- a/serialization/operator.def
+++ b/serialization/operator.def
@@ -45,7 +45,7 @@ DEF_OPERATOR(tanh, TANH, Tanh,
/* elementwise - binary */
DEF_OPERATOR(add, ADD, Add, None, None)
-DEF_OPERATOR(arithmetic_right_shift, ARITHMETIC_RIGHT_SHIFT, ArithmeticRightShift, None, None)
+DEF_OPERATOR(arithmetic_right_shift, ARITHMETIC_RIGHT_SHIFT, ArithmeticRightShift, ArithmeticRightShift, None)
DEF_OPERATOR(bitwise_and, BITWISE_AND, BitwiseAnd, None, None)
DEF_OPERATOR(bitwise_or, BITWISE_OR, BitwiseOr, None, None)
DEF_OPERATOR(bitwise_xor, BITWISE_XOR, BitwiseXor, None, None)
@@ -56,7 +56,7 @@ DEF_OPERATOR(logical_or, LOGICAL_OR, LogicalOr,
DEF_OPERATOR(logical_xor, LOGICAL_XOR, LogicalXor, None, None)
DEF_OPERATOR(maximum, MAXIMUM, Maximum, None, None)
DEF_OPERATOR(minimum, MINIMUM, Minimum, None, None)
-DEF_OPERATOR(mul, MUL, Mul, None, None)
+DEF_OPERATOR(mul, MUL, Mul, Mul, None)
DEF_OPERATOR(pow, POW, Pow, None, None)
DEF_OPERATOR(sub, SUB, Sub, None, None)
DEF_OPERATOR(table, TABLE, Table, None, None)
diff --git a/serialization/tosa.fbs b/serialization/tosa.fbs
index 841cf3d..f57d9dc 100644
--- a/serialization/tosa.fbs
+++ b/serialization/tosa.fbs
@@ -167,7 +167,8 @@ union Attribute {
ResizeAttribute,
ClampAttribute,
RescaleAttribute,
- CustomAttribute,
+ MulAttribute,
+ ArithmeticRightShiftAttribute,
CondIfAttribute,
WhileLoopAttribute,
}
@@ -238,8 +239,12 @@ table RescaleAttribute {
per_channel: bool;
}
-table CustomAttribute {
- identifier: string;
+table MulAttribute {
+ shift: int32;
+}
+
+table ArithmeticRightShiftAttribute {
+ round: bool;
}
table CondIfAttribute {
diff --git a/serialization/tosa_generated.h b/serialization/tosa_generated.h
index 5bb21f3..5140f7b 100644
--- a/serialization/tosa_generated.h
+++ b/serialization/tosa_generated.h
@@ -45,7 +45,9 @@ struct ClampAttribute;
struct RescaleAttribute;
-struct CustomAttribute;
+struct MulAttribute;
+
+struct ArithmeticRightShiftAttribute;
struct CondIfAttribute;
@@ -478,14 +480,15 @@ enum Attribute {
Attribute_ResizeAttribute = 9,
Attribute_ClampAttribute = 10,
Attribute_RescaleAttribute = 11,
- Attribute_CustomAttribute = 12,
- Attribute_CondIfAttribute = 13,
- Attribute_WhileLoopAttribute = 14,
+ Attribute_MulAttribute = 12,
+ Attribute_ArithmeticRightShiftAttribute = 13,
+ Attribute_CondIfAttribute = 14,
+ Attribute_WhileLoopAttribute = 15,
Attribute_MIN = Attribute_NONE,
Attribute_MAX = Attribute_WhileLoopAttribute
};
-inline const Attribute (&EnumValuesAttribute())[15] {
+inline const Attribute (&EnumValuesAttribute())[16] {
static const Attribute values[] = {
Attribute_NONE,
Attribute_Pool2dAttribute,
@@ -499,7 +502,8 @@ inline const Attribute (&EnumValuesAttribute())[15] {
Attribute_ResizeAttribute,
Attribute_ClampAttribute,
Attribute_RescaleAttribute,
- Attribute_CustomAttribute,
+ Attribute_MulAttribute,
+ Attribute_ArithmeticRightShiftAttribute,
Attribute_CondIfAttribute,
Attribute_WhileLoopAttribute
};
@@ -520,7 +524,8 @@ inline const char * const *EnumNamesAttribute() {
"ResizeAttribute",
"ClampAttribute",
"RescaleAttribute",
- "CustomAttribute",
+ "MulAttribute",
+ "ArithmeticRightShiftAttribute",
"CondIfAttribute",
"WhileLoopAttribute",
nullptr
@@ -582,8 +587,12 @@ template<> struct AttributeTraits<RescaleAttribute> {
static const Attribute enum_value = Attribute_RescaleAttribute;
};
-template<> struct AttributeTraits<CustomAttribute> {
- static const Attribute enum_value = Attribute_CustomAttribute;
+template<> struct AttributeTraits<MulAttribute> {
+ static const Attribute enum_value = Attribute_MulAttribute;
+};
+
+template<> struct AttributeTraits<ArithmeticRightShiftAttribute> {
+ static const Attribute enum_value = Attribute_ArithmeticRightShiftAttribute;
};
template<> struct AttributeTraits<CondIfAttribute> {
@@ -1457,54 +1466,84 @@ inline flatbuffers::Offset<RescaleAttribute> CreateRescaleAttributeDirect(
per_channel);
}
-struct CustomAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+struct MulAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
- VT_IDENTIFIER = 4
+ VT_SHIFT = 4
};
- const flatbuffers::String *identifier() const {
- return GetPointer<const flatbuffers::String *>(VT_IDENTIFIER);
+ int32_t shift() const {
+ return GetField<int32_t>(VT_SHIFT, 0);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
- VerifyOffset(verifier, VT_IDENTIFIER) &&
- verifier.VerifyString(identifier()) &&
+ VerifyField<int32_t>(verifier, VT_SHIFT) &&
verifier.EndTable();
}
};
-struct CustomAttributeBuilder {
+struct MulAttributeBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
- void add_identifier(flatbuffers::Offset<flatbuffers::String> identifier) {
- fbb_.AddOffset(CustomAttribute::VT_IDENTIFIER, identifier);
+ void add_shift(int32_t shift) {
+ fbb_.AddElement<int32_t>(MulAttribute::VT_SHIFT, shift, 0);
}
- explicit CustomAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ explicit MulAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
- CustomAttributeBuilder &operator=(const CustomAttributeBuilder &);
- flatbuffers::Offset<CustomAttribute> Finish() {
+ MulAttributeBuilder &operator=(const MulAttributeBuilder &);
+ flatbuffers::Offset<MulAttribute> Finish() {
const auto end = fbb_.EndTable(start_);
- auto o = flatbuffers::Offset<CustomAttribute>(end);
+ auto o = flatbuffers::Offset<MulAttribute>(end);
return o;
}
};
-inline flatbuffers::Offset<CustomAttribute> CreateCustomAttribute(
+inline flatbuffers::Offset<MulAttribute> CreateMulAttribute(
flatbuffers::FlatBufferBuilder &_fbb,
- flatbuffers::Offset<flatbuffers::String> identifier = 0) {
- CustomAttributeBuilder builder_(_fbb);
- builder_.add_identifier(identifier);
+ int32_t shift = 0) {
+ MulAttributeBuilder builder_(_fbb);
+ builder_.add_shift(shift);
return builder_.Finish();
}
-inline flatbuffers::Offset<CustomAttribute> CreateCustomAttributeDirect(
+struct ArithmeticRightShiftAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_ROUND = 4
+ };
+ bool round() const {
+ return GetField<uint8_t>(VT_ROUND, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_ROUND) &&
+ verifier.EndTable();
+ }
+};
+
+struct ArithmeticRightShiftAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_round(bool round) {
+ fbb_.AddElement<uint8_t>(ArithmeticRightShiftAttribute::VT_ROUND, static_cast<uint8_t>(round), 0);
+ }
+ explicit ArithmeticRightShiftAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ArithmeticRightShiftAttributeBuilder &operator=(const ArithmeticRightShiftAttributeBuilder &);
+ flatbuffers::Offset<ArithmeticRightShiftAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ArithmeticRightShiftAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ArithmeticRightShiftAttribute> CreateArithmeticRightShiftAttribute(
flatbuffers::FlatBufferBuilder &_fbb,
- const char *identifier = nullptr) {
- auto identifier__ = identifier ? _fbb.CreateString(identifier) : 0;
- return tosa::CreateCustomAttribute(
- _fbb,
- identifier__);
+ bool round = false) {
+ ArithmeticRightShiftAttributeBuilder builder_(_fbb);
+ builder_.add_round(round);
+ return builder_.Finish();
}
struct CondIfAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
@@ -2066,8 +2105,11 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const RescaleAttribute *attribute_as_RescaleAttribute() const {
return attribute_type() == Attribute_RescaleAttribute ? static_cast<const RescaleAttribute *>(attribute()) : nullptr;
}
- const CustomAttribute *attribute_as_CustomAttribute() const {
- return attribute_type() == Attribute_CustomAttribute ? static_cast<const CustomAttribute *>(attribute()) : nullptr;
+ const MulAttribute *attribute_as_MulAttribute() const {
+ return attribute_type() == Attribute_MulAttribute ? static_cast<const MulAttribute *>(attribute()) : nullptr;
+ }
+ const ArithmeticRightShiftAttribute *attribute_as_ArithmeticRightShiftAttribute() const {
+ return attribute_type() == Attribute_ArithmeticRightShiftAttribute ? static_cast<const ArithmeticRightShiftAttribute *>(attribute()) : nullptr;
}
const CondIfAttribute *attribute_as_CondIfAttribute() const {
return attribute_type() == Attribute_CondIfAttribute ? static_cast<const CondIfAttribute *>(attribute()) : nullptr;
@@ -2163,8 +2205,12 @@ template<> inline const RescaleAttribute *TosaOperator::attribute_as<RescaleAttr
return attribute_as_RescaleAttribute();
}
-template<> inline const CustomAttribute *TosaOperator::attribute_as<CustomAttribute>() const {
- return attribute_as_CustomAttribute();
+template<> inline const MulAttribute *TosaOperator::attribute_as<MulAttribute>() const {
+ return attribute_as_MulAttribute();
+}
+
+template<> inline const ArithmeticRightShiftAttribute *TosaOperator::attribute_as<ArithmeticRightShiftAttribute>() const {
+ return attribute_as_ArithmeticRightShiftAttribute();
}
template<> inline const CondIfAttribute *TosaOperator::attribute_as<CondIfAttribute>() const {
@@ -2492,8 +2538,12 @@ inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, At
auto ptr = reinterpret_cast<const RescaleAttribute *>(obj);
return verifier.VerifyTable(ptr);
}
- case Attribute_CustomAttribute: {
- auto ptr = reinterpret_cast<const CustomAttribute *>(obj);
+ case Attribute_MulAttribute: {
+ auto ptr = reinterpret_cast<const MulAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_ArithmeticRightShiftAttribute: {
+ auto ptr = reinterpret_cast<const ArithmeticRightShiftAttribute *>(obj);
return verifier.VerifyTable(ptr);
}
case Attribute_CondIfAttribute: {
diff --git a/verif/tosa/ArithmeticRightShiftAttribute.py b/verif/tosa/ArithmeticRightShiftAttribute.py
new file mode 100644
index 0000000..eaa52ab
--- /dev/null
+++ b/verif/tosa/ArithmeticRightShiftAttribute.py
@@ -0,0 +1,45 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class ArithmeticRightShiftAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsArithmeticRightShiftAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = ArithmeticRightShiftAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # ArithmeticRightShiftAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # ArithmeticRightShiftAttribute
+ def Round(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
+def ArithmeticRightShiftAttributeStart(builder): builder.StartObject(1)
+def ArithmeticRightShiftAttributeAddRound(builder, round): builder.PrependBoolSlot(0, round, 0)
+def ArithmeticRightShiftAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/Attribute.py b/verif/tosa/Attribute.py
index a4d96e0..5d79a08 100644
--- a/verif/tosa/Attribute.py
+++ b/verif/tosa/Attribute.py
@@ -30,7 +30,8 @@ class Attribute(object):
ResizeAttribute = 9
ClampAttribute = 10
RescaleAttribute = 11
- CustomAttribute = 12
- CondIfAttribute = 13
- WhileLoopAttribute = 14
+ MulAttribute = 12
+ ArithmeticRightShiftAttribute = 13
+ CondIfAttribute = 14
+ WhileLoopAttribute = 15
diff --git a/verif/tosa/MulAttribute.py b/verif/tosa/MulAttribute.py
new file mode 100644
index 0000000..f45b285
--- /dev/null
+++ b/verif/tosa/MulAttribute.py
@@ -0,0 +1,45 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class MulAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsMulAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = MulAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # MulAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # MulAttribute
+ def Shift(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+def MulAttributeStart(builder): builder.StartObject(1)
+def MulAttributeAddShift(builder, shift): builder.PrependInt32Slot(0, shift, 0)
+def MulAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py
index 7ba68c3..07e0e1a 100644
--- a/verif/tosa_serializer.py
+++ b/verif/tosa_serializer.py
@@ -247,6 +247,24 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.bools.append((a.RescaleAttributeAddPerChannel,
per_channel))
+ def MulAttribute(self, shift):
+ from tosa import MulAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().MulAttribute
+ self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd)
+
+ self.ints.append((a.MulAttributeAddShift,
+ shift))
+
+ def ArithmeticRightShiftAttribute(self, round):
+ from tosa import ArithmeticRightShiftAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
+ self.optFcns = (a.ArithmeticRightShiftAttributeStart, a.ArithmeticRightShiftAttributeEnd)
+
+ self.bools.append((a.ArithmeticRightShiftAttributeAddRound,
+ round))
+
def CustomAttribute(self, identifier):
from tosa import CustomAttribute as a, Attribute
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index dc2d803..302e4f4 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -489,6 +489,30 @@ class TosaArgGen:
return arg_list
+ @staticmethod
+ def agMul(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ if dtype is DType.INT32:
+ for p in range(testGen.args.num_rand_permutations):
+
+ shift = testGen.randInt(0, 32)
+
+ arg_list.append(('perm{}_shift{}'.format(p, shift), [shift]))
+ else:
+ arg_list.append(('shift0', [0]))
+
+ return arg_list
+
+ @staticmethod
+ def agArithmeticRightShift(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ arg_list.append(('roundTrue', [True]))
+ arg_list.append(('roundFalse', [False]))
+
+ return arg_list
+
# Helper function for reshape. Gets some factors of a larger number.
@staticmethod
def getFactors(val, start=1):
@@ -647,7 +671,7 @@ class TosaArgGen:
arg_list.append(('mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}'.format(m, shift, output_dims[0], output_dims[1],
testGen.typeStr(outputDType), stride[0], stride[1],
offset[0], offset[1]),
- [m, stride, offset, shift, output_dims, outputDType]))
+ [m, stride, offset, shift, output_dims, dtype, outputDType]))
return arg_list
@@ -850,7 +874,16 @@ class TosaTestGen:
self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
return result_tens
- def build_mul(self, op, a, b):
+ def build_arithmetic_right_shift(self, op, a, b, round):
+ result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.ArithmeticRightShiftAttribute(round)
+
+ self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
+ return result_tens
+
+ def build_mul(self, op, a, b, shift):
result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
# Special for multiply:
@@ -858,7 +891,10 @@ class TosaTestGen:
if a.dtype != DType.FLOAT:
result_tens.setDtype(DType.INT32)
- self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
+ attr = ts.TosaSerializerAttribute()
+ attr.MulAttribute(shift)
+
+ self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
return result_tens
def build_table(self, op, a):
@@ -1121,8 +1157,8 @@ class TosaTestGen:
return result_tens
- def build_resize(self, op, input, mode, stride, offset, shift, output_dims, output_dtype):
- result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, output_dims, output_dtype)
+ def build_resize(self, op, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype):
+ result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype)
attr = ts.TosaSerializerAttribute()
attr.ResizeAttribute(output_dims, stride, offset, shift, mode)
@@ -1191,6 +1227,8 @@ class TosaTestGen:
for i in range(nc):
multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(scale_arr[i], scale32)
+ if shift_arr[i] < 2 or shift_arr[i] > 62:
+ self.ser.setExpectedFailure(True, 'OpRescale: invalid shift value')
#print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
@@ -1413,8 +1451,30 @@ class TosaTestGen:
# Build the random tensor operands and the test
tens = []
- tens.extend(self.buildPlaceholderTensors(shapeList[0:pCount], dtype))
- tens.extend(self.buildConstTensors(shapeList[pCount:], dtype))
+
+ # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
+ if op['op'] == Op.ARITHMETIC_RIGHT_SHIFT:
+ assert pCount == 2 and cCount == 0, 'Op.ArithmeticRightShift must have 2 placeholders, 0 consts'
+
+ placeholders = []
+ for idx, shape in enumerate(shapeList[:]):
+ if idx == 1:
+ if dtype == DType.INT8:
+ arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
+ elif dtype == DType.INT16:
+ arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
+ elif dtype == DType.INT32:
+ arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
+ else:
+ raise Exception('OpArithmeticRightShift: invalid input dtype')
+ else:
+ arr = self.getRandTensor(shapeList[0], dtype)
+ placeholders.append(self.ser.addPlaceholder(shape, dtype, Usage.ACTIVATION, [], arr))
+
+ tens.extend(placeholders)
+ else:
+ tens.extend(self.buildPlaceholderTensors(shapeList[0:pCount], dtype))
+ tens.extend(self.buildConstTensors(shapeList[pCount:], dtype))
if qgen is not None:
qinfo = qgen(self, op, dtype)
@@ -1536,7 +1596,7 @@ class TosaTestGen:
'arithmetic_right_shift':
{ 'op': Op.ARITHMETIC_RIGHT_SHIFT,
'operands': (2, 0),
- 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'build_fcn': (build_arithmetic_right_shift, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agArithmeticRightShift),
'types': TYPE_PURE_INT },
'bitwise_and':
@@ -1602,7 +1662,7 @@ class TosaTestGen:
'mul':
{ 'op': Op.MUL,
'operands': (2, 0),
- 'build_fcn': (build_mul, TosaTensorGen.tgBroadcastFuzz, None),
+ 'build_fcn': (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
'types': TYPE_PURE_INT_FP },
'pow':
@@ -2271,13 +2331,36 @@ class OutputShaper:
return ser.addOutput(input.shape, DType.INT32, input.usage, input.dformat)
@staticmethod
- def resizeOp(ser, input, mode, stride, offset, shift, output_dims, output_dtype):
+ def resizeOp(ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype):
output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
if stride[0] <= 0 or stride[1] <= 0:
ser.setExpectedFailure(True, 'Negative or zero stride')
+ if mode == ResizeMode.BILINEAR:
+ if input_dtype == DType.INT8:
+ if output_dtype != DType.INT32:
+ ser.setExpectedFailure(True, 'Invalid output data type')
+ elif input_dtype == DType.INT16:
+ if output_dtype != DType.INT48:
+ ser.setexpectedfailure(true, 'Invalid output data type')
+ else:
+ ser.setexpectedfailure(true, 'Invalid input data type')
+
+ elif mode == ResizeMode.NEAREST:
+ if input_dtype == DType.INT8:
+ if output_dtype != DType.INT8:
+ ser.setExpectedFailure(True, 'Invalid output data type')
+ elif input_dtype == DType.INT16:
+ if output_dtype != DType.INT16:
+ ser.setexpectedfailure(true, 'Invalid output data type')
+ else:
+ ser.setexpectedfailure(true, 'Invalid input data type')
+
+ else:
+ ser.setexpectedfailure(true, 'Invalid resize mode')
+
return ser.addOutput(output_dims, output_dtype, input.usage, input.dformat)
@staticmethod