diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 42 |
1 files changed, 34 insertions, 8 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index e036d0ca1c..0484c6f478 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -631,6 +631,16 @@ bool CheckShape(const armnn::TensorShape& actual, const std::vector<int32_t>& ex return true; } +bool CheckShape(const armnn::TensorShape& actual, const armnn::TensorShape& expected) +{ + std::vector<int32_t> expectedVec; + for (uint32_t i = 0; i < expected.GetNumDimensions(); i++) + { + expectedVec.push_back(expected[i]); + } + return CheckShape(actual, expectedVec); +} + void CheckMatchingQuantization(const TensorInfo& first, const TensorInfo& second, const std::string& descName, @@ -2889,17 +2899,33 @@ void TfLiteParserImpl::ParseReshape(size_t subgraphIndex, size_t operatorIndex) TfLiteParserImpl::OutputShapeOfReshape(inputTensorInfo, targetShape); // Check for valid input size and that reshape parameters equal output shape + // The output shape can be provided to us in 2 ways: + // 1. through the normal 'shape' parameter given by outputs[indx]->shape + // 2. through additional parameter 'shape_signature' given by outputs[indx]->buffer. + // This parameter can sometimes contain -1 value not visible in the 'shape' parameter. const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape(); if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, outputs[0]->shape)) { - std::stringstream ss; - ss << "New shape defined in reshape parameters " - << reshapeOutputTensorShape - << " does not equal output shape " - << actualOutputTensorInfo.GetShape() - << ": " - << CHECK_LOCATION().AsString(); - throw ParseException(ss.str()); + // Attempt to extract output shape from secondary 'shape_signature' + // parameter and try to CheckShape() with this param. + std::vector<int32_t> secondaryOutputTargetShape = outputs[0]->shape_signature; + + // if outputs[0]->shape_signature contain a -1 value, we need to compute its actual value + // from reshape input in order to correctly verify reshape parameters equal output shape + armnn::TensorInfo secondaryReshapeOutputTensorInfo = + TfLiteParserImpl::OutputShapeOfReshape(inputTensorInfo, secondaryOutputTargetShape); + + if (!CheckShape(reshapeOutputTensorShape, secondaryReshapeOutputTensorInfo.GetShape())) + { + std::stringstream ss; + ss << "New shape defined in reshape parameters " + << reshapeOutputTensorShape + << " does not equal output shape " + << actualOutputTensorInfo.GetShape() + << ": " + << CHECK_LOCATION().AsString(); + throw ParseException(ss.str()); + } } ReshapeDescriptor reshapeDesc; |