diff options
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r-- | delegate/src/DelegateUtils.hpp | 48 |
1 files changed, 47 insertions, 1 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp index 850b279fea..91447576d0 100644 --- a/delegate/src/DelegateUtils.hpp +++ b/delegate/src/DelegateUtils.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -367,6 +367,52 @@ TfLiteStatus FusedActivation(TfLiteContext* tfLiteContext, return kTfLiteOk; } +armnn::IConnectableLayer* AddReshapeLayer(TfLiteContext* tfLiteContext, + TfLiteNode* tfLiteNode, + armnn::IConnectableLayer* prevLayer, + armnn::TensorInfo reshapedOutputTensorInfo, + armnn::TensorInfo outputTensorInfo, + armnnDelegate::DelegateData& data) +{ + armnn::ReshapeDescriptor desc; + desc.m_TargetShape = outputTensorInfo.GetShape(); + + bool isSupported = false; + armnn::BackendId setBackend; + FORWARD_LAYER_SUPPORT_FUNC("RESHAPE", + tfLiteContext, + IsReshapeSupported, + data.m_Backends, + isSupported, + setBackend, + reshapedOutputTensorInfo, + outputTensorInfo, + desc); + + if (!isSupported) + { + return nullptr; + } + + armnn::IConnectableLayer* reshapeLayer = data.m_Network->AddReshapeLayer(desc); + reshapeLayer->SetBackendId(setBackend); + ARMNN_ASSERT(reshapeLayer != nullptr); + + prevLayer->GetOutputSlot(0).SetTensorInfo(reshapedOutputTensorInfo); + reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + // Connect and prepare output slots + for (unsigned int outputIndex = 0; outputIndex < reshapeLayer->GetNumOutputSlots(); ++outputIndex) + { + data.m_OutputSlotForNode[static_cast<unsigned long>( + tfLiteNode->outputs->data[outputIndex])]->Connect(reshapeLayer->GetInputSlot(0)); + armnn::IOutputSlot& outputSlot = reshapeLayer->GetOutputSlot(outputIndex); + data.m_OutputSlotForNode[static_cast<unsigned long>( + tfLiteNode->outputs->data[outputIndex])] = &outputSlot; + } + return reshapeLayer; +} + armnn::DataType GetDataType(const TfLiteTensor& tfLiteTensor) { switch (tfLiteTensor.type) |