aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--reference_model/src/ops/tensor_ops.cc10
1 files changed, 5 insertions, 5 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 3ab4d56..03cb9fb 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -675,7 +675,7 @@ int OpConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8)
+ if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -838,7 +838,7 @@ int OpConv3d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8)
+ if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1017,7 +1017,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8)
+ if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1156,7 +1156,7 @@ int OpFullyConnected<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
- if (InDtype == DType_INT8)
+ if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1603,7 +1603,7 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8)
+ if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();