diff options
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 207 |
1 files changed, 121 insertions, 86 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 5494d77..7942a24 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -21,10 +21,10 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute, - std::vector<int32_t> input_shape, - std::vector<int32_t> output_shape, - std::string& msg) +int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute, + std::vector<int32_t> input_shape, + std::vector<int32_t> output_shape, + std::string& msg) { if (attribute->padding().size() != 4) { @@ -57,7 +57,7 @@ int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute, { if (i < 1) { - msg = "At least one kernel dimension is smaller than zero"; + msg = "At least one kernel dimension is smaller than one"; return 1; } } @@ -66,7 +66,7 @@ int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute, { if (i < 1) { - msg = "At least one stride dimension is smaller than zero"; + msg = "At least one stride dimension is smaller than one"; return 1; } } @@ -102,6 +102,77 @@ int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute, return 0; } +int check_conv_attribute_qinfo(tosa::TosaConvAttribute* attribute, + tosa::TosaConvQuantInfo* qinfo, + uint32_t conv_dimension, + std::vector<int32_t> input_shape, + std::vector<int32_t> output_shape, + DType InDtype, + DType WeightDtype, + std::string& msg) +{ + if (attribute->padding().size() != (2 * conv_dimension)) + { + msg = "Illegal size for attribute padding"; + return 1; + } + + if (attribute->stride().size() != conv_dimension) + { + msg = "Illegal size for attribute stride"; + return 1; + } + + if (attribute->dilation().size() != conv_dimension) + { + msg = "Illegal size for attribute dilation"; + return 1; + } + + for (int32_t i : attribute->padding()) + { + if (i < 0) + { + msg = "At least one pad is smaller than zero"; + return 1; + } + } + + for (int32_t i : attribute->stride()) + { + if (i < 1) + { + msg = "At least one stride dimension is smaller than one"; + return 1; + } + } + + for (int32_t i : attribute->dilation()) + { + if (i < 1) + { + msg = "At least one dilation dimension is smaller than one"; + return 1; + } + } + + if (qinfo) + { + if (InDtype != DType_INT8 && qinfo->input_zp() != 0) + { + msg = "zeropoint only for int8_t"; + return 1; + } + if (WeightDtype != DType_INT8 && qinfo->weight_zp() != 0) + { + msg = "zeropoint only for int8_t"; + return 1; + } + } + + return 0; +} + template <int Rank, DType Dtype> OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, @@ -243,7 +314,7 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes() } std::string msg; - if (check_pool2d_attribute_common(attribute, in->getShape(), out->getShape(), msg)) + if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg)) { msg = "OpAvgPool2d: " + msg; printNodeValidationError(msg.c_str()); @@ -460,36 +531,15 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes() bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); - if (attribute->padding().size() != 4) - { - printNodeValidationError("OpConv2d: illegal size for attribute padding"); - return 1; - } - - if (attribute->stride().size() != 2) - { - printNodeValidationError("OpConv2d: illegal size for attribute stride"); - return 1; - } - - if (attribute->dilation().size() != 2) + std::string msg; + if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(), + InDtype, WeightDtype, msg)) { - printNodeValidationError("OpConv2d: illegal size for attribute dilation"); + msg = "OpConv2d: " + msg; + printNodeValidationError(msg.c_str()); return 1; } - if (this->qinfo) - { - if (InDtype != DType_INT8) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpConv2d: zeropoint only for int8_t"); - } - if (WeightDtype != DType_INT8) - { - ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv2d: zeropoint only for int8_t"); - } - } - return 0; } @@ -667,36 +717,15 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes() bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); - if (attribute->padding().size() != 6) - { - printNodeValidationError("OpConv3d: illegal size for attribute padding"); - return 1; - } - - if (attribute->stride().size() != 3) - { - printNodeValidationError("OpConv3d: illegal size for attribute stride"); - return 1; - } - - if (attribute->dilation().size() != 3) + std::string msg; + if (check_conv_attribute_qinfo(attribute, qinfo, 3 /* conv_dimension */, input->getShape(), output->getShape(), + InDtype, WeightDtype, msg)) { - printNodeValidationError("OpConv3d: illegal size for attribute dilation"); + msg = "OpConv3d: " + msg; + printNodeValidationError(msg.c_str()); return 1; } - if (this->qinfo) - { - if (InDtype != DType_INT8) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpConv3d: zeropoint only for int8_t"); - } - if (WeightDtype != DType_INT8) - { - ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv3d: zeropoint only for int8_t"); - } - } - return 0; } @@ -877,36 +906,15 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes() bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); - if (attribute->padding().size() != 4) - { - printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute padding"); - return 1; - } - - if (attribute->stride().size() != 2) - { - printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute stride"); - return 1; - } - - if (attribute->dilation().size() != 2) + std::string msg; + if (check_conv_attribute_qinfo(attribute, qinfo, 2 /* conv_dimension */, input->getShape(), output->getShape(), + InDtype, WeightDtype, msg)) { - printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute dilation"); + msg = "OpDepthwiseConv2d: " + msg; + printNodeValidationError(msg.c_str()); return 1; } - if (this->qinfo) - { - if (InDtype != DType_INT8) - { - ERROR_IF(this->qinfo->input_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t"); - } - if (WeightDtype != DType_INT8) - { - ERROR_IF(this->qinfo->weight_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t"); - } - } - return 0; } @@ -1310,7 +1318,7 @@ int OpMaxPool2d<Dtype>::checkTensorAttributes() out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); std::string msg; - if (check_pool2d_attribute_common(attribute, in->getShape(), out->getShape(), msg)) + if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg)) { msg = "OpMaxPool2d: " + msg; printNodeValidationError(msg.c_str()); @@ -1467,6 +1475,33 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes() return 1; } + for (int32_t i : attribute->outpad()) + { + if (i < 0) + { + printNodeValidationError("OpTransposeConv2d: At least one pad is smaller than zero"); + return 1; + } + } + + for (int32_t i : attribute->stride()) + { + if (i < 1) + { + printNodeValidationError("OpTransposeConv2d: At least one stride is smaller than one"); + return 1; + } + } + + for (int32_t i : attribute->dilation()) + { + if (i < 1) + { + printNodeValidationError("OpTransposeConv2d: At least one dilation is smaller than one"); + return 1; + } + } + for (int d = 0; d < 4; d++) { if (attribute->output_shape()[d] != this->output->getShape()[d]) |