aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser
diff options
context:
space:
mode:
authorkevmay01 <kevin.may@arm.com>2018-12-17 14:28:03 +0000
committerLes Bell <les.bell@arm.com>2018-12-17 14:51:56 +0000
commit71972a85778ad158ed3f471bec6b75a8c40ea3a1 (patch)
tree132a552b01e918f7db6afb8ad27806b655386e75 /src/armnnTfLiteParser
parent69352c1504d9e82c261a639db3ef03087a410f3a (diff)
downloadarmnn-71972a85778ad158ed3f471bec6b75a8c40ea3a1.tar.gz
IVGCVSW-2395 TfLiteParse::ParseReshape doesn't support reshape input
Change-Id: If2a31a49df3701877ce0287a81c569334a24cd20
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp21
1 files changed, 17 insertions, 4 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 89c72c52e0..49bc73708f 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1065,7 +1065,6 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
- CHECK_VALID_SIZE(inputs.size(), 1);
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
@@ -1074,15 +1073,29 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
const auto * options = operatorPtr->builtin_options.AsReshapeOptions();
armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
- armnn::TensorInfo outputTensorInfo =
+ armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]);
+ armnn::TensorInfo reshapeOutputTensorInfo =
TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape);
+ // Check for valid input size and that reshape parameters equal output shape
+ if (inputs.size() > 1 && (options->new_shape != outputs[0]->shape))
+ {
+ std::stringstream ss;
+ ss << "New shape defined in reshape parameters "
+ << reshapeOutputTensorInfo.GetShape()
+ << " does not equal output shape "
+ << actualOutputTensorInfo.GetShape()
+ << ": "
+ << CHECK_LOCATION().AsString();
+ throw ParseException(ss.str());
+ }
+
ReshapeDescriptor reshapeDesc;
- reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
+ reshapeDesc.m_TargetShape = reshapeOutputTensorInfo.GetShape();
auto layerName = boost::str(boost::format("Reshape:%1%:%2%") % subgraphIndex % operatorIndex);
IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
- layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ layer->GetOutputSlot(0).SetTensorInfo(reshapeOutputTensorInfo);
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});