From f733783db23f3e89a5f518b0bbfe5d5ed0bc0337 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Fri, 17 Jun 2022 08:19:12 -0700 Subject: Fix reference model use of weight zero point In the case of an int16xint8 test, the zero point was not being subtracted from the weights. Signed-off-by: Eric Kunze Change-Id: Ic77119b200b952715870abc11d09d1a646da86b1 --- reference_model/src/ops/tensor_ops.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 3ab4d56..03cb9fb 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -675,7 +675,7 @@ int OpConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8) + if (InDtype == DType_INT8 || WeightDtype == DType_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -838,7 +838,7 @@ int OpConv3d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8) + if (InDtype == DType_INT8 || WeightDtype == DType_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1017,7 +1017,7 @@ int OpDepthwiseConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8) + if (InDtype == DType_INT8 || WeightDtype == DType_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1156,7 +1156,7 @@ int OpFullyConnected::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); - if (InDtype == DType_INT8) + if (InDtype == DType_INT8 || WeightDtype == DType_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1603,7 +1603,7 @@ int OpTransposeConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8) + if (InDtype == DType_INT8 || WeightDtype == DType_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); -- cgit v1.2.1