diff options
author | Eric Kunze <eric.kunze@arm.com> | 2022-06-07 05:20:44 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-06-15 11:38:04 -0700 |
commit | b5fabec33abeca2d92c20c7b094fa3f113d0ddd8 (patch) | |
tree | 9c7d946012c7a70a7fcb237daa4376d7b65c6f76 /reference_model/src/ops/tensor_ops.h | |
parent | 24594f55ee3bf0e95c764e51b94c3ec7f9cfa54a (diff) | |
download | reference_model-b5fabec33abeca2d92c20c7b094fa3f113d0ddd8.tar.gz |
Remove quantization info from serialization attributes
Any needed information moves into the attributes for each operator.
New serialization library version removes teh quantization information
attributes from the schema
Signed-off-by: Eric Kunze <eric.kunze@arm.com>
Change-Id: Icf6165687ab1fd34a01f64c01b0b92b2820e72fa
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 29 |
1 files changed, 13 insertions, 16 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 05b1ca1..24eadeb 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -28,7 +28,7 @@ template <int Rank, DType Dtype> class OpArgMax : public GraphNode { public: - OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpArgMax(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template <DType Dtype> class OpAvgPool2d : public GraphNode { public: - OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpAvgPool2d(); virtual int checkTensorAttributes(); @@ -69,7 +69,6 @@ protected: TosaReference::TensorTemplate<TIn>* in; TosaReference::TensorTemplate<TOut>* out; tosa::TosaPoolAttribute* attribute; - tosa::TosaUnaryQuantInfo* qinfo; protected: // return a 1D [N] tensor that describes a how many valid elements covered in the input space @@ -80,7 +79,7 @@ template <DType InDtype, DType WeightDtype> class OpConv2d : public GraphNode { public: - OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConv2d(); virtual int checkTensorAttributes() final; @@ -105,14 +104,13 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template <DType InDtype, DType WeightDtype> class OpConv3d : public GraphNode { public: - OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConv3d(); virtual int checkTensorAttributes() final; @@ -137,14 +135,13 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template <DType InDtype, DType WeightDtype> class OpDepthwiseConv2d : public GraphNode { public: - OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpDepthwiseConv2d(); virtual int checkTensorAttributes() final; @@ -169,14 +166,13 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template <DType InDtype, DType WeightDtype> class OpFullyConnected : public GraphNode { public: - OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpFullyConnected(); virtual int checkTensorAttributes() final; @@ -199,14 +195,15 @@ protected: TosaReference::TensorTemplate<TWeight>* weight; TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; - tosa::TosaConvQuantInfo* qinfo; + + tosa::TosaFullyConnectedAttribute* attribute; }; template <DType Dtype> class OpMatMul : public GraphNode { public: - OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpMatMul(); virtual int checkTensorAttributes() final; @@ -230,14 +227,15 @@ protected: int64_t H; int64_t W; int64_t C; - tosa::TosaMatMulQuantInfo* qinfo; + + tosa::TosaMatMulAttribute* attribute; }; template <DType Dtype> class OpMaxPool2d : public GraphNode { public: - OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpMaxPool2d(); virtual int checkTensorAttributes(); @@ -258,7 +256,7 @@ template <DType InDtype, DType WeightDtype> class OpTransposeConv2d : public GraphNode { public: - OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTransposeConv2d(); virtual int checkTensorAttributes() final; @@ -283,7 +281,6 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; TosaTransposeConvAttribute* attribute; - TosaConvQuantInfo* qinfo; }; }; // namespace TosaReference |