diff options
Diffstat (limited to 'ethosu/vela/mark_tensors.py')
-rw-r--r-- | ethosu/vela/mark_tensors.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py index 5a475841..114649d9 100644 --- a/ethosu/vela/mark_tensors.py +++ b/ethosu/vela/mark_tensors.py @@ -24,7 +24,7 @@ from .tensor import TensorPurpose def get_format(purpose, arch): - if purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT, TensorPurpose.Scratch): + if purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT, TensorPurpose.Scratch, TensorPurpose.ScratchFast): fmt = arch.default_feature_map_format elif purpose == TensorPurpose.Weights: fmt = arch.default_weight_format @@ -46,7 +46,11 @@ def mark_purpose(tens, arch, purpose): tens.mem_area = arch.tensor_storage_mem_area[tens.purpose] tens.mem_type = arch.tensor_storage_mem_type[tens.purpose] - if len(tens.ops) == 1 and tens.ops[0].type == Op.Const: + if ( + len(tens.ops) == 1 + and tens.ops[0].type == Op.Const + and purpose not in (TensorPurpose.Scratch, TensorPurpose.ScratchFast) + ): tens.mem_area = arch.permanent_storage_mem_area # special case constants, as they must be in permanent storage tens.mem_type = MemType.Permanent_NPU @@ -79,6 +83,11 @@ def rewrite_mark_tensor_purpose(op, arch): if scratch_tensor.name.endswith("_scratch"): scratch_tensor.purpose = TensorPurpose.Scratch + if len(op.inputs) >= 4: + scratch_fast_tensor = op.inputs[3] # should be existing scratch fast tensor + if scratch_fast_tensor.name.endswith("_scratch_fast"): + scratch_fast_tensor.purpose = TensorPurpose.ScratchFast + if scratch_tensor is None: op.error("Scratch tensor not found.") |