aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-23 13:52:34 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-28 13:57:12 +0200
commitc2b129dbdd9b493fee904bb07838bb0b9e247e96 (patch)
tree9a4e02909ebf74839e459278aa5ecb59e3fa278a
parent3f22ec2025c8e1afe6780785fd8c62c015824a63 (diff)
downloadethos-u-vela-c2b129dbdd9b493fee904bb07838bb0b9e247e96.tar.gz
TOSA: Decomposition of CONCAT
-Added support for unlimited number of dimensions -Added support for Tensors with dimension size exceeding maximum limit of NPU. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I3cc7327ac759e69042a600e686160aeb18a5ec59
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py144
-rw-r--r--ethosu/vela/tosa_supported_operators.py28
2 files changed, 117 insertions, 55 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 1e059cc..5cd9d21 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -265,32 +265,21 @@ def remove_splitsliceread(op, arch):
DebugDatabase.add_optimised(op, add_op)
-def rewrite_concat_ops(op, arch):
+def rewrite_concat(op):
if not op.run_on_npu or not op.type == Op.Concat:
return
- axis_4D = 0
- ofm = op.ofm
- ofm.ops = []
offset = 0
-
inputs = op.inputs
- axis = op.attrs["axis"]
+ axis_4D = op.attrs["axis4D"]
for idx, inp in enumerate(inputs):
- op.ifm_shapes[idx] = Shape4D(inp.shape)
- if axis >= 0:
- axis_4D = axis + (4 - len(inp.shape))
- else:
- axis_4D = axis
write_offset = [0, 0, 0, 0]
write_offset[axis_4D] = offset
concat_end = offset + op.ifm_shapes[idx][axis_4D]
create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset))
offset = concat_end
- assert ofm.shape[axis] == offset
-
- return op
+ assert op.ofm_shapes[0][axis_4D] == offset
def remove_reshapes(op, arch):
@@ -503,12 +492,15 @@ def convert_table_to_lut(op, arch, nng):
return convert_to_lut(op, table.values, "table")
-def decompose_tensors_hwc(op):
+def decompose_elem_tensors_hwc(op):
+ """
+ Decomposes elementwise op if any of the ifm(s)/ofm are to large in any dimension to be handled by the NPU
+ """
max_t_size = 65535
- ofm_shape = op.ofm_shapes[0]
- ifm_shape = op.ifm_shapes[0]
+ ofm_shape = op.write_shape if op.write_shape is not None else op.ofm_shapes[0]
+ ifm_shape = op.read_shapes[0] if op.read_shapes[0] is not None else op.ifm_shapes[0]
ifm2_shape = op.ifm_shapes[1] if op.ifm_shapes[1] else None
-
+ ifm2_shape = op.read_shapes[1] if op.read_shapes[1] is not None else ifm2_shape
limit_shape = Shape4D(1, max_t_size, max_t_size, max_t_size)
if any(dim_size > max_t_size for dim_size in ofm_shape.as_list()):
@@ -536,7 +528,6 @@ def decompose_tensors_hwc(op):
ifm2_part_shape = ifm2_shape.clip(ifm2_offset, limit_shape)
ifm2_cut = (ifm2_offset, ifm2_part_shape)
else:
- ifm2_offset = None
ifm2_cut = (None, None)
create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut)
@@ -582,17 +573,23 @@ def get_elem_shapes_removed_singles(op):
"""
Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm
"""
- rank = len(op.ofm.shape)
binary = op.ifm2 is not None
+ ofm_shape = op.ofm_shapes[0].as_list() if len(op.ofm_shapes) > 0 else op.ofm.shape
+ ifm_shape = op.ifm_shapes[0].as_list() if len(op.ifm_shapes) > 0 else op.ifm.shape
+ if binary:
+ ifm2_shape = op.ifm_shapes[1].as_list() if len(op.ofm_shapes) else op.ifm2.shape
+
+ rank = len(ofm_shape)
new_ofm_shape = []
new_ifm_shape = []
new_ifm2_shape = []
for idx in range(rank):
- if op.ofm.shape[idx] != 1:
- new_ofm_shape.append(op.ofm.shape[idx])
- new_ifm_shape.append(op.ifm.shape[idx])
+ if ofm_shape[idx] != 1:
+ new_ofm_shape.append(ofm_shape[idx])
+ new_ifm_shape.append(ifm_shape[idx])
if binary:
- new_ifm2_shape.append(op.ifm2.shape[idx])
+ new_ifm2_shape.append(ifm2_shape[idx])
+
if new_ofm_shape == []:
new_ofm_shape = [1]
new_ifm_shape = [1]
@@ -614,7 +611,6 @@ def decomp_dims_elementwise(op):
ifm2 = op.ifm2
ofm = op.ofm
binary = op.ifm2 is not None
- assert len(ofm.shape) <= 6
# Remove dimensions that are all 1
new_ofm_shape, new_ifm_shape, new_ifm2_shape = get_elem_shapes_removed_singles(op)
@@ -679,19 +675,74 @@ def decomp_dims_elementwise(op):
def decomp_elementwise(tens, arch, nng):
"""
- Decompose elementwise ops with Rank > 3 (H,W,D).
+ Decompose elementwise ops with Rank > 3 (H,W,C).
Decompose size of tensors exceeding NPU max size
"""
- if len(tens.ops) == 1 and tens.ops[0].type.is_elementwise_op():
+ tens_ops = tens.ops.copy()
+ for op in tens_ops:
+ if op.type.is_elementwise_op():
+ decomp_list = decomp_dims_elementwise(op)
+ for part_op in decomp_list:
+ decompose_elem_tensors_hwc(part_op)
+ return tens
+
+
+def reshape_concat_shape(shape, rank, axis):
+ new_h = 1
+ for i in range(axis):
+ new_h *= shape[i]
+ new_c = 1
+ for i in range(axis + 1, rank):
+ new_c *= shape[i]
+ if axis == (rank - 1):
+ new_shape = [new_h, shape[axis], 1]
+ else:
+ new_shape = [new_h, shape[axis], new_c]
+ return new_shape
+
+
+def reshape_concat(op):
+ """
+ Reshapes concat ops with Rank > 3 (H,W,C).
+ """
+ ofm = op.ofm
+ rank = len(ofm.shape)
+ axis = op.attrs["axis"]
+ if axis < 0:
+ axis += rank
+
+ if rank > 3:
+ # Reshape so that axis in to be concatenated is the W dimension
+ # Reshape inputs
+ for inp in op.inputs:
+ new_shape = reshape_concat_shape(inp.shape, rank, axis)
+ op.ifm_shapes.append(Shape4D(new_shape))
+ # Reshape output
+ new_shape = reshape_concat_shape(ofm.shape, rank, axis)
+ op.ofm_shapes.append(Shape4D(new_shape))
+ op.attrs["axis4D"] = 2
+ else:
+ for inp in op.inputs:
+ op.ifm_shapes.append(Shape4D(inp.shape))
+ op.ofm_shapes.append(Shape4D(ofm.shape))
+ op.attrs["axis4D"] = axis + (4 - rank)
+
+
+def decomp_rewrite_concat(tens, arch, nng):
+ """
+ Decompose concat ops with Rank > 3 (H,W,C).
+ Rewrite of concat to elementwise operations
+ """
+ if len(tens.ops) == 1 and tens.ops[0].type == Op.Concat:
op = tens.ops[0]
- rank = len(op.ofm.shape)
- assert rank <= 6
- decomp_list = []
- decomp_list = decomp_dims_elementwise(op)
+ reshape_concat(op)
+ rewrite_concat(op)
+
+ op.ofm.ops.remove(op)
+ for inp in op.inputs:
+ inp.consumer_list.remove(op)
- for part_op in decomp_list:
- decompose_tensors_hwc(part_op)
return tens
@@ -714,21 +765,27 @@ def supported_operator_check(op, arch, nng):
def tosa_optimise_graph(nng, arch):
- # Decomposing to 4 dimensions
+ # TODO the supported operator checking need to be split in semantic and HW checks
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False
+ nng, sg, arch, [], [supported_operator_check], rewrite_unsupported=False,
)
- # Pre-processing step
- pre_process_list = [
- supported_operator_check,
- set_ifm_ofm_op_shapes,
- ]
+ # Decomposing and rewrite of concat
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [decomp_rewrite_concat], [], rewrite_unsupported=False
+ )
+ # Decomposing of elementwise
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
+ nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False
+ )
+
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [], [set_ifm_ofm_op_shapes], rewrite_unsupported=False,
)
# Removal of Transpose
@@ -743,11 +800,6 @@ def tosa_optimise_graph(nng, arch):
nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
)
- # Rewrite concat ops
- for idx, sg in enumerate(nng.subgraphs):
- rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
- sg.refresh_after_modification()
-
# Removal of reshapes
for sg in nng.subgraphs:
rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 1012a61..d71e575 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -46,6 +46,10 @@ class TosaSupportedOperators:
activation_ops = relu_ops | set((Op.Table,))
pad_ops = set((Op.Pad,))
+ rank_unlimited_ops = set((Op.Concat,))
+ rank6_limited_ops = elem_wise_ops
+ batch_enabled_ops = elem_wise_ops | set((Op.Concat,))
+ large_tens_dims_enabled_ops = elem_wise_ops | set((Op.Concat,))
npu_post_ops = activation_ops
supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops
@@ -60,8 +64,10 @@ class TosaSupportedOperators:
self.generic_constraints = []
self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dtype)
self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dimension) # TODO as not supported yet
- self.generic_constraints.append(TosaSupportedOperators.constraint_rank) # TODO as not supported yet
- self.generic_constraints.append(TosaSupportedOperators.constraint_batch) # TODO as not supported yet
+ self.generic_constraints.append(TosaSupportedOperators.constraint_rank) # TODO as not supported for all ops yet
+ self.generic_constraints.append(
+ TosaSupportedOperators.constraint_batch
+ ) # TODO as not supported for all ops yet
# Setup specific constraints. Note: the order matters
self.specific_constraints = defaultdict(list)
@@ -118,11 +124,11 @@ class TosaSupportedOperators:
@classmethod
@docstring_format_args(tens_dim_range)
def constraint_tens_dimension(self, op):
- "Tensor dimensions must be in the range [{}, {}], if not elementwise"
+ "Tensor dimensions must be in the range [{}, {}]"
tens_min, tens_max = self.tens_dim_range
valid = True
extra = []
- if op.type not in self.binary_elem_wise_add_mul_sub:
+ if op.type not in self.large_tens_dims_enabled_ops:
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
if not tensors:
tensors = [tens for tens in op.inputs if tens]
@@ -135,16 +141,20 @@ class TosaSupportedOperators:
# TODO This is for a HW limitation, that is to be resolved in SW later on
@classmethod
def constraint_rank(self, op):
- "Tensor rank must be <= 4, if not elementwise"
+ "Tensor rank must be <= 6 or <= 4 depending on operator"
valid = True
extra = []
- if op.type not in self.binary_elem_wise_add_mul_sub:
+ if op.type not in self.rank_unlimited_ops:
+ if op.type in self.rank6_limited_ops:
+ rank_limit = 6
+ else:
+ rank_limit = 4
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
if not tensors:
tensors = [tens for tens in op.inputs if tens]
for tens in tensors:
rank = len(tens.shape)
- if not rank <= 4:
+ if not rank <= rank_limit:
valid = False
extra.append(f"Tensor '{tens.name}' has rank: {rank}")
return valid, ", ".join(extra)
@@ -152,10 +162,10 @@ class TosaSupportedOperators:
# TODO This is for a HW limitation, that is to be resolved in SW later on
@classmethod
def constraint_batch(self, op):
- "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise"
+ "If Tensor rank is 4 batch of ifms/ofm must be 1"
valid = True
extra = []
- if op.type not in self.binary_elem_wise_add_mul_sub:
+ if op.type not in self.batch_enabled_ops:
tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
if not tensors:
tensors = [tens for tens in op.inputs if tens]