From 0ac0804e76e098695ee2b8a9e24e2f0a1efc324f Mon Sep 17 00:00:00 2001 From: Fredrik Svedberg Date: Tue, 11 Apr 2023 22:35:04 +0200 Subject: MLBEDSW-7196 Add LSTM support Added int8 and int16 UNIDIRECTIONAL_SEQUENCE_LSTM support. The implementation does not include support for: * CIFG * Peephole * Projection * Normalisation This change also: * Removed unused Op.BlockLSTM operation type. * Removed the only one consumer limitation on putting the SplitSliceRead on the tensor consumer(s), if all consumers fullfills the requirements * Added Op.VariableTensorWrite as a Operation.memory_function to make sure writes to variable tensors: * Always use linear mode * Are not moved to fast scratch * Are not fused with other elementwise operation tensor ranges Change-Id: Ief831738924ac3d1f2ba6d41f10bd6dc969911f3 Signed-off-by: Fredrik Svedberg --- ethosu/vela/graph_optimiser_util.py | 25 +- ethosu/vela/live_range.py | 4 +- ethosu/vela/lstm.py | 447 +++++++++++++++++++++ ethosu/vela/operation.py | 8 +- ethosu/vela/operation_util.py | 51 +++ ethosu/vela/pass_packing.py | 2 - ethosu/vela/scheduler.py | 7 +- ethosu/vela/test/test_tflite_model_semantic.py | 34 ++ .../vela/test/test_tflite_supported_operators.py | 46 +++ ethosu/vela/test/testutil.py | 62 ++- ethosu/vela/tflite_graph_optimiser.py | 50 +-- ethosu/vela/tflite_model_semantic.py | 45 +++ ethosu/vela/tflite_supported_operators.py | 44 +- 13 files changed, 783 insertions(+), 42 deletions(-) create mode 100644 ethosu/vela/lstm.py (limited to 'ethosu') diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index e1341d82..82790364 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -27,6 +27,7 @@ from .debug_database import DebugDatabase from .errors import UnsupportedFeatureError from .errors import VelaError from .operation import Op +from .operation_util import create_avgpool_nop from .shape4d import Shape4D from .tensor import create_const_tensor from .tensor import QuantizationParameters @@ -101,6 +102,10 @@ def check_format_restrictions(tens: Tensor, arch): ): return + # Writing to the buffer of a variable tensor needs to be linear format + if tens.ops[0].memory_function == Op.VariableTensorWrite: + return + # Check if any of the producers/consumers is run on CPU if not all(cons.run_on_npu for cons in tens.consumer_list): return @@ -222,7 +227,8 @@ def move_splitsliceread_to_consumer(op, cons_op): cons_op.ifm_shapes[1] = op.ifm_shapes[0] op.ofm.consumer_list.remove(cons_op) op.ofm.ops = [] - op.ifm.consumer_list.remove(op) + if op in op.ifm.consumer_list: + op.ifm.consumer_list.remove(op) def check_memory_only_removed(op, arch): @@ -357,3 +363,20 @@ def convert_to_lut(op, lut_values, lut_name): op.set_ifm_ofm_shapes() DebugDatabase.add_optimised(op, op) return op + + +def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D): + """Creates an average pool for the given concat op/input feature map""" + ofm = concat_op.ofm + avgpool_op = create_avgpool_nop(name) + avgpool_op.inputs = [ifm] + avgpool_op.outputs = [ofm] + + avgpool_op.write_offset = write_offset + avgpool_op.write_shape = ifm_shape + ofm.ops.append(avgpool_op) + avgpool_op.ifm_shapes.append(ifm_shape) + avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0]) + avgpool_op.memory_function = Op.ConcatSliceWrite + DebugDatabase.add_optimised(concat_op, avgpool_op) + return avgpool_op diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index 995a0ccb..3abcfcf0 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -166,9 +166,9 @@ def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None): ifm_tens = None - if sched_op.op_type.is_elementwise_op(): + elem_op = sched_op.parent_op + if sched_op.op_type.is_elementwise_op() and elem_op.memory_function is not Op.VariableTensorWrite: # Check if possible to merge ifm/ofm live ranges of elementwise op - elem_op = sched_op.parent_op if not tensor_should_be_ignored(elem_op.ofm, target_mem_area, target_mem_type_set): # Check if overwriting the inputs can be allowed OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"]) diff --git a/ethosu/vela/lstm.py b/ethosu/vela/lstm.py new file mode 100644 index 00000000..5a50788b --- /dev/null +++ b/ethosu/vela/lstm.py @@ -0,0 +1,447 @@ +# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# Contains implementation of UnidirectionalSequenceLstm graph optimisation. +from enum import Enum +from typing import Tuple + +import numpy as np + +from .data_type import DataType +from .debug_database import DebugDatabase +from .graph_optimiser_util import create_avg_pool_for_concat +from .operation import ActivationFunction +from .operation import ExplicitScaling +from .operation import Op +from .operation import Operation +from .operation_util import create_add +from .operation_util import create_fullyconnected +from .operation_util import create_fused_activation +from .operation_util import create_mul +from .scaling import elementwise_mul_scale +from .shape4d import Shape4D +from .tensor import QuantizationParameters +from .tensor import Tensor + +Q0_15_SCALE = np.float32(0.00003051757) +"""Q0.15 scale like the reference defines it""" + + +class Lstm: + """Lstm graph optimisation. + + Unrolls a UNIDIRECTIONAL_SEQUENCE_LSTM operation into its basic operations. + + Usage: + + unrolled_op = Lstm(op).get_graph() + """ + + class State(Enum): + """States (variable tensors)""" + + OUTPUT = 18 # Value = tensor index + CELL = 19 # Value = tensor index + + def __init__(self, op): + self.op = op + + def get_graph(self) -> Operation: + """Return the generated graph implementation""" + self.op.ofm.ops = [] + if self.time_major: + output_state = self.get_initial_state(Lstm.State.OUTPUT) + cell_state = self.get_initial_state(Lstm.State.CELL) + for time in range(self.n_time): + feature = self.get_feature(time) + output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time) + op = self.put_ofm(output_state, time) + else: + for batch in range(self.n_batch): + output_state = self.get_initial_state(Lstm.State.OUTPUT, batch) + cell_state = self.get_initial_state(Lstm.State.CELL, batch) + for time in range(self.n_time): + feature = self.get_feature(time, batch) + output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time, batch) + op = self.put_ofm(output_state, time, batch) + return op + + def get_feature(self, time: int, batch: int = 0) -> Tensor: + """Get input feature for provided time and batch""" + feature = self.op.ifm.clone(f"_feature#{batch}.{time}") + feature.set_all_shapes([self.n_batch if self.time_major else 1, self.n_feature]) + op = Operation(Op.SplitSliceRead, feature.name) + op.add_input_tensor(self.op.ifm) + op.set_output_tensor(feature) + op.set_ifm_ofm_shapes() + offset = [time, 0, 0] if self.time_major else [batch, time, 0] + op.read_offsets[0] = Shape4D.from_list(offset, 0) + op.read_shapes[0] = op.ofm_shapes[0] + DebugDatabase.add_optimised(self.op, op) + return feature + + def get_initial_state(self, state_type: State, batch: int = 0) -> Tensor: + """Get state tensor for provided state type and batch""" + state = self.state(state_type) + if self.time_major: + # For time major just return the 2D state, since all batches + # are calculated at the same time + return state + else: + # For non time major return one batch of the 2D state + # by setting the read offset to the provided batch + + # The cloned state tensor will share equivalence id and buffer + # with the variable state tensor + n_state = state.shape[-1] + state_ofm = state.clone(f"_state#{batch}") + # Set shape to be one batch + state_ofm.set_all_shapes([1, n_state]) + # Create the op for reading one batch of the state + # (will be optimised away at a later stage) + op = Operation(Op.SplitSliceRead, state_ofm.name) + op.add_input_tensor(state) + op.set_output_tensor(state_ofm) + op.set_ifm_ofm_shapes() + # Set the read offset to the provided batch + op.read_offsets[0] = Shape4D.from_list([batch, 0], 0) + # Set the read shape to one batch, see above + op.read_shapes[0] = op.ofm_shapes[0] + DebugDatabase.add_optimised(self.op, op) + return state_ofm + + def get_state(self, op: Operation, batch: int = 0) -> Operation: + """Setup the correct read offset for reading the state from + a variable tensor state""" + if not self.time_major and self.n_batch > 1: + op.read_offsets[0] = Shape4D.from_list([batch, 0], 0) + op.read_shapes[0] = Shape4D(op.ifm.shape) + op.ifm_shapes[0] = Shape4D([self.n_batch, op.ifm.shape[-1]]) + return op + + def put_state(self, op: Operation, state_type: State, batch: int = 0) -> Operation: + """Save the state for the provided batch by pointing the operations + ofm to the variable state tensor""" + # The create op functions always return 4D shape, however the state + # should have 2D shape for correct operation + op.ofm.shape = op.ofm.shape[-2:] + # Get state from type + state = self.state(state_type) + # By using the same equivalence_id the backing buffer for the ofm + # tensor will be the state variable tensor buffer + op.ofm.equivalence_id = state.equivalence_id + # Set memory function which will make the tensor be in linear format + # just as the state variable tensor + op.memory_function = Op.VariableTensorWrite + # Set the batch write offset into the state tensor buffer unless + # time_major mode when all batches are written at once + if not self.time_major: + op.write_offset = Shape4D.from_list([batch, 0], 0) + op.write_shape = Shape4D(op.ofm.shape) + op.ofm_shapes = [Shape4D(state.shape)] + DebugDatabase.add_optimised(self.op, op) + return op + + def put_ofm(self, state: Tensor, time: int, batch: int = 0) -> Operation: + """Save the output state for the provided batch and time to OFM""" + name = f"{self.op.ofm.name}#{batch}.{time}" + offset = Shape4D.from_list([time, 0, 0] if self.time_major else [batch, time, 0], 0) + op = create_avg_pool_for_concat(self.op, name, state, Shape4D(state.shape), offset) + # The provided state tensor use the output state tensors buffer, so unless + # time_major mode we need to set the correct batch read offset + if not self.time_major: + op.read_offsets[0] = Shape4D.from_list([batch, 0], 0) + op.read_shapes[0] = Shape4D(state.shape) + op.ifm_shapes[0] = Shape4D(self.output_state.shape) + return op + + def lstm_step( + self, feature: Tensor, output_state: Tensor, cell_state: Tensor, time: int, batch: int = 0 + ) -> Tuple[Tensor, Tensor]: + """Generate one step of the LSTM implementation for the provided feature, batch and time""" + input_gate = self.calculate_gate( + f"input_gate#{batch}.{time}", + feature, + output_state, + self.input_to_input_weights, + self.input_bias, + self.recurrent_to_input_weights, + None, + Op.Sigmoid, + batch, + ) + forget_gate = self.calculate_gate( + f"forget_gate#{batch}.{time}", + feature, + output_state, + self.input_to_forget_weights, + self.forget_bias, + self.recurrent_to_forget_weights, + None, + Op.Sigmoid, + batch, + ) + cell_gate = self.calculate_gate( + f"cell_gate#{batch}.{time}", + feature, + output_state, + self.input_to_cell_weights, + self.cell_bias, + self.recurrent_to_cell_weights, + None, + Op.Tanh, + batch, + ) + cell_state = self.calculate_cell_state(cell_state, input_gate, forget_gate, cell_gate, time, batch) + output_gate = self.calculate_gate( + f"output_gate#{batch}.{time}", + feature, + output_state, + self.input_to_output_weights, + self.output_bias, + self.recurrent_to_output_weights, + None, + Op.Sigmoid, + batch, + ) + output_state = self.calculate_output_state(output_gate, cell_state, time, batch) + return (output_state, cell_state) + + def calculate_gate( + self, + name: str, + input: Tensor, + state: Tensor, + input_weights: Tensor, + input_bias: Tensor, + recurrent_weights: Tensor, + recurrent_bias: Tensor, + activation: Op, + batch: int = 0, + ): + """Generate a gate for the provided input and weights""" + # Activation( Add( FC(input), FC(output state) ) ) + # Setup fullyconnected quantization + q_fc = QuantizationParameters() + q_fc.scale_f32 = np.float32(2**-12) + q_fc.zero_point = 0 + # Create fullyconnected + in_fc = create_fullyconnected(f"{name}:{input.name}_fc", input, input_weights, input_bias, q_fc, False) + re_fc = create_fullyconnected(f"{name}:{state.name}_fc", state, recurrent_weights, recurrent_bias, q_fc, False) + self.get_state(re_fc, batch) + # Change fullyconnected ofm data type + in_fc.ofm.dtype = DataType.int16 + re_fc.ofm.dtype = DataType.int16 + # Setup add quantization + q_add = q_fc.clone() + q_add.scale_f32 = np.float32(2**-15) + # Create add + activation + add = create_add(f"{name}_add", in_fc.ofm, re_fc.ofm, q_add, ActivationFunction(activation)) + if activation is Op.Sigmoid: + # For Sigmoid we need to set the activation min/max values to match the possible range + # in the reference. The values below are the quantized min/max values that the reference + # can achive for the LUT based Sigmoid/Logistic. (The NPU does however have a larger range + # due to intermediate higher precision.) + # The quantized min/max values are divided by the effective output scale 0x3000 (3<<12) used for + # elementwise operations with fused Tanh/Sigmoid activations (to get correct scaling before the + # fused activation function). This will yield the dequantized min/max values which are later + # quantized again by the command stream generator. + add.activation.max = 32757 / 0x3000 + add.activation.min = 11 / 0x3000 + # Add to debug database + DebugDatabase.add_optimised(self.op, in_fc) + DebugDatabase.add_optimised(self.op, re_fc) + DebugDatabase.add_optimised(self.op, add) + return add.ofm + + def calculate_cell_state( + self, cell_state: Tensor, input_gate: Tensor, forget_gate: Tensor, cell_gate: Tensor, time: int, batch: int = 0 + ): + """Update the cell state from the provided gate output""" + # Clip( Add( Mul(cell state, forget gate), Mul(input gate, cell gate) ) ) + base_name = f"cell_state#{batch}.{time}" + # Cell scale + cell_scale = cell_state.quantization.scale_f32 + # Create mul(cell_state, forget_gate) + mul_cf = create_mul(f"{base_name}_cf_mul", cell_state, forget_gate, cell_state.quantization) + self.get_state(mul_cf, batch) + # Calculate explicit scales to match reference + multiplier, shift = elementwise_mul_scale(np.double(cell_scale), np.double(Q0_15_SCALE), np.double(cell_scale)) + mul_cf.explicit_scaling = ExplicitScaling(False, [shift], [multiplier]) + # Create mul(cell_gate, input_gate) + mul_ci = create_mul(f"{base_name}_ci_mul", cell_gate, input_gate, cell_state.quantization) + # Calculate explicit scales to match reference + multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(cell_scale)) + mul_ci.explicit_scaling = ExplicitScaling(False, [shift], [multiplier]) + # Setup cell clip + activation = None if self.cell_clip == 0 else ActivationFunction(Op.Clip) + if activation: + activation.max = self.cell_clip + activation.min = -self.cell_clip + # Create add + activation + add = create_add(f"{base_name}_add", mul_cf.ofm, mul_ci.ofm, cell_state.quantization, activation) + add.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) + # Save new state + self.put_state(add, Lstm.State.CELL, batch) + # Add to debug database + DebugDatabase.add_optimised(self.op, mul_cf) + DebugDatabase.add_optimised(self.op, mul_ci) + DebugDatabase.add_optimised(self.op, add) + return add.ofm + + def calculate_output_state(self, output_gate: Tensor, cell_state: Tensor, time: int, batch: int): + """Generate the output state from the provided gate output""" + # Mul( Tanh(cell state), output gate ) + base_name = f"output_state#{batch}.{time}" + # Setup tanh quantization + q_out_tanh = QuantizationParameters() + q_out_tanh.scale_f32 = np.float32(2**-15) + q_out_tanh.zero_point = 0 + # Create tanh(cell state) + tanh = create_fused_activation(Op.Tanh, f"{base_name}_tanh", cell_state, q_out_tanh) + self.get_state(tanh, batch) + # Create Mul( Tanh(cell state), output gate ) + q_mul = self.output_state.quantization + mul = create_mul(f"{base_name}_mul", tanh.ofm, output_gate, q_mul, dtype=self.op.ifm.dtype) + # Use explicit scaling to match reference, the following line would have been the preferred way + # mul.forced_output_quantization = self.hidden_quantization + out_scale = self.hidden_quantization.scale_f32 + multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(out_scale)) + mul.explicit_scaling = ExplicitScaling(False, [shift], [multiplier]) + # Save new state + self.put_state(mul, Lstm.State.OUTPUT, batch) + # Add to debug database + DebugDatabase.add_optimised(self.op, tanh) + DebugDatabase.add_optimised(self.op, mul) + return mul.ofm + + def state(self, state_type: State) -> Tensor: + """Get state tensor from type""" + return self.output_state if state_type == Lstm.State.OUTPUT else self.cell_state + + # Dimensions + @property + def n_feature(self) -> int: + return self.op.ifm.shape[-1] + + @property + def n_time(self) -> int: + return self.op.ifm.shape[0 if self.time_major else 1] + + @property + def n_batch(self) -> int: + return self.op.ifm.shape[1 if self.time_major else 0] + + # Attributes + @property + def cell_clip(self) -> int: + return self.op.attrs.get("cell_clip", 0) + + @property + def projection_clip(self) -> int: + return self.op.attrs.get("proj_clip", 0) + + @property + def time_major(self) -> bool: + return self.op.attrs.get("time_major", False) + + # Hidden (intermediate) + @property + def hidden_quantization(self) -> QuantizationParameters: + return self.op.intermediates[4].quantization + + # Input weights + @property + def input_to_input_weights(self) -> Tensor: + return self.op.inputs[1] + + @property + def input_to_forget_weights(self) -> Tensor: + return self.op.inputs[2] + + @property + def input_to_cell_weights(self) -> Tensor: + return self.op.inputs[3] + + @property + def input_to_output_weights(self) -> Tensor: + return self.op.inputs[4] + + # Recurrent weights + @property + def recurrent_to_input_weights(self) -> Tensor: + return self.op.inputs[5] + + @property + def recurrent_to_forget_weights(self) -> Tensor: + return self.op.inputs[6] + + @property + def recurrent_to_cell_weights(self) -> Tensor: + return self.op.inputs[7] + + @property + def recurrent_to_output_weights(self) -> Tensor: + return self.op.inputs[8] + + # Peephole weights + @property + def cell_to_input_weights(self) -> Tensor: + return self.op.inputs[9] + + @property + def cell_to_forget_weights(self) -> Tensor: + return self.op.inputs[10] + + @property + def cell_to_output_weights(self) -> Tensor: + return self.op.inputs[11] + + # Bias tensors + @property + def input_bias(self) -> Tensor: + return self.op.inputs[12] + + @property + def forget_bias(self) -> Tensor: + return self.op.inputs[13] + + @property + def cell_bias(self) -> Tensor: + return self.op.inputs[14] + + @property + def output_bias(self) -> Tensor: + return self.op.inputs[15] + + # Projection tensors + @property + def projection_weights(self) -> Tensor: + return self.op.inputs[16] + + @property + def projection_bias(self) -> Tensor: + return self.op.inputs[17] + + # State tensors (variable) + @property + def output_state(self) -> Tensor: + return self.op.inputs[Lstm.State.OUTPUT.value] + + @property + def cell_state(self) -> Tensor: + return self.op.inputs[Lstm.State.CELL.value] diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 67717104..d1670536 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -37,6 +37,7 @@ from .shape4d import Shape4D # Import needed for Type annotations. Only import for Type checking to avoid run-time errors due to cyclic import. if TYPE_CHECKING: + from .tensor import QuantizationParameters from .tensor import Tensor PointXY = namedtuple("PointXY", "x y") @@ -142,8 +143,6 @@ class Op(Enum): BatchToSpaceND = OperatorInfo() BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) - BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_BLOCK_LSTM_INDICES) - CLZ = OperatorInfo( block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True ) # NPU specific operation @@ -297,6 +296,7 @@ class Op(Enum): Unique = OperatorInfo() Unpack = OperatorInfo(indices=NNG_IFM_INDICES) UnpackReshaped = OperatorInfo(indices=NNG_IFM_INDICES) + VariableTensorWrite = OperatorInfo() Where = OperatorInfo() While = OperatorInfo() ZerosLike = OperatorInfo() @@ -516,8 +516,8 @@ class Operation: self.memory_function: Optional[Op] = None # If not none: contains QuantizationParameters to be used as output quantization # (which overrides the ofm tensor's quantization), used in LUT - self.forced_input_quantization = None - self.forced_output_quantization = None + self.forced_input_quantization: Optional[QuantizationParameters] = None + self.forced_output_quantization: Optional[QuantizationParameters] = None self.scheduled_pass = None self.op_index = None # input network operator index self.activation_lut = None diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py index 74836eb2..ef4949fa 100644 --- a/ethosu/vela/operation_util.py +++ b/ethosu/vela/operation_util.py @@ -27,8 +27,10 @@ from .operation import ActivationFunction from .operation import Op from .operation import Operation from .operation import Padding +from .reader_util import clone_and_reshape_tensor from .shape4d import Shape4D from .tensor import create_const_tensor +from .tensor import create_equivalence_id from .tensor import QuantizationParameters from .tensor import Tensor @@ -117,6 +119,55 @@ def create_cast_op( return op +def create_fused_activation(op_type: Op, name: str, ifm: Tensor, quantization: QuantizationParameters) -> Operation: + assert op_type.is_activation_op() + op = create_avgpool_nop(name) + op.activation = ActivationFunction(op_type) + ofm = Tensor(ifm.shape, ifm.dtype, f"{op.name}_tens0") + ofm.quantization = quantization + op.add_input_tensor(ifm) + op.set_output_tensor(ofm) + op.set_ifm_ofm_shapes() + return op + + +def create_fullyconnected( + name: str, + ifm: Tensor, + weights: Tensor, + bias: Optional[Tensor], + quantization: QuantizationParameters, + vela_weight_order: bool = True, +) -> Operation: + # Reshape weights if needed + if not vela_weight_order: + weights = clone_and_reshape_tensor(weights, (1, 0), False) + + n_ofm = weights.shape[-1] + + # Setup bias if needed + if not bias: + bias_values = [0] * n_ofm + dtype = DataType.int64 if ifm.dtype == DataType.int16 else DataType.int32 + bias = create_const_tensor(f"{name}_bias", [n_ofm], dtype, bias_values) + # Set equivalence_id based on values to avoid placing duplicate data in flash + bias.equivalence_id = create_equivalence_id(tuple(bias_values)) + bias.value_id = bias.equivalence_id + + # Setup ofm + ofm = Tensor([ifm.shape[0], n_ofm], ifm.dtype, f"{name}_tens0") + ofm.quantization = quantization + + # Create op and add tensors + op = Operation(Op.FullyConnected, name) + op.add_input_tensor(ifm) + op.add_input_tensor(weights) + op.add_input_tensor(bias) + op.set_output_tensor(ofm) + op.set_ifm_ofm_shapes() + return op + + def create_depthwise_maxpool( name: str, ifm: Tensor, diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index e43a9191..932f701b 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -55,8 +55,6 @@ mac_main_ops = set( Op.QuantizedMatMul, Op.MatMul, Op.FullyConnected, - # RNN/LSTM/GRU - Op.BlockLSTM, # pooling Op.QuantizedMaxPool, Op.QuantizedAvgPool, diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 6fcb6c1d..cbd7ce44 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -1242,7 +1242,11 @@ class Scheduler: cost = schedule.cost_map[sched_op] if cost.cascade == 0 and sched_op.get_dependants(): ofm_tens = sched_op.ofm.connection.parent_tens - if not any(cons is None for cons in ofm_tens.consumer_list): + # Do not move subgraph outputs or Variable Tensor Writes + if ( + not any(cons is None for cons in ofm_tens.consumer_list) + and sched_op.parent_op.memory_function is not Op.VariableTensorWrite + ): if ofm_tens not in self.scratched_fms: # Remember default mem area and mem type, only done once self.scratched_fms[ofm_tens] = (ofm_tens.mem_area, ofm_tens.mem_type) @@ -1260,6 +1264,7 @@ class Scheduler: mem_type_set, lr_graph, ) + max_mem_usage = lr_graph.get_temporal_memory_usage(fast_storage_mem_area) # If max_mem_usage does not exceed staging limit at any point all lrs fit and can stay in fast storage diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index fd23d042..d4c92553 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -576,3 +576,37 @@ def test_matching_in_out_quant(): dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0) op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False) assert not semantic_checker.is_operator_semantic_valid(op) + + +def test_lstm_semantics(): + # Test valid configurations + op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8) + assert semantic_checker.is_operator_semantic_valid(op) + assert semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.int16)) + # Test invalid datatype + assert not semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.uint8)) + # Test invalid shape + ifm_shape = op.ifm.shape + ofm_shape = op.ofm.shape + op.ifm.shape = [12, 24] + assert not semantic_checker.is_operator_semantic_valid(op) + op.ifm.shape = ifm_shape + op.ofm.shape = [12, 20] + assert not semantic_checker.is_operator_semantic_valid(op) + op.ofm.shape = ofm_shape + # Test invalid number of intermediates + intermediate = op.intermediates.pop() + assert not semantic_checker.is_operator_semantic_valid(op) + op.intermediates.append(intermediate) + op.intermediates.append(intermediate) + assert not semantic_checker.is_operator_semantic_valid(op) + op.intermediates.pop() + # Test invalid number of inputs + input = op.inputs.pop() + assert not semantic_checker.is_operator_semantic_valid(op) + op.inputs.append(input) + op.inputs.append(input) + assert not semantic_checker.is_operator_semantic_valid(op) + op.inputs.pop() + # Test restored valid configuration + assert semantic_checker.is_operator_semantic_valid(op) diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index 2713adf9..04f10e9a 100644 --- a/ethosu/vela/test/test_tflite_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -623,3 +623,49 @@ def test_mean_hw_product_avgpool(): assert support.is_operator_supported(op) op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True}) assert not support.is_operator_supported(op) + + +def test_lstm_support(): + # Test valid configuration + op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8) + assert support.is_operator_supported(op) + # Test CIFG not supported + input_to_input_weights, recurrent_to_input_weights = op.inputs[1], op.inputs[5] + op.inputs[1] = None + assert not support.is_operator_supported(op) + op.inputs[1] = input_to_input_weights + op.inputs[5] = None + assert not support.is_operator_supported(op) + op.inputs[5] = recurrent_to_input_weights + # Test Peephole not supported + op.inputs[9] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[9] = None + op.inputs[10] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[10] = None + op.inputs[11] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[11] = None + # Test Projection not supported + op.inputs[16] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[16] = None + op.inputs[17] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[17] = None + # Test Normalisation not supported + op.inputs[20] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[20] = None + op.inputs[21] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[21] = None + op.inputs[22] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[22] = None + op.inputs[23] = input_to_input_weights + assert not support.is_operator_supported(op) + op.inputs[23] = None + # Test restored valid configuration + assert support.is_operator_supported(op) diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py index 88fc8747..e08bde24 100644 --- a/ethosu/vela/test/testutil.py +++ b/ethosu/vela/test/testutil.py @@ -103,7 +103,10 @@ def create_op_with_quant_tensors( def create_op(op_type, inputs, output, attrs=None, set_ifm_ofm_shapes=True): op = Operation(op_type, output.name + "_op") for input in inputs: - op.add_input_tensor(input) + if input: # Add regular tensor input + op.add_input_tensor(input) + else: # Add optional (None) inputs for operators with sparse input positioning + op.inputs.append(input) op.set_output_tensor(output) if attrs is not None: op.attrs = attrs @@ -112,6 +115,63 @@ def create_op(op_type, inputs, output, attrs=None, set_ifm_ofm_shapes=True): return op +def create_lstm_op(batches, times, features, outputs, datatype): + input_shape = [batches, times, features] + output_shape = [batches, times, outputs] + weight_shape = [features, outputs] + state_shape = [batches, outputs] + bias_shape = [outputs] + ifm = Tensor(input_shape, datatype, "in") + ifm.quantization = default_quant_params() + ofm = Tensor(output_shape, datatype, "out") + ofm.quantization = default_quant_params() + bias_dtype = DataType.int64 if datatype == DataType.int16 else DataType.int32 + bias = create_const_tensor("bias", bias_shape, bias_dtype, [0] * outputs) + weight_q = default_quant_params() + weight = create_const_tensor("weight", weight_shape, DataType.int8, np.ones(weight_shape), quantization=weight_q) + output_state = Tensor(state_shape, datatype, "output_state") + output_state.quantization = default_quant_params() + output_state.is_variable = True + cell_state = Tensor(state_shape, DataType.int16, "cell_state") + cell_state.quantization = default_quant_params() + cell_state.is_variable = True + intermediate = Tensor([], DataType.float32, "intermediate") + hidden_scale_intermediate = Tensor([], datatype, "effective_hidden_scale_intermediate") + hidden_scale_intermediate.quantization = default_quant_params() + peephole = None + projection = None + normalisation = None + inputs = [ + ifm, + weight, + weight, + weight, + weight, + weight, + weight, + weight, + weight, + peephole, + peephole, + peephole, + bias, + bias, + bias, + bias, + projection, + projection, + output_state, + cell_state, + normalisation, + normalisation, + normalisation, + normalisation, + ] + op = create_op(Op.UnidirectionalSequenceLstm, inputs, ofm) + op.intermediates = [intermediate, intermediate, intermediate, intermediate, hidden_scale_intermediate] + return op + + def create_subgraph(op_list): # Creates subgraph using the given list of operations sg = Subgraph() diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 077f4afa..478d0189 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -35,11 +35,13 @@ from .graph_optimiser_util import bypass_memory_only_ops from .graph_optimiser_util import calc_explicit_padding from .graph_optimiser_util import convert_depthwise_to_conv from .graph_optimiser_util import convert_to_lut +from .graph_optimiser_util import create_avg_pool_for_concat from .graph_optimiser_util import memory_only_ops from .graph_optimiser_util import move_splitsliceread_to_consumer from .graph_optimiser_util import needed_total_padding from .graph_optimiser_util import set_ifm_ofm_op_shapes from .graph_optimiser_util import set_tensor_equivalence +from .lstm import Lstm from .numeric_util import clamp_sigmoid from .numeric_util import full_shape from .numeric_util import round_away_zero @@ -69,23 +71,6 @@ from .tflite_mapping import optype_to_builtintype passthrough_nodes = (Op.Identity,) -def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D): - """Creates an average pool for the given concat op/input feature map""" - ofm = concat_op.ofm - avgpool_op = create_avgpool_nop(name) - avgpool_op.inputs = [ifm] - avgpool_op.outputs = [ofm] - - avgpool_op.write_offset = write_offset - avgpool_op.write_shape = ifm_shape - ofm.ops.append(avgpool_op) - avgpool_op.ifm_shapes.append(ifm_shape) - avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0]) - avgpool_op.memory_function = Op.ConcatSliceWrite - DebugDatabase.add_optimised(concat_op, avgpool_op) - return avgpool_op - - def remove_passthrough_tensor(tens, arch, nng): if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes: assert len(tens.ops[0].inputs) == 1 @@ -196,17 +181,15 @@ def rewrite_split_ops(tens, arch, nng): def remove_SplitSliceRead(op, arch): if op.type == Op.SplitSliceRead: - # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted - if ( - len(op.ofm.consumer_list) == 1 - and op.ofm.consumer_list[0] is not None - and op.ofm.consumer_list[0].run_on_npu - and op.ofm.consumer_list[0].type not in memory_only_ops - and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) + # Check if it is possible to put the SplitSliceRead on the tensor consumer(s), + # or if an avgpool need to be inserted + if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all( + consumer is not None and consumer.run_on_npu and consumer.type not in memory_only_ops + for consumer in op.ofm.consumer_list ): - # SplitSliceRead can be performed by tensor consumer - cons_op = op.ofm.consumer_list[0] - move_splitsliceread_to_consumer(op, cons_op) + # SplitSliceRead can be performed by tensor consumer(s) + for cons_op in list(op.ofm.consumer_list): + move_splitsliceread_to_consumer(op, cons_op) else: avgpool_op = create_avgpool_nop(op.name + "_avgpool") avgpool_op.add_input_tensor(op.ifm) @@ -801,8 +784,9 @@ def convert_nop_split_to_identity(op, arch, nng): def rewrite_fully_connected_input(op: Operation, arch, nng): - - if op.type == Op.FullyConnected: + # If the operation already have a read shape do not modify + # the ifm shape, since that will already be correct + if op.type == Op.FullyConnected and not op.read_shapes[0]: new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2]) assert new_shape is not None, "Tensor can not be reshaped to 2D" op.ifm_shapes[0] = new_shape @@ -1080,6 +1064,13 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng): return op +def convert_lstm(op, arch, nng): + if op.type == Op.UnidirectionalSequenceLstm: + lstm = Lstm(op) + op = lstm.get_graph() + return op + + def convert_softmax(op, arch, nng): if op.type == Op.Softmax and op.run_on_npu: softmax = SoftMax(op) @@ -2144,6 +2135,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): convert_mean_to_depthwise_conv_or_avgpool, convert_depthwise_to_conv, convert_conv_to_fc, + convert_lstm, convert_softmax, convert_prelu, convert_mul_max_to_abs_or_lrelu, diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 5661f36e..6ba7b835 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -193,6 +193,14 @@ class TFLiteSemantic: self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit) self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_argmax_output) + # UnidirectionalSequenceLstm specific checks: + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_input_signed) + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_matching_in_out_types) + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_dimensions) + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_inputs) + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_intermediates) + self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_variables) + def is_operator_semantic_valid(self, op): ext_type = optype_to_builtintype(op.type) @@ -627,6 +635,13 @@ class TFLiteSemantic: valid = (ifm_dtype == ofm_dtype) or (ofm_dtype == DataType.int32) return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}" + @staticmethod + def constraint_input_signed(op): + "IFM must be int8 or int16" + ifm_dtype = op.ifm.dtype + valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.int16) + return valid, f"Op has ifm_dtype={ifm_dtype}" + @staticmethod def constraint_input_8bit(op): "IFM must be int8 or uint8" @@ -689,6 +704,36 @@ class TFLiteSemantic: return False, f"IFM {op.ifm.shape} and OFM {op.ofm.shape} number of elements are not equal." return True, "IFM and OFM number of elements are equal." + @staticmethod + def constraint_lstm_dimensions(op): + "IFM and OFM must have 3D shape" + valid = len(op.ifm.shape) == len(op.ofm.shape) == 3 + return valid, f"Op has ifm shape {op.ifm.shape} and ofm shape {op.ofm.shape}" + + @staticmethod + def constraint_lstm_inputs(op): + "Must have 24 input tensors" + n_inputs = len(op.inputs) + return n_inputs == 24, f"Op has {n_inputs} inputs" + + @staticmethod + def constraint_lstm_intermediates(op): + "Must have 5 intermediate tensors" + n_intermediates = len(op.intermediates) + return n_intermediates == 5, f"Op has {n_intermediates} intermediates" + + @staticmethod + def constraint_lstm_variables(op): + "State tensors must be variable" + valid = True + extra = [] + for tens in op.inputs[18:20]: + if not tens.is_variable: + valid = False + extra.append(tens.name) + extra = ", ".join(extra) + return valid, f"Op has non-variable state tensor(s): {extra}" + def tflite_semantic_checker(nng): semantic_checker = TFLiteSemantic() diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 25f19b77..457c35eb 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -69,8 +69,8 @@ class TFLiteSupportedOperators: ) ) mac_main_ops = ( - # RNN/LSTM/GRU - set((Op.BlockLSTM,)) + # LSTM + set((Op.UnidirectionalSequenceLstm,)) # conv/depthwiseconv/transposeconv | convolution_like_ops # pooling @@ -320,6 +320,14 @@ class TFLiteSupportedOperators: self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis) self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth) + # UnidirectionalSequenceLstm specific checks: + op_type = Op.UnidirectionalSequenceLstm + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_cifg) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_peep_hole) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_projection) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_normalisation) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weights) + def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) if op.type not in TFLiteSupportedOperators.supported_operators: @@ -888,3 +896,35 @@ class TFLiteSupportedOperators: "IFM depth must be no greater than 127" ifm_depth = op.inputs[0].shape[-1] return ifm_depth <= 127, f"IFM depth is {ifm_depth}" + + @staticmethod + def constraint_lstm_no_cifg(op): + "Must not use CIFG" + cifg = None not in op.inputs[2:5] + op.inputs[6:9] + cifg = cifg and op.inputs[1] is None + cifg = cifg and op.inputs[5] is None + return not cifg, "Op uses CIFG" + + @staticmethod + def constraint_lstm_no_peep_hole(op): + "Must not use Peephole" + valid = all([tens is None for tens in op.inputs[9:12]]) + return valid, "Op uses peephole" + + @staticmethod + def constraint_lstm_no_projection(op): + "Must not use Projection" + valid = all([tens is None for tens in op.inputs[16:18]]) + return valid, "Op uses projection" + + @staticmethod + def constraint_lstm_no_normalisation(op): + "Must not use Normalisation" + valid = all([tens is None for tens in op.inputs[20:24]]) + return valid, "Op uses normalisation" + + @staticmethod + def constraint_lstm_weights(op): + "All input and recurrent weights must be available" + valid = None not in op.inputs[1:9] + return valid, "Op has missing weights" -- cgit v1.2.1