diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2021-01-27 16:53:41 +0100 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2021-01-28 09:08:43 +0000 |
commit | e22ba8cb3090886b2d80a2df0e599dbf4cd7f483 (patch) | |
tree | c980972830ec15613ab291c7731968ae1a010609 /ethosu/vela/mark_tensors.py | |
parent | 60232140a2927865d1f6f9bc48871df3b2bb135b (diff) | |
download | ethos-u-vela-e22ba8cb3090886b2d80a2df0e599dbf4cd7f483.tar.gz |
[MLBEDSW-3891] Fix reading back in an ethos-u custom op
Fixed assertion when reading back in an ethos-u custom op.
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: I275ec9187ffead1e96f2522ecbd658328fa4ef69
Diffstat (limited to 'ethosu/vela/mark_tensors.py')
-rw-r--r-- | ethosu/vela/mark_tensors.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py index 5a475841..114649d9 100644 --- a/ethosu/vela/mark_tensors.py +++ b/ethosu/vela/mark_tensors.py @@ -24,7 +24,7 @@ from .tensor import TensorPurpose def get_format(purpose, arch): - if purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT, TensorPurpose.Scratch): + if purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT, TensorPurpose.Scratch, TensorPurpose.ScratchFast): fmt = arch.default_feature_map_format elif purpose == TensorPurpose.Weights: fmt = arch.default_weight_format @@ -46,7 +46,11 @@ def mark_purpose(tens, arch, purpose): tens.mem_area = arch.tensor_storage_mem_area[tens.purpose] tens.mem_type = arch.tensor_storage_mem_type[tens.purpose] - if len(tens.ops) == 1 and tens.ops[0].type == Op.Const: + if ( + len(tens.ops) == 1 + and tens.ops[0].type == Op.Const + and purpose not in (TensorPurpose.Scratch, TensorPurpose.ScratchFast) + ): tens.mem_area = arch.permanent_storage_mem_area # special case constants, as they must be in permanent storage tens.mem_type = MemType.Permanent_NPU @@ -79,6 +83,11 @@ def rewrite_mark_tensor_purpose(op, arch): if scratch_tensor.name.endswith("_scratch"): scratch_tensor.purpose = TensorPurpose.Scratch + if len(op.inputs) >= 4: + scratch_fast_tensor = op.inputs[3] # should be existing scratch fast tensor + if scratch_fast_tensor.name.endswith("_scratch_fast"): + scratch_fast_tensor.purpose = TensorPurpose.ScratchFast + if scratch_tensor is None: op.error("Scratch tensor not found.") |