diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 42 |
1 files changed, 38 insertions, 4 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 15ca36d906..f51cf508e2 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -2566,13 +2566,47 @@ void TfLiteParserImpl::ParseReshape(size_t subgraphIndex, size_t operatorIndex) // Extract target shape from input auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer); auto values = reinterpret_cast<const int32_t*>(bufferPtr->data.data()); - if (!values) + if (values) { - ARMNN_THROW_PARSE_EXCEPTION("Reshape operator target shape input buffer data is null"); + for (int i = 0; i < inputs[1]->shape[0]; ++i) + { + targetShape.push_back(values[i]); + } } - for (int i=0; i < inputs[1]->shape[0]; ++i) + else { - targetShape.push_back(values[i]); + try + { + // We attempt to infer during Runtime. + TensorShape reshapeShapes = ToTensorInfo(inputs[1]).GetShape(); + // The parser only supports shape (batch, -1) or (-1) for non-constant shape input. + if (reshapeShapes[0] > 2) + { + throw ParseException(fmt::format("Invalid input shape '{}' in Reshape layer '{}' {}. " + "When inferring during runtime, the parser only supports " + "shape (batch, -1) or (-1) for target shape input.", + reshapeShapes[0], + layerName, + CHECK_LOCATION().AsString())); + } + + const int32_t numInputElements = inputTensorInfo.GetNumElements(); + const int32_t inputTensorShape = inputTensorInfo.GetShape()[0]; + if (reshapeShapes[0] == 1) + { + targetShape = {numInputElements}; + } + else if (reshapeShapes[0] == 2) + { + targetShape = {inputTensorShape, numInputElements / inputTensorShape}; + } + } + catch (const std::exception& exc) + { + ARMNN_THROW_PARSE_EXCEPTION("Failed attempt to infer during runtime the target shape input for " + "Reshape operation. Reshape operator target shape input buffer data " + "is null. " << exc.what()); + } } } else |