From 90724965751e882c58de74a044cc7adab307bc55 Mon Sep 17 00:00:00 2001 From: Johan Alfven Date: Thu, 2 Feb 2023 09:07:48 +0100 Subject: MLBEDSW-6260: Add support for using DMA to copy feature maps - Reshape ops can be bypassed and there is no need to process them by the NPU. There are use cases when the IFM must be preserved so a memcpy is needed. This is implemented by an AvgPool. - In order to reduce the cost of the AvgPool the IFM can be copied by DMA. This is faster and also it can be turned into a real NOP in cases where the IFM and the OFM can use the same memory space. - Added new memcpy op. Only NHWC format supported since DMA can not change the format on the fly. - Allow ofm to reuse ifm for memcpy op - Make sure the DMA copy size is 16 byte aligned Change-Id: I3605a48d47646ff60d2bb3644dd3a23f872235a7 Signed-off-by: Johan Alfven --- ethosu/vela/graph_optimiser_util.py | 19 ++++++-- ethosu/vela/high_level_command_stream.py | 18 ++++++- ethosu/vela/high_level_command_stream_generator.py | 55 ++++++++++++++-------- ethosu/vela/high_level_command_to_npu_op.py | 7 ++- ethosu/vela/live_range.py | 24 ++++++---- ethosu/vela/npu_performance.py | 32 ++++++++++--- ethosu/vela/operation.py | 5 ++ ethosu/vela/operation_util.py | 8 +++- ethosu/vela/pass_packing.py | 18 ++++++- 9 files changed, 142 insertions(+), 44 deletions(-) (limited to 'ethosu/vela') diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index 24a55836..e8d5ac64 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -27,7 +27,7 @@ from .debug_database import DebugDatabase from .errors import UnsupportedFeatureError from .errors import VelaError from .operation import Op -from .operation_util import create_avgpool_nop +from .operation_util import create_memcpy from .shape4d import Shape4D from .tensor import create_const_tensor from .tensor import QuantizationParameters @@ -89,6 +89,11 @@ def _avoid_nhcwb16_for_shapes(tens): return False +def _avoid_nhcwb16_for_memory_only(tens): + # check all producers/consumers to see if any op is preventing NHCWB16 + return any(op.type == Op.Memcpy for op in (tens.consumer_list + tens.ops)) + + # Check if non linear format can be used def check_format_restrictions(tens, arch): if len(tens.ops) < 1: @@ -116,6 +121,10 @@ def check_format_restrictions(tens, arch): if _avoid_nhcwb16_for_shapes(tens): return + # Memory only ifm/ofm exception: DMA ops must use NHCW + if _avoid_nhcwb16_for_memory_only(tens): + return + # Resize bilinear half pixel center implementation requires OFM with linear format to # allow stride modification in H/W dimensions. for op in tens.ops: @@ -274,10 +283,10 @@ def record_optimised(op, arch): def insert_copy_op_before_op(op): - # Create a avg_pool nop op with ifm as input + # Create a memcpy op with ifm as input tens = op.ifm copy_tens = tens.clone() - copy_op = create_avgpool_nop(f"{tens.name}_avgpool") + copy_op = create_memcpy(f"{tens.name}_memcpy") copy_op.add_input_tensor(tens) copy_op.set_output_tensor(copy_tens) copy_op.set_ifm_ofm_shapes() @@ -290,9 +299,9 @@ def insert_copy_op_before_op(op): def insert_copy_op_after_tens(tens): tens_cons_list_copy = tens.consumer_list.copy() - # Create a avg_pool nop op with ifm as input + # Create a mempcy op with ifm as input copy_tens = tens.clone() - copy_op = create_avgpool_nop(tens.name + "_avgpool") + copy_op = create_memcpy(tens.name + "_memcpy") copy_op.add_input_tensor(tens) copy_op.set_output_tensor(copy_tens) copy_op.set_ifm_ofm_shapes() diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py index 609f8556..09c1805d 100644 --- a/ethosu/vela/high_level_command_stream.py +++ b/ethosu/vela/high_level_command_stream.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -293,3 +293,19 @@ class DMA(Command): def get_operation_count(self): # returns numpy array of (DPU blocks, dma_ops) return np.array((0, 1)) + + +class NOP(Command): + def __init__(self, ps, in_tensor, out_tensor): + self.ps = ps + self.in_tensor = in_tensor + self.out_tensor = out_tensor + + def __str__(self): + return f"" + + __repr__ = __str__ + + def get_operation_count(self): + # returns numpy array of (DPU blocks, dma_ops) + return np.array((0, 0)) diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index 5f6a93a3..770241bc 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -18,6 +18,7 @@ # Generate a high-level command stream from a schedule from .high_level_command_stream import Box from .high_level_command_stream import DMA +from .high_level_command_stream import NOP from .high_level_command_stream import NpuStripe from .numeric_util import round_up_divide from .operation import create_activation_function @@ -33,6 +34,19 @@ def dma_if_necessary(ps, box, tensor): yield DMA(ps, src_tensor, tensor, box) +def dma_feature_map_if_necessary(ps, src_tensor, dst_tensor): + box = Box([0] * len(src_tensor.shape), list(src_tensor.shape)) + src_addr = src_tensor.address_for_coordinate(box.start_coord) + dst_addr = dst_tensor.address_for_coordinate(box.start_coord) + + if src_addr != dst_addr or src_tensor.mem_area != dst_tensor.mem_area: + yield DMA(ps, src_tensor, dst_tensor, box) + else: + # Source and destination is the same so no need for a DMA transaction + # Create a NOP for visibility when printing the high_level_command_stream + yield NOP(ps, src_tensor, dst_tensor) + + def generate_high_level_command_stream_for_schedule(nng, sg, arch, verbose_high_level_command_stream): res = [] # sg.sched_ops are ordered by execution @@ -224,21 +238,24 @@ def generate_high_level_commands_for_sched_op(sched_op, schedule): lut_dma_done = True yield from dma_if_necessary(sched_op.parent_ps, lut_box, lut_tensor) - yield NpuStripe( - sched_op.parent_ps, - block_config.old_style_representation(), - is_first_h_stripe, - is_last_h_stripe, - ifm_tensor, - ifm_box, - ofm_tensor, - ofm_box, - weight_tensor, - weight_box, - scale_tensor, - ifm2_tensor=ifm2_tensor, - ifm2_box=ifm2_box, - pad_top=pad_top, - pad_bottom=pad_bottom, - reversed_operands=sched_op.reversed_operands, - ) + if parent_op.type == Op.Memcpy: + yield from dma_feature_map_if_necessary(sched_op.parent_ps, ifm_tensor, ofm_tensor) + else: + yield NpuStripe( + sched_op.parent_ps, + block_config.old_style_representation(), + is_first_h_stripe, + is_last_h_stripe, + ifm_tensor, + ifm_box, + ofm_tensor, + ofm_box, + weight_tensor, + weight_box, + scale_tensor, + ifm2_tensor=ifm2_tensor, + ifm2_box=ifm2_box, + pad_top=pad_top, + pad_bottom=pad_bottom, + reversed_operands=sched_op.reversed_operands, + ) diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py index 2c62c6f7..7634fe1f 100644 --- a/ethosu/vela/high_level_command_to_npu_op.py +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -54,6 +54,7 @@ from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .high_level_command_stream import Box from .high_level_command_stream import Command from .high_level_command_stream import DMA +from .high_level_command_stream import NOP from .high_level_command_stream import NpuStripe from .numeric_util import quantise_float32 from .numeric_util import round_up @@ -627,7 +628,8 @@ def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation: else: src_addr = cmd.in_tensor.address_for_coordinate(cmd.box.start_coord) dest_addr = cmd.out_tensor.address_for_coordinate(cmd.box.start_coord) - sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr + # DMA must use 16 bytes alignment (tensors are always aligned but the sz calculation uses actual size) + sz = round_up(cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr, 16) src = NpuAddressRange(src_region, int(src_addr), int(sz)) dest = NpuAddressRange(dest_region, int(dest_addr), int(sz)) return NpuDmaOperation(src, dest) @@ -663,6 +665,9 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False): for cmd in sg.high_level_command_stream: if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default: print("Warning: Skipping register command stream generation for", cmd.ps) + elif isinstance(cmd, NOP): + # NOP should not generate anything + continue else: npu_op = convert_command_to_npu_op(cmd, arch) npu_op_list.append(npu_op) diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index 05e481e0..995a0ccb 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -165,16 +165,11 @@ def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None): - def _tensor_should_be_ignored(tens): - if tens.ifm_write_protected: - return True - return tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set) - - # Check if possible to merge ifm/ofm live ranges of elementwise op ifm_tens = None if sched_op.op_type.is_elementwise_op(): + # Check if possible to merge ifm/ofm live ranges of elementwise op elem_op = sched_op.parent_op - if not _tensor_should_be_ignored(elem_op.ofm): + if not tensor_should_be_ignored(elem_op.ofm, target_mem_area, target_mem_type_set): # Check if overwriting the inputs can be allowed OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"]) outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm) @@ -183,7 +178,6 @@ def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None): inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm)) if elem_op.ifm2 is not None: inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2)) - # find an input tensor that can be overwritten by the output for inp in inps: if ( @@ -192,7 +186,8 @@ def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None): # check input tensor is valid and inp.tens is not None and inp.tens.shape != [] - and not _tensor_should_be_ignored(inp.tens) + and not inp.tens.ifm_write_protected + and not tensor_should_be_ignored(inp.tens, target_mem_area, target_mem_type_set) # check input and output tensors are compatible and inp.tens.format == outp.tens.format and inp.tens.dtype == outp.tens.dtype @@ -203,6 +198,17 @@ def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None): ): ifm_tens = inp.tens break + elif sched_op.op_type == Op.Memcpy: + # Check if possible to merge ifm/ofm live ranges of dma op + dma_op = sched_op.parent_op + ifm = dma_op.ifm + ofm = dma_op.ofm + if not ( + tensor_should_be_ignored(ifm, target_mem_area, target_mem_type_set) + or tensor_should_be_ignored(ofm, target_mem_area, target_mem_type_set) + ): + # Currently DMA only used when bypassing memory only ops so ok to reuse ifm + ifm_tens = ifm return ifm_tens diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py index 967a7ac0..80011244 100644 --- a/ethosu/vela/npu_performance.py +++ b/ethosu/vela/npu_performance.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -472,6 +472,10 @@ def measure_cycle_cost(arch, op_type: Op, faf_type: Op, query: PerformanceQuery) _estimate_output_cycles_per_element(arch, op_type, faf_type, query) * Shape4D.round_up(query.ofm_shape, ofm_rounding).elements() ) + # DMA cycle calculation + elif query.npu_block_type == NpuBlockType.Dma: + # Return 0 since this is not an actual NPU op + cycles.op_cycles = 0 else: assert False @@ -541,6 +545,10 @@ def measure_element_access(arch, query: PerformanceQuery): elif query.ifm2_bits > 8: # ifm2 is a non 8-bit scalar access.ifm_read[1] = Shape4D.round_up(query.ifm2_shape, ifm_rounding).elements() + # DMA + elif query.npu_block_type == NpuBlockType.Dma: + # Return empty access since this is not an actual NPU op + return access # Unknown else: assert False @@ -646,18 +654,28 @@ def estimate_full_op_performance( # LUT Transfer parent_op = op.parent_op - lut_transfer_cycles = 0 + dma_transfer_cycles = 0 if parent_op.activation_lut: lut_tensor = [tens for tens in parent_op.inputs if tens.purpose == TensorPurpose.LUT][0] src_tensor = lut_tensor.src_tensor if src_tensor and lut_tensor.mem_area != src_tensor.mem_area: bw = src_tensor.storage_size() - lut_transfer_cycles = measure_mem2mem_cycles(arch, src_tensor.mem_area, lut_tensor.mem_area, bw) + dma_transfer_cycles += measure_mem2mem_cycles(arch, src_tensor.mem_area, lut_tensor.mem_area, bw) bws[src_tensor.mem_area][lut_tensor.purpose][BandwidthDirection.Read] += bw # LUT read from SHRAM TODO remove? scaled_bws[lut_tensor.mem_area][lut_tensor.purpose][BandwidthDirection.Read] += bw + # DMA Transfer + if parent_op.type == Op.Memcpy: + src_tensor = parent_op.ifm + dst_tensor = parent_op.ofm + if src_tensor.mem_area != dst_tensor.mem_area: + bw = src_tensor.storage_size() + dma_transfer_cycles += measure_mem2mem_cycles(arch, src_tensor.mem_area, dst_tensor.mem_area, bw) + bws[src_tensor.mem_area][src_tensor.purpose][BandwidthDirection.Read] += bw + bws[dst_tensor.mem_area][src_tensor.purpose][BandwidthDirection.Write] += bw + if cost.npu_weights_tensor and cost.buffered_weight_tensors: # DMA Weight Transfer sz = 0 @@ -690,11 +708,11 @@ def estimate_full_op_performance( cycles.op_cycles + cost.full_weight_transfer_cycles - min(ws_first_transfer_cycles, slack_cycles) ) - # Add cycles for LUT Transfer - cycles_a[PassCycles.Npu] += lut_transfer_cycles + # Add cycles for LUT + mempcy op Transfer + cycles_a[PassCycles.Npu] += dma_transfer_cycles else: - # Add cycles for LUT Transfer - cycles_a[PassCycles.Npu] += max(lut_transfer_cycles - slack_cycles, 0) + # Add cycles for LUT + mempcy op Transfer + cycles_a[PassCycles.Npu] += max(dma_transfer_cycles - slack_cycles, 0) # OFM write ofm = op.parent_op.ofm diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 19b00b31..6be9dc25 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -51,6 +51,7 @@ class NpuBlockType(Enum): ConvolutionDepthWise = 4 ElementWise = 5 ReduceSum = 6 + Dma = 7 class Kernel: @@ -174,6 +175,7 @@ class Op(Enum): ) Dequantize = OperatorInfo(indices=NNG_IFM_INDICES) Div = OperatorInfo() + Memcpy = OperatorInfo(block_type=NpuBlockType.Dma, indices=NNG_IFM_INDICES) Elu = OperatorInfo() EmbeddingLookup = OperatorInfo() EmbeddingLookupSparse = OperatorInfo() @@ -373,6 +375,9 @@ class Op(Enum): def is_resize_op(self): return self in (Op.ResizeBilinear, Op.ResizeNearestNeighbor) + def is_memcpy_op(self): + return self.info.block_type == NpuBlockType.Dma + def needs_bias(self): return bool(self.info.indices.biases) diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py index 7b66dff3..21f9dbed 100644 --- a/ethosu/vela/operation_util.py +++ b/ethosu/vela/operation_util.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -51,6 +51,12 @@ def create_add_nop(name: str) -> Operation: return op +def create_memcpy(name: str) -> Operation: + op = Operation(Op.Memcpy, name) + op.run_on_npu = True + return op + + def create_pad_nop(name: str) -> Operation: op = Operation(Op.Pad, name) op.run_on_npu = True diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index 5a9f9575..e43a9191 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -39,6 +39,7 @@ class PassFlags(enum.Flag): StartupInit = 64 MemoryOnly = 128 PostFusingLimited = 256 + Memcpy = 512 mac_main_ops = set( @@ -95,6 +96,7 @@ memory_only_ops = set( Op.ExpandDims, ) ) +memcpy_ops = set((Op.Memcpy,)) test_sequence = [ @@ -158,6 +160,16 @@ test_sequence = [ # flags_to_clear PassFlags.Empty, ), + ( + # ops_set + memcpy_ops, + # incompatible_pack_flags + PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.Mac | PassFlags.Main | PassFlags.PostFusingLimited, + # flags_to_set + PassFlags.Npu | PassFlags.Memcpy | PassFlags.Main, + # flags_to_clear + PassFlags.Empty, + ), ( # ops_set cpu_ops, @@ -248,7 +260,11 @@ def pack_into_passes(nng, arch, verbose_packing=False): if flags_to_set & PassFlags.Npu: if flags_to_set & ( - PassFlags.Mac | PassFlags.ElementWise | PassFlags.Post | PassFlags.PostFusingLimited + PassFlags.Mac + | PassFlags.ElementWise + | PassFlags.Post + | PassFlags.PostFusingLimited + | PassFlags.Memcpy ): assert len(curr_op.inputs) >= 1 ifm_tensor = curr_op.ifm -- cgit v1.2.1