diff options
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 144 |
1 files changed, 46 insertions, 98 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index aef1ad2..3ab4d56 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -114,8 +114,7 @@ int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute, return 0; } -int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute, - tosa::TosaConvQuantInfo* qinfo, +int check_conv_attribute(tosa::TosaConvAttribute* attribute, uint32_t conv_dimension, std::vector<int32_t> input_shape, std::vector<int32_t> output_shape, @@ -226,18 +225,13 @@ int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute, return 1; } - if (qinfo) - { - if (InDtype != DType_INT8 && qinfo->input_zp() != 0) - { - msg = "zeropoint only for int8_t"; - return 1; - } - if (WeightDtype != DType_INT8 && qinfo->weight_zp() != 0) - { - msg = "zeropoint only for int8_t"; - return 1; - } + if (InDtype != DType_INT8 && attribute->input_zp() != 0) { + msg = "Input zero point must be zero for non-int8 data"; + return 1; + } + if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) { + msg = "Weight zero point must be zero for non-int8 data"; + return 1; } return 0; @@ -246,7 +240,6 @@ int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute, template <int Rank, DType Dtype> OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_ARGMAX, id_) { @@ -339,7 +332,6 @@ int OpArgMax<Rank, Dtype>::eval() template <DType Dtype> OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_AVG_POOL2D, id_) { @@ -347,7 +339,6 @@ OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_, setRequiredRank(4); INIT_ATTRIBUTE(Pool); - INIT_QINFO(Unary); } template <DType Dtype> @@ -377,11 +368,8 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes() in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); - if (Dtype != DType_INT8 && this->qinfo) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t"); - ERROR_IF(this->qinfo->output_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t"); - } + ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data"); std::string msg; if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg)) @@ -474,9 +462,9 @@ int OpAvgPool2d<Dtype>::eval() pad[3] = std::make_pair(0, 0); ETensor4<InEigenType> input_val = this->in->getTensor(); - if (this->qinfo) + if (Dtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); } ETensor4<InEigenType> input_padded = input_val.pad(pad); @@ -537,7 +525,7 @@ int OpAvgPool2d<Dtype>::eval() { REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str()); } - this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp()); + this->out->getTensor() = this->out->getTensor() + (OutEigenType)(attribute->output_zp()); this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin); this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax); } @@ -552,7 +540,6 @@ int OpAvgPool2d<Dtype>::eval() template <DType InDtype, DType WeightDtype> OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_CONV2D, id_) { @@ -560,7 +547,6 @@ OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_, setRequiredRank(4); INIT_ATTRIBUTE(Conv); - INIT_QINFO(Conv); } template <DType InDtype, DType WeightDtype> @@ -568,8 +554,6 @@ OpConv2d<InDtype, WeightDtype>::~OpConv2d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template <DType InDtype, DType WeightDtype> @@ -598,7 +582,7 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes() output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); std::string msg; - if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(), + if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(), weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg)) { msg = "OpConv2d: " + msg; @@ -691,10 +675,10 @@ int OpConv2d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } ETensor4<InEigenType> input_padded = input_val.pad(pad); @@ -739,7 +723,6 @@ int OpConv2d<InDtype, WeightDtype>::eval() template <DType InDtype, DType WeightDtype> OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_CONV3D, id_) { @@ -747,7 +730,6 @@ OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_, setRequiredRank(5); INIT_ATTRIBUTE(Conv); - INIT_QINFO(Conv); } template <DType InDtype, DType WeightDtype> @@ -755,8 +737,6 @@ OpConv3d<InDtype, WeightDtype>::~OpConv3d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template <DType InDtype, DType WeightDtype> @@ -785,7 +765,7 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes() output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); std::string msg; - if (check_conv_attribute_qinfo(attribute, qinfo, 3 /* conv_dimension */, input->getShape(), output->getShape(), + if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(), weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg)) { msg = "OpConv3d: " + msg; @@ -858,10 +838,10 @@ int OpConv3d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } ETensor5<InEigenType> input_padded = input_val.pad(pad); @@ -931,7 +911,6 @@ int OpConv3d<InDtype, WeightDtype>::eval() template <DType InDtype, DType WeightDtype> OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { @@ -939,7 +918,6 @@ OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sg setRequiredRank(4); INIT_ATTRIBUTE(Conv); - INIT_QINFO(Conv); } template <DType InDtype, DType WeightDtype> @@ -947,8 +925,6 @@ OpDepthwiseConv2d<InDtype, WeightDtype>::~OpDepthwiseConv2d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template <DType InDtype, DType WeightDtype> @@ -977,7 +953,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes() output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); std::string msg; - if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(), + if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(), weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg)) { msg = "OpDepthwiseConv2d: " + msg; @@ -1041,10 +1017,10 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } ETensor4<InEigenType> input_padded = input_val.pad(pad); @@ -1108,21 +1084,20 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval() template <DType InDtype, DType WeightDtype> OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) { setRequiredOperands(3, 1); setRequiredRank(2); - INIT_QINFO(Conv); + INIT_ATTRIBUTE(FullyConnected); } template <DType InDtype, DType WeightDtype> OpFullyConnected<InDtype, WeightDtype>::~OpFullyConnected() { - if (qinfo) - delete qinfo; + if (attribute) + delete attribute; } template <DType InDtype, DType WeightDtype> @@ -1157,17 +1132,8 @@ int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes() output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); - if (this->qinfo) - { - if (InDtype != DType_INT8) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t"); - } - if (WeightDtype != DType_INT8) - { - ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t"); - } - } + ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data"); return 0; } @@ -1190,10 +1156,10 @@ int OpFullyConnected<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } this->output->getTensor() = @@ -1211,21 +1177,20 @@ int OpFullyConnected<InDtype, WeightDtype>::eval() template <DType Dtype> OpMatMul<Dtype>::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_MATMUL, id_) { setRequiredOperands(2, 1); setRequiredRank(3); - INIT_QINFO(MatMul); + INIT_ATTRIBUTE(MatMul); } template <DType Dtype> OpMatMul<Dtype>::~OpMatMul() { - if (qinfo) - delete qinfo; + if (attribute) + delete attribute; } template <DType Dtype> @@ -1284,11 +1249,8 @@ int OpMatMul<Dtype>::checkTensorAttributes() } W = b->getShape()[2]; - if (Dtype != DType_INT8 && this->qinfo) - { - ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t"); - ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t"); - } + ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data"); return 0; } @@ -1301,10 +1263,10 @@ int OpMatMul<Dtype>::eval() TIn a_val = this->a->getTensor(); TIn b_val = this->b->getTensor(); - if (this->qinfo) + if (Dtype == DType_INT8) { - a_val = a_val - (InEigenType)this->qinfo->a_zp(); - b_val = b_val - (InEigenType)this->qinfo->b_zp(); + a_val = a_val - (InEigenType)attribute->a_zp(); + b_val = b_val - (InEigenType)attribute->b_zp(); } Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C }); @@ -1351,7 +1313,6 @@ int OpMatMul<Dtype>::eval() template <DType Dtype> OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_MAX_POOL2D, id_) { @@ -1484,7 +1445,6 @@ int OpMaxPool2d<Dtype>::eval() template <DType InDtype, DType WeightDtype> OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { @@ -1492,7 +1452,6 @@ OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sg setRequiredRank(4); INIT_ATTRIBUTE(TransposeConv); - INIT_QINFO(Conv); } template <DType InDtype, DType WeightDtype> @@ -1500,8 +1459,6 @@ OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d() { if (attribute) delete attribute; - if (qinfo) - delete qinfo; } template <DType InDtype, DType WeightDtype> @@ -1595,17 +1552,8 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes() return 1; } - if (this->qinfo) - { - if (InDtype != DType_INT8) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t"); - } - if (WeightDtype != DType_INT8) - { - ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t"); - } - } + ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data"); return 0; } @@ -1655,10 +1603,10 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (this->qinfo) + if (InDtype == DType_INT8) { - input_val = input_val - (InEigenType)this->qinfo->input_zp(); - weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + input_val = input_val - (InEigenType)attribute->input_zp(); + weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); } Eigen::array<Eigen::Index, 4> reshape_dim; |