From 2b922e2a5b5085b47480a4a971d40a1782bbfabd Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Fri, 23 Sep 2022 15:49:24 +0100 Subject: IVGCVSW-7158 TfLiteParser supports reshape when output 'shape_signature' param contains a value of -1. Signed-off-by: Cathal Corbett Change-Id: I538347083e9f22b3f3b6c048aebc2cf5cf4dc786 --- src/armnnTfLiteParser/TfLiteParser.cpp | 42 +++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 8 deletions(-) (limited to 'src/armnnTfLiteParser') diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index e036d0ca1c..0484c6f478 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -631,6 +631,16 @@ bool CheckShape(const armnn::TensorShape& actual, const std::vector& ex return true; } +bool CheckShape(const armnn::TensorShape& actual, const armnn::TensorShape& expected) +{ + std::vector expectedVec; + for (uint32_t i = 0; i < expected.GetNumDimensions(); i++) + { + expectedVec.push_back(expected[i]); + } + return CheckShape(actual, expectedVec); +} + void CheckMatchingQuantization(const TensorInfo& first, const TensorInfo& second, const std::string& descName, @@ -2889,17 +2899,33 @@ void TfLiteParserImpl::ParseReshape(size_t subgraphIndex, size_t operatorIndex) TfLiteParserImpl::OutputShapeOfReshape(inputTensorInfo, targetShape); // Check for valid input size and that reshape parameters equal output shape + // The output shape can be provided to us in 2 ways: + // 1. through the normal 'shape' parameter given by outputs[indx]->shape + // 2. through additional parameter 'shape_signature' given by outputs[indx]->buffer. + // This parameter can sometimes contain -1 value not visible in the 'shape' parameter. const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape(); if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, outputs[0]->shape)) { - std::stringstream ss; - ss << "New shape defined in reshape parameters " - << reshapeOutputTensorShape - << " does not equal output shape " - << actualOutputTensorInfo.GetShape() - << ": " - << CHECK_LOCATION().AsString(); - throw ParseException(ss.str()); + // Attempt to extract output shape from secondary 'shape_signature' + // parameter and try to CheckShape() with this param. + std::vector secondaryOutputTargetShape = outputs[0]->shape_signature; + + // if outputs[0]->shape_signature contain a -1 value, we need to compute its actual value + // from reshape input in order to correctly verify reshape parameters equal output shape + armnn::TensorInfo secondaryReshapeOutputTensorInfo = + TfLiteParserImpl::OutputShapeOfReshape(inputTensorInfo, secondaryOutputTargetShape); + + if (!CheckShape(reshapeOutputTensorShape, secondaryReshapeOutputTensorInfo.GetShape())) + { + std::stringstream ss; + ss << "New shape defined in reshape parameters " + << reshapeOutputTensorShape + << " does not equal output shape " + << actualOutputTensorInfo.GetShape() + << ": " + << CHECK_LOCATION().AsString(); + throw ParseException(ss.str()); + } } ReshapeDescriptor reshapeDesc; -- cgit v1.2.1