aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-20 10:47:47 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-20 13:27:15 +0200
commit46408a8049f6a51dda5bfa8a4c9959e037120265 (patch)
tree68595457843f3ff4193da0542afbe5de56da8d31
parentf436ada9caea87ec2dd686a92e41a15c1dcdeb1d (diff)
downloadethos-u-vela-46408a8049f6a51dda5bfa8a4c9959e037120265.tar.gz
TOSA: Elementwise Rank > 4 and Batch > 1
Added support for elementwise operations: -Support for up to Rank == 6 -Support for Batch > 1 for Rank == 4 -For binary elementwise ops this includes handling of broadcasting in dimensions above H-dimension Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I73850bbfb288077a99bd2ceecbf989172016da24
-rw-r--r--ethosu/vela/operation.py1
-rw-r--r--ethosu/vela/shape4d.py3
-rw-r--r--ethosu/vela/tensor.py2
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py143
-rw-r--r--ethosu/vela/tosa_supported_operators.py52
5 files changed, 175 insertions, 26 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 1558b943..b4267926 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -545,6 +545,7 @@ class Operation:
res.rounding_mode = self.rounding_mode
res.explicit_scaling = self.explicit_scaling
res.low_precision_scaling = self.low_precision_scaling
+ res.rescale = self.rescale
return res
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
index fd674031..08b2a6a0 100644
--- a/ethosu/vela/shape4d.py
+++ b/ethosu/vela/shape4d.py
@@ -136,6 +136,9 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
def elements(self):
return self.batch * self.width * self.height * self.depth
+ def dot_prod(self, rhs):
+ return self.batch * rhs.batch + self.width * rhs.width + self.height * rhs.height + self.depth * rhs.depth
+
def elements_wh(self):
return self.width * self.height
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 37fd06ea..2e70d72e 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -632,7 +632,7 @@ class Tensor:
self, coord: Optional[Shape] = None, shape4D: Optional[Shape4D] = None
) -> Tuple[Optional[Shape], Optional[Shape]]:
if coord is None:
- coord = [0] * len(self.storage_shape)
+ coord = [0] * min(len(self.storage_shape), 4)
if shape4D and self.is_standard_fm:
augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 1ef04449..f4aa4534 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -503,6 +503,142 @@ def convert_table_to_lut(op, arch, nng):
return convert_to_lut(op, table.values, "table")
+def create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n):
+ part_op = op.clone()
+ offset = Shape4D(0, 0, 0, 0)
+
+ part_op.read_offsets[0] = offset.with_batch(ifm_offset_n)
+ part_op.read_shapes[0] = op.ifm_shapes[0].with_batch(1)
+ part_op.write_offset = offset.with_batch(ofm_offset_n)
+ part_op.write_shape = op.ofm_shapes[0].with_batch(1)
+ part_op.ifm_shapes = op.ifm_shapes.copy()
+ part_op.ofm_shapes = op.ofm_shapes.copy()
+ part_op.ifm.consumer_list.append(part_op)
+ op.ofm.ops.append(part_op)
+ if ifm2_offset_n:
+ part_op.read_offsets[1] = offset.with_batch(ifm2_offset_n)
+ part_op.read_shapes[1] = op.ifm_shapes[1].with_batch(1)
+ part_op.ifm2.consumer_list.append(part_op)
+
+
+def get_nhwc_stride(shape):
+ stride_x = shape.depth
+ stride_y = shape.width * stride_x
+ stride_n = shape.height * stride_y
+ return Shape4D(stride_n, stride_y, stride_x, 1)
+
+
+def decomp_unary_elementwise(op):
+ """
+ Decompose binary elementwise ops with Rank > 3 (H,W,D).
+ If Rank > 3, all the dimensions above H are viewed as the N dimension.
+ the elementwise operation will be decomposed to N (of ofm) elementwise operations.
+ By reading and writing with offsets from/to the ifm/ofm.
+ """
+ ifm = op.ifm
+ ofm = op.ofm
+ assert op.type.is_unary_elementwise_op()
+ assert None not in (ifm, ofm)
+ assert ifm.shape == ofm.shape
+
+ rank = len(ofm.shape)
+ if rank > 3:
+ n = rank - 3
+ ofm_decomp_shape = Shape4D(ofm.shape[0:n])
+ new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
+ op.ifm_shapes.append(Shape4D(new_ofm_shape))
+ op.ofm_shapes.append(Shape4D(new_ofm_shape))
+
+ if new_ofm_shape[0] == 1:
+ return
+
+ for height in range(ofm_decomp_shape.height):
+ for width in range(ofm_decomp_shape.width):
+ for depth in range(ofm_decomp_shape.depth):
+ ifm_offset, ofm_offset = Shape4D(0, height, width, depth)
+ create_elem_part_op(op, ifm_offset, None, ofm_offset)
+
+ ifm.consumer_list.remove(op)
+ ofm.ops.remove(op)
+ return
+
+
+def decomp_binary_elementwise(op):
+ """
+ Decompose binary elementwise ops with Rank > 3 (H,W,D).
+ If Rank > 3, all the dimensions above H are viewed as the N dimension.
+ the elementwise operation will be decomposed to N (of ofm) elementwise operations.
+ By reading and writing with offsets from/to the ifm(s)/ofm.
+ Note: Broadcast need to be handled, and TOSA allowes for broadcast by both ifm and ifm2
+ """
+
+ ifm = op.ifm
+ ifm2 = op.ifm2
+ ofm = op.ofm
+ assert op.type.is_binary_elementwise_op()
+ assert None not in (ifm, ifm2, ofm)
+
+ rank = len(ofm.shape)
+ if rank > 3:
+ n = rank - 3
+ ofm_decomp_shape = Shape4D(ofm.shape[0:n])
+ ifm_decomp_shape = Shape4D(ifm.shape[0:n])
+ ifm2_decomp_shape = Shape4D(ifm2.shape[0:n])
+
+ ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape)
+ ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
+ ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
+
+ new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
+ new_ifm_shape = [ifm_decomp_shape.elements()] + ifm.shape[n:]
+ new_ifm2_shape = [ifm2_decomp_shape.elements()] + ifm2.shape[n:]
+
+ op.ofm_shapes.append(Shape4D(new_ofm_shape))
+ op.ifm_shapes.append(Shape4D(new_ifm_shape))
+ op.ifm_shapes.append(Shape4D(new_ifm2_shape))
+
+ if new_ifm_shape[0] == new_ifm2_shape[0] == new_ofm_shape[0] == 1:
+ return
+
+ for height in range(ofm_decomp_shape.height):
+ for width in range(ofm_decomp_shape.width):
+ for depth in range(ofm_decomp_shape.depth):
+ ofm_offset = Shape4D(0, height, width, depth)
+
+ ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
+ ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
+ ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
+ ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
+
+ ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
+ ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
+ ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
+ ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
+
+ ofm_offset_n = ofm_offset.dot_prod(ofm_decomp_stride)
+ ifm_offset_n = ifm_offset.dot_prod(ifm_decomp_stride)
+ ifm2_offset_n = ifm2_offset.dot_prod(ifm2_decomp_stride)
+ create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n)
+
+ ifm.consumer_list.remove(op)
+ ifm2.consumer_list.remove(op)
+ ofm.ops.remove(op)
+ return
+
+
+def decomp_elementwise(tens, arch, nng):
+ """
+ Decompose elementwise ops with Rank > 3 (H,W,D).
+ """
+ assert len(tens.ops) == 1
+
+ if tens.ops[0].type.is_binary_elementwise_op():
+ decomp_binary_elementwise(tens.ops[0])
+ elif tens.ops[0].type.is_unary_elementwise_op():
+ decomp_unary_elementwise(tens.ops[0])
+ return tens
+
+
def fixup_quantization(op, arch, nng):
if op.ifm and op.ifm.quantization.zero_point is None:
op.ifm.quantization.zero_point = 0
@@ -521,6 +657,13 @@ def supported_operator_check(op, arch, nng):
def tosa_optimise_graph(nng, arch):
+
+ # Decomposing to 4 dimensions
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False
+ )
+
# Pre-processing step
pre_process_list = [
supported_operator_check,
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 98df27e3..f5eddccc 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -40,15 +40,15 @@ class TosaSupportedOperators:
mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products
memory_only_ops = set((Op.Reshape, Op.Transpose, Op.Concat, Op.SplitSliceRead,))
binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.RescaleMul, Op.Sub,))
+ elem_wise_ops = binary_elem_wise_add_mul_sub
type_conversion_ops = set((Op.Rescale,))
relu_ops = set((Op.Clamp, Op.ReluN,))
activation_ops = relu_ops | set((Op.Table,))
pad_ops = set((Op.Pad,))
npu_post_ops = activation_ops
- supported_operators = (
- mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | binary_elem_wise_add_mul_sub | pad_ops
- )
+
+ supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops
# Supported data types
# TODO will differ compared to TensorFlow Lite, currently set to the same
@@ -132,35 +132,37 @@ class TosaSupportedOperators:
return valid, ", ".join(extra)
# TODO This is for a HW limitation, that is to be resolved in SW later on
- @staticmethod
- def constraint_rank(op):
- "Tensor rank must be <= 4"
+ @classmethod
+ def constraint_rank(self, op):
+ "Tensor rank must be <= 4, if not elementwise"
valid = True
extra = []
- tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
- if not tensors:
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- rank = len(tens.shape)
- if not rank <= 4:
- valid = False
- extra.append(f"Tensor '{tens.name}' has rank: {rank}")
+ if op.type not in self.binary_elem_wise_add_mul_sub:
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ rank = len(tens.shape)
+ if not rank <= 4:
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has rank: {rank}")
return valid, ", ".join(extra)
# TODO This is for a HW limitation, that is to be resolved in SW later on
- @staticmethod
- def constraint_batch(op):
- "If Tensor rank is 4 batch of ifms/ofm must be 1"
+ @classmethod
+ def constraint_batch(self, op):
+ "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise"
valid = True
extra = []
- tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
- if not tensors:
- tensors = [tens for tens in op.inputs if tens]
- for tens in tensors:
- rank = len(tens.shape)
- if rank == 4 and tens.shape[0] != 1:
- valid = False
- extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}")
+ if op.type not in self.binary_elem_wise_add_mul_sub:
+ tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ rank = len(tens.shape)
+ if rank == 4 and tens.shape[0] != 1:
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}")
return valid, ", ".join(extra)
@staticmethod