diff options
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 29 |
1 files changed, 13 insertions, 16 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 05b1ca1..24eadeb 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -28,7 +28,7 @@ template <int Rank, DType Dtype> class OpArgMax : public GraphNode { public: - OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpArgMax(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template <DType Dtype> class OpAvgPool2d : public GraphNode { public: - OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpAvgPool2d(); virtual int checkTensorAttributes(); @@ -69,7 +69,6 @@ protected: TosaReference::TensorTemplate<TIn>* in; TosaReference::TensorTemplate<TOut>* out; tosa::TosaPoolAttribute* attribute; - tosa::TosaUnaryQuantInfo* qinfo; protected: // return a 1D [N] tensor that describes a how many valid elements covered in the input space @@ -80,7 +79,7 @@ template <DType InDtype, DType WeightDtype> class OpConv2d : public GraphNode { public: - OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConv2d(); virtual int checkTensorAttributes() final; @@ -105,14 +104,13 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template <DType InDtype, DType WeightDtype> class OpConv3d : public GraphNode { public: - OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpConv3d(); virtual int checkTensorAttributes() final; @@ -137,14 +135,13 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template <DType InDtype, DType WeightDtype> class OpDepthwiseConv2d : public GraphNode { public: - OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpDepthwiseConv2d(); virtual int checkTensorAttributes() final; @@ -169,14 +166,13 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; tosa::TosaConvAttribute* attribute; - tosa::TosaConvQuantInfo* qinfo; }; template <DType InDtype, DType WeightDtype> class OpFullyConnected : public GraphNode { public: - OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpFullyConnected(); virtual int checkTensorAttributes() final; @@ -199,14 +195,15 @@ protected: TosaReference::TensorTemplate<TWeight>* weight; TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; - tosa::TosaConvQuantInfo* qinfo; + + tosa::TosaFullyConnectedAttribute* attribute; }; template <DType Dtype> class OpMatMul : public GraphNode { public: - OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpMatMul(); virtual int checkTensorAttributes() final; @@ -230,14 +227,15 @@ protected: int64_t H; int64_t W; int64_t C; - tosa::TosaMatMulQuantInfo* qinfo; + + tosa::TosaMatMulAttribute* attribute; }; template <DType Dtype> class OpMaxPool2d : public GraphNode { public: - OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpMaxPool2d(); virtual int checkTensorAttributes(); @@ -258,7 +256,7 @@ template <DType InDtype, DType WeightDtype> class OpTransposeConv2d : public GraphNode { public: - OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpTransposeConv2d(); virtual int checkTensorAttributes() final; @@ -283,7 +281,6 @@ protected: TosaReference::TensorTemplate<TBias>* bias; TosaReference::TensorTemplate<TAcc>* output; TosaTransposeConvAttribute* attribute; - TosaConvQuantInfo* qinfo; }; }; // namespace TosaReference |