aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Bohlin <jacob.bohlin@arm.com>2020-09-11 10:04:15 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-09-17 08:18:50 +0000
commit1a66697b80a527af6d6dd1ed235199264696767e (patch)
tree447f19903eedb0ed163348769da28267ccf3bf47
parent1356c2ab034738bcf51822de18911cc499fa2e8e (diff)
downloadethos-u-vela-1a66697b80a527af6d6dd1ed235199264696767e.tar.gz
MLBEDSW-2809: Redo the Tensor addressing
Added a static class TensorAddressMap that stores all Tensor addresses based on their equivalence_id. Made the "address" field into a property which getter and setter looks up/sets the tensor's address in TensorAddressMap. This makes the references to cpu_tensor/npu_tensor obsolete and they have been removed. Addition to scheduler: avoid SRAM spilling if an op has consumers in other subgraphs. Minor rework in LUTState; it will now assign a unique equivalence_id to the SHRAM lut tensor to avoid issues with addressing. The equivalent checks in LUTState now compares the values of the LUT instead of the the equivalence_id. Updated LUT unit tests accordingly. Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com> Change-Id: I41de5a8a4e5f07b77d6544d8d4034b754993e503
-rw-r--r--ethosu/vela/extract_npu_subgraphs.py6
-rw-r--r--ethosu/vela/live_range.py32
-rw-r--r--ethosu/vela/lut.py6
-rw-r--r--ethosu/vela/scheduler.py17
-rw-r--r--ethosu/vela/tensor.py37
-rw-r--r--ethosu/vela/test/test_lut.py14
6 files changed, 61 insertions, 51 deletions
diff --git a/ethosu/vela/extract_npu_subgraphs.py b/ethosu/vela/extract_npu_subgraphs.py
index 4adddc17..c0430b5d 100644
--- a/ethosu/vela/extract_npu_subgraphs.py
+++ b/ethosu/vela/extract_npu_subgraphs.py
@@ -70,10 +70,7 @@ def rewrite_tensor_cpu_producer_npu_consumers(
orig_tens, call_ps, startup_init_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass
):
is_const = orig_tens.ops[0].type == "Const"
-
new_tens = orig_tens.clone("_npu")
- orig_tens.npu_tensor = new_tens
- new_tens.cpu_tensor = orig_tens
op_type = "SubgraphInput"
if is_const:
@@ -107,9 +104,6 @@ def rewrite_tensor_npu_producer_cpu_consumers(
):
new_tens = orig_tens.clone("_cpu")
- new_tens.npu_tensor = orig_tens
- orig_tens.cpu_tensor = new_tens
-
npu_subgraph.output_tensors.append(orig_tens)
call_ps.outputs.append(new_tens)
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 156090f7..9a8ee580 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -84,21 +84,11 @@ class LiveRange:
return self.name < other.name
def set_address(self, address):
- # Set address of all unaddressed tensors in LiveRange
+ # Set address of all tensors in LiveRange
for tens in self.tensors:
- if tens.address is None:
- addr = address
- else:
- # Limit to single tensor for the lr if the tensor address already assigned
- assert len(self.tensors) == 1
- addr = tens.address
- tens.address = addr
- # Also need to set the address to the tensor's cpu/npu clones
- if tens.cpu_tensor is not None:
- tens.cpu_tensor.address = addr
- if tens.npu_tensor is not None:
- tens.npu_tensor.address = addr
- return addr
+ tens.address = address
+
+ return address
def get_alignment(self):
return self.alignment
@@ -113,10 +103,6 @@ def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_ar
# For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
input_tensor = ps.inputs[0]
output_tensor = ps.outputs[0]
- # If the input or output tensor is tied to a Cpu tensor, i.e. a subgraph input
- # or output, fuse the live-range with the Cpu tensors' live-range instead.
- input_tensor = input_tensor.cpu_tensor if input_tensor.cpu_tensor is not None else input_tensor
- output_tensor = output_tensor.cpu_tensor if output_tensor.cpu_tensor is not None else output_tensor
if not tensor_should_be_ignored(input_tensor, target_mem_area) and not tensor_should_be_ignored(
output_tensor, target_mem_area
):
@@ -132,9 +118,9 @@ class LiveRangeGraph:
self.current_time = 0
def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
- for rng in self.ranges.values():
- # Return the live range of the tensor (or it's cpu/npu clone)
- if any(tensor in rng.tensors for tensor in [tens, tens.npu_tensor, tens.cpu_tensor]):
+ # Return the live range of the tensor (or any of its clones)
+ for existing_tensor, rng in self.ranges.items():
+ if tens.equivalent(existing_tensor):
rng.set_alignment(alignment)
return rng
@@ -252,10 +238,6 @@ def extract_live_ranges_from_cascaded_passes(
# For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
input_tensor = ps.inputs[0]
output_tensor = ps.outputs[0]
- # If the input or output tensor is tied to a Cpu tensor, i.e. a subgraph input
- # or output, fuse the live-range with the Cpu tensors' live-range instead.
- input_tensor = input_tensor.cpu_tensor if input_tensor.cpu_tensor is not None else input_tensor
- output_tensor = output_tensor.cpu_tensor if output_tensor.cpu_tensor is not None else output_tensor
if not tensor_should_be_ignored(input_tensor, target_mem_area, target_mem_type_set) and not (
tensor_should_be_ignored(output_tensor, target_mem_area, target_mem_type_set)
):
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
index 0e8dcc95..e3373ca2 100644
--- a/ethosu/vela/lut.py
+++ b/ethosu/vela/lut.py
@@ -42,9 +42,9 @@ class LUTState:
self.tensors = []
def get_equivalent(self, lut_tens):
- # Returns existing lut with same equivalence id, None if not found
+ # Returns existing lut with the same values, None if not found
for t in self.tensors:
- if t.equivalent(lut_tens):
+ if np.array_equal(t.values, lut_tens.values):
return t
return None
@@ -60,6 +60,7 @@ class LUTState:
end2 = start2 + tens.storage_size()
if not numeric_util.overlaps(start, end, start2, end2):
new_state.tensors.append(tens)
+
return new_state
def find_best_address(self, start, stop, step):
@@ -129,6 +130,7 @@ def optimize_high_level_cmd_stream(sg, arch):
# Place the LUT in the last 2 blocks of SHRAM
# Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc
address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size())
+ lut_tens.equivalence_id = uuid.uuid4()
lut_tens.address = address
cmd.ps.primary_op.attrs["lut_index"] = (address - lut_start) // slot_size
lut_state = lut_state.put(lut_tens)
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index e9a93c19..47f8a47f 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -35,6 +35,7 @@ from .npu_performance import make_cycles_array
from .npu_performance import make_macs_array
from .npu_performance import make_metrics_arrays
from .npu_performance import PassCycles
+from .numeric_util import full_shape
from .operation import NpuBlockType
from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
@@ -43,7 +44,7 @@ from .tensor import MemType
from .tensor import TensorFormat
from .tensor import TensorPurpose
from .tensor import TensorSubPurpose
-from .numeric_util import full_shape
+
class ParetoMetric(enum.Enum):
BwCycMem = 1
@@ -652,6 +653,9 @@ class DynamicProgrammingScheduler:
for op in pred_candidate.ops:
if op.type == "ConcatSliceWrite":
return True
+ if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1:
+ # The op has consumers in other subgraphs
+ return True
return False
def search_ifm_streaming_partial(self, ps, block_config):
@@ -976,8 +980,15 @@ class DynamicProgrammingScheduler:
# be processed by CPU operations. No-op reshape consumers with empty lists
# (those that have no consumers, or null-consumers used as list terminators)
# must use normal NHWC output.
- incompatible_consumers = [ (not consumer.run_on_npu or consumer.type == "Reshape" or (consumer is last_op_in_subgraph))
- for consumer in op.outputs[0].consumer_list if consumer is not None ]
+ incompatible_consumers = [
+ (
+ not consumer.run_on_npu
+ or consumer.type == "Reshape"
+ or (consumer is last_op_in_subgraph)
+ )
+ for consumer in op.outputs[0].consumer_list
+ if consumer is not None
+ ]
if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
rewrites.append(op)
else:
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 49521e7a..0f8170d4 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -17,6 +17,7 @@
# Internal representation of a Neural Network Tensor.
import enum
import uuid
+from collections import defaultdict
import numpy as np
@@ -258,6 +259,25 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True):
return reshape_ofm if ifm_reshape else reshape_ifm
+# class that keeps track of all tensor addresses in the different memory types
+class TensorAddressMap:
+ address_map = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
+
+ @classmethod
+ def get_address_for_tens(cls, tens_id, mem_type):
+ return cls.address_map[tens_id].get(mem_type)
+
+ @classmethod
+ def set_address_for_tens(cls, tens_id, mem_type, address):
+ # Check previous address if there is one
+ previous_address = cls.address_map[tens_id].get(mem_type)
+ if previous_address is not None:
+ assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
+
+ # Set tensor's address for memory type
+ cls.address_map[tens_id][mem_type] = address
+
+
class Tensor:
__slots__ = (
"shape",
@@ -285,13 +305,10 @@ class Tensor:
"weight_compression_config",
"storage_rounding_quantum",
"brick_size",
- "address",
"quantization",
"weight_compressed_offsets",
"element_size_bytes",
"block_traversal",
- "cpu_tensor",
- "npu_tensor",
"equivalence_id",
"resampling_mode",
"avoid_NHCWB16",
@@ -308,10 +325,6 @@ class Tensor:
self.ops = []
self.consumer_list = []
- # Below attributes are only set if a tensor has been cloned,
- # either from Cpu -> Npu or vice versa. Needed for offline allocation
- self.cpu_tensor = None # reference to the corresponding Cpu tensor
- self.npu_tensor = None # reference to the corresponding Npu tensor
self.values = None
self.quant_values = None
@@ -333,7 +346,6 @@ class Tensor:
self.weight_compressed_offsets = []
self.storage_rounding_quantum = (1, 1, 1, 1)
self.brick_size = (1, 1, 1, 1)
- self.address = None # start address of tensor. will be filled in by tensor allocator
self.element_size_bytes = 0
# quantization parameters
@@ -343,6 +355,14 @@ class Tensor:
self.avoid_NHCWB16 = False
+ @property
+ def address(self):
+ return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
+
+ @address.setter
+ def address(self, address):
+ TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
+
def element_size(self):
if self.element_size_bytes == 0:
return self.dtype.size_in_bits() / 8
@@ -367,7 +387,6 @@ class Tensor:
res.alignment = self.alignment
res.bandwidth_compression_scale = self.bandwidth_compression_scale
res.storage_rounding_quantum = self.storage_rounding_quantum
- res.address = None
if self.quantization is not None:
res.quantization = self.quantization.clone()
diff --git a/ethosu/vela/test/test_lut.py b/ethosu/vela/test/test_lut.py
index 3dda1793..ee1a40fe 100644
--- a/ethosu/vela/test/test_lut.py
+++ b/ethosu/vela/test/test_lut.py
@@ -15,6 +15,8 @@
# limitations under the License.
# Description:
# Unit tests for LUT support
+import random
+
import numpy as np
from ethosu.vela import insert_dma
@@ -31,29 +33,29 @@ from ethosu.vela.test import testutil
def set_256_lut(op, key):
- values = list(range(256))
+ random.seed(key)
+ values = random.choices(range(256), k=256)
lut_tensor = create_const_tensor(
op.name + "_lut", [1, 1, 1, 256], DataType.int8, values, np.uint8, TensorPurpose.LUT
)
- lut_tensor.equivalence_id = lut.create_equivalence_id(key)
op.set_activation_lut(lut_tensor)
def set_1K_lut(op, key):
- values = list(range(256))
+ random.seed(key)
+ values = random.choices(range(256), k=256)
lut_tensor = create_const_tensor(
op.name + "_lut", [1, 1, 1, 256], DataType.int32, values, np.uint32, TensorPurpose.LUT
)
- lut_tensor.equivalence_id = lut.create_equivalence_id(key)
op.set_activation_lut(lut_tensor)
def set_2K_lut(op, key):
- values = list(range(512))
+ random.seed(key)
+ values = random.choices(range(512), k=512)
lut_tensor = create_const_tensor(
op.name + "_lut", [1, 1, 1, 512], DataType.int32, values, np.uint32, TensorPurpose.LUT
)
- lut_tensor.equivalence_id = lut.create_equivalence_id(key)
op.set_activation_lut(lut_tensor)