diff options
author | kevmay01 <kevin.may@arm.com> | 2018-12-17 14:28:03 +0000 |
---|---|---|
committer | Les Bell <les.bell@arm.com> | 2018-12-17 14:51:56 +0000 |
commit | 71972a85778ad158ed3f471bec6b75a8c40ea3a1 (patch) | |
tree | 132a552b01e918f7db6afb8ad27806b655386e75 /src | |
parent | 69352c1504d9e82c261a639db3ef03087a410f3a (diff) | |
download | armnn-71972a85778ad158ed3f471bec6b75a8c40ea3a1.tar.gz |
IVGCVSW-2395 TfLiteParse::ParseReshape doesn't support reshape input
Change-Id: If2a31a49df3701877ce0287a81c569334a24cd20
Diffstat (limited to 'src')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 89c72c52e0..49bc73708f 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1065,7 +1065,6 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex) CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); - CHECK_VALID_SIZE(inputs.size(), 1); auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); @@ -1074,15 +1073,29 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex) const auto * options = operatorPtr->builtin_options.AsReshapeOptions(); armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); - armnn::TensorInfo outputTensorInfo = + armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]); + armnn::TensorInfo reshapeOutputTensorInfo = TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape); + // Check for valid input size and that reshape parameters equal output shape + if (inputs.size() > 1 && (options->new_shape != outputs[0]->shape)) + { + std::stringstream ss; + ss << "New shape defined in reshape parameters " + << reshapeOutputTensorInfo.GetShape() + << " does not equal output shape " + << actualOutputTensorInfo.GetShape() + << ": " + << CHECK_LOCATION().AsString(); + throw ParseException(ss.str()); + } + ReshapeDescriptor reshapeDesc; - reshapeDesc.m_TargetShape = outputTensorInfo.GetShape(); + reshapeDesc.m_TargetShape = reshapeOutputTensorInfo.GetShape(); auto layerName = boost::str(boost::format("Reshape:%1%:%2%") % subgraphIndex % operatorIndex); IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str()); - layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + layer->GetOutputSlot(0).SetTensorInfo(reshapeOutputTensorInfo); auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); |