aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-10-16 13:59:52 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-10-28 12:32:21 +0000
commita151f597bef24e0d8b51dbe833338057e8bcbc92 (patch)
tree88b6ec7aad0a423935b13594020de573a9dce7df
parentc7c0b1ba5e7c3dea73d1ab175b03ff188658d27b (diff)
downloadethos-u-vela-a151f597bef24e0d8b51dbe833338057e8bcbc92.tar.gz
MLBEDSW-3212 Enable overlap of elementwise input/output
Enable overlap of elementwise input/output Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I6e6f11953319c843c8203bf038f96778df194332
-rw-r--r--ethosu/vela/live_range.py177
-rw-r--r--ethosu/vela/scheduler.py5
-rw-r--r--ethosu/vela/tensor_allocation.py1
3 files changed, 85 insertions, 98 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 23026c7..b884035 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -98,18 +98,6 @@ class LiveRange:
self.alignment = max(self.alignment, alignment)
-def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area):
- for ps in sg.passes:
- if ps.placement == PassPlacement.MemoryOnly:
- # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
- input_tensor = ps.inputs[0]
- output_tensor = ps.outputs[0]
- if not tensor_should_be_ignored(input_tensor, target_mem_area) and not tensor_should_be_ignored(
- output_tensor, target_mem_area
- ):
- lr_graph.fuse_ranges(input_tensor, output_tensor)
-
-
class LiveRangeGraph:
def __init__(self):
self.ranges = {} # tens -> range
@@ -138,10 +126,79 @@ class LiveRangeGraph:
return live_range
+def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
+ if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
+ return True
+ if tens in lr_graph.ignore_tensors:
+ return True
+ if tens.name.endswith("reshape_shape_npu"):
+ # Reshape tensor, no need to allocate
+ lr_graph.ignore_tensors.add(tens)
+ return True
+ return False
+
+
+# Tries merging of ifm/ofm live ranges for memory only ops and elementwise ops
+def merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set):
+ for ps in sg.passes:
+ if ps.placement == PassPlacement.MemoryOnly:
+ # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
+ input_tensor = ps.inputs[0]
+ output_tensor = ps.outputs[0]
+ if not tensor_should_be_ignored(lr_graph, input_tensor, target_mem_area, target_mem_type_set) and not (
+ tensor_should_be_ignored(lr_graph, output_tensor, target_mem_area, target_mem_type_set)
+ ):
+ lr_graph.fuse_ranges(input_tensor, output_tensor)
+ elif ps.is_element_wise:
+ merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set)
+
+
+# Tries to merge ifm/ofm live of elementwise op
+def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set):
+ elem_op = None
+ for op in ps.ops:
+ if op.type.is_elementwise_op():
+ assert elem_op is None
+ elem_op = op
+
+ if elem_op is not None and not tensor_should_be_ignored(
+ lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set
+ ):
+ # Check if overwriting the inputs can be allowed
+ if elem_op.type not in (Op.SHL, Op.SHR):
+ inps = []
+ if (
+ elem_op.ifm is not None
+ and elem_op.ifm.shape != []
+ and elem_op.ifm.mem_area == target_mem_area
+ and elem_op.ifm.mem_type in target_mem_type_set
+ ):
+ inps.append(elem_op.ifm)
+ if (
+ elem_op.ifm2 is not None
+ and elem_op.ifm2.shape != []
+ and elem_op.ifm2.mem_area == target_mem_area
+ and elem_op.ifm.mem_type in target_mem_type_set
+ ):
+ inps.append(elem_op.ifm2)
+
+ if len(inps) > 0:
+ for inp in inps:
+ # check input format, dtype, broadcasting or if there are more input consumers
+ if (
+ inp.format == elem_op.ofm.format
+ and inp.dtype == elem_op.ofm.dtype
+ and inp.shape == elem_op.ofm.shape
+ and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
+ ):
+ lr_graph.fuse_ranges(inp, elem_op.ofm)
+ break
+
+
def extract_live_ranges_from_passes(
sg,
target_mem_area,
- mark_output_tensors_overlapping_with_input_tensors=False,
+ target_mem_type=set((MemType.Scratch, MemType.Scratch_fast)),
ignore_subgraph_input_output_tensors=False,
):
lr_graph = LiveRangeGraph()
@@ -150,50 +207,24 @@ def extract_live_ranges_from_passes(
lr_graph.ignore_tensors.update(sg.input_tensors)
lr_graph.ignore_tensors.update(sg.output_tensors)
- def tensor_should_be_ignored(tens, target_mem_area):
- if tens.mem_area != target_mem_area:
- return True
- if tens in lr_graph.ignore_tensors:
- return True
- if tens.name.endswith("reshape_shape_npu"):
- # Reshape tensor, no need to allocate
- lr_graph.ignore_tensors.add(tens)
- return True
- return False
-
- # Merge only memory operations in the NPU subgraphs
+ # Try to merge live ranges of operations in the NPU subgraphs
if sg.placement == PassPlacement.Npu:
- merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area)
+ merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type)
for idx, ps in enumerate(sg.passes):
ps.time = 2 * idx
time_for_pass = ps.time
- for tens in ps.inputs:
- if tensor_should_be_ignored(tens, target_mem_area):
- continue
- rng = lr_graph.get_or_create_range(tens)
- rng.mark_usage(time_for_pass)
-
- for tens in ps.intermediates:
- if tensor_should_be_ignored(tens, target_mem_area):
+ for tens in ps.inputs + ps.intermediates + ps.outputs:
+ if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type):
continue
rng = lr_graph.get_or_create_range(tens)
rng.mark_usage(time_for_pass)
- for tens in ps.outputs:
- if tensor_should_be_ignored(tens, target_mem_area):
- continue
- rng = lr_graph.get_or_create_range(tens)
- output_time = time_for_pass
- if not mark_output_tensors_overlapping_with_input_tensors and ps.is_element_wise:
- output_time += 1
- rng.mark_usage(output_time)
-
end_time = len(sg.passes) * 2
for tens in sg.output_tensors:
- if tensor_should_be_ignored(tens, target_mem_area):
+ if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type):
continue
rng = lr_graph.get_or_create_range(tens)
rng.mark_usage(end_time)
@@ -205,7 +236,6 @@ def extract_live_ranges_from_cascaded_passes(
sg,
target_mem_area,
target_mem_type_set,
- mark_output_tensors_overlapping_with_input_tensors=False,
use_ifm_ofm_overlap=True,
ignore_subgraph_input_output_tensors=False,
lr_graph=None,
@@ -222,41 +252,17 @@ def extract_live_ranges_from_cascaded_passes(
lr_graph.ignore_tensors.update(sg.input_tensors)
lr_graph.ignore_tensors.update(sg.output_tensors)
- def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
- if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
- return True
- if tens in lr_graph.ignore_tensors:
- return True
- if tens.name.endswith("reshape_shape_npu"):
- # Reshape tensor, no need to allocate
- lr_graph.ignore_tensors.add(tens)
- return True
- return False
-
- def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set):
- for ps in sg.passes:
- if ps.placement == PassPlacement.MemoryOnly:
- # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
- input_tensor = ps.inputs[0]
- output_tensor = ps.outputs[0]
- if not tensor_should_be_ignored(input_tensor, target_mem_area, target_mem_type_set) and not (
- tensor_should_be_ignored(output_tensor, target_mem_area, target_mem_type_set)
- ):
- lr_graph.fuse_ranges(input_tensor, output_tensor)
-
- # Merge only memory operations in the NPU subgraphs
+ # Try to merge live ranges of operations in the NPU subgraphs
if sg.placement == PassPlacement.Npu:
- merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set)
+ merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
for cps in sg.cascaded_passes:
cps.time = lr_graph.current_time
time_for_pass = cps.time
- is_element_wise = cps.is_element_wise
-
for tens in cps.inputs:
- if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
+ if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens, allocation_alignment)
rng.mark_usage(time_for_pass)
@@ -273,33 +279,18 @@ def extract_live_ranges_from_cascaded_passes(
# Use default allocation alignment of 16 for Npu tensors
npu_sg = cps_primary_op.attrs["subgraph"]
lr_graph = extract_live_ranges_from_cascaded_passes(
- npu_sg,
- target_mem_area,
- target_mem_type_set,
- mark_output_tensors_overlapping_with_input_tensors,
- use_ifm_ofm_overlap,
- False,
- lr_graph,
+ npu_sg, target_mem_area, target_mem_type_set, use_ifm_ofm_overlap, False, lr_graph,
)
# Set the new time after handling the Npu subgraph
time_for_pass = lr_graph.current_time
cps.time = time_for_pass
- for tens in cps.intermediates:
- if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
+ for tens in cps.intermediates + cps.outputs:
+ if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens, allocation_alignment)
rng.mark_usage(time_for_pass)
- for tens in cps.outputs:
- if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
- continue
- rng = lr_graph.get_or_create_range(tens, allocation_alignment)
- output_time = time_for_pass
- if not mark_output_tensors_overlapping_with_input_tensors and is_element_wise:
- output_time += 1
- rng.mark_usage(output_time)
-
if use_ifm_ofm_overlap:
# fill allowed overlap for ifm and ofm tensor
ifm_tensor = cps.passes[0].ifm_tensor
@@ -307,8 +298,8 @@ def extract_live_ranges_from_cascaded_passes(
if (
ifm_tensor is not None
and ofm_tensor is not None
- and not tensor_should_be_ignored(ifm_tensor, target_mem_area, target_mem_type_set)
- and not tensor_should_be_ignored(ofm_tensor, target_mem_area, target_mem_type_set)
+ and not tensor_should_be_ignored(lr_graph, ifm_tensor, target_mem_area, target_mem_type_set)
+ and not tensor_should_be_ignored(lr_graph, ofm_tensor, target_mem_area, target_mem_type_set)
):
lr_graph.allowed_overlaps[(ifm_tensor, ofm_tensor)] = calc_allowed_ofm_ifm_overlap_for_cascaded_pass(
cps
@@ -322,7 +313,7 @@ def extract_live_ranges_from_cascaded_passes(
end_time = max(end_time, rng.end_time)
for tens in sg.output_tensors:
- if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
+ if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens, allocation_alignment)
rng.mark_usage(end_time)
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 41e1529..31e6383 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -480,10 +480,7 @@ class DynamicProgrammingScheduler:
def calc_non_local_mem_usage(self):
ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
range_set = live_range.extract_live_ranges_from_passes(
- self.sg,
- self.mem_area,
- mark_output_tensors_overlapping_with_input_tensors=True,
- ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
+ self.sg, self.mem_area, ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
)
range_dict = range_set.ranges
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index 1efcd68..9f14ec4 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -137,7 +137,6 @@ def allocate_tensors(
sg,
mem_area,
mem_type_set,
- mark_output_tensors_overlapping_with_input_tensors=False,
use_ifm_ofm_overlap=use_ifm_ofm_overlap,
ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
lr_graph=lr_graph,