aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/live_range.py
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2023-04-11 22:35:04 +0200
committerFredrik Svedberg <fredrik.svedberg@arm.com>2023-04-17 14:16:44 +0200
commit0ac0804e76e098695ee2b8a9e24e2f0a1efc324f (patch)
tree9ccb766221987a415244079ed6c596a47d693b20 /ethosu/vela/live_range.py
parentc1ad80b3a581dd39b39a112d6c2026f6560207a4 (diff)
downloadethos-u-vela-0ac0804e76e098695ee2b8a9e24e2f0a1efc324f.tar.gz
MLBEDSW-7196 Add LSTM support
Added int8 and int16 UNIDIRECTIONAL_SEQUENCE_LSTM support. The implementation does not include support for: * CIFG * Peephole * Projection * Normalisation This change also: * Removed unused Op.BlockLSTM operation type. * Removed the only one consumer limitation on putting the SplitSliceRead on the tensor consumer(s), if all consumers fullfills the requirements * Added Op.VariableTensorWrite as a Operation.memory_function to make sure writes to variable tensors: * Always use linear mode * Are not moved to fast scratch * Are not fused with other elementwise operation tensor ranges Change-Id: Ief831738924ac3d1f2ba6d41f10bd6dc969911f3 Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Diffstat (limited to 'ethosu/vela/live_range.py')
-rw-r--r--ethosu/vela/live_range.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 995a0ccb..3abcfcf0 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -166,9 +166,9 @@ def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None):
ifm_tens = None
- if sched_op.op_type.is_elementwise_op():
+ elem_op = sched_op.parent_op
+ if sched_op.op_type.is_elementwise_op() and elem_op.memory_function is not Op.VariableTensorWrite:
# Check if possible to merge ifm/ofm live ranges of elementwise op
- elem_op = sched_op.parent_op
if not tensor_should_be_ignored(elem_op.ofm, target_mem_area, target_mem_type_set):
# Check if overwriting the inputs can be allowed
OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])