aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_binary.h
blob: b2c92a484b4a913f0153633f598d313c365405cd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

// Copyright (c) 2020, ARM Limited.
//
//    Licensed under the Apache License, Version 2.0 (the "License");
//    you may not use this file except in compliance with the License.
//    You may obtain a copy of the License at
//
//         http://www.apache.org/licenses/LICENSE-2.0
//
//    Unless required by applicable law or agreed to in writing, software
//    distributed under the License is distributed on an "AS IS" BASIS,
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//    See the License for the specific language governing permissions and
//    limitations under the License.

#ifndef OPS_EWISE_BINARY_H
#define OPS_EWISE_BINARY_H

#include "graph_node.h"

using namespace tosa;

namespace TosaReference
{

// class BinaryNodeBase: hold common functions of all the binary nodes
//                       when an binary op is created, the virtual OpXXX::register_fcn() will be called
//                       and 'fcn' will be register with lambda function which has two inputs
// class BinaryNode: the level of indirection to partially specialize template for rank 0
//                   eval() from toplevel called should call the .binaryExpr(dims, fcn) here
//                   this needs to be partially specialize or
//                   compiler will statically fail when trying to broadcast rank0 tensor
// class OpXXX: implement per-element lambda function based on different data type
//              unlike BinaryNode, this doesn't need to be partially specialized

// Eigen::Tensor does support some binary element-wise natively (e.g. CWiseMax, or '+', etc.)
// which might be faster since it could be implemented with SIMD instructions
// the way of registering lambda + .binaryExpr() might sacrifice performance here
// but it can avoid partially specialization for combination of {rankN, rank0} x {FLOAT/INT32, QU8, ...}
// needs to revisit if performance becomes a bottleneck here
template <int Rank, DType InDtype, DType OutDtype>
class BinaryNodeBase : public GraphNode
{
public:
    BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, const uint64_t id_);
    virtual ~BinaryNodeBase();

    virtual int checkTensorAttributes() final;
    virtual int eval()         = 0;
    virtual int register_fcn() = 0;

    using InEigenType  = typename GetEigenType<InDtype>::type;
    using OutEigenType = typename GetEigenType<OutDtype>::type;
    using TIn          = Eigen::Tensor<InEigenType, Rank>;
    using TOut         = Eigen::Tensor<OutEigenType, Rank>;

protected:
    int broadcast();

protected:
    std::function<OutEigenType(InEigenType, InEigenType)> fcn;
    Eigen::array<int, Rank> bcast_a;
    Eigen::array<int, Rank> bcast_b;
    TosaReference::TensorTemplate<TIn>* a;
    TosaReference::TensorTemplate<TIn>* b;
    TosaReference::TensorTemplate<TOut>* result;
};

// primary class
template <int Rank, DType InDtype, DType OutDtype>
class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype>
{
public:
    BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_)
        : BinaryNodeBase<Rank, InDtype, OutDtype>(sgt_, op_, id_)
    {}
    virtual ~BinaryNode()
    {}

    virtual int eval();

    using InEigenType  = typename GetEigenType<InDtype>::type;
    using OutEigenType = typename GetEigenType<OutDtype>::type;
    using TIn          = Eigen::Tensor<InEigenType, Rank>;
    using TOut         = Eigen::Tensor<OutEigenType, Rank>;
};

// partial specialization for rank 0
template <DType InDtype, DType OutDtype>
class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype>
{
public:
    BinaryNode(SubgraphTraverser* sgt_, const Op& op_, const uint64_t id_)
        : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, id_)
    {}
    virtual ~BinaryNode()
    {}

    virtual int eval();
};

#define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME)                                                                 \
    template <int Rank, DType Dtype>                                                                                   \
    class Op##Opname : public BinaryNode<Rank, Dtype, Dtype>                                                           \
    {                                                                                                                  \
    public:                                                                                                            \
        Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)    \
            : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_)                                           \
        {                                                                                                              \
            register_fcn();                                                                                            \
        }                                                                                                              \
        static constexpr DType InDtype  = Dtype;                                                                       \
        static constexpr DType OutDtype = Dtype;                                                                       \
        using InEigenType               = typename GetEigenType<InDtype>::type;                                        \
        using OutEigenType              = typename GetEigenType<OutDtype>::type;                                       \
        virtual int register_fcn();                                                                                    \
    };

