From 0ad3ef15b7b731e9b722123f8763b2f1e3783cb8 Mon Sep 17 00:00:00 2001 From: Colm Donelan Date: Fri, 3 Jul 2020 15:54:28 +0100 Subject: IVGCVSW-4988 Add handling output shape parameter to TransposeConvolution2d * Add m_OutputShape and m_OutputShapeEnabled to TransposeConvolution2dDescriptor. * Update TfLite parser to populate m_OutputShape if found in the model. Handle both Signed32 from tflite files and QAsymmU8 from test fixtures. * Update TransposeConvolution2dLayer to use m_OutputShape instead of InferOutputShapes if specified. Signed-off-by: Colm Donelan Change-Id: Ia6933065375eb8006c916f1ca67c38dc50bc205c --- src/armnnTfLiteParser/TfLiteParser.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp') diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index bad2504f18..1b93aadc5b 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1082,6 +1082,28 @@ void TfLiteParser::ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); + if (inputs[0]) + { + armnn::TensorInfo tensorInfo = ToTensorInfo(inputs[0]); + std::vector output_shape(tensorInfo.GetNumElements()); + if (tensorInfo.GetDataType() == DataType::Signed32) + { + ::memcpy(output_shape.data(), GetBuffer(m_Model, inputs[0]->buffer)->data.data(), tensorInfo.GetNumBytes()); + } + if (tensorInfo.GetDataType() == DataType::QAsymmU8) + { + for(unsigned int i=0; i < tensorInfo.GetNumElements(); i++) + { + output_shape[i] = GetBuffer(m_Model, inputs[0]->buffer)->data.data()[i]; + } + } + // Change from signed to unsigned int to store in TransposeConvolution2dDescriptor. + for (int dimension : output_shape) + { + desc.m_OutputShape.push_back(static_cast(dimension)); + } + desc.m_OutputShapeEnabled = true; + } armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[2]); armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1]); -- cgit v1.2.1