diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-09-20 10:47:47 +0200 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-09-20 13:27:15 +0200 |
commit | 46408a8049f6a51dda5bfa8a4c9959e037120265 (patch) | |
tree | 68595457843f3ff4193da0542afbe5de56da8d31 | |
parent | f436ada9caea87ec2dd686a92e41a15c1dcdeb1d (diff) | |
download | ethos-u-vela-46408a8049f6a51dda5bfa8a4c9959e037120265.tar.gz |
TOSA: Elementwise Rank > 4 and Batch > 1
Added support for elementwise operations:
-Support for up to Rank == 6
-Support for Batch > 1 for Rank == 4
-For binary elementwise ops this includes handling
of broadcasting in dimensions above H-dimension
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I73850bbfb288077a99bd2ceecbf989172016da24
-rw-r--r-- | ethosu/vela/operation.py | 1 | ||||
-rw-r--r-- | ethosu/vela/shape4d.py | 3 | ||||
-rw-r--r-- | ethosu/vela/tensor.py | 2 | ||||
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 143 | ||||
-rw-r--r-- | ethosu/vela/tosa_supported_operators.py | 52 |
5 files changed, 175 insertions, 26 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 1558b943..b4267926 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -545,6 +545,7 @@ class Operation: res.rounding_mode = self.rounding_mode res.explicit_scaling = self.explicit_scaling res.low_precision_scaling = self.low_precision_scaling + res.rescale = self.rescale return res diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py index fd674031..08b2a6a0 100644 --- a/ethosu/vela/shape4d.py +++ b/ethosu/vela/shape4d.py @@ -136,6 +136,9 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])): def elements(self): return self.batch * self.width * self.height * self.depth + def dot_prod(self, rhs): + return self.batch * rhs.batch + self.width * rhs.width + self.height * rhs.height + self.depth * rhs.depth + def elements_wh(self): return self.width * self.height diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 37fd06ea..2e70d72e 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -632,7 +632,7 @@ class Tensor: self, coord: Optional[Shape] = None, shape4D: Optional[Shape4D] = None ) -> Tuple[Optional[Shape], Optional[Shape]]: if coord is None: - coord = [0] * len(self.storage_shape) + coord = [0] * min(len(self.storage_shape), 4) if shape4D and self.is_standard_fm: augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list() diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 1ef04449..f4aa4534 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -503,6 +503,142 @@ def convert_table_to_lut(op, arch, nng): return convert_to_lut(op, table.values, "table") +def create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n): + part_op = op.clone() + offset = Shape4D(0, 0, 0, 0) + + part_op.read_offsets[0] = offset.with_batch(ifm_offset_n) + part_op.read_shapes[0] = op.ifm_shapes[0].with_batch(1) + part_op.write_offset = offset.with_batch(ofm_offset_n) + part_op.write_shape = op.ofm_shapes[0].with_batch(1) + part_op.ifm_shapes = op.ifm_shapes.copy() + part_op.ofm_shapes = op.ofm_shapes.copy() + part_op.ifm.consumer_list.append(part_op) + op.ofm.ops.append(part_op) + if ifm2_offset_n: + part_op.read_offsets[1] = offset.with_batch(ifm2_offset_n) + part_op.read_shapes[1] = op.ifm_shapes[1].with_batch(1) + part_op.ifm2.consumer_list.append(part_op) + + +def get_nhwc_stride(shape): + stride_x = shape.depth + stride_y = shape.width * stride_x + stride_n = shape.height * stride_y + return Shape4D(stride_n, stride_y, stride_x, 1) + + +def decomp_unary_elementwise(op): + """ + Decompose binary elementwise ops with Rank > 3 (H,W,D). + If Rank > 3, all the dimensions above H are viewed as the N dimension. + the elementwise operation will be decomposed to N (of ofm) elementwise operations. + By reading and writing with offsets from/to the ifm/ofm. + """ + ifm = op.ifm + ofm = op.ofm + assert op.type.is_unary_elementwise_op() + assert None not in (ifm, ofm) + assert ifm.shape == ofm.shape + + rank = len(ofm.shape) + if rank > 3: + n = rank - 3 + ofm_decomp_shape = Shape4D(ofm.shape[0:n]) + new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:] + op.ifm_shapes.append(Shape4D(new_ofm_shape)) + op.ofm_shapes.append(Shape4D(new_ofm_shape)) + + if new_ofm_shape[0] == 1: + return + + for height in range(ofm_decomp_shape.height): + for width in range(ofm_decomp_shape.width): + for depth in range(ofm_decomp_shape.depth): + ifm_offset, ofm_offset = Shape4D(0, height, width, depth) + create_elem_part_op(op, ifm_offset, None, ofm_offset) + + ifm.consumer_list.remove(op) + ofm.ops.remove(op) + return + + +def decomp_binary_elementwise(op): + """ + Decompose binary elementwise ops with Rank > 3 (H,W,D). + If Rank > 3, all the dimensions above H are viewed as the N dimension. + the elementwise operation will be decomposed to N (of ofm) elementwise operations. + By reading and writing with offsets from/to the ifm(s)/ofm. + Note: Broadcast need to be handled, and TOSA allowes for broadcast by both ifm and ifm2 + """ + + ifm = op.ifm + ifm2 = op.ifm2 + ofm = op.ofm + assert op.type.is_binary_elementwise_op() + assert None not in (ifm, ifm2, ofm) + + rank = len(ofm.shape) + if rank > 3: + n = rank - 3 + ofm_decomp_shape = Shape4D(ofm.shape[0:n]) + ifm_decomp_shape = Shape4D(ifm.shape[0:n]) + ifm2_decomp_shape = Shape4D(ifm2.shape[0:n]) + + ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape) + ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape) + ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape) + + new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:] + new_ifm_shape = [ifm_decomp_shape.elements()] + ifm.shape[n:] + new_ifm2_shape = [ifm2_decomp_shape.elements()] + ifm2.shape[n:] + + op.ofm_shapes.append(Shape4D(new_ofm_shape)) + op.ifm_shapes.append(Shape4D(new_ifm_shape)) + op.ifm_shapes.append(Shape4D(new_ifm2_shape)) + + if new_ifm_shape[0] == new_ifm2_shape[0] == new_ofm_shape[0] == 1: + return + + for height in range(ofm_decomp_shape.height): + for width in range(ofm_decomp_shape.width): + for depth in range(ofm_decomp_shape.depth): + ofm_offset = Shape4D(0, height, width, depth) + + ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0 + ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0 + ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0 + ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d) + + ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0 + ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0 + ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0 + ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d) + + ofm_offset_n = ofm_offset.dot_prod(ofm_decomp_stride) + ifm_offset_n = ifm_offset.dot_prod(ifm_decomp_stride) + ifm2_offset_n = ifm2_offset.dot_prod(ifm2_decomp_stride) + create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n) + + ifm.consumer_list.remove(op) + ifm2.consumer_list.remove(op) + ofm.ops.remove(op) + return + + +def decomp_elementwise(tens, arch, nng): + """ + Decompose elementwise ops with Rank > 3 (H,W,D). + """ + assert len(tens.ops) == 1 + + if tens.ops[0].type.is_binary_elementwise_op(): + decomp_binary_elementwise(tens.ops[0]) + elif tens.ops[0].type.is_unary_elementwise_op(): + decomp_unary_elementwise(tens.ops[0]) + return tens + + def fixup_quantization(op, arch, nng): if op.ifm and op.ifm.quantization.zero_point is None: op.ifm.quantization.zero_point = 0 @@ -521,6 +657,13 @@ def supported_operator_check(op, arch, nng): def tosa_optimise_graph(nng, arch): + + # Decomposing to 4 dimensions + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False + ) + # Pre-processing step pre_process_list = [ supported_operator_check, diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py index 98df27e3..f5eddccc 100644 --- a/ethosu/vela/tosa_supported_operators.py +++ b/ethosu/vela/tosa_supported_operators.py @@ -40,15 +40,15 @@ class TosaSupportedOperators: mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products memory_only_ops = set((Op.Reshape, Op.Transpose, Op.Concat, Op.SplitSliceRead,)) binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.RescaleMul, Op.Sub,)) + elem_wise_ops = binary_elem_wise_add_mul_sub type_conversion_ops = set((Op.Rescale,)) relu_ops = set((Op.Clamp, Op.ReluN,)) activation_ops = relu_ops | set((Op.Table,)) pad_ops = set((Op.Pad,)) npu_post_ops = activation_ops - supported_operators = ( - mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | binary_elem_wise_add_mul_sub | pad_ops - ) + + supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops # Supported data types # TODO will differ compared to TensorFlow Lite, currently set to the same @@ -132,35 +132,37 @@ class TosaSupportedOperators: return valid, ", ".join(extra) # TODO This is for a HW limitation, that is to be resolved in SW later on - @staticmethod - def constraint_rank(op): - "Tensor rank must be <= 4" + @classmethod + def constraint_rank(self, op): + "Tensor rank must be <= 4, if not elementwise" valid = True extra = [] - tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] - if not tensors: - tensors = [tens for tens in op.inputs if tens] - for tens in tensors: - rank = len(tens.shape) - if not rank <= 4: - valid = False - extra.append(f"Tensor '{tens.name}' has rank: {rank}") + if op.type not in self.binary_elem_wise_add_mul_sub: + tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] + if not tensors: + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + rank = len(tens.shape) + if not rank <= 4: + valid = False + extra.append(f"Tensor '{tens.name}' has rank: {rank}") return valid, ", ".join(extra) # TODO This is for a HW limitation, that is to be resolved in SW later on - @staticmethod - def constraint_batch(op): - "If Tensor rank is 4 batch of ifms/ofm must be 1" + @classmethod + def constraint_batch(self, op): + "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise" valid = True extra = [] - tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens] - if not tensors: - tensors = [tens for tens in op.inputs if tens] - for tens in tensors: - rank = len(tens.shape) - if rank == 4 and tens.shape[0] != 1: - valid = False - extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}") + if op.type not in self.binary_elem_wise_add_mul_sub: + tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens] + if not tensors: + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + rank = len(tens.shape) + if rank == 4 and tens.shape[0] != 1: + valid = False + extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}") return valid, ", ".join(extra) @staticmethod |