aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2020-03-11 11:42:26 +0000
committerJim Flynn <jim.flynn@arm.com>2020-03-12 11:56:33 +0000
commitc9e52794083eb73dd1bbf15ce7b16bb26394d7f5 (patch)
tree2681e7b7a3509989fdd87024363a7f8346bfd7c2 /src/armnnTfLiteParser
parent431852c95ab89194bac9c9ce57ca011c0ce2f15e (diff)
downloadarmnn-c9e52794083eb73dd1bbf15ce7b16bb26394d7f5.tar.gz
IVGCVSW-4545 Fix segfault parsing reshape layer
Change-Id: Ib4bbab387cec4780c8aae56fdede090fa265e4ba Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp63
1 files changed, 58 insertions, 5 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index eab9f4ea30..fc5041bf9a 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -31,6 +31,14 @@
#include <algorithm>
#include <limits>
#include <numeric>
+#include <sstream>
+
+#define ARMNN_THROW_PARSE_EXCEPTION(msg) \
+ { \
+ throw armnn::ParseException( static_cast<const std::stringstream&>( std::stringstream() << msg \
+ << ": " \
+ << CHECK_LOCATION().AsString()).str()); \
+ }
using namespace armnn;
using armnn::CheckLocation;
@@ -1932,8 +1940,51 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]);
+
+ std::vector<int32_t> targetShape;
+ if (inputs.size() > 1 && inputs[1] != nullptr)
+ {
+ if (options != nullptr)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION("Target shape defined in reshape parameters and input tensor. "
+ "Only one method expected");
+ }
+
+ if (inputs[1]->is_variable)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION( "Target shapes defined in non-const input tensors is not supported");
+ }
+
+ if (inputs[1]->shape.size() != 1)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not a 1D tensor");
+ }
+
+ if (inputs[1]->type != tflite::TensorType_INT32)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not an int32 type");
+ }
+
+ auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+ auto vals = reinterpret_cast<const int32_t*>(bufferPtr->data.data());
+ for (int i=0; i < inputs[1]->shape[0]; i++)
+ {
+ targetShape.push_back(vals[i]);
+ }
+ }
+ else
+ {
+ if (options == nullptr)
+ {
+ ARMNN_THROW_PARSE_EXCEPTION("Target shape not defined in reshape parameters or input tensor. "
+ "At least one method required");
+ }
+
+ targetShape = options->new_shape;
+ }
+
armnn::TensorInfo reshapeOutputTensorInfo =
- TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape);
+ TfLiteParser::OutputShapeOfReshape(inputTensorInfo, targetShape);
// Check for valid input size and that reshape parameters equal output shape
const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape();
@@ -2581,10 +2632,12 @@ TfLiteParser::ModelPtr TfLiteParser::LoadModelFromFile(const char * fileName)
boost::filesystem::path pathToFile(fileName);
if (!boost::filesystem::exists(pathToFile, errorCode))
{
- throw FileNotFoundException(boost::str(boost::format("Cannot find the file (%1%) errorCode: %2% %3%") %
- fileName %
- errorCode %
- CHECK_LOCATION().AsString()));
+ std::string locationString = CHECK_LOCATION().AsString();
+ std::string msg = boost::str(boost::format("Cannot find the file (%1%) errorCode: %2% %3%") %
+ fileName %
+ errorCode %
+ locationString);
+ throw FileNotFoundException(msg);
}
std::ifstream file(fileName, std::ios::binary);
std::string fileContent((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());