DEF_TEMPLATE_BINARY_OP_DEFAULT(Add, ADD)
DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseAnd, BITWISE_AND)
DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseOr, BITWISE_OR)
DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseXor, BITWISE_XOR)
DEF_TEMPLATE_BINARY_OP_DEFAULT(Intdiv, INTDIV)
DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalAnd, LOGICAL_AND)
DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalOr, LOGICAL_OR)
DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalXor, LOGICAL_XOR)
DEF_TEMPLATE_BINARY_OP_DEFAULT(Maximum, MAXIMUM)
DEF_TEMPLATE_BINARY_OP_DEFAULT(Minimum, MINIMUM)
DEF_TEMPLATE_BINARY_OP_DEFAULT(Pow, POW)
DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB)

#undef DEF_TEMPLATE_BINARY_OP_DEFAULT

template <int Rank, DType Dtype>
class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype>
{
public:
    OpArithmeticRightShift(SubgraphTraverser* sgt_,
                           TosaAttributeBase* attribute_,
                           uint64_t id_)
        : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, id_)
    {
        INIT_ATTRIBUTE(ArithmeticRightShift);
        register_fcn();
    }
    using InEigenType  = typename GetEigenType<Dtype>::type;
    using OutEigenType = typename GetEigenType<Dtype>::type;
    virtual int register_fcn();

protected:
    TosaArithmeticRightShiftAttribute* attribute;
};

template <int Rank, DType InDtype, DType OutDtype>
class OpMul : public BinaryNode<Rank, InDtype, OutDtype>
{
public:
    OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
        : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_)
    {
        INIT_ATTRIBUTE(Mul);
        register_fcn();
    }
    static constexpr int64_t QMin = GetQMin<OutDtype>::value;
    static constexpr int64_t QMax = GetQMax<OutDtype>::value;
    using InEigenType             = typename GetEigenType<InDtype>::type;
    using OutEigenType            = typename GetEigenType<OutDtype>::type;
    virtual int register_fcn();

protected:
    TosaMulAttribute* attribute;
};

template <int Rank, DType InDtype>
class OpTable : public GraphNode
{
public:
    OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
    virtual ~OpTable();

    virtual int checkTensorAttributes();
    virtual int eval();

    static constexpr DType TableDtype         = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16;
    static constexpr DType OutDtype           = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32;
    static constexpr uint32_t TableNumEntries = (InDtype == DType_INT8) ? 256 : 513;
    using InEigenType                         = typename GetEigenType<InDtype>::type;
    using TableEigenType                      = typename GetEigenType<TableDtype>::type;
    using OutEigenType                        = typename GetEigenType<OutDtype>::type;
    using TIn                                 = Eigen::Tensor<InEigenType, Rank>;
    using TTable                              = Eigen::Tensor<TableEigenType, 1>;
    using TOut                                = Eigen::Tensor<OutEigenType, Rank>;
    static constexpr int32_t IntegerBits      = 9;
    static constexpr int32_t FractionBits     = 7;
    static constexpr int32_t NumTableEntries  = (1 << IntegerBits);
    static constexpr int32_t QInMin           = GetQMin<InDtype>::value;
    static constexpr int32_t QInMax           = GetQMax<InDtype>::value;
    static constexpr int32_t QOutMin          = GetQMin<OutDtype>::value;
    static constexpr int32_t QOutMax          = GetQMax<OutDtype>::value;

protected:
    TosaReference::TensorTemplate<TIn>* in;
    TosaReference::TensorTemplate<TOut>* out;
    TosaTableAttribute* attribute;
    std::array<TableEigenType, TableNumEntries> table;
};

};    // namespace TosaReference

#endif