aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
authorJacob Bohlin <jacob.bohlin@arm.com>2020-09-11 10:04:15 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-09-17 08:18:50 +0000
commit1a66697b80a527af6d6dd1ed235199264696767e (patch)
tree447f19903eedb0ed163348769da28267ccf3bf47 /ethosu/vela/tensor.py
parent1356c2ab034738bcf51822de18911cc499fa2e8e (diff)
downloadethos-u-vela-1a66697b80a527af6d6dd1ed235199264696767e.tar.gz
MLBEDSW-2809: Redo the Tensor addressing
Added a static class TensorAddressMap that stores all Tensor addresses based on their equivalence_id. Made the "address" field into a property which getter and setter looks up/sets the tensor's address in TensorAddressMap. This makes the references to cpu_tensor/npu_tensor obsolete and they have been removed. Addition to scheduler: avoid SRAM spilling if an op has consumers in other subgraphs. Minor rework in LUTState; it will now assign a unique equivalence_id to the SHRAM lut tensor to avoid issues with addressing. The equivalent checks in LUTState now compares the values of the LUT instead of the the equivalence_id. Updated LUT unit tests accordingly. Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com> Change-Id: I41de5a8a4e5f07b77d6544d8d4034b754993e503
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py37
1 files changed, 28 insertions, 9 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 49521e7a..0f8170d4 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -17,6 +17,7 @@
# Internal representation of a Neural Network Tensor.
import enum
import uuid
+from collections import defaultdict
import numpy as np
@@ -258,6 +259,25 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True):
return reshape_ofm if ifm_reshape else reshape_ifm
+# class that keeps track of all tensor addresses in the different memory types
+class TensorAddressMap:
+ address_map = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
+
+ @classmethod
+ def get_address_for_tens(cls, tens_id, mem_type):
+ return cls.address_map[tens_id].get(mem_type)
+
+ @classmethod
+ def set_address_for_tens(cls, tens_id, mem_type, address):
+ # Check previous address if there is one
+ previous_address = cls.address_map[tens_id].get(mem_type)
+ if previous_address is not None:
+ assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
+
+ # Set tensor's address for memory type
+ cls.address_map[tens_id][mem_type] = address
+
+
class Tensor:
__slots__ = (
"shape",
@@ -285,13 +305,10 @@ class Tensor:
"weight_compression_config",
"storage_rounding_quantum",
"brick_size",
- "address",
"quantization",
"weight_compressed_offsets",
"element_size_bytes",
"block_traversal",
- "cpu_tensor",
- "npu_tensor",
"equivalence_id",
"resampling_mode",
"avoid_NHCWB16",
@@ -308,10 +325,6 @@ class Tensor:
self.ops = []
self.consumer_list = []
- # Below attributes are only set if a tensor has been cloned,
- # either from Cpu -> Npu or vice versa. Needed for offline allocation
- self.cpu_tensor = None # reference to the corresponding Cpu tensor
- self.npu_tensor = None # reference to the corresponding Npu tensor
self.values = None
self.quant_values = None
@@ -333,7 +346,6 @@ class Tensor:
self.weight_compressed_offsets = []
self.storage_rounding_quantum = (1, 1, 1, 1)
self.brick_size = (1, 1, 1, 1)
- self.address = None # start address of tensor. will be filled in by tensor allocator
self.element_size_bytes = 0
# quantization parameters
@@ -343,6 +355,14 @@ class Tensor:
self.avoid_NHCWB16 = False
+ @property
+ def address(self):
+ return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
+
+ @address.setter
+ def address(self, address):
+ TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
+
def element_size(self):
if self.element_size_bytes == 0:
return self.dtype.size_in_bits() / 8
@@ -367,7 +387,6 @@ class Tensor:
res.alignment = self.alignment
res.bandwidth_compression_scale = self.bandwidth_compression_scale
res.storage_rounding_quantum = self.storage_rounding_quantum
- res.address = None
if self.quantization is not None:
res.quantization = self.quantization.clone()