aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index bad2504f18..1b93aadc5b 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1082,6 +1082,28 @@ void TfLiteParser::ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
+ if (inputs[0])
+ {
+ armnn::TensorInfo tensorInfo = ToTensorInfo(inputs[0]);
+ std::vector<int> output_shape(tensorInfo.GetNumElements());
+ if (tensorInfo.GetDataType() == DataType::Signed32)
+ {
+ ::memcpy(output_shape.data(), GetBuffer(m_Model, inputs[0]->buffer)->data.data(), tensorInfo.GetNumBytes());
+ }
+ if (tensorInfo.GetDataType() == DataType::QAsymmU8)
+ {
+ for(unsigned int i=0; i < tensorInfo.GetNumElements(); i++)
+ {
+ output_shape[i] = GetBuffer(m_Model, inputs[0]->buffer)->data.data()[i];
+ }
+ }
+ // Change from signed to unsigned int to store in TransposeConvolution2dDescriptor.
+ for (int dimension : output_shape)
+ {
+ desc.m_OutputShape.push_back(static_cast<unsigned int>(dimension));
+ }
+ desc.m_OutputShapeEnabled = true;
+ }
armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[2]);
armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1]);