aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp42
1 files changed, 38 insertions, 4 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 15ca36d906..f51cf508e2 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -2566,13 +2566,47 @@ void TfLiteParserImpl::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
// Extract target shape from input
auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
auto values = reinterpret_cast<const int32_t*>(bufferPtr->data.data());
- if (!values)
+ if (values)
{
- ARMNN_THROW_PARSE_EXCEPTION("Reshape operator target shape input buffer data is null");
+ for (int i = 0; i < inputs[1]->shape[0]; ++i)
+ {
+ targetShape.push_back(values[i]);
+ }
}
- for (int i=0; i < inputs[1]->shape[0]; ++i)
+ else
{
- targetShape.push_back(values[i]);
+ try
+ {
+ // We attempt to infer during Runtime.
+ TensorShape reshapeShapes = ToTensorInfo(inputs[1]).GetShape();
+ // The parser only supports shape (batch, -1) or (-1) for non-constant shape input.
+ if (reshapeShapes[0] > 2)
+ {
+ throw ParseException(fmt::format("Invalid input shape '{}' in Reshape layer '{}' {}. "
+ "When inferring during runtime, the parser only supports "
+ "shape (batch, -1) or (-1) for target shape input.",
+ reshapeShapes[0],
+ layerName,
+ CHECK_LOCATION().AsString()));
+ }
+
+ const int32_t numInputElements = inputTensorInfo.GetNumElements();
+ const int32_t inputTensorShape = inputTensorInfo.GetShape()[0];
+ if (reshapeShapes[0] == 1)
+ {
+ targetShape = {numInputElements};
+ }
+ else if (reshapeShapes[0] == 2)
+ {
+ targetShape = {inputTensorShape, numInputElements / inputTensorShape};
+ }
+ }
+ catch (const std::exception& exc)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION("Failed attempt to infer during runtime the target shape input for "
+ "Reshape operation. Reshape operator target shape input buffer data "
+ "is null. " << exc.what());
+ }
}
}
else