aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/DelegateUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r--delegate/src/DelegateUtils.hpp79
1 files changed, 7 insertions, 72 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp
index 3e74225b15..1aa9029271 100644
--- a/delegate/src/DelegateUtils.hpp
+++ b/delegate/src/DelegateUtils.hpp
@@ -13,6 +13,7 @@
#include <armnn/utility/NumericCast.hpp>
#include <armnnUtils/Permute.hpp>
+#include <armnnUtils/TensorUtils.hpp>
#include <tensorflow/lite/builtin_ops.h>
#include <tensorflow/lite/c/builtin_op_data.h>
@@ -188,91 +189,25 @@ TfLiteStatus Connect(armnn::IConnectableLayer* layer,
return kTfLiteOk;
}
-armnn::IConnectableLayer* BroadcastTensor(const armnn::TensorInfo& inputInfo0,
- const armnn::TensorInfo& inputInfo1,
- armnn::IConnectableLayer* startLayer,
- TfLiteContext* tfLiteContext,
- TfLiteNode* tfLiteNode,
- armnnDelegate::DelegateData& delegateData)
+void ExpandTensorRankToEqual(armnn::TensorInfo& inputInfo0,
+ armnn::TensorInfo& inputInfo1)
{
unsigned int inputDimensions0 = inputInfo0.GetNumDimensions();
unsigned int inputDimensions1 = inputInfo1.GetNumDimensions();
if (inputDimensions0 == inputDimensions1)
{
- auto status = Connect(startLayer, tfLiteNode, delegateData);
- return status == kTfLiteOk ? startLayer : nullptr;
+ return;
}
unsigned int biggerInputDimensions = std::max(inputDimensions0, inputDimensions1);
- unsigned int dimDifference = static_cast<unsigned int>(std::abs(armnn::numeric_cast<int>(inputDimensions0) -
- armnn::numeric_cast<int>(inputDimensions1)));
bool input0IsSmaller = inputDimensions0 < inputDimensions1;
- const armnn::TensorInfo& smallInfo = input0IsSmaller ? inputInfo0 : inputInfo1;
- const armnn::TensorShape& smallShape = smallInfo.GetShape();
-
- std::vector<unsigned int> reshapedDimensions(biggerInputDimensions, 1);
- for (unsigned int i = dimDifference; i < biggerInputDimensions; ++i)
- {
- reshapedDimensions[i] = smallShape[i - dimDifference];
- }
-
- armnn::TensorInfo reshapedInfo = smallInfo;
- reshapedInfo.SetShape(armnn::TensorShape{ armnn::numeric_cast<unsigned int>(reshapedDimensions.size()),
- reshapedDimensions.data() });
-
- armnn::ReshapeDescriptor reshapeDescriptor;
- reshapeDescriptor.m_TargetShape = reshapedInfo.GetShape();
- bool isSupported = false;
- armnn::BackendId setBackend;
- FORWARD_LAYER_SUPPORT_FUNC("RESHAPE",
- tfLiteContext,
- IsReshapeSupported,
- delegateData.m_Backends,
- isSupported,
- setBackend,
- smallInfo,
- reshapedInfo,
- reshapeDescriptor);
- if (!isSupported)
- {
- return nullptr;
- }
+ armnn::TensorInfo& smallInfo = input0IsSmaller ? inputInfo0 : inputInfo1;
+ const armnn::TensorShape& newShape = armnnUtils::ExpandDimsToRank(smallInfo.GetShape(), biggerInputDimensions);
- ARMNN_ASSERT(delegateData.m_Network != nullptr);
- // Add Reshape layer
- armnn::IConnectableLayer* reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
- reshapeLayer->SetBackendId(setBackend);
- ARMNN_ASSERT(reshapeLayer != nullptr);
- reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedInfo);
+ smallInfo.SetShape(newShape);
- if (input0IsSmaller)
- {
- delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->inputs->data[0])]
- ->Connect(reshapeLayer->GetInputSlot(0));
- reshapeLayer->GetOutputSlot(0).Connect(startLayer->GetInputSlot(0));
- delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->inputs->data[1])]
- ->Connect(startLayer->GetInputSlot(1));
- }
- else
- {
- delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->inputs->data[1])]
- ->Connect(reshapeLayer->GetInputSlot(0));
- reshapeLayer->GetOutputSlot(0).Connect(startLayer->GetInputSlot(1));
- delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->inputs->data[0])]
- ->Connect(startLayer->GetInputSlot(0));
- }
-
- // Prepare output slots
- for (unsigned int outputIndex = 0; outputIndex < startLayer->GetNumOutputSlots(); ++outputIndex)
- {
- armnn::IOutputSlot& outputSlot = startLayer->GetOutputSlot(outputIndex);
- delegateData.m_OutputSlotForNode
- [static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &outputSlot;
- }
-
- return reshapeLayer;
}
TfLiteStatus FusedActivation(TfLiteContext* tfLiteContext,