aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-01-23 14:00:00 +0000
committerAron Virginas-Tar <aron.virginas-tar@arm.com>2019-01-23 14:08:20 +0000
commit70672f6c52e95256911ca70110d3ad2643b43eaa (patch)
tree1911ca37e37b5efaca67dc84ae32b52178f6f626
parent6392a77094e686fb4b973bddc8408c3faf256ed7 (diff)
downloadarmnn-70672f6c52e95256911ca70110d3ad2643b43eaa.tar.gz
IVGCVSW-2534 Fix bug TfLiteParser::ParseReshape()
Change-Id: I44d63552d2552842f02b2c870466851581f65b1a
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp25
1 files changed, 23 insertions, 2 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index c45e794274..affd858d77 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -432,6 +432,26 @@ armnn::LayerBindingId GenerateLayerBindingId(size_t subgraphIndex, size_t tensor
return static_cast<armnn::LayerBindingId>((tensorIndex<<8)+subgraphIndex);
}
+bool CheckShape(const armnn::TensorShape& actual, const std::vector<int32_t>& expected)
+{
+ const unsigned int actualSize = actual.GetNumDimensions();
+ if (actualSize != expected.size())
+ {
+ return false;
+ }
+
+ for (unsigned int i = 0u; i < actualSize; i++)
+ {
+ if (expected[i] < 0 ||
+ actual[i] != static_cast<unsigned int>(expected[i]))
+ {
+ return false;
+ }
+ }
+
+ return true;
+}
+
} // <anonymous>
TfLiteParser::TfLiteParser()
@@ -1314,11 +1334,12 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape);
// Check for valid input size and that reshape parameters equal output shape
- if (inputs.size() > 1 && (options->new_shape != outputs[0]->shape))
+ const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape();
+ if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, outputs[0]->shape))
{
std::stringstream ss;
ss << "New shape defined in reshape parameters "
- << reshapeOutputTensorInfo.GetShape()
+ << reshapeOutputTensorShape
<< " does not equal output shape "
<< actualOutputTensorInfo.GetShape()
<< ": "