aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp31
1 files changed, 26 insertions, 5 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 9d9f4fa14b..8e0fae68d1 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1083,7 +1083,14 @@ void TfLiteParser::ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex
desc.m_DataLayout = armnn::DataLayout::NHWC;
auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
- CHECK_VALID_SIZE(inputs.size(), 3);
+ if (inputs.size() == 4)
+ {
+ desc.m_BiasEnabled = true;
+ }
+ else
+ {
+ CHECK_VALID_SIZE(inputs.size(), 3);
+ }
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
@@ -1143,10 +1150,24 @@ void TfLiteParser::ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex
armnn::IConnectableLayer* layer = nullptr;
auto layerName = fmt::format("TransposeConv:{}:{}", subgraphIndex, operatorIndex);
- layer = m_Network->AddTransposeConvolution2dLayer(desc,
- filterTensorAndData.first,
- EmptyOptional(),
- layerName.c_str());
+ if (desc.m_BiasEnabled)
+ {
+ auto biasTensorInfo = ToTensorInfo(inputs[3]);
+ auto biasConstTensor = CreateConstTensor(inputs[3],
+ biasTensorInfo,
+ armnn::Optional<armnn::PermutationVector&>());
+ layer = m_Network->AddTransposeConvolution2dLayer(desc,
+ filterTensorAndData.first,
+ biasConstTensor.first,
+ layerName.c_str());
+ }
+ else
+ {
+ layer = m_Network->AddTransposeConvolution2dLayer(desc,
+ filterTensorAndData.first,
+ EmptyOptional(),
+ layerName.c_str());
+ }
ARMNN_ASSERT(layer != nullptr);