diff options
author | Cathal Corbett <cathal.corbett@arm.com> | 2022-09-23 15:49:24 +0100 |
---|---|---|
committer | Cathal Corbett <cathal.corbett@arm.com> | 2022-09-26 11:25:07 +0000 |
commit | 2b922e2a5b5085b47480a4a971d40a1782bbfabd (patch) | |
tree | e6d8121ad833451dca358e31b84cd034ace1f976 | |
parent | e6f30addfea477ab628cfa71cbd7a4044d515d30 (diff) | |
download | armnn-2b922e2a5b5085b47480a4a971d40a1782bbfabd.tar.gz |
IVGCVSW-7158 TfLiteParser supports reshape when output 'shape_signature' param contains a value of -1.
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I538347083e9f22b3f3b6c048aebc2cf5cf4dc786
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 42 |
1 files changed, 34 insertions, 8 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index e036d0ca1c..0484c6f478 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -631,6 +631,16 @@ bool CheckShape(const armnn::TensorShape& actual, const std::vector<int32_t>& ex return true; } +bool CheckShape(const armnn::TensorShape& actual, const armnn::TensorShape& expected) +{ + std::vector<int32_t> expectedVec; + for (uint32_t i = 0; i < expected.GetNumDimensions(); i++) + { + expectedVec.push_back(expected[i]); + } + return CheckShape(actual, expectedVec); +} + void CheckMatchingQuantization(const TensorInfo& first, const TensorInfo& second, const std::string& descName, @@ -2889,17 +2899,33 @@ void TfLiteParserImpl::ParseReshape(size_t subgraphIndex, size_t operatorIndex) TfLiteParserImpl::OutputShapeOfReshape(inputTensorInfo, targetShape); // Check for valid input size and that reshape parameters equal output shape + // The output shape can be provided to us in 2 ways: + // 1. through the normal 'shape' parameter given by outputs[indx]->shape + // 2. through additional parameter 'shape_signature' given by outputs[indx]->buffer. + // This parameter can sometimes contain -1 value not visible in the 'shape' parameter. const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape(); if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, outputs[0]->shape)) { - std::stringstream ss; - ss << "New shape defined in reshape parameters " - << reshapeOutputTensorShape - << " does not equal output shape " - << actualOutputTensorInfo.GetShape() - << ": " - << CHECK_LOCATION().AsString(); - throw ParseException(ss.str()); + // Attempt to extract output shape from secondary 'shape_signature' + // parameter and try to CheckShape() with this param. + std::vector<int32_t> secondaryOutputTargetShape = outputs[0]->shape_signature; + + // if outputs[0]->shape_signature contain a -1 value, we need to compute its actual value + // from reshape input in order to correctly verify reshape parameters equal output shape + armnn::TensorInfo secondaryReshapeOutputTensorInfo = + TfLiteParserImpl::OutputShapeOfReshape(inputTensorInfo, secondaryOutputTargetShape); + + if (!CheckShape(reshapeOutputTensorShape, secondaryReshapeOutputTensorInfo.GetShape())) + { + std::stringstream ss; + ss << "New shape defined in reshape parameters " + << reshapeOutputTensorShape + << " does not equal output shape " + << actualOutputTensorInfo.GetShape() + << ": " + << CHECK_LOCATION().AsString(); + throw ParseException(ss.str()); + } } ReshapeDescriptor reshapeDesc; |