aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc144
1 files changed, 46 insertions, 98 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index aef1ad2..3ab4d56 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -114,8 +114,7 @@ int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
return 0;
}
-int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute,
- tosa::TosaConvQuantInfo* qinfo,
+int check_conv_attribute(tosa::TosaConvAttribute* attribute,
uint32_t conv_dimension,
std::vector<int32_t> input_shape,
std::vector<int32_t> output_shape,
@@ -226,18 +225,13 @@ int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute,
return 1;
}
- if (qinfo)
- {
- if (InDtype != DType_INT8 && qinfo->input_zp() != 0)
- {
- msg = "zeropoint only for int8_t";
- return 1;
- }
- if (WeightDtype != DType_INT8 && qinfo->weight_zp() != 0)
- {
- msg = "zeropoint only for int8_t";
- return 1;
- }
+ if (InDtype != DType_INT8 && attribute->input_zp() != 0) {
+ msg = "Input zero point must be zero for non-int8 data";
+ return 1;
+ }
+ if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) {
+ msg = "Weight zero point must be zero for non-int8 data";
+ return 1;
}
return 0;
@@ -246,7 +240,6 @@ int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute,
template <int Rank, DType Dtype>
OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_ARGMAX, id_)
{
@@ -339,7 +332,6 @@ int OpArgMax<Rank, Dtype>::eval()
template <DType Dtype>
OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_AVG_POOL2D, id_)
{
@@ -347,7 +339,6 @@ OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
setRequiredRank(4);
INIT_ATTRIBUTE(Pool);
- INIT_QINFO(Unary);
}
template <DType Dtype>
@@ -377,11 +368,8 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes()
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- if (Dtype != DType_INT8 && this->qinfo)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t");
- ERROR_IF(this->qinfo->output_zp() != 0, "OpAvgPool2d: zeropoint only for int8_t");
- }
+ ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
std::string msg;
if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
@@ -474,9 +462,9 @@ int OpAvgPool2d<Dtype>::eval()
pad[3] = std::make_pair(0, 0);
ETensor4<InEigenType> input_val = this->in->getTensor();
- if (this->qinfo)
+ if (Dtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
}
ETensor4<InEigenType> input_padded = input_val.pad(pad);
@@ -537,7 +525,7 @@ int OpAvgPool2d<Dtype>::eval()
{
REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
}
- this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp());
+ this->out->getTensor() = this->out->getTensor() + (OutEigenType)(attribute->output_zp());
this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
}
@@ -552,7 +540,6 @@ int OpAvgPool2d<Dtype>::eval()
template <DType InDtype, DType WeightDtype>
OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_CONV2D, id_)
{
@@ -560,7 +547,6 @@ OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_,
setRequiredRank(4);
INIT_ATTRIBUTE(Conv);
- INIT_QINFO(Conv);
}
template <DType InDtype, DType WeightDtype>
@@ -568,8 +554,6 @@ OpConv2d<InDtype, WeightDtype>::~OpConv2d()
{
if (attribute)
delete attribute;
- if (qinfo)
- delete qinfo;
}
template <DType InDtype, DType WeightDtype>
@@ -598,7 +582,7 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
std::string msg;
- if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(),
+ if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
{
msg = "OpConv2d: " + msg;
@@ -691,10 +675,10 @@ int OpConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (this->qinfo)
+ if (InDtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
- weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
+ weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
ETensor4<InEigenType> input_padded = input_val.pad(pad);
@@ -739,7 +723,6 @@ int OpConv2d<InDtype, WeightDtype>::eval()
template <DType InDtype, DType WeightDtype>
OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_CONV3D, id_)
{
@@ -747,7 +730,6 @@ OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
setRequiredRank(5);
INIT_ATTRIBUTE(Conv);
- INIT_QINFO(Conv);
}
template <DType InDtype, DType WeightDtype>
@@ -755,8 +737,6 @@ OpConv3d<InDtype, WeightDtype>::~OpConv3d()
{
if (attribute)
delete attribute;
- if (qinfo)
- delete qinfo;
}
template <DType InDtype, DType WeightDtype>
@@ -785,7 +765,7 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
std::string msg;
- if (check_conv_attribute_qinfo(attribute, qinfo, 3 /* conv_dimension */, input->getShape(), output->getShape(),
+ if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
weight->getShape(), 1 /* offset_kernel */, InDtype, WeightDtype, msg))
{
msg = "OpConv3d: " + msg;
@@ -858,10 +838,10 @@ int OpConv3d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (this->qinfo)
+ if (InDtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
- weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
+ weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
ETensor5<InEigenType> input_padded = input_val.pad(pad);
@@ -931,7 +911,6 @@ int OpConv3d<InDtype, WeightDtype>::eval()
template <DType InDtype, DType WeightDtype>
OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
{
@@ -939,7 +918,6 @@ OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sg
setRequiredRank(4);
INIT_ATTRIBUTE(Conv);
- INIT_QINFO(Conv);
}
template <DType InDtype, DType WeightDtype>
@@ -947,8 +925,6 @@ OpDepthwiseConv2d<InDtype, WeightDtype>::~OpDepthwiseConv2d()
{
if (attribute)
delete attribute;
- if (qinfo)
- delete qinfo;
}
template <DType InDtype, DType WeightDtype>
@@ -977,7 +953,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
std::string msg;
- if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(),
+ if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
weight->getShape(), 0 /* offset_kernel */, InDtype, WeightDtype, msg))
{
msg = "OpDepthwiseConv2d: " + msg;
@@ -1041,10 +1017,10 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (this->qinfo)
+ if (InDtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
- weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
+ weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
ETensor4<InEigenType> input_padded = input_val.pad(pad);
@@ -1108,21 +1084,20 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
template <DType InDtype, DType WeightDtype>
OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
{
setRequiredOperands(3, 1);
setRequiredRank(2);
- INIT_QINFO(Conv);
+ INIT_ATTRIBUTE(FullyConnected);
}
template <DType InDtype, DType WeightDtype>
OpFullyConnected<InDtype, WeightDtype>::~OpFullyConnected()
{
- if (qinfo)
- delete qinfo;
+ if (attribute)
+ delete attribute;
}
template <DType InDtype, DType WeightDtype>
@@ -1157,17 +1132,8 @@ int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
- if (this->qinfo)
- {
- if (InDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
- }
- if (WeightDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
- }
- }
+ ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
return 0;
}
@@ -1190,10 +1156,10 @@ int OpFullyConnected<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
- if (this->qinfo)
+ if (InDtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
- weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
+ weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
this->output->getTensor() =
@@ -1211,21 +1177,20 @@ int OpFullyConnected<InDtype, WeightDtype>::eval()
template <DType Dtype>
OpMatMul<Dtype>::OpMatMul(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_MATMUL, id_)
{
setRequiredOperands(2, 1);
setRequiredRank(3);
- INIT_QINFO(MatMul);
+ INIT_ATTRIBUTE(MatMul);
}
template <DType Dtype>
OpMatMul<Dtype>::~OpMatMul()
{
- if (qinfo)
- delete qinfo;
+ if (attribute)
+ delete attribute;
}
template <DType Dtype>
@@ -1284,11 +1249,8 @@ int OpMatMul<Dtype>::checkTensorAttributes()
}
W = b->getShape()[2];
- if (Dtype != DType_INT8 && this->qinfo)
- {
- ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t");
- ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t");
- }
+ ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data");
return 0;
}
@@ -1301,10 +1263,10 @@ int OpMatMul<Dtype>::eval()
TIn a_val = this->a->getTensor();
TIn b_val = this->b->getTensor();
- if (this->qinfo)
+ if (Dtype == DType_INT8)
{
- a_val = a_val - (InEigenType)this->qinfo->a_zp();
- b_val = b_val - (InEigenType)this->qinfo->b_zp();
+ a_val = a_val - (InEigenType)attribute->a_zp();
+ b_val = b_val - (InEigenType)attribute->b_zp();
}
Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
@@ -1351,7 +1313,6 @@ int OpMatMul<Dtype>::eval()
template <DType Dtype>
OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_MAX_POOL2D, id_)
{
@@ -1484,7 +1445,6 @@ int OpMaxPool2d<Dtype>::eval()
template <DType InDtype, DType WeightDtype>
OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
@@ -1492,7 +1452,6 @@ OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sg
setRequiredRank(4);
INIT_ATTRIBUTE(TransposeConv);
- INIT_QINFO(Conv);
}
template <DType InDtype, DType WeightDtype>
@@ -1500,8 +1459,6 @@ OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d()
{
if (attribute)
delete attribute;
- if (qinfo)
- delete qinfo;
}
template <DType InDtype, DType WeightDtype>
@@ -1595,17 +1552,8 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
- if (this->qinfo)
- {
- if (InDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
- }
- if (WeightDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
- }
- }
+ ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
return 0;
}
@@ -1655,10 +1603,10 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (this->qinfo)
+ if (InDtype == DType_INT8)
{
- input_val = input_val - (InEigenType)this->qinfo->input_zp();
- weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ input_val = input_val - (InEigenType)attribute->input_zp();
+ weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
}
Eigen::array<Eigen::Index, 4> reshape_dim;