aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaCommon/operatorMappings/AdditionOperator.hpp
blob: 98c01e2cb8b23724d2094fba99f90e771df9fd05 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
//
// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <Layer.hpp>

#include <tosa_serialization_handler.h>
#include "TosaOperatorUtils.hpp"

using namespace armnn;
using namespace tosa;

TosaSerializationBasicBlock* ConvertAdditionToTosaOperator(const std::vector<const TensorInfo*>& inputs,
                                                          const std::vector<const TensorInfo*>& outputs)
{
    // A helper function with static global variables ensures uniqueness
    // for dynamically generating input, output and block names
    std::string input0Name = std::string("Op_ADD_input0_")  + GetUniqueTosaMappingID();
    std::string input1Name = std::string("Op_ADD_input1_")  + GetUniqueTosaMappingID();
    std::string outputName = std::string("Op_ADD_output0_") + GetUniqueTosaMappingID();
    std::string blockName  = std::string("Op_ADD_block_")   + GetUniqueTosaMappingID();

    TosaSerializationOperator* op = new TosaSerializationOperator(Op_ADD,
                                                                  Attribute_NONE,
                                                                  nullptr,
                                                                  {input0Name, input1Name},
                                                                  {outputName});

    std::vector<int32_t> inputShape0 = GetTosaTensorShape(inputs[0]->GetShape());
    DType inputDType0 = ArmNNToDType(inputs[0]->GetDataType());

    std::vector<int32_t> inputShape1 = GetTosaTensorShape(inputs[1]->GetShape());
    DType inputDType1 = ArmNNToDType(inputs[1]->GetDataType());

    std::vector<int32_t> outputShape0 = GetTosaTensorShape(outputs[0]->GetShape());
    DType outputDType0 = ArmNNToDType(outputs[0]->GetDataType());

    TosaSerializationTensor* inputTensor0  = new TosaSerializationTensor(input0Name, inputShape0, inputDType0, {});
    TosaSerializationTensor* inputTensor1  = new TosaSerializationTensor(input1Name, inputShape1, inputDType1, {});
    TosaSerializationTensor* outputTensor0 = new TosaSerializationTensor(outputName, outputShape0, outputDType0, {});

    // operatorInputNames/operatorOutputNames ends up being the same as
    // blockInputNames/blockOutputNames for one-to-one ArmNN to Tosa mappings
    return new TosaSerializationBasicBlock(blockName, // name
                                           {op}, // operators
                                           {inputTensor0, inputTensor1, outputTensor0}, // tensors
                                           {input0Name, input1Name}, // inputs
                                           {outputName}); // outputs
}