From d2f7323b7ddf8f811f19ba7ae9987dcc6bf672a6 Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Fri, 10 Dec 2021 13:38:52 +0000 Subject: IVGCVSW-6252 Armnn Error: Failed to parse operator #1 within subgraph #0 error: Operator not supported * Added missing support for reshape operator in tflite parser when the target shape is dynamic and batch size is unknown * Added corresponding unit test Change-Id: I35e159c9c70440168c6092d2ad02828bb2b81cd9 Signed-off-by: Cathal Corbett --- src/armnnTfLiteParser/TfLiteParser.cpp | 42 ++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp') diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 15ca36d906..f51cf508e2 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -2566,13 +2566,47 @@ void TfLiteParserImpl::ParseReshape(size_t subgraphIndex, size_t operatorIndex) // Extract target shape from input auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer); auto values = reinterpret_cast(bufferPtr->data.data()); - if (!values) + if (values) { - ARMNN_THROW_PARSE_EXCEPTION("Reshape operator target shape input buffer data is null"); + for (int i = 0; i < inputs[1]->shape[0]; ++i) + { + targetShape.push_back(values[i]); + } } - for (int i=0; i < inputs[1]->shape[0]; ++i) + else { - targetShape.push_back(values[i]); + try + { + // We attempt to infer during Runtime. + TensorShape reshapeShapes = ToTensorInfo(inputs[1]).GetShape(); + // The parser only supports shape (batch, -1) or (-1) for non-constant shape input. + if (reshapeShapes[0] > 2) + { + throw ParseException(fmt::format("Invalid input shape '{}' in Reshape layer '{}' {}. " + "When inferring during runtime, the parser only supports " + "shape (batch, -1) or (-1) for target shape input.", + reshapeShapes[0], + layerName, + CHECK_LOCATION().AsString())); + } + + const int32_t numInputElements = inputTensorInfo.GetNumElements(); + const int32_t inputTensorShape = inputTensorInfo.GetShape()[0]; + if (reshapeShapes[0] == 1) + { + targetShape = {numInputElements}; + } + else if (reshapeShapes[0] == 2) + { + targetShape = {inputTensorShape, numInputElements / inputTensorShape}; + } + } + catch (const std::exception& exc) + { + ARMNN_THROW_PARSE_EXCEPTION("Failed attempt to infer during runtime the target shape input for " + "Reshape operation. Reshape operator target shape input buffer data " + "is null. " << exc.what()); + } } } else -- cgit v1.2.1