diff options
author | Eric Kunze <eric.kunze@arm.com> | 2022-06-17 08:19:12 -0700 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-06-17 08:19:12 -0700 |
commit | f733783db23f3e89a5f518b0bbfe5d5ed0bc0337 (patch) | |
tree | 498d500f9781ce036a8e4eaf9cc6380415287c74 /reference_model | |
parent | 95dfcb8aab41650e9fdf17a6c875dfef09c9f610 (diff) | |
download | reference_model-f733783db23f3e89a5f518b0bbfe5d5ed0bc0337.tar.gz |
In the case of an int16xint8 test, the zero point was not being
subtracted from the weights.
Signed-off-by: Eric Kunze <eric.kunze@arm.com>
Change-Id: Ic77119b200b952715870abc11d09d1a646da86b1
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 3ab4d56..03cb9fb 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -675,7 +675,7 @@ int OpConv2d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8) + if (InDtype == DType_INT8 || WeightDtype == DType_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -838,7 +838,7 @@ int OpConv3d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8) + if (InDtype == DType_INT8 || WeightDtype == DType_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1017,7 +1017,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8) + if (InDtype == DType_INT8 || WeightDtype == DType_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1156,7 +1156,7 @@ int OpFullyConnected<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); - if (InDtype == DType_INT8) + if (InDtype == DType_INT8 || WeightDtype == DType_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1603,7 +1603,7 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8) + if (InDtype == DType_INT8 || WeightDtype == DType_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); |