diff options
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index c45e794274..affd858d77 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -432,6 +432,26 @@ armnn::LayerBindingId GenerateLayerBindingId(size_t subgraphIndex, size_t tensor return static_cast<armnn::LayerBindingId>((tensorIndex<<8)+subgraphIndex); } +bool CheckShape(const armnn::TensorShape& actual, const std::vector<int32_t>& expected) +{ + const unsigned int actualSize = actual.GetNumDimensions(); + if (actualSize != expected.size()) + { + return false; + } + + for (unsigned int i = 0u; i < actualSize; i++) + { + if (expected[i] < 0 || + actual[i] != static_cast<unsigned int>(expected[i])) + { + return false; + } + } + + return true; +} + } // <anonymous> TfLiteParser::TfLiteParser() @@ -1314,11 +1334,12 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex) TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape); // Check for valid input size and that reshape parameters equal output shape - if (inputs.size() > 1 && (options->new_shape != outputs[0]->shape)) + const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape(); + if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, outputs[0]->shape)) { std::stringstream ss; ss << "New shape defined in reshape parameters " - << reshapeOutputTensorInfo.GetShape() + << reshapeOutputTensorShape << " does not equal output shape " << actualOutputTensorInfo.GetShape() << ": " |