diff options
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 89c72c52e0..49bc73708f 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1065,7 +1065,6 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex) CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); - CHECK_VALID_SIZE(inputs.size(), 1); auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); @@ -1074,15 +1073,29 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex) const auto * options = operatorPtr->builtin_options.AsReshapeOptions(); armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); - armnn::TensorInfo outputTensorInfo = + armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]); + armnn::TensorInfo reshapeOutputTensorInfo = TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape); + // Check for valid input size and that reshape parameters equal output shape + if (inputs.size() > 1 && (options->new_shape != outputs[0]->shape)) + { + std::stringstream ss; + ss << "New shape defined in reshape parameters " + << reshapeOutputTensorInfo.GetShape() + << " does not equal output shape " + << actualOutputTensorInfo.GetShape() + << ": " + << CHECK_LOCATION().AsString(); + throw ParseException(ss.str()); + } + ReshapeDescriptor reshapeDesc; - reshapeDesc.m_TargetShape = outputTensorInfo.GetShape(); + reshapeDesc.m_TargetShape = reshapeOutputTensorInfo.GetShape(); auto layerName = boost::str(boost::format("Reshape:%1%:%2%") % subgraphIndex % operatorIndex); IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str()); - layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + layer->GetOutputSlot(0).SetTensorInfo(reshapeOutputTensorInfo); auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); |