aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2021-06-24 18:29:53 +0100
committerTim Hall <tim.hall@arm.com>2021-07-09 11:51:22 +0100
commitffe8e288c8321dfc55a3b75f1aedc08769ecb23a (patch)
tree7445c018851422a5c1b08f1ae5d7f5529fa461aa
parent5e26eda0e0f359b6e22b1f1eeb9344cd15e0f093 (diff)
downloadethos-u-vela-ffe8e288c8321dfc55a3b75f1aedc08769ecb23a.tar.gz
MLBEDSW-4839: Fix issues with Elementwise IFM/OFM overlap
- Fixed typo with not using ifm.mem_type - Fixed bug with using ifm1 properties when only ifm2 is a potential match - Removed restriction on not considering SHL and SHR for overlap - Removed some dead reshape code Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: Id9bcc3c2b3ee9ac7b6276187d3e2f513b4acd4b5
-rw-r--r--ethosu/vela/live_range.py61
-rw-r--r--ethosu/vela/pass_packing.py3
-rw-r--r--ethosu/vela/tflite_writer.py2
3 files changed, 32 insertions, 34 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 2795b668..7ff1b28d 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -16,6 +16,7 @@
# Description:
# Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
# Can work with either a pass packed subgraph or a scheduled subgraph.
+from collections import namedtuple
from typing import List
import numpy as np
@@ -159,47 +160,45 @@ def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_se
return True
if tens in lr_graph.ignore_tensors:
return True
- if tens.name.endswith("reshape_shape_npu"):
- # Reshape tensor, no need to allocate
- lr_graph.ignore_tensors.add(tens)
- return True
return False
def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set):
+ def _tensor_should_be_ignored(tens):
+ return tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set)
+
# Tries to merge ifm/ofm live ranges of elementwise op
if sched_op.op_type.is_elementwise_op():
elem_op = sched_op.parent_op
- if not tensor_should_be_ignored(lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set):
+ if not _tensor_should_be_ignored(elem_op.ofm):
# Check if overwriting the inputs can be allowed
- if elem_op.type not in (Op.SHL, Op.SHR):
- inps = []
- if (
- elem_op.ifm is not None
- and elem_op.ifm.shape != []
- and elem_op.ifm.mem_area == target_mem_area
- and elem_op.ifm.mem_type in target_mem_type_set
- ):
- inps.append(elem_op.ifm)
+ OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
+ outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm)
+ inps = []
+ if elem_op.ifm is not None:
+ inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm))
+ if elem_op.ifm2 is not None:
+ inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2))
+
+ # find an input tensor that can be overwritten by the output
+ for inp in inps:
if (
- elem_op.ifm2 is not None
- and elem_op.ifm2.shape != []
- and elem_op.ifm2.mem_area == target_mem_area
- and elem_op.ifm.mem_type in target_mem_type_set
+ # check op input and output shapes allow overlapping
+ inp.op_shape == outp.op_shape
+ # check input tensor is valid
+ and inp.tens is not None
+ and inp.tens.shape != []
+ and not _tensor_should_be_ignored(inp.tens)
+ # check input and output tensors are compatible
+ and inp.tens.format == outp.tens.format
+ and inp.tens.dtype == outp.tens.dtype
+ # check input tensor only has one consumer
+ and len(inp.tens.consumer_list) == 1
+ # check output tensor only has one producer
+ and len(outp.tens.ops) == 1
):
- inps.append(elem_op.ifm2)
-
- if len(inps) > 0:
- for i, inp in enumerate(inps):
- # check input format, dtype, broadcasting or if there are more input consumers
- if (
- inp.format == elem_op.ofm.format
- and inp.dtype == elem_op.ofm.dtype
- and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0]
- and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
- ):
- lr_graph.fuse_ranges(inp, elem_op.ofm)
- break
+ lr_graph.fuse_ranges(inp.tens, outp.tens)
+ break
def extract_live_ranges_from_cascaded_passes(
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 518b2436..b28f4eb4 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -348,8 +348,7 @@ def pack_into_passes(nng, arch, verbose_packing=False):
ps.ifm_shapes.append(op.ifm_shapes[0])
elif ps.ifm_tensor == op.ifm2:
ps.ifm_shapes.append(op.ifm_shapes[1])
- for op in input_ops_list + [primary_op]:
- if op.run_on_npu:
+
if ps.ifm2_tensor == op.ifm:
ps.ifm_shapes.append(op.ifm_shapes[0])
elif ps.ifm2_tensor == op.ifm2:
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 3701893e..fd3bf421 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -39,7 +39,7 @@ from .tflite_mapping import builtin_operator_inv_map
from .tflite_mapping import BuiltinOperator
from .tflite_mapping import datatype_inv_map
-# ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
+# the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
tflite_version = 3
tflite_file_identifier = "TFL" + str(tflite_version)