From b5fabec33abeca2d92c20c7b094fa3f113d0ddd8 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 7 Jun 2022 05:20:44 +0000 Subject: Remove quantization info from serialization attributes Any needed information moves into the attributes for each operator. New serialization library version removes teh quantization information attributes from the schema Signed-off-by: Eric Kunze Change-Id: Icf6165687ab1fd34a01f64c01b0b92b2820e72fa --- reference_model/src/graph_node.h | 12 --- reference_model/src/main.cpp | 2 +- reference_model/src/ops/activation_funcs.h | 6 +- reference_model/src/ops/comparison.h | 12 +-- reference_model/src/ops/data_layout.cc | 20 ---- reference_model/src/ops/data_layout.h | 23 +++-- reference_model/src/ops/data_nodes.cc | 1 - reference_model/src/ops/data_nodes.h | 2 +- reference_model/src/ops/ewise_binary.cc | 2 - reference_model/src/ops/ewise_binary.h | 23 +++-- reference_model/src/ops/ewise_ternary.cc | 1 - reference_model/src/ops/ewise_ternary.h | 10 +- reference_model/src/ops/ewise_unary.cc | 34 +++++-- reference_model/src/ops/ewise_unary.h | 43 ++++----- reference_model/src/ops/image.cc | 2 - reference_model/src/ops/image.h | 2 +- reference_model/src/ops/op_factory.cc | 1 - reference_model/src/ops/op_factory.h | 30 +++--- reference_model/src/ops/reduction.h | 14 +-- reference_model/src/ops/scatter_gather.cc | 2 - reference_model/src/ops/scatter_gather.h | 4 +- reference_model/src/ops/tensor_ops.cc | 144 +++++++++-------------------- reference_model/src/ops/tensor_ops.h | 29 +++--- reference_model/src/ops/type_conversion.cc | 2 - reference_model/src/ops/type_conversion.h | 4 +- reference_model/src/subgraph_traverser.cc | 2 +- thirdparty/serialization_lib | 2 +- verif/generator/tosa_arg_gen.py | 119 +++++++++++------------- verif/generator/tosa_error_if.py | 7 +- verif/generator/tosa_test_gen.py | 103 +++++++++++---------- 30 files changed, 279 insertions(+), 379 deletions(-) diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index 14a8acc..787e89d 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -17,7 +17,6 @@ #define GRAPH_NODE_H #include "attribute.h" -#include "quant_info.h" #include "subgraph_traverser.h" #include "tensor.h" #include "tosa_generated.h" @@ -126,17 +125,6 @@ FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute"); \ } -#define INIT_QINFO(QINFO_NAME) \ - if (auto p = dynamic_cast(qinfo_)) \ - { \ - qinfo = new Tosa##QINFO_NAME##QuantInfo(p); \ - ASSERT_MEM(qinfo); \ - } \ - else \ - { \ - qinfo = nullptr; \ - } - namespace TosaReference { diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index f8a446a..2b70e94 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -26,7 +26,7 @@ #include #define MODEL_VERSION_MAJOR 0 -#define MODEL_VERSION_MINOR 25 +#define MODEL_VERSION_MINOR 30 #define MODEL_VERSION_PATCH 0 #define MODEL_VERSION_DRAFT true diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h index 10385e0..4853971 100644 --- a/reference_model/src/ops/activation_funcs.h +++ b/reference_model/src/ops/activation_funcs.h @@ -28,7 +28,7 @@ template class OpClamp : public UnaryNode { public: - OpClamp(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpClamp(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : UnaryNode(sgt_, Op_CLAMP, id_) { INIT_ATTRIBUTE(Clamp); @@ -48,7 +48,7 @@ template class OpSigmoid : public UnaryNode { public: - OpSigmoid(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpSigmoid(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : UnaryNode(sgt_, Op_SIGMOID, id_) { register_fcn(); @@ -64,7 +64,7 @@ template class OpTanh : public UnaryNode { public: - OpTanh(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpTanh(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : UnaryNode(sgt_, Op_TANH, id_) { register_fcn(); diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h index 5b4d0f1..29e6b5a 100644 --- a/reference_model/src/ops/comparison.h +++ b/reference_model/src/ops/comparison.h @@ -28,8 +28,8 @@ template class OpEqual : public BinaryNode { public: - OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode(sgt_, Op_EQUAL, qinfo_, id_) + OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : BinaryNode(sgt_, Op_EQUAL, id_) { register_fcn(); } @@ -42,8 +42,8 @@ template class OpGreater : public BinaryNode { public: - OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode(sgt_, Op_GREATER, qinfo_, id_) + OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : BinaryNode(sgt_, Op_GREATER, id_) { register_fcn(); } @@ -56,8 +56,8 @@ template class OpGreaterEqual : public BinaryNode { public: - OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode(sgt_, Op_EQUAL, qinfo_, id_) + OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : BinaryNode(sgt_, Op_EQUAL, id_) { register_fcn(); } diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 49f53e8..9fe429b 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -23,7 +23,6 @@ using namespace tosa; template OpConcat::OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_CONCAT, id_) { @@ -124,22 +123,18 @@ int OpConcat::eval() template OpPad::OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_PAD, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); - INIT_QINFO(Pad); INIT_ATTRIBUTE(Pad); } template OpPad::~OpPad() { - if (qinfo) - delete qinfo; } template @@ -177,11 +172,6 @@ int OpPad::checkTensorAttributes() paddings_array[i] = std::make_pair(pad_front, pad_back); } - if (this->qinfo && Dtype != DType_INT8) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpPad: zeropoint should be 0"); - } - return 0; } @@ -206,11 +196,6 @@ int OpPad::eval() break; } - if (this->qinfo && Dtype == DType_INT8) - { - pad_value += (InEigenType)this->qinfo->input_zp(); - } - this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value); return GraphNode::eval(); @@ -219,7 +204,6 @@ int OpPad::eval() template OpReshape::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_RESHAPE, id_) { @@ -315,7 +299,6 @@ int OpReshape::eval() template OpReverse::OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_REVERSE, id_) { @@ -383,7 +366,6 @@ int OpReverse::eval() template OpSlice::OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_SLICE, id_) { @@ -451,7 +433,6 @@ int OpSlice::eval() template OpTileBase::OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_TILE, id_) { @@ -586,7 +567,6 @@ int OpTile<4, Dtype>::eval() template OpTranspose::OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE, id_) { diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index bad88e4..c6513ae 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -27,7 +27,7 @@ template class OpConcat : public GraphNode { public: - OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConcat(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template class OpPad : public GraphNode { public: - OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpPad(); virtual int checkTensorAttributes(); virtual int eval(); @@ -63,7 +63,6 @@ protected: Eigen::array, Rank> paddings_array; TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; - TosaPadQuantInfo* qinfo; TosaPadAttribute* attribute; }; @@ -71,7 +70,7 @@ template class OpReshape : public GraphNode { public: - OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpReshape(); virtual int checkTensorAttributes(); @@ -95,7 +94,7 @@ template class OpReverse : public GraphNode { public: - OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpReverse(); virtual int checkTensorAttributes(); @@ -117,7 +116,7 @@ template class OpSlice : public GraphNode { public: - OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpSlice(); virtual int checkTensorAttributes(); @@ -140,7 +139,7 @@ template class OpTileBase : public GraphNode { public: - OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTileBase(); virtual int checkTensorAttributes(); @@ -161,8 +160,8 @@ template class OpTile : public OpTileBase { public: - OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpTileBase(sgt_, attribute_, qinfo_, id_) + OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : OpTileBase(sgt_, attribute_, id_) {} protected: @@ -175,8 +174,8 @@ protected: class OpTile : public OpTileBase \ { \ public: \ - OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : OpTileBase(sgt_, attribute_, qinfo_, id_) \ + OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ + : OpTileBase(sgt_, attribute_, id_) \ {} \ \ protected: \ @@ -194,7 +193,7 @@ template class OpTranspose : public GraphNode { public: - OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTranspose(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index ec4bc41..30c9511 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -45,7 +45,6 @@ int OpConst::eval() template OpIdentity::OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_IDENTITY, id_) { diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h index 407cf0a..8761a08 100644 --- a/reference_model/src/ops/data_nodes.h +++ b/reference_model/src/ops/data_nodes.h @@ -35,7 +35,7 @@ template class OpIdentity : public GraphNode { public: - OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpIdentity(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 7f30e30..eadefaa 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -25,7 +25,6 @@ using namespace tosa; template BinaryNodeBase::BinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, op_, id_) { @@ -490,7 +489,6 @@ int OpSub::register_fcn() template OpTable::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_TABLE, id_) { diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 373dfb8..b2c92a4 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -42,7 +42,7 @@ template class BinaryNodeBase : public GraphNode { public: - BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_); + BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, const uint64_t id_); virtual ~BinaryNodeBase(); virtual int checkTensorAttributes() final; @@ -71,8 +71,8 @@ template class BinaryNode : public BinaryNodeBase { public: - BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) - : BinaryNodeBase(sgt_, op_, qinfo_, id_) + BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_) + : BinaryNodeBase(sgt_, op_, id_) {} virtual ~BinaryNode() {} @@ -90,8 +90,8 @@ template class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype> { public: - BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) - : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, qinfo_, id_) + BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_) + : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, id_) {} virtual ~BinaryNode() {} @@ -104,8 +104,8 @@ public: class Op##Opname : public BinaryNode \ { \ public: \ - Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : BinaryNode(sgt_, Op_##OPNAME, qinfo_, id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ + : BinaryNode(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ } \ @@ -139,9 +139,8 @@ class OpArithmeticRightShift : public BinaryNode public: OpArithmeticRightShift(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_) + : BinaryNode(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, id_) { INIT_ATTRIBUTE(ArithmeticRightShift); register_fcn(); @@ -158,8 +157,8 @@ template class OpMul : public BinaryNode { public: - OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode(sgt_, Op_MUL, qinfo_, id_) + OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : BinaryNode(sgt_, Op_MUL, id_) { INIT_ATTRIBUTE(Mul); register_fcn(); @@ -178,7 +177,7 @@ template class OpTable : public GraphNode { public: - OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTable(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index c265077..72fe5a0 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -22,7 +22,6 @@ using namespace tosa; template OpSelectBase::OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_SELECT, id_) { diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h index b80fb23..75a2194 100644 --- a/reference_model/src/ops/ewise_ternary.h +++ b/reference_model/src/ops/ewise_ternary.h @@ -33,7 +33,7 @@ template class OpSelectBase : public GraphNode { public: - OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpSelectBase(); virtual int checkTensorAttributes(); @@ -59,8 +59,8 @@ template class OpSelect : public OpSelectBase { public: - OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpSelectBase(sgt_, attribute_, qinfo_, id_) + OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : OpSelectBase(sgt_, attribute_, id_) {} virtual int eval(); int broadcast(); @@ -73,8 +73,8 @@ template class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype> { public: - OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpSelectBase<0, Dtype>(sgt_, attribute_, qinfo_, id_) + OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : OpSelectBase<0, Dtype>(sgt_, attribute_, id_) {} virtual int eval(); }; diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 8b83a50..8ef1e3c 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -29,7 +29,10 @@ UnaryNode::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64 setRequiredOperands(1, 1); setRequiredRank(0, 6); - fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); }; + fcn = [](InEigenType a) -> OutEigenType { + ASSERT_MSG(0, "In default UnaryNode function, missing function registration"); + return OutEigenType(); + }; } template @@ -210,14 +213,29 @@ int OpLogicalNot::register_fcn() return 0; } +template +OpNegate::OpNegate(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) + : UnaryNode(sgt_, Op_NEGATE, id_) +{ + INIT_ATTRIBUTE(Negate); + + register_fcn(); +} + +template +OpNegate::~OpNegate() +{ + if (attribute) + delete attribute; +} + template int OpNegate::register_fcn() { - if (Dtype != DType_INT8 && this->qinfo) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpNegate: zeropoint only for int8_t"); - ERROR_IF(this->qinfo->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); - } + ERROR_IF(Dtype != DType_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t"); + ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); switch (Dtype) { @@ -251,11 +269,11 @@ int OpNegate::register_fcn() break; case DType_INT8: this->fcn = [this](InEigenType a) -> OutEigenType { - int64_t res_in_64 = 0 - (a - this->qinfo->input_zp()); + int64_t res_in_64 = 0 - (a - attribute->input1_zp()); int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)"); - res_in_64 += this->qinfo->output_zp(); + res_in_64 += attribute->output_zp(); InEigenType result = static_cast(std::min(std::max(res_in_64, static_cast(QMin)), static_cast(QMax))); return result; }; diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h index 374c8e4..16a4c88 100644 --- a/reference_model/src/ops/ewise_unary.h +++ b/reference_model/src/ops/ewise_unary.h @@ -49,7 +49,7 @@ protected: class Op##Opname : public UnaryNode \ { \ public: \ - Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ : UnaryNode(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ @@ -61,27 +61,6 @@ protected: virtual int register_fcn(); \ }; -#define DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Opname, OPNAME) \ - template \ - class Op##Opname : public UnaryNode \ - { \ - public: \ - Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : UnaryNode(sgt_, Op_##OPNAME, id_) \ - { \ - INIT_QINFO(Unary); \ - register_fcn(); \ - } \ - static constexpr int32_t QMin = GetQMin::value; \ - static constexpr int32_t QMax = GetQMax::value; \ - using InEigenType = typename GetEigenType::type; \ - using OutEigenType = typename GetEigenType::type; \ - virtual int register_fcn(); \ - \ - protected: \ - TosaUnaryQuantInfo* qinfo; \ - }; - DEF_TEMPLATE_UNARY_OP(Abs, ABS) DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT) DEF_TEMPLATE_UNARY_OP(Ceil, CEIL) @@ -90,12 +69,28 @@ DEF_TEMPLATE_UNARY_OP(Exp, EXP) DEF_TEMPLATE_UNARY_OP(Floor, FLOOR) DEF_TEMPLATE_UNARY_OP(Log, LOG) DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT) -DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Negate, NEGATE) DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL) DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT) #undef DEF_TEMPLATE_UNARY_OP -#undef DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO + +// Negate is the only unary op with attributes +template +class OpNegate : public UnaryNode +{ +public: + OpNegate(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); + virtual ~OpNegate(); + + static constexpr int32_t QMin = GetQMin::value; + static constexpr int32_t QMax = GetQMax::value; + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + virtual int register_fcn(); + +protected: + tosa::TosaNegateAttribute* attribute; +}; }; // namespace TosaReference diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 6dec1bc..701fdfb 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -15,7 +15,6 @@ #include "image.h" #include "arith_util.h" -#include "quant_util.h" using namespace TosaReference; using namespace Eigen; @@ -24,7 +23,6 @@ using namespace tosa; template OpResize::OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_RESIZE, id_) { diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h index 095dc7d..fea4885 100644 --- a/reference_model/src/ops/image.h +++ b/reference_model/src/ops/image.h @@ -27,7 +27,7 @@ template class OpResize : public GraphNode { public: - OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpResize(); virtual int checkTensorAttributes() final; virtual int eval(); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index f7ded9a..62d5b11 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -36,7 +36,6 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, TosaSerializationHandler* tsh, Op opType, TosaAttributeBase* attribute, - TosaQuantInfoBase* qinfo, uint64_t id, DType inputDType, int inputRank, diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index eaa359c..4cb8178 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -18,61 +18,60 @@ #include "attribute.h" #include "graph_node.h" -#include "quant_info.h" #include "template_types.h" #include "tosa_serialization_handler.h" #define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \ case RANK: \ - return new OP(sgt, attribute, qinfo, id); + return new OP(sgt, attribute, id); #define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ case RANK: \ - return new OP(sgt, attribute, qinfo, id); + return new OP(sgt, attribute, id); #define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \ case RANK2: \ - return new OP(sgt, attribute, qinfo, id); + return new OP(sgt, attribute, id); #define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ case RANK2: \ - return new OP(sgt, attribute, qinfo, id); + return new OP(sgt, attribute, id); #define DEF_FACTORY_ONE_RANK_0_6(OP) \ switch (inputRank) \ { \ case 0: \ - return new OP<0>(sgt, attribute, qinfo, id); \ + return new OP<0>(sgt, attribute, id); \ case 1: \ - return new OP<1>(sgt, attribute, qinfo, id); \ + return new OP<1>(sgt, attribute, id); \ case 2: \ - return new OP<2>(sgt, attribute, qinfo, id); \ + return new OP<2>(sgt, attribute, id); \ case 3: \ - return new OP<3>(sgt, attribute, qinfo, id); \ + return new OP<3>(sgt, attribute, id); \ case 4: \ - return new OP<4>(sgt, attribute, qinfo, id); \ + return new OP<4>(sgt, attribute, id); \ case 5: \ - return new OP<5>(sgt, attribute, qinfo, id); \ + return new OP<5>(sgt, attribute, id); \ case 6: \ - return new OP<6>(sgt, attribute, qinfo, id); \ + return new OP<6>(sgt, attribute, id); \ } #define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \ if (inputDType == DType_##DTYPE) \ { \ - return new OP(sgt, attribute, qinfo, id); \ + return new OP(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \ { \ - return new OP(sgt, attribute, qinfo, id); \ + return new OP(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_RESIZE(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ - return new OP(sgt, attribute, qinfo, id); \ + return new OP(sgt, attribute, id); \ } #define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ @@ -231,7 +230,6 @@ public: tosa::TosaSerializationHandler* tsh, tosa::Op opType, tosa::TosaAttributeBase* attribute, - tosa::TosaQuantInfoBase* qinfo, uint64_t id, tosa::DType inputDType, int inputRank, diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h index f3407f4..6e98a76 100644 --- a/reference_model/src/ops/reduction.h +++ b/reference_model/src/ops/reduction.h @@ -48,7 +48,7 @@ template class OpReduceAll : public ReduceNode { public: - OpReduceAll(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceAll(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode(sgt_, Op_REDUCE_ALL, attribute_, id_) {} virtual int eval(); @@ -58,7 +58,7 @@ template class OpReduceAny : public ReduceNode { public: - OpReduceAny(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceAny(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode(sgt_, Op_REDUCE_ALL, attribute_, id_) {} virtual int eval(); @@ -68,7 +68,7 @@ template class OpReduceMax : public ReduceNode { public: - OpReduceMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode(sgt_, Op_REDUCE_MAX, attribute_, id_) {} virtual int eval(); @@ -78,7 +78,7 @@ template class OpReduceMin : public ReduceNode { public: - OpReduceMin(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceMin(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode(sgt_, Op_REDUCE_MIN, attribute_, id_) {} virtual int eval(); @@ -88,7 +88,7 @@ template class OpReduceProduct : public ReduceNode { public: - OpReduceProduct(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceProduct(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode(sgt_, Op_REDUCE_PRODUCT, attribute_, id_) {} virtual int eval(); @@ -98,7 +98,7 @@ template class OpReduceSum : public ReduceNode { public: - OpReduceSum(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceSum(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode(sgt_, Op_REDUCE_SUM, attribute_, id_) {} virtual int eval(); @@ -108,7 +108,7 @@ template class OpReduceSumInt : public ReduceNode { public: - OpReduceSumInt(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceSumInt(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode(sgt_, Op_REDUCE_SUM, attribute_, id_) {} virtual int eval(); diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index 02ec54f..faf7db9 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -23,7 +23,6 @@ using namespace tosa; template OpGather::OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_GATHER, id_) { @@ -120,7 +119,6 @@ int OpGather::eval() template OpScatter::OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_SCATTER, id_) { diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h index 66b584a..af09153 100644 --- a/reference_model/src/ops/scatter_gather.h +++ b/reference_model/src/ops/scatter_gather.h @@ -27,7 +27,7 @@ template class OpGather : public GraphNode { public: - OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpGather(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template class OpScatter : public GraphNode { public: - OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpScatter(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index aef1ad2..3ab4d56 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -114,8 +114,7 @@ int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute, return 0; } -int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute, - tosa::TosaConvQuantInfo* qinfo, +int check_conv_attribute(tosa::TosaConvAttribute* attribute, uint32_t conv_dimension, std::vector input_shape, std::vector output_shape, @@ -226,18 +225,13 @@ int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute, return 1; } - if (qinfo) - { - if (InDtype != DType_INT8 && qinfo->input_zp() != 0) - { - msg = "zeropoint only for int8_t"; - return 1; - } - if (WeightDtype != DType_INT8 && qinfo->weight_zp() != 0) - { - msg = "zeropoint only for int8_t"; - return 1; - } + if (InDtype != DType_INT8 && attribute->input_zp() != 0) { + msg = "Input zero point must be zero for non-int8 data"; + return 1; + } + if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) { + msg = "Weight zero point must be zero for non-int8 data"; + return 1; } return 0; @@ -246,7 +240,6 @@ int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute, template OpArgMax::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_ARGMAX, id_) { @@ -339,7 +332,6 @@ int OpArgMax::eval() template OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_AVG_POOL2D, id_) { @@ -347,7 +339,6 @@ OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, setRequiredRank(4); INIT_ATTRIBUTE(Pool); - INIT_QINFO(Unary); } template @@ -377,11 +368,8 @@ int OpAvgPool2d::checkTensorAttributes() in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); - if (Dtype != DType_INT8 && this->qinfo) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t"); - ERROR_IF(this->qinfo->output_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t"); - } + ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data"); std::string msg; if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg)) @@ -474,9 +462,9 @@ int OpAvgPool2d::eval() pad[3] = std::make_pair(0, 0); ETensor4 input_val = this->in->getTensor(); - if (this->qinfo) + if (Dtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); } ETensor4 input_padded = input_val.pad(pad); @@ -537,7 +525,7 @@ int OpAvgPool2d::eval() { REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str()); } - this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp()); + this->out->getTensor() = this->out->getTensor() + (OutEigenType)(attribute->output_zp()); this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin); this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax); } @@ -552,7 +540,6 @@ int OpAvgPool2d::eval() template OpConv2d::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_CONV2D, id_) { @@ -560,7 +547,6 @@ OpConv2d::OpConv2d(SubgraphTraverser* sgt_, setRequiredRank(4); INIT_ATTRIBUTE(Conv); - INIT_QINFO(Conv); } template @@ -568,8 +554,6 @@ OpConv2d::~OpConv2d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template @@ -598,7 +582,7 @@ int OpConv2d::checkTensorAttributes() output = dynamic_cast*>(outputs[0]); std::string msg; - if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(), + if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(), weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg)) { msg = "OpConv2d: " + msg; @@ -691,10 +675,10 @@ int OpConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } ETensor4 input_padded = input_val.pad(pad); @@ -739,7 +723,6 @@ int OpConv2d::eval() template OpConv3d::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_CONV3D, id_) { @@ -747,7 +730,6 @@ OpConv3d::OpConv3d(SubgraphTraverser* sgt_, setRequiredRank(5); INIT_ATTRIBUTE(Conv); - INIT_QINFO(Conv); } template @@ -755,8 +737,6 @@ OpConv3d::~OpConv3d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template @@ -785,7 +765,7 @@ int OpConv3d::checkTensorAttributes() output = dynamic_cast*>(outputs[0]); std::string msg; - if (check_conv_attribute_qinfo(attribute, qinfo, 3 /* conv_dimension */, input->getShape(), output->getShape(), + if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(), weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg)) { msg = "OpConv3d: " + msg; @@ -858,10 +838,10 @@ int OpConv3d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } ETensor5 input_padded = input_val.pad(pad); @@ -931,7 +911,6 @@ int OpConv3d::eval() template OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { @@ -939,7 +918,6 @@ OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sg setRequiredRank(4); INIT_ATTRIBUTE(Conv); - INIT_QINFO(Conv); } template @@ -947,8 +925,6 @@ OpDepthwiseConv2d::~OpDepthwiseConv2d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template @@ -977,7 +953,7 @@ int OpDepthwiseConv2d::checkTensorAttributes() output = dynamic_cast*>(outputs[0]); std::string msg; - if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(), + if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(), weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg)) { msg = "OpDepthwiseConv2d: " + msg; @@ -1041,10 +1017,10 @@ int OpDepthwiseConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } ETensor4 input_padded = input_val.pad(pad); @@ -1108,21 +1084,20 @@ int OpDepthwiseConv2d::eval() template OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) { setRequiredOperands(3, 1); setRequiredRank(2); - INIT_QINFO(Conv); + INIT_ATTRIBUTE(FullyConnected); } template OpFullyConnected::~OpFullyConnected() { - if (qinfo) - delete qinfo; + if (attribute) + delete attribute; } template @@ -1157,17 +1132,8 @@ int OpFullyConnected::checkTensorAttributes() output = dynamic_cast*>(outputs[0]); - if (this->qinfo) - { - if (InDtype != DType_INT8) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t"); - } - if (WeightDtype != DType_INT8) - { - ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t"); - } - } + ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data"); return 0; } @@ -1190,10 +1156,10 @@ int OpFullyConnected::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } this->output->getTensor() = @@ -1211,21 +1177,20 @@ int OpFullyConnected::eval() template OpMatMul::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_MATMUL, id_) { setRequiredOperands(2, 1); setRequiredRank(3); - INIT_QINFO(MatMul); + INIT_ATTRIBUTE(MatMul); } template OpMatMul::~OpMatMul() { - if (qinfo) - delete qinfo; + if (attribute) + delete attribute; } template @@ -1284,11 +1249,8 @@ int OpMatMul::checkTensorAttributes() } W = b->getShape()[2]; - if (Dtype != DType_INT8 && this->qinfo) - { - ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t"); - ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t"); - } + ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data"); return 0; } @@ -1301,10 +1263,10 @@ int OpMatMul::eval() TIn a_val = this->a->getTensor(); TIn b_val = this->b->getTensor(); - if (this->qinfo) + if (Dtype == DType_INT8) { - a_val = a_val - (InEigenType)this->qinfo->a_zp(); - b_val = b_val - (InEigenType)this->qinfo->b_zp(); + a_val = a_val - (InEigenType)attribute->a_zp(); + b_val = b_val - (InEigenType)attribute->b_zp(); } Eigen::array a_rank2_shape({ H, C }); @@ -1351,7 +1313,6 @@ int OpMatMul::eval() template OpMaxPool2d::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_MAX_POOL2D, id_) { @@ -1484,7 +1445,6 @@ int OpMaxPool2d::eval() template OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { @@ -1492,7 +1452,6 @@ OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sg setRequiredRank(4); INIT_ATTRIBUTE(TransposeConv); - INIT_QINFO(Conv); } template @@ -1500,8 +1459,6 @@ OpTransposeConv2d::~OpTransposeConv2d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template @@ -1595,17 +1552,8 @@ int OpTransposeConv2d::checkTensorAttributes() return 1; } - if (this->qinfo) - { - if (InDtype != DType_INT8) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t"); - } - if (WeightDtype != DType_INT8) - { - ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t"); - } - } + ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data"); return 0; } @@ -1655,10 +1603,10 @@ int OpTransposeConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } Eigen::array reshape_dim; diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 05b1ca1..24eadeb 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -28,7 +28,7 @@ template class OpArgMax : public GraphNode { public: - OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpArgMax(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template class OpAvgPool2d : public GraphNode { public: - OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpAvgPool2d(); virtual int checkTensorAttributes(); @@ -69,7 +69,6 @@ protected: TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; tosa::TosaPoolAttribute* attribute; - tosa::TosaUnaryQuantInfo* qinfo; protected: // return a 1D [N] tensor that describes a how many valid elements covered in the input space @@ -80,7 +79,7 @@ template class OpConv2d : public GraphNode { public: - OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConv2d(); virtual int checkTensorAttributes() final; @@ -105,14 +104,13 @@ protected: TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template class OpConv3d : public GraphNode { public: - OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConv3d(); virtual int checkTensorAttributes() final; @@ -137,14 +135,13 @@ protected: TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template class OpDepthwiseConv2d : public GraphNode { public: - OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpDepthwiseConv2d(); virtual int checkTensorAttributes() final; @@ -169,14 +166,13 @@ protected: TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template class OpFullyConnected : public GraphNode { public: - OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpFullyConnected(); virtual int checkTensorAttributes() final; @@ -199,14 +195,15 @@ protected: TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; - tosa::TosaConvQuantInfo* qinfo; + + tosa::TosaFullyConnectedAttribute* attribute; }; template class OpMatMul : public GraphNode { public: - OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpMatMul(); virtual int checkTensorAttributes() final; @@ -230,14 +227,15 @@ protected: int64_t H; int64_t W; int64_t C; - tosa::TosaMatMulQuantInfo* qinfo; + + tosa::TosaMatMulAttribute* attribute; }; template class OpMaxPool2d : public GraphNode { public: - OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpMaxPool2d(); virtual int checkTensorAttributes(); @@ -258,7 +256,7 @@ template class OpTransposeConv2d : public GraphNode { public: - OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTransposeConv2d(); virtual int checkTensorAttributes() final; @@ -283,7 +281,6 @@ protected: TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; TosaTransposeConvAttribute* attribute; - TosaConvQuantInfo* qinfo; }; }; // namespace TosaReference diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 7ee9692..ac54932 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -25,7 +25,6 @@ using namespace tosa; template OpRescale::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_RESCALE, id_) { @@ -218,7 +217,6 @@ int OpRescale::eval() template OpCast::OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_CAST, id_) { diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 060e14e..53470d1 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -26,7 +26,7 @@ template class OpRescale : public GraphNode { public: - OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpRescale(); virtual int checkTensorAttributes() final; @@ -140,7 +140,7 @@ template class OpCast : public GraphNode { public: - OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpCast(); virtual int checkTensorAttributes() final; diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 36e0a63..d0cc6cf 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -197,7 +197,7 @@ int SubgraphTraverser::initializeGraph() DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx, EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size()); - GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype, + GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, input_rank, output_dtype, output_rank, weight_dtype, weight_rank); if (!node) { diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index a336d54..bdcc3fe 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit a336d54aca08b06953a8b6c49d7e5f6c4899952e +Subproject commit bdcc3fee1b8bf55aac50e060115b92a1ccf9741c diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index a27d849..8e00fab 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -4,7 +4,6 @@ import itertools import math import numpy as np -import serializer.tosa_serializer as ts from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen from serializer.tosa_serializer import DTypeNames @@ -26,7 +25,7 @@ class TosaQuantGen: pass @staticmethod - def getQinfo(testGen, dtype, error_name=None): + def getZeroPoint(testGen, dtype, error_name=None): if dtype == DType.INT8: return testGen.randInt(-128, 128) @@ -45,27 +44,25 @@ class TosaQuantGen: @staticmethod def qgUnary(testGen, op, dtype, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() if error_name == ErrorIf.InputZeroPointNotZero: - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype, error_name), - TosaQuantGen.getQinfo(testGen, dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + TosaQuantGen.getZeroPoint(testGen, dtype), + ] elif error_name == ErrorIf.OutputZeroPointNotZero: - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), - TosaQuantGen.getQinfo(testGen, dtype, error_name), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype), + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + ] else: - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), - TosaQuantGen.getQinfo(testGen, dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype), + TosaQuantGen.getZeroPoint(testGen, dtype), + ] return qinfo @staticmethod def qgConv(testGen, op, dtype_or_dtypeList, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() if isinstance(dtype_or_dtypeList, list): # a list of [input, weights, accumulator] dtypes dtypeList = dtype_or_dtypeList @@ -74,40 +71,34 @@ class TosaQuantGen: dtypeList = [dtype_or_dtypeList] * 3 if error_name == ErrorIf.InputZeroPointNotZero: - input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name) - weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1]) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name), + TosaQuantGen.getZeroPoint(testGen, dtypeList[1]), + ] elif error_name == ErrorIf.WeightZeroPointNotZero: - input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0]) - weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtypeList[0]), + TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name), + ] else: - input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0]) - weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1]) - - qinfo.ConvQuantInfo(input_zp, weights_zp) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtypeList[0]), + TosaQuantGen.getZeroPoint(testGen, dtypeList[1]), + ] return qinfo @staticmethod def qgMatmul(testGen, op, dtype, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() if error_name == ErrorIf.InputZeroPointNotZero: - qinfo.MatMulQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype, error_name), - TosaQuantGen.getQinfo(testGen, dtype, error_name), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + ] else: - qinfo.MatMulQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), - TosaQuantGen.getQinfo(testGen, dtype), - ) - return qinfo - - @staticmethod - def qgPad(testGen, op, dtype, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() - if error_name == ErrorIf.InputZeroPointNotZero: - qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name)) - else: - qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype)) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype), + TosaQuantGen.getZeroPoint(testGen, dtype), + ] return qinfo @staticmethod @@ -550,7 +541,7 @@ class TosaTensorValuesGen: pass @staticmethod - def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None): pCount, cCount = op["operands"] tens = [] @@ -562,7 +553,7 @@ class TosaTensorValuesGen: return tens @staticmethod - def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32 and error_name is None: pCount, cCount = op["operands"] assert ( @@ -582,11 +573,11 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32 and error_name is None: # Make sure the operation does not cause value saturation - where # the number wraps due to limited number of bits to store the answer @@ -651,12 +642,12 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgCondIfWhileLoop( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None + testGen, op, dtypeList, shapeList, testArgs, error_name=None ): if dtypeList[0] in ( DType.INT32, @@ -689,12 +680,12 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgArithmeticRightShift( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None + testGen, op, dtypeList, shapeList, testArgs, error_name=None ): pCount, cCount = op["operands"] # Force value of operand[1] to be within [0, num_bits] @@ -722,16 +713,16 @@ class TosaTensorValuesGen: return placeholders @staticmethod - def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None): # Set datatype of condition tensor to boolean dtypeList[0] = DType.BOOL return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( @@ -765,11 +756,11 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgMul(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( @@ -839,11 +830,11 @@ class TosaTensorValuesGen: return tens else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None): count = len(shapeList) - testGen.args.num_const_inputs_concat if count < 1: count = 1 @@ -866,9 +857,7 @@ class TosaTensorValuesGen: return tens @staticmethod - def tvgLogicalShift( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None - ): + def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None): pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 @@ -886,7 +875,7 @@ class TosaTensorValuesGen: return placeholders @staticmethod - def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( @@ -924,13 +913,11 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgReduceSum( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None - ): + def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32: pCount, cCount = op["operands"] assert ( @@ -949,7 +936,7 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 1967d8a..b331a42 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -1003,12 +1003,7 @@ class TosaErrorValidator: Generally input_zp is index 0, output_zp is index 1 """ - if isinstance(qinfo, tuple): - zero_point = qinfo[index] - else: - # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp - zero_point = qinfo.ints[index][1] - return zero_point + return qinfo[index] @staticmethod def evInputZeroPointNotZero(check=False, **kwargs): diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 262a652..b0e7c8c 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -216,20 +216,19 @@ class TosaTestGen: # build_placeholder returns an int, ABS/other ops does not if isinstance(op, int): - self.ser.addOperator(op, a.name, result_tens.name, None, qinfo) + self.ser.addOperator(op, a.name, result_tens.name, None) return result_tens elif op["op"] == Op.IDENTITY: - self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo) + self.ser.addOperator(op["op"], a.name, result_tens.name, None) return result_tens # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongOutputType: if result_tens.dtype not in [DType.INT8, DType.UINT8]: - qinfo = ts.TosaSerializerQuantInfo() - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(self, a.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, a.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error if checks. input_list = [a.name] @@ -255,7 +254,12 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) + attr = None + if op["op"] == Op.NEGATE: + attr = ts.TosaSerializerAttribute() + attr.NegateAttribute(qinfo[0], qinfo[1]) + + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None): @@ -542,11 +546,10 @@ class TosaTestGen: # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongInputType: if input.dtype not in [DType.INT8, DType.UINT8]: - qinfo = ts.TosaSerializerQuantInfo() - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(self, input.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, input.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error if checks. input_list = [input.name] @@ -577,10 +580,13 @@ class TosaTestGen: ): return None + if qinfo is None: + qinfo = [0, 0] + attr = ts.TosaSerializerAttribute() - attr.PoolAttribute(kernel, stride, pad) + attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_conv2d( @@ -606,11 +612,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -642,9 +647,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_conv3d( @@ -670,11 +675,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -706,9 +710,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_transpose_conv2d( @@ -734,11 +738,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -769,9 +772,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.TransposeConvAttribute(out_pad, stride, output_shape) + attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_depthwise_conv2d( @@ -796,11 +799,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -832,9 +834,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_fully_connected( @@ -871,7 +873,10 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) + attr = ts.TosaSerializerAttribute() + attr.FullyConnectedAttribute(qinfo[0], qinfo[1]) + + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None): @@ -905,7 +910,10 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) + attr = ts.TosaSerializerAttribute() + attr.MatMulAttribute(qinfo[0], qinfo[1]) + + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_reduce(self, op, a, axis, validator_fcns, error_name=None): @@ -1164,7 +1172,7 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None): @@ -2212,7 +2220,7 @@ class TosaTestGen: else: qinfo = None - tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, qinfo, error_name) + tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name) try: if error_if_validators is None: @@ -3425,7 +3433,6 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agPad, ), - "qgen": TosaQuantGen.qgPad, "types": TYPE_FIB, "error_if_validators": ( TosaErrorValidator.evWrongInputType, -- cgit v1.2.1