aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaCommon/operatorMappings/TransposeConv2dOperator.cpp
diff options
context:
space:
mode:
authorMatthew Sloyan <matthew.sloyan@arm.com>2022-12-14 10:16:27 +0000
committerTeresaARM <teresa.charlinreyes@arm.com>2022-12-15 12:21:16 +0000
commitda6bf9e2eac374cd92147d3c60a8af8bd6bc5a37 (patch)
tree9999b8d92c2b14b4bb349cbfd250dc33af252fb7 /src/backends/tosaCommon/operatorMappings/TransposeConv2dOperator.cpp
parentfc9d5e7d1e0c1a4d7fed4ebc363832e03c3e2543 (diff)
downloadarmnn-da6bf9e2eac374cd92147d3c60a8af8bd6bc5a37.tar.gz
IVGCVSW-7168 Support simple model in the TOSA Reference Backend
* Fixed issue where duplicate tensors where being created. * Fixed issue where output name could be generated with the wrong id. * Updated bias tensor for Conv2d, so the size matches the channel. Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com> Change-Id: I1de6947e036b3e629ec6446d24d69e50603a5593
Diffstat (limited to 'src/backends/tosaCommon/operatorMappings/TransposeConv2dOperator.cpp')
-rw-r--r--src/backends/tosaCommon/operatorMappings/TransposeConv2dOperator.cpp3
1 files changed, 1 insertions, 2 deletions
diff --git a/src/backends/tosaCommon/operatorMappings/TransposeConv2dOperator.cpp b/src/backends/tosaCommon/operatorMappings/TransposeConv2dOperator.cpp
index a0d58e2fa8..1ad8c9562f 100644
--- a/src/backends/tosaCommon/operatorMappings/TransposeConv2dOperator.cpp
+++ b/src/backends/tosaCommon/operatorMappings/TransposeConv2dOperator.cpp
@@ -94,8 +94,7 @@ TosaSerializationBasicBlock* ConvertTransposeConv2dToTosaOperator(const Layer* l
{
// If bias is disabled, create a constant bias tensor of 0's as three inputs are required.
// The size of the bias must match the channels dimension, so get the correct index.
- unsigned int index = (descriptor->m_DataLayout == DataLayout::NHWC) ?
- outputs[0]->GetShape()[3] : outputs[0]->GetShape()[1];
+ unsigned int index = (descriptor->m_DataLayout == DataLayout::NHWC) ? 3 : 1;
std::vector<uint8_t> uint8Data;
std::vector<float> data(outputs[0]->GetShape()[index], 0.0f);