aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2022-06-17 08:19:12 -0700
committerEric Kunze <eric.kunze@arm.com>2022-06-17 08:19:12 -0700
commitf733783db23f3e89a5f518b0bbfe5d5ed0bc0337 (patch)
tree498d500f9781ce036a8e4eaf9cc6380415287c74
parent95dfcb8aab41650e9fdf17a6c875dfef09c9f610 (diff)
downloadreference_model-v0.30.tar.gz
Fix reference model use of weight zero pointv0.30.0v0.30
In the case of an int16xint8 test, the zero point was not being subtracted from the weights. Signed-off-by: Eric Kunze <eric.kunze@arm.com> Change-Id: Ic77119b200b952715870abc11d09d1a646da86b1
-rw-r--r--reference_model/src/ops/tensor_ops.cc10
1 files changed, 5 insertions, 5 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 3ab4d56..03cb9fb 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -675,7 +675,7 @@ int OpConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8)
+ if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -838,7 +838,7 @@ int OpConv3d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8)
+ if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1017,7 +1017,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8)
+ if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1156,7 +1156,7 @@ int OpFullyConnected<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
- if (InDtype == DType_INT8)
+ if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1603,7 +1603,7 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8)
+ if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();