aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser
diff options
context:
space:
mode:
authorKevin May <kevin.may@arm.com>2019-03-26 11:39:19 +0000
committerKevin May <kevin.may@arm.com>2019-03-29 14:02:57 +0000
commit83add2165b680b3cf38403a7ce90ea86febd4cc7 (patch)
tree1cd28c1c375a4ff49022a92bd4e3e97bfb35a319 /src/armnnTfLiteParser
parent10e6be471a925307dc6cae1d80377cd73878c2f2 (diff)
downloadarmnn-83add2165b680b3cf38403a7ce90ea86febd4cc7.tar.gz
MLCE-101 Deeplab v3+ (Add Tf Lite Parser Dilation Check)
* Add Parse Exception for convolutions without default dilation Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: I1b8f75c2d871d81161eb5378ced277438e809ba2
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp24
1 files changed, 24 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 31e808fd6e..b9a3522736 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -226,6 +226,24 @@ void CheckBufferSize(TfLiteParser::BufferRawPtr bufferPtr,
#define CHECK_BUFFER_SIZE(BUFFER_PTR, TENSOR_INFO, BUFFER_ID) \
CheckBufferSize(BUFFER_PTR, TENSOR_INFO, BUFFER_ID, CHECK_LOCATION())
+uint32_t CheckDilation(const int32_t dilationFactor,
+ size_t operatorIndex,
+ const CheckLocation& location)
+{
+ if (dilationFactor != 1)
+ {
+ std::stringstream ss;
+ ss << "ArmNN only supports convolution layers with dilations [1,1,1,1] for operator with index "
+ << operatorIndex << location.AsString();
+ throw ParseException(ss.str());
+ }
+
+ return static_cast<uint32_t>(dilationFactor);
+}
+
+#define CHECK_DILATION(DILATION_FACTOR, OPERATOR_INDEX) \
+ CheckDilation(DILATION_FACTOR, OPERATOR_INDEX, CHECK_LOCATION())
+
bool IsActivationSupported(tflite::ActivationFunctionType activationType)
{
switch(activationType)
@@ -694,6 +712,9 @@ void TfLiteParser::ParseConv2D(size_t subgraphIndex, size_t operatorIndex)
desc.m_StrideY = CHECKED_NON_NEGATIVE(options->stride_h);
desc.m_DataLayout = armnn::DataLayout::NHWC;
+ CHECK_DILATION(options->dilation_h_factor, operatorIndex);
+ CHECK_DILATION(options->dilation_w_factor, operatorIndex);
+
auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(inputs.size(), 2, 3);
@@ -779,6 +800,9 @@ void TfLiteParser::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorInd
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
+ CHECK_DILATION(options->dilation_h_factor, operatorIndex);
+ CHECK_DILATION(options->dilation_w_factor, operatorIndex);
+
armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1]);