From 70672f6c52e95256911ca70110d3ad2643b43eaa Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Wed, 23 Jan 2019 14:00:00 +0000 Subject: IVGCVSW-2534 Fix bug TfLiteParser::ParseReshape() Change-Id: I44d63552d2552842f02b2c870466851581f65b1a --- src/armnnTfLiteParser/TfLiteParser.cpp | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index c45e794274..affd858d77 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -432,6 +432,26 @@ armnn::LayerBindingId GenerateLayerBindingId(size_t subgraphIndex, size_t tensor return static_cast((tensorIndex<<8)+subgraphIndex); } +bool CheckShape(const armnn::TensorShape& actual, const std::vector& expected) +{ + const unsigned int actualSize = actual.GetNumDimensions(); + if (actualSize != expected.size()) + { + return false; + } + + for (unsigned int i = 0u; i < actualSize; i++) + { + if (expected[i] < 0 || + actual[i] != static_cast(expected[i])) + { + return false; + } + } + + return true; +} + } // TfLiteParser::TfLiteParser() @@ -1314,11 +1334,12 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex) TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape); // Check for valid input size and that reshape parameters equal output shape - if (inputs.size() > 1 && (options->new_shape != outputs[0]->shape)) + const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape(); + if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, outputs[0]->shape)) { std::stringstream ss; ss << "New shape defined in reshape parameters " - << reshapeOutputTensorInfo.GetShape() + << reshapeOutputTensorShape << " does not equal output shape " << actualOutputTensorInfo.GetShape() << ": " -- cgit v1.2.1