diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 31 |
1 files changed, 26 insertions, 5 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 9d9f4fa14b..8e0fae68d1 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1083,7 +1083,14 @@ void TfLiteParser::ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex desc.m_DataLayout = armnn::DataLayout::NHWC; auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); - CHECK_VALID_SIZE(inputs.size(), 3); + if (inputs.size() == 4) + { + desc.m_BiasEnabled = true; + } + else + { + CHECK_VALID_SIZE(inputs.size(), 3); + } auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); @@ -1143,10 +1150,24 @@ void TfLiteParser::ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex armnn::IConnectableLayer* layer = nullptr; auto layerName = fmt::format("TransposeConv:{}:{}", subgraphIndex, operatorIndex); - layer = m_Network->AddTransposeConvolution2dLayer(desc, - filterTensorAndData.first, - EmptyOptional(), - layerName.c_str()); + if (desc.m_BiasEnabled) + { + auto biasTensorInfo = ToTensorInfo(inputs[3]); + auto biasConstTensor = CreateConstTensor(inputs[3], + biasTensorInfo, + armnn::Optional<armnn::PermutationVector&>()); + layer = m_Network->AddTransposeConvolution2dLayer(desc, + filterTensorAndData.first, + biasConstTensor.first, + layerName.c_str()); + } + else + { + layer = m_Network->AddTransposeConvolution2dLayer(desc, + filterTensorAndData.first, + EmptyOptional(), + layerName.c_str()); + } ARMNN_ASSERT(layer != nullptr); |