aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/lstm.py
blob: ecec58f4828c24a51a3e7c97c5909526ed904c1e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Description:
# Contains implementation of UnidirectionalSequenceLstm graph optimisation.
from enum import Enum
from typing import Tuple

import numpy as np

from .data_type import DataType
from .debug_database import DebugDatabase
from .graph_optimiser_util import create_avg_pool_for_concat
from .operation import ActivationFunction
from .operation import ExplicitScaling
from .operation import Op
from .operation import Operation
from .operation_util import create_add
from .operation_util import create_fullyconnected
from .operation_util import create_fused_activation
from .operation_util import create_mul
from .scaling import elementwise_mul_scale
from .shape4d import Shape4D
from .tensor import QuantizationParameters
from .tensor import Tensor

Q0_15_SCALE = np.float32(2**-15)
"""Q0.15 scale like the reference defines it"""


class Lstm:
    """Lstm graph optimisation.

    Unrolls a UNIDIRECTIONAL_SEQUENCE_LSTM operation into its basic operations.

    Usage:

    unrolled_op = Lstm(op).get_graph()
    """

    class State(Enum):
        """States (variable tensors)"""

        OUTPUT = 18  # Value = tensor index
        CELL = 19  # Value = tensor index

    def __init__(self, op):
        self.op = op

    def get_graph(self) -> Operation:
        """Return the generated graph implementation"""
        self.op.ofm.ops = []
        if self.time_major:
            output_state = self.get_initial_state(Lstm.State.OUTPUT)
            cell_state = self.get_initial_state(Lstm.State.CELL)
            for time in range(self.n_time):
                feature = self.get_feature(time)
                output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time)
                op = self.put_ofm(output_state, time)
        else:
            for batch in range(self.n_batch):
                output_state = self.get_initial_state(Lstm.State.OUTPUT, batch)
                cell_state = self.get_initial_state(Lstm.State.CELL, batch)
                for time in range(self.n_time):
                    feature = self.get_feature(time, batch)
                    output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time, batch)
                    op = self.put_ofm(output_state, time, batch)
        return op

    def get_feature(self, time: int, batch: int = 0) -> Tensor:
        """Get input feature for provided time and batch"""
        feature = self.op.ifm.clone(f"_feature#{batch}.{time}")
        feature.set_all_shapes([self.n_batch if self.time_major else 1, self.n_feature])
        op = Operation(Op.SplitSliceRead, feature.name)
        op.add_input_tensor(self.op.ifm)
        op.set_output_tensor(feature)
        op.set_ifm_ofm_shapes()
        offset = [time, 0, 0] if self.time_major else [batch, time, 0]
        op.read_offsets[0] = Shape4D.from_list(offset, 0)
        op.read_shapes[0] = op.ofm_shapes[0]
        DebugDatabase.add_optimised(self.op, op)
        return feature

    def get_initial_state(self, state_type: State, batch: int = 0) -> Tensor:
        """Get state tensor for provided state type and batch"""
        state = self.state(state_type)
        if self.time_major:
            # For time major just return the 2D state, since all batches
            # are calculated at the same time
            return state
        else:
            # For non time major return one batch of the 2D state
            # by setting the read offset to the provided batch

            # The cloned state tensor will share equivalence id and buffer
            # with the variable state tensor
            n_state = state.shape[-1]
            state_ofm = state.clone(f"_state#{batch}")
            # Set shape to be one batch
            state_ofm.set_all_shapes([1, n_state])
            # Create the op for reading one batch of the state
            # (will be optimised away at a later stage)
            op = Operation(Op.SplitSliceRead, state_ofm.name)
            op.add_input_tensor(state)
            op.set_output_tensor(state_ofm)
            op.set_ifm_ofm_shapes()
            # Set the read offset to the provided batch
            op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
            # Set the read shape to one batch, see above
            op.read_shapes[0] = op.ofm_shapes[0]
            DebugDatabase.add_optimised(self.op, op)
            return state_ofm

    def get_state(self, op: Operation, batch: int = 0) -> Operation:
        """Setup the correct read offset for reading the state from
        a variable tensor state"""
        if not self.time_major and self.n_batch > 1:
            op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
            op.read_shapes[0] = Shape4D(op.ifm.shape)
            op.ifm_shapes[0] = Shape4D([self.n_batch, op.ifm.shape[-1]])
        return op

    def put_state(self, op: Operation, state_type: State, batch: int = 0) -> Operation:
        """Save the state for the provided batch by pointing the operations
        ofm to the variable state tensor"""
        # The create op functions always return 4D shape, however the state
        # should have 2D shape for correct operation
        op.ofm.shape = op.ofm.shape[-2:]
        # Get state from type
        state = self.state(state_type)
        # By using the same equivalence_id the backing buffer for the ofm
        # tensor will be the state variable tensor buffer
        op.ofm.equivalence_id = state.equivalence_id
        # Set memory function which will make the tensor be in linear format
        # just as the state variable tensor
        op.memory_function = Op.VariableTensorWrite
        # Set the batch write offset into the state tensor buffer unless
        # time_major mode when all batches are written at once
        if not self.time_major:
            op.write_offset = Shape4D.from_list([batch, 0], 0)
            op.write_shape = Shape4D(op.ofm.shape)
            op.ofm_shapes = [Shape4D(state.shape)]
        DebugDatabase.add_optimised(self.op, op)
        return op

    def put_ofm(self, state: Tensor, time: int, batch: int = 0) -> Operation:
        """Save the output state for the provided batch and time to OFM"""
        name = f"{self.op.ofm.name}#{batch}.{time}"
        offset = Shape4D.from_list([time, 0, 0] if self.time_major else [batch, time, 0], 0)
        op = create_avg_pool_for_concat(self.op, name, state, Shape4D(state.shape), offset)
        # The provided state tensor use the output state tensors buffer, so unless
        # time_major mode we need to set the correct batch read offset
        if not self.time_major:
            op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
            op.read_shapes[0] = Shape4D(state.shape)
            op.ifm_shapes[0] = Shape4D(self.output_state.shape)
        return op

    def lstm_step(
        self, feature: Tensor, output_state: Tensor, cell_state: Tensor, time: int, batch: int = 0
    ) -> Tuple[Tensor, Tensor]:
        """Generate one step of the LSTM implementation for the provided feature, batch and time"""
        input_gate = self.calculate_gate(
            f"input_gate#{batch}.{time}",
            feature,
            output_state,
            self.input_to_input_weights,
            self.input_bias,
            self.recurrent_to_input_weights,
            None,
            Op.Sigmoid,
            batch,
        )
        forget_gate = self.calculate_gate(
            f"forget_gate#{batch}.{time}",
            feature,
            output_state,
            self.input_to_forget_weights,
            self.forget_bias,
            self.recurrent_to_forget_weights,
            None,
            Op.Sigmoid,
            batch,
        )
        cell_gate = self.calculate_gate(
            f"cell_gate#{batch}.{time}",
            feature,
            output_state,
            self.input_to_cell_weights,
            self.cell_bias,
            self.recurrent_to_cell_weights,
            None,
            Op.Tanh,
            batch,
        )
        cell_state = self.calculate_cell_state(cell_state, input_gate, forget_gate, cell_gate, time, batch)
        output_gate = self.calculate_gate(
            f"output_gate#{batch}.{time}",
            feature,
            output_state,
            self.input_to_output_weights,
            self.output_bias,
            self.recurrent_to_output_weights,
            None,
            Op.Sigmoid,
            batch,
        )
        output_state = self.calculate_output_state(output_gate, cell_state, time, batch)
        return (output_state, cell_state)

    def calculate_gate(
        self,
        name: str,
        input: Tensor,
        state: Tensor,
        input_weights: Tensor,
        input_bias: Tensor,
        recurrent_weights: Tensor,
        recurrent_bias: Tensor,
        activation: Op,
        batch: int = 0,
    ):
        """Generate a gate for the provided input and weights"""
        # Activation( Add( FC(input), FC(output state) ) )
        # Setup fullyconnected quantization
        q_fc = QuantizationParameters()
        q_fc.scale_f32 = np.float32(2**-12)
        q_fc.zero_point = 0
        # Create fullyconnected
        in_fc = create_fullyconnected(f"{name}:{input.name}_fc", input, input_weights, input_bias, q_fc, False)
        re_fc = create_fullyconnected(f"{name}:{state.name}_fc", state, recurrent_weights, recurrent_bias, q_fc, False)
        self.get_state(re_fc, batch)
        # Change fullyconnected ofm data type
        in_fc.ofm.dtype = DataType.int16
        re_fc.ofm.dtype = DataType.int16
        # Setup add quantization
        q_add = q_fc.clone()
        q_add.scale_f32 = Q0_15_SCALE
        # Create add + activation
        add = create_add(f"{name}_add", in_fc.ofm, re_fc.ofm, q_add, ActivationFunction(activation))
        if activation is Op.Sigmoid:
            # For Sigmoid we need to set the activation min/max values to match the possible range
            # in the reference. The values below are the quantized min/max values that the reference
            # can achive for the LUT based Sigmoid/Logistic. (The NPU does however have a larger range
            # due to intermediate higher precision.)
            # The quantized min/max values are divided by the effective output scale 0x3000 (3<<12) used for
            # elementwise operations with fused Tanh/Sigmoid activations (to get correct scaling before the
            # fused activation function). This will yield the dequantized min/max values which are later
            # quantized again by the command stream generator.
            add.activation.max = 32757 / 0x3000
            add.activation.min = 11 / 0x3000
        # Add to debug database
        DebugDatabase.add_optimised(self.op, in_fc)
        DebugDatabase.add_optimised(self.op, re_fc)
        DebugDatabase.add_optimised(self.op, add)
        return add.ofm

    def calculate_cell_state(
        self, cell_state: Tensor, input_gate: Tensor, forget_gate: Tensor, cell_gate: Tensor, time: int, batch: int = 0
    ):
        """Update the cell state from the provided gate output"""
        # Clip( Add( Mul(cell state, forget gate), Mul(input gate, cell gate) ) )
        base_name = f"cell_state#{batch}.{time}"
        # Cell scale
        cell_scale = cell_state.quantization.scale_f32
        # Create mul(cell_state, forget_gate)
        mul_cf = create_mul(f"{base_name}_cf_mul", cell_state, forget_gate, cell_state.quantization)
        self.get_state(mul_cf, batch)
        # Calculate explicit scales to match reference
        multiplier, shift = elementwise_mul_scale(np.double(cell_scale), np.double(Q0_15_SCALE), np.double(cell_scale))
        mul_cf.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
        # Create mul(cell_gate, input_gate)
        mul_ci = create_mul(f"{base_name}_ci_mul", cell_gate, input_gate, cell_state.quantization)
        # Calculate explicit scales to match reference
        multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(cell_scale))
        mul_ci.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
        # Setup cell clip
        activation = None if self.cell_clip == 0 else ActivationFunction(Op.Clip)
        if activation:
            activation.max = self.cell_clip
            activation.min = -self.cell_clip
        # Create add + activation
        add = create_add(f"{base_name}_add", mul_cf.ofm, mul_ci.ofm, cell_state.quantization, activation)
        add.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
        # Save new state
        self.put_state(add, Lstm.State.CELL, batch)
        # Add to debug database
        DebugDatabase.add_optimised(self.op, mul_cf)
        DebugDatabase.add_optimised(self.op, mul_ci)
        DebugDatabase.add_optimised(self.op, add)
        return add.ofm

    def calculate_output_state(self, output_gate: Tensor, cell_state: Tensor, time: int, batch: int):
        """Generate the output state from the provided gate output"""
        # Mul( Tanh(cell state), output gate )
        base_name = f"output_state#{batch}.{time}"
        # Setup tanh quantization
        q_out_tanh = QuantizationParameters()
        q_out_tanh.scale_f32 = Q0_15_SCALE
        q_out_tanh.zero_point = 0
        # Create tanh(cell state)
        tanh = create_fused_activation(Op.Tanh, f"{base_name}_tanh", cell_state, q_out_tanh)
        self.get_state(tanh, batch)
        # Create Mul( Tanh(cell state), output gate )
        q_mul = self.output_state.quantization
        mul = create_mul(f"{base_name}_mul", tanh.ofm, output_gate, q_mul, dtype=self.op.ifm.dtype)
        # Use explicit scaling to match reference, the following line would have been the preferred way
        # mul.forced_output_quantization = self.hidden_quantization
        out_scale = self.hidden_quantization.scale_f32
        multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(out_scale))
        mul.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
        # Save new state
        self.put_state(mul, Lstm.State.OUTPUT, batch)
        # Add to debug database
        DebugDatabase.add_optimised(self.op, tanh)
        DebugDatabase.add_optimised(self.op, mul)
        return mul.ofm

    def state(self, state_type: State) -> Tensor:
        """Get state tensor from type"""
        return self.output_state if state_type == Lstm.State.OUTPUT else self.cell_state

    # Dimensions
    @property
    def n_feature(self) -> int:
        return self.op.ifm.shape[-1]

    @property
    def n_time(self) -> int:
        return self.op.ifm.shape[0 if self.time_major else 1]

    @property
    def n_batch(self) -> int:
        return self.op.ifm.shape[1 if self.time_major else 0]

    # Attributes
    @property
    def cell_clip(self) -> int:
        return self.op.attrs.get("cell_clip", 0)

    @property
    def projection_clip(self) -> int:
        return self.op.attrs.get("proj_clip", 0)

    @property
    def time_major(self) -> bool:
        return self.op.attrs.get("time_major", False)

    # Hidden (intermediate)
    @property
    def hidden_quantization(self) -> QuantizationParameters:
        return self.op.intermediates[4].quantization

    # Input weights
    @property
    def input_to_input_weights(self) -> Tensor:
        return self.op.inputs[1]

    @property
    def input_to_forget_weights(self) -> Tensor:
        return self.op.inputs[2]

    @property
    def input_to_cell_weights(self) -> Tensor:
        return self.op.inputs[3]

    @property
    def input_to_output_weights(self) -> Tensor:
        return self.op.inputs[4]

    # Recurrent weights
    @property
    def recurrent_to_input_weights(self) -> Tensor:
        return self.op.inputs[5]

    @property
    def recurrent_to_forget_weights(self) -> Tensor:
        return self.op.inputs[6]

    @property
    def recurrent_to_cell_weights(self) -> Tensor:
        return self.op.inputs[7]

    @property
    def recurrent_to_output_weights(self) -> Tensor:
        return self.op.inputs[8]

    # Peephole weights
    @property
    def cell_to_input_weights(self) -> Tensor:
        return self.op.inputs[9]

    @property
    def cell_to_forget_weights(self) -> Tensor:
        return self.op.inputs[10]

    @property
    def cell_to_output_weights(self) -> Tensor:
        return self.op.inputs[11]

    # Bias tensors
    @property
    def input_bias(self) -> Tensor:
        return self.op.inputs[12]

    @property
    def forget_bias(self) -> Tensor:
        return self.op.inputs[13]

    @property
    def cell_bias(self) -> Tensor:
        return self.op.inputs[14]

    @property
    def output_bias(self) -> Tensor:
        return self.op.inputs[15]

    # Projection tensors
    @property
    def projection_weights(self) -> Tensor:
        return self.op.inputs[16]

    @property
    def projection_bias(self) -> Tensor:
        return self.op.inputs[17]

    # State tensors (variable)
    @property
    def output_state(self) -> Tensor:
        return self.op.inputs[Lstm.State.OUTPUT.value]

    @property
    def cell_state(self) -> Tensor:
        return self.op.inputs[Lstm.State.CELL.value]