aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-02-02 09:07:48 +0100
committerJohan Alfven <johan.alfven@arm.com>2023-03-14 11:00:58 +0100
commit90724965751e882c58de74a044cc7adab307bc55 (patch)
tree425ccea87487b66ca298a801b298fbf8567f86d9
parentbb9885190f5f7ea959f171b38ee1dd44d3e1e75e (diff)
downloadethos-u-vela-90724965751e882c58de74a044cc7adab307bc55.tar.gz
MLBEDSW-6260: Add support for using DMA to copy feature maps
- Reshape ops can be bypassed and there is no need to process them by the NPU. There are use cases when the IFM must be preserved so a memcpy is needed. This is implemented by an AvgPool. - In order to reduce the cost of the AvgPool the IFM can be copied by DMA. This is faster and also it can be turned into a real NOP in cases where the IFM and the OFM can use the same memory space. - Added new memcpy op. Only NHWC format supported since DMA can not change the format on the fly. - Allow ofm to reuse ifm for memcpy op - Make sure the DMA copy size is 16 byte aligned Change-Id: I3605a48d47646ff60d2bb3644dd3a23f872235a7 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
-rw-r--r--ethosu/vela/graph_optimiser_util.py19
-rw-r--r--ethosu/vela/high_level_command_stream.py18
-rw-r--r--ethosu/vela/high_level_command_stream_generator.py55
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py7
-rw-r--r--ethosu/vela/live_range.py24
-rw-r--r--ethosu/vela/npu_performance.py32
-rw-r--r--ethosu/vela/operation.py5
-rw-r--r--ethosu/vela/operation_util.py8
-rw-r--r--ethosu/vela/pass_packing.py18
9 files changed, 142 insertions, 44 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 24a55836..e8d5ac64 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -27,7 +27,7 @@ from .debug_database import DebugDatabase
from .errors import UnsupportedFeatureError
from .errors import VelaError
from .operation import Op
-from .operation_util import create_avgpool_nop
+from .operation_util import create_memcpy
from .shape4d import Shape4D
from .tensor import create_const_tensor
from .tensor import QuantizationParameters
@@ -89,6 +89,11 @@ def _avoid_nhcwb16_for_shapes(tens):
return False
+def _avoid_nhcwb16_for_memory_only(tens):
+ # check all producers/consumers to see if any op is preventing NHCWB16
+ return any(op.type == Op.Memcpy for op in (tens.consumer_list + tens.ops))
+
+
# Check if non linear format can be used
def check_format_restrictions(tens, arch):
if len(tens.ops) < 1:
@@ -116,6 +121,10 @@ def check_format_restrictions(tens, arch):
if _avoid_nhcwb16_for_shapes(tens):
return
+ # Memory only ifm/ofm exception: DMA ops must use NHCW
+ if _avoid_nhcwb16_for_memory_only(tens):
+ return
+
# Resize bilinear half pixel center implementation requires OFM with linear format to
# allow stride modification in H/W dimensions.
for op in tens.ops:
@@ -274,10 +283,10 @@ def record_optimised(op, arch):
def insert_copy_op_before_op(op):
- # Create a avg_pool nop op with ifm as input
+ # Create a memcpy op with ifm as input
tens = op.ifm
copy_tens = tens.clone()
- copy_op = create_avgpool_nop(f"{tens.name}_avgpool")
+ copy_op = create_memcpy(f"{tens.name}_memcpy")
copy_op.add_input_tensor(tens)
copy_op.set_output_tensor(copy_tens)
copy_op.set_ifm_ofm_shapes()
@@ -290,9 +299,9 @@ def insert_copy_op_before_op(op):
def insert_copy_op_after_tens(tens):
tens_cons_list_copy = tens.consumer_list.copy()
- # Create a avg_pool nop op with ifm as input
+ # Create a mempcy op with ifm as input
copy_tens = tens.clone()
- copy_op = create_avgpool_nop(tens.name + "_avgpool")
+ copy_op = create_memcpy(tens.name + "_memcpy")
copy_op.add_input_tensor(tens)
copy_op.set_output_tensor(copy_tens)
copy_op.set_ifm_ofm_shapes()
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index 609f8556..09c1805d 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -293,3 +293,19 @@ class DMA(Command):
def get_operation_count(self):
# returns numpy array of (DPU blocks, dma_ops)
return np.array((0, 1))
+
+
+class NOP(Command):
+ def __init__(self, ps, in_tensor, out_tensor):
+ self.ps = ps
+ self.in_tensor = in_tensor
+ self.out_tensor = out_tensor
+
+ def __str__(self):
+ return f"<NOP: in={self.in_tensor.name}, out={self.out_tensor.name}>"
+
+ __repr__ = __str__
+
+ def get_operation_count(self):
+ # returns numpy array of (DPU blocks, dma_ops)
+ return np.array((0, 0))
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 5f6a93a3..770241bc 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -18,6 +18,7 @@
# Generate a high-level command stream from a schedule
from .high_level_command_stream import Box
from .high_level_command_stream import DMA
+from .high_level_command_stream import NOP
from .high_level_command_stream import NpuStripe
from .numeric_util import round_up_divide
from .operation import create_activation_function
@@ -33,6 +34,19 @@ def dma_if_necessary(ps, box, tensor):
yield DMA(ps, src_tensor, tensor, box)
+def dma_feature_map_if_necessary(ps, src_tensor, dst_tensor):
+ box = Box([0] * len(src_tensor.shape), list(src_tensor.shape))
+ src_addr = src_tensor.address_for_coordinate(box.start_coord)
+ dst_addr = dst_tensor.address_for_coordinate(box.start_coord)
+
+ if src_addr != dst_addr or src_tensor.mem_area != dst_tensor.mem_area:
+ yield DMA(ps, src_tensor, dst_tensor, box)
+ else:
+ # Source and destination is the same so no need for a DMA transaction
+ # Create a NOP for visibility when printing the high_level_command_stream
+ yield NOP(ps, src_tensor, dst_tensor)
+
+
def generate_high_level_command_stream_for_schedule(nng, sg, arch, verbose_high_level_command_stream):
res = []
# sg.sched_ops are ordered by execution
@@ -224,21 +238,24 @@ def generate_high_level_commands_for_sched_op(sched_op, schedule):
lut_dma_done = True
yield from dma_if_necessary(sched_op.parent_ps, lut_box, lut_tensor)
- yield NpuStripe(
- sched_op.parent_ps,
- block_config.old_style_representation(),
- is_first_h_stripe,
- is_last_h_stripe,
- ifm_tensor,
- ifm_box,
- ofm_tensor,
- ofm_box,
- weight_tensor,
- weight_box,
- scale_tensor,
- ifm2_tensor=ifm2_tensor,
- ifm2_box=ifm2_box,
- pad_top=pad_top,
- pad_bottom=pad_bottom,
- reversed_operands=sched_op.reversed_operands,
- )
+ if parent_op.type == Op.Memcpy:
+ yield from dma_feature_map_if_necessary(sched_op.parent_ps, ifm_tensor, ofm_tensor)
+ else:
+ yield NpuStripe(
+ sched_op.parent_ps,
+ block_config.old_style_representation(),
+ is_first_h_stripe,
+ is_last_h_stripe,
+ ifm_tensor,
+ ifm_box,
+ ofm_tensor,
+ ofm_box,
+ weight_tensor,
+ weight_box,
+ scale_tensor,
+ ifm2_tensor=ifm2_tensor,
+ ifm2_box=ifm2_box,
+ pad_top=pad_top,
+ pad_bottom=pad_bottom,
+ reversed_operands=sched_op.reversed_operands,
+ )
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 2c62c6f7..7634fe1f 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -54,6 +54,7 @@ from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .high_level_command_stream import Box
from .high_level_command_stream import Command
from .high_level_command_stream import DMA
+from .high_level_command_stream import NOP
from .high_level_command_stream import NpuStripe
from .numeric_util import quantise_float32
from .numeric_util import round_up
@@ -627,7 +628,8 @@ def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
else:
src_addr = cmd.in_tensor.address_for_coordinate(cmd.box.start_coord)
dest_addr = cmd.out_tensor.address_for_coordinate(cmd.box.start_coord)
- sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
+ # DMA must use 16 bytes alignment (tensors are always aligned but the sz calculation uses actual size)
+ sz = round_up(cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr, 16)
src = NpuAddressRange(src_region, int(src_addr), int(sz))
dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
return NpuDmaOperation(src, dest)
@@ -663,6 +665,9 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
for cmd in sg.high_level_command_stream:
if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
print("Warning: Skipping register command stream generation for", cmd.ps)
+ elif isinstance(cmd, NOP):
+ # NOP should not generate anything
+ continue
else:
npu_op = convert_command_to_npu_op(cmd, arch)
npu_op_list.append(npu_op)
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 05e481e0..995a0ccb 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -165,16 +165,11 @@ def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None):
- def _tensor_should_be_ignored(tens):
- if tens.ifm_write_protected:
- return True
- return tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set)
-
- # Check if possible to merge ifm/ofm live ranges of elementwise op
ifm_tens = None
if sched_op.op_type.is_elementwise_op():
+ # Check if possible to merge ifm/ofm live ranges of elementwise op
elem_op = sched_op.parent_op
- if not _tensor_should_be_ignored(elem_op.ofm):
+ if not tensor_should_be_ignored(elem_op.ofm, target_mem_area, target_mem_type_set):
# Check if overwriting the inputs can be allowed
OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm)
@@ -183,7 +178,6 @@ def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None):
inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm))
if elem_op.ifm2 is not None:
inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2))
-
# find an input tensor that can be overwritten by the output
for inp in inps:
if (
@@ -192,7 +186,8 @@ def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None):
# check input tensor is valid
and inp.tens is not None
and inp.tens.shape != []
- and not _tensor_should_be_ignored(inp.tens)
+ and not inp.tens.ifm_write_protected
+ and not tensor_should_be_ignored(inp.tens, target_mem_area, target_mem_type_set)
# check input and output tensors are compatible
and inp.tens.format == outp.tens.format
and inp.tens.dtype == outp.tens.dtype
@@ -203,6 +198,17 @@ def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None):
):
ifm_tens = inp.tens
break
+ elif sched_op.op_type == Op.Memcpy:
+ # Check if possible to merge ifm/ofm live ranges of dma op
+ dma_op = sched_op.parent_op
+ ifm = dma_op.ifm
+ ofm = dma_op.ofm
+ if not (
+ tensor_should_be_ignored(ifm, target_mem_area, target_mem_type_set)
+ or tensor_should_be_ignored(ofm, target_mem_area, target_mem_type_set)
+ ):
+ # Currently DMA only used when bypassing memory only ops so ok to reuse ifm
+ ifm_tens = ifm
return ifm_tens
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 967a7ac0..80011244 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -472,6 +472,10 @@ def measure_cycle_cost(arch, op_type: Op, faf_type: Op, query: PerformanceQuery)
_estimate_output_cycles_per_element(arch, op_type, faf_type, query)
* Shape4D.round_up(query.ofm_shape, ofm_rounding).elements()
)
+ # DMA cycle calculation
+ elif query.npu_block_type == NpuBlockType.Dma:
+ # Return 0 since this is not an actual NPU op
+ cycles.op_cycles = 0
else:
assert False
@@ -541,6 +545,10 @@ def measure_element_access(arch, query: PerformanceQuery):
elif query.ifm2_bits > 8:
# ifm2 is a non 8-bit scalar
access.ifm_read[1] = Shape4D.round_up(query.ifm2_shape, ifm_rounding).elements()
+ # DMA
+ elif query.npu_block_type == NpuBlockType.Dma:
+ # Return empty access since this is not an actual NPU op
+ return access
# Unknown
else:
assert False
@@ -646,18 +654,28 @@ def estimate_full_op_performance(
# LUT Transfer
parent_op = op.parent_op
- lut_transfer_cycles = 0
+ dma_transfer_cycles = 0
if parent_op.activation_lut:
lut_tensor = [tens for tens in parent_op.inputs if tens.purpose == TensorPurpose.LUT][0]
src_tensor = lut_tensor.src_tensor
if src_tensor and lut_tensor.mem_area != src_tensor.mem_area:
bw = src_tensor.storage_size()
- lut_transfer_cycles = measure_mem2mem_cycles(arch, src_tensor.mem_area, lut_tensor.mem_area, bw)
+ dma_transfer_cycles += measure_mem2mem_cycles(arch, src_tensor.mem_area, lut_tensor.mem_area, bw)
bws[src_tensor.mem_area][lut_tensor.purpose][BandwidthDirection.Read] += bw
# LUT read from SHRAM TODO remove?
scaled_bws[lut_tensor.mem_area][lut_tensor.purpose][BandwidthDirection.Read] += bw
+ # DMA Transfer
+ if parent_op.type == Op.Memcpy:
+ src_tensor = parent_op.ifm
+ dst_tensor = parent_op.ofm
+ if src_tensor.mem_area != dst_tensor.mem_area:
+ bw = src_tensor.storage_size()
+ dma_transfer_cycles += measure_mem2mem_cycles(arch, src_tensor.mem_area, dst_tensor.mem_area, bw)
+ bws[src_tensor.mem_area][src_tensor.purpose][BandwidthDirection.Read] += bw
+ bws[dst_tensor.mem_area][src_tensor.purpose][BandwidthDirection.Write] += bw
+
if cost.npu_weights_tensor and cost.buffered_weight_tensors:
# DMA Weight Transfer
sz = 0
@@ -690,11 +708,11 @@ def estimate_full_op_performance(
cycles.op_cycles + cost.full_weight_transfer_cycles - min(ws_first_transfer_cycles, slack_cycles)
)
- # Add cycles for LUT Transfer
- cycles_a[PassCycles.Npu] += lut_transfer_cycles
+ # Add cycles for LUT + mempcy op Transfer
+ cycles_a[PassCycles.Npu] += dma_transfer_cycles
else:
- # Add cycles for LUT Transfer
- cycles_a[PassCycles.Npu] += max(lut_transfer_cycles - slack_cycles, 0)
+ # Add cycles for LUT + mempcy op Transfer
+ cycles_a[PassCycles.Npu] += max(dma_transfer_cycles - slack_cycles, 0)
# OFM write
ofm = op.parent_op.ofm
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 19b00b31..6be9dc25 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -51,6 +51,7 @@ class NpuBlockType(Enum):
ConvolutionDepthWise = 4
ElementWise = 5
ReduceSum = 6
+ Dma = 7
class Kernel:
@@ -174,6 +175,7 @@ class Op(Enum):
)
Dequantize = OperatorInfo(indices=NNG_IFM_INDICES)
Div = OperatorInfo()
+ Memcpy = OperatorInfo(block_type=NpuBlockType.Dma, indices=NNG_IFM_INDICES)
Elu = OperatorInfo()
EmbeddingLookup = OperatorInfo()
EmbeddingLookupSparse = OperatorInfo()
@@ -373,6 +375,9 @@ class Op(Enum):
def is_resize_op(self):
return self in (Op.ResizeBilinear, Op.ResizeNearestNeighbor)
+ def is_memcpy_op(self):
+ return self.info.block_type == NpuBlockType.Dma
+
def needs_bias(self):
return bool(self.info.indices.biases)
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 7b66dff3..21f9dbed 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -51,6 +51,12 @@ def create_add_nop(name: str) -> Operation:
return op
+def create_memcpy(name: str) -> Operation:
+ op = Operation(Op.Memcpy, name)
+ op.run_on_npu = True
+ return op
+
+
def create_pad_nop(name: str) -> Operation:
op = Operation(Op.Pad, name)
op.run_on_npu = True
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 5a9f9575..e43a9191 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -39,6 +39,7 @@ class PassFlags(enum.Flag):
StartupInit = 64
MemoryOnly = 128
PostFusingLimited = 256
+ Memcpy = 512
mac_main_ops = set(
@@ -95,6 +96,7 @@ memory_only_ops = set(
Op.ExpandDims,
)
)
+memcpy_ops = set((Op.Memcpy,))
test_sequence = [
@@ -160,6 +162,16 @@ test_sequence = [
),
(
# ops_set
+ memcpy_ops,
+ # incompatible_pack_flags
+ PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.Mac | PassFlags.Main | PassFlags.PostFusingLimited,
+ # flags_to_set
+ PassFlags.Npu | PassFlags.Memcpy | PassFlags.Main,
+ # flags_to_clear
+ PassFlags.Empty,
+ ),
+ (
+ # ops_set
cpu_ops,
# incompatible_pack_flags
PassFlags.Npu | PassFlags.MemoryOnly | PassFlags.Main,
@@ -248,7 +260,11 @@ def pack_into_passes(nng, arch, verbose_packing=False):
if flags_to_set & PassFlags.Npu:
if flags_to_set & (
- PassFlags.Mac | PassFlags.ElementWise | PassFlags.Post | PassFlags.PostFusingLimited
+ PassFlags.Mac
+ | PassFlags.ElementWise
+ | PassFlags.Post
+ | PassFlags.PostFusingLimited
+ | PassFlags.Memcpy
):
assert len(curr_op.inputs) >= 1
ifm_tensor = curr_op.ifm