aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2022-06-07 05:20:44 +0000
committerEric Kunze <eric.kunze@arm.com>2022-06-15 11:38:04 -0700
commitb5fabec33abeca2d92c20c7b094fa3f113d0ddd8 (patch)
tree9c7d946012c7a70a7fcb237daa4376d7b65c6f76
parent24594f55ee3bf0e95c764e51b94c3ec7f9cfa54a (diff)
downloadreference_model-b5fabec33abeca2d92c20c7b094fa3f113d0ddd8.tar.gz
Remove quantization info from serialization attributes
Any needed information moves into the attributes for each operator. New serialization library version removes teh quantization information attributes from the schema Signed-off-by: Eric Kunze <eric.kunze@arm.com> Change-Id: Icf6165687ab1fd34a01f64c01b0b92b2820e72fa
-rw-r--r--reference_model/src/graph_node.h12
-rw-r--r--reference_model/src/main.cpp2
-rw-r--r--reference_model/src/ops/activation_funcs.h6
-rw-r--r--reference_model/src/ops/comparison.h12
-rw-r--r--reference_model/src/ops/data_layout.cc20
-rw-r--r--reference_model/src/ops/data_layout.h23
-rw-r--r--reference_model/src/ops/data_nodes.cc1
-rw-r--r--reference_model/src/ops/data_nodes.h2
-rw-r--r--reference_model/src/ops/ewise_binary.cc2
-rw-r--r--reference_model/src/ops/ewise_binary.h23
-rw-r--r--reference_model/src/ops/ewise_ternary.cc1
-rw-r--r--reference_model/src/ops/ewise_ternary.h10
-rw-r--r--reference_model/src/ops/ewise_unary.cc34
-rw-r--r--reference_model/src/ops/ewise_unary.h43
-rw-r--r--reference_model/src/ops/image.cc2
-rw-r--r--reference_model/src/ops/image.h2
-rw-r--r--reference_model/src/ops/op_factory.cc1
-rw-r--r--reference_model/src/ops/op_factory.h30
-rw-r--r--reference_model/src/ops/reduction.h14
-rw-r--r--reference_model/src/ops/scatter_gather.cc2
-rw-r--r--reference_model/src/ops/scatter_gather.h4
-rw-r--r--reference_model/src/ops/tensor_ops.cc144
-rw-r--r--reference_model/src/ops/tensor_ops.h29
-rw-r--r--reference_model/src/ops/type_conversion.cc2
-rw-r--r--reference_model/src/ops/type_conversion.h4
-rw-r--r--reference_model/src/subgraph_traverser.cc2
m---------thirdparty/serialization_lib0
-rw-r--r--verif/generator/tosa_arg_gen.py119
-rw-r--r--verif/generator/tosa_error_if.py7
-rw-r--r--verif/generator/tosa_test_gen.py103
30 files changed, 278 insertions, 378 deletions
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index 14a8acc..787e89d 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -17,7 +17,6 @@
#define GRAPH_NODE_H
#include "attribute.h"
-#include "quant_info.h"
#include "subgraph_traverser.h"
#include "tensor.h"
#include "tosa_generated.h"
@@ -126,17 +125,6 @@
FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute"); \
}
-#define INIT_QINFO(QINFO_NAME) \
- if (auto p = dynamic_cast<Tosa##QINFO_NAME##QuantInfo*>(qinfo_)) \
- { \
- qinfo = new Tosa##QINFO_NAME##QuantInfo(p); \
- ASSERT_MEM(qinfo); \
- } \
- else \
- { \
- qinfo = nullptr; \
- }
-
namespace TosaReference
{
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index f8a446a..2b70e94 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -26,7 +26,7 @@
#include <nlohmann/json.hpp>
#define MODEL_VERSION_MAJOR 0
-#define MODEL_VERSION_MINOR 25
+#define MODEL_VERSION_MINOR 30
#define MODEL_VERSION_PATCH 0
#define MODEL_VERSION_DRAFT true
diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h
index 10385e0..4853971 100644
--- a/reference_model/src/ops/activation_funcs.h
+++ b/reference_model/src/ops/activation_funcs.h
@@ -28,7 +28,7 @@ template <int Rank, DType Dtype>
class OpClamp : public UnaryNode<Rank, Dtype>
{
public:
- OpClamp(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ OpClamp(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: UnaryNode<Rank, Dtype>(sgt_, Op_CLAMP, id_)
{
INIT_ATTRIBUTE(Clamp);
@@ -48,7 +48,7 @@ template <int Rank, DType Dtype>
class OpSigmoid : public UnaryNode<Rank, Dtype>
{
public:
- OpSigmoid(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ OpSigmoid(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: UnaryNode<Rank, Dtype>(sgt_, Op_SIGMOID, id_)
{
register_fcn();
@@ -64,7 +64,7 @@ template <int Rank, DType Dtype>
class OpTanh : public UnaryNode<Rank, Dtype>
{
public:
- OpTanh(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ OpTanh(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: UnaryNode<Rank, Dtype>(sgt_, Op_TANH, id_)
{
register_fcn();
diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h
index 5b4d0f1..29e6b5a 100644
--- a/reference_model/src/ops/comparison.h
+++ b/reference_model/src/ops/comparison.h
@@ -28,8 +28,8 @@ template <int Rank, DType Dtype>
class OpEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
{
public:
- OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, qinfo_, id_)
+ OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, id_)
{
register_fcn();
}
@@ -42,8 +42,8 @@ template <int Rank, DType Dtype>
class OpGreater : public BinaryNode<Rank, Dtype, DType_BOOL>
{
public:
- OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_GREATER, qinfo_, id_)
+ OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_GREATER, id_)
{
register_fcn();
}
@@ -56,8 +56,8 @@ template <int Rank, DType Dtype>
class OpGreaterEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
{
public:
- OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, qinfo_, id_)
+ OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, id_)
{
register_fcn();
}
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 49f53e8..9fe429b 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -23,7 +23,6 @@ using namespace tosa;
template <int Rank, DType Dtype>
OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_CONCAT, id_)
{
@@ -124,22 +123,18 @@ int OpConcat<Rank, Dtype>::eval()
template <int Rank, DType Dtype>
OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_PAD, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
- INIT_QINFO(Pad);
INIT_ATTRIBUTE(Pad);
}
template <int Rank, DType Dtype>
OpPad<Rank, Dtype>::~OpPad()
{
- if (qinfo)
- delete qinfo;
}
template <int Rank, DType Dtype>
@@ -177,11 +172,6 @@ int OpPad<Rank, Dtype>::checkTensorAttributes()
paddings_array[i] = std::make_pair(pad_front, pad_back);
}
- if (this->qinfo && Dtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpPad: zeropoint should be 0");
- }
-
return 0;
}
@@ -206,11 +196,6 @@ int OpPad<Rank, Dtype>::eval()
break;
}
- if (this->qinfo && Dtype == DType_INT8)
- {
- pad_value += (InEigenType)this->qinfo->input_zp();
- }
-
this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
return GraphNode::eval();
@@ -219,7 +204,6 @@ int OpPad<Rank, Dtype>::eval()
template <int InRank, int OutRank, DType Dtype>
OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_RESHAPE, id_)
{
@@ -315,7 +299,6 @@ int OpReshape<InRank, OutRank, Dtype>::eval()
template <int Rank, DType Dtype>
OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_REVERSE, id_)
{
@@ -383,7 +366,6 @@ int OpReverse<Rank, Dtype>::eval()
template <int Rank, DType Dtype>
OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_SLICE, id_)
{
@@ -451,7 +433,6 @@ int OpSlice<Rank, Dtype>::eval()
template <int Rank, DType Dtype>
OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_TILE, id_)
{
@@ -586,7 +567,6 @@ int OpTile<4, Dtype>::eval()
template <int Rank, DType Dtype>
OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE, id_)
{
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index bad88e4..c6513ae 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -27,7 +27,7 @@ template <int Rank, DType Dtype>
class OpConcat : public GraphNode
{
public:
- OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpConcat();
virtual int checkTensorAttributes();
@@ -49,7 +49,7 @@ template <int Rank, DType Dtype>
class OpPad : public GraphNode
{
public:
- OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpPad();
virtual int checkTensorAttributes();
virtual int eval();
@@ -63,7 +63,6 @@ protected:
Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
TosaReference::TensorTemplate<TIn>* in;
TosaReference::TensorTemplate<TOut>* out;
- TosaPadQuantInfo* qinfo;
TosaPadAttribute* attribute;
};
@@ -71,7 +70,7 @@ template <int InRank, int OutRank, DType Dtype>
class OpReshape : public GraphNode
{
public:
- OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpReshape();
virtual int checkTensorAttributes();
@@ -95,7 +94,7 @@ template <int Rank, DType Dtype>
class OpReverse : public GraphNode
{
public:
- OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpReverse();
virtual int checkTensorAttributes();
@@ -117,7 +116,7 @@ template <int Rank, DType Dtype>
class OpSlice : public GraphNode
{
public:
- OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpSlice();
virtual int checkTensorAttributes();
@@ -140,7 +139,7 @@ template <int Rank, DType Dtype>
class OpTileBase : public GraphNode
{
public:
- OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpTileBase();
virtual int checkTensorAttributes();
@@ -161,8 +160,8 @@ template <int Rank, DType Dtype>
class OpTile : public OpTileBase<Rank, Dtype>
{
public:
- OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : OpTileBase<Rank, Dtype>(sgt_, attribute_, qinfo_, id_)
+ OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : OpTileBase<Rank, Dtype>(sgt_, attribute_, id_)
{}
protected:
@@ -175,8 +174,8 @@ protected:
class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
{ \
public: \
- OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
- : OpTileBase<N, Dtype>(sgt_, attribute_, qinfo_, id_) \
+ OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
+ : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \
{} \
\
protected: \
@@ -194,7 +193,7 @@ template <int Rank, DType Dtype>
class OpTranspose : public GraphNode
{
public:
- OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpTranspose();
virtual int checkTensorAttributes();
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
index ec4bc41..30c9511 100644
--- a/reference_model/src/ops/data_nodes.cc
+++ b/reference_model/src/ops/data_nodes.cc
@@ -45,7 +45,6 @@ int OpConst::eval()
template <int Rank, DType Dtype>
OpIdentity<Rank, Dtype>::OpIdentity(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_IDENTITY, id_)
{
diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h
index 407cf0a..8761a08 100644
--- a/reference_model/src/ops/data_nodes.h
+++ b/reference_model/src/ops/data_nodes.h
@@ -35,7 +35,7 @@ template <int Rank, DType Dtype>
class OpIdentity : public GraphNode
{
public:
- OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpIdentity();
virtual int checkTensorAttributes();
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 7f30e30..eadefaa 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -25,7 +25,6 @@ using namespace tosa;
template <int Rank, DType InDtype, DType OutDtype>
BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
const Op& op_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, op_, id_)
{
@@ -490,7 +489,6 @@ int OpSub<Rank, Dtype>::register_fcn()
template <int Rank, DType InDtype>
OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_TABLE, id_)
{
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 373dfb8..b2c92a4 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -42,7 +42,7 @@ template <int Rank, DType InDtype, DType OutDtype>
class BinaryNodeBase : public GraphNode
{
public:
- BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_);
+ BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, const uint64_t id_);
virtual ~BinaryNodeBase();
virtual int checkTensorAttributes() final;
@@ -71,8 +71,8 @@ template <int Rank, DType InDtype, DType OutDtype>
class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype>
{
public:
- BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
- : BinaryNodeBase<Rank, InDtype, OutDtype>(sgt_, op_, qinfo_, id_)
+ BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_)
+ : BinaryNodeBase<Rank, InDtype, OutDtype>(sgt_, op_, id_)
{}
virtual ~BinaryNode()
{}
@@ -90,8 +90,8 @@ template <DType InDtype, DType OutDtype>
class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype>
{
public:
- BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
- : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, qinfo_, id_)
+ BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_)
+ : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, id_)
{}
virtual ~BinaryNode()
{}
@@ -104,8 +104,8 @@ public:
class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
{ \
public: \
- Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
- : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, qinfo_, id_) \
+ Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
+ : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \
{ \
register_fcn(); \
} \
@@ -139,9 +139,8 @@ class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype>
public:
OpArithmeticRightShift(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
- : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_)
+ : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, id_)
{
INIT_ATTRIBUTE(ArithmeticRightShift);
register_fcn();
@@ -158,8 +157,8 @@ template <int Rank, DType InDtype, DType OutDtype>
class OpMul : public BinaryNode<Rank, InDtype, OutDtype>
{
public:
- OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, qinfo_, id_)
+ OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_)
{
INIT_ATTRIBUTE(Mul);
register_fcn();
@@ -178,7 +177,7 @@ template <int Rank, DType InDtype>
class OpTable : public GraphNode
{
public:
- OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpTable();
virtual int checkTensorAttributes();
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index c265077..72fe5a0 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -22,7 +22,6 @@ using namespace tosa;
template <int Rank, DType Dtype>
OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_SELECT, id_)
{
diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h
index b80fb23..75a2194 100644
--- a/reference_model/src/ops/ewise_ternary.h
+++ b/reference_model/src/ops/ewise_ternary.h
@@ -33,7 +33,7 @@ template <int Rank, DType Dtype>
class OpSelectBase : public GraphNode
{
public:
- OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpSelectBase();
virtual int checkTensorAttributes();
@@ -59,8 +59,8 @@ template <int Rank, DType Dtype>
class OpSelect : public OpSelectBase<Rank, Dtype>
{
public:
- OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : OpSelectBase<Rank, Dtype>(sgt_, attribute_, qinfo_, id_)
+ OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : OpSelectBase<Rank, Dtype>(sgt_, attribute_, id_)
{}
virtual int eval();
int broadcast();
@@ -73,8 +73,8 @@ template <DType Dtype>
class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype>
{
public:
- OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : OpSelectBase<0, Dtype>(sgt_, attribute_, qinfo_, id_)
+ OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : OpSelectBase<0, Dtype>(sgt_, attribute_, id_)
{}
virtual int eval();
};
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 8b83a50..8ef1e3c 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -29,7 +29,10 @@ UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
- fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); };
+ fcn = [](InEigenType a) -> OutEigenType {
+ ASSERT_MSG(0, "In default UnaryNode function, missing function registration");
+ return OutEigenType();
+ };
}
template <int Rank, DType Dtype>
@@ -211,13 +214,28 @@ int OpLogicalNot<Rank, Dtype>::register_fcn()
}
template <int Rank, DType Dtype>
+OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
+ : UnaryNode<Rank, Dtype>(sgt_, Op_NEGATE, id_)
+{
+ INIT_ATTRIBUTE(Negate);
+
+ register_fcn();
+}
+
+template <int Rank, DType Dtype>
+OpNegate<Rank, Dtype>::~OpNegate()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
int OpNegate<Rank, Dtype>::register_fcn()
{
- if (Dtype != DType_INT8 && this->qinfo)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpNegate: zeropoint only for int8_t");
- ERROR_IF(this->qinfo->output_zp() != 0, "OpNegate: zeropoint only for int8_t");
- }
+ ERROR_IF(Dtype != DType_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t");
+ ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t");
switch (Dtype)
{
@@ -251,11 +269,11 @@ int OpNegate<Rank, Dtype>::register_fcn()
break;
case DType_INT8:
this->fcn = [this](InEigenType a) -> OutEigenType {
- int64_t res_in_64 = 0 - (a - this->qinfo->input_zp());
+ int64_t res_in_64 = 0 - (a - attribute->input1_zp());
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)");
- res_in_64 += this->qinfo->output_zp();
+ res_in_64 += attribute->output_zp();
InEigenType result = static_cast<InEigenType>(std::min(std::max(res_in_64, static_cast<int64_t>(QMin)), static_cast<int64_t>(QMax)));
return result;
};
diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h
index 374c8e4..16a4c88 100644
--- a/reference_model/src/ops/ewise_unary.h
+++ b/reference_model/src/ops/ewise_unary.h
@@ -49,7 +49,7 @@ protected:
class Op##Opname : public UnaryNode<Rank, Dtype> \
{ \
public: \
- Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
: UnaryNode<Rank, Dtype>(sgt_, Op_##OPNAME, id_) \
{ \
register_fcn(); \
@@ -61,27 +61,6 @@ protected:
virtual int register_fcn(); \
};
-#define DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Opname, OPNAME) \
- template <int Rank, DType Dtype> \
- class Op##Opname : public UnaryNode<Rank, Dtype> \
- { \
- public: \
- Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
- : UnaryNode<Rank, Dtype>(sgt_, Op_##OPNAME, id_) \
- { \
- INIT_QINFO(Unary); \
- register_fcn(); \
- } \
- static constexpr int32_t QMin = GetQMin<Dtype>::value; \
- static constexpr int32_t QMax = GetQMax<Dtype>::value; \
- using InEigenType = typename GetEigenType<Dtype>::type; \
- using OutEigenType = typename GetEigenType<Dtype>::type; \
- virtual int register_fcn(); \
- \
- protected: \
- TosaUnaryQuantInfo* qinfo; \
- };
-
DEF_TEMPLATE_UNARY_OP(Abs, ABS)
DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT)
DEF_TEMPLATE_UNARY_OP(Ceil, CEIL)
@@ -90,12 +69,28 @@ DEF_TEMPLATE_UNARY_OP(Exp, EXP)
DEF_TEMPLATE_UNARY_OP(Floor, FLOOR)
DEF_TEMPLATE_UNARY_OP(Log, LOG)
DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT)
-DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Negate, NEGATE)
DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL)
DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT)
#undef DEF_TEMPLATE_UNARY_OP
-#undef DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO
+
+// Negate is the only unary op with attributes
+template <int Rank, DType Dtype>
+class OpNegate : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpNegate(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
+ virtual ~OpNegate();
+
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+
+protected:
+ tosa::TosaNegateAttribute* attribute;
+};
}; // namespace TosaReference
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index 6dec1bc..701fdfb 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -15,7 +15,6 @@
#include "image.h"
#include "arith_util.h"
-#include "quant_util.h"
using namespace TosaReference;
using namespace Eigen;
@@ -24,7 +23,6 @@ using namespace tosa;
template <DType InDtype, DType OutDtype>
OpResize<InDtype, OutDtype>::OpResize(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_RESIZE, id_)
{
diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h
index 095dc7d..fea4885 100644
--- a/reference_model/src/ops/image.h
+++ b/reference_model/src/ops/image.h
@@ -27,7 +27,7 @@ template <DType InDtype, DType OutDtype>
class OpResize : public GraphNode
{
public:
- OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpResize();
virtual int checkTensorAttributes() final;
virtual int eval();
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index f7ded9a..62d5b11 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -36,7 +36,6 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
TosaSerializationHandler* tsh,
Op opType,
TosaAttributeBase* attribute,
- TosaQuantInfoBase* qinfo,
uint64_t id,
DType inputDType,
int inputRank,
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index eaa359c..4cb8178 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -18,61 +18,60 @@
#include "attribute.h"
#include "graph_node.h"
-#include "quant_info.h"
#include "template_types.h"
#include "tosa_serialization_handler.h"
#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
case RANK: \
- return new OP<RANK, DType_##DTYPE>(sgt, attribute, qinfo, id);
+ return new OP<RANK, DType_##DTYPE>(sgt, attribute, id);
#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
case RANK: \
- return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id);
+ return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, qinfo, id);
+ return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, id);
#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id);
+ return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
#define DEF_FACTORY_ONE_RANK_0_6(OP) \
switch (inputRank) \
{ \
case 0: \
- return new OP<0>(sgt, attribute, qinfo, id); \
+ return new OP<0>(sgt, attribute, id); \
case 1: \
- return new OP<1>(sgt, attribute, qinfo, id); \
+ return new OP<1>(sgt, attribute, id); \
case 2: \
- return new OP<2>(sgt, attribute, qinfo, id); \
+ return new OP<2>(sgt, attribute, id); \
case 3: \
- return new OP<3>(sgt, attribute, qinfo, id); \
+ return new OP<3>(sgt, attribute, id); \
case 4: \
- return new OP<4>(sgt, attribute, qinfo, id); \
+ return new OP<4>(sgt, attribute, id); \
case 5: \
- return new OP<5>(sgt, attribute, qinfo, id); \
+ return new OP<5>(sgt, attribute, id); \
case 6: \
- return new OP<6>(sgt, attribute, qinfo, id); \
+ return new OP<6>(sgt, attribute, id); \
}
#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \
if (inputDType == DType_##DTYPE) \
{ \
- return new OP<DType_##DTYPE>(sgt, attribute, qinfo, id); \
+ return new OP<DType_##DTYPE>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
}
#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
@@ -231,7 +230,6 @@ public:
tosa::TosaSerializationHandler* tsh,
tosa::Op opType,
tosa::TosaAttributeBase* attribute,
- tosa::TosaQuantInfoBase* qinfo,
uint64_t id,
tosa::DType inputDType,
int inputRank,
diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h
index f3407f4..6e98a76 100644
--- a/reference_model/src/ops/reduction.h
+++ b/reference_model/src/ops/reduction.h
@@ -48,7 +48,7 @@ template <int Rank, DType Dtype>
class OpReduceAll : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceAll(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ OpReduceAll(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_ALL, attribute_, id_)
{}
virtual int eval();
@@ -58,7 +58,7 @@ template <int Rank, DType Dtype>
class OpReduceAny : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceAny(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ OpReduceAny(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_ALL, attribute_, id_)
{}
virtual int eval();
@@ -68,7 +68,7 @@ template <int Rank, DType Dtype>
class OpReduceMax : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ OpReduceMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_MAX, attribute_, id_)
{}
virtual int eval();
@@ -78,7 +78,7 @@ template <int Rank, DType Dtype>
class OpReduceMin : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceMin(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ OpReduceMin(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_MIN, attribute_, id_)
{}
virtual int eval();
@@ -88,7 +88,7 @@ template <int Rank, DType Dtype>
class OpReduceProduct : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceProduct(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ OpReduceProduct(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_PRODUCT, attribute_, id_)
{}
virtual int eval();
@@ -98,7 +98,7 @@ template <int Rank, DType Dtype>
class OpReduceSum : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceSum(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ OpReduceSum(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_)
{}
virtual int eval();
@@ -108,7 +108,7 @@ template <int Rank, DType Dtype>
class OpReduceSumInt : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceSumInt(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ OpReduceSumInt(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_)
{}
virtual int eval();
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index 02ec54f..faf7db9 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -23,7 +23,6 @@ using namespace tosa;
template <DType Dtype>
OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_GATHER, id_)
{
@@ -120,7 +119,6 @@ int OpGather<Dtype>::eval()
template <DType Dtype>
OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_SCATTER, id_)
{
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h
index 66b584a..af09153 100644
--- a/reference_model/src/ops/scatter_gather.h
+++ b/reference_model/src/ops/scatter_gather.h
@@ -27,7 +27,7 @@ template <DType Dtype>
class OpGather : public GraphNode
{
public:
- OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpGather();
virtual int checkTensorAttributes();
@@ -49,7 +49,7 @@ template <DType Dtype>
class OpScatter : public GraphNode
{
public:
- OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpScatter();
virtual int checkTensorAttributes();
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index aef1ad2..3ab4d56 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -114,8 +114,7 @@ int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
return 0;
}
-int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute,
- tosa::TosaConvQuantInfo* qinfo,
+int check_conv_attribute(tosa::TosaConvAttribute* attribute,
uint32_t conv_dimension,
std::vector<int32_t> input_shape,
std::vector<int32_t> output_shape,
@@ -226,18 +225,13 @@ int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute,
return 1;
}
- if (qinfo)
- {
- if (InDtype != DType_INT8 && qinfo->input_zp() != 0)
- {
- msg = "zeropoint only for int8_t";
- return 1;
- }
- if (WeightDtype != DType_INT8 && qinfo->weight_zp() != 0)
- {
- msg = "zeropoint only for int8_t";
- return 1;
- }
+ if (InDtype != DType_INT8 && attribute->input_zp() != 0) {
+ msg = "Input zero point must be zero for non-int8 data";
+ return 1;
+ }
+ if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) {
+ msg = "Weight zero point must be zero for non-int8 data";
+ return 1;
}
return 0;
@@ -246,7 +240,6 @@ int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute,
template <int Rank, DType Dtype>
OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_ARGMAX, id_)
{
@@ -339,7 +332,6 @@ int OpArgMax<Rank, Dtype>::eval()
template <DType Dtype>
OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_AVG_POOL2D, id_)
{
@@ -347,7 +339,6 @@ OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
setRequiredRank(4);
INIT_ATTRIBUTE(Pool);
- INIT_QINFO(Unary);
}
template <DType Dtype>
@@ -377,11 +368,8 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes()
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- if (Dtype != DType_INT8 && this->qinfo)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t");
- ERROR_IF(this->qinfo->output_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t");
- }
+ ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
std::string msg;
if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
@@ -474,9 +462,9 @@ int OpAvgPool2d<Dtype>::eval()
pad[3] = std::make_pair(0, 0);
ETensor4<InEigenType> input_val = this->in->getTensor();
- if (this->qinfo)
+ if (Dtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
}
ETensor4<InEigenType> input_padded = input_val.pad(pad);
@@ -537,7 +525,7 @@ int OpAvgPool2d<Dtype>::eval()
{
REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
}
- this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp());
+ this->out->getTensor() = this->out->getTensor() + (OutEigenType)(attribute->output_zp());
this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
}
@@ -552,7 +540,6 @@ int OpAvgPool2d<Dtype>::eval()
template <DType InDtype, DType WeightDtype>
OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_CONV2D, id_)
{
@@ -560,7 +547,6 @@ OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_,
setRequiredRank(4);
INIT_ATTRIBUTE(Conv);
- INIT_QINFO(Conv);
}
template <DType InDtype, DType WeightDtype>
@@ -568,8 +554,6 @@ OpConv2d<InDtype, WeightDtype>::~OpConv2d()
{
if (attribute)
delete attribute;
- if (qinfo)
- delete qinfo;
}
template <DType InDtype, DType WeightDtype>
@@ -598,7 +582,7 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
std::string msg;
- if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(),
+ if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
{
msg = "OpConv2d: " + msg;
@@ -691,10 +675,10 @@ int OpConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (this->qinfo)
+ if (InDtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
- weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
+ weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
ETensor4<InEigenType> input_padded = input_val.pad(pad);
@@ -739,7 +723,6 @@ int OpConv2d<InDtype, WeightDtype>::eval()
template <DType InDtype, DType WeightDtype>
OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_CONV3D, id_)
{
@@ -747,7 +730,6 @@ OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
setRequiredRank(5);
INIT_ATTRIBUTE(Conv);
- INIT_QINFO(Conv);
}
template <DType InDtype, DType WeightDtype>
@@ -755,8 +737,6 @@ OpConv3d<InDtype, WeightDtype>::~OpConv3d()
{
if (attribute)
delete attribute;
- if (qinfo)
- delete qinfo;
}
template <DType InDtype, DType WeightDtype>
@@ -785,7 +765,7 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
std::string msg;
- if (check_conv_attribute_qinfo(attribute, qinfo, 3 /* conv_dimension */, input->getShape(), output->getShape(),
+ if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
{
msg = "OpConv3d: " + msg;
@@ -858,10 +838,10 @@ int OpConv3d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (this->qinfo)
+ if (InDtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
- weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
+ weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
ETensor5<InEigenType> input_padded = input_val.pad(pad);
@@ -931,7 +911,6 @@ int OpConv3d<InDtype, WeightDtype>::eval()
template <DType InDtype, DType WeightDtype>
OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
{
@@ -939,7 +918,6 @@ OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sg
setRequiredRank(4);
INIT_ATTRIBUTE(Conv);
- INIT_QINFO(Conv);
}
template <DType InDtype, DType WeightDtype>
@@ -947,8 +925,6 @@ OpDepthwiseConv2d<InDtype, WeightDtype>::~OpDepthwiseConv2d()
{
if (attribute)
delete attribute;
- if (qinfo)
- delete qinfo;
}
template <DType InDtype, DType WeightDtype>
@@ -977,7 +953,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
std::string msg;
- if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(),
+ if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg))
{
msg = "OpDepthwiseConv2d: " + msg;
@@ -1041,10 +1017,10 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (this->qinfo)
+ if (InDtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
- weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
+ weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
ETensor4<InEigenType> input_padded = input_val.pad(pad);
@@ -1108,21 +1084,20 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
template <DType InDtype, DType WeightDtype>
OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
{
setRequiredOperands(3, 1);
setRequiredRank(2);
- INIT_QINFO(Conv);
+ INIT_ATTRIBUTE(FullyConnected);
}
template <DType InDtype, DType WeightDtype>
OpFullyConnected<InDtype, WeightDtype>::~OpFullyConnected()
{
- if (qinfo)
- delete qinfo;
+ if (attribute)
+ delete attribute;
}
template <DType InDtype, DType WeightDtype>
@@ -1157,17 +1132,8 @@ int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
- if (this->qinfo)
- {
- if (InDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
- }
- if (WeightDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
- }
- }
+ ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
return 0;
}
@@ -1190,10 +1156,10 @@ int OpFullyConnected<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
- if (this->qinfo)
+ if (InDtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
- weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
+ weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
this->output->getTensor() =
@@ -1211,21 +1177,20 @@ int OpFullyConnected<InDtype, WeightDtype>::eval()
template <DType Dtype>
OpMatMul<Dtype>::OpMatMul(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_MATMUL, id_)
{
setRequiredOperands(2, 1);
setRequiredRank(3);
- INIT_QINFO(MatMul);
+ INIT_ATTRIBUTE(MatMul);
}
template <DType Dtype>
OpMatMul<Dtype>::~OpMatMul()
{
- if (qinfo)
- delete qinfo;
+ if (attribute)
+ delete attribute;
}
template <DType Dtype>
@@ -1284,11 +1249,8 @@ int OpMatMul<Dtype>::checkTensorAttributes()
}
W = b->getShape()[2];
- if (Dtype != DType_INT8 && this->qinfo)
- {
- ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t");
- ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t");
- }
+ ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data");
return 0;
}
@@ -1301,10 +1263,10 @@ int OpMatMul<Dtype>::eval()
TIn a_val = this->a->getTensor();
TIn b_val = this->b->getTensor();
- if (this->qinfo)
+ if (Dtype == DType_INT8)
{
- a_val = a_val - (InEigenType)this->qinfo->a_zp();
- b_val = b_val - (InEigenType)this->qinfo->b_zp();
+ a_val = a_val - (InEigenType)attribute->a_zp();
+ b_val = b_val - (InEigenType)attribute->b_zp();
}
Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
@@ -1351,7 +1313,6 @@ int OpMatMul<Dtype>::eval()
template <DType Dtype>
OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_MAX_POOL2D, id_)
{
@@ -1484,7 +1445,6 @@ int OpMaxPool2d<Dtype>::eval()
template <DType InDtype, DType WeightDtype>
OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
@@ -1492,7 +1452,6 @@ OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sg
setRequiredRank(4);
INIT_ATTRIBUTE(TransposeConv);
- INIT_QINFO(Conv);
}
template <DType InDtype, DType WeightDtype>
@@ -1500,8 +1459,6 @@ OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d()
{
if (attribute)
delete attribute;
- if (qinfo)
- delete qinfo;
}
template <DType InDtype, DType WeightDtype>
@@ -1595,17 +1552,8 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
- if (this->qinfo)
- {
- if (InDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
- }
- if (WeightDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
- }
- }
+ ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
return 0;
}
@@ -1655,10 +1603,10 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (this->qinfo)
+ if (InDtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
- weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
+ weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
Eigen::array<Eigen::Index, 4> reshape_dim;
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 05b1ca1..24eadeb 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -28,7 +28,7 @@ template <int Rank, DType Dtype>
class OpArgMax : public GraphNode
{
public:
- OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpArgMax();
virtual int checkTensorAttributes();
@@ -49,7 +49,7 @@ template <DType Dtype>
class OpAvgPool2d : public GraphNode
{
public:
- OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpAvgPool2d();
virtual int checkTensorAttributes();
@@ -69,7 +69,6 @@ protected:
TosaReference::TensorTemplate<TIn>* in;
TosaReference::TensorTemplate<TOut>* out;
tosa::TosaPoolAttribute* attribute;
- tosa::TosaUnaryQuantInfo* qinfo;
protected:
// return a 1D [N] tensor that describes a how many valid elements covered in the input space
@@ -80,7 +79,7 @@ template <DType InDtype, DType WeightDtype>
class OpConv2d : public GraphNode
{
public:
- OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpConv2d();
virtual int checkTensorAttributes() final;
@@ -105,14 +104,13 @@ protected:
TosaReference::TensorTemplate<TBias>* bias;
TosaReference::TensorTemplate<TAcc>* output;
tosa::TosaConvAttribute* attribute;
- tosa::TosaConvQuantInfo* qinfo;
};
template <DType InDtype, DType WeightDtype>
class OpConv3d : public GraphNode
{
public:
- OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpConv3d();
virtual int checkTensorAttributes() final;
@@ -137,14 +135,13 @@ protected:
TosaReference::TensorTemplate<TBias>* bias;
TosaReference::TensorTemplate<TAcc>* output;
tosa::TosaConvAttribute* attribute;
- tosa::TosaConvQuantInfo* qinfo;
};
template <DType InDtype, DType WeightDtype>
class OpDepthwiseConv2d : public GraphNode
{
public:
- OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpDepthwiseConv2d();
virtual int checkTensorAttributes() final;
@@ -169,14 +166,13 @@ protected:
TosaReference::TensorTemplate<TBias>* bias;
TosaReference::TensorTemplate<TAcc>* output;
tosa::TosaConvAttribute* attribute;
- tosa::TosaConvQuantInfo* qinfo;
};
template <DType InDtype, DType WeightDtype>
class OpFullyConnected : public GraphNode
{
public:
- OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpFullyConnected();
virtual int checkTensorAttributes() final;
@@ -199,14 +195,15 @@ protected:
TosaReference::TensorTemplate<TWeight>* weight;
TosaReference::TensorTemplate<TBias>* bias;
TosaReference::TensorTemplate<TAcc>* output;
- tosa::TosaConvQuantInfo* qinfo;
+
+ tosa::TosaFullyConnectedAttribute* attribute;
};
template <DType Dtype>
class OpMatMul : public GraphNode
{
public:
- OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpMatMul();
virtual int checkTensorAttributes() final;
@@ -230,14 +227,15 @@ protected:
int64_t H;
int64_t W;
int64_t C;
- tosa::TosaMatMulQuantInfo* qinfo;
+
+ tosa::TosaMatMulAttribute* attribute;
};
template <DType Dtype>
class OpMaxPool2d : public GraphNode
{
public:
- OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpMaxPool2d();
virtual int checkTensorAttributes();
@@ -258,7 +256,7 @@ template <DType InDtype, DType WeightDtype>
class OpTransposeConv2d : public GraphNode
{
public:
- OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpTransposeConv2d();
virtual int checkTensorAttributes() final;
@@ -283,7 +281,6 @@ protected:
TosaReference::TensorTemplate<TBias>* bias;
TosaReference::TensorTemplate<TAcc>* output;
TosaTransposeConvAttribute* attribute;
- TosaConvQuantInfo* qinfo;
};
}; // namespace TosaReference
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 7ee9692..ac54932 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -25,7 +25,6 @@ using namespace tosa;
template <int Rank, DType InDtype, DType OutDtype>
OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_RESCALE, id_)
{
@@ -218,7 +217,6 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
template <int Rank, DType InDtype, DType OutDtype>
OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_CAST, id_)
{
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index 060e14e..53470d1 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -26,7 +26,7 @@ template <int Rank, DType InDtype, DType OutDtype>
class OpRescale : public GraphNode
{
public:
- OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpRescale();
virtual int checkTensorAttributes() final;
@@ -140,7 +140,7 @@ template <int Rank, DType InDtype, DType OutDtype>
class OpCast : public GraphNode
{
public:
- OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpCast();
virtual int checkTensorAttributes() final;
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 36e0a63..d0cc6cf 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -197,7 +197,7 @@ int SubgraphTraverser::initializeGraph()
DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
- GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype,
+ GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
if (!node)
{
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject a336d54aca08b06953a8b6c49d7e5f6c4899952
+Subproject bdcc3fee1b8bf55aac50e060115b92a1ccf9741
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index a27d849..8e00fab 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -4,7 +4,6 @@ import itertools
import math
import numpy as np
-import serializer.tosa_serializer as ts
from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
from serializer.tosa_serializer import DTypeNames
@@ -26,7 +25,7 @@ class TosaQuantGen:
pass
@staticmethod
- def getQinfo(testGen, dtype, error_name=None):
+ def getZeroPoint(testGen, dtype, error_name=None):
if dtype == DType.INT8:
return testGen.randInt(-128, 128)
@@ -45,27 +44,25 @@ class TosaQuantGen:
@staticmethod
def qgUnary(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ ]
elif error_name == ErrorIf.OutputZeroPointNotZero:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ ]
else:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ ]
return qinfo
@staticmethod
def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
if isinstance(dtype_or_dtypeList, list):
# a list of [input, weights, accumulator] dtypes
dtypeList = dtype_or_dtypeList
@@ -74,40 +71,34 @@ class TosaQuantGen:
dtypeList = [dtype_or_dtypeList] * 3
if error_name == ErrorIf.InputZeroPointNotZero:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
+ ]
elif error_name == ErrorIf.WeightZeroPointNotZero:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
+ ]
else:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
-
- qinfo.ConvQuantInfo(input_zp, weights_zp)
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
+ ]
return qinfo
@staticmethod
def qgMatmul(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.MatMulQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ ]
else:
- qinfo.MatMulQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
- return qinfo
-
- @staticmethod
- def qgPad(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
- if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
- else:
- qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ ]
return qinfo
@staticmethod
@@ -550,7 +541,7 @@ class TosaTensorValuesGen:
pass
@staticmethod
- def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
pCount, cCount = op["operands"]
tens = []
@@ -562,7 +553,7 @@ class TosaTensorValuesGen:
return tens
@staticmethod
- def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32 and error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -582,11 +573,11 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32 and error_name is None:
# Make sure the operation does not cause value saturation - where
# the number wraps due to limited number of bits to store the answer
@@ -651,12 +642,12 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgCondIfWhileLoop(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
+ testGen, op, dtypeList, shapeList, testArgs, error_name=None
):
if dtypeList[0] in (
DType.INT32,
@@ -689,12 +680,12 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgArithmeticRightShift(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
+ testGen, op, dtypeList, shapeList, testArgs, error_name=None
):
pCount, cCount = op["operands"]
# Force value of operand[1] to be within [0, num_bits]
@@ -722,16 +713,16 @@ class TosaTensorValuesGen:
return placeholders
@staticmethod
- def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
# Set datatype of condition tensor to boolean
dtypeList[0] = DType.BOOL
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -765,11 +756,11 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgMul(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -839,11 +830,11 @@ class TosaTensorValuesGen:
return tens
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
count = len(shapeList) - testGen.args.num_const_inputs_concat
if count < 1:
count = 1
@@ -866,9 +857,7 @@ class TosaTensorValuesGen:
return tens
@staticmethod
- def tvgLogicalShift(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
- ):
+ def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
pCount, cCount = op["operands"]
assert (
pCount == 2 and cCount == 0
@@ -886,7 +875,7 @@ class TosaTensorValuesGen:
return placeholders
@staticmethod
- def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -924,13 +913,11 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgReduceSum(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
- ):
+ def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32:
pCount, cCount = op["operands"]
assert (
@@ -949,7 +936,7 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 1967d8a..b331a42 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1003,12 +1003,7 @@ class TosaErrorValidator:
Generally input_zp is index 0, output_zp is index 1
"""
- if isinstance(qinfo, tuple):
- zero_point = qinfo[index]
- else:
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- zero_point = qinfo.ints[index][1]
- return zero_point
+ return qinfo[index]
@staticmethod
def evInputZeroPointNotZero(check=False, **kwargs):
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 262a652..b0e7c8c 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -216,20 +216,19 @@ class TosaTestGen:
# build_placeholder returns an int, ABS/other ops does not
if isinstance(op, int):
- self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
+ self.ser.addOperator(op, a.name, result_tens.name, None)
return result_tens
elif op["op"] == Op.IDENTITY:
- self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo)
+ self.ser.addOperator(op["op"], a.name, result_tens.name, None)
return result_tens
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongOutputType:
if result_tens.dtype not in [DType.INT8, DType.UINT8]:
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(self, a.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, a.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
@@ -255,7 +254,12 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
+ attr = None
+ if op["op"] == Op.NEGATE:
+ attr = ts.TosaSerializerAttribute()
+ attr.NegateAttribute(qinfo[0], qinfo[1])
+
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
@@ -542,11 +546,10 @@ class TosaTestGen:
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType:
if input.dtype not in [DType.INT8, DType.UINT8]:
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(self, input.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, input.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error if checks.
input_list = [input.name]
@@ -577,10 +580,13 @@ class TosaTestGen:
):
return None
+ if qinfo is None:
+ qinfo = [0, 0]
+
attr = ts.TosaSerializerAttribute()
- attr.PoolAttribute(kernel, stride, pad)
+ attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_conv2d(
@@ -606,11 +612,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -642,9 +647,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_conv3d(
@@ -670,11 +675,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -706,9 +710,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_transpose_conv2d(
@@ -734,11 +738,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -769,9 +772,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.TransposeConvAttribute(out_pad, stride, output_shape)
+ attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_depthwise_conv2d(
@@ -796,11 +799,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -832,9 +834,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_fully_connected(
@@ -871,7 +873,10 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
+ attr = ts.TosaSerializerAttribute()
+ attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
+
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
@@ -905,7 +910,10 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
+ attr = ts.TosaSerializerAttribute()
+ attr.MatMulAttribute(qinfo[0], qinfo[1])
+
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
@@ -1164,7 +1172,7 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
@@ -2212,7 +2220,7 @@ class TosaTestGen:
else:
qinfo = None
- tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, qinfo, error_name)
+ tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
try:
if error_if_validators is None:
@@ -3425,7 +3433,6 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agPad,
),
- "qgen": TosaQuantGen.qgPad,
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,