aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2023-04-11 22:35:04 +0200
committerFredrik Svedberg <fredrik.svedberg@arm.com>2023-04-17 14:16:44 +0200
commit0ac0804e76e098695ee2b8a9e24e2f0a1efc324f (patch)
tree9ccb766221987a415244079ed6c596a47d693b20
parentc1ad80b3a581dd39b39a112d6c2026f6560207a4 (diff)
downloadethos-u-vela-0ac0804e76e098695ee2b8a9e24e2f0a1efc324f.tar.gz
MLBEDSW-7196 Add LSTM support
Added int8 and int16 UNIDIRECTIONAL_SEQUENCE_LSTM support. The implementation does not include support for: * CIFG * Peephole * Projection * Normalisation This change also: * Removed unused Op.BlockLSTM operation type. * Removed the only one consumer limitation on putting the SplitSliceRead on the tensor consumer(s), if all consumers fullfills the requirements * Added Op.VariableTensorWrite as a Operation.memory_function to make sure writes to variable tensors: * Always use linear mode * Are not moved to fast scratch * Are not fused with other elementwise operation tensor ranges Change-Id: Ief831738924ac3d1f2ba6d41f10bd6dc969911f3 Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
-rw-r--r--SUPPORTED_OPS.md17
-rw-r--r--ethosu/vela/graph_optimiser_util.py25
-rw-r--r--ethosu/vela/live_range.py4
-rw-r--r--ethosu/vela/lstm.py447
-rw-r--r--ethosu/vela/operation.py8
-rw-r--r--ethosu/vela/operation_util.py51
-rw-r--r--ethosu/vela/pass_packing.py2
-rw-r--r--ethosu/vela/scheduler.py7
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py34
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py46
-rw-r--r--ethosu/vela/test/testutil.py62
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py50
-rw-r--r--ethosu/vela/tflite_model_semantic.py45
-rw-r--r--ethosu/vela/tflite_supported_operators.py44
14 files changed, 800 insertions, 42 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 08c63e7c..f641d3f2 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -55,6 +55,7 @@ Please check the supported operator list for your chosen runtime for further inf
| SUB | [Generic](#tflite-generic-constraints), [Specific](#tflite-sub-constraints) |
| TANH | [Generic](#tflite-generic-constraints) |
| TRANSPOSE_CONV | [Generic](#tflite-generic-constraints), [Specific](#tflite-transpose_conv-constraints) |
+| UNIDIRECTIONAL_SEQUENCE_LSTM | [Generic](#tflite-generic-constraints), [Specific](#tflite-unidirectional_sequence_lstm-constraints) |
| UNPACK | [Generic](#tflite-generic-constraints) |
### TFLite Generic Constraints
@@ -356,3 +357,19 @@ This is a list of constraints that the TRANSPOSE_CONV operator must satisfy in o
- SAME padding: OFM dimensions must equal IFM dimensions multiplied by stride
- VALID padding: OFM dimensions must equal IFM dimensions multiplied by stride,
minus difference between kernel size and stride
+
+### TFLite UNIDIRECTIONAL_SEQUENCE_LSTM Constraints
+
+This is a list of constraints that the UNIDIRECTIONAL_SEQUENCE_LSTM operator must satisfy in order to be scheduled on the NPU.
+
+- IFM must be int8 or int16
+- IFM and OFM data types must match
+- IFM and OFM must have 3D shape
+- Must have 24 input tensors
+- Must have 5 intermediate tensors
+- State tensors must be variable
+- Must not use CIFG
+- Must not use Peephole
+- Must not use Projection
+- Must not use Normalisation
+- All input and recurrent weights must be available
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index e1341d82..82790364 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -27,6 +27,7 @@ from .debug_database import DebugDatabase
from .errors import UnsupportedFeatureError
from .errors import VelaError
from .operation import Op
+from .operation_util import create_avgpool_nop
from .shape4d import Shape4D
from .tensor import create_const_tensor
from .tensor import QuantizationParameters
@@ -101,6 +102,10 @@ def check_format_restrictions(tens: Tensor, arch):
):
return
+ # Writing to the buffer of a variable tensor needs to be linear format
+ if tens.ops[0].memory_function == Op.VariableTensorWrite:
+ return
+
# Check if any of the producers/consumers is run on CPU
if not all(cons.run_on_npu for cons in tens.consumer_list):
return
@@ -222,7 +227,8 @@ def move_splitsliceread_to_consumer(op, cons_op):
cons_op.ifm_shapes[1] = op.ifm_shapes[0]
op.ofm.consumer_list.remove(cons_op)
op.ofm.ops = []
- op.ifm.consumer_list.remove(op)
+ if op in op.ifm.consumer_list:
+ op.ifm.consumer_list.remove(op)
def check_memory_only_removed(op, arch):
@@ -357,3 +363,20 @@ def convert_to_lut(op, lut_values, lut_name):
op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, op)
return op
+
+
+def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
+ """Creates an average pool for the given concat op/input feature map"""
+ ofm = concat_op.ofm
+ avgpool_op = create_avgpool_nop(name)
+ avgpool_op.inputs = [ifm]
+ avgpool_op.outputs = [ofm]
+
+ avgpool_op.write_offset = write_offset
+ avgpool_op.write_shape = ifm_shape
+ ofm.ops.append(avgpool_op)
+ avgpool_op.ifm_shapes.append(ifm_shape)
+ avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
+ avgpool_op.memory_function = Op.ConcatSliceWrite
+ DebugDatabase.add_optimised(concat_op, avgpool_op)
+ return avgpool_op
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 995a0ccb..3abcfcf0 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -166,9 +166,9 @@ def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None):
ifm_tens = None
- if sched_op.op_type.is_elementwise_op():
+ elem_op = sched_op.parent_op
+ if sched_op.op_type.is_elementwise_op() and elem_op.memory_function is not Op.VariableTensorWrite:
# Check if possible to merge ifm/ofm live ranges of elementwise op
- elem_op = sched_op.parent_op
if not tensor_should_be_ignored(elem_op.ofm, target_mem_area, target_mem_type_set):
# Check if overwriting the inputs can be allowed
OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
diff --git a/ethosu/vela/lstm.py b/ethosu/vela/lstm.py
new file mode 100644
index 00000000..5a50788b
--- /dev/null
+++ b/ethosu/vela/lstm.py
@@ -0,0 +1,447 @@
+# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Contains implementation of UnidirectionalSequenceLstm graph optimisation.
+from enum import Enum
+from typing import Tuple
+
+import numpy as np
+
+from .data_type import DataType
+from .debug_database import DebugDatabase
+from .graph_optimiser_util import create_avg_pool_for_concat
+from .operation import ActivationFunction
+from .operation import ExplicitScaling
+from .operation import Op
+from .operation import Operation
+from .operation_util import create_add
+from .operation_util import create_fullyconnected
+from .operation_util import create_fused_activation
+from .operation_util import create_mul
+from .scaling import elementwise_mul_scale
+from .shape4d import Shape4D
+from .tensor import QuantizationParameters
+from .tensor import Tensor
+
+Q0_15_SCALE = np.float32(0.00003051757)
+"""Q0.15 scale like the reference defines it"""
+
+
+class Lstm:
+ """Lstm graph optimisation.
+
+ Unrolls a UNIDIRECTIONAL_SEQUENCE_LSTM operation into its basic operations.
+
+ Usage:
+
+ unrolled_op = Lstm(op).get_graph()
+ """
+
+ class State(Enum):
+ """States (variable tensors)"""
+
+ OUTPUT = 18 # Value = tensor index
+ CELL = 19 # Value = tensor index
+
+ def __init__(self, op):
+ self.op = op
+
+ def get_graph(self) -> Operation:
+ """Return the generated graph implementation"""
+ self.op.ofm.ops = []
+ if self.time_major:
+ output_state = self.get_initial_state(Lstm.State.OUTPUT)
+ cell_state = self.get_initial_state(Lstm.State.CELL)
+ for time in range(self.n_time):
+ feature = self.get_feature(time)
+ output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time)
+ op = self.put_ofm(output_state, time)
+ else:
+ for batch in range(self.n_batch):
+ output_state = self.get_initial_state(Lstm.State.OUTPUT, batch)
+ cell_state = self.get_initial_state(Lstm.State.CELL, batch)
+ for time in range(self.n_time):
+ feature = self.get_feature(time, batch)
+ output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time, batch)
+ op = self.put_ofm(output_state, time, batch)
+ return op
+
+ def get_feature(self, time: int, batch: int = 0) -> Tensor:
+ """Get input feature for provided time and batch"""
+ feature = self.op.ifm.clone(f"_feature#{batch}.{time}")
+ feature.set_all_shapes([self.n_batch if self.time_major else 1, self.n_feature])
+ op = Operation(Op.SplitSliceRead, feature.name)
+ op.add_input_tensor(self.op.ifm)
+ op.set_output_tensor(feature)
+ op.set_ifm_ofm_shapes()
+ offset = [time, 0, 0] if self.time_major else [batch, time, 0]
+ op.read_offsets[0] = Shape4D.from_list(offset, 0)
+ op.read_shapes[0] = op.ofm_shapes[0]
+ DebugDatabase.add_optimised(self.op, op)
+ return feature
+
+ def get_initial_state(self, state_type: State, batch: int = 0) -> Tensor:
+ """Get state tensor for provided state type and batch"""
+ state = self.state(state_type)
+ if self.time_major:
+ # For time major just return the 2D state, since all batches
+ # are calculated at the same time
+ return state
+ else:
+ # For non time major return one batch of the 2D state
+ # by setting the read offset to the provided batch
+
+ # The cloned state tensor will share equivalence id and buffer
+ # with the variable state tensor
+ n_state = state.shape[-1]
+ state_ofm = state.clone(f"_state#{batch}")
+ # Set shape to be one batch
+ state_ofm.set_all_shapes([1, n_state])
+ # Create the op for reading one batch of the state
+ # (will be optimised away at a later stage)
+ op = Operation(Op.SplitSliceRead, state_ofm.name)
+ op.add_input_tensor(state)
+ op.set_output_tensor(state_ofm)
+ op.set_ifm_ofm_shapes()
+ # Set the read offset to the provided batch
+ op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
+ # Set the read shape to one batch, see above
+ op.read_shapes[0] = op.ofm_shapes[0]
+ DebugDatabase.add_optimised(self.op, op)
+ return state_ofm
+
+ def get_state(self, op: Operation, batch: int = 0) -> Operation:
+ """Setup the correct read offset for reading the state from
+ a variable tensor state"""
+ if not self.time_major and self.n_batch > 1:
+ op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
+ op.read_shapes[0] = Shape4D(op.ifm.shape)
+ op.ifm_shapes[0] = Shape4D([self.n_batch, op.ifm.shape[-1]])
+ return op
+
+ def put_state(self, op: Operation, state_type: State, batch: int = 0) -> Operation:
+ """Save the state for the provided batch by pointing the operations
+ ofm to the variable state tensor"""
+ # The create op functions always return 4D shape, however the state
+ # should have 2D shape for correct operation
+ op.ofm.shape = op.ofm.shape[-2:]
+ # Get state from type
+ state = self.state(state_type)
+ # By using the same equivalence_id the backing buffer for the ofm
+ # tensor will be the state variable tensor buffer
+ op.ofm.equivalence_id = state.equivalence_id
+ # Set memory function which will make the tensor be in linear format
+ # just as the state variable tensor
+ op.memory_function = Op.VariableTensorWrite
+ # Set the batch write offset into the state tensor buffer unless
+ # time_major mode when all batches are written at once
+ if not self.time_major:
+ op.write_offset = Shape4D.from_list([batch, 0], 0)
+ op.write_shape = Shape4D(op.ofm.shape)
+ op.ofm_shapes = [Shape4D(state.shape)]
+ DebugDatabase.add_optimised(self.op, op)
+ return op
+
+ def put_ofm(self, state: Tensor, time: int, batch: int = 0) -> Operation:
+ """Save the output state for the provided batch and time to OFM"""
+ name = f"{self.op.ofm.name}#{batch}.{time}"
+ offset = Shape4D.from_list([time, 0, 0] if self.time_major else [batch, time, 0], 0)
+ op = create_avg_pool_for_concat(self.op, name, state, Shape4D(state.shape), offset)
+ # The provided state tensor use the output state tensors buffer, so unless
+ # time_major mode we need to set the correct batch read offset
+ if not self.time_major:
+ op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
+ op.read_shapes[0] = Shape4D(state.shape)
+ op.ifm_shapes[0] = Shape4D(self.output_state.shape)
+ return op
+
+ def lstm_step(
+ self, feature: Tensor, output_state: Tensor, cell_state: Tensor, time: int, batch: int = 0
+ ) -> Tuple[Tensor, Tensor]:
+ """Generate one step of the LSTM implementation for the provided feature, batch and time"""
+ input_gate = self.calculate_gate(
+ f"input_gate#{batch}.{time}",
+ feature,
+ output_state,
+ self.input_to_input_weights,
+ self.input_bias,
+ self.recurrent_to_input_weights,
+ None,
+ Op.Sigmoid,
+ batch,
+ )
+ forget_gate = self.calculate_gate(
+ f"forget_gate#{batch}.{time}",
+ feature,
+ output_state,
+ self.input_to_forget_weights,
+ self.forget_bias,
+ self.recurrent_to_forget_weights,
+ None,
+ Op.Sigmoid,
+ batch,
+ )
+ cell_gate = self.calculate_gate(
+ f"cell_gate#{batch}.{time}",
+ feature,
+ output_state,
+ self.input_to_cell_weights,
+ self.cell_bias,
+ self.recurrent_to_cell_weights,
+ None,
+ Op.Tanh,
+ batch,
+ )
+ cell_state = self.calculate_cell_state(cell_state, input_gate, forget_gate, cell_gate, time, batch)
+ output_gate = self.calculate_gate(
+ f"output_gate#{batch}.{time}",
+ feature,
+ output_state,
+ self.input_to_output_weights,
+ self.output_bias,
+ self.recurrent_to_output_weights,
+ None,
+ Op.Sigmoid,
+ batch,
+ )
+ output_state = self.calculate_output_state(output_gate, cell_state, time, batch)
+ return (output_state, cell_state)
+
+ def calculate_gate(
+ self,
+ name: str,
+ input: Tensor,
+ state: Tensor,
+ input_weights: Tensor,
+ input_bias: Tensor,
+ recurrent_weights: Tensor,
+ recurrent_bias: Tensor,
+ activation: Op,
+ batch: int = 0,
+ ):
+ """Generate a gate for the provided input and weights"""
+ # Activation( Add( FC(input), FC(output state) ) )
+ # Setup fullyconnected quantization
+ q_fc = QuantizationParameters()
+ q_fc.scale_f32 = np.float32(2**-12)
+ q_fc.zero_point = 0
+ # Create fullyconnected
+ in_fc = create_fullyconnected(f"{name}:{input.name}_fc", input, input_weights, input_bias, q_fc, False)
+ re_fc = create_fullyconnected(f"{name}:{state.name}_fc", state, recurrent_weights, recurrent_bias, q_fc, False)
+ self.get_state(re_fc, batch)
+ # Change fullyconnected ofm data type
+ in_fc.ofm.dtype = DataType.int16
+ re_fc.ofm.dtype = DataType.int16
+ # Setup add quantization
+ q_add = q_fc.clone()
+ q_add.scale_f32 = np.float32(2**-15)
+ # Create add + activation
+ add = create_add(f"{name}_add", in_fc.ofm, re_fc.ofm, q_add, ActivationFunction(activation))
+ if activation is Op.Sigmoid:
+ # For Sigmoid we need to set the activation min/max values to match the possible range
+ # in the reference. The values below are the quantized min/max values that the reference
+ # can achive for the LUT based Sigmoid/Logistic. (The NPU does however have a larger range
+ # due to intermediate higher precision.)
+ # The quantized min/max values are divided by the effective output scale 0x3000 (3<<12) used for
+ # elementwise operations with fused Tanh/Sigmoid activations (to get correct scaling before the
+ # fused activation function). This will yield the dequantized min/max values which are later
+ # quantized again by the command stream generator.
+ add.activation.max = 32757 / 0x3000
+ add.activation.min = 11 / 0x3000
+ # Add to debug database
+ DebugDatabase.add_optimised(self.op, in_fc)
+ DebugDatabase.add_optimised(self.op, re_fc)
+ DebugDatabase.add_optimised(self.op, add)
+ return add.ofm
+
+ def calculate_cell_state(
+ self, cell_state: Tensor, input_gate: Tensor, forget_gate: Tensor, cell_gate: Tensor, time: int, batch: int = 0
+ ):
+ """Update the cell state from the provided gate output"""
+ # Clip( Add( Mul(cell state, forget gate), Mul(input gate, cell gate) ) )
+ base_name = f"cell_state#{batch}.{time}"
+ # Cell scale
+ cell_scale = cell_state.quantization.scale_f32
+ # Create mul(cell_state, forget_gate)
+ mul_cf = create_mul(f"{base_name}_cf_mul", cell_state, forget_gate, cell_state.quantization)
+ self.get_state(mul_cf, batch)
+ # Calculate explicit scales to match reference
+ multiplier, shift = elementwise_mul_scale(np.double(cell_scale), np.double(Q0_15_SCALE), np.double(cell_scale))
+ mul_cf.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
+ # Create mul(cell_gate, input_gate)
+ mul_ci = create_mul(f"{base_name}_ci_mul", cell_gate, input_gate, cell_state.quantization)
+ # Calculate explicit scales to match reference
+ multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(cell_scale))
+ mul_ci.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
+ # Setup cell clip
+ activation = None if self.cell_clip == 0 else ActivationFunction(Op.Clip)
+ if activation:
+ activation.max = self.cell_clip
+ activation.min = -self.cell_clip
+ # Create add + activation
+ add = create_add(f"{base_name}_add", mul_cf.ofm, mul_ci.ofm, cell_state.quantization, activation)
+ add.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
+ # Save new state
+ self.put_state(add, Lstm.State.CELL, batch)
+ # Add to debug database
+ DebugDatabase.add_optimised(self.op, mul_cf)
+ DebugDatabase.add_optimised(self.op, mul_ci)
+ DebugDatabase.add_optimised(self.op, add)
+ return add.ofm
+
+ def calculate_output_state(self, output_gate: Tensor, cell_state: Tensor, time: int, batch: int):
+ """Generate the output state from the provided gate output"""
+ # Mul( Tanh(cell state), output gate )
+ base_name = f"output_state#{batch}.{time}"
+ # Setup tanh quantization
+ q_out_tanh = QuantizationParameters()
+ q_out_tanh.scale_f32 = np.float32(2**-15)
+ q_out_tanh.zero_point = 0
+ # Create tanh(cell state)
+ tanh = create_fused_activation(Op.Tanh, f"{base_name}_tanh", cell_state, q_out_tanh)
+ self.get_state(tanh, batch)
+ # Create Mul( Tanh(cell state), output gate )
+ q_mul = self.output_state.quantization
+ mul = create_mul(f"{base_name}_mul", tanh.ofm, output_gate, q_mul, dtype=self.op.ifm.dtype)
+ # Use explicit scaling to match reference, the following line would have been the preferred way
+ # mul.forced_output_quantization = self.hidden_quantization
+ out_scale = self.hidden_quantization.scale_f32
+ multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(out_scale))
+ mul.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
+ # Save new state
+ self.put_state(mul, Lstm.State.OUTPUT, batch)
+ # Add to debug database
+ DebugDatabase.add_optimised(self.op, tanh)
+ DebugDatabase.add_optimised(self.op, mul)
+ return mul.ofm
+
+ def state(self, state_type: State) -> Tensor:
+ """Get state tensor from type"""
+ return self.output_state if state_type == Lstm.State.OUTPUT else self.cell_state
+
+ # Dimensions
+ @property
+ def n_feature(self) -> int:
+ return self.op.ifm.shape[-1]
+
+ @property
+ def n_time(self) -> int:
+ return self.op.ifm.shape[0 if self.time_major else 1]
+
+ @property
+ def n_batch(self) -> int:
+ return self.op.ifm.shape[1 if self.time_major else 0]
+
+ # Attributes
+ @property
+ def cell_clip(self) -> int:
+ return self.op.attrs.get("cell_clip", 0)
+
+ @property
+ def projection_clip(self) -> int:
+ return self.op.attrs.get("proj_clip", 0)
+
+ @property
+ def time_major(self) -> bool:
+ return self.op.attrs.get("time_major", False)
+
+ # Hidden (intermediate)
+ @property
+ def hidden_quantization(self) -> QuantizationParameters:
+ return self.op.intermediates[4].quantization
+
+ # Input weights
+ @property
+ def input_to_input_weights(self) -> Tensor:
+ return self.op.inputs[1]
+
+ @property
+ def input_to_forget_weights(self) -> Tensor:
+ return self.op.inputs[2]
+
+ @property
+ def input_to_cell_weights(self) -> Tensor:
+ return self.op.inputs[3]
+
+ @property
+ def input_to_output_weights(self) -> Tensor:
+ return self.op.inputs[4]
+
+ # Recurrent weights
+ @property
+ def recurrent_to_input_weights(self) -> Tensor:
+ return self.op.inputs[5]
+
+ @property
+ def recurrent_to_forget_weights(self) -> Tensor:
+ return self.op.inputs[6]
+
+ @property
+ def recurrent_to_cell_weights(self) -> Tensor:
+ return self.op.inputs[7]
+
+ @property
+ def recurrent_to_output_weights(self) -> Tensor:
+ return self.op.inputs[8]
+
+ # Peephole weights
+ @property
+ def cell_to_input_weights(self) -> Tensor:
+ return self.op.inputs[9]
+
+ @property
+ def cell_to_forget_weights(self) -> Tensor:
+ return self.op.inputs[10]
+
+ @property
+ def cell_to_output_weights(self) -> Tensor:
+ return self.op.inputs[11]
+
+ # Bias tensors
+ @property
+ def input_bias(self) -> Tensor:
+ return self.op.inputs[12]
+
+ @property
+ def forget_bias(self) -> Tensor:
+ return self.op.inputs[13]
+
+ @property
+ def cell_bias(self) -> Tensor:
+ return self.op.inputs[14]
+
+ @property
+ def output_bias(self) -> Tensor:
+ return self.op.inputs[15]
+
+ # Projection tensors
+ @property
+ def projection_weights(self) -> Tensor:
+ return self.op.inputs[16]
+
+ @property
+ def projection_bias(self) -> Tensor:
+ return self.op.inputs[17]
+
+ # State tensors (variable)
+ @property
+ def output_state(self) -> Tensor:
+ return self.op.inputs[Lstm.State.OUTPUT.value]
+
+ @property
+ def cell_state(self) -> Tensor:
+ return self.op.inputs[Lstm.State.CELL.value]
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 67717104..d1670536 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -37,6 +37,7 @@ from .shape4d import Shape4D
# Import needed for Type annotations. Only import for Type checking to avoid run-time errors due to cyclic import.
if TYPE_CHECKING:
+ from .tensor import QuantizationParameters
from .tensor import Tensor
PointXY = namedtuple("PointXY", "x y")
@@ -142,8 +143,6 @@ class Op(Enum):
BatchToSpaceND = OperatorInfo()
BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
- BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_BLOCK_LSTM_INDICES)
-
CLZ = OperatorInfo(
block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True
) # NPU specific operation
@@ -297,6 +296,7 @@ class Op(Enum):
Unique = OperatorInfo()
Unpack = OperatorInfo(indices=NNG_IFM_INDICES)
UnpackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
+ VariableTensorWrite = OperatorInfo()
Where = OperatorInfo()
While = OperatorInfo()
ZerosLike = OperatorInfo()
@@ -516,8 +516,8 @@ class Operation:
self.memory_function: Optional[Op] = None
# If not none: contains QuantizationParameters to be used as output quantization
# (which overrides the ofm tensor's quantization), used in LUT
- self.forced_input_quantization = None
- self.forced_output_quantization = None
+ self.forced_input_quantization: Optional[QuantizationParameters] = None
+ self.forced_output_quantization: Optional[QuantizationParameters] = None
self.scheduled_pass = None
self.op_index = None # input network operator index
self.activation_lut = None
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 74836eb2..ef4949fa 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -27,8 +27,10 @@ from .operation import ActivationFunction
from .operation import Op
from .operation import Operation
from .operation import Padding
+from .reader_util import clone_and_reshape_tensor
from .shape4d import Shape4D
from .tensor import create_const_tensor
+from .tensor import create_equivalence_id
from .tensor import QuantizationParameters
from .tensor import Tensor
@@ -117,6 +119,55 @@ def create_cast_op(
return op
+def create_fused_activation(op_type: Op, name: str, ifm: Tensor, quantization: QuantizationParameters) -> Operation:
+ assert op_type.is_activation_op()
+ op = create_avgpool_nop(name)
+ op.activation = ActivationFunction(op_type)
+ ofm = Tensor(ifm.shape, ifm.dtype, f"{op.name}_tens0")
+ ofm.quantization = quantization
+ op.add_input_tensor(ifm)
+ op.set_output_tensor(ofm)
+ op.set_ifm_ofm_shapes()
+ return op
+
+
+def create_fullyconnected(
+ name: str,
+ ifm: Tensor,
+ weights: Tensor,
+ bias: Optional[Tensor],
+ quantization: QuantizationParameters,
+ vela_weight_order: bool = True,
+) -> Operation:
+ # Reshape weights if needed
+ if not vela_weight_order:
+ weights = clone_and_reshape_tensor(weights, (1, 0), False)
+
+ n_ofm = weights.shape[-1]
+
+ # Setup bias if needed
+ if not bias:
+ bias_values = [0] * n_ofm
+ dtype = DataType.int64 if ifm.dtype == DataType.int16 else DataType.int32
+ bias = create_const_tensor(f"{name}_bias", [n_ofm], dtype, bias_values)
+ # Set equivalence_id based on values to avoid placing duplicate data in flash
+ bias.equivalence_id = create_equivalence_id(tuple(bias_values))
+ bias.value_id = bias.equivalence_id
+
+ # Setup ofm
+ ofm = Tensor([ifm.shape[0], n_ofm], ifm.dtype, f"{name}_tens0")
+ ofm.quantization = quantization
+
+ # Create op and add tensors
+ op = Operation(Op.FullyConnected, name)
+ op.add_input_tensor(ifm)
+ op.add_input_tensor(weights)
+ op.add_input_tensor(bias)
+ op.set_output_tensor(ofm)
+ op.set_ifm_ofm_shapes()
+ return op
+
+
def create_depthwise_maxpool(
name: str,
ifm: Tensor,
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index e43a9191..932f701b 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -55,8 +55,6 @@ mac_main_ops = set(
Op.QuantizedMatMul,
Op.MatMul,
Op.FullyConnected,
- # RNN/LSTM/GRU
- Op.BlockLSTM,
# pooling
Op.QuantizedMaxPool,
Op.QuantizedAvgPool,
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 6fcb6c1d..cbd7ce44 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -1242,7 +1242,11 @@ class Scheduler:
cost = schedule.cost_map[sched_op]
if cost.cascade == 0 and sched_op.get_dependants():
ofm_tens = sched_op.ofm.connection.parent_tens
- if not any(cons is None for cons in ofm_tens.consumer_list):
+ # Do not move subgraph outputs or Variable Tensor Writes
+ if (
+ not any(cons is None for cons in ofm_tens.consumer_list)
+ and sched_op.parent_op.memory_function is not Op.VariableTensorWrite
+ ):
if ofm_tens not in self.scratched_fms:
# Remember default mem area and mem type, only done once
self.scratched_fms[ofm_tens] = (ofm_tens.mem_area, ofm_tens.mem_type)
@@ -1260,6 +1264,7 @@ class Scheduler:
mem_type_set,
lr_graph,
)
+
max_mem_usage = lr_graph.get_temporal_memory_usage(fast_storage_mem_area)
# If max_mem_usage does not exceed staging limit at any point all lrs fit and can stay in fast storage
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index fd23d042..d4c92553 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -576,3 +576,37 @@ def test_matching_in_out_quant():
dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0)
op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False)
assert not semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_lstm_semantics():
+ # Test valid configurations
+ op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ assert semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.int16))
+ # Test invalid datatype
+ assert not semantic_checker.is_operator_semantic_valid(testutil.create_lstm_op(3, 12, 24, 20, DataType.uint8))
+ # Test invalid shape
+ ifm_shape = op.ifm.shape
+ ofm_shape = op.ofm.shape
+ op.ifm.shape = [12, 24]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.ifm.shape = ifm_shape
+ op.ofm.shape = [12, 20]
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.ofm.shape = ofm_shape
+ # Test invalid number of intermediates
+ intermediate = op.intermediates.pop()
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.intermediates.append(intermediate)
+ op.intermediates.append(intermediate)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.intermediates.pop()
+ # Test invalid number of inputs
+ input = op.inputs.pop()
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.inputs.append(input)
+ op.inputs.append(input)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ op.inputs.pop()
+ # Test restored valid configuration
+ assert semantic_checker.is_operator_semantic_valid(op)
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 2713adf9..04f10e9a 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -623,3 +623,49 @@ def test_mean_hw_product_avgpool():
assert support.is_operator_supported(op)
op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
assert not support.is_operator_supported(op)
+
+
+def test_lstm_support():
+ # Test valid configuration
+ op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8)
+ assert support.is_operator_supported(op)
+ # Test CIFG not supported
+ input_to_input_weights, recurrent_to_input_weights = op.inputs[1], op.inputs[5]
+ op.inputs[1] = None
+ assert not support.is_operator_supported(op)
+ op.inputs[1] = input_to_input_weights
+ op.inputs[5] = None
+ assert not support.is_operator_supported(op)
+ op.inputs[5] = recurrent_to_input_weights
+ # Test Peephole not supported
+ op.inputs[9] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[9] = None
+ op.inputs[10] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[10] = None
+ op.inputs[11] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[11] = None
+ # Test Projection not supported
+ op.inputs[16] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[16] = None
+ op.inputs[17] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[17] = None
+ # Test Normalisation not supported
+ op.inputs[20] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[20] = None
+ op.inputs[21] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[21] = None
+ op.inputs[22] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[22] = None
+ op.inputs[23] = input_to_input_weights
+ assert not support.is_operator_supported(op)
+ op.inputs[23] = None
+ # Test restored valid configuration
+ assert support.is_operator_supported(op)
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index 88fc8747..e08bde24 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -103,7 +103,10 @@ def create_op_with_quant_tensors(
def create_op(op_type, inputs, output, attrs=None, set_ifm_ofm_shapes=True):
op = Operation(op_type, output.name + "_op")
for input in inputs:
- op.add_input_tensor(input)
+ if input: # Add regular tensor input
+ op.add_input_tensor(input)
+ else: # Add optional (None) inputs for operators with sparse input positioning
+ op.inputs.append(input)
op.set_output_tensor(output)
if attrs is not None:
op.attrs = attrs
@@ -112,6 +115,63 @@ def create_op(op_type, inputs, output, attrs=None, set_ifm_ofm_shapes=True):
return op
+def create_lstm_op(batches, times, features, outputs, datatype):
+ input_shape = [batches, times, features]
+ output_shape = [batches, times, outputs]
+ weight_shape = [features, outputs]
+ state_shape = [batches, outputs]
+ bias_shape = [outputs]
+ ifm = Tensor(input_shape, datatype, "in")
+ ifm.quantization = default_quant_params()
+ ofm = Tensor(output_shape, datatype, "out")
+ ofm.quantization = default_quant_params()
+ bias_dtype = DataType.int64 if datatype == DataType.int16 else DataType.int32
+ bias = create_const_tensor("bias", bias_shape, bias_dtype, [0] * outputs)
+ weight_q = default_quant_params()
+ weight = create_const_tensor("weight", weight_shape, DataType.int8, np.ones(weight_shape), quantization=weight_q)
+ output_state = Tensor(state_shape, datatype, "output_state")
+ output_state.quantization = default_quant_params()
+ output_state.is_variable = True
+ cell_state = Tensor(state_shape, DataType.int16, "cell_state")
+ cell_state.quantization = default_quant_params()
+ cell_state.is_variable = True
+ intermediate = Tensor([], DataType.float32, "intermediate")
+ hidden_scale_intermediate = Tensor([], datatype, "effective_hidden_scale_intermediate")
+ hidden_scale_intermediate.quantization = default_quant_params()
+ peephole = None
+ projection = None
+ normalisation = None
+ inputs = [
+ ifm,
+ weight,
+ weight,
+ weight,
+ weight,
+ weight,
+ weight,
+ weight,
+ weight,
+ peephole,
+ peephole,
+ peephole,
+ bias,
+ bias,
+ bias,
+ bias,
+ projection,
+ projection,
+ output_state,
+ cell_state,
+ normalisation,
+ normalisation,
+ normalisation,
+ normalisation,
+ ]
+ op = create_op(Op.UnidirectionalSequenceLstm, inputs, ofm)
+ op.intermediates = [intermediate, intermediate, intermediate, intermediate, hidden_scale_intermediate]
+ return op
+
+
def create_subgraph(op_list):
# Creates subgraph using the given list of operations
sg = Subgraph()
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 077f4afa..478d0189 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -35,11 +35,13 @@ from .graph_optimiser_util import bypass_memory_only_ops
from .graph_optimiser_util import calc_explicit_padding
from .graph_optimiser_util import convert_depthwise_to_conv
from .graph_optimiser_util import convert_to_lut
+from .graph_optimiser_util import create_avg_pool_for_concat
from .graph_optimiser_util import memory_only_ops
from .graph_optimiser_util import move_splitsliceread_to_consumer
from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
+from .lstm import Lstm
from .numeric_util import clamp_sigmoid
from .numeric_util import full_shape
from .numeric_util import round_away_zero
@@ -69,23 +71,6 @@ from .tflite_mapping import optype_to_builtintype
passthrough_nodes = (Op.Identity,)
-def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
- """Creates an average pool for the given concat op/input feature map"""
- ofm = concat_op.ofm
- avgpool_op = create_avgpool_nop(name)
- avgpool_op.inputs = [ifm]
- avgpool_op.outputs = [ofm]
-
- avgpool_op.write_offset = write_offset
- avgpool_op.write_shape = ifm_shape
- ofm.ops.append(avgpool_op)
- avgpool_op.ifm_shapes.append(ifm_shape)
- avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
- avgpool_op.memory_function = Op.ConcatSliceWrite
- DebugDatabase.add_optimised(concat_op, avgpool_op)
- return avgpool_op
-
-
def remove_passthrough_tensor(tens, arch, nng):
if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
assert len(tens.ops[0].inputs) == 1
@@ -196,17 +181,15 @@ def rewrite_split_ops(tens, arch, nng):
def remove_SplitSliceRead(op, arch):
if op.type == Op.SplitSliceRead:
- # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
- if (
- len(op.ofm.consumer_list) == 1
- and op.ofm.consumer_list[0] is not None
- and op.ofm.consumer_list[0].run_on_npu
- and op.ofm.consumer_list[0].type not in memory_only_ops
- and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
+ # Check if it is possible to put the SplitSliceRead on the tensor consumer(s),
+ # or if an avgpool need to be inserted
+ if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all(
+ consumer is not None and consumer.run_on_npu and consumer.type not in memory_only_ops
+ for consumer in op.ofm.consumer_list
):
- # SplitSliceRead can be performed by tensor consumer
- cons_op = op.ofm.consumer_list[0]
- move_splitsliceread_to_consumer(op, cons_op)
+ # SplitSliceRead can be performed by tensor consumer(s)
+ for cons_op in list(op.ofm.consumer_list):
+ move_splitsliceread_to_consumer(op, cons_op)
else:
avgpool_op = create_avgpool_nop(op.name + "_avgpool")
avgpool_op.add_input_tensor(op.ifm)
@@ -801,8 +784,9 @@ def convert_nop_split_to_identity(op, arch, nng):
def rewrite_fully_connected_input(op: Operation, arch, nng):
-
- if op.type == Op.FullyConnected:
+ # If the operation already have a read shape do not modify
+ # the ifm shape, since that will already be correct
+ if op.type == Op.FullyConnected and not op.read_shapes[0]:
new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
assert new_shape is not None, "Tensor can not be reshaped to 2D"
op.ifm_shapes[0] = new_shape
@@ -1080,6 +1064,13 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
return op
+def convert_lstm(op, arch, nng):
+ if op.type == Op.UnidirectionalSequenceLstm:
+ lstm = Lstm(op)
+ op = lstm.get_graph()
+ return op
+
+
def convert_softmax(op, arch, nng):
if op.type == Op.Softmax and op.run_on_npu:
softmax = SoftMax(op)
@@ -2144,6 +2135,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
convert_mean_to_depthwise_conv_or_avgpool,
convert_depthwise_to_conv,
convert_conv_to_fc,
+ convert_lstm,
convert_softmax,
convert_prelu,
convert_mul_max_to_abs_or_lrelu,
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 5661f36e..6ba7b835 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -193,6 +193,14 @@ class TFLiteSemantic:
self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit)
self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_argmax_output)
+ # UnidirectionalSequenceLstm specific checks:
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_input_signed)
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_matching_in_out_types)
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_dimensions)
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_inputs)
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_intermediates)
+ self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_variables)
+
def is_operator_semantic_valid(self, op):
ext_type = optype_to_builtintype(op.type)
@@ -628,6 +636,13 @@ class TFLiteSemantic:
return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
@staticmethod
+ def constraint_input_signed(op):
+ "IFM must be int8 or int16"
+ ifm_dtype = op.ifm.dtype
+ valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.int16)
+ return valid, f"Op has ifm_dtype={ifm_dtype}"
+
+ @staticmethod
def constraint_input_8bit(op):
"IFM must be int8 or uint8"
ifm_dtype = op.ifm.dtype
@@ -689,6 +704,36 @@ class TFLiteSemantic:
return False, f"IFM {op.ifm.shape} and OFM {op.ofm.shape} number of elements are not equal."
return True, "IFM and OFM number of elements are equal."
+ @staticmethod
+ def constraint_lstm_dimensions(op):
+ "IFM and OFM must have 3D shape"
+ valid = len(op.ifm.shape) == len(op.ofm.shape) == 3
+ return valid, f"Op has ifm shape {op.ifm.shape} and ofm shape {op.ofm.shape}"
+
+ @staticmethod
+ def constraint_lstm_inputs(op):
+ "Must have 24 input tensors"
+ n_inputs = len(op.inputs)
+ return n_inputs == 24, f"Op has {n_inputs} inputs"
+
+ @staticmethod
+ def constraint_lstm_intermediates(op):
+ "Must have 5 intermediate tensors"
+ n_intermediates = len(op.intermediates)
+ return n_intermediates == 5, f"Op has {n_intermediates} intermediates"
+
+ @staticmethod
+ def constraint_lstm_variables(op):
+ "State tensors must be variable"
+ valid = True
+ extra = []
+ for tens in op.inputs[18:20]:
+ if not tens.is_variable:
+ valid = False
+ extra.append(tens.name)
+ extra = ", ".join(extra)
+ return valid, f"Op has non-variable state tensor(s): {extra}"
+
def tflite_semantic_checker(nng):
semantic_checker = TFLiteSemantic()
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 25f19b77..457c35eb 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -69,8 +69,8 @@ class TFLiteSupportedOperators:
)
)
mac_main_ops = (
- # RNN/LSTM/GRU
- set((Op.BlockLSTM,))
+ # LSTM
+ set((Op.UnidirectionalSequenceLstm,))
# conv/depthwiseconv/transposeconv
| convolution_like_ops
# pooling
@@ -320,6 +320,14 @@ class TFLiteSupportedOperators:
self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis)
self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth)
+ # UnidirectionalSequenceLstm specific checks:
+ op_type = Op.UnidirectionalSequenceLstm
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_cifg)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_peep_hole)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_projection)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_no_normalisation)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weights)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -888,3 +896,35 @@ class TFLiteSupportedOperators:
"IFM depth must be no greater than 127"
ifm_depth = op.inputs[0].shape[-1]
return ifm_depth <= 127, f"IFM depth is {ifm_depth}"
+
+ @staticmethod
+ def constraint_lstm_no_cifg(op):
+ "Must not use CIFG"
+ cifg = None not in op.inputs[2:5] + op.inputs[6:9]
+ cifg = cifg and op.inputs[1] is None
+ cifg = cifg and op.inputs[5] is None
+ return not cifg, "Op uses CIFG"
+
+ @staticmethod
+ def constraint_lstm_no_peep_hole(op):
+ "Must not use Peephole"
+ valid = all([tens is None for tens in op.inputs[9:12]])
+ return valid, "Op uses peephole"
+
+ @staticmethod
+ def constraint_lstm_no_projection(op):
+ "Must not use Projection"
+ valid = all([tens is None for tens in op.inputs[16:18]])
+ return valid, "Op uses projection"
+
+ @staticmethod
+ def constraint_lstm_no_normalisation(op):
+ "Must not use Normalisation"
+ valid = all([tens is None for tens in op.inputs[20:24]])
+ return valid, "Op uses normalisation"
+
+ @staticmethod
+ def constraint_lstm_weights(op):
+ "All input and recurrent weights must be available"
+ valid = None not in op.inputs[1:9]
+ return valid, "Op has missing weights"