aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2020-03-11 11:42:26 +0000
committerJim Flynn <jim.flynn@arm.com>2020-03-12 11:56:33 +0000
commitc9e52794083eb73dd1bbf15ce7b16bb26394d7f5 (patch)
tree2681e7b7a3509989fdd87024363a7f8346bfd7c2 /src
parent431852c95ab89194bac9c9ce57ca011c0ce2f15e (diff)
downloadarmnn-c9e52794083eb73dd1bbf15ce7b16bb26394d7f5.tar.gz
IVGCVSW-4545 Fix segfault parsing reshape layer
Change-Id: Ib4bbab387cec4780c8aae56fdede090fa265e4ba Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/armnnConverter/ArmnnConverter.cpp96
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp63
2 files changed, 110 insertions, 49 deletions
diff --git a/src/armnnConverter/ArmnnConverter.cpp b/src/armnnConverter/ArmnnConverter.cpp
index 70df2c3a5a..e0a659dca3 100644
--- a/src/armnnConverter/ArmnnConverter.cpp
+++ b/src/armnnConverter/ArmnnConverter.cpp
@@ -420,68 +420,76 @@ int main(int argc, const char* argv[])
ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
- if (modelFormat.find("caffe") != std::string::npos)
+ try
{
-#if defined(ARMNN_CAFFE_PARSER)
- if (!converter.CreateNetwork<armnnCaffeParser::ICaffeParser>())
+ if (modelFormat.find("caffe") != std::string::npos)
{
- ARMNN_LOG(fatal) << "Failed to load model from file";
- return EXIT_FAILURE;
- }
+#if defined(ARMNN_CAFFE_PARSER)
+ if (!converter.CreateNetwork<armnnCaffeParser::ICaffeParser>())
+ {
+ ARMNN_LOG(fatal) << "Failed to load model from file";
+ return EXIT_FAILURE;
+ }
#else
- ARMNN_LOG(fatal) << "Not built with Caffe parser support.";
- return EXIT_FAILURE;
-#endif
- }
- else if (modelFormat.find("onnx") != std::string::npos)
- {
-#if defined(ARMNN_ONNX_PARSER)
- if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
- {
- ARMNN_LOG(fatal) << "Failed to load model from file";
+ ARMNN_LOG(fatal) << "Not built with Caffe parser support.";
return EXIT_FAILURE;
- }
-#else
- ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
- return EXIT_FAILURE;
#endif
- }
- else if (modelFormat.find("tensorflow") != std::string::npos)
- {
-#if defined(ARMNN_TF_PARSER)
- if (!converter.CreateNetwork<armnnTfParser::ITfParser>())
+ }
+ else if (modelFormat.find("onnx") != std::string::npos)
{
- ARMNN_LOG(fatal) << "Failed to load model from file";
+#if defined(ARMNN_ONNX_PARSER)
+ if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
+ {
+ ARMNN_LOG(fatal) << "Failed to load model from file";
+ return EXIT_FAILURE;
+ }
+#else
+ ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
return EXIT_FAILURE;
+#endif
}
+ else if (modelFormat.find("tensorflow") != std::string::npos)
+ {
+#if defined(ARMNN_TF_PARSER)
+ if (!converter.CreateNetwork<armnnTfParser::ITfParser>())
+ {
+ ARMNN_LOG(fatal) << "Failed to load model from file";
+ return EXIT_FAILURE;
+ }
#else
- ARMNN_LOG(fatal) << "Not built with Tensorflow parser support.";
- return EXIT_FAILURE;
+ ARMNN_LOG(fatal) << "Not built with Tensorflow parser support.";
+ return EXIT_FAILURE;
#endif
- }
- else if (modelFormat.find("tflite") != std::string::npos)
- {
-#if defined(ARMNN_TF_LITE_PARSER)
- if (!isModelBinary)
+ }
+ else if (modelFormat.find("tflite") != std::string::npos)
{
- ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
- for tflite files";
+#if defined(ARMNN_TF_LITE_PARSER)
+ if (!isModelBinary)
+ {
+ ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
+ for tflite files";
+ return EXIT_FAILURE;
+ }
+
+ if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
+ {
+ ARMNN_LOG(fatal) << "Failed to load model from file";
+ return EXIT_FAILURE;
+ }
+#else
+ ARMNN_LOG(fatal) << "Not built with TfLite parser support.";
return EXIT_FAILURE;
+#endif
}
-
- if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
+ else
{
- ARMNN_LOG(fatal) << "Failed to load model from file";
+ ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'";
return EXIT_FAILURE;
}
-#else
- ARMNN_LOG(fatal) << "Not built with TfLite parser support.";
- return EXIT_FAILURE;
-#endif
}
- else
+ catch(armnn::Exception& e)
{
- ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'";
+ ARMNN_LOG(fatal) << "Failed to load model from file: " << e.what();
return EXIT_FAILURE;
}
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index eab9f4ea30..fc5041bf9a 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -31,6 +31,14 @@
#include <algorithm>
#include <limits>
#include <numeric>
+#include <sstream>
+
+#define ARMNN_THROW_PARSE_EXCEPTION(msg) \
+ { \
+ throw armnn::ParseException( static_cast<const std::stringstream&>( std::stringstream() << msg \
+ << ": " \
+ << CHECK_LOCATION().AsString()).str()); \
+ }
using namespace armnn;
using armnn::CheckLocation;
@@ -1932,8 +1940,51 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]);
+
+ std::vector<int32_t> targetShape;
+ if (inputs.size() > 1 && inputs[1] != nullptr)
+ {
+ if (options != nullptr)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION("Target shape defined in reshape parameters and input tensor. "
+ "Only one method expected");
+ }
+
+ if (inputs[1]->is_variable)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION( "Target shapes defined in non-const input tensors is not supported");
+ }
+
+ if (inputs[1]->shape.size() != 1)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not a 1D tensor");
+ }
+
+ if (inputs[1]->type != tflite::TensorType_INT32)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not an int32 type");
+ }
+
+ auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+ auto vals = reinterpret_cast<const int32_t*>(bufferPtr->data.data());
+ for (int i=0; i < inputs[1]->shape[0]; i++)
+ {
+ targetShape.push_back(vals[i]);
+ }
+ }
+ else
+ {
+ if (options == nullptr)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION("Target shape not defined in reshape parameters or input tensor. "
+ "At least one method required");
+ }
+
+ targetShape = options->new_shape;
+ }
+
armnn::TensorInfo reshapeOutputTensorInfo =
- TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape);
+ TfLiteParser::OutputShapeOfReshape(inputTensorInfo, targetShape);
// Check for valid input size and that reshape parameters equal output shape
const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape();
@@ -2581,10 +2632,12 @@ TfLiteParser::ModelPtr TfLiteParser::LoadModelFromFile(const char * fileName)
boost::filesystem::path pathToFile(fileName);
if (!boost::filesystem::exists(pathToFile, errorCode))
{
- throw FileNotFoundException(boost::str(boost::format("Cannot find the file (%1%) errorCode: %2% %3%") %
- fileName %
- errorCode %
- CHECK_LOCATION().AsString()));
+ std::string locationString = CHECK_LOCATION().AsString();
+ std::string msg = boost::str(boost::format("Cannot find the file (%1%) errorCode: %2% %3%") %
+ fileName %
+ errorCode %
+ locationString);
+ throw FileNotFoundException(msg);
}
std::ifstream file(fileName, std::ios::binary);
std::string fileContent((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());