aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-21 14:18:44 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-21 14:37:10 +0200
commit3f22ec2025c8e1afe6780785fd8c62c015824a63 (patch)
treeb7d3324def750afc3a0f4806b195872069e08b62
parent46408a8049f6a51dda5bfa8a4c9959e037120265 (diff)
downloadethos-u-vela-3f22ec2025c8e1afe6780785fd8c62c015824a63.tar.gz
TOSA: Decompose elem op tensors
Added decomposition of tensors exceeding maximum size supported by NPU. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I17a99cb72947d2f1064a631ad6975ce895c258d5
-rw-r--r--ethosu/vela/shape4d.py3
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py234
-rw-r--r--ethosu/vela/tosa_supported_operators.py21
3 files changed, 159 insertions, 99 deletions
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
index 08b2a6a..fd1ee94 100644
--- a/ethosu/vela/shape4d.py
+++ b/ethosu/vela/shape4d.py
@@ -111,6 +111,9 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
def __sub__(self, rhs):
return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth)
+ def floordiv_const(self, const):
+ return Shape4D(self.batch // const, self.height // const, self.width // const, self.depth // const)
+
def __floordiv__(self, rhs):
return Shape4D(
self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index f4aa453..1e059cc 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -503,21 +503,71 @@ def convert_table_to_lut(op, arch, nng):
return convert_to_lut(op, table.values, "table")
-def create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n):
- part_op = op.clone()
- offset = Shape4D(0, 0, 0, 0)
+def decompose_tensors_hwc(op):
+ max_t_size = 65535
+ ofm_shape = op.ofm_shapes[0]
+ ifm_shape = op.ifm_shapes[0]
+ ifm2_shape = op.ifm_shapes[1] if op.ifm_shapes[1] else None
+
+ limit_shape = Shape4D(1, max_t_size, max_t_size, max_t_size)
+
+ if any(dim_size > max_t_size for dim_size in ofm_shape.as_list()):
+ ofm_split = ofm_shape.floordiv_const(max_t_size).add(1, 1, 1, 1)
- part_op.read_offsets[0] = offset.with_batch(ifm_offset_n)
- part_op.read_shapes[0] = op.ifm_shapes[0].with_batch(1)
- part_op.write_offset = offset.with_batch(ofm_offset_n)
- part_op.write_shape = op.ofm_shapes[0].with_batch(1)
+ for height in range(ofm_split.height):
+ for width in range(ofm_split.width):
+ for depth in range(ofm_split.depth):
+ ofm_offset = Shape4D(0, height * max_t_size, width * max_t_size, depth * max_t_size)
+ ofm_part_shape = ofm_shape.clip(ofm_offset, limit_shape)
+ ofm_cut = (ofm_offset, ofm_part_shape)
+
+ ifm_d = depth * max_t_size if ifm_shape.depth == ofm_shape.depth else 0
+ ifm_w = width * max_t_size if ifm_shape.width == ofm_shape.width else 0
+ ifm_h = height * max_t_size if ifm_shape.height == ofm_shape.height else 0
+ ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
+ ifm_part_shape = ifm_shape.clip(ifm_offset, limit_shape)
+ ifm_cut = (ifm_offset, ifm_part_shape)
+
+ if ifm2_shape is not None:
+ ifm2_d = depth * max_t_size if ifm2_shape.depth == ofm_shape.depth else 0
+ ifm2_w = width * max_t_size if ifm2_shape.width == ofm_shape.width else 0
+ ifm2_h = height * max_t_size if ifm2_shape.height == ofm_shape.height else 0
+ ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
+ ifm2_part_shape = ifm2_shape.clip(ifm2_offset, limit_shape)
+ ifm2_cut = (ifm2_offset, ifm2_part_shape)
+ else:
+ ifm2_offset = None
+ ifm2_cut = (None, None)
+
+ create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut)
+ op.ofm.ops.remove(op)
+ op.ifm.consumer_list.remove(op)
+ if op.ifm2 is not None:
+ op.ifm2.consumer_list.remove(op)
+ return
+
+
+def create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut):
+ part_op = op.clone()
+ ifm_read_offset = op.read_offsets[0] if op.read_offsets[0] is not None else Shape4D(0, 0, 0, 0)
+ ofm_write_offset = op.write_offset if op.write_offset is not None else Shape4D(0, 0, 0, 0)
+ ifm_offset, ifm_shape = ifm_cut
+ ofm_offset, ofm_shape = ofm_cut
+
+ part_op.read_offsets[0] = ifm_read_offset + ifm_offset
+ part_op.read_shapes[0] = ifm_shape
+ part_op.write_offset = ofm_write_offset + ofm_offset
+ part_op.write_shape = ofm_shape
part_op.ifm_shapes = op.ifm_shapes.copy()
part_op.ofm_shapes = op.ofm_shapes.copy()
part_op.ifm.consumer_list.append(part_op)
op.ofm.ops.append(part_op)
- if ifm2_offset_n:
- part_op.read_offsets[1] = offset.with_batch(ifm2_offset_n)
- part_op.read_shapes[1] = op.ifm_shapes[1].with_batch(1)
+
+ ifm2_offset, ifm2_shape = ifm2_cut
+ if ifm2_offset:
+ ifm2_read_offset = op.read_offsets[1] if op.read_offsets[1] is not None else Shape4D(0, 0, 0, 0)
+ part_op.read_offsets[1] = ifm2_read_offset + ifm2_offset
+ part_op.read_shapes[1] = ifm2_shape
part_op.ifm2.consumer_list.append(part_op)
@@ -528,114 +578,120 @@ def get_nhwc_stride(shape):
return Shape4D(stride_n, stride_y, stride_x, 1)
-def decomp_unary_elementwise(op):
+def get_elem_shapes_removed_singles(op):
"""
- Decompose binary elementwise ops with Rank > 3 (H,W,D).
- If Rank > 3, all the dimensions above H are viewed as the N dimension.
- the elementwise operation will be decomposed to N (of ofm) elementwise operations.
- By reading and writing with offsets from/to the ifm/ofm.
+ Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm
"""
- ifm = op.ifm
- ofm = op.ofm
- assert op.type.is_unary_elementwise_op()
- assert None not in (ifm, ofm)
- assert ifm.shape == ofm.shape
-
- rank = len(ofm.shape)
- if rank > 3:
- n = rank - 3
- ofm_decomp_shape = Shape4D(ofm.shape[0:n])
- new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
- op.ifm_shapes.append(Shape4D(new_ofm_shape))
- op.ofm_shapes.append(Shape4D(new_ofm_shape))
-
- if new_ofm_shape[0] == 1:
- return
-
- for height in range(ofm_decomp_shape.height):
- for width in range(ofm_decomp_shape.width):
- for depth in range(ofm_decomp_shape.depth):
- ifm_offset, ofm_offset = Shape4D(0, height, width, depth)
- create_elem_part_op(op, ifm_offset, None, ofm_offset)
-
- ifm.consumer_list.remove(op)
- ofm.ops.remove(op)
- return
-
-
-def decomp_binary_elementwise(op):
+ rank = len(op.ofm.shape)
+ binary = op.ifm2 is not None
+ new_ofm_shape = []
+ new_ifm_shape = []
+ new_ifm2_shape = []
+ for idx in range(rank):
+ if op.ofm.shape[idx] != 1:
+ new_ofm_shape.append(op.ofm.shape[idx])
+ new_ifm_shape.append(op.ifm.shape[idx])
+ if binary:
+ new_ifm2_shape.append(op.ifm2.shape[idx])
+ if new_ofm_shape == []:
+ new_ofm_shape = [1]
+ new_ifm_shape = [1]
+ new_ifm2_shape = [1] if binary else None
+
+ return new_ofm_shape, new_ifm_shape, new_ifm2_shape
+
+
+def decomp_dims_elementwise(op):
"""
- Decompose binary elementwise ops with Rank > 3 (H,W,D).
+ Decompose elementwise ops with Rank > 3 (H,W,D).
If Rank > 3, all the dimensions above H are viewed as the N dimension.
the elementwise operation will be decomposed to N (of ofm) elementwise operations.
By reading and writing with offsets from/to the ifm(s)/ofm.
- Note: Broadcast need to be handled, and TOSA allowes for broadcast by both ifm and ifm2
+ Note: Broadcast need to be handled for binary elementwise ops, and TOSA allowes for broadcast by both ifm and ifm2
"""
ifm = op.ifm
ifm2 = op.ifm2
ofm = op.ofm
- assert op.type.is_binary_elementwise_op()
- assert None not in (ifm, ifm2, ofm)
+ binary = op.ifm2 is not None
+ assert len(ofm.shape) <= 6
+
+ # Remove dimensions that are all 1
+ new_ofm_shape, new_ifm_shape, new_ifm2_shape = get_elem_shapes_removed_singles(op)
+ rank = len(new_ofm_shape)
- rank = len(ofm.shape)
if rank > 3:
n = rank - 3
- ofm_decomp_shape = Shape4D(ofm.shape[0:n])
- ifm_decomp_shape = Shape4D(ifm.shape[0:n])
- ifm2_decomp_shape = Shape4D(ifm2.shape[0:n])
-
+ ofm_decomp_shape = Shape4D(new_ofm_shape[0:n])
ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape)
- ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
- ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
-
- new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
- new_ifm_shape = [ifm_decomp_shape.elements()] + ifm.shape[n:]
- new_ifm2_shape = [ifm2_decomp_shape.elements()] + ifm2.shape[n:]
-
- op.ofm_shapes.append(Shape4D(new_ofm_shape))
- op.ifm_shapes.append(Shape4D(new_ifm_shape))
- op.ifm_shapes.append(Shape4D(new_ifm2_shape))
-
- if new_ifm_shape[0] == new_ifm2_shape[0] == new_ofm_shape[0] == 1:
- return
+ ofm_part_shape = Shape4D(new_ofm_shape[n:])
+ op.ofm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
+
+ if binary:
+ ifm_decomp_shape = Shape4D(new_ifm_shape[0:n])
+ ifm2_decomp_shape = Shape4D(new_ifm2_shape[0:n])
+ ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
+ ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
+ ifm_part_shape = Shape4D(new_ifm_shape[n:])
+ ifm2_part_shape = Shape4D(new_ifm2_shape[n:])
+ op.ifm_shapes.append(Shape4D([ifm_decomp_shape.elements()] + new_ifm_shape[n:]))
+ op.ifm_shapes.append(Shape4D([ifm2_decomp_shape.elements()] + new_ifm2_shape[n:]))
+ else:
+ op.ifm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
+ op_list = []
for height in range(ofm_decomp_shape.height):
for width in range(ofm_decomp_shape.width):
for depth in range(ofm_decomp_shape.depth):
ofm_offset = Shape4D(0, height, width, depth)
+ ofm_offset = Shape4D(ofm_offset.dot_prod(ofm_decomp_stride), 0, 0, 0)
+ ofm_cut = (ofm_offset, ofm_part_shape)
+
+ if binary:
+ ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
+ ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
+ ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
+ ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
+ ifm_offset = Shape4D(ifm_offset.dot_prod(ifm_decomp_stride), 0, 0, 0)
+ ifm_cut = (ifm_offset, ifm_part_shape)
+
+ ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
+ ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
+ ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
+ ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
+ ifm2_offset = Shape4D(ifm2_offset.dot_prod(ifm2_decomp_stride), 0, 0, 0)
+ ifm2_cut = (ifm2_offset, ifm2_part_shape)
+ op_list.append(create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut))
+ else:
+ op_list.append(create_elem_part_op(op, ofm_cut, None, ofm_cut))
- ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
- ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
- ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
- ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
-
- ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
- ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
- ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
- ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
-
- ofm_offset_n = ofm_offset.dot_prod(ofm_decomp_stride)
- ifm_offset_n = ifm_offset.dot_prod(ifm_decomp_stride)
- ifm2_offset_n = ifm2_offset.dot_prod(ifm2_decomp_stride)
- create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n)
-
- ifm.consumer_list.remove(op)
- ifm2.consumer_list.remove(op)
ofm.ops.remove(op)
- return
+ ifm.consumer_list.remove(op)
+ if binary:
+ ifm2.consumer_list.remove(op)
+ else:
+ op.ofm_shapes.append(Shape4D(new_ofm_shape))
+ op.ifm_shapes.append(Shape4D(new_ifm_shape))
+ op.ifm_shapes.append(Shape4D(new_ifm2_shape))
+
+ return [op]
def decomp_elementwise(tens, arch, nng):
"""
Decompose elementwise ops with Rank > 3 (H,W,D).
+ Decompose size of tensors exceeding NPU max size
"""
- assert len(tens.ops) == 1
+ if len(tens.ops) == 1 and tens.ops[0].type.is_elementwise_op():
+ op = tens.ops[0]
+ rank = len(op.ofm.shape)
+ assert rank <= 6
+
+ decomp_list = []
+ decomp_list = decomp_dims_elementwise(op)
- if tens.ops[0].type.is_binary_elementwise_op():
- decomp_binary_elementwise(tens.ops[0])
- elif tens.ops[0].type.is_unary_elementwise_op():
- decomp_unary_elementwise(tens.ops[0])
+ for part_op in decomp_list:
+ decompose_tensors_hwc(part_op)
return tens
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index f5eddcc..1012a61 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -117,18 +117,19 @@ class TosaSupportedOperators:
# This is for a HW limitation, that is to be resolved in SW later on
@classmethod
@docstring_format_args(tens_dim_range)
- def constraint_tens_dimension(cls, op):
- "Tensor dimensions must be in the range [{}, {}]"
- tens_min, tens_max = cls.tens_dim_range
+ def constraint_tens_dimension(self, op):
+ "Tensor dimensions must be in the range [{}, {}], if not elementwise"
+ tens_min, tens_max = self.tens_dim_range
valid = True
extra = []
- tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
- if not tensors:
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- if not all(tens_min <= dim <= tens_max for dim in tens.shape):
- valid = False
- extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+ if op.type not in self.binary_elem_wise_add_mul_sub:
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ if not all(tens_min <= dim <= tens_max for dim in tens.shape):
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
return valid, ", ".join(extra)
# TODO This is for a HW limitation, that is to be resolved in SW later on