diff options
author | Eric Kunze <eric.kunze@arm.com> | 2022-06-07 05:20:44 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-06-15 11:38:04 -0700 |
commit | b5fabec33abeca2d92c20c7b094fa3f113d0ddd8 (patch) | |
tree | 9c7d946012c7a70a7fcb237daa4376d7b65c6f76 /reference_model/src/ops | |
parent | 24594f55ee3bf0e95c764e51b94c3ec7f9cfa54a (diff) | |
download | reference_model-b5fabec33abeca2d92c20c7b094fa3f113d0ddd8.tar.gz |
Remove quantization info from serialization attributes
Any needed information moves into the attributes for each operator.
New serialization library version removes teh quantization information
attributes from the schema
Signed-off-by: Eric Kunze <eric.kunze@arm.com>
Change-Id: Icf6165687ab1fd34a01f64c01b0b92b2820e72fa
Diffstat (limited to 'reference_model/src/ops')
23 files changed, 167 insertions, 244 deletions
diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h index 10385e0..4853971 100644 --- a/reference_model/src/ops/activation_funcs.h +++ b/reference_model/src/ops/activation_funcs.h @@ -28,7 +28,7 @@ template <int Rank, DType Dtype> class OpClamp : public UnaryNode<Rank, Dtype> { public: - OpClamp(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpClamp(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : UnaryNode<Rank, Dtype>(sgt_, Op_CLAMP, id_) { INIT_ATTRIBUTE(Clamp); @@ -48,7 +48,7 @@ template <int Rank, DType Dtype> class OpSigmoid : public UnaryNode<Rank, Dtype> { public: - OpSigmoid(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpSigmoid(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : UnaryNode<Rank, Dtype>(sgt_, Op_SIGMOID, id_) { register_fcn(); @@ -64,7 +64,7 @@ template <int Rank, DType Dtype> class OpTanh : public UnaryNode<Rank, Dtype> { public: - OpTanh(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpTanh(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : UnaryNode<Rank, Dtype>(sgt_, Op_TANH, id_) { register_fcn(); diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h index 5b4d0f1..29e6b5a 100644 --- a/reference_model/src/ops/comparison.h +++ b/reference_model/src/ops/comparison.h @@ -28,8 +28,8 @@ template <int Rank, DType Dtype> class OpEqual : public BinaryNode<Rank, Dtype, DType_BOOL> { public: - OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, qinfo_, id_) + OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, id_) { register_fcn(); } @@ -42,8 +42,8 @@ template <int Rank, DType Dtype> class OpGreater : public BinaryNode<Rank, Dtype, DType_BOOL> { public: - OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_GREATER, qinfo_, id_) + OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_GREATER, id_) { register_fcn(); } @@ -56,8 +56,8 @@ template <int Rank, DType Dtype> class OpGreaterEqual : public BinaryNode<Rank, Dtype, DType_BOOL> { public: - OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, qinfo_, id_) + OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, id_) { register_fcn(); } diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 49f53e8..9fe429b 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -23,7 +23,6 @@ using namespace tosa; template <int Rank, DType Dtype> OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_CONCAT, id_) { @@ -124,22 +123,18 @@ int OpConcat<Rank, Dtype>::eval() template <int Rank, DType Dtype> OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_PAD, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); - INIT_QINFO(Pad); INIT_ATTRIBUTE(Pad); } template <int Rank, DType Dtype> OpPad<Rank, Dtype>::~OpPad() { - if (qinfo) - delete qinfo; } template <int Rank, DType Dtype> @@ -177,11 +172,6 @@ int OpPad<Rank, Dtype>::checkTensorAttributes() paddings_array[i] = std::make_pair(pad_front, pad_back); } - if (this->qinfo && Dtype != DType_INT8) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpPad: zeropoint should be 0"); - } - return 0; } @@ -206,11 +196,6 @@ int OpPad<Rank, Dtype>::eval() break; } - if (this->qinfo && Dtype == DType_INT8) - { - pad_value += (InEigenType)this->qinfo->input_zp(); - } - this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value); return GraphNode::eval(); @@ -219,7 +204,6 @@ int OpPad<Rank, Dtype>::eval() template <int InRank, int OutRank, DType Dtype> OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_RESHAPE, id_) { @@ -315,7 +299,6 @@ int OpReshape<InRank, OutRank, Dtype>::eval() template <int Rank, DType Dtype> OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_REVERSE, id_) { @@ -383,7 +366,6 @@ int OpReverse<Rank, Dtype>::eval() template <int Rank, DType Dtype> OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_SLICE, id_) { @@ -451,7 +433,6 @@ int OpSlice<Rank, Dtype>::eval() template <int Rank, DType Dtype> OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_TILE, id_) { @@ -586,7 +567,6 @@ int OpTile<4, Dtype>::eval() template <int Rank, DType Dtype> OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE, id_) { diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index bad88e4..c6513ae 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -27,7 +27,7 @@ template <int Rank, DType Dtype> class OpConcat : public GraphNode { public: - OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConcat(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template <int Rank, DType Dtype> class OpPad : public GraphNode { public: - OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpPad(); virtual int checkTensorAttributes(); virtual int eval(); @@ -63,7 +63,6 @@ protected: Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array; TosaReference::TensorTemplate<TIn>* in; TosaReference::TensorTemplate<TOut>* out; - TosaPadQuantInfo* qinfo; TosaPadAttribute* attribute; }; @@ -71,7 +70,7 @@ template <int InRank, int OutRank, DType Dtype> class OpReshape : public GraphNode { public: - OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpReshape(); virtual int checkTensorAttributes(); @@ -95,7 +94,7 @@ template <int Rank, DType Dtype> class OpReverse : public GraphNode { public: - OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpReverse(); virtual int checkTensorAttributes(); @@ -117,7 +116,7 @@ template <int Rank, DType Dtype> class OpSlice : public GraphNode { public: - OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpSlice(); virtual int checkTensorAttributes(); @@ -140,7 +139,7 @@ template <int Rank, DType Dtype> class OpTileBase : public GraphNode { public: - OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTileBase(); virtual int checkTensorAttributes(); @@ -161,8 +160,8 @@ template <int Rank, DType Dtype> class OpTile : public OpTileBase<Rank, Dtype> { public: - OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpTileBase<Rank, Dtype>(sgt_, attribute_, qinfo_, id_) + OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : OpTileBase<Rank, Dtype>(sgt_, attribute_, id_) {} protected: @@ -175,8 +174,8 @@ protected: class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \ { \ public: \ - OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : OpTileBase<N, Dtype>(sgt_, attribute_, qinfo_, id_) \ + OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ + : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \ {} \ \ protected: \ @@ -194,7 +193,7 @@ template <int Rank, DType Dtype> class OpTranspose : public GraphNode { public: - OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTranspose(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index ec4bc41..30c9511 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -45,7 +45,6 @@ int OpConst::eval() template <int Rank, DType Dtype> OpIdentity<Rank, Dtype>::OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_IDENTITY, id_) { diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h index 407cf0a..8761a08 100644 --- a/reference_model/src/ops/data_nodes.h +++ b/reference_model/src/ops/data_nodes.h @@ -35,7 +35,7 @@ template <int Rank, DType Dtype> class OpIdentity : public GraphNode { public: - OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpIdentity(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 7f30e30..eadefaa 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -25,7 +25,6 @@ using namespace tosa; template <int Rank, DType InDtype, DType OutDtype> BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, op_, id_) { @@ -490,7 +489,6 @@ int OpSub<Rank, Dtype>::register_fcn() template <int Rank, DType InDtype> OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_TABLE, id_) { diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 373dfb8..b2c92a4 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -42,7 +42,7 @@ template <int Rank, DType InDtype, DType OutDtype> class BinaryNodeBase : public GraphNode { public: - BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_); + BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, const uint64_t id_); virtual ~BinaryNodeBase(); virtual int checkTensorAttributes() final; @@ -71,8 +71,8 @@ template <int Rank, DType InDtype, DType OutDtype> class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype> { public: - BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) - : BinaryNodeBase<Rank, InDtype, OutDtype>(sgt_, op_, qinfo_, id_) + BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_) + : BinaryNodeBase<Rank, InDtype, OutDtype>(sgt_, op_, id_) {} virtual ~BinaryNode() {} @@ -90,8 +90,8 @@ template <DType InDtype, DType OutDtype> class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype> { public: - BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) - : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, qinfo_, id_) + BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_) + : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, id_) {} virtual ~BinaryNode() {} @@ -104,8 +104,8 @@ public: class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \ { \ public: \ - Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, qinfo_, id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ + : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ } \ @@ -139,9 +139,8 @@ class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype> public: OpArithmeticRightShift(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_) + : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, id_) { INIT_ATTRIBUTE(ArithmeticRightShift); register_fcn(); @@ -158,8 +157,8 @@ template <int Rank, DType InDtype, DType OutDtype> class OpMul : public BinaryNode<Rank, InDtype, OutDtype> { public: - OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, qinfo_, id_) + OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_) { INIT_ATTRIBUTE(Mul); register_fcn(); @@ -178,7 +177,7 @@ template <int Rank, DType InDtype> class OpTable : public GraphNode { public: - OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTable(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index c265077..72fe5a0 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -22,7 +22,6 @@ using namespace tosa; template <int Rank, DType Dtype> OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_SELECT, id_) { diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h index b80fb23..75a2194 100644 --- a/reference_model/src/ops/ewise_ternary.h +++ b/reference_model/src/ops/ewise_ternary.h @@ -33,7 +33,7 @@ template <int Rank, DType Dtype> class OpSelectBase : public GraphNode { public: - OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpSelectBase(); virtual int checkTensorAttributes(); @@ -59,8 +59,8 @@ template <int Rank, DType Dtype> class OpSelect : public OpSelectBase<Rank, Dtype> { public: - OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpSelectBase<Rank, Dtype>(sgt_, attribute_, qinfo_, id_) + OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : OpSelectBase<Rank, Dtype>(sgt_, attribute_, id_) {} virtual int eval(); int broadcast(); @@ -73,8 +73,8 @@ template <DType Dtype> class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype> { public: - OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpSelectBase<0, Dtype>(sgt_, attribute_, qinfo_, id_) + OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : OpSelectBase<0, Dtype>(sgt_, attribute_, id_) {} virtual int eval(); }; diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 8b83a50..8ef1e3c 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -29,7 +29,10 @@ UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64 setRequiredOperands(1, 1); setRequiredRank(0, 6); - fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); }; + fcn = [](InEigenType a) -> OutEigenType { + ASSERT_MSG(0, "In default UnaryNode function, missing function registration"); + return OutEigenType(); + }; } template <int Rank, DType Dtype> @@ -211,13 +214,28 @@ int OpLogicalNot<Rank, Dtype>::register_fcn() } template <int Rank, DType Dtype> +OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) + : UnaryNode<Rank, Dtype>(sgt_, Op_NEGATE, id_) +{ + INIT_ATTRIBUTE(Negate); + + register_fcn(); +} + +template <int Rank, DType Dtype> +OpNegate<Rank, Dtype>::~OpNegate() +{ + if (attribute) + delete attribute; +} + +template <int Rank, DType Dtype> int OpNegate<Rank, Dtype>::register_fcn() { - if (Dtype != DType_INT8 && this->qinfo) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpNegate: zeropoint only for int8_t"); - ERROR_IF(this->qinfo->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); - } + ERROR_IF(Dtype != DType_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t"); + ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); switch (Dtype) { @@ -251,11 +269,11 @@ int OpNegate<Rank, Dtype>::register_fcn() break; case DType_INT8: this->fcn = [this](InEigenType a) -> OutEigenType { - int64_t res_in_64 = 0 - (a - this->qinfo->input_zp()); + int64_t res_in_64 = 0 - (a - attribute->input1_zp()); int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min()); REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)"); - res_in_64 += this->qinfo->output_zp(); + res_in_64 += attribute->output_zp(); InEigenType result = static_cast<InEigenType>(std::min(std::max(res_in_64, static_cast<int64_t>(QMin)), static_cast<int64_t>(QMax))); return result; }; diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h index 374c8e4..16a4c88 100644 --- a/reference_model/src/ops/ewise_unary.h +++ b/reference_model/src/ops/ewise_unary.h @@ -49,7 +49,7 @@ protected: class Op##Opname : public UnaryNode<Rank, Dtype> \ { \ public: \ - Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ : UnaryNode<Rank, Dtype>(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ @@ -61,27 +61,6 @@ protected: virtual int register_fcn(); \ }; -#define DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Opname, OPNAME) \ - template <int Rank, DType Dtype> \ - class Op##Opname : public UnaryNode<Rank, Dtype> \ - { \ - public: \ - Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : UnaryNode<Rank, Dtype>(sgt_, Op_##OPNAME, id_) \ - { \ - INIT_QINFO(Unary); \ - register_fcn(); \ - } \ - static constexpr int32_t QMin = GetQMin<Dtype>::value; \ - static constexpr int32_t QMax = GetQMax<Dtype>::value; \ - using InEigenType = typename GetEigenType<Dtype>::type; \ - using OutEigenType = typename GetEigenType<Dtype>::type; \ - virtual int register_fcn(); \ - \ - protected: \ - TosaUnaryQuantInfo* qinfo; \ - }; - DEF_TEMPLATE_UNARY_OP(Abs, ABS) DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT) DEF_TEMPLATE_UNARY_OP(Ceil, CEIL) @@ -90,12 +69,28 @@ DEF_TEMPLATE_UNARY_OP(Exp, EXP) DEF_TEMPLATE_UNARY_OP(Floor, FLOOR) DEF_TEMPLATE_UNARY_OP(Log, LOG) DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT) -DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Negate, NEGATE) DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL) DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT) #undef DEF_TEMPLATE_UNARY_OP -#undef DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO + +// Negate is the only unary op with attributes +template <int Rank, DType Dtype> +class OpNegate : public UnaryNode<Rank, Dtype> +{ +public: + OpNegate(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); + virtual ~OpNegate(); + + static constexpr int32_t QMin = GetQMin<Dtype>::value; + static constexpr int32_t QMax = GetQMax<Dtype>::value; + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + virtual int register_fcn(); + +protected: + tosa::TosaNegateAttribute* attribute; +}; }; // namespace TosaReference diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 6dec1bc..701fdfb 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -15,7 +15,6 @@ #include "image.h" #include "arith_util.h" -#include "quant_util.h" using namespace TosaReference; using namespace Eigen; @@ -24,7 +23,6 @@ using namespace tosa; template <DType InDtype, DType OutDtype> OpResize<InDtype, OutDtype>::OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_RESIZE, id_) { diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h index 095dc7d..fea4885 100644 --- a/reference_model/src/ops/image.h +++ b/reference_model/src/ops/image.h @@ -27,7 +27,7 @@ template <DType InDtype, DType OutDtype> class OpResize : public GraphNode { public: - OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpResize(); virtual int checkTensorAttributes() final; virtual int eval(); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index f7ded9a..62d5b11 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -36,7 +36,6 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, TosaSerializationHandler* tsh, Op opType, TosaAttributeBase* attribute, - TosaQuantInfoBase* qinfo, uint64_t id, DType inputDType, int inputRank, diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index eaa359c..4cb8178 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -18,61 +18,60 @@ #include "attribute.h" #include "graph_node.h" -#include "quant_info.h" #include "template_types.h" #include "tosa_serialization_handler.h" #define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \ case RANK: \ - return new OP<RANK, DType_##DTYPE>(sgt, attribute, qinfo, id); + return new OP<RANK, DType_##DTYPE>(sgt, attribute, id); #define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ case RANK: \ - return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); + return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); #define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \ case RANK2: \ - return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, qinfo, id); + return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, id); #define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ case RANK2: \ - return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); + return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); #define DEF_FACTORY_ONE_RANK_0_6(OP) \ switch (inputRank) \ { \ case 0: \ - return new OP<0>(sgt, attribute, qinfo, id); \ + return new OP<0>(sgt, attribute, id); \ case 1: \ - return new OP<1>(sgt, attribute, qinfo, id); \ + return new OP<1>(sgt, attribute, id); \ case 2: \ - return new OP<2>(sgt, attribute, qinfo, id); \ + return new OP<2>(sgt, attribute, id); \ case 3: \ - return new OP<3>(sgt, attribute, qinfo, id); \ + return new OP<3>(sgt, attribute, id); \ case 4: \ - return new OP<4>(sgt, attribute, qinfo, id); \ + return new OP<4>(sgt, attribute, id); \ case 5: \ - return new OP<5>(sgt, attribute, qinfo, id); \ + return new OP<5>(sgt, attribute, id); \ case 6: \ - return new OP<6>(sgt, attribute, qinfo, id); \ + return new OP<6>(sgt, attribute, id); \ } #define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \ if (inputDType == DType_##DTYPE) \ { \ - return new OP<DType_##DTYPE>(sgt, attribute, qinfo, id); \ + return new OP<DType_##DTYPE>(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); \ + return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_RESIZE(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); \ + return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \ } #define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ @@ -231,7 +230,6 @@ public: tosa::TosaSerializationHandler* tsh, tosa::Op opType, tosa::TosaAttributeBase* attribute, - tosa::TosaQuantInfoBase* qinfo, uint64_t id, tosa::DType inputDType, int inputRank, diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h index f3407f4..6e98a76 100644 --- a/reference_model/src/ops/reduction.h +++ b/reference_model/src/ops/reduction.h @@ -48,7 +48,7 @@ template <int Rank, DType Dtype> class OpReduceAll : public ReduceNode<Rank, Dtype> { public: - OpReduceAll(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceAll(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_ALL, attribute_, id_) {} virtual int eval(); @@ -58,7 +58,7 @@ template <int Rank, DType Dtype> class OpReduceAny : public ReduceNode<Rank, Dtype> { public: - OpReduceAny(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceAny(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_ALL, attribute_, id_) {} virtual int eval(); @@ -68,7 +68,7 @@ template <int Rank, DType Dtype> class OpReduceMax : public ReduceNode<Rank, Dtype> { public: - OpReduceMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_MAX, attribute_, id_) {} virtual int eval(); @@ -78,7 +78,7 @@ template <int Rank, DType Dtype> class OpReduceMin : public ReduceNode<Rank, Dtype> { public: - OpReduceMin(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceMin(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_MIN, attribute_, id_) {} virtual int eval(); @@ -88,7 +88,7 @@ template <int Rank, DType Dtype> class OpReduceProduct : public ReduceNode<Rank, Dtype> { public: - OpReduceProduct(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceProduct(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_PRODUCT, attribute_, id_) {} virtual int eval(); @@ -98,7 +98,7 @@ template <int Rank, DType Dtype> class OpReduceSum : public ReduceNode<Rank, Dtype> { public: - OpReduceSum(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceSum(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_) {} virtual int eval(); @@ -108,7 +108,7 @@ template <int Rank, DType Dtype> class OpReduceSumInt : public ReduceNode<Rank, Dtype> { public: - OpReduceSumInt(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + OpReduceSumInt(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_) {} virtual int eval(); diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index 02ec54f..faf7db9 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -23,7 +23,6 @@ using namespace tosa; template <DType Dtype> OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_GATHER, id_) { @@ -120,7 +119,6 @@ int OpGather<Dtype>::eval() template <DType Dtype> OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_SCATTER, id_) { diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h index 66b584a..af09153 100644 --- a/reference_model/src/ops/scatter_gather.h +++ b/reference_model/src/ops/scatter_gather.h @@ -27,7 +27,7 @@ template <DType Dtype> class OpGather : public GraphNode { public: - OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpGather(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template <DType Dtype> class OpScatter : public GraphNode { public: - OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpScatter(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index aef1ad2..3ab4d56 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -114,8 +114,7 @@ int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute, return 0; } -int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute, - tosa::TosaConvQuantInfo* qinfo, +int check_conv_attribute(tosa::TosaConvAttribute* attribute, uint32_t conv_dimension, std::vector<int32_t> input_shape, std::vector<int32_t> output_shape, @@ -226,18 +225,13 @@ int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute, return 1; } - if (qinfo) - { - if (InDtype != DType_INT8 && qinfo->input_zp() != 0) - { - msg = "zeropoint only for int8_t"; - return 1; - } - if (WeightDtype != DType_INT8 && qinfo->weight_zp() != 0) - { - msg = "zeropoint only for int8_t"; - return 1; - } + if (InDtype != DType_INT8 && attribute->input_zp() != 0) { + msg = "Input zero point must be zero for non-int8 data"; + return 1; + } + if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) { + msg = "Weight zero point must be zero for non-int8 data"; + return 1; } return 0; @@ -246,7 +240,6 @@ int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute, template <int Rank, DType Dtype> OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_ARGMAX, id_) { @@ -339,7 +332,6 @@ int OpArgMax<Rank, Dtype>::eval() template <DType Dtype> OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_AVG_POOL2D, id_) { @@ -347,7 +339,6 @@ OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_, setRequiredRank(4); INIT_ATTRIBUTE(Pool); - INIT_QINFO(Unary); } template <DType Dtype> @@ -377,11 +368,8 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes() in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); - if (Dtype != DType_INT8 && this->qinfo) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t"); - ERROR_IF(this->qinfo->output_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t"); - } + ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data"); std::string msg; if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg)) @@ -474,9 +462,9 @@ int OpAvgPool2d<Dtype>::eval() pad[3] = std::make_pair(0, 0); ETensor4<InEigenType> input_val = this->in->getTensor(); - if (this->qinfo) + if (Dtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); } ETensor4<InEigenType> input_padded = input_val.pad(pad); @@ -537,7 +525,7 @@ int OpAvgPool2d<Dtype>::eval() { REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str()); } - this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp()); + this->out->getTensor() = this->out->getTensor() + (OutEigenType)(attribute->output_zp()); this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin); this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax); } @@ -552,7 +540,6 @@ int OpAvgPool2d<Dtype>::eval() template <DType InDtype, DType WeightDtype> OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_CONV2D, id_) { @@ -560,7 +547,6 @@ OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_, setRequiredRank(4); INIT_ATTRIBUTE(Conv); - INIT_QINFO(Conv); } template <DType InDtype, DType WeightDtype> @@ -568,8 +554,6 @@ OpConv2d<InDtype, WeightDtype>::~OpConv2d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template <DType InDtype, DType WeightDtype> @@ -598,7 +582,7 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes() output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); std::string msg; - if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(), + if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(), weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg)) { msg = "OpConv2d: " + msg; @@ -691,10 +675,10 @@ int OpConv2d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } ETensor4<InEigenType> input_padded = input_val.pad(pad); @@ -739,7 +723,6 @@ int OpConv2d<InDtype, WeightDtype>::eval() template <DType InDtype, DType WeightDtype> OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_CONV3D, id_) { @@ -747,7 +730,6 @@ OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_, setRequiredRank(5); INIT_ATTRIBUTE(Conv); - INIT_QINFO(Conv); } template <DType InDtype, DType WeightDtype> @@ -755,8 +737,6 @@ OpConv3d<InDtype, WeightDtype>::~OpConv3d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template <DType InDtype, DType WeightDtype> @@ -785,7 +765,7 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes() output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); std::string msg; - if (check_conv_attribute_qinfo(attribute, qinfo, 3 /* conv_dimension */, input->getShape(), output->getShape(), + if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(), weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg)) { msg = "OpConv3d: " + msg; @@ -858,10 +838,10 @@ int OpConv3d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } ETensor5<InEigenType> input_padded = input_val.pad(pad); @@ -931,7 +911,6 @@ int OpConv3d<InDtype, WeightDtype>::eval() template <DType InDtype, DType WeightDtype> OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { @@ -939,7 +918,6 @@ OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sg setRequiredRank(4); INIT_ATTRIBUTE(Conv); - INIT_QINFO(Conv); } template <DType InDtype, DType WeightDtype> @@ -947,8 +925,6 @@ OpDepthwiseConv2d<InDtype, WeightDtype>::~OpDepthwiseConv2d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template <DType InDtype, DType WeightDtype> @@ -977,7 +953,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes() output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); std::string msg; - if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(), + if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(), weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg)) { msg = "OpDepthwiseConv2d: " + msg; @@ -1041,10 +1017,10 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } ETensor4<InEigenType> input_padded = input_val.pad(pad); @@ -1108,21 +1084,20 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval() template <DType InDtype, DType WeightDtype> OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) { setRequiredOperands(3, 1); setRequiredRank(2); - INIT_QINFO(Conv); + INIT_ATTRIBUTE(FullyConnected); } template <DType InDtype, DType WeightDtype> OpFullyConnected<InDtype, WeightDtype>::~OpFullyConnected() { - if (qinfo) - delete qinfo; + if (attribute) + delete attribute; } template <DType InDtype, DType WeightDtype> @@ -1157,17 +1132,8 @@ int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes() output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); - if (this->qinfo) - { - if (InDtype != DType_INT8) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t"); - } - if (WeightDtype != DType_INT8) - { - ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t"); - } - } + ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data"); return 0; } @@ -1190,10 +1156,10 @@ int OpFullyConnected<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } this->output->getTensor() = @@ -1211,21 +1177,20 @@ int OpFullyConnected<InDtype, WeightDtype>::eval() template <DType Dtype> OpMatMul<Dtype>::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_MATMUL, id_) { setRequiredOperands(2, 1); setRequiredRank(3); - INIT_QINFO(MatMul); + INIT_ATTRIBUTE(MatMul); } template <DType Dtype> OpMatMul<Dtype>::~OpMatMul() { - if (qinfo) - delete qinfo; + if (attribute) + delete attribute; } template <DType Dtype> @@ -1284,11 +1249,8 @@ int OpMatMul<Dtype>::checkTensorAttributes() } W = b->getShape()[2]; - if (Dtype != DType_INT8 && this->qinfo) - { - ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t"); - ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t"); - } + ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data"); return 0; } @@ -1301,10 +1263,10 @@ int OpMatMul<Dtype>::eval() TIn a_val = this->a->getTensor(); TIn b_val = this->b->getTensor(); - if (this->qinfo) + if (Dtype == DType_INT8) { - a_val = a_val - (InEigenType)this->qinfo->a_zp(); - b_val = b_val - (InEigenType)this->qinfo->b_zp(); + a_val = a_val - (InEigenType)attribute->a_zp(); + b_val = b_val - (InEigenType)attribute->b_zp(); } Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C }); @@ -1351,7 +1313,6 @@ int OpMatMul<Dtype>::eval() template <DType Dtype> OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_MAX_POOL2D, id_) { @@ -1484,7 +1445,6 @@ int OpMaxPool2d<Dtype>::eval() template <DType InDtype, DType WeightDtype> OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { @@ -1492,7 +1452,6 @@ OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sg setRequiredRank(4); INIT_ATTRIBUTE(TransposeConv); - INIT_QINFO(Conv); } template <DType InDtype, DType WeightDtype> @@ -1500,8 +1459,6 @@ OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template <DType InDtype, DType WeightDtype> @@ -1595,17 +1552,8 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes() return 1; } - if (this->qinfo) - { - if (InDtype != DType_INT8) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t"); - } - if (WeightDtype != DType_INT8) - { - ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t"); - } - } + ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data"); return 0; } @@ -1655,10 +1603,10 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } Eigen::array<Eigen::Index, 4> reshape_dim; diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 05b1ca1..24eadeb 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -28,7 +28,7 @@ template <int Rank, DType Dtype> class OpArgMax : public GraphNode { public: - OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpArgMax(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template <DType Dtype> class OpAvgPool2d : public GraphNode { public: - OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpAvgPool2d(); virtual int checkTensorAttributes(); @@ -69,7 +69,6 @@ protected: TosaReference::TensorTemplate<TIn>* in; TosaReference::TensorTemplate<TOut>* out; tosa::TosaPoolAttribute* attribute; - tosa::TosaUnaryQuantInfo* qinfo; protected: // return a 1D [N] tensor that describes a how many valid elements covered in the input space @@ -80,7 +79,7 @@ template <DType InDtype, DType WeightDtype> class OpConv2d : public GraphNode { public: - OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConv2d(); virtual int checkTensorAttributes() final; @@ -105,14 +104,13 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template <DType InDtype, DType WeightDtype> class OpConv3d : public GraphNode { public: - OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConv3d(); virtual int checkTensorAttributes() final; @@ -137,14 +135,13 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template <DType InDtype, DType WeightDtype> class OpDepthwiseConv2d : public GraphNode { public: - OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpDepthwiseConv2d(); virtual int checkTensorAttributes() final; @@ -169,14 +166,13 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template <DType InDtype, DType WeightDtype> class OpFullyConnected : public GraphNode { public: - OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpFullyConnected(); virtual int checkTensorAttributes() final; @@ -199,14 +195,15 @@ protected: TosaReference::TensorTemplate<TWeight>* weight; TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; - tosa::TosaConvQuantInfo* qinfo; + + tosa::TosaFullyConnectedAttribute* attribute; }; template <DType Dtype> class OpMatMul : public GraphNode { public: - OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpMatMul(); virtual int checkTensorAttributes() final; @@ -230,14 +227,15 @@ protected: int64_t H; int64_t W; int64_t C; - tosa::TosaMatMulQuantInfo* qinfo; + + tosa::TosaMatMulAttribute* attribute; }; template <DType Dtype> class OpMaxPool2d : public GraphNode { public: - OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpMaxPool2d(); virtual int checkTensorAttributes(); @@ -258,7 +256,7 @@ template <DType InDtype, DType WeightDtype> class OpTransposeConv2d : public GraphNode { public: - OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTransposeConv2d(); virtual int checkTensorAttributes() final; @@ -283,7 +281,6 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; TosaTransposeConvAttribute* attribute; - TosaConvQuantInfo* qinfo; }; }; // namespace TosaReference diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 7ee9692..ac54932 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -25,7 +25,6 @@ using namespace tosa; template <int Rank, DType InDtype, DType OutDtype> OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_RESCALE, id_) { @@ -218,7 +217,6 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() template <int Rank, DType InDtype, DType OutDtype> OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_CAST, id_) { diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 060e14e..53470d1 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -26,7 +26,7 @@ template <int Rank, DType InDtype, DType OutDtype> class OpRescale : public GraphNode { public: - OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpRescale(); virtual int checkTensorAttributes() final; @@ -140,7 +140,7 @@ template <int Rank, DType InDtype, DType OutDtype> class OpCast : public GraphNode { public: - OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpCast(); virtual int checkTensorAttributes() final; |