aboutsummaryrefslogtreecommitdiff
path: root/delegate/opaque/src/BroadcastTo.hpp
blob: 379587546f249128d7506a3ebc9fb75266613ba9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
//
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <OpaqueDelegateUtils.hpp>

namespace armnnOpaqueDelegate
{
    TfLiteStatus ValidateBroadcastToOperator(DelegateData& delegateData,
                                             TfLiteOpaqueContext *tfLiteContext,
                                             const armnn::TensorInfo& inputInfo,
                                             const armnn::TensorInfo& outputInfo,
                                             const armnn::BroadcastToDescriptor& descriptor)
    {
        bool isSupported = false;
        FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("BROADCAST_TO",
                                          tfLiteContext,
                                          IsBroadcastToSupported,
                                          delegateData.m_Backends,
                                          isSupported,
                                          armnn::BackendId(),
                                          inputInfo,
                                          outputInfo,
                                          descriptor);
        return isSupported ? kTfLiteOk : kTfLiteError;
    }

    TfLiteStatus VisitBroadcastToOperator(DelegateData& delegateData,
                                          TfLiteOpaqueContext* tfLiteContext,
                                          TfLiteOpaqueNode* tfLiteNode,
                                          int nodeIndex,
                                          int32_t broadcastToOperatorCode)
    {
        TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
        TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));

        // Gather input tensors
        auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
        const int* inputTensors;
        if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
        {
            TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
                tfLiteContext,
                "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
                nodeIndex);
            return kTfLiteError;
        }

        // Gather output tensors
        int numOutputs = 0;
        const int* outputTensors;
        if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors,
                                    &numOutputs) != kTfLiteOk)
        {
            TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
                tfLiteContext,
                "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
                nodeIndex);
            return kTfLiteError;
        }

        // The input contains the data
        const TfLiteOpaqueTensor* tfLiteInputTensor =
                TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
        if (IsDynamicTensor(tfLiteInputTensor))
        {
            TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
                tfLiteContext,
                "TfLiteArmnnOpaqueDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
                broadcastToOperatorCode, nodeIndex);
            return kTfLiteError;
        }

        // The shape tensor
        const TfLiteOpaqueTensor* tfLiteShapeTensor =
                TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);;
        if (IsDynamicTensor(tfLiteShapeTensor))
        {
            TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
                tfLiteContext,
                "TfLiteArmnnOpaqueDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
                broadcastToOperatorCode, nodeIndex);
            return kTfLiteError;
        }

        // The output tensor
        const TfLiteOpaqueTensor* tfLiteOutputTensor =
                TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
        if (IsDynamicTensor(tfLiteOutputTensor))
        {
            TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
                tfLiteContext,
                "TfLiteArmnnOpaqueDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
                broadcastToOperatorCode, nodeIndex);
            return kTfLiteError;
        }

        const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
        const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor,
                                                                                       true);

        auto* shapeData = static_cast<int32_t*>(TfLiteOpaqueTensorData(tfLiteShapeTensor));
        int32_t shapeTensorNum = TfLiteOpaqueTensorDim(tfLiteShapeTensor, 0);

        armnn::BroadcastToDescriptor broadcastToDescriptor;
        broadcastToDescriptor.m_BroadcastToShape = armnn::TensorShape(shapeTensorNum,
                                                                      shapeData);

        // No network pointer indicates that only support for this operator should be checked
        if (!delegateData.m_Network)
        {
            return ValidateBroadcastToOperator(delegateData,
                                               tfLiteContext,
                                               inputTensorInfo,
                                               outputTensorInfo,
                                               broadcastToDescriptor);
        }

        std::string layerName("BroadcastTo");
        armnn::IConnectableLayer* layer = delegateData.m_Network->AddBroadcastToLayer(broadcastToDescriptor,
                                                                                      layerName.c_str());

        if (layer == nullptr)
        {
            return kTfLiteError;
        }

        layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);

        if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
        {
            return kTfLiteError;
        }

        return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
    }

} // namespace armnnOpaqueDelegate