aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-11-01 11:14:13 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-11-01 14:43:56 -0700
commit80794804907bfe73367f40616f3b6c41deacaca3 (patch)
treeb0dfba412a665594ecb997866647e92062628fc0
parentc0b24f010813a8d01bd6820b8b86ed2011596020 (diff)
downloadreference_model-80794804907bfe73367f40616f3b6c41deacaca3.tar.gz
Fix for tensor_ops.cc
- MATMUL: only check a_zp/b_zp valid when this->qinfo exists - Fix typo in debug message Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I2cedcb25e4f57fcaec2caa1b850ea1232a023340
-rw-r--r--reference_model/src/ops/tensor_ops.cc16
1 files changed, 8 insertions, 8 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 059638a..5494d77 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -92,8 +92,8 @@ int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute,
return 1;
}
- if ( OH != (IH + pad_top + pad_bottom + stride_y - kernel_y) / stride_y ||
- OW != (IW + pad_left + pad_right + stride_x - kernel_x) / stride_x )
+ if ((OH != (IH + pad_top + pad_bottom + stride_y - kernel_y) / stride_y) ||
+ (OW != (IW + pad_left + pad_right + stride_x - kernel_x) / stride_x))
{
msg = "Mismatch between output shape provided and expected output shape";
return 1;
@@ -453,7 +453,7 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpFullyConnected: Output data type not supported for this configuration of operator");
+ "OpConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
@@ -660,7 +660,7 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpFullyConnected: Output data type not supported for this configuration of operator");
+ "OpConv3d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
@@ -870,7 +870,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpFullyConnected: Output data type not supported for this configuration of operator");
+ "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
@@ -1161,7 +1161,7 @@ int OpMatMul<Dtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpFullyConnected: Output data type not supported for this configuration of operator");
+ "OpMatMul: Output data type not supported for this configuration of operator");
a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
@@ -1205,7 +1205,7 @@ int OpMatMul<Dtype>::checkTensorAttributes()
}
W = b->getShape()[2];
- if (Dtype != DType_INT8)
+ if (Dtype != DType_INT8 && this->qinfo)
{
ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t");
ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t");
@@ -1436,7 +1436,7 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpFullyConnected: Output data type not supported for this configuration of operator");
+ "OpTransposeConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);