diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 42d95262..3990164d 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -26,6 +26,27 @@ from .numeric_util import round_up_divide from .range_set import MemoryRangeSet +class MemType(enum.IntFlag): + Unknown = 0 + Permanent_NPU = 1 + Permanent_CPU = 2 + Scratch = 3 + Scratch_fast = 4 + Size = Scratch_fast + 1 + + def display_name(self): + return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value] + + def identifier_name(self): + return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value] + + def all(): + return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast) + + def __str__(self): + return self.name + + class MemArea(enum.IntFlag): Unknown = 0 Sram = 1 @@ -209,6 +230,7 @@ class Tensor: "quant_values", "compressed_values", "mem_area", + "mem_type", "format", "purpose", "sub_purpose", @@ -252,6 +274,7 @@ class Tensor: self.quant_values = None self.compressed_values = None self.mem_area = MemArea.Unknown + self.mem_type = MemType.Unknown self.format = TensorFormat.Unknown self.purpose = TensorPurpose.Unknown self.sub_purpose = TensorSubPurpose.Standard @@ -291,6 +314,7 @@ class Tensor: res.values = self.values res.quant_values = self.quant_values res.mem_area = self.mem_area + res.mem_type = self.mem_type res.format = self.format res.purpose = self.purpose res.sub_purpose = self.sub_purpose @@ -312,6 +336,7 @@ class Tensor: def clone_into_fast_storage(self, arch): res = self.clone(suffix="_fast_storage") res.mem_area = arch.fast_storage_mem_area + res.mem_type = MemType.Scratch_fast return res def copy_compressed_weight_info(self, src_tens): @@ -641,6 +666,11 @@ class Tensor: assert address_offset <= self.storage_size() return address_offset + def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area): + if self.mem_area == scratch_tensor_mem_area and (self.mem_type in set((MemType.Scratch, MemType.Scratch_fast))): + return True + return False + def __str__(self): return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype) |