aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Bohlin <jacob.bohlin@arm.com>2021-06-21 17:22:20 +0200
committerJacob Bohlin <jacob.bohlin@arm.com>2021-06-22 14:00:55 +0200
commit98bfecd20f600a12de2f6a282d2fdbddb23dc081 (patch)
tree432b0c6a33a99af95f8a22f98f4778b26b1b4485
parent3016157e5099a50075d1a8b54d1b2cac2ee3899e (diff)
downloadethos-u-vela-98bfecd20f600a12de2f6a282d2fdbddb23dc081.tar.gz
MLBEDSW-4807 Elementwise IFM/OFM overlap
Reinstated allowing the IFM and OFM tensor to overlap for Elementwise operations. Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com> Change-Id: Ide6db7781f3ca7a36c8ff9e3efdc7943a7bf6d7f
-rw-r--r--ethosu/vela/live_range.py133
1 files changed, 36 insertions, 97 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index b687a9e7..2795b668 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -20,7 +20,6 @@ from typing import List
import numpy as np
-from .nn_graph import PassPlacement
from .operation import Op
from .tensor import MemArea
from .tensor import MemType
@@ -167,98 +166,40 @@ def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_se
return False
-# Tries merging of ifm/ofm live ranges for memory only ops and elementwise ops
-def merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set):
- for ps in sg.passes:
- if ps.placement == PassPlacement.MemoryOnly:
- # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
- input_tensor = ps.inputs[0]
- output_tensor = ps.outputs[0]
- if not tensor_should_be_ignored(lr_graph, input_tensor, target_mem_area, target_mem_type_set) and not (
- tensor_should_be_ignored(lr_graph, output_tensor, target_mem_area, target_mem_type_set)
- ):
- lr_graph.fuse_ranges(input_tensor, output_tensor)
- elif ps.is_element_wise:
- merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set)
-
-
-# Tries to merge ifm/ofm live of elementwise op
-def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set):
- elem_op = None
- for op in ps.ops:
- if op.type.is_elementwise_op():
- assert elem_op is None
- elem_op = op
-
- if elem_op is not None and not tensor_should_be_ignored(
- lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set
- ):
- # Check if overwriting the inputs can be allowed
- if elem_op.type not in (Op.SHL, Op.SHR):
- inps = []
- if (
- elem_op.ifm is not None
- and elem_op.ifm.shape != []
- and elem_op.ifm.mem_area == target_mem_area
- and elem_op.ifm.mem_type in target_mem_type_set
- ):
- inps.append(elem_op.ifm)
- if (
- elem_op.ifm2 is not None
- and elem_op.ifm2.shape != []
- and elem_op.ifm2.mem_area == target_mem_area
- and elem_op.ifm.mem_type in target_mem_type_set
- ):
- inps.append(elem_op.ifm2)
-
- if len(inps) > 0:
- for i, inp in enumerate(inps):
- # check input format, dtype, broadcasting or if there are more input consumers
- if (
- inp.format == elem_op.ofm.format
- and inp.dtype == elem_op.ofm.dtype
- and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0]
- and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
- ):
- lr_graph.fuse_ranges(inp, elem_op.ofm)
- break
-
-
-def extract_live_ranges_from_passes(
- sg, target_mem_area, target_mem_type_set=None, ignore_subgraph_input_output_tensors=False,
-):
- lr_graph = LiveRangeGraph()
-
- if ignore_subgraph_input_output_tensors:
- lr_graph.ignore_tensors.update(sg.input_tensors)
- lr_graph.ignore_tensors.update(sg.output_tensors)
-
- if target_mem_type_set is None:
- target_mem_type_set = set((MemType.Scratch, MemType.Scratch_fast))
-
- # Try to merge live ranges of operations in the NPU subgraphs
- if sg.placement == PassPlacement.Npu:
- merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
-
- for idx, ps in enumerate(sg.passes):
- ps.time = 2 * idx
-
- time_for_pass = ps.time
-
- for tens in ps.inputs + ps.intermediates + ps.outputs:
- if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
- continue
- rng = lr_graph.get_or_create_range(tens)
- rng.mark_usage(time_for_pass)
-
- end_time = len(sg.passes) * 2
- for tens in sg.output_tensors:
- if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
- continue
- rng = lr_graph.get_or_create_range(tens)
- rng.mark_usage(end_time)
+def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set):
+ # Tries to merge ifm/ofm live ranges of elementwise op
+ if sched_op.op_type.is_elementwise_op():
+ elem_op = sched_op.parent_op
+ if not tensor_should_be_ignored(lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set):
+ # Check if overwriting the inputs can be allowed
+ if elem_op.type not in (Op.SHL, Op.SHR):
+ inps = []
+ if (
+ elem_op.ifm is not None
+ and elem_op.ifm.shape != []
+ and elem_op.ifm.mem_area == target_mem_area
+ and elem_op.ifm.mem_type in target_mem_type_set
+ ):
+ inps.append(elem_op.ifm)
+ if (
+ elem_op.ifm2 is not None
+ and elem_op.ifm2.shape != []
+ and elem_op.ifm2.mem_area == target_mem_area
+ and elem_op.ifm.mem_type in target_mem_type_set
+ ):
+ inps.append(elem_op.ifm2)
- return lr_graph
+ if len(inps) > 0:
+ for i, inp in enumerate(inps):
+ # check input format, dtype, broadcasting or if there are more input consumers
+ if (
+ inp.format == elem_op.ofm.format
+ and inp.dtype == elem_op.ofm.dtype
+ and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0]
+ and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
+ ):
+ lr_graph.fuse_ranges(inp, elem_op.ofm)
+ break
def extract_live_ranges_from_cascaded_passes(
@@ -280,10 +221,6 @@ def extract_live_ranges_from_cascaded_passes(
lr_graph.ignore_tensors.update(sg.input_tensors)
lr_graph.ignore_tensors.update(sg.output_tensors)
- # Try to merge live ranges of operations in the NPU subgraphs
- if sg.placement == PassPlacement.Npu:
- merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
-
for cps in sg.cascaded_passes:
cps.time = lr_graph.current_time
@@ -347,7 +284,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_
rng = lr_graph.get_or_create_range(tens)
rng.mark_usage(sg_time)
- for sched_op, op_info in sg.schedule.cost_map.items():
+ for _, op_info in sg.schedule.cost_map.items():
for tensor in [op_info.npu_weights_tensor, op_info.npu_scales_tensor]:
if tensor and not (tensor_should_be_ignored(lr_graph, tensor, target_mem_area, target_mem_type_set)):
rng = lr_graph.get_or_create_range(tensor)
@@ -360,6 +297,8 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_
def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph):
time_for_cascade = {}
for sched_op in sg.sched_ops:
+ merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set)
+
op_info = sg.schedule.cost_map[sched_op]
cascade = op_info.cascade
cascade_info = sg.schedule.cascades.get(cascade, None)