diff options
Diffstat (limited to 'ethosu/vela/lstm.py')
-rw-r--r-- | ethosu/vela/lstm.py | 447 |
1 files changed, 447 insertions, 0 deletions
diff --git a/ethosu/vela/lstm.py b/ethosu/vela/lstm.py new file mode 100644 index 00000000..5a50788b --- /dev/null +++ b/ethosu/vela/lstm.py @@ -0,0 +1,447 @@ +# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com> +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# Contains implementation of UnidirectionalSequenceLstm graph optimisation. +from enum import Enum +from typing import Tuple + +import numpy as np + +from .data_type import DataType +from .debug_database import DebugDatabase +from .graph_optimiser_util import create_avg_pool_for_concat +from .operation import ActivationFunction +from .operation import ExplicitScaling +from .operation import Op +from .operation import Operation +from .operation_util import create_add +from .operation_util import create_fullyconnected +from .operation_util import create_fused_activation +from .operation_util import create_mul +from .scaling import elementwise_mul_scale +from .shape4d import Shape4D +from .tensor import QuantizationParameters +from .tensor import Tensor + +Q0_15_SCALE = np.float32(0.00003051757) +"""Q0.15 scale like the reference defines it""" + + +class Lstm: + """Lstm graph optimisation. + + Unrolls a UNIDIRECTIONAL_SEQUENCE_LSTM operation into its basic operations. + + Usage: + + unrolled_op = Lstm(op).get_graph() + """ + + class State(Enum): + """States (variable tensors)""" + + OUTPUT = 18 # Value = tensor index + CELL = 19 # Value = tensor index + + def __init__(self, op): + self.op = op + + def get_graph(self) -> Operation: + """Return the generated graph implementation""" + self.op.ofm.ops = [] + if self.time_major: + output_state = self.get_initial_state(Lstm.State.OUTPUT) + cell_state = self.get_initial_state(Lstm.State.CELL) + for time in range(self.n_time): + feature = self.get_feature(time) + output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time) + op = self.put_ofm(output_state, time) + else: + for batch in range(self.n_batch): + output_state = self.get_initial_state(Lstm.State.OUTPUT, batch) + cell_state = self.get_initial_state(Lstm.State.CELL, batch) + for time in range(self.n_time): + feature = self.get_feature(time, batch) + output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time, batch) + op = self.put_ofm(output_state, time, batch) + return op + + def get_feature(self, time: int, batch: int = 0) -> Tensor: + """Get input feature for provided time and batch""" + feature = self.op.ifm.clone(f"_feature#{batch}.{time}") + feature.set_all_shapes([self.n_batch if self.time_major else 1, self.n_feature]) + op = Operation(Op.SplitSliceRead, feature.name) + op.add_input_tensor(self.op.ifm) + op.set_output_tensor(feature) + op.set_ifm_ofm_shapes() + offset = [time, 0, 0] if self.time_major else [batch, time, 0] + op.read_offsets[0] = Shape4D.from_list(offset, 0) + op.read_shapes[0] = op.ofm_shapes[0] + DebugDatabase.add_optimised(self.op, op) + return feature + + def get_initial_state(self, state_type: State, batch: int = 0) -> Tensor: + """Get state tensor for provided state type and batch""" + state = self.state(state_type) + if self.time_major: + # For time major just return the 2D state, since all batches + # are calculated at the same time + return state + else: + # For non time major return one batch of the 2D state + # by setting the read offset to the provided batch + + # The cloned state tensor will share equivalence id and buffer + # with the variable state tensor + n_state = state.shape[-1] + state_ofm = state.clone(f"_state#{batch}") + # Set shape to be one batch + state_ofm.set_all_shapes([1, n_state]) + # Create the op for reading one batch of the state + # (will be optimised away at a later stage) + op = Operation(Op.SplitSliceRead, state_ofm.name) + op.add_input_tensor(state) + op.set_output_tensor(state_ofm) + op.set_ifm_ofm_shapes() + # Set the read offset to the provided batch + op.read_offsets[0] = Shape4D.from_list([batch, 0], 0) + # Set the read shape to one batch, see above + op.read_shapes[0] = op.ofm_shapes[0] + DebugDatabase.add_optimised(self.op, op) + return state_ofm + + def get_state(self, op: Operation, batch: int = 0) -> Operation: + """Setup the correct read offset for reading the state from + a variable tensor state""" + if not self.time_major and self.n_batch > 1: + op.read_offsets[0] = Shape4D.from_list([batch, 0], 0) + op.read_shapes[0] = Shape4D(op.ifm.shape) + op.ifm_shapes[0] = Shape4D([self.n_batch, op.ifm.shape[-1]]) + return op + + def put_state(self, op: Operation, state_type: State, batch: int = 0) -> Operation: + """Save the state for the provided batch by pointing the operations + ofm to the variable state tensor""" + # The create op functions always return 4D shape, however the state + # should have 2D shape for correct operation + op.ofm.shape = op.ofm.shape[-2:] + # Get state from type + state = self.state(state_type) + # By using the same equivalence_id the backing buffer for the ofm + # tensor will be the state variable tensor buffer + op.ofm.equivalence_id = state.equivalence_id + # Set memory function which will make the tensor be in linear format + # just as the state variable tensor + op.memory_function = Op.VariableTensorWrite + # Set the batch write offset into the state tensor buffer unless + # time_major mode when all batches are written at once + if not self.time_major: + op.write_offset = Shape4D.from_list([batch, 0], 0) + op.write_shape = Shape4D(op.ofm.shape) + op.ofm_shapes = [Shape4D(state.shape)] + DebugDatabase.add_optimised(self.op, op) + return op + + def put_ofm(self, state: Tensor, time: int, batch: int = 0) -> Operation: + """Save the output state for the provided batch and time to OFM""" + name = f"{self.op.ofm.name}#{batch}.{time}" + offset = Shape4D.from_list([time, 0, 0] if self.time_major else [batch, time, 0], 0) + op = create_avg_pool_for_concat(self.op, name, state, Shape4D(state.shape), offset) + # The provided state tensor use the output state tensors buffer, so unless + # time_major mode we need to set the correct batch read offset + if not self.time_major: + op.read_offsets[0] = Shape4D.from_list([batch, 0], 0) + op.read_shapes[0] = Shape4D(state.shape) + op.ifm_shapes[0] = Shape4D(self.output_state.shape) + return op + + def lstm_step( + self, feature: Tensor, output_state: Tensor, cell_state: Tensor, time: int, batch: int = 0 + ) -> Tuple[Tensor, Tensor]: + """Generate one step of the LSTM implementation for the provided feature, batch and time""" + input_gate = self.calculate_gate( + f"input_gate#{batch}.{time}", + feature, + output_state, + self.input_to_input_weights, + self.input_bias, + self.recurrent_to_input_weights, + None, + Op.Sigmoid, + batch, + ) + forget_gate = self.calculate_gate( + f"forget_gate#{batch}.{time}", + feature, + output_state, + self.input_to_forget_weights, + self.forget_bias, + self.recurrent_to_forget_weights, + None, + Op.Sigmoid, + batch, + ) + cell_gate = self.calculate_gate( + f"cell_gate#{batch}.{time}", + feature, + output_state, + self.input_to_cell_weights, + self.cell_bias, + self.recurrent_to_cell_weights, + None, + Op.Tanh, + batch, + ) + cell_state = self.calculate_cell_state(cell_state, input_gate, forget_gate, cell_gate, time, batch) + output_gate = self.calculate_gate( + f"output_gate#{batch}.{time}", + feature, + output_state, + self.input_to_output_weights, + self.output_bias, + self.recurrent_to_output_weights, + None, + Op.Sigmoid, + batch, + ) + output_state = self.calculate_output_state(output_gate, cell_state, time, batch) + return (output_state, cell_state) + + def calculate_gate( + self, + name: str, + input: Tensor, + state: Tensor, + input_weights: Tensor, + input_bias: Tensor, + recurrent_weights: Tensor, + recurrent_bias: Tensor, + activation: Op, + batch: int = 0, + ): + """Generate a gate for the provided input and weights""" + # Activation( Add( FC(input), FC(output state) ) ) + # Setup fullyconnected quantization + q_fc = QuantizationParameters() + q_fc.scale_f32 = np.float32(2**-12) + q_fc.zero_point = 0 + # Create fullyconnected + in_fc = create_fullyconnected(f"{name}:{input.name}_fc", input, input_weights, input_bias, q_fc, False) + re_fc = create_fullyconnected(f"{name}:{state.name}_fc", state, recurrent_weights, recurrent_bias, q_fc, False) + self.get_state(re_fc, batch) + # Change fullyconnected ofm data type + in_fc.ofm.dtype = DataType.int16 + re_fc.ofm.dtype = DataType.int16 + # Setup add quantization + q_add = q_fc.clone() + q_add.scale_f32 = np.float32(2**-15) + # Create add + activation + add = create_add(f"{name}_add", in_fc.ofm, re_fc.ofm, q_add, ActivationFunction(activation)) + if activation is Op.Sigmoid: + # For Sigmoid we need to set the activation min/max values to match the possible range + # in the reference. The values below are the quantized min/max values that the reference + # can achive for the LUT based Sigmoid/Logistic. (The NPU does however have a larger range + # due to intermediate higher precision.) + # The quantized min/max values are divided by the effective output scale 0x3000 (3<<12) used for + # elementwise operations with fused Tanh/Sigmoid activations (to get correct scaling before the + # fused activation function). This will yield the dequantized min/max values which are later + # quantized again by the command stream generator. + add.activation.max = 32757 / 0x3000 + add.activation.min = 11 / 0x3000 + # Add to debug database + DebugDatabase.add_optimised(self.op, in_fc) + DebugDatabase.add_optimised(self.op, re_fc) + DebugDatabase.add_optimised(self.op, add) + return add.ofm + + def calculate_cell_state( + self, cell_state: Tensor, input_gate: Tensor, forget_gate: Tensor, cell_gate: Tensor, time: int, batch: int = 0 + ): + """Update the cell state from the provided gate output""" + # Clip( Add( Mul(cell state, forget gate), Mul(input gate, cell gate) ) ) + base_name = f"cell_state#{batch}.{time}" + # Cell scale + cell_scale = cell_state.quantization.scale_f32 + # Create mul(cell_state, forget_gate) + mul_cf = create_mul(f"{base_name}_cf_mul", cell_state, forget_gate, cell_state.quantization) + self.get_state(mul_cf, batch) + # Calculate explicit scales to match reference + multiplier, shift = elementwise_mul_scale(np.double(cell_scale), np.double(Q0_15_SCALE), np.double(cell_scale)) + mul_cf.explicit_scaling = ExplicitScaling(False, [shift], [multiplier]) + # Create mul(cell_gate, input_gate) + mul_ci = create_mul(f"{base_name}_ci_mul", cell_gate, input_gate, cell_state.quantization) + # Calculate explicit scales to match reference + multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(cell_scale)) + mul_ci.explicit_scaling = ExplicitScaling(False, [shift], [multiplier]) + # Setup cell clip + activation = None if self.cell_clip == 0 else ActivationFunction(Op.Clip) + if activation: + activation.max = self.cell_clip + activation.min = -self.cell_clip + # Create add + activation + add = create_add(f"{base_name}_add", mul_cf.ofm, mul_ci.ofm, cell_state.quantization, activation) + add.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) + # Save new state + self.put_state(add, Lstm.State.CELL, batch) + # Add to debug database + DebugDatabase.add_optimised(self.op, mul_cf) + DebugDatabase.add_optimised(self.op, mul_ci) + DebugDatabase.add_optimised(self.op, add) + return add.ofm + + def calculate_output_state(self, output_gate: Tensor, cell_state: Tensor, time: int, batch: int): + """Generate the output state from the provided gate output""" + # Mul( Tanh(cell state), output gate ) + base_name = f"output_state#{batch}.{time}" + # Setup tanh quantization + q_out_tanh = QuantizationParameters() + q_out_tanh.scale_f32 = np.float32(2**-15) + q_out_tanh.zero_point = 0 + # Create tanh(cell state) + tanh = create_fused_activation(Op.Tanh, f"{base_name}_tanh", cell_state, q_out_tanh) + self.get_state(tanh, batch) + # Create Mul( Tanh(cell state), output gate ) + q_mul = self.output_state.quantization + mul = create_mul(f"{base_name}_mul", tanh.ofm, output_gate, q_mul, dtype=self.op.ifm.dtype) + # Use explicit scaling to match reference, the following line would have been the preferred way + # mul.forced_output_quantization = self.hidden_quantization + out_scale = self.hidden_quantization.scale_f32 + multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(out_scale)) + mul.explicit_scaling = ExplicitScaling(False, [shift], [multiplier]) + # Save new state + self.put_state(mul, Lstm.State.OUTPUT, batch) + # Add to debug database + DebugDatabase.add_optimised(self.op, tanh) + DebugDatabase.add_optimised(self.op, mul) + return mul.ofm + + def state(self, state_type: State) -> Tensor: + """Get state tensor from type""" + return self.output_state if state_type == Lstm.State.OUTPUT else self.cell_state + + # Dimensions + @property + def n_feature(self) -> int: + return self.op.ifm.shape[-1] + + @property + def n_time(self) -> int: + return self.op.ifm.shape[0 if self.time_major else 1] + + @property + def n_batch(self) -> int: + return self.op.ifm.shape[1 if self.time_major else 0] + + # Attributes + @property + def cell_clip(self) -> int: + return self.op.attrs.get("cell_clip", 0) + + @property + def projection_clip(self) -> int: + return self.op.attrs.get("proj_clip", 0) + + @property + def time_major(self) -> bool: + return self.op.attrs.get("time_major", False) + + # Hidden (intermediate) + @property + def hidden_quantization(self) -> QuantizationParameters: + return self.op.intermediates[4].quantization + + # Input weights + @property + def input_to_input_weights(self) -> Tensor: + return self.op.inputs[1] + + @property + def input_to_forget_weights(self) -> Tensor: + return self.op.inputs[2] + + @property + def input_to_cell_weights(self) -> Tensor: + return self.op.inputs[3] + + @property + def input_to_output_weights(self) -> Tensor: + return self.op.inputs[4] + + # Recurrent weights + @property + def recurrent_to_input_weights(self) -> Tensor: + return self.op.inputs[5] + + @property + def recurrent_to_forget_weights(self) -> Tensor: + return self.op.inputs[6] + + @property + def recurrent_to_cell_weights(self) -> Tensor: + return self.op.inputs[7] + + @property + def recurrent_to_output_weights(self) -> Tensor: + return self.op.inputs[8] + + # Peephole weights + @property + def cell_to_input_weights(self) -> Tensor: + return self.op.inputs[9] + + @property + def cell_to_forget_weights(self) -> Tensor: + return self.op.inputs[10] + + @property + def cell_to_output_weights(self) -> Tensor: + return self.op.inputs[11] + + # Bias tensors + @property + def input_bias(self) -> Tensor: + return self.op.inputs[12] + + @property + def forget_bias(self) -> Tensor: + return self.op.inputs[13] + + @property + def cell_bias(self) -> Tensor: + return self.op.inputs[14] + + @property + def output_bias(self) -> Tensor: + return self.op.inputs[15] + + # Projection tensors + @property + def projection_weights(self) -> Tensor: + return self.op.inputs[16] + + @property + def projection_bias(self) -> Tensor: + return self.op.inputs[17] + + # State tensors (variable) + @property + def output_state(self) -> Tensor: + return self.op.inputs[Lstm.State.OUTPUT.value] + + @property + def cell_state(self) -> Tensor: + return self.op.inputs[Lstm.State.CELL.value] |