aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp21
1 files changed, 17 insertions, 4 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 89c72c52e0..49bc73708f 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1065,7 +1065,6 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
- CHECK_VALID_SIZE(inputs.size(), 1);
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
@@ -1074,15 +1073,29 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
const auto * options = operatorPtr->builtin_options.AsReshapeOptions();
armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
- armnn::TensorInfo outputTensorInfo =
+ armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]);
+ armnn::TensorInfo reshapeOutputTensorInfo =
TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape);
+ // Check for valid input size and that reshape parameters equal output shape
+ if (inputs.size() > 1 && (options->new_shape != outputs[0]->shape))
+ {
+ std::stringstream ss;
+ ss << "New shape defined in reshape parameters "
+ << reshapeOutputTensorInfo.GetShape()
+ << " does not equal output shape "
+ << actualOutputTensorInfo.GetShape()
+ << ": "
+ << CHECK_LOCATION().AsString();
+ throw ParseException(ss.str());
+ }
+
ReshapeDescriptor reshapeDesc;
- reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
+ reshapeDesc.m_TargetShape = reshapeOutputTensorInfo.GetShape();
auto layerName = boost::str(boost::format("Reshape:%1%:%2%") % subgraphIndex % operatorIndex);
IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
- layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ layer->GetOutputSlot(0).SetTensorInfo(reshapeOutputTensorInfo);
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});