diff options
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 31e808fd6e..b9a3522736 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -226,6 +226,24 @@ void CheckBufferSize(TfLiteParser::BufferRawPtr bufferPtr, #define CHECK_BUFFER_SIZE(BUFFER_PTR, TENSOR_INFO, BUFFER_ID) \ CheckBufferSize(BUFFER_PTR, TENSOR_INFO, BUFFER_ID, CHECK_LOCATION()) +uint32_t CheckDilation(const int32_t dilationFactor, + size_t operatorIndex, + const CheckLocation& location) +{ + if (dilationFactor != 1) + { + std::stringstream ss; + ss << "ArmNN only supports convolution layers with dilations [1,1,1,1] for operator with index " + << operatorIndex << location.AsString(); + throw ParseException(ss.str()); + } + + return static_cast<uint32_t>(dilationFactor); +} + +#define CHECK_DILATION(DILATION_FACTOR, OPERATOR_INDEX) \ + CheckDilation(DILATION_FACTOR, OPERATOR_INDEX, CHECK_LOCATION()) + bool IsActivationSupported(tflite::ActivationFunctionType activationType) { switch(activationType) @@ -694,6 +712,9 @@ void TfLiteParser::ParseConv2D(size_t subgraphIndex, size_t operatorIndex) desc.m_StrideY = CHECKED_NON_NEGATIVE(options->stride_h); desc.m_DataLayout = armnn::DataLayout::NHWC; + CHECK_DILATION(options->dilation_h_factor, operatorIndex); + CHECK_DILATION(options->dilation_w_factor, operatorIndex); + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(inputs.size(), 2, 3); @@ -779,6 +800,9 @@ void TfLiteParser::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorInd auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); + CHECK_DILATION(options->dilation_h_factor, operatorIndex); + CHECK_DILATION(options->dilation_w_factor, operatorIndex); + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1]); |