aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
authorRyan OShea <ryan.oshea3@arm.com>2023-02-21 18:32:30 +0000
committerryan.oshea3 <ryan.oshea3@arm.com>2023-02-22 12:01:57 +0000
commitf0a35b8552ffcc39c5ebe2efc1ced15f813d8c09 (patch)
tree501d02dfd15ec2f913ea6f9c3a257ca82811e469 /src/armnnTfLiteParser/TfLiteParser.cpp
parentdf15c4e8a03423188d0598dc1a503c2d7a6d9f4e (diff)
downloadarmnn-f0a35b8552ffcc39c5ebe2efc1ced15f813d8c09.tar.gz
Fix segfault in ParseTransposeConv2d when output_shape is not constant
We currently check that output_shape is an input for transpose conv2d. If that is an input, we assume that it is constant and attempt to copy the data into the descriptor. When this data is not constant and instead comes from the output of another layer we segfault. When not constant we will use infer output shapes. * Adds a check into ParseTransposeConv2d that inputs[0] is constant Signed-off-by: Ryan OShea <ryan.oshea3@arm.com> Change-Id: I01176ae22974767a2306a3db749a029ed220d88b
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp6
1 files changed, 5 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 9e8af66b49..279f804a03 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1606,7 +1606,11 @@ void TfLiteParserImpl::ParseTransposeConv(size_t subgraphIndex, size_t operatorI
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
- if (inputs[0])
+ // This block determines the output shape of the transpose convolution. If the output shape tensor ptr is not null
+ // And the tensor is a constant, we can access the data at load time and set the output shape of the
+ // layer. If this is not constant, We do not have access to the shape data, so we have to use
+ // infer output shape and skip this code block.
+ if (inputs[0] && IsConstTensor(inputs[0]))
{
armnn::TensorInfo tensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
std::vector<int> output_shape(tensorInfo.GetNumElements());