From b5fabec33abeca2d92c20c7b094fa3f113d0ddd8 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 7 Jun 2022 05:20:44 +0000 Subject: Remove quantization info from serialization attributes Any needed information moves into the attributes for each operator. New serialization library version removes teh quantization information attributes from the schema Signed-off-by: Eric Kunze Change-Id: Icf6165687ab1fd34a01f64c01b0b92b2820e72fa --- reference_model/src/ops/tensor_ops.h | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) (limited to 'reference_model/src/ops/tensor_ops.h') diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 05b1ca1..24eadeb 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -28,7 +28,7 @@ template class OpArgMax : public GraphNode { public: - OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpArgMax(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template class OpAvgPool2d : public GraphNode { public: - OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpAvgPool2d(); virtual int checkTensorAttributes(); @@ -69,7 +69,6 @@ protected: TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; tosa::TosaPoolAttribute* attribute; - tosa::TosaUnaryQuantInfo* qinfo; protected: // return a 1D [N] tensor that describes a how many valid elements covered in the input space @@ -80,7 +79,7 @@ template class OpConv2d : public GraphNode { public: - OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConv2d(); virtual int checkTensorAttributes() final; @@ -105,14 +104,13 @@ protected: TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template class OpConv3d : public GraphNode { public: - OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConv3d(); virtual int checkTensorAttributes() final; @@ -137,14 +135,13 @@ protected: TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template class OpDepthwiseConv2d : public GraphNode { public: - OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpDepthwiseConv2d(); virtual int checkTensorAttributes() final; @@ -169,14 +166,13 @@ protected: TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template class OpFullyConnected : public GraphNode { public: - OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpFullyConnected(); virtual int checkTensorAttributes() final; @@ -199,14 +195,15 @@ protected: TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; - tosa::TosaConvQuantInfo* qinfo; + + tosa::TosaFullyConnectedAttribute* attribute; }; template class OpMatMul : public GraphNode { public: - OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpMatMul(); virtual int checkTensorAttributes() final; @@ -230,14 +227,15 @@ protected: int64_t H; int64_t W; int64_t C; - tosa::TosaMatMulQuantInfo* qinfo; + + tosa::TosaMatMulAttribute* attribute; }; template class OpMaxPool2d : public GraphNode { public: - OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpMaxPool2d(); virtual int checkTensorAttributes(); @@ -258,7 +256,7 @@ template class OpTransposeConv2d : public GraphNode { public: - OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTransposeConv2d(); virtual int checkTensorAttributes() final; @@ -283,7 +281,6 @@ protected: TosaReference::TensorTemplate* bias; TosaReference::TensorTemplate* output; TosaTransposeConvAttribute* attribute; - TosaConvQuantInfo* qinfo; }; }; // namespace TosaReference -- cgit v1.2.1