From c9e52794083eb73dd1bbf15ce7b16bb26394d7f5 Mon Sep 17 00:00:00 2001 From: Derek Lamberti Date: Wed, 11 Mar 2020 11:42:26 +0000 Subject: IVGCVSW-4545 Fix segfault parsing reshape layer Change-Id: Ib4bbab387cec4780c8aae56fdede090fa265e4ba Signed-off-by: Derek Lamberti --- src/armnnTfLiteParser/TfLiteParser.cpp | 63 +++++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 5 deletions(-) (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp') diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index eab9f4ea30..fc5041bf9a 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -31,6 +31,14 @@ #include #include #include +#include + +#define ARMNN_THROW_PARSE_EXCEPTION(msg) \ + { \ + throw armnn::ParseException( static_cast( std::stringstream() << msg \ + << ": " \ + << CHECK_LOCATION().AsString()).str()); \ + } using namespace armnn; using armnn::CheckLocation; @@ -1932,8 +1940,51 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex) armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]); + + std::vector targetShape; + if (inputs.size() > 1 && inputs[1] != nullptr) + { + if (options != nullptr) + { + ARMNN_THROW_PARSE_EXCEPTION("Target shape defined in reshape parameters and input tensor. " + "Only one method expected"); + } + + if (inputs[1]->is_variable) + { + ARMNN_THROW_PARSE_EXCEPTION( "Target shapes defined in non-const input tensors is not supported"); + } + + if (inputs[1]->shape.size() != 1) + { + ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not a 1D tensor"); + } + + if (inputs[1]->type != tflite::TensorType_INT32) + { + ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not an int32 type"); + } + + auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer); + auto vals = reinterpret_cast(bufferPtr->data.data()); + for (int i=0; i < inputs[1]->shape[0]; i++) + { + targetShape.push_back(vals[i]); + } + } + else + { + if (options == nullptr) + { + ARMNN_THROW_PARSE_EXCEPTION("Target shape not defined in reshape parameters or input tensor. " + "At least one method required"); + } + + targetShape = options->new_shape; + } + armnn::TensorInfo reshapeOutputTensorInfo = - TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape); + TfLiteParser::OutputShapeOfReshape(inputTensorInfo, targetShape); // Check for valid input size and that reshape parameters equal output shape const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape(); @@ -2581,10 +2632,12 @@ TfLiteParser::ModelPtr TfLiteParser::LoadModelFromFile(const char * fileName) boost::filesystem::path pathToFile(fileName); if (!boost::filesystem::exists(pathToFile, errorCode)) { - throw FileNotFoundException(boost::str(boost::format("Cannot find the file (%1%) errorCode: %2% %3%") % - fileName % - errorCode % - CHECK_LOCATION().AsString())); + std::string locationString = CHECK_LOCATION().AsString(); + std::string msg = boost::str(boost::format("Cannot find the file (%1%) errorCode: %2% %3%") % + fileName % + errorCode % + locationString); + throw FileNotFoundException(msg); } std::ifstream file(fileName, std::ios::binary); std::string fileContent((std::istreambuf_iterator(file)), std::istreambuf_iterator()); -- cgit v1.2.1