diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-01-23 14:00:00 +0000 |
---|---|---|
committer | Aron Virginas-Tar <aron.virginas-tar@arm.com> | 2019-01-23 14:08:20 +0000 |
commit | 70672f6c52e95256911ca70110d3ad2643b43eaa (patch) | |
tree | 1911ca37e37b5efaca67dc84ae32b52178f6f626 /src | |
parent | 6392a77094e686fb4b973bddc8408c3faf256ed7 (diff) | |
download | armnn-70672f6c52e95256911ca70110d3ad2643b43eaa.tar.gz |
IVGCVSW-2534 Fix bug TfLiteParser::ParseReshape()
Change-Id: I44d63552d2552842f02b2c870466851581f65b1a
Diffstat (limited to 'src')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index c45e794274..affd858d77 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -432,6 +432,26 @@ armnn::LayerBindingId GenerateLayerBindingId(size_t subgraphIndex, size_t tensor return static_cast<armnn::LayerBindingId>((tensorIndex<<8)+subgraphIndex); } +bool CheckShape(const armnn::TensorShape& actual, const std::vector<int32_t>& expected) +{ + const unsigned int actualSize = actual.GetNumDimensions(); + if (actualSize != expected.size()) + { + return false; + } + + for (unsigned int i = 0u; i < actualSize; i++) + { + if (expected[i] < 0 || + actual[i] != static_cast<unsigned int>(expected[i])) + { + return false; + } + } + + return true; +} + } // <anonymous> TfLiteParser::TfLiteParser() @@ -1314,11 +1334,12 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex) TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape); // Check for valid input size and that reshape parameters equal output shape - if (inputs.size() > 1 && (options->new_shape != outputs[0]->shape)) + const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape(); + if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, outputs[0]->shape)) { std::stringstream ss; ss << "New shape defined in reshape parameters " - << reshapeOutputTensorInfo.GetShape() + << reshapeOutputTensorShape << " does not equal output shape " << actualOutputTensorInfo.GetShape() << ": " |