From c9e52794083eb73dd1bbf15ce7b16bb26394d7f5 Mon Sep 17 00:00:00 2001 From: Derek Lamberti Date: Wed, 11 Mar 2020 11:42:26 +0000 Subject: IVGCVSW-4545 Fix segfault parsing reshape layer Change-Id: Ib4bbab387cec4780c8aae56fdede090fa265e4ba Signed-off-by: Derek Lamberti --- src/armnnConverter/ArmnnConverter.cpp | 96 ++++++++++++++++++---------------- src/armnnTfLiteParser/TfLiteParser.cpp | 63 ++++++++++++++++++++-- 2 files changed, 110 insertions(+), 49 deletions(-) (limited to 'src') diff --git a/src/armnnConverter/ArmnnConverter.cpp b/src/armnnConverter/ArmnnConverter.cpp index 70df2c3a5a..e0a659dca3 100644 --- a/src/armnnConverter/ArmnnConverter.cpp +++ b/src/armnnConverter/ArmnnConverter.cpp @@ -420,68 +420,76 @@ int main(int argc, const char* argv[]) ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary); - if (modelFormat.find("caffe") != std::string::npos) + try { -#if defined(ARMNN_CAFFE_PARSER) - if (!converter.CreateNetwork()) + if (modelFormat.find("caffe") != std::string::npos) { - ARMNN_LOG(fatal) << "Failed to load model from file"; - return EXIT_FAILURE; - } +#if defined(ARMNN_CAFFE_PARSER) + if (!converter.CreateNetwork()) + { + ARMNN_LOG(fatal) << "Failed to load model from file"; + return EXIT_FAILURE; + } #else - ARMNN_LOG(fatal) << "Not built with Caffe parser support."; - return EXIT_FAILURE; -#endif - } - else if (modelFormat.find("onnx") != std::string::npos) - { -#if defined(ARMNN_ONNX_PARSER) - if (!converter.CreateNetwork()) - { - ARMNN_LOG(fatal) << "Failed to load model from file"; + ARMNN_LOG(fatal) << "Not built with Caffe parser support."; return EXIT_FAILURE; - } -#else - ARMNN_LOG(fatal) << "Not built with Onnx parser support."; - return EXIT_FAILURE; #endif - } - else if (modelFormat.find("tensorflow") != std::string::npos) - { -#if defined(ARMNN_TF_PARSER) - if (!converter.CreateNetwork()) + } + else if (modelFormat.find("onnx") != std::string::npos) { - ARMNN_LOG(fatal) << "Failed to load model from file"; +#if defined(ARMNN_ONNX_PARSER) + if (!converter.CreateNetwork()) + { + ARMNN_LOG(fatal) << "Failed to load model from file"; + return EXIT_FAILURE; + } +#else + ARMNN_LOG(fatal) << "Not built with Onnx parser support."; return EXIT_FAILURE; +#endif } + else if (modelFormat.find("tensorflow") != std::string::npos) + { +#if defined(ARMNN_TF_PARSER) + if (!converter.CreateNetwork()) + { + ARMNN_LOG(fatal) << "Failed to load model from file"; + return EXIT_FAILURE; + } #else - ARMNN_LOG(fatal) << "Not built with Tensorflow parser support."; - return EXIT_FAILURE; + ARMNN_LOG(fatal) << "Not built with Tensorflow parser support."; + return EXIT_FAILURE; #endif - } - else if (modelFormat.find("tflite") != std::string::npos) - { -#if defined(ARMNN_TF_LITE_PARSER) - if (!isModelBinary) + } + else if (modelFormat.find("tflite") != std::string::npos) { - ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \ - for tflite files"; +#if defined(ARMNN_TF_LITE_PARSER) + if (!isModelBinary) + { + ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \ + for tflite files"; + return EXIT_FAILURE; + } + + if (!converter.CreateNetwork()) + { + ARMNN_LOG(fatal) << "Failed to load model from file"; + return EXIT_FAILURE; + } +#else + ARMNN_LOG(fatal) << "Not built with TfLite parser support."; return EXIT_FAILURE; +#endif } - - if (!converter.CreateNetwork()) + else { - ARMNN_LOG(fatal) << "Failed to load model from file"; + ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'"; return EXIT_FAILURE; } -#else - ARMNN_LOG(fatal) << "Not built with TfLite parser support."; - return EXIT_FAILURE; -#endif } - else + catch(armnn::Exception& e) { - ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'"; + ARMNN_LOG(fatal) << "Failed to load model from file: " << e.what(); return EXIT_FAILURE; } diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index eab9f4ea30..fc5041bf9a 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -31,6 +31,14 @@ #include #include #include +#include + +#define ARMNN_THROW_PARSE_EXCEPTION(msg) \ + { \ + throw armnn::ParseException( static_cast( std::stringstream() << msg \ + << ": " \ + << CHECK_LOCATION().AsString()).str()); \ + } using namespace armnn; using armnn::CheckLocation; @@ -1932,8 +1940,51 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex) armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]); + + std::vector targetShape; + if (inputs.size() > 1 && inputs[1] != nullptr) + { + if (options != nullptr) + { + ARMNN_THROW_PARSE_EXCEPTION("Target shape defined in reshape parameters and input tensor. " + "Only one method expected"); + } + + if (inputs[1]->is_variable) + { + ARMNN_THROW_PARSE_EXCEPTION( "Target shapes defined in non-const input tensors is not supported"); + } + + if (inputs[1]->shape.size() != 1) + { + ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not a 1D tensor"); + } + + if (inputs[1]->type != tflite::TensorType_INT32) + { + ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not an int32 type"); + } + + auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer); + auto vals = reinterpret_cast(bufferPtr->data.data()); + for (int i=0; i < inputs[1]->shape[0]; i++) + { + targetShape.push_back(vals[i]); + } + } + else + { + if (options == nullptr) + { + ARMNN_THROW_PARSE_EXCEPTION("Target shape not defined in reshape parameters or input tensor. " + "At least one method required"); + } + + targetShape = options->new_shape; + } + armnn::TensorInfo reshapeOutputTensorInfo = - TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape); + TfLiteParser::OutputShapeOfReshape(inputTensorInfo, targetShape); // Check for valid input size and that reshape parameters equal output shape const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape(); @@ -2581,10 +2632,12 @@ TfLiteParser::ModelPtr TfLiteParser::LoadModelFromFile(const char * fileName) boost::filesystem::path pathToFile(fileName); if (!boost::filesystem::exists(pathToFile, errorCode)) { - throw FileNotFoundException(boost::str(boost::format("Cannot find the file (%1%) errorCode: %2% %3%") % - fileName % - errorCode % - CHECK_LOCATION().AsString())); + std::string locationString = CHECK_LOCATION().AsString(); + std::string msg = boost::str(boost::format("Cannot find the file (%1%) errorCode: %2% %3%") % + fileName % + errorCode % + locationString); + throw FileNotFoundException(msg); } std::ifstream file(fileName, std::ios::binary); std::string fileContent((std::istreambuf_iterator(file)), std::istreambuf_iterator()); -- cgit v1.2.1