aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathal Corbett <cathal.corbett@arm.com>2022-09-23 15:49:24 +0100
committerCathal Corbett <cathal.corbett@arm.com>2022-09-26 11:25:07 +0000
commit2b922e2a5b5085b47480a4a971d40a1782bbfabd (patch)
treee6d8121ad833451dca358e31b84cd034ace1f976
parente6f30addfea477ab628cfa71cbd7a4044d515d30 (diff)
downloadarmnn-2b922e2a5b5085b47480a4a971d40a1782bbfabd.tar.gz
IVGCVSW-7158 TfLiteParser supports reshape when output 'shape_signature' param contains a value of -1.
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: I538347083e9f22b3f3b6c048aebc2cf5cf4dc786
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp42
1 files changed, 34 insertions, 8 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index e036d0ca1c..0484c6f478 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -631,6 +631,16 @@ bool CheckShape(const armnn::TensorShape& actual, const std::vector<int32_t>& ex
return true;
}
+bool CheckShape(const armnn::TensorShape& actual, const armnn::TensorShape& expected)
+{
+ std::vector<int32_t> expectedVec;
+ for (uint32_t i = 0; i < expected.GetNumDimensions(); i++)
+ {
+ expectedVec.push_back(expected[i]);
+ }
+ return CheckShape(actual, expectedVec);
+}
+
void CheckMatchingQuantization(const TensorInfo& first,
const TensorInfo& second,
const std::string& descName,
@@ -2889,17 +2899,33 @@ void TfLiteParserImpl::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
TfLiteParserImpl::OutputShapeOfReshape(inputTensorInfo, targetShape);
// Check for valid input size and that reshape parameters equal output shape
+ // The output shape can be provided to us in 2 ways:
+ // 1. through the normal 'shape' parameter given by outputs[indx]->shape
+ // 2. through additional parameter 'shape_signature' given by outputs[indx]->buffer.
+ // This parameter can sometimes contain -1 value not visible in the 'shape' parameter.
const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape();
if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, outputs[0]->shape))
{
- std::stringstream ss;
- ss << "New shape defined in reshape parameters "
- << reshapeOutputTensorShape
- << " does not equal output shape "
- << actualOutputTensorInfo.GetShape()
- << ": "
- << CHECK_LOCATION().AsString();
- throw ParseException(ss.str());
+ // Attempt to extract output shape from secondary 'shape_signature'
+ // parameter and try to CheckShape() with this param.
+ std::vector<int32_t> secondaryOutputTargetShape = outputs[0]->shape_signature;
+
+ // if outputs[0]->shape_signature contain a -1 value, we need to compute its actual value
+ // from reshape input in order to correctly verify reshape parameters equal output shape
+ armnn::TensorInfo secondaryReshapeOutputTensorInfo =
+ TfLiteParserImpl::OutputShapeOfReshape(inputTensorInfo, secondaryOutputTargetShape);
+
+ if (!CheckShape(reshapeOutputTensorShape, secondaryReshapeOutputTensorInfo.GetShape()))
+ {
+ std::stringstream ss;
+ ss << "New shape defined in reshape parameters "
+ << reshapeOutputTensorShape
+ << " does not equal output shape "
+ << actualOutputTensorInfo.GetShape()
+ << ": "
+ << CHECK_LOCATION().AsString();
+ throw ParseException(ss.str());
+ }
}
ReshapeDescriptor reshapeDesc;