diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2023-04-11 22:35:04 +0200 |
---|---|---|
committer | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2023-04-17 14:16:44 +0200 |
commit | 0ac0804e76e098695ee2b8a9e24e2f0a1efc324f (patch) | |
tree | 9ccb766221987a415244079ed6c596a47d693b20 /ethosu/vela/scheduler.py | |
parent | c1ad80b3a581dd39b39a112d6c2026f6560207a4 (diff) | |
download | ethos-u-vela-0ac0804e76e098695ee2b8a9e24e2f0a1efc324f.tar.gz |
MLBEDSW-7196 Add LSTM support
Added int8 and int16 UNIDIRECTIONAL_SEQUENCE_LSTM support.
The implementation does not include support for:
* CIFG
* Peephole
* Projection
* Normalisation
This change also:
* Removed unused Op.BlockLSTM operation type.
* Removed the only one consumer limitation on putting the SplitSliceRead
on the tensor consumer(s), if all consumers fullfills the requirements
* Added Op.VariableTensorWrite as a Operation.memory_function to make
sure writes to variable tensors:
* Always use linear mode
* Are not moved to fast scratch
* Are not fused with other elementwise operation tensor ranges
Change-Id: Ief831738924ac3d1f2ba6d41f10bd6dc969911f3
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Diffstat (limited to 'ethosu/vela/scheduler.py')
-rw-r--r-- | ethosu/vela/scheduler.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 6fcb6c1d..cbd7ce44 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -1242,7 +1242,11 @@ class Scheduler: cost = schedule.cost_map[sched_op] if cost.cascade == 0 and sched_op.get_dependants(): ofm_tens = sched_op.ofm.connection.parent_tens - if not any(cons is None for cons in ofm_tens.consumer_list): + # Do not move subgraph outputs or Variable Tensor Writes + if ( + not any(cons is None for cons in ofm_tens.consumer_list) + and sched_op.parent_op.memory_function is not Op.VariableTensorWrite + ): if ofm_tens not in self.scratched_fms: # Remember default mem area and mem type, only done once self.scratched_fms[ofm_tens] = (ofm_tens.mem_area, ofm_tens.mem_type) @@ -1260,6 +1264,7 @@ class Scheduler: mem_type_set, lr_graph, ) + max_mem_usage = lr_graph.get_temporal_memory_usage(fast_storage_mem_area) # If max_mem_usage does not exceed staging limit at any point all lrs fit and can stay in fast storage |