aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/mark_tensors.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/mark_tensors.py')
-rw-r--r--ethosu/vela/mark_tensors.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index 5a475841..114649d9 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -24,7 +24,7 @@ from .tensor import TensorPurpose
def get_format(purpose, arch):
- if purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT, TensorPurpose.Scratch):
+ if purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT, TensorPurpose.Scratch, TensorPurpose.ScratchFast):
fmt = arch.default_feature_map_format
elif purpose == TensorPurpose.Weights:
fmt = arch.default_weight_format
@@ -46,7 +46,11 @@ def mark_purpose(tens, arch, purpose):
tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
tens.mem_type = arch.tensor_storage_mem_type[tens.purpose]
- if len(tens.ops) == 1 and tens.ops[0].type == Op.Const:
+ if (
+ len(tens.ops) == 1
+ and tens.ops[0].type == Op.Const
+ and purpose not in (TensorPurpose.Scratch, TensorPurpose.ScratchFast)
+ ):
tens.mem_area = arch.permanent_storage_mem_area # special case constants, as they must be in permanent storage
tens.mem_type = MemType.Permanent_NPU
@@ -79,6 +83,11 @@ def rewrite_mark_tensor_purpose(op, arch):
if scratch_tensor.name.endswith("_scratch"):
scratch_tensor.purpose = TensorPurpose.Scratch
+ if len(op.inputs) >= 4:
+ scratch_fast_tensor = op.inputs[3] # should be existing scratch fast tensor
+ if scratch_fast_tensor.name.endswith("_scratch_fast"):
+ scratch_fast_tensor.purpose = TensorPurpose.ScratchFast
+
if scratch_tensor is None:
op.error("Scratch tensor not found.")