aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-11-10 01:04:39 +0000
committerEric Kunze <eric.kunze@arm.com>2021-11-11 21:17:42 +0000
commit9fe172483b77dcaa0bfe7e97af4a934d6ef01a16 (patch)
tree0b2d9c16e090ec7d1a6d98d2019ee24a645c0b8e
parent7e9ac9ab74ff2a793e226abf86d2543d1421d3c9 (diff)
downloadreference_model-9fe172483b77dcaa0bfe7e97af4a934d6ef01a16.tar.gz
More ERROR_IF to check attribute for convolution ops
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I49d498dd3d4c069d8d1db07310f939268b9df4b7
-rw-r--r--reference_model/src/ops/tensor_ops.cc207
1 files changed, 121 insertions, 86 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 5494d77..7942a24 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -21,10 +21,10 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute,
- std::vector<int32_t> input_shape,
- std::vector<int32_t> output_shape,
- std::string& msg)
+int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
+ std::vector<int32_t> input_shape,
+ std::vector<int32_t> output_shape,
+ std::string& msg)
{
if (attribute->padding().size() != 4)
{
@@ -57,7 +57,7 @@ int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute,
{
if (i < 1)
{
- msg = "At least one kernel dimension is smaller than zero";
+ msg = "At least one kernel dimension is smaller than one";
return 1;
}
}
@@ -66,7 +66,7 @@ int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute,
{
if (i < 1)
{
- msg = "At least one stride dimension is smaller than zero";
+ msg = "At least one stride dimension is smaller than one";
return 1;
}
}
@@ -102,6 +102,77 @@ int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute,
return 0;
}
+int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute,
+ tosa::TosaConvQuantInfo* qinfo,
+ uint32_t conv_dimension,
+ std::vector<int32_t> input_shape,
+ std::vector<int32_t> output_shape,
+ DType InDtype,
+ DType WeightDtype,
+ std::string& msg)
+{
+ if (attribute->padding().size() != (2 * conv_dimension))
+ {
+ msg = "Illegal size for attribute padding";
+ return 1;
+ }
+
+ if (attribute->stride().size() != conv_dimension)
+ {
+ msg = "Illegal size for attribute stride";
+ return 1;
+ }
+
+ if (attribute->dilation().size() != conv_dimension)
+ {
+ msg = "Illegal size for attribute dilation";
+ return 1;
+ }
+
+ for (int32_t i : attribute->padding())
+ {
+ if (i < 0)
+ {
+ msg = "At least one pad is smaller than zero";
+ return 1;
+ }
+ }
+
+ for (int32_t i : attribute->stride())
+ {
+ if (i < 1)
+ {
+ msg = "At least one stride dimension is smaller than one";
+ return 1;
+ }
+ }
+
+ for (int32_t i : attribute->dilation())
+ {
+ if (i < 1)
+ {
+ msg = "At least one dilation dimension is smaller than one";
+ return 1;
+ }
+ }
+
+ if (qinfo)
+ {
+ if (InDtype != DType_INT8 && qinfo->input_zp() != 0)
+ {
+ msg = "zeropoint only for int8_t";
+ return 1;
+ }
+ if (WeightDtype != DType_INT8 && qinfo->weight_zp() != 0)
+ {
+ msg = "zeropoint only for int8_t";
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
template <int Rank, DType Dtype>
OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
@@ -243,7 +314,7 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes()
}
std::string msg;
- if (check_pool2d_attribute_common(attribute, in->getShape(), out->getShape(), msg))
+ if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
{
msg = "OpAvgPool2d: " + msg;
printNodeValidationError(msg.c_str());
@@ -460,36 +531,15 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
- if (attribute->padding().size() != 4)
- {
- printNodeValidationError("OpConv2d: illegal size for attribute padding");
- return 1;
- }
-
- if (attribute->stride().size() != 2)
- {
- printNodeValidationError("OpConv2d: illegal size for attribute stride");
- return 1;
- }
-
- if (attribute->dilation().size() != 2)
+ std::string msg;
+ if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(),
+ InDtype, WeightDtype, msg))
{
- printNodeValidationError("OpConv2d: illegal size for attribute dilation");
+ msg = "OpConv2d: " + msg;
+ printNodeValidationError(msg.c_str());
return 1;
}
- if (this->qinfo)
- {
- if (InDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpConv2d: zeropoint only for int8_t");
- }
- if (WeightDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv2d: zeropoint only for int8_t");
- }
- }
-
return 0;
}
@@ -667,36 +717,15 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
- if (attribute->padding().size() != 6)
- {
- printNodeValidationError("OpConv3d: illegal size for attribute padding");
- return 1;
- }
-
- if (attribute->stride().size() != 3)
- {
- printNodeValidationError("OpConv3d: illegal size for attribute stride");
- return 1;
- }
-
- if (attribute->dilation().size() != 3)
+ std::string msg;
+ if (check_conv_attribute_qinfo(attribute, qinfo, 3 /* conv_dimension */, input->getShape(), output->getShape(),
+ InDtype, WeightDtype, msg))
{
- printNodeValidationError("OpConv3d: illegal size for attribute dilation");
+ msg = "OpConv3d: " + msg;
+ printNodeValidationError(msg.c_str());
return 1;
}
- if (this->qinfo)
- {
- if (InDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpConv3d: zeropoint only for int8_t");
- }
- if (WeightDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv3d: zeropoint only for int8_t");
- }
- }
-
return 0;
}
@@ -877,36 +906,15 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
- if (attribute->padding().size() != 4)
- {
- printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute padding");
- return 1;
- }
-
- if (attribute->stride().size() != 2)
- {
- printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute stride");
- return 1;
- }
-
- if (attribute->dilation().size() != 2)
+ std::string msg;
+ if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(),
+ InDtype, WeightDtype, msg))
{
- printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute dilation");
+ msg = "OpDepthwiseConv2d: " + msg;
+ printNodeValidationError(msg.c_str());
return 1;
}
- if (this->qinfo)
- {
- if (InDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->input_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
- }
- if (WeightDtype != DType_INT8)
- {
- ERROR_IF(this->qinfo->weight_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
- }
- }
-
return 0;
}
@@ -1310,7 +1318,7 @@ int OpMaxPool2d<Dtype>::checkTensorAttributes()
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
std::string msg;
- if (check_pool2d_attribute_common(attribute, in->getShape(), out->getShape(), msg))
+ if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
{
msg = "OpMaxPool2d: " + msg;
printNodeValidationError(msg.c_str());
@@ -1467,6 +1475,33 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
+ for (int32_t i : attribute->outpad())
+ {
+ if (i < 0)
+ {
+ printNodeValidationError("OpTransposeConv2d: At least one pad is smaller than zero");
+ return 1;
+ }
+ }
+
+ for (int32_t i : attribute->stride())
+ {
+ if (i < 1)
+ {
+ printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one");
+ return 1;
+ }
+ }
+
+ for (int32_t i : attribute->dilation())
+ {
+ if (i < 1)
+ {
+ printNodeValidationError("OpTransposeConv2d: At least one dilation is smaller than one");
+ return 1;
+ }
+ }
+
for (int d = 0; d < 4; d++)
{
if (attribute->output_shape()[d] != this->output->getShape()[d])