From e8a5a78dd16ec979c7a7bb1f5bd87e9b2909c32d Mon Sep 17 00:00:00 2001 From: Louis Verhaard Date: Mon, 2 Nov 2020 18:04:27 +0100 Subject: MLBEDSW-839: Code generation using external API Added external API to generate register command streams. Existing code generation has been refactored to make use of this API. Change-Id: Ibb4c2b167809869f16470b14da24f08a65c82b7b Signed-off-by: Louis Verhaard --- ethosu/vela/api.py | 369 ++++ ethosu/vela/architecture_features.py | 10 +- ethosu/vela/compiler_driver.py | 2 +- ethosu/vela/graph_optimiser.py | 5 +- ethosu/vela/high_level_command_stream.py | 64 - ethosu/vela/high_level_command_stream_generator.py | 3 +- ethosu/vela/high_level_command_to_npu_op.py | 497 +++++ ethosu/vela/lut.py | 4 +- ethosu/vela/npu_performance.py | 2 +- ethosu/vela/operation.py | 75 +- ethosu/vela/register_command_stream_generator.py | 1970 +++++++++++--------- ethosu/vela/shared_buffer_allocation.py | 192 +- ethosu/vela/softmax.py | 39 +- ethosu/vela/supported_operators.py | 10 +- .../vela/test/extapi/test_extapi_encode_weights.py | 5 +- .../test/extapi/test_extapi_generate_commands.py | 370 ++++ .../vela/test/test_register_command_generator.py | 104 ++ ethosu/vela/test/test_supported_operators.py | 3 +- ethosu/vela/tflite_reader.py | 5 +- ethosu/vela/tflite_writer.py | 3 +- ethosu/vela/weight_compressor.py | 20 +- 21 files changed, 2686 insertions(+), 1066 deletions(-) create mode 100644 ethosu/vela/api.py create mode 100644 ethosu/vela/high_level_command_to_npu_op.py create mode 100644 ethosu/vela/test/extapi/test_extapi_generate_commands.py create mode 100644 ethosu/vela/test/test_register_command_generator.py diff --git a/ethosu/vela/api.py b/ethosu/vela/api.py new file mode 100644 index 0000000..06de0d9 --- /dev/null +++ b/ethosu/vela/api.py @@ -0,0 +1,369 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# Contains data types used in the external API for code generation +from enum import auto +from enum import Enum +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Tuple + + +class NpuElementWiseOp(Enum): + """ + Elementwise operation + """ + + ADD = auto() + SUB = auto() + MUL = auto() + ABS = auto() + MIN = auto() + MAX = auto() + LRELU = auto() # Leaky relu + CLZ = auto() # Number leading zeros + SHR = auto() # Rounded right-shift + SHL = auto() # Bitwise shift-left + + +class NpuPoolingOp(Enum): + """ + Pooling operation + """ + + MAX = auto() + AVERAGE = auto() + REDUCE_SUM = auto() + + +class NpuActivationOp(Enum): + """ + Activation function + """ + + NONE_OR_RELU = auto() # Clamps output using min/max + TANH = auto() + SIGMOID = auto() + TABLE_LOOKUP = auto() # Performs table look-up, using the provided table lookup index + + +class NpuRoundingMode(Enum): + """ + Available rounding modes + """ + + TFL = auto() # TensorFlow Lite rounding + TRUNCATE = auto() # Truncate towards zero + NATURAL = auto() # Round to nearest with x.5 rounded up, towards +infinity + + +class NpuLayout(Enum): + """ + Tensor layout of feature maps + """ + + NHWC = auto() + NHCWB16 = auto() + + def __str__(self): + return self.name + + +class NpuResamplingMode(Enum): + """ + Resampling mode + """ + + NONE = auto() # No resampling is performed + NEAREST = auto() # 2x2 insert nearest + TRANSPOSE = auto() # 2x2 transpose + + +class NpuBlockTraversal(Enum): + """ + Block-traversal of weights + """ + + DEPTH_FIRST = auto() + PART_KERNEL_FIRST = auto() + + +class NpuDataType(Enum): + """ + Supported data types in feature maps + """ + + UINT8 = 8, False, auto() + INT8 = 8, True, auto() + UINT16 = 16, False, auto() + INT16 = 16, True, auto() + INT32 = 32, True, auto() + + def is_signed(self) -> bool: + """Checks if this data type is signed or unsigned""" + return self.value[1] + + def size_in_bits(self) -> int: + """ Size of the data type in bits""" + return self.value[0] + + def size_in_bytes(self) -> int: + """ Size of the data type in bytes""" + return self.value[0] // 8 + + def min_value(self) -> int: + """Minimum value of this type""" + if self.is_signed(): + return -(1 << (self.size_in_bits() - 1)) + else: + return 0 + + def max_value(self) -> int: + """Maximum value of this type""" + if self.is_signed(): + return (1 << (self.size_in_bits() - 1)) - 1 + else: + return (1 << self.size_in_bits()) - 1 + + def __str__(self): + return self.name + + __repr__ = __str__ + + +class NpuAddressRange(NamedTuple): + """ + Address range + """ + + region: int # Memory region, a value between 0 and 7 + address: int # Address, offset from the region's base address + length: int # The length of the range, in bytes + + def __str__(self): + return f"(region={self.region}, address={hex(self.address)}, length={self.length})" + + +class NpuTileBox(NamedTuple): + """ + Specifies the addresses and dimensions of the tiles of a feature map. + A feature map can use 1 to 4 tiles + """ + + height_0: int # The height of tile 0 + height_1: int # The height of tile 1, 0 if unused + width_0: int # the width of tile 0, and tile 2 (if used) + addresses: List[int] # A list of 4 addresses, set unused addresses to 0 + + +class NpuShape3D(NamedTuple): + """ + Shape of (part of) a feature map + """ + + height: int + width: int + depth: int + + +class NpuQuantization(NamedTuple): + """ + Quantization parameters + """ + + scale_f32: Optional[float] + zero_point: int + + +class NpuPadding(NamedTuple): + """ + Padding to be applied to a convolution operation + """ + + top: int + left: int + bottom: int + right: int + + +class NpuActivation: + """ + Activation function, fused with NPU operations + """ + + def __init__(self, op_type: NpuActivationOp): + self.op_type = op_type # The activation operation to be performed + # min/max are optional + self.min: Optional[float] = None # E.g. set to 0.0 for RELU + self.max: Optional[float] = None # E.g. set to 6.0 for RELU6 + # Table lookup index, only applicable for TABLE_LOOKUP activation, 0-7 + self.lookup_table_index: int = 0 + + +class NpuFeatureMap: + """ + Basic information about IFM, IFM2, OFM + """ + + def __init__(self): + self.data_type: NpuDataType = NpuDataType.UINT8 + # The memory region, a value 0-7 + self.region: int = 0 + # Shape of the feature map + self.shape: NpuShape3D = NpuShape3D(height=0, width=0, depth=0) + # The tiles that comprise the feature map. In the normal case when only 1 tile is used, + # height_0 == self.shape.height, height_1 is 0, width_0 == self.shape.width, addresses[1:] are set to 0 + self.tiles: NpuTileBox = NpuTileBox(height_0=0, height_1=0, width_0=0, addresses=[0, 0, 0, 0]) + self.quantization: Optional[NpuQuantization] + self.layout: NpuLayout = NpuLayout.NHWC + # x/y/c strides used by the NPU when traversing the feature map, if None, vela will use default strides + self.strides: Optional[NpuShape3D] = None + + +class NpuKernel: + """ + Kernel information for NPU operations + """ + + def __init__(self, w: int, h: int, stride_x: int = 1, stride_y: int = 1, dilation_x: int = 1, dilation_y: int = 1): + assert stride_x > 0 and stride_y > 0 + assert dilation_x > 0 and dilation_y > 0 + self.width = w + self.height = h + self.stride_x = stride_x + self.stride_y = stride_y + self.dilation_x = dilation_x + self.dilation_y = dilation_y + + +class NpuOperationType(Enum): + """ + Type of NPU operation + """ + + Dma = auto() + Conv2D = auto() + ConvDepthWise = auto() + Pooling = auto() + ElementWise = auto() + + +class NpuOperation: + """ + Base class for all NPU operations + """ + + def __init__(self, op_type: NpuOperationType): + self.op_type = op_type + + +class NpuDmaOperation(NpuOperation): + """ + DMA operation + """ + + def __init__(self, src: NpuAddressRange, dest: NpuAddressRange): + super().__init__(NpuOperationType.Dma) + self.src = src + self.dest = dest + # DMA channel, usually 0 (user channel) + self.channel: int = 0 + # Channel mode, 0 = external, 1 = internal (should usually be 0) + self.mode: int = 0 + + +class NpuBlockOperation(NpuOperation): + """ + Base class for operations which produce an OFM + """ + + def __init__(self, op_type: NpuOperationType): + super().__init__(op_type) + self.ifm: Optional[NpuFeatureMap] = None + self.ifm2: Optional[NpuFeatureMap] = None + # The non-quantized scalar value in a binary elementwise operation. Only set if IFM2 is scalar + self.ifm2_scalar: Optional[float] = None + self.ofm: Optional[NpuFeatureMap] = None + self.kernel: Optional[NpuKernel] = None + # Weights, one element for each NPU core, empty if no weights are used. + # Must have been compressed using weight_compressor.encode_weights() + self.weights: List[NpuAddressRange] = [] + # Biases, one element for each NPU core, empty if no bias is used. + # Must have been encoded using weight_compressor.encode_bias() + self.biases: List[NpuAddressRange] = [] + self.padding: Optional[NpuPadding] = None + # Optional activation function to be applied + self.activation: Optional[NpuActivation] = None + # The block config is the unit of work in which the NPU generates the OFM. + # If the operation has weights, the depth of the block config must be the same as + # the ofm depth used in the call to weight_compressor.encode_weights() + # If set to None, vela will determine a suitable block size (can only be used if there are no weights) + # If block_config.width and height are set to -1, vela will determine suitable width/height + self.block_config: Optional[NpuShape3D] = None # OFM_BLK parameters + self.rounding_mode: NpuRoundingMode = NpuRoundingMode.TFL + # Set to True if the operations is fused with a Quantize operation (affects scaling) + self.fused_quantize: bool = False + # IFM upscaling to be applied + self.ifm_upscale: NpuResamplingMode = NpuResamplingMode.NONE + + +class NpuConv2DOperation(NpuBlockOperation): + """ + NPU_OP_CONV operation + """ + + def __init__(self): + super().__init__(NpuOperationType.Conv2D) + # Block traversal must be consistent with the block_traversal parameter specified in + # weight_compressor.encode_weights() + self.block_traversal: NpuBlockTraversal = NpuBlockTraversal.PART_KERNEL_FIRST + + +class NpuConvDepthWiseOperation(NpuBlockOperation): + """ + NPU_OP_DEPTHWISE operation + """ + + def __init__(self): + super().__init__(NpuOperationType.ConvDepthWise) + + +class NpuPoolingOperation(NpuBlockOperation): + """ + NPU_OP_POOL operation + """ + + def __init__(self, pooling_op_type: NpuPoolingOp): + super().__init__(NpuOperationType.Pooling) + self.sub_op_type: NpuPoolingOp = pooling_op_type + # Set to a float value for ResizeBilinear operations (affects scaling), else to None + self.rescale: Optional[float] = None + + +class NpuElementWiseOperation(NpuBlockOperation): + """ + NPU_OP_ELEMENTWISE operation + """ + + def __init__(self, elementwise_op_type: NpuElementWiseOp): + super().__init__(NpuOperationType.ElementWise) + self.sub_op_type: NpuElementWiseOp = elementwise_op_type + # Set to True for binary operators where IFM2 should be used as first operand + self.reversed_operands: bool = False + # Set to a tuple (scale, shift) for explicit rescale, else to None + self.rescale: Optional[Tuple] = None diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py index b77205b..6a02a4e 100644 --- a/ethosu/vela/architecture_features.py +++ b/ethosu/vela/architecture_features.py @@ -496,7 +496,7 @@ class ArchitectureFeatures: return (start_coord, end_coord, 1) # start, end, total jobs def get_prev_job_output_volume( - self, ifm: Block, ofm: Rect, ifm_block_depth, ofm_block: Block, kernel: Kernel, block_offset + self, ifm: Rect, ofm: Rect, ifm_block_depth, ofm_block: Block, kernel: Kernel, block_offset ): assert block_offset >= 0 @@ -518,13 +518,13 @@ class ArchitectureFeatures: def calc_block_dep( self, - prev_ifm: Block, - prev_ofm: Block, + prev_ifm: Rect, + prev_ofm: Rect, prev_ifm_block_depth, prev_ofm_block: Block, prev_kernel: Kernel, - ifm: Block, - ofm: Block, + ifm: Rect, + ofm: Rect, ifm_block_depth, ofm_block: Block, kernel: Kernel, diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py index e089b70..32eef30 100644 --- a/ethosu/vela/compiler_driver.py +++ b/ethosu/vela/compiler_driver.py @@ -292,7 +292,7 @@ def compiler_driver(nng, arch, options, scheduler_options): nng, sg, arch, options.verbose_high_level_command_stream ) lut.optimize_high_level_cmd_stream(sg, arch) - register_command_stream_generator.generate_register_command_stream( + register_command_stream_generator.generate_register_command_stream_for_sg( nng, sg, arch, options.verbose_register_command_stream ) scratch_tens, scratch_fast_tens, flash_tens = npu_serialisation.serialise_npu_subgraph_into_tensors( diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 899da07..fb5235d 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -31,6 +31,7 @@ from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .numeric_util import clamp_sigmoid from .numeric_util import full_shape from .numeric_util import round_away_zero +from .operation import create_activation_function from .operation import create_avgpool_nop from .operation import NpuBlockType from .operation import Op @@ -413,7 +414,7 @@ def fixup_pack_input(op, arch, nng): def unfuse_activation_function(op, arch, nng): if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None: - act_op = Operation(op.activation, op.name + op.activation.name) + act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name) op.activation = None out_tens = op.outputs[0] intermediate_tens = out_tens.clone("_act_intermediate") @@ -641,7 +642,7 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng): # Override this op with its own primary op (avgpool) relu_fused_op = create_avgpool_nop(op.name + "_avgpool") # And fuse the original activation function to it - relu_fused_op.activation = op.type + relu_fused_op.activation = create_activation_function(op.type) # Tidy up and assign the ifm and ofm to the new op ifm.consumer_list.remove(op) diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py index a5372d7..4c3a9cf 100644 --- a/ethosu/vela/high_level_command_stream.py +++ b/ethosu/vela/high_level_command_stream.py @@ -21,12 +21,6 @@ import numpy as np from .numeric_util import round_up_divide from .operation import NpuBlockType -from .operation import Op -from .range_set import AccessDirection -from .range_set import MemoryAccessSet -from .range_set import MemoryRangeSet -from .tensor import MemArea -from .tensor import TensorPurpose class Box: @@ -159,9 +153,6 @@ class Command: def is_npu_pass_command(self): return False - def get_memory_accesses(self): - return None - def get_operation_count(self): # returns numpy array of (DPU blocks, dma_ops). Should line up with the CommandType enum return np.array((0, 0)) @@ -213,48 +204,6 @@ class NpuStripe(Command): for i in range(len(self.ofm_box.end_coord)): assert self.ofm_box.end_coord[i] <= self.ofm_tensor.shape[i] - def get_memory_accesses(self): - res = MemoryAccessSet() - if self.ifm_tensor is not None and self.ifm_tensor.shape != []: - res.add( - self.ifm_tensor.get_address_ranges_for_coordinates(self.ifm_box.start_coord, self.ifm_box.end_coord), - AccessDirection.Read, - ) - if self.ifm2_tensor is not None and self.ifm2_tensor.shape != []: - res.add( - self.ifm2_tensor.get_address_ranges_for_coordinates(self.ifm2_box.start_coord, self.ifm2_box.end_coord), - AccessDirection.Read, - ) - if self.ofm_tensor is not None: - res.add( - self.ofm_tensor.get_address_ranges_for_coordinates(self.ofm_box.start_coord, self.ofm_box.end_coord), - AccessDirection.Write, - ) - if self.weight_tensor is not None: - res.add( - self.weight_tensor.get_address_ranges_for_coordinates( - self.weight_box.start_coord, self.weight_box.end_coord - ), - AccessDirection.Read, - ) - if self.scale_tensor is not None and self.scale_tensor.ops[0].type == Op.DMA: - res.add( - self.scale_tensor.get_address_ranges_for_coordinates([0], self.scale_tensor.shape), - AccessDirection.Read, - ) - # Add read access to SHRAM by any LUT-s - for tens in self.ps.intermediates: - if tens.purpose == TensorPurpose.LUT and tens.mem_area == MemArea.Shram: - res.add( - MemoryRangeSet(tens.mem_area, tens.address, tens.address + tens.storage_size()), - AccessDirection.Read, - ) - # Add write access to SHRAM, needed when LUTs can overwrite accumulator banks - res.add( - self.ps.shared_buffer.get_shram_memory_access_range(), AccessDirection.Write, - ) - return res - def is_npu_pass_command(self): return True @@ -391,19 +340,6 @@ class DMA(Command): __repr__ = __str__ - def get_memory_accesses(self): - res = MemoryAccessSet() - - res.add( - self.in_tensor.get_address_ranges_for_coordinates(self.box.start_coord, self.box.end_coord), - AccessDirection.Read, - ) - res.add( - self.out_tensor.get_address_ranges_for_coordinates(self.box.start_coord, self.box.end_coord), - AccessDirection.Write, - ) - return res - def get_operation_count(self): # returns numpy array of (DPU blocks, dma_ops) return np.array((0, 1)) diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index 871a048..14cd051 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -24,6 +24,7 @@ from .high_level_command_stream import NpuStripe from .nn_graph import PassPlacement from .nn_graph import SchedulingStrategy from .numeric_util import round_up_divide +from .operation import create_activation_function from .operation import NpuBlockType from .operation import Op from .tensor import TensorPurpose @@ -109,7 +110,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id concat_offset = concat_start ps.primary_op.memory_function = op.type elif op.type.is_relu_op() or op.type in (Op.Tanh, Op.Sigmoid): - ps.primary_op.activation = op.type + ps.primary_op.activation = create_activation_function(op.type) if strat == SchedulingStrategy.WeightStream: ofm_step = block_config[-1] diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py new file mode 100644 index 0000000..7750121 --- /dev/null +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -0,0 +1,497 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# Conversion from high level command to NpuOperation +from enum import IntEnum +from typing import List +from typing import Optional + +from .api import NpuActivation +from .api import NpuActivationOp +from .api import NpuAddressRange +from .api import NpuBlockOperation +from .api import NpuBlockTraversal +from .api import NpuConv2DOperation +from .api import NpuConvDepthWiseOperation +from .api import NpuDataType +from .api import NpuDmaOperation +from .api import NpuElementWiseOp +from .api import NpuElementWiseOperation +from .api import NpuFeatureMap +from .api import NpuKernel +from .api import NpuLayout +from .api import NpuOperation +from .api import NpuPadding +from .api import NpuPoolingOp +from .api import NpuPoolingOperation +from .api import NpuQuantization +from .api import NpuResamplingMode +from .api import NpuRoundingMode +from .api import NpuShape3D +from .api import NpuTileBox +from .architecture_features import ArchitectureFeatures +from .data_type import DataType +from .high_level_command_stream import Box +from .high_level_command_stream import Command +from .high_level_command_stream import CommandType +from .high_level_command_stream import DMA +from .high_level_command_stream import NpuStripe +from .operation import Kernel +from .operation import NpuBlockType +from .operation import Op +from .operation import Operation +from .tensor import MemType +from .tensor import Tensor +from .tensor import TensorBlockTraversal +from .tensor import TensorFormat +from .tensor import TensorPurpose + + +unary_elementwise_ops = set((NpuElementWiseOp.ABS, NpuElementWiseOp.LRELU, NpuElementWiseOp.CLZ,)) + + +class BasePointerIndex(IntEnum): + WeightTensor = 0 # base address index for the Weight tensor + ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena + ScratchFastTensor = 2 # base address for the Scratch_fast_tensor + Mem2Mem = (1 << 8) | (3 << 0) # base address slot for memory 2 memory transfer + + +dtype_map = { + DataType.uint8: NpuDataType.UINT8, + DataType.int8: NpuDataType.INT8, + DataType.uint16: NpuDataType.UINT16, + DataType.int16: NpuDataType.INT16, + DataType.int32: NpuDataType.INT32, +} + + +block_traversal_map = { + TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST, + TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST, +} + + +# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE +elementwise_op_map = { + Op.Mul: NpuElementWiseOp.MUL, + Op.Add: NpuElementWiseOp.ADD, + Op.Sub: NpuElementWiseOp.SUB, + Op.Minimum: NpuElementWiseOp.MIN, + Op.Maximum: NpuElementWiseOp.MAX, + Op.LeakyRelu: NpuElementWiseOp.LRELU, + Op.Abs: NpuElementWiseOp.ABS, + Op.CLZ: NpuElementWiseOp.CLZ, + Op.SHR: NpuElementWiseOp.SHR, + Op.SHL: NpuElementWiseOp.SHL, +} + + +def to_npu_kernel(kernel: Kernel) -> NpuKernel: + """Converts the given internally used kernel object to NpuKernel (of public API)""" + return NpuKernel( + kernel.width, kernel.height, kernel.stride.x, kernel.stride.y, kernel.dilation.x, kernel.dilation.y + ) + + +def to_kernel(kernel: Optional[NpuKernel]) -> Kernel: + """Converts the given public API object to Kernel (used internally)""" + if kernel is None: + return Kernel(1, 1) + return Kernel(kernel.width, kernel.height, kernel.stride_x, kernel.stride_y, kernel.dilation_x, kernel.dilation_y) + + +def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool: + if ifm_shape == []: + # Scalar needs to be in IFM2 + return False + if ifm2_shape == []: + return True + + for ifm, ifm2 in zip(ifm_shape, ifm2_shape): + if ifm != ifm2 and ifm == 1: + # Broadcasted FM needs to be in IFM2 + return False + return True + + +def get_rounding_mode(op: Operation) -> NpuRoundingMode: + """Specifies type of rounding to be used""" + rounding_mode = NpuRoundingMode.TFL + if op.type == Op.ResizeBilinear: + rounding_mode = NpuRoundingMode.TRUNCATE + elif ( + op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise) + and op.ifm.dtype == DataType.int16 + ): + rounding_mode = NpuRoundingMode.NATURAL + elif op.type.is_avgpool_op() and op.memory_function == Op.ConcatSliceWrite and op.kernel.elements_wh() == 1: + rounding_mode = NpuRoundingMode.NATURAL + rounding_mode = op.attrs.get("rounding_mode", rounding_mode) + return rounding_mode + + +def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding: + if primary_op.type.npu_block_type == NpuBlockType.VectorProduct: + return NpuPadding(top=0, left=0, bottom=0, right=0) + explicit_padding = list(primary_op.attrs["explicit_padding"]) # (top, left, bottom, right) + + # Check if this is for horizontal ifm streaming + if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe): + explicit_padding[0] = cmd.pad_top + explicit_padding[2] = cmd.pad_bottom + + # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output, + # because of activation function needed to be fused. + if cmd.ifm_box.start_coord[-2] > 0: + explicit_padding[1] = 0 + if cmd.ifm_box.end_coord[-2] < cmd.ifm_tensor.shape[-2]: + explicit_padding[3] = 0 + return NpuPadding( + top=explicit_padding[0], left=explicit_padding[1], bottom=explicit_padding[2], right=explicit_padding[3] + ) + + +def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int: + if arch.feature_map_storage_mem_area == arch.fast_storage_mem_area: + base_ptr_idx_map = { + MemType.Permanent_NPU: BasePointerIndex.WeightTensor, + MemType.Permanent_CPU: BasePointerIndex.WeightTensor, + MemType.Scratch: BasePointerIndex.ScratchTensor, + MemType.Scratch_fast: BasePointerIndex.ScratchTensor, + } + else: + base_ptr_idx_map = { + MemType.Permanent_NPU: BasePointerIndex.WeightTensor, + MemType.Permanent_CPU: BasePointerIndex.WeightTensor, + MemType.Scratch: BasePointerIndex.ScratchTensor, + MemType.Scratch_fast: BasePointerIndex.ScratchFastTensor, + } + return int(base_ptr_idx_map[tens.mem_type]) + + +def get_upscale(op: Operation) -> NpuResamplingMode: + upscale = NpuResamplingMode.NONE + if op.type == Op.ResizeBilinear: + # perform nearest neighbor upscale + upscale = NpuResamplingMode.NEAREST + elif op.type == Op.Conv2DBackpropInputSwitchedBias: + # perform insert zero upscale + upscale = NpuResamplingMode.TRANSPOSE + return upscale + + +def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int: + if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum): + shape = ifm_box.get_size_shape() + else: + shape = ofm_box.get_size_shape() + return shape[-1] + + +def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool: + """Checks if quantization should use 0 as zero point""" + if tens.dtype == DataType.int32 and is_ifm_tensor: + return True + if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL): + return False + fused_quantize = any(op.type == Op.Quantize for op in ps.ops) + forced_ofm_quantization = ps.primary_op.forced_output_quantization + use_0 = ( + (ps.primary_op.activation is None or forced_ofm_quantization is not None) + and (ps.primary_op.memory_function != Op.ConcatSliceWrite) + and not fused_quantize + ) + return use_0 + + +def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]: + """Gets quantization for IFM/IFM2""" + if tens.quantization is None: + return None + if use_zero_point_0(ps, tens, True): + zero_point = 0 + else: + zero_point = int(tens.quantization.zero_point) + return NpuQuantization(scale_f32=tens.quantization.scale_f32, zero_point=zero_point) + + +def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]: + """Gets quantization for OFM""" + op = ps.primary_op + # Check if operation's output quantization is should be used instead of the output tensor's quantization + # (used in LUTs) + ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization + if ofm_quant is None: + return None + if use_zero_point_0(ps, tens, False): + zero_point = 0 + else: + zero_point = int(ofm_quant.zero_point) + return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point) + + +def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures) -> NpuFeatureMap: + """Creates feature map with common fields populated""" + fm = NpuFeatureMap() + fm.region = get_region(tens, arch) + fm.data_type = dtype_map[tens.dtype] + if tens.format == TensorFormat.NHWC: + fm.layout = NpuLayout.NHWC + elif tens.format == TensorFormat.NHCWB16: + fm.layout = NpuLayout.NHCWB16 + else: + assert 0, "Incorrect tensor format" + height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord) + for idx, addr in enumerate(addresses): + if addr is None: + addresses[idx] = 0 + fm.tiles = NpuTileBox( + height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses] + ) + strides = tens.get_strides() + fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1])) + return fm + + +def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]: + """Returns address ranges for weights""" + weights = [] + stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord) + weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index] + substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length + + # Extract weight substream offsets and calculate their lengths + assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0) + weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord) + region = get_region(weight_tensor, arch) + for core in range(substreams): + address = weight_addr + weight_substream_offsets[core] + length = weight_substream_offsets[core + 1] - weight_substream_offsets[core] + addr_range = NpuAddressRange(region, int(address), int(length)) + weights.append(addr_range) + return weights + + +def create_biases( + weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures +) -> List[NpuAddressRange]: + """Returns address ranges for biases""" + biases = [] + stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord) + scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index] + substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length + + # Extract scale substream offsets and calculate their lengths + assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0) + scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:]) + + region = get_region(scale_tensor, arch) + for core in range(substreams): + address = scale_addr + scale_substream_offsets[core] + length = scale_substream_offsets[core + 1] - scale_substream_offsets[core] + addr_range = NpuAddressRange(region, int(address), int(length)) + biases.append(addr_range) + return biases + + +def create_npu_activation(op: Operation) -> NpuActivation: + """Creates fused activation function""" + if op.activation is None: + return NpuActivation(NpuActivationOp.NONE_OR_RELU) + faf = op.activation.op_type + act_op = NpuActivationOp.NONE_OR_RELU + if faf == Op.Tanh: + act_op = NpuActivationOp.TANH + elif faf == Op.Sigmoid: + act_op = NpuActivationOp.SIGMOID + elif faf == Op.LUT: + act_op = NpuActivationOp.TABLE_LOOKUP + elif not faf.is_relu_op(): + raise Exception("Unsupported fused_activation_function = " + faf.name) + + act = NpuActivation(act_op) + act.min = op.activation.min + act.max = op.activation.max + act.lookup_table_index = op.activation.lut_index + return act + + +def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures): + """Sets common fields of the given operation""" + ps = cmd.ps + op = ps.primary_op + in_shape = cmd.ifm_box.get_size_shape() + out_shape = cmd.ofm_box.get_size_shape() + ofm_height = out_shape[-3] if len(out_shape) >= 4 else 1 + ofm_width = out_shape[-2] if len(out_shape) >= 2 else 1 + ofm_depth = out_shape[-1] if len(out_shape) >= 1 else 1 + ifm_height = in_shape[-3] if len(in_shape) >= 4 else 1 + if op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum): + ifm_depth = in_shape[-1] if len(in_shape) >= 1 else 1 + else: + ifm_depth = ofm_depth + + npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch) + npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=cmd.ifm_tensor.shape[-2], depth=ifm_depth) + npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor) + npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch) + npu_op.ofm.shape = NpuShape3D(height=ofm_height, width=ofm_width, depth=ofm_depth) + npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor) + + if cmd.weight_tensor is not None: + npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch) + if cmd.scale_tensor is not None: + npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch) + npu_op.activation = create_npu_activation(op) + npu_op.rounding_mode = get_rounding_mode(op) + npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3]) + + if not op.type.is_elementwise_op(): + npu_op.padding = create_padding(cmd, op) + npu_op.kernel = to_npu_kernel(op.kernel) + npu_op.ifm_upscale = get_upscale(op) + npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops) + return npu_op + + +def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation: + """Converts the command to NpuConv2DOperation""" + npu_op = NpuConv2DOperation() + set_common_op_fields(npu_op, cmd, arch) + if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct: + npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST + else: + npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal] + return npu_op + + +def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation: + """Converts the command to NpuConvDepthWiseOperation""" + npu_op = NpuConvDepthWiseOperation() + set_common_op_fields(npu_op, cmd, arch) + return npu_op + + +def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation: + """Converts the command to NpuPoolingOperation""" + ps = cmd.ps + op = ps.primary_op + pool_op = NpuPoolingOp.AVERAGE + if op.type.is_maxpool_op(): + pool_op = NpuPoolingOp.MAX + elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear: + pool_op = NpuPoolingOp.AVERAGE + elif op.type == Op.ReduceSum: + pool_op = NpuPoolingOp.REDUCE_SUM + else: + assert 0, f"Unknown pool type {op.type}" + npu_op = NpuPoolingOperation(pool_op) + set_common_op_fields(npu_op, cmd, arch) + # Pooling specific info + if op.type == Op.ResizeBilinear and "rescale" in op.attrs: + npu_op.rescale = op.attrs["rescale"] + return npu_op + + +def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation: + """Converts the command to NpuElementWiseOperation""" + ps = cmd.ps + op = ps.primary_op + assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}" + elemwise_op = elementwise_op_map[op.type] + npu_op = NpuElementWiseOperation(elemwise_op) + if elemwise_op not in unary_elementwise_ops: + if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape): + # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms + cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor + cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box + npu_op.reversed_operands = True + npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch) + npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor) + if cmd.ifm2_tensor.shape == []: + # scalar + assert cmd.ifm2_tensor.quant_values.size == 1 + npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0) + npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0) + else: + box_shp = cmd.ifm2_box.get_size_shape() + height = box_shp[-3] if len(box_shp) >= 3 else 1 + npu_op.ifm2.shape = NpuShape3D(height=height, width=cmd.ifm2_tensor.shape[-2], depth=box_shp[-1]) + set_common_op_fields(npu_op, cmd, arch) + # Check if output scale needs to be overridden + output_scale = None + if op.type == Op.Add and "resizebilinear" in op.attrs: + # Force output scale same as the input scale for + # resizebilinear 1x1 that is converted to add + output_scale = npu_op.ifm2.quantization.scale_f32 + if op.type == Op.LeakyRelu: + output_scale = op.attrs["alpha"] + if op.type in (Op.Add, Op.Sub) and "rescale" in op.attrs: + npu_op.rescale = op.attrs.get("rescale") + if op.type in (Op.Add, Op.Mul, Op.Sub): + if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh): + output_scale = 1 / 0x3000 + if output_scale is not None: + npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point) + return npu_op + + +def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation: + """Converts the command to NpuDmaOperation""" + src_region = get_region(cmd.in_tensor, arch) + if cmd.out_tensor.purpose == TensorPurpose.LUT: + dest_region = BasePointerIndex.Mem2Mem + else: + dest_region = get_region(cmd.out_tensor, arch) + + start_coord = cmd.box.start_coord + src_addr = cmd.in_tensor.address_for_coordinate(start_coord) + dest_addr = cmd.out_tensor.address_for_coordinate(start_coord) + + if cmd.in_tensor.compressed_values is not None: + if cmd.out_tensor.purpose == TensorPurpose.FSBias: + sz = cmd.in_tensor.storage_size() + else: + stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord) + sz = cmd.in_tensor.size_of_compressed_stream(stream_index) + else: + sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr + src = NpuAddressRange(src_region, int(src_addr), int(sz)) + dest = NpuAddressRange(dest_region, int(dest_addr), int(sz)) + return NpuDmaOperation(src, dest) + + +def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation: + """Converts the high level command to NpuOperation""" + if cmd.cmdtype == CommandType.DMA: + npu_op = create_dma_op(cmd, arch) + elif cmd.cmdtype == CommandType.NpuStripe: + npu_block_type = cmd.ps.primary_op.type.npu_block_type + if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct): + npu_op = create_npu_conv2d_op(cmd, arch) + elif npu_block_type == NpuBlockType.ConvolutionDepthWise: + npu_op = create_npu_conv_depthwise_op(cmd, arch) + elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum): + npu_op = create_npu_pool_op(cmd, arch) + elif npu_block_type == NpuBlockType.ElementWise: + npu_op = create_npu_elementwise_op(cmd, arch) + else: + assert 0, f"Unknown command type {npu_block_type}" + # add a link to the high level command for debugging purposes + npu_op.cmd = cmd + return npu_op diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py index 69aa2a0..8e28b95 100644 --- a/ethosu/vela/lut.py +++ b/ethosu/vela/lut.py @@ -115,14 +115,14 @@ def optimize_high_level_cmd_stream(sg, arch): if existing_tens is not None: # LUT is already in SHRAM, no need to perform DMA lut_tens.address = existing_tens.address - cmd.ps.primary_op.attrs["lut_index"] = get_lut_index(arch, existing_tens) + cmd.ps.primary_op.activation.lut_index = get_lut_index(arch, existing_tens) continue # Place the LUT in the last 2 blocks of SHRAM # Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size()) lut_tens.equivalence_id = uuid.uuid4() lut_tens.address = address - cmd.ps.primary_op.attrs["lut_index"] = (address - lut_start) // slot_size + cmd.ps.primary_op.activation.lut_index = (address - lut_start) // slot_size lut_state = lut_state.put(lut_tens) cmd_stream.append(cmd) sg.high_level_command_stream = cmd_stream diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py index d8995a1..80f2c27 100644 --- a/ethosu/vela/npu_performance.py +++ b/ethosu/vela/npu_performance.py @@ -225,7 +225,7 @@ def get_ifm_block_depth(npu_block_type, ifm_depth, ifm_elemwidth, block_traversa def estimate_output_cycles( arch, npu_block_type, primary_op, num_elems, ifm_tensor, ofm_tensor, ifm2_tensor, use_acc_40bits=False ): - faf = primary_op.activation + faf = None if primary_op.activation is None else primary_op.activation.op_type if npu_block_type == NpuBlockType.ElementWise and ifm_tensor.dtype == DataType.int32: if ifm2_tensor is None: # Unary op diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 1ba2a38..9f7d544 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -15,8 +15,10 @@ # limitations under the License. # Description: # Internal representation of a Neural Network Operation. +import copy from collections import namedtuple from enum import Enum +from typing import Optional from .numeric_util import full_shape @@ -35,24 +37,30 @@ class NpuBlockType(Enum): class Kernel: - def __init__(self, w, h, sx=1, sy=1, dx=1, dy=1): - assert sx > 0 and sy > 0 - assert dx > 0 and dy > 0 + """ + Kernel information for NPU operations + """ + + def __init__(self, w: int, h: int, stride_x: int = 1, stride_y: int = 1, dilation_x: int = 1, dilation_y: int = 1): + assert stride_x > 0 and stride_y > 0 + assert dilation_x > 0 and dilation_y > 0 self.width = w self.height = h - self.stride = PointXY(sx, sy) - self.dilation = PointXY(dx, dy) - self.upscale = 1 + self.stride = PointXY(stride_x, stride_y) + self.dilation = PointXY(dilation_x, dilation_y) - def elements_wh(self): + def elements_wh(self) -> int: return self.width * self.height - def area_width(self): + def area_width(self) -> int: return (self.width - 1) * self.dilation.x + 1 - def area_height(self): + def area_height(self) -> int: return (self.height - 1) * self.dilation.y + 1 + def __str__(self): + return f"w={self.width}, h={self.height}, stride={tuple(self.stride)}, dilation={tuple(self.dilation)}" + # Classifies operators of type Custom class CustomType(Enum): @@ -109,6 +117,7 @@ class Op(Enum): Call = OperatorInfo() Cast = OperatorInfo() Ceil = OperatorInfo() + Clip = OperatorInfo() # NPU specific fused activation function for clipping between activation.min/max Concat = OperatorInfo(indices=CONCAT_INDICES) ConcatEmbeddings = OperatorInfo() ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES) @@ -282,7 +291,7 @@ class Op(Enum): return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary def is_relu_op(self): - return self in (Op.Relu, Op.Relu6, Op.ReluN1To1) + return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip) def is_activation_op(self): return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT) @@ -310,6 +319,42 @@ class Op(Enum): return self.value.id < other.value.id +class ActivationFunction: + """Fused activation function""" + + def __init__(self, op_type: Op): + self.op_type = op_type # The activation operation to be performed + # min/max are optional; if present they are non-quantized values + self.min: Optional[float] = None + self.max: Optional[float] = None + # Table lookup index, only applicable for Op.LUT activation, 0-7 + self.lut_index: int = 0 + + def clone(self): + res = copy.copy(self) + return res + + +def create_activation_function(op_type: Op) -> ActivationFunction: + """Creates activation function with min/max depending on op_type""" + act = ActivationFunction(op_type) + if op_type == Op.Relu: + act.min = 0.0 + elif op_type == Op.Relu6: + act.min = 0.0 + act.max = 6.0 + elif op_type == Op.ReluN1To1: + act.min = -1.0 + act.max = 1.0 + elif op_type == Op.Tanh: + act.min = -1.0 + act.max = 1.0 + elif op_type == Op.Sigmoid: + act.min = 0.0 + act.max = 1.0 + return act + + def create_avgpool_nop(name): op = Operation(Op.AvgPool, name) op.attrs["padding"] = b"VALID" @@ -358,7 +403,7 @@ class Operation: "_kernel", ) - def __init__(self, op_type, name): + def __init__(self, op_type: Op, name: str): self.type = op_type self.name = name self.attrs = {} @@ -367,7 +412,7 @@ class Operation: self.flops = 0 self.run_on_npu = True # Fused activation function. If not none: operator code. - self.activation = None + self.activation: Optional[ActivationFunction] = None # Fused memory function, if not None: operator code self.memory_function = None # If not none: contains QuantizationParameters to be used as output quantization @@ -386,7 +431,7 @@ class Operation: res.outputs = list(self.outputs) res.flops = self.flops res.run_on_npu = self.run_on_npu - res.activation = self.activation + res.activation = None if self.activation is None else self.activation.clone() res.memory_function = self.memory_function res.forced_output_quantization = self.forced_output_quantization res.scheduled_pass = self.scheduled_pass @@ -405,6 +450,8 @@ class Operation: weight_shape = full_shape(4, weights.shape, 1) h = weight_shape[-4] w = weight_shape[-3] + elif self.type.npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum) and "ksize" in self.attrs: + h, w = self.attrs["ksize"][1:3] else: h = self.attrs.get("filter_height", 1) w = self.attrs.get("filter_width", 1) @@ -597,7 +644,7 @@ class Operation: return input_tens, outputs, axis, offset_start, offset_end def set_activation_lut(self, lut_tensor): - self.activation = Op.LUT + self.activation = ActivationFunction(Op.LUT) self.activation_lut = lut_tensor self.add_input_tensor(lut_tensor) diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index e3fedfc..30b5e04 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -14,47 +14,72 @@ # See the License for the specific language governing permissions and # limitations under the License. # Description: -# Register level (low-level) command stream generation for Ethos-U55. Takes a high-level command stream and generates +# Register level (low-level) command stream generation for Ethos-U55. Takes a list of NPU operations and generates # all the register settings. Calculates dependencies between commands and inserts wait operations. And generates a bit # stream suitable for interpretation by the Ethos-U55 processor. from collections import defaultdict from collections import namedtuple from enum import Enum from enum import IntEnum +from typing import List +from typing import Optional import numpy as np +from . import numeric_util from . import scaling +from .api import NpuActivation +from .api import NpuActivationOp +from .api import NpuAddressRange +from .api import NpuBlockOperation +from .api import NpuBlockTraversal +from .api import NpuConv2DOperation +from .api import NpuDataType +from .api import NpuDmaOperation +from .api import NpuElementWiseOp +from .api import NpuElementWiseOperation +from .api import NpuFeatureMap +from .api import NpuKernel +from .api import NpuLayout +from .api import NpuOperation +from .api import NpuOperationType +from .api import NpuPadding +from .api import NpuPoolingOp +from .api import NpuPoolingOperation +from .api import NpuQuantization +from .api import NpuResamplingMode +from .api import NpuRoundingMode +from .api import NpuShape3D +from .api import NpuTileBox +from .architecture_features import Accelerator from .architecture_features import ArchitectureFeatures from .architecture_features import Block from .architecture_features import Rect from .architecture_features import SharedBufferArea from .architecture_features import SHRAMElements -from .data_type import BaseType -from .data_type import DataType from .debug_database import DebugDatabase from .ethos_u55_regs.ethos_u55_regs import acc_format from .ethos_u55_regs.ethos_u55_regs import activation from .ethos_u55_regs.ethos_u55_regs import cmd0 from .ethos_u55_regs.ethos_u55_regs import cmd1 from .ethos_u55_regs.ethos_u55_regs import elementwise_mode -from .ethos_u55_regs.ethos_u55_regs import ifm_precision from .ethos_u55_regs.ethos_u55_regs import pooling_mode from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .ethos_u55_regs.ethos_u55_regs import rounding from .high_level_command_stream import CommandType -from .numeric_util import clamp_sigmoid -from .numeric_util import clamp_tanh -from .numeric_util import full_shape +from .high_level_command_to_npu_op import convert_command_to_npu_op +from .high_level_command_to_npu_op import to_kernel +from .high_level_command_to_npu_op import unary_elementwise_ops from .numeric_util import quantise_float32 from .numeric_util import round_away_zero from .numeric_util import round_up_to_int from .operation import NpuBlockType -from .operation import Op -from .tensor import MemType -from .tensor import TensorBlockTraversal -from .tensor import TensorFormat -from .tensor import TensorPurpose +from .range_set import AccessDirection +from .range_set import MemoryAccessSet +from .range_set import MemoryRangeSet +from .shared_buffer_allocation import find_suitable_block_configs +from .shared_buffer_allocation import shared_buffer_allocation_for_npu_op +from .shared_buffer_allocation import SharedBufferAllocation class RegisterMachine: @@ -80,22 +105,6 @@ class CmdMode(IntEnum): CmdOpMask = 0x03FF -class BasePointerIndex(IntEnum): - WeightTensor = 0 # base address index for the Weight tensor - ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena - ScratchFastTensor = 2 # base address for the Scratch_fast_tensor - Mem2Mem = (1 << 8) | (3 << 0) # base address slot for memory 2 memory transfer - - -# TODO: Replace with definitions from ethos_u55_regs -class IFM2Broadcast(IntEnum): - BroadcastHdim = 1 << 0 - BroadcastWdim = 1 << 1 - BroadcastCdim = 1 << 2 - ReverseOperandOrder = 1 << 6 - UseIFM2Scalar = 1 << 7 - - class CommandStreamEmitter: WORD_SIZE = 4 @@ -117,7 +126,7 @@ class CommandStreamEmitter: sz += len(cmd) * CommandStreamEmitter.WORD_SIZE return sz - def to_list(self): + def to_list(self) -> List[int]: return [elem for cmd in self.cmd_stream for elem in cmd] def print_cmds(self): @@ -146,7 +155,7 @@ class CommandStreamEmitter: print(s) - def cmd0_with_param(self, cmd, param): + def cmd0_with_param(self, cmd: cmd0, param): if isinstance(param, Enum): param = int(param.value) else: @@ -160,7 +169,7 @@ class CommandStreamEmitter: self.cmd_stream.append((command,)) self.offset += CommandStreamEmitter.WORD_SIZE - def cmd1_with_offset(self, cmd, offset, param=0x0): + def cmd1_with_offset(self, cmd: cmd1, offset, param=0x0): offset = int(offset) & 0xFFFFFFFFF command = cmd.value | CmdMode.Payload32.value | (param << 16) @@ -171,13 +180,13 @@ class CommandStreamEmitter: self.cmd_stream.append((command, offset)) self.offset += CommandStreamEmitter.WORD_SIZE * 2 - def cmd_wait(self, cmd, channel, outstanding_count): + def cmd_wait(self, cmd: cmd0, channel: int, outstanding_count: int): param = (16 * channel) + outstanding_count command = ((param & 0xFFFF) << 16) | cmd.value self.cmd_stream.append((command,)) self.offset += CommandStreamEmitter.WORD_SIZE - def cmd_do_operation(self, cmd, param=0): + def cmd_do_operation(self, cmd: cmd0, param=0): param = int(param) command = ((param & 0xFFFF) << 16) | cmd.value @@ -186,13 +195,674 @@ class CommandStreamEmitter: self.get_reg_machine(cmd).switch_bank() +# ------------------------------------------------------------------- +# REGISTER GENERATION +# ------------------------------------------------------------------- + + +class BasePointerIndex(IntEnum): + WeightTensor = 0 # base address index for the Weight tensor + ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena + ScratchFastTensor = 2 # base address for the Scratch_fast_tensor + Mem2Mem = (1 << 8) | (3 << 0) # base address slot for memory 2 memory transfer + + +# TODO: Replace with definitions from ethos_u55_regs +class IFM2Broadcast(IntEnum): + BroadcastHdim = 1 << 0 + BroadcastWdim = 1 << 1 + BroadcastCdim = 1 << 2 + ReverseOperandOrder = 1 << 6 + UseIFM2Scalar = 1 << 7 + + +pooling_op_map = { + NpuPoolingOp.MAX: pooling_mode.MAX.value, + NpuPoolingOp.AVERAGE: pooling_mode.AVERAGE.value, + NpuPoolingOp.REDUCE_SUM: pooling_mode.REDUCE_SUM.value, +} + +elementwise_op_map = { + NpuElementWiseOp.MUL: elementwise_mode.MUL.value, + NpuElementWiseOp.ADD: elementwise_mode.ADD.value, + NpuElementWiseOp.SUB: elementwise_mode.SUB.value, + NpuElementWiseOp.MIN: elementwise_mode.MIN.value, + NpuElementWiseOp.MAX: elementwise_mode.MAX.value, + NpuElementWiseOp.LRELU: elementwise_mode.LRELU.value, + NpuElementWiseOp.ABS: elementwise_mode.ABS.value, + NpuElementWiseOp.CLZ: elementwise_mode.CLZ.value, + NpuElementWiseOp.SHR: elementwise_mode.SHR.value, + NpuElementWiseOp.SHL: elementwise_mode.SHL.value, +} + +activation_op_map = { + NpuActivationOp.NONE_OR_RELU: activation.NONE, + NpuActivationOp.TANH: activation.TANH, + NpuActivationOp.SIGMOID: activation.SIGMOID, +} + +# Maps an AccumulatorType enum to the corresponding acc_format value +acc_format_map = { + SHRAMElements.Acc16: acc_format.FP_S5_10.value, + SHRAMElements.Acc32: acc_format.INT_32BIT.value, + SHRAMElements.Acc40: acc_format.INT_40BIT.value, +} + +resampling_mode_map = { + NpuResamplingMode.NONE: resampling_mode.NONE, + NpuResamplingMode.NEAREST: resampling_mode.NEAREST, + NpuResamplingMode.TRANSPOSE: resampling_mode.TRANSPOSE, +} + +# Maps data type size in bits to activation precision +precision_map = {8: 0, 16: 1, 32: 2} + +# Maps rounding mode to the corresponding value +rounding_mode_map = { + NpuRoundingMode.TFL: rounding.TFL.value, + NpuRoundingMode.TRUNCATE: rounding.TRUNCATE.value, + NpuRoundingMode.NATURAL: rounding.NATURAL.value, +} + + +def quantise(value: float, quant: Optional[NpuQuantization]) -> int: + """Quantizes the given value""" + scale = 1 if quant is None or quant.scale_f32 is None else quant.scale_f32 + zp = 0 if quant is None else quant.zero_point + return quantise_float32(value, scale, zp) + + +def has_ifm2(npu_op: NpuBlockOperation) -> bool: + """Checks if op has non-scalar IFM2""" + return npu_op.ifm2 is not None and npu_op.ifm2_scalar is None + + +def is_dma_op(npu_op: NpuOperation) -> bool: + """Checks if op is a DMA operation""" + return npu_op.op_type == NpuOperationType.Dma + + +def generate_padding(emit: CommandStreamEmitter, padding: NpuPadding): + """Generates IFM_PAD registers""" + emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_TOP, padding.top) + emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_LEFT, padding.left) + emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_BOTTOM, padding.bottom) + emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_RIGHT, padding.right) + + +def generate_activation(emit: CommandStreamEmitter, activation: Optional[NpuActivation], ofm: NpuFeatureMap): + """Generates ACTIVATION registers""" + act = activation if activation is not None else NpuActivation(NpuActivationOp.NONE_OR_RELU) + + if act.min is None: + quantized_min = ofm.data_type.min_value() + else: + quantized_min = quantise(act.min, ofm.quantization) + if act.max is None: + quantized_max = ofm.data_type.max_value() + else: + quantized_max = quantise(act.max, ofm.quantization) + quantized_min = max(quantized_min, np.iinfo(np.int16).min, ofm.data_type.min_value()) + quantized_max = min(quantized_max, np.iinfo(np.int16).max, ofm.data_type.max_value()) + if act.op_type == NpuActivationOp.TABLE_LOOKUP: + assert 0 <= act.lookup_table_index < 8 + activation_value = 16 + act.lookup_table_index + if ofm.data_type == NpuDataType.INT32: + activation_value |= 3 << 12 # Force I8 range + quantized_min = max(-128, quantized_min) + quantized_max = min(127, quantized_max) + else: + activation_value = activation_op_map[act.op_type] + emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation_value) + emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MIN, quantized_min) + emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MAX, quantized_max) + + +def generate_addresses(emit: CommandStreamEmitter, ptr_cmds: List[cmd1], addresses: List[int], layout: NpuLayout): + """Generates xFM_BASE registers""" + if layout == NpuLayout.NHCWB16: + # Check that all BasePointer addresses are aligned to 16 bytes + assert all((int(addr) % 16) == 0 for addr in addresses) + emit.cmd1_with_offset(ptr_cmds[0], addresses[0]) + emit.cmd1_with_offset(ptr_cmds[1], addresses[1]) + emit.cmd1_with_offset(ptr_cmds[2], addresses[2]) + emit.cmd1_with_offset(ptr_cmds[3], addresses[3]) + + +def generate_tiles(emit: CommandStreamEmitter, tile_cmds: List[cmd0], tiles: NpuTileBox): + """Generates xFM_HEIGHT0/HEIGHT1/WIDTH0 registers""" + emit.cmd0_with_param(tile_cmds[0], tiles.height_0 - 1) + emit.cmd0_with_param(tile_cmds[1], tiles.height_1 - 1) + emit.cmd0_with_param(tile_cmds[2], tiles.width_0 - 1) + + +def generate_strides( + emit: CommandStreamEmitter, fm: NpuFeatureMap, stride_c_cmd: cmd1, stride_y_cmd: cmd1, stride_x_cmd: cmd1 +): + """Generates STRIDE_C/Y/X registers""" + strides = get_strides(fm) + emit.cmd1_with_offset(stride_c_cmd, strides.depth) # stride between 16-byte channel blocks (C) + emit.cmd1_with_offset(stride_y_cmd, strides.height) # stride between vertical values (H) + emit.cmd1_with_offset(stride_x_cmd, strides.width) # stride between horisontal values (W) + + +def generate_ifm_precision(emit: CommandStreamEmitter, fm: NpuFeatureMap, op_to_scale: int, precision_cmd: cmd0): + """Generates IFM/IFM2_PRECISION register""" + dtype = fm.data_type + prec = 1 if dtype.is_signed() else 0 + activation_precision = precision_map[dtype.size_in_bits()] + prec += activation_precision << 2 + + if fm.layout == NpuLayout.NHCWB16: + prec |= 1 << 6 + + prec |= op_to_scale << 8 + emit.cmd0_with_param(precision_cmd, prec) + + +def generate_ofm_precision(emit: CommandStreamEmitter, npu_op: NpuBlockOperation, use_global_scale: bool): + """Generates OFM_PRECISION register""" + dtype = npu_op.ofm.data_type + prec = 1 if dtype.is_signed() else 0 + activation_precision = precision_map[dtype.size_in_bits()] + prec += activation_precision << 1 + + if use_global_scale: + # Set global scale bit, as opposed to using per channel scale + prec |= 1 << 8 + if npu_op.ofm.layout == NpuLayout.NHCWB16: + prec |= 1 << 6 + prec |= rounding_mode_map[npu_op.rounding_mode] << 14 + emit.cmd0_with_param(cmd0.NPU_SET_OFM_PRECISION, prec) + + +def generate_ifm2_broadcast(emit: CommandStreamEmitter, npu_op: NpuElementWiseOperation): + """Generates IFM2_BROADCAST register for binary elementwise operations""" + ifm2_broadcast = 0 + ifm = npu_op.ifm + ifm2 = npu_op.ifm2 + if npu_op.reversed_operands: + ifm2_broadcast |= IFM2Broadcast.ReverseOperandOrder + if npu_op.ifm2_scalar is not None: + # IFM2 is a constant, set UseIFM2Scalar bit to IFM2_BROADCAST + ifm2_broadcast |= IFM2Broadcast.UseIFM2Scalar + else: + if ifm.shape.height != ifm2.shape.height: + # Broadcast in 'H' dimension + assert ifm2.shape.height == 1 + ifm2_broadcast |= IFM2Broadcast.BroadcastHdim + + if ifm.shape.width != ifm2.shape.width: + # Broadcast in 'W' dimension + assert ifm2.shape.width == 1 + ifm2_broadcast |= IFM2Broadcast.BroadcastWdim + + if ifm.shape.depth != ifm2.shape.depth: + # Broadcast in 'C' dimension + assert ifm2.shape.depth == 1 + ifm2_broadcast |= IFM2Broadcast.BroadcastCdim + + emit.cmd0_with_param(cmd0.NPU_SET_IFM2_BROADCAST, ifm2_broadcast) + + +def generate_ifm(emit: CommandStreamEmitter, ifm: NpuFeatureMap): + """Generates general IFM registers""" + emit.cmd0_with_param(cmd0.NPU_SET_IFM_REGION, ifm.region) + generate_addresses( + emit, + [cmd1.NPU_SET_IFM_BASE0, cmd1.NPU_SET_IFM_BASE1, cmd1.NPU_SET_IFM_BASE2, cmd1.NPU_SET_IFM_BASE3], + ifm.tiles.addresses, + ifm.layout, + ) + generate_tiles( + emit, [cmd0.NPU_SET_IFM_HEIGHT0_M1, cmd0.NPU_SET_IFM_HEIGHT1_M1, cmd0.NPU_SET_IFM_WIDTH0_M1], ifm.tiles + ) + emit.cmd0_with_param(cmd0.NPU_SET_IFM_DEPTH_M1, ifm.shape.depth - 1) + generate_strides(emit, ifm, cmd1.NPU_SET_IFM_STRIDE_C, cmd1.NPU_SET_IFM_STRIDE_Y, cmd1.NPU_SET_IFM_STRIDE_X) + emit.cmd0_with_param(cmd0.NPU_SET_IFM_ZERO_POINT, int(ifm.quantization.zero_point)) + + +def generate_ifm2(emit: CommandStreamEmitter, ifm2: NpuFeatureMap, has_scalar: bool): + """Generates general IFM2 registers""" + if not has_scalar: + emit.cmd0_with_param(cmd0.NPU_SET_IFM2_REGION, ifm2.region) + generate_addresses( + emit, + [cmd1.NPU_SET_IFM2_BASE0, cmd1.NPU_SET_IFM2_BASE1, cmd1.NPU_SET_IFM2_BASE2, cmd1.NPU_SET_IFM2_BASE3], + ifm2.tiles.addresses, + ifm2.layout, + ) + generate_tiles( + emit, [cmd0.NPU_SET_IFM2_HEIGHT0_M1, cmd0.NPU_SET_IFM2_HEIGHT1_M1, cmd0.NPU_SET_IFM2_WIDTH0_M1], ifm2.tiles + ) + generate_strides(emit, ifm2, cmd1.NPU_SET_IFM2_STRIDE_C, cmd1.NPU_SET_IFM2_STRIDE_Y, cmd1.NPU_SET_IFM2_STRIDE_X) + emit.cmd0_with_param(cmd0.NPU_SET_IFM2_ZERO_POINT, int(ifm2.quantization.zero_point)) + + +def generate_ofm(emit: CommandStreamEmitter, ofm: NpuFeatureMap): + """Generates general OFM registers""" + emit.cmd0_with_param(cmd0.NPU_SET_OFM_REGION, ofm.region) + generate_addresses( + emit, + [cmd1.NPU_SET_OFM_BASE0, cmd1.NPU_SET_OFM_BASE1, cmd1.NPU_SET_OFM_BASE2, cmd1.NPU_SET_OFM_BASE3], + ofm.tiles.addresses, + ofm.layout, + ) + generate_tiles( + emit, [cmd0.NPU_SET_OFM_HEIGHT0_M1, cmd0.NPU_SET_OFM_HEIGHT1_M1, cmd0.NPU_SET_OFM_WIDTH0_M1], ofm.tiles + ) + emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT_M1, ofm.shape.height - 1) + emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH_M1, ofm.shape.width - 1) + emit.cmd0_with_param(cmd0.NPU_SET_OFM_DEPTH_M1, ofm.shape.depth - 1) + generate_strides(emit, ofm, cmd1.NPU_SET_OFM_STRIDE_C, cmd1.NPU_SET_OFM_STRIDE_Y, cmd1.NPU_SET_OFM_STRIDE_X) + emit.cmd0_with_param(cmd0.NPU_SET_OFM_ZERO_POINT, int(ofm.quantization.zero_point)) + + +def generate_kernel(emit: CommandStreamEmitter, kernel: NpuKernel, block_traversal: NpuBlockTraversal): + """Generates KERNEL related registers""" + emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_HEIGHT_M1, kernel.dilation_y * (kernel.height - 1)) + emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_WIDTH_M1, kernel.dilation_x * (kernel.width - 1)) + # set kernel x stride low bit + stride = (kernel.stride_x - 1) & 1 + # set kernel y stride low bit + stride |= (kernel.stride_y - 1 & 1) << 1 + # set kernel x stride extension bits + stride |= (kernel.stride_x - 1 >> 1) << 6 + # set kernel y stride extension bits + stride |= (kernel.stride_y - 1 >> 1) << 9 + stride |= (kernel.dilation_x - 1) << 3 + stride |= (kernel.dilation_y - 1) << 4 + if block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST: + stride |= 1 << 2 + emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_STRIDE, stride) + + +def generate_weights(emit: CommandStreamEmitter, weights: List[NpuAddressRange], arch: ArchitectureFeatures): + """Generates WEIGHT registers""" + if len(weights) == 0: + return + emit.cmd0_with_param(cmd0.NPU_SET_WEIGHT_REGION, weights[0].region) + # Set weights sources for active and present cores + for core, (addr, length) in enumerate( + [ + (cmd1.NPU_SET_WEIGHT_BASE, cmd1.NPU_SET_WEIGHT_LENGTH), + (cmd1.NPU_SET_WEIGHT1_BASE, cmd1.NPU_SET_WEIGHT1_LENGTH), + ] + ): + if core < len(weights): + emit.cmd1_with_offset(addr, weights[core].address) + emit.cmd1_with_offset(length, weights[core].length) + elif core < arch.ncores: + emit.cmd1_with_offset(addr, weights[0].address) + emit.cmd1_with_offset(length, 0) + + +def generate_biases(emit: CommandStreamEmitter, biases: List[NpuAddressRange], arch: ArchitectureFeatures): + """Generates SCALE registers""" + if len(biases) == 0: + return + emit.cmd0_with_param(cmd0.NPU_SET_SCALE_REGION, biases[0].region) + # Set weights sources for active and present cores + for core, (addr, length) in enumerate( + [(cmd1.NPU_SET_SCALE_BASE, cmd1.NPU_SET_SCALE_LENGTH), (cmd1.NPU_SET_SCALE1_BASE, cmd1.NPU_SET_SCALE1_LENGTH)] + ): + if core < len(biases): + emit.cmd1_with_offset(addr, biases[core].address) + emit.cmd1_with_offset(length, biases[core].length) + elif core < arch.ncores: + emit.cmd1_with_offset(addr, biases[0].address) + emit.cmd1_with_offset(length, 0) + + +def generate_block_config( + emit: CommandStreamEmitter, + npu_op: NpuBlockOperation, + arch: ArchitectureFeatures, + shared_buffer: SharedBufferAllocation, +) -> NpuShape3D: + """Selects a suitable block config if none has been set, and generates OFM_BLK_HEIGHT/WIDTH/DEPTH registers""" + block_config = npu_op.block_config + if block_config is None or block_config.height < 0: + # Note: this code only used if the public API to generate command streams is used; + # in the "normal" flow, the block config selected by the scheduler is used + if npu_op.weights: + assert block_config is not None, "block_config.depth must be provided for ops with weights" + # Block config has not been provided: find one + blocks = find_suitable_block_configs(arch, shared_buffer) + # Return the block with biggest volume + # TODO: use a better algorithm to find the best block + best_block = None + best_value = 0 + for block in blocks: + if block_config is not None and block[3] != block_config.depth: + continue + value = block[0] * block[1] * block[3] + if value > best_value: + best_value = value + best_block = block + assert best_block is not None, f"No suitable block config was found, {npu_op.op_type}" + block_config = NpuShape3D(height=best_block[0], width=best_block[1], depth=best_block[3]) + alloc = shared_buffer.try_block(Block(block_config.width, block_config.height, block_config.depth)) + assert alloc is not None, f"Block config {block_config} does not fit, op: {npu_op.op_type}" + emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, block_config.height - 1) + emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_WIDTH_M1, block_config.width - 1) + emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_DEPTH_M1, block_config.depth - 1) + return block_config + + +def generate_shram_registers_elementwise( + emit: CommandStreamEmitter, + npu_op: NpuElementWiseOperation, + arch: ArchitectureFeatures, + shared_buffer: SharedBufferAllocation, +): + """Generates IB_END/IB_START/AB_START registers for elementwise operations""" + # For elementwise set the required SHRAM to be equal to the total size of available SHRAM + uses_lut = npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP + shram_required = arch.available_shram_banks(uses_lut) + + # Acc buffers not needed so set AB_START to size of SHRAM + emit.cmd0_with_param(cmd0.NPU_SET_IFM_IB_END, shram_required) + emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shram_required) + if has_ifm2(npu_op): + # Set IFM2_IB_START to the latter half of the IB space + ifm_ib_start = shared_buffer.bank_locations[SharedBufferArea.IFM] + emit.cmd0_with_param( + cmd0.NPU_SET_IFM2_IB_START, (shram_required - ifm_ib_start) // shared_buffer.ifm_count + ifm_ib_start, + ) + emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[shared_buffer.use_accumulator_element]) + + +def generate_shram_registers_non_elementwise(emit: CommandStreamEmitter, shared_buffer: SharedBufferAllocation): + """Generates IB_END/IB_START/AB_START registers for non-elementwise operations""" + emit.cmd0_with_param( + cmd0.NPU_SET_IFM_IB_END, + shared_buffer.bank_locations[SharedBufferArea.IFM] + shared_buffer.banks_required[SharedBufferArea.IFM], + ) + emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shared_buffer.bank_locations[SharedBufferArea.Accumulators]) + emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[shared_buffer.use_accumulator_element]) + + +def generate_common( + emit: CommandStreamEmitter, + npu_op: NpuBlockOperation, + block_traversal: NpuBlockTraversal, + arch: ArchitectureFeatures, + use_global_scale: bool = False, + op_to_scale: int = 0, +): + """Generate registers that are common to most operations""" + assert npu_op.ifm is not None and npu_op.ofm is not None + generate_ifm(emit, npu_op.ifm) + generate_ifm_precision(emit, npu_op.ifm, op_to_scale, cmd0.NPU_SET_IFM_PRECISION) + emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, resampling_mode_map[npu_op.ifm_upscale]) + if npu_op.padding is not None: + generate_padding(emit, npu_op.padding) + generate_ofm(emit, npu_op.ofm) + generate_ofm_precision(emit, npu_op, use_global_scale) + if npu_op.op_type != NpuOperationType.ElementWise: + assert npu_op.kernel is not None + generate_kernel(emit, npu_op.kernel, block_traversal) + generate_weights(emit, npu_op.weights, arch) + generate_biases(emit, npu_op.biases, arch) + generate_activation(emit, npu_op.activation, npu_op.ofm) + + +# ------------------------------------------------------------------- +# SCALING +# ------------------------------------------------------------------- + + +def generate_ofm_scaling_for_pooling(emit: CommandStreamEmitter, pool_op: NpuPoolingOperation): + """Generates OFM_SCALE register for pooling operations""" + # For valid padding vela has to output scaling values + kernel = pool_op.kernel + ifm_quant = pool_op.ifm.quantization + ofm_quant = pool_op.ofm.quantization + if pool_op.activation is not None and pool_op.activation.op_type in (NpuActivationOp.SIGMOID, NpuActivationOp.TANH): + assert ifm_quant.scale_f32 is not None + rescale = 0x3000 * ifm_quant.scale_f32 + if pool_op.ifm.data_type == NpuDataType.INT16: + # Calculate scale and shift for the output scale of 1/(3*4096) + shift = 0 + max_rescale = np.iinfo(np.int16).max / 2 + while rescale <= max_rescale and shift <= 30: + shift += 1 + rescale *= 2 + scale = int(rescale) + else: + rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1 + scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits) + scale = int(round_away_zero(scale * rescale)) + elif pool_op.fused_quantize: + # Quantize op requires different scaling + ifm_scale_f64 = np.double(ifm_quant.scale_f32) + ofm_scale_f64 = np.double(ofm_quant.scale_f32) + scale, shift = scaling.quantise_scale(ifm_scale_f64 / ofm_scale_f64) + elif pool_op.rescale is not None: + # for ResizeBilinear operations with "rescale" in primary_op.attrs + rescale = pool_op.rescale + rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1 + scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits) + scale = int(round_away_zero(scale * rescale)) + else: + # In case avg pool fused with concat or other memory operation, rescaling might be needed. + # kernel height == kernel width == 1 is always true in this case + # Normally the scale is maximised, to get maximum precision, which means that + # if rescale != 1, scale need to consider the number of bits needed for rescaling + if ofm_quant.scale_f32 is not None and ifm_quant.scale_f32 is not None: + rescale = ifm_quant.scale_f32 / ofm_quant.scale_f32 + rescale_bits = 0 + if kernel.height == kernel.width == 1: + if rescale > 1: + rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1 + elif rescale < 1: + rescale_bits = -(len(bin(round_up_to_int(1 / rescale))) - 2 - 1) + scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits) + scale = int(round_away_zero(scale * rescale)) + else: + scale = 1 + shift = 0 + + emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, scale, shift) + + +def generate_scaling_for_elementwise(emit: CommandStreamEmitter, npu_op: NpuElementWiseOperation) -> int: + """ + Generates OFM/OPA/OPB_SCALE registers for elementwise operators. + Returns the operator to scale + """ + op_to_scale = 0 + if npu_op.sub_op_type in (NpuElementWiseOp.ADD, NpuElementWiseOp.MUL, NpuElementWiseOp.SUB): + input_scale = npu_op.ifm.quantization.scale_f32 if npu_op.ifm.quantization else None + input2_scale = npu_op.ifm2.quantization.scale_f32 if npu_op.ifm2.quantization else None + output_scale = npu_op.ofm.quantization.scale_f32 if npu_op.ofm.quantization else None + + if npu_op.activation is not None and npu_op.activation.op_type in ( + NpuActivationOp.SIGMOID, + NpuActivationOp.TANH, + ): + output_scale = 1 / 0x3000 + + if npu_op.sub_op_type == NpuElementWiseOp.MUL: + if None in (input_scale, input2_scale, output_scale): + ofm_scale = 1 + shift = 0 + else: + ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale) + emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift) + else: # Add/Sub + if None in (input_scale, input2_scale, output_scale): + opa_scale = opb_scale = ofm_scale = 1 + opa_shift = shift = 0 + if npu_op.rescale is not None: + ofm_scale, shift = npu_op.rescale + elif input_scale == input2_scale: + opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale( + input_scale, input2_scale, output_scale + ) + opa_shift = 0 # Unused for this case + else: + # Use advanced implementation only when input scales differ + bitdepth = npu_op.ifm.data_type.size_in_bits() + (opa_scale, opa_shift, ofm_scale, shift, op_to_scale,) = scaling.advanced_elementwise_add_sub_scale( + input_scale, input2_scale, output_scale, bitdepth + ) + opb_scale = 0 # Unused for this case + if npu_op.reversed_operands: + # If the operand order is reversed we also have to swap which operand is scaled + if op_to_scale == scaling.OperandToScale.OPa: + op_to_scale = scaling.OperandToScale.OPb + else: + op_to_scale = scaling.OperandToScale.OPa + emit.cmd1_with_offset(cmd1.NPU_SET_OPA_SCALE, opa_scale, opa_shift) + emit.cmd1_with_offset(cmd1.NPU_SET_OPB_SCALE, opb_scale) + emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift) + elif npu_op.sub_op_type in (NpuElementWiseOp.LRELU, NpuElementWiseOp.ABS): + output_scale = npu_op.ofm.quantization.scale_f32 + ofm_scale, shift = scaling.quantise_scale(output_scale) + emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift) + else: + emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, 1, 0) + return op_to_scale + + +# ------------------------------------------------------------------- +# ADDRESSING/STRIDES (helper functions) +# ------------------------------------------------------------------- + + +def ranges_overlap(range1: NpuAddressRange, range2: NpuAddressRange) -> bool: + """Checks if the ranges overlap""" + return range1.region == range2.region and numeric_util.overlaps( + range1.address, range1.address + range1.length, range2.address, range2.address + range2.length + ) + + +def get_strides(fm: NpuFeatureMap) -> NpuShape3D: + """Calculates STRIDE_C/Y/X""" + if fm.strides is not None: + return fm.strides + elem_size = fm.data_type.size_in_bytes() + if fm.layout == NpuLayout.NHWC: + stride_c = elem_size + stride_x = fm.shape.depth * stride_c + stride_y = fm.shape.width * stride_x + else: + stride_x = 16 * elem_size + stride_c = stride_x * fm.shape.width + stride_y = elem_size * fm.shape.width * numeric_util.round_up(fm.shape.depth, 16) + return NpuShape3D(depth=stride_c, height=stride_y, width=stride_x) + + +def get_address(fm: NpuFeatureMap, strides: NpuShape3D, y: int, x: int, c: int) -> int: + """Returns address of given coordinate""" + t = 0 + BRICK = 16 + stride_c = BRICK * fm.data_type.size_in_bytes() if fm.layout == NpuLayout.NHWC else strides.depth + stride_x = BRICK * fm.data_type.size_in_bytes() if fm.layout == NpuLayout.NHCWB16 else strides.width + if x >= fm.tiles.width_0: + x -= fm.tiles.width_0 + t = 1 + if y >= fm.tiles.height_1: + y -= fm.tiles.height_1 + t += 2 + elif y >= fm.tiles.height_0: + y -= fm.tiles.height_0 + t += 2 + elem_size = fm.data_type.size_in_bytes() + return ( + fm.tiles.addresses[t] + y * strides.height + x * stride_x + (c // BRICK) * stride_c + int(c % BRICK) * elem_size + ) + + +def get_address_range( + fm: NpuFeatureMap, strides: NpuShape3D, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int +) -> NpuAddressRange: + """Gets address range for (y0, x0, c0) - (y1, x1, c1)""" + addr0 = get_address(fm, strides, y0, x0, c0) + addr1 = get_address(fm, strides, y1, x1, c1) + return NpuAddressRange(region=fm.region, address=addr0, length=addr1 - addr0 + fm.data_type.size_in_bytes()) + + +def get_address_ranges(fm: NpuFeatureMap) -> List[Optional[NpuAddressRange]]: + """Returns 4 adddress ranges, one for every tile, None if the tile is not in use""" + strides = get_strides(fm) + height, width, depth = fm.shape.height, fm.shape.width, fm.shape.depth + height_0, height_1, width_0 = fm.tiles.height_0, fm.tiles.height_1, fm.tiles.width_0 + t0 = get_address_range(fm, strides, 0, 0, 0, min(height, height_0) - 1, min(width, width_0) - 1, depth - 1,) + if width > width_0: + t1 = get_address_range(fm, strides, 0, width_0, 0, min(height, height_1) - 1, width - 1, depth - 1) + else: + t1 = None + if height > height_0: + t2 = get_address_range(fm, strides, height_0, 0, 0, height - 1, min(width, width_0) - 1, depth - 1) + else: + t2 = None + if t1 is not None and t2 is not None: + t3 = get_address_range(fm, strides, height_0, width_0, 0, height - 1, width - 1, depth - 1) + else: + t3 = None + return [t0, t1, t2, t3] + + +# ------------------------------------------------------------------- +# DMA_WAIT/KERNEL_WAIT +# ------------------------------------------------------------------- + + Watermark = namedtuple("Watermark", ["npu", "dma"]) -def get_cmd_wait_dependency(arch, cmd_stream, memory_accesses, cmd_index, watermark: Watermark): - cmd = cmd_stream[cmd_index] - cmd_access = memory_accesses[cmd] - index = cmd_index - 1 +def memory_range_set(range: NpuAddressRange) -> MemoryRangeSet: + return MemoryRangeSet(range.region, range.address, range.address + range.length) + + +def get_dma_memory_accesses(dma_op: NpuDmaOperation) -> MemoryAccessSet: + """Returns the address that are read and written by the given DMA operation""" + res = MemoryAccessSet() + res.add(memory_range_set(dma_op.src), AccessDirection.Read) + res.add(memory_range_set(dma_op.dest), AccessDirection.Write) + return res + + +def get_op_memory_accesses(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> MemoryAccessSet: + """Returns the addresses that are read and written by the given operation""" + assert npu_op.ifm is not None and npu_op.ofm is not None + # Read addresses + read_ranges = get_address_ranges(npu_op.ifm) + if has_ifm2(npu_op): + assert npu_op.ifm2 is not None + read_ranges.extend(get_address_ranges(npu_op.ifm2)) + read_ranges.extend(npu_op.weights) + read_ranges.extend(npu_op.biases) + if npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP: + address = arch.available_shram_banks(True) * arch.shram_bank_size + read_ranges.append(NpuAddressRange(region=BasePointerIndex.Mem2Mem, address=address, length=2048)) + # Written addresses + write_ranges = get_address_ranges(npu_op.ofm) + # Add write access to SHRAM, needed when LUTs can overwrite accumulator banks + uses_lut = npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP + written_shram_size = arch.available_shram_banks(uses_lut) * arch.shram_bank_size + write_ranges.append(NpuAddressRange(region=BasePointerIndex.Mem2Mem, address=0, length=written_shram_size)) + + res = MemoryAccessSet() + for read_range in read_ranges: + if read_range is not None: + res.add(memory_range_set(read_range), AccessDirection.Read) + for write_range in write_ranges: + if write_range is not None: + res.add(memory_range_set(write_range), AccessDirection.Write) + return res + + +def get_wait_dependency( + arch: ArchitectureFeatures, npu_op_list: List[NpuOperation], memory_accesses, op_index: int, watermark: Watermark +): + """Used to calculate whether DMA wait or kernel wait operations are needed""" + npu_op = npu_op_list[op_index] + op_access = memory_accesses[npu_op] + index = op_index - 1 # NPU dependency tracking npu_outstanding = -1 @@ -211,33 +881,32 @@ def get_cmd_wait_dependency(arch, cmd_stream, memory_accesses, cmd_index, waterm # the command that issues the wait. # NPU->NPU dependency is handled via blockdep. while (index >= npu_index) or (index >= dma_index): - prev_cmd = cmd_stream[index] - prev_access = memory_accesses[prev_cmd] - - # Check DMA consuming NPU output - if prev_cmd.cmdtype == CommandType.NpuStripe: - if index >= npu_index: - if (cmd.cmdtype == CommandType.DMA) and (npu_outstanding == -1) and prev_access.conflicts(cmd_access): - npu_outstanding = npu_ops - npu_ops = npu_ops + 1 # Count NPU ops in the pipeline - if npu_ops >= arch.max_outstanding_kernels: - npu_index = max(index + 1, npu_index) + prev_op = npu_op_list[index] + prev_access = memory_accesses[prev_op] # Check NPU consuming DMA output - elif prev_cmd.cmdtype == CommandType.DMA: + if is_dma_op(prev_op): if index >= dma_index: - if cmd.cmdtype == CommandType.NpuStripe: - if (dma_outstanding == -1) and prev_access.conflicts(cmd_access): + if not is_dma_op(npu_op): + if (dma_outstanding == -1) and prev_access.conflicts(op_access): dma_outstanding = dma_ops - dma_ops = dma_ops + 1 # Count DMA ops in the pipeline + dma_ops += 1 # Count DMA ops in the pipeline if dma_ops >= arch.max_outstanding_dma: dma_index = max(index + 1, dma_index) + # Check DMA consuming NPU output + else: + if index >= npu_index: + if is_dma_op(npu_op) and npu_outstanding == -1 and prev_access.conflicts(op_access): + npu_outstanding = npu_ops + npu_ops += 1 # Count NPU ops in the pipeline + if npu_ops >= arch.max_outstanding_kernels: + npu_index = max(index + 1, npu_index) - index = index - 1 + index -= 1 # Update DMA watermark if we didn't see any and the NPU pipeline is full if (dma_ops == 0) and (npu_ops >= arch.max_outstanding_kernels): - dma_index = cmd_index + dma_index = op_index # Bring the search watermark forwards as we complete for those dependencies watermark = Watermark(npu_index, dma_index) @@ -246,873 +915,380 @@ def get_cmd_wait_dependency(arch, cmd_stream, memory_accesses, cmd_index, waterm return watermark, outstanding -def has_prev_op_dependency(prev_cmd, cmd): - if prev_cmd is None: - return False - if (prev_cmd.cmdtype == cmd.cmdtype == CommandType.NpuStripe) and (prev_cmd.ps != cmd.ps): - if prev_cmd.ofm_tensor.equivalent(cmd.ifm_tensor): - return True - elif cmd.ifm2_tensor is not None: - return prev_cmd.ofm_tensor.equivalent(cmd.ifm2_tensor) - return False - +def generate_cmd_waits(emit: CommandStreamEmitter, cmd_waits: Watermark): + if cmd_waits.npu >= 0: + emit.cmd_wait(cmd0.NPU_OP_KERNEL_WAIT, 0, cmd_waits.npu) -def get_op_ofm_rect(cmd): - start = full_shape(4, cmd.ofm_box.start_coord, 0) - end = full_shape(4, cmd.ofm_box.end_coord, 1) - return Rect(start[-2], start[-3], start[-1], end[-2] - 1, end[-3] - 1, end[-1] - 1) + if cmd_waits.dma >= 0: + emit.cmd_wait(cmd0.NPU_OP_DMA_WAIT, 0, cmd_waits.dma) -def get_op_ifm_rect(cmd): - start = full_shape(4, cmd.ifm_box.start_coord, 0) - end = full_shape(4, cmd.ifm_box.end_coord, 1) - return Rect(start[-2], start[-3], start[-1], end[-2] - 1, end[-3] - 1, end[-1] - 1) +# ------------------------------------------------------------------- +# BLOCKDEP +# ------------------------------------------------------------------- -def get_op_ifmofm_block_depth(arch, cmd): - # Note: NOT equivalent to the normal ifm block depth calculation since - # it takes into account 'depthless' block operations by returning full - # depth - if cmd.ps.npu_block_type in ( - NpuBlockType.ConvolutionDepthWise, - NpuBlockType.Pooling, - NpuBlockType.ElementWise, - NpuBlockType.ReduceSum, - ): - return cmd.ofm_box.get_size_shape()[-1] +def is_dependent_on_prev_op(prev_op: NpuBlockOperation, npu_op: NpuBlockOperation) -> bool: + """Checks if npu_op's input is dependent on prev_op's output""" + assert npu_op.ifm is not None + assert prev_op.ofm is not None + curr_input_ranges = get_address_ranges(npu_op.ifm) - return arch.calc_ifm_block_depth(cmd.ifm_box.get_size_shape()[-1], cmd.ifm_tensor.dtype.bits) - - -def get_op_padding_lt(cmd): - if cmd.ps.npu_block_type not in ( - NpuBlockType.ConvolutionDepthWise, - NpuBlockType.Pooling, - NpuBlockType.ConvolutionMxN, - NpuBlockType.ReduceSum, - ): - return (0, 0) - - explicit_padding = list(cmd.ps.primary_op.attrs["explicit_padding"]) # (top, left, bottom, right) - - # Check if this is for horizontal ifm streaming - if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe): - explicit_padding[0] = cmd.pad_top - explicit_padding[2] = cmd.pad_bottom - - return (explicit_padding[1], explicit_padding[0]) - - -def ifm_ifm2_correct_order(ifm_shape, ifm2_shape): - if ifm_shape == []: - # Scalar needs to be in IFM2 - return False - elif ifm2_shape == []: - return True - - for ifm, ifm2 in zip(ifm_shape, ifm2_shape): - if ifm != ifm2 and ifm == 1: - # Broadcasted FM needs to be in IFM2 - return False + if has_ifm2(npu_op): + assert npu_op.ifm2 is not None + curr_input_ranges.extend(get_address_ranges(npu_op.ifm2)) + for prev_range in get_address_ranges(prev_op.ofm): + if prev_range is None: + continue + for curr_range in curr_input_ranges: + if curr_range is not None and ranges_overlap(prev_range, curr_range): + return True + return False - return True +def shape3d_to_rect(shape: NpuShape3D) -> Rect: + return Rect(0, 0, 0, shape.width - 1, shape.height - 1, shape.depth - 1) -def generate_register_command_stream(nng, sg, arch, verbose=False): - emit = CommandStreamEmitter() - if arch.feature_map_storage_mem_area == arch.fast_storage_mem_area: - base_ptr_idx_map = { - MemType.Permanent_NPU: BasePointerIndex.WeightTensor, - MemType.Permanent_CPU: BasePointerIndex.WeightTensor, - MemType.Scratch: BasePointerIndex.ScratchTensor, - MemType.Scratch_fast: BasePointerIndex.ScratchTensor, - } +def get_ifm_ofm_block_depth(arch: ArchitectureFeatures, npu_op: NpuBlockOperation) -> int: + # Note: NOT equivalent to the normal ifm block depth calculation since + # it takes into account 'depthless' block operations by returning full + # depth + if npu_op.op_type == NpuOperationType.Conv2D: + res = arch.calc_ifm_block_depth(npu_op.ifm.shape.depth, npu_op.ifm.data_type.size_in_bits()) + return res + return npu_op.ofm.shape.depth + + +def calc_blockdep( + arch: ArchitectureFeatures, + prev_op: Optional[NpuBlockOperation], + prev_block_config: Optional[NpuShape3D], + npu_op: NpuBlockOperation, + block_config: NpuShape3D, +) -> int: + """Calculates the value of the BLOCKDEP register""" + if prev_op is None: + return 0 + if not is_dependent_on_prev_op(prev_op, npu_op): + return ArchitectureFeatures.MAX_BLOCKDEP + if prev_op.ofm.shape != npu_op.ifm.shape: + return 0 + prev_ifm_block_depth = get_ifm_ofm_block_depth(arch, prev_op) + prev_ofm_block = Block(prev_block_config.width, prev_block_config.height, prev_block_config.depth) + prev_ofm_rect = shape3d_to_rect(prev_op.ofm.shape) + prev_ifm_rect = shape3d_to_rect(prev_op.ifm.shape) + cur_ifm_block_depth = get_ifm_ofm_block_depth(arch, npu_op) + cur_ofm_block = Block(block_config.width, block_config.height, block_config.depth) + cur_ofm_rect = shape3d_to_rect(npu_op.ofm.shape) + cur_ifm_rect = shape3d_to_rect(npu_op.ifm.shape) + cur_padLT = (0, 0) if npu_op.padding is None else (npu_op.padding.left, npu_op.padding.top) + blockdep = arch.calc_block_dep( + prev_ifm_rect, + prev_ofm_rect, + prev_ifm_block_depth, + prev_ofm_block, + to_kernel(prev_op.kernel), + cur_ifm_rect, + cur_ofm_rect, + cur_ifm_block_depth, + cur_ofm_block, + to_kernel(npu_op.kernel), + cur_padLT, + ) + return blockdep + + +# ------------------------------------------------------------------- +# PRINT +# ------------------------------------------------------------------- + + +def print_feature_map(fm: NpuFeatureMap, name: str): + if fm is not None: + q = ( + "no quantization" + if fm.quantization is None + else f"scale: {fm.quantization.scale_f32}, zero: {fm.quantization.zero_point}" + ) + h, w, c = fm.shape + sz = h * w * c * fm.data_type.size_in_bytes() + print(f" {name}: h={h},w={w},c={c}, region={fm.region}, {fm.layout}, {fm.data_type}, size={sz}, {q}") + strides = get_strides(fm) + stride_str = f"Stride y/x/c: {strides.height}/{strides.width}/{strides.depth}" + t = fm.tiles + addresses = [hex(addr) for addr in t.addresses] + print(f" {stride_str}, tiles: w0={t.width_0}, h0={t.height_0}, h1={t.height_1}, base={addresses}") + + +def print_operation(npu_op: NpuOperation, index: int = 0): + pass_info = f", {npu_op.cmd}" if hasattr(npu_op, "cmd") else "" + if is_dma_op(npu_op): + print(f"{index} DMA_START src={npu_op.src}, dest={npu_op.dest}{pass_info}") + return + k = None if npu_op.kernel is None else to_kernel(npu_op.kernel) + if npu_op.op_type in (NpuOperationType.Pooling, NpuOperationType.ElementWise): + print(f"{index} {npu_op.sub_op_type.name} {npu_op.op_type.name}:{pass_info}") else: - base_ptr_idx_map = { - MemType.Permanent_NPU: BasePointerIndex.WeightTensor, - MemType.Permanent_CPU: BasePointerIndex.WeightTensor, - MemType.Scratch: BasePointerIndex.ScratchTensor, - MemType.Scratch_fast: BasePointerIndex.ScratchFastTensor, - } - - # Maps an AccumulatorType enum to the corresponding acc_format value - acc_format_map = { - SHRAMElements.Acc16: acc_format.FP_S5_10.value, - SHRAMElements.Acc32: acc_format.INT_32BIT.value, - SHRAMElements.Acc40: acc_format.INT_40BIT.value, - } - - # Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE - elementwise_mode_map = { - Op.Mul: elementwise_mode.MUL.value, - Op.Add: elementwise_mode.ADD.value, - Op.Sub: elementwise_mode.SUB.value, - Op.Minimum: elementwise_mode.MIN.value, - Op.Maximum: elementwise_mode.MAX.value, - Op.LeakyRelu: elementwise_mode.LRELU.value, - Op.Abs: elementwise_mode.ABS.value, - Op.CLZ: elementwise_mode.CLZ.value, - Op.SHR: elementwise_mode.SHR.value, - Op.SHL: elementwise_mode.SHL.value, - } - - cmd_stream = [] - memory_accesses = {} - for cmd in sg.high_level_command_stream: - if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default: - print("Warning: Skipping register command stream generation for", cmd.ps) + if ( + npu_op.op_type == NpuOperationType.Conv2D + and k.elements_wh() * k.stride.x * k.stride.y * k.dilation.x * k.dilation.y == 1 + ): + fc = "FullyConnected " else: - cmd_stream.append(cmd) - memory_accesses[cmd] = cmd.get_memory_accesses() - - def emit_cmd_waits(cmd_waits): - if cmd_waits.npu >= 0: - emit.cmd_wait(cmd0.NPU_OP_KERNEL_WAIT, 0, cmd_waits.npu) - - if cmd_waits.dma >= 0: - emit.cmd_wait(cmd0.NPU_OP_DMA_WAIT, 0, cmd_waits.dma) + fc = "" + print(f"{index} {fc}{npu_op.op_type.name}{pass_info}") + print_feature_map(npu_op.ifm, "IFM") + if npu_op.ifm2_scalar is not None: + quant_val = quantise(npu_op.ifm2_scalar, npu_op.ifm2.quantization) + print(f" IFM2: Scalar={npu_op.ifm2_scalar} (quantized: {quant_val}), {npu_op.ifm2.quantization}") + else: + print_feature_map(npu_op.ifm2, "IFM2") + print_feature_map(npu_op.ofm, "OFM") + if k is not None and npu_op.op_type != NpuOperationType.ElementWise: + print(f" Kernel: {k}") + if npu_op.padding is not None: + print(f" {npu_op.padding}") + for weights in npu_op.weights: + print(f" Weights: {weights}") + for bias in npu_op.biases: + print(f" Scales: {bias}") + if npu_op.activation is not None: + act = npu_op.activation + if act.op_type != NpuActivationOp.NONE_OR_RELU or act.min is not None or act.max is not None: + lut = f", lut index={act.lookup_table_index}" if act.op_type == NpuActivationOp.TABLE_LOOKUP else "" + print(f" Activation: {act.op_type.name}, min={act.min}, max={act.max}{lut}") + if npu_op.op_type == NpuOperationType.Conv2D: + print(f" {npu_op.block_traversal}") + bh, bw, bc = npu_op.block_config + rescale = f", rescale={npu_op.rescale}" if hasattr(npu_op, "rescale") else "" + print(f" Block config: h={bh},w={bw},c={bc}, {npu_op.ifm_upscale}, {npu_op.rounding_mode}{rescale}") + + +def print_operations(npu_op_list: List[NpuOperation]): + for index, npu_op in enumerate(npu_op_list): + print_operation(npu_op, index) + + +# ------------------------------------------------------------------- +# OPERATIONS +# ------------------------------------------------------------------- + + +def generate_operation_code(emit: CommandStreamEmitter, npu_op: NpuOperation): + """Generates NPU_OP_* command""" + op_type = npu_op.op_type + if op_type == NpuOperationType.Dma: + emit.cmd_do_operation(cmd0.NPU_OP_DMA_START, npu_op.channel * 16 + npu_op.mode) + elif op_type == NpuOperationType.Conv2D: + emit.cmd_do_operation(cmd0.NPU_OP_CONV) + elif op_type == NpuOperationType.ConvDepthWise: + emit.cmd_do_operation(cmd0.NPU_OP_DEPTHWISE) + elif op_type == NpuOperationType.Pooling: + emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=pooling_op_map[npu_op.sub_op_type]) + elif op_type == NpuOperationType.ElementWise: + emit.cmd_do_operation(cmd0.NPU_OP_ELEMENTWISE, param=elementwise_op_map[npu_op.sub_op_type]) + else: + assert 0, "Unsupported operation" + + +def generate_conv2d_op( + emit: CommandStreamEmitter, npu_op: NpuConv2DOperation, arch: ArchitectureFeatures +) -> NpuShape3D: + """Generates register commands for Conv2D operations""" + generate_common(emit, npu_op, npu_op.block_traversal, arch) + ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale] + shared_buffer = shared_buffer_allocation_for_npu_op(arch, npu_op, NpuBlockType.ConvolutionMxN, ifm_resampling_mode) + block_config = generate_block_config(emit, npu_op, arch, shared_buffer) + generate_shram_registers_non_elementwise(emit, shared_buffer) + return block_config + + +def generate_conv_depthwise_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, arch: ArchitectureFeatures): + """Generates register commands for depthwise convolution operations""" + generate_common(emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch) + ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale] + shared_buffer = shared_buffer_allocation_for_npu_op( + arch, npu_op, NpuBlockType.ConvolutionDepthWise, ifm_resampling_mode + ) + block_config = generate_block_config(emit, npu_op, arch, shared_buffer) + generate_shram_registers_non_elementwise(emit, shared_buffer) + return block_config + + +def generate_pooling_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, arch: ArchitectureFeatures): + """Generates register commands for pooling operations""" + use_global_scale = ( + npu_op.sub_op_type in (NpuPoolingOp.AVERAGE, NpuPoolingOp.REDUCE_SUM) and sum(npu_op.padding) == 0 + ) + generate_common(emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch, use_global_scale=use_global_scale) + # Pooling op specific + if use_global_scale: + generate_ofm_scaling_for_pooling(emit, npu_op) + ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale] + npu_block_type = NpuBlockType.ReduceSum if npu_op.sub_op_type == NpuPoolingOp.REDUCE_SUM else NpuBlockType.Pooling + shared_buffer = shared_buffer_allocation_for_npu_op(arch, npu_op, npu_block_type, ifm_resampling_mode) + block_config = generate_block_config(emit, npu_op, arch, shared_buffer) + generate_shram_registers_non_elementwise(emit, shared_buffer) + return block_config + + +def generate_elementwise_op(emit: CommandStreamEmitter, npu_op: NpuElementWiseOperation, arch: ArchitectureFeatures): + """Generates register commands for elementwise operations""" + use_global_scale = npu_op.sub_op_type in ( + NpuElementWiseOp.ADD, + NpuElementWiseOp.SUB, + NpuElementWiseOp.MUL, + NpuElementWiseOp.LRELU, + NpuElementWiseOp.ABS, + ) + op_to_scale = generate_scaling_for_elementwise(emit, npu_op) + generate_common( + emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch, use_global_scale=use_global_scale, op_to_scale=op_to_scale + ) + # Elementwise op specific + if npu_op.sub_op_type not in unary_elementwise_ops: + # Binary operation; generate IFM2 registers + assert npu_op.ifm2 is not None + has_scalar = npu_op.ifm2_scalar is not None + generate_ifm2(emit, npu_op.ifm2, has_scalar) + generate_ifm_precision(emit, npu_op.ifm2, 0, cmd0.NPU_SET_IFM2_PRECISION) + generate_ifm2_broadcast(emit, npu_op) + if has_scalar: + quantized_scalar = quantise(npu_op.ifm2_scalar, npu_op.ifm2.quantization) + assert npu_op.ifm2.data_type.min_value() <= quantized_scalar <= npu_op.ifm2.data_type.max_value() + emit.cmd0_with_param(cmd0.NPU_SET_IFM2_SCALAR, quantized_scalar) + ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale] + shared_buffer = shared_buffer_allocation_for_npu_op(arch, npu_op, NpuBlockType.ElementWise, ifm_resampling_mode) + block_config = generate_block_config(emit, npu_op, arch, shared_buffer) + generate_shram_registers_elementwise(emit, npu_op, arch, shared_buffer) + return block_config + + +def generate_dma_op(emit: CommandStreamEmitter, dma_op: NpuDmaOperation): + """Generates register commands for DMA operations""" + emit.cmd0_with_param(cmd0.NPU_SET_DMA0_SRC_REGION, dma_op.src.region) + emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_SRC, dma_op.src.address) + emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, dma_op.dest.region) + + emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_DST, dma_op.dest.address) + emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_LEN, dma_op.src.length) + + +def generate_registers_for_op( + emit: CommandStreamEmitter, npu_op: NpuOperation, arch: ArchitectureFeatures +) -> Optional[NpuShape3D]: + """ + Generates register commands for the given operation, but not the final NPU_OP_... command. + Returns the selected block config + """ + op_type = npu_op.op_type + block_config = None + if op_type == NpuOperationType.Conv2D: + block_config = generate_conv2d_op(emit, npu_op, arch) + elif op_type == NpuOperationType.ConvDepthWise: + block_config = generate_conv_depthwise_op(emit, npu_op, arch) + elif op_type == NpuOperationType.Pooling: + block_config = generate_pooling_op(emit, npu_op, arch) + elif op_type == NpuOperationType.ElementWise: + block_config = generate_elementwise_op(emit, npu_op, arch) + elif op_type == NpuOperationType.Dma: + generate_dma_op(emit, npu_op) + else: + assert 0, "Unsupported operation" + return block_config - # Initialise operator dependency state - prev_ifm_rect = cur_ifm_rect = None - prev_ifm_block_depth = cur_ifm_block_depth = None - prev_ofm_rect = cur_ofm_rect = None - prev_ofm_block = cur_ofm_block = None - prev_kernel = cur_kernel = None - prev_cmd = None +def generate_command_stream( + emit: CommandStreamEmitter, npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, add_to_debug_db=None +): + """Generates register commands for the given list of NPU operations""" + # Calculate memory accesses for every operation + memory_accesses = {} + for npu_op in npu_op_list: + if is_dma_op(npu_op): + memory_accesses[npu_op] = get_dma_memory_accesses(npu_op) + else: + memory_accesses[npu_op] = get_op_memory_accesses(npu_op, arch) if arch.is_yoda_system: emit.cmd0_with_param(cmd0.NPU_SET_PARALLEL_MODE, arch.ncores - 1) - dep_watermark = Watermark(0, 0) - - stream_id = DebugDatabase.add_stream(sg) - DebugDatabase.set_stream_offset(sg, 0) # Default to zero, can only set during file writing - - for cmd_index, cmd in enumerate(cmd_stream): - dep_watermark, cmd_waits = get_cmd_wait_dependency(arch, cmd_stream, memory_accesses, cmd_index, dep_watermark) - - if cmd.cmdtype == CommandType.DMA: - start_coord = cmd.box.start_coord - - src_addr = cmd.in_tensor.address_for_coordinate(start_coord) - dst_addr = cmd.out_tensor.address_for_coordinate(start_coord) - - if cmd.in_tensor.compressed_values is not None: - if cmd.out_tensor.purpose == TensorPurpose.FSBias: - sz = cmd.in_tensor.storage_size() - else: - stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord) - sz = cmd.in_tensor.size_of_compressed_stream(stream_index) - else: - sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr - - emit.cmd0_with_param(cmd0.NPU_SET_DMA0_SRC_REGION, base_ptr_idx_map[cmd.in_tensor.mem_type]) - emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_SRC, src_addr) - if cmd.out_tensor.purpose == TensorPurpose.LUT: - emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, BasePointerIndex.Mem2Mem) - else: - emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, base_ptr_idx_map[cmd.out_tensor.mem_type]) - - emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_DST, dst_addr) - emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_LEN, sz) - dma_channel = 0 - mode = 0 # From external to external - - emit_cmd_waits(cmd_waits) - emit.cmd_do_operation(cmd0.NPU_OP_DMA_START, dma_channel * 16 + mode) - - elif cmd.cmdtype == CommandType.NpuStripe: - - ps = cmd.ps - primary_op = ps.primary_op - npu_block_type = ps.npu_block_type - # Specifies if global scale from the NPU_SET_OFM_SCALE register should be used instead of per-channel scale - use_global_scale = False - # Specifies type of rounding to be used. - rounding_mode = ( - rounding.NATURAL if primary_op.attrs.get("rounding_mode", "") == b"NATURAL" else rounding.TFL - ) - if primary_op.type == Op.ResizeBilinear: - rounding_mode = rounding.TRUNCATE - fmf = primary_op.memory_function - faf = primary_op.activation - fused_quantize = any(op.type == Op.Quantize for op in ps.ops) - # Force output scale, used in operations with fused LUT - # Note: with current LUT support, forced_ofm_quantization is always equal to cmd.ofm_tensor.quantization - # except when primary_op is AddAct + 0 (no-op) + LUT - forced_ofm_quantization = primary_op.forced_output_quantization - ofm_quant = cmd.ofm_tensor.quantization - if forced_ofm_quantization is not None: - ofm_quant = forced_ofm_quantization - - # Specifies which operand to apply scaling to in bitexact elementwise ADD/SUB - op_to_scale = 0 - - # Update state history - prev_ifm_rect = cur_ifm_rect - prev_ifm_block_depth = cur_ifm_block_depth - prev_ofm_rect = cur_ofm_rect - prev_ofm_block = cur_ofm_block - prev_kernel = cur_kernel - cur_kernel = ps.primary_op.kernel if ps.primary_op else None - - block_config = ps.block_config - emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, block_config[0] - 1) - emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_WIDTH_M1, block_config[1] - 1) - emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_DEPTH_M1, block_config[3] - 1) - - shared_buffer = ps.shared_buffer - - if npu_block_type == NpuBlockType.ElementWise: - ifm2_broadcast = 0 - - if cmd.ifm2_tensor and not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape): - # The scalar has to be the ifm2 tensor so switch the ifms - cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor - cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box - - # Set ReverseOperandOrder bit to IFM2_BROADCAST - ifm2_broadcast |= IFM2Broadcast.ReverseOperandOrder - - # Calculate scales needed for arithmetic elementwise operators - if primary_op.type in set((Op.Add, Op.Mul, Op.Sub,)): - input_scale = cmd.ifm_tensor.quantization.scale_f32 if cmd.ifm_tensor.quantization else None - input2_scale = cmd.ifm2_tensor.quantization.scale_f32 if cmd.ifm2_tensor.quantization else None - output_scale = ofm_quant.scale_f32 if ofm_quant else None - use_global_scale = True - - if output_scale is not None and faf in (Op.Sigmoid, Op.Tanh): - output_scale = 1 / 0x3000 - - if primary_op.type == Op.Mul: - if None in (input_scale, input2_scale, output_scale): - ofm_scale = 1 - shift = 0 - else: - ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale) - emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift) - else: # AddAct/SubAct - # Force output scale same as the input scale for - # resizebilinear 1x1 that is converted to add - if "resizebilinear" in primary_op.attrs: - output_scale = input2_scale - - if None in (input_scale, input2_scale, output_scale): - opa_scale = opb_scale = ofm_scale = 1 - opa_shift = shift = 0 - ofm_scale, shift = primary_op.attrs.get("rescale", [1, 0]) - elif input_scale == input2_scale: - opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale( - input_scale, input2_scale, output_scale - ) - opa_shift = 0 # Unused for this case - else: - # Use advanced implementation only when input scales differ - bitdepth = cmd.ifm_tensor.dtype.bits - ( - opa_scale, - opa_shift, - ofm_scale, - shift, - op_to_scale, - ) = scaling.advanced_elementwise_add_sub_scale( - input_scale, input2_scale, output_scale, bitdepth - ) - opb_scale = 0 # Unused for this case - if ifm2_broadcast & IFM2Broadcast.ReverseOperandOrder: - # If the operand order is reversed we also have to swap which operand is scaled - if op_to_scale == scaling.OperandToScale.OPa: - op_to_scale = scaling.OperandToScale.OPb - else: - op_to_scale = scaling.OperandToScale.OPa - - emit.cmd1_with_offset(cmd1.NPU_SET_OPA_SCALE, opa_scale, opa_shift) - emit.cmd1_with_offset(cmd1.NPU_SET_OPB_SCALE, opb_scale) - emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift) - - elif primary_op.type in set((Op.LeakyRelu, Op.Abs,)): - output_scale = ofm_quant.scale_f32 - use_global_scale = True - - if primary_op.type == Op.LeakyRelu: - output_scale = primary_op.attrs["alpha"] - - ofm_scale, shift = scaling.quantise_scale(output_scale) - emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift) - else: - emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, 1, 0) - - # For elementwise set the required SHRAM to be equal to the total size of available SHRAM - uses_lut = primary_op.activation_lut is not None - shram_required = arch.available_shram_banks(uses_lut) - emit.cmd0_with_param(cmd0.NPU_SET_IFM_IB_END, shram_required) - - # Acc buffers not needed so set AB_START to size of SHRAM - emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shram_required) - - # Is not a unary operator - if cmd.ifm2_tensor is not None: - if cmd.ifm2_tensor.shape == []: - # IFM2 is a constant, set UseIFM2Scalar bit to IFM2_BROADCAST - ifm2_broadcast |= IFM2Broadcast.UseIFM2Scalar - else: - ifm_box_shape = cmd.ifm_box.get_size_shape() - ifm2_box_shape = cmd.ifm2_box.get_size_shape() - - if len(cmd.ifm_tensor.shape) > 1 and ifm_box_shape[1] != ifm2_box_shape[1]: - # Broadcast in 'H' dimension - assert cmd.ifm2_tensor.shape[1] == 1 - ifm2_broadcast |= IFM2Broadcast.BroadcastHdim - - if len(cmd.ifm_tensor.shape) > 2 and ifm_box_shape[2] != ifm2_box_shape[2]: - # Broadcast in 'W' dimension - assert cmd.ifm2_tensor.shape[2] == 1 - ifm2_broadcast |= IFM2Broadcast.BroadcastWdim - - if len(cmd.ifm_tensor.shape) > 3 and ifm_box_shape[3] != ifm2_box_shape[3]: - # Broadcast in 'C' dimension - assert cmd.ifm2_tensor.shape[3] == 1 - ifm2_broadcast |= IFM2Broadcast.BroadcastCdim - - # Set IFM2_IB_START to the latter half of the IB space - ifm_ib_start = shared_buffer.bank_locations[SharedBufferArea.IFM] - emit.cmd0_with_param( - cmd0.NPU_SET_IFM2_IB_START, - (shram_required - ifm_ib_start) // shared_buffer.ifm_count + ifm_ib_start, - ) - - emit.cmd0_with_param(cmd0.NPU_SET_IFM2_BROADCAST, ifm2_broadcast) - - else: - emit.cmd0_with_param( - cmd0.NPU_SET_IFM_IB_END, - shared_buffer.bank_locations[SharedBufferArea.IFM] - + shared_buffer.banks_required[SharedBufferArea.IFM], - ) - emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shared_buffer.bank_locations[SharedBufferArea.Accumulators]) - - emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[shared_buffer.use_accumulator_element]) - - if primary_op.type == Op.ResizeBilinear: - # perform nearest neighbor upscale - emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, resampling_mode.NEAREST) - elif primary_op.type == Op.Conv2DBackpropInputSwitchedBias: - # perform insert zero upscale - emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, resampling_mode.TRANSPOSE) - else: - emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, resampling_mode.NONE) - - if npu_block_type in set( - ( - NpuBlockType.ConvolutionMxN, - NpuBlockType.ConvolutionDepthWise, - NpuBlockType.Pooling, - NpuBlockType.ReduceSum, - ) - ): - # Set up padding - explicit_padding = list(primary_op.attrs["explicit_padding"]) # (top, left, bottom, right) - - # Check if this is for horizontal ifm streaming - if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe): - explicit_padding[0] = cmd.pad_top - explicit_padding[2] = cmd.pad_bottom - - # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output, - # because of activation function needed to be fused. - if cmd.ifm_box.start_coord[-2] > 0: - explicit_padding[1] = 0 - if cmd.ifm_box.end_coord[-2] < cmd.ifm_tensor.shape[-2]: - explicit_padding[3] = 0 - emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_TOP, explicit_padding[0]) - emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_LEFT, explicit_padding[1]) - emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_BOTTOM, explicit_padding[2]) - emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_RIGHT, explicit_padding[3]) - - # set kernel x stride low bit - stride = primary_op.attrs["strides"][2] - 1 & 1 - # set kernel y stride low bit - stride |= (primary_op.attrs["strides"][1] - 1 & 1) << 1 - # set kernel x stride extension bits - stride |= (primary_op.attrs["strides"][2] - 1 >> 1) << 6 - # set kernel y stride extension bits - stride |= (primary_op.attrs["strides"][1] - 1 >> 1) << 9 - - if npu_block_type in set((NpuBlockType.Pooling, NpuBlockType.ReduceSum)): - k_height, k_width = primary_op.attrs["ksize"][1:3] - emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_HEIGHT_M1, k_height - 1) - emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_WIDTH_M1, k_width - 1) - - valid_padding = sum(explicit_padding) == 0 - - if primary_op.type in set((Op.AvgPool, Op.ResizeBilinear, Op.ReduceSum)) and valid_padding: - # For valid padding vela has to output scaling values - if faf == Op.Sigmoid or faf == Op.Tanh: - rescale = 0x3000 * cmd.ifm_tensor.quantization.scale_f32 - if cmd.ifm_tensor.dtype == DataType.int16: - # Calculate scale and shift for the output scale of 1/(3*4096) - shift = 0 - max_rescale = np.iinfo(np.int16).max / 2 - while rescale <= max_rescale and shift <= 30: - shift += 1 - rescale *= 2 - scale = int(rescale) - else: - rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1 - scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits) - scale = int(round_away_zero(scale * rescale)) - elif fused_quantize: - # Quantize op requires different scaling - ifm_scale_f64 = np.double(cmd.ifm_tensor.quantization.scale_f32) - ofm_scale_f64 = np.double(ofm_quant.scale_f32) - scale, shift = scaling.quantise_scale(ifm_scale_f64 / ofm_scale_f64) - elif primary_op.type == Op.ResizeBilinear and "rescale" in primary_op.attrs: - rescale = primary_op.attrs["rescale"] - rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1 - scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits) - scale = int(round_away_zero(scale * rescale)) - else: - # In case avg pool fused with concat or other memory operation, rescaling might be needed. - # k_height == k_width == 1 is allways true in this case - # Normally the scale is maximised, to get maximum precision, which means that - # if rescale != 1, scale need to consider the number of bits needed for rescaling - if None not in (ofm_quant.scale_f32, cmd.ifm_tensor.quantization.scale_f32,): - rescale = cmd.ifm_tensor.quantization.scale_f32 / ofm_quant.scale_f32 - rescale_bits = 0 - if k_height == k_width == 1: - if fmf == Op.ConcatSliceWrite: - rounding_mode = rounding.NATURAL - if rescale > 1: - rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1 - elif rescale < 1: - rescale_bits = -(len(bin(round_up_to_int(1 / rescale))) - 2 - 1) - scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits) - scale = int(round_away_zero(scale * rescale)) - else: - scale = 1 - shift = 0 - - emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, scale, shift) - # Valid-padded average pool should use the global scale from - # NPU_SET_OFM_SCALE register, which is set above. - use_global_scale = True - - else: # Convolution - assert cmd.weight_tensor.block_traversal != TensorBlockTraversal.Default - # Reduced precision quantization and natural rounding used for int16 - if cmd.ifm_tensor.dtype == DataType.int16: - rounding_mode = rounding.NATURAL - stride |= (cur_kernel.dilation.y - 1) << 4 - stride |= (cur_kernel.dilation.x - 1) << 3 - emit.cmd0_with_param( - cmd0.NPU_SET_KERNEL_HEIGHT_M1, cur_kernel.dilation.y * (cmd.weight_tensor.shape[0] - 1) - ) - emit.cmd0_with_param( - cmd0.NPU_SET_KERNEL_WIDTH_M1, cur_kernel.dilation.x * (cmd.weight_tensor.shape[1] - 1) - ) - if cmd.weight_tensor.block_traversal == TensorBlockTraversal.PartKernelFirst: - # Part-kernel-first weight ordering - assert npu_block_type == NpuBlockType.ConvolutionMxN - stride |= 1 << 2 - - emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_STRIDE, stride) - - elif npu_block_type in set((NpuBlockType.VectorProduct,)): - # Vector product is implemented using a 1x1 convolution so need - # to setup the appropriate padding and kernel info - emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_TOP, 0) - emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_LEFT, 0) - emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_BOTTOM, 0) - emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_RIGHT, 0) - - # kernel stride reg = 0 means stride(1,1) + depth first weight - # order + dilation(0,0) + kernel_split_size=8 - emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_STRIDE, 0) - - emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_HEIGHT_M1, 0) - emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_WIDTH_M1, 0) - - if npu_block_type in set( - (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.VectorProduct) - ): - # Emit Weight base address commands, only maps the area required for - # this command's weights from the larger tensor. - stream_index = cmd.weight_tensor.compressed_stream_index_from_coord(cmd.weight_box.start_coord) - weight_substream_offsets = cmd.weight_tensor.compressed_values_substream_offsets[stream_index] - substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length - - # Extract weight substream offsets and calculate their lengths - assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0) - weight_addr = cmd.weight_tensor.address_for_coordinate(cmd.weight_box.start_coord) - - # Set weights sources for active and present cores - for core, param in enumerate( - [ - (cmd1.NPU_SET_WEIGHT_BASE, cmd1.NPU_SET_WEIGHT_LENGTH), - (cmd1.NPU_SET_WEIGHT1_BASE, cmd1.NPU_SET_WEIGHT1_LENGTH), - ] - ): - if core < substreams: - emit.cmd1_with_offset(param[0], weight_addr + weight_substream_offsets[core]) - emit.cmd1_with_offset( - param[1], weight_substream_offsets[core + 1] - weight_substream_offsets[core] - ) - elif core < arch.ncores: - emit.cmd1_with_offset(param[0], weight_addr) - emit.cmd1_with_offset(param[1], 0) - - weight_region = base_ptr_idx_map[cmd.weight_tensor.mem_type] - emit.cmd0_with_param(cmd0.NPU_SET_WEIGHT_REGION, weight_region) - - # Emit Scale & Bias base address commands, with length matching the amount required by - # the weight tensors. - if cmd.scale_tensor is not None: - scale_substream_offsets = cmd.scale_tensor.compressed_values_substream_offsets[stream_index] - substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length - - # Extract scale substream offsets and calculate their lengths - assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0) - scale_addr = cmd.scale_tensor.address_for_coordinate(cmd.weight_box.start_coord[-1:]) - - # Set scale sources for active and present cores - for core, param in enumerate( - [ - (cmd1.NPU_SET_SCALE_BASE, cmd1.NPU_SET_SCALE_LENGTH), - (cmd1.NPU_SET_SCALE1_BASE, cmd1.NPU_SET_SCALE1_LENGTH), - ] - ): - if core < substreams: - emit.cmd1_with_offset(param[0], scale_addr + scale_substream_offsets[core]) - emit.cmd1_with_offset( - param[1], scale_substream_offsets[core + 1] - scale_substream_offsets[core] - ) - elif core < arch.ncores: - emit.cmd1_with_offset(param[0], scale_addr) - emit.cmd1_with_offset(param[1], 0) - - # Emit base address for NPU to access scale & bias data - scale_region = base_ptr_idx_map[cmd.scale_tensor.mem_type] - emit.cmd0_with_param(cmd0.NPU_SET_SCALE_REGION, scale_region) - - ofm_quant_qmin = ofm_quant.quant_min if ofm_quant else np.iinfo(np.int16).min - ofm_quant_qmax = ofm_quant.quant_max if ofm_quant else np.iinfo(np.int16).max - ifm_min = cmd.ifm_tensor.quantization.min if cmd.ifm_tensor.quantization else np.iinfo(np.int16).min - ifm_max = cmd.ifm_tensor.quantization.max if cmd.ifm_tensor.quantization else np.iinfo(np.int16).max - - # Emit commands for any fused activation function - if faf is None: - emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE) - # Even if no activation function, values need to be set to override previous values - faf_min = ofm_quant_qmin - faf_max = ofm_quant_qmax - elif faf == Op.Relu: - emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE) - faf_min = quantise_float32(0.0, ofm_quant.scale_f32, ofm_quant.zero_point) - faf_max = ofm_quant_qmax - elif faf == Op.Relu6: - emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE) - faf_min = quantise_float32(0.0, ofm_quant.scale_f32, ofm_quant.zero_point) - faf_max = quantise_float32(6.0, ofm_quant.scale_f32, ofm_quant.zero_point) - elif faf == Op.ReluN1To1: - emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE) - faf_min = quantise_float32(-1.0, ofm_quant.scale_f32, ofm_quant.zero_point) - faf_max = quantise_float32(1.0, ofm_quant.scale_f32, ofm_quant.zero_point) - elif faf == Op.Tanh: - emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.TANH) - if primary_op.type in set((Op.AvgPool, Op.ResizeBilinear)): - faf_min = quantise_float32(-1.0, ofm_quant.scale_f32, ofm_quant.zero_point) - faf_max = quantise_float32(1.0, ofm_quant.scale_f32, ofm_quant.zero_point) - else: - faf_min = quantise_float32(clamp_tanh(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point) - faf_max = quantise_float32(clamp_tanh(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point) - elif faf == Op.Sigmoid: - emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.SIGMOID) - if primary_op.type in set((Op.AvgPool, Op.ResizeBilinear)): - faf_min = quantise_float32(0, ofm_quant.scale_f32, ofm_quant.zero_point) - faf_max = quantise_float32(1.0, ofm_quant.scale_f32, ofm_quant.zero_point) - else: - faf_min = quantise_float32(clamp_sigmoid(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point) - faf_max = quantise_float32(clamp_sigmoid(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point) - elif faf == Op.LUT: - lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", -1) - assert activation.LUT_START.value <= lut_index <= activation.LUT_END.value, "LUT index out of range." - if cmd.ofm_tensor.dtype == DataType.int32: - lut_index |= 3 << 12 # Force I8 range - emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, lut_index) - faf_min = ofm_quant_qmin - faf_max = ofm_quant_qmax - else: - raise Exception("Unsupported fused_activation_function = " + faf.name) - - # Activation range needs to be set based upon the quantisation range and the fused activation range - emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MIN, max(ofm_quant_qmin, faf_min)) - emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MAX, min(ofm_quant_qmax, faf_max)) - - out_shape = cmd.ofm_box.get_size_shape() - if len(out_shape) >= 4: - emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT_M1, out_shape[-3] - 1) - else: - emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT_M1, 0) - if len(out_shape) >= 2: - emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH_M1, out_shape[-2] - 1) - else: - emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH_M1, 0) - emit.cmd0_with_param(cmd0.NPU_SET_OFM_DEPTH_M1, out_shape[-1] - 1) - - if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum)): - in_shape = cmd.ifm_box.get_size_shape() - emit.cmd0_with_param(cmd0.NPU_SET_IFM_DEPTH_M1, in_shape[-1] - 1) - else: - emit.cmd0_with_param(cmd0.NPU_SET_IFM_DEPTH_M1, out_shape[-1] - 1) - - for tens, box, region_op, ptr_ops, stride_ops, zero_point_op in ( - ( - cmd.ifm_tensor, - cmd.ifm_box, - cmd0.NPU_SET_IFM_REGION, - (cmd1.NPU_SET_IFM_BASE0, cmd1.NPU_SET_IFM_BASE1, cmd1.NPU_SET_IFM_BASE2, cmd1.NPU_SET_IFM_BASE3), - (cmd1.NPU_SET_IFM_STRIDE_C, cmd1.NPU_SET_IFM_STRIDE_Y, cmd1.NPU_SET_IFM_STRIDE_X), - cmd0.NPU_SET_IFM_ZERO_POINT, - ), - ( - cmd.ifm2_tensor, - cmd.ifm2_box, - cmd0.NPU_SET_IFM2_REGION, - ( - cmd1.NPU_SET_IFM2_BASE0, - cmd1.NPU_SET_IFM2_BASE1, - cmd1.NPU_SET_IFM2_BASE2, - cmd1.NPU_SET_IFM2_BASE3, - ), - (cmd1.NPU_SET_IFM2_STRIDE_C, cmd1.NPU_SET_IFM2_STRIDE_Y, cmd1.NPU_SET_IFM2_STRIDE_X), - cmd0.NPU_SET_IFM2_ZERO_POINT, - ), - ( - cmd.ofm_tensor, - cmd.ofm_box, - cmd0.NPU_SET_OFM_REGION, - (cmd1.NPU_SET_OFM_BASE0, cmd1.NPU_SET_OFM_BASE1, cmd1.NPU_SET_OFM_BASE2, cmd1.NPU_SET_OFM_BASE3), - (cmd1.NPU_SET_OFM_STRIDE_C, cmd1.NPU_SET_OFM_STRIDE_Y, cmd1.NPU_SET_OFM_STRIDE_X), - cmd0.NPU_SET_OFM_ZERO_POINT, - ), - ): - - if tens is None: - continue - - need_zero_point = ( - (faf is not None and forced_ofm_quantization is None) - or (fmf == Op.ConcatSliceWrite) - or fused_quantize - ) - if ( - (primary_op.type in set((Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL)) and not need_zero_point) - or ( - tens.dtype == DataType.int32 - and zero_point_op in (cmd0.NPU_SET_IFM_ZERO_POINT, cmd0.NPU_SET_IFM2_ZERO_POINT) - ) - or tens.quantization is None - ): - # Actual integer operation, just set scale to 1 and zero point to 0 - emit.cmd0_with_param(zero_point_op, 0) - else: - assert tens.quantization.zero_point is not None, "need an actual zero point set" - if cmd0.NPU_SET_OFM_ZERO_POINT == zero_point_op and forced_ofm_quantization is not None: - zero_point = forced_ofm_quantization.zero_point - elif ( - "resizebilinear" in primary_op.attrs - and primary_op.type == Op.Add - and cmd0.NPU_SET_OFM_ZERO_POINT == zero_point_op - ): - # Force output zero point same as the input zero point - # for resizebilinear 1x1 that is converted to add - zero_point = cmd.ifm2_tensor.quantization.zero_point - else: - zero_point = tens.quantization.zero_point - emit.cmd0_with_param(zero_point_op, int(zero_point)) - - if tens.shape == []: - # Empty shape, elementwise constant - ifm2_scalar = tens.quant_values - assert ifm2_scalar.size == 1 - emit.cmd0_with_param(cmd0.NPU_SET_IFM2_SCALAR, int(ifm2_scalar.item(0))) - continue - - height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer( - box.start_coord, box.end_coord - ) - if npu_block_type != NpuBlockType.VectorProduct: - if tens == cmd.ifm_tensor: - emit.cmd0_with_param(cmd0.NPU_SET_IFM_HEIGHT0_M1, height_0 - 1) - emit.cmd0_with_param(cmd0.NPU_SET_IFM_HEIGHT1_M1, height_1 - 1) - emit.cmd0_with_param(cmd0.NPU_SET_IFM_WIDTH0_M1, width_0 - 1) - elif tens == cmd.ofm_tensor: - emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT0_M1, height_0 - 1) - emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT1_M1, height_1 - 1) - emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH0_M1, width_0 - 1) - if tens == cmd.ifm2_tensor: - emit.cmd0_with_param(cmd0.NPU_SET_IFM2_HEIGHT0_M1, height_0 - 1) - emit.cmd0_with_param(cmd0.NPU_SET_IFM2_HEIGHT1_M1, height_1 - 1) - emit.cmd0_with_param(cmd0.NPU_SET_IFM2_WIDTH0_M1, width_0 - 1) - else: - if len(out_shape) == 2: - assert out_shape[0] == 1 - if tens == cmd.ifm_tensor: - emit.cmd0_with_param(cmd0.NPU_SET_IFM_WIDTH0_M1, 0) - elif tens == cmd.ofm_tensor: - emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH0_M1, 0) - else: - assert False - - emit.cmd0_with_param(region_op, base_ptr_idx_map[tens.mem_type]) - - for idx, addr in enumerate(addresses): - if addr is None: - addresses[idx] = 0 - - emit.cmd1_with_offset(ptr_ops[0], addresses[0]) - emit.cmd1_with_offset(ptr_ops[1], addresses[1]) - emit.cmd1_with_offset(ptr_ops[2], addresses[2]) - emit.cmd1_with_offset(ptr_ops[3], addresses[3]) - - strides = tens.get_strides() - emit.cmd1_with_offset(stride_ops[0], strides[1]) # stride between 16-byte channel blocks (C) - emit.cmd1_with_offset(stride_ops[2], strides[3]) # stride between horisontal values (W) - emit.cmd1_with_offset(stride_ops[1], strides[2]) # stride between vertical values (H) - - if tens.format == TensorFormat.NHCWB16: - # Check that all BasePointer addresses are aligned to 16 bytes - assert (int(addresses[0]) % 16) == 0 - assert (int(addresses[1]) % 16) == 0 - assert (int(addresses[2]) % 16) == 0 - assert (int(addresses[3]) % 16) == 0 - - ofm_dtype = cmd.ofm_tensor.dtype - assert ofm_dtype.type & BaseType.Int - prec = 0 - if ofm_dtype.size_in_bits() == 8: - prec = 0 - elif ofm_dtype.size_in_bits() == 16: - prec = 2 - elif ofm_dtype.size_in_bits() == 32: - prec = 4 - else: - assert 0 - - if ofm_dtype.type & BaseType.Signed: - prec += 1 - - if use_global_scale: - # Set global scale bit, as opposed to using per channel scale - prec |= 1 << 8 - - if cmd.ofm_tensor.format == TensorFormat.NHCWB16: - prec |= 1 << 6 - - prec |= rounding_mode.value << 14 - - emit.cmd0_with_param(cmd0.NPU_SET_OFM_PRECISION, prec) - - prec = None - weight_bits = 8 - if cmd.weight_tensor is not None: - weight_bits = cmd.weight_tensor.dtype.size_in_bits() - - ifm_dtype = cmd.ifm_tensor.dtype - - assert weight_bits == 8, "Unsupported weight bit depth" - assert ( - ifm_dtype.size_in_bits() in {8, 16} - or ifm_dtype.size_in_bits() == 32 - and npu_block_type in (NpuBlockType.ElementWise, NpuBlockType.ReduceSum) - ), "Unsupported ifm bit depth" - - if ifm_dtype.size_in_bits() == 8: - if ifm_dtype.type & BaseType.Signed: - prec = ifm_precision.S8 - else: - prec = ifm_precision.U8 - elif ifm_dtype.size_in_bits() == 16: - if ifm_dtype.type & BaseType.Signed: - prec = ifm_precision.S16 - else: - prec = ifm_precision.U16 - elif ifm_dtype == DataType.int32: - prec = ifm_precision.S32 - - ifm_prec = prec.value - ifm2_prec = ifm_prec - - if cmd.ifm_tensor.format == TensorFormat.NHCWB16: - ifm_prec |= 1 << 6 - - ifm_prec |= op_to_scale << 8 - - emit.cmd0_with_param(cmd0.NPU_SET_IFM_PRECISION, ifm_prec) - - if cmd.ifm2_tensor is not None: - if cmd.ifm2_tensor.format == TensorFormat.NHCWB16: - ifm2_prec |= 1 << 6 - emit.cmd0_with_param(cmd0.NPU_SET_IFM2_PRECISION, ifm2_prec) - - # Get op parameters - cur_ifm_block_depth = get_op_ifmofm_block_depth(arch, cmd) - cur_ofm_block = Block(ps.block_config[1], ps.block_config[0], ps.block_config[3]) - cur_ofm_rect = get_op_ofm_rect(cmd) - cur_ifm_rect = get_op_ifm_rect(cmd) - cur_padLT = get_op_padding_lt(cmd) - if (prev_kernel is not None) and (cur_kernel is not None) and has_prev_op_dependency(prev_cmd, cmd): - if cmd.ifm_tensor.shape == prev_cmd.ofm_tensor.shape: - blockdep = arch.calc_block_dep( - prev_ifm_rect, - prev_ofm_rect, - prev_ifm_block_depth, - prev_ofm_block, - prev_kernel, - cur_ifm_rect, - cur_ofm_rect, - cur_ifm_block_depth, - cur_ofm_block, - cur_kernel, - cur_padLT, - ) - else: - blockdep = 0 - else: - blockdep = ArchitectureFeatures.MAX_BLOCKDEP - - # Set between every op (dependent or not) + prev_op = None + prev_block_config = None + # Generate register commands for all operations + for op_index, npu_op in enumerate(npu_op_list): + dep_watermark, cmd_waits = get_wait_dependency(arch, npu_op_list, memory_accesses, op_index, dep_watermark) + block_config = generate_registers_for_op(emit, npu_op, arch) + if not is_dma_op(npu_op): + # Generate BLOCKDEP + assert block_config is not None + blockdep = calc_blockdep(arch, prev_op, prev_block_config, npu_op, block_config) blockdep = min(blockdep, arch.max_blockdep) emit.cmd0_with_param(cmd0.NPU_SET_BLOCKDEP, blockdep) - prev_cmd = cmd - - emit_cmd_waits(cmd_waits) - DebugDatabase.add_command(stream_id, emit.offset, primary_op) - - if npu_block_type == NpuBlockType.ConvolutionMxN: - emit.cmd_do_operation(cmd0.NPU_OP_CONV) - elif npu_block_type == NpuBlockType.ConvolutionDepthWise: - emit.cmd_do_operation(cmd0.NPU_OP_DEPTHWISE) - elif npu_block_type == NpuBlockType.VectorProduct: - # Vector product is implemented using a 1x1 convolution - emit.cmd_do_operation(cmd0.NPU_OP_CONV) - elif npu_block_type == NpuBlockType.Pooling: - param = pooling_mode.MAX.value if primary_op.type.is_maxpool_op() else pooling_mode.AVERAGE.value - emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=param) - elif npu_block_type == NpuBlockType.ReduceSum: - emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=pooling_mode.REDUCE_SUM.value) - elif npu_block_type == NpuBlockType.ElementWise: - param = elementwise_mode_map[primary_op.type] - emit.cmd_do_operation(cmd0.NPU_OP_ELEMENTWISE, param) - else: - print("Warning: Skipping register command stream generation for", ps) - + prev_op = npu_op + prev_block_config = block_config + + generate_cmd_waits(emit, cmd_waits) + # Generate the actual NPU_OP command + generate_operation_code(emit, npu_op) + if add_to_debug_db is not None: + add_to_debug_db(npu_op, emit.offset) # Fill in final part of command stream: emit.cmd_do_operation(cmd0.NPU_OP_STOP, param=0xFFFF) + +def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False): + """Generates command stream for the subgraph, adds it to sg.register_command_stream""" + # Convert high level command stream to list of NpuOperation + npu_op_list = [] + npu_op_to_cmd = dict() # map from npu op to high level command + for cmd in sg.high_level_command_stream: + if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default: + print("Warning: Skipping register command stream generation for", cmd.ps) + else: + npu_op = convert_command_to_npu_op(cmd, arch) + npu_op_list.append(npu_op) + npu_op_to_cmd[npu_op] = cmd + if verbose: + print_operations(npu_op_list) + # Generate register commands + stream_id = DebugDatabase.add_stream(sg) + DebugDatabase.set_stream_offset(sg, 0) # Default to zero, can only set during file writing + emit = CommandStreamEmitter() + + def add_to_debug_db(npu_op: NpuOperation, offset: int): + """Adds info to the debug database""" + if not is_dma_op(npu_op): + cmd = npu_op_to_cmd[npu_op] + DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op) + + generate_command_stream(emit, npu_op_list, arch, add_to_debug_db) sg.register_command_stream = emit.to_list() if verbose: emit.print_cmds() print("number of commands", len(emit.cmd_stream)) print("command stream length in words", len(sg.register_command_stream)) + + +def generate_register_command_stream(npu_op_list: List[NpuOperation], accelerator: Accelerator) -> List[int]: + """ + Public facing API for generating an ethosu register command stream. + Calculates dependencies between commands and inserts wait operations if needed. + + :param npu_op_list: List[NpuOperation] list of high level NPU operations + :param accelerator: architecture_features.Accelerator enum to pick the correct ethosu accelerator + :return ethosu instructions, as a list of 32-bit integers + """ + emit = CommandStreamEmitter() + arch = ArchitectureFeatures( + vela_config=None, + system_config=None, + accelerator_config=accelerator.value, + override_block_config=None, + block_config_limit=None, + global_memory_clock_scale=1.0, + max_blockdep=ArchitectureFeatures.MAX_BLOCKDEP, + weight_estimation_scaling=1.0, + ) + generate_command_stream(emit, npu_op_list, arch) + return emit.to_list() diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py index 51fb168..c957be8 100644 --- a/ethosu/vela/shared_buffer_allocation.py +++ b/ethosu/vela/shared_buffer_allocation.py @@ -15,14 +15,20 @@ # limitations under the License. # Description: # Shared buffer allocation works out how to allocate the Ethos-U55 shared buffer for a given pass. +from typing import List +from typing import Tuple + import numpy as np +from .api import NpuActivationOp +from .api import NpuBlockOperation from .architecture_features import ArchitectureFeatures from .architecture_features import Block from .architecture_features import SharedBufferArea from .architecture_features import SHRAMElements from .errors import VelaError from .ethos_u55_regs.ethos_u55_regs import resampling_mode +from .high_level_command_to_npu_op import to_kernel from .operation import Kernel from .operation import NpuBlockType from .range_set import MemoryRangeSet @@ -30,24 +36,30 @@ from .tensor import MemArea class SharedBufferAllocation: - def __init__(self, arch, ps): + def __init__( + self, + arch, + kernel, + uses_lut, + npu_block_type, + all_fms_have_quant, + ifm_resampling_mode, + ifm_bits, + ifm_depth, + ifm_count, + ofm_shape, + ): self.arch = arch self.bank_locations = np.zeros(SharedBufferArea.Size) self.banks_required = np.zeros(SharedBufferArea.Size) - ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm() - - self.kernel = Kernel(1, 1) - self.is_elementwise = ps.npu_block_type == NpuBlockType.ElementWise - self.uses_lut = False - self.ifm_count = 1 - - if ps.primary_op: - self.kernel = ps.primary_op.kernel - self.uses_lut = ps.primary_op.activation_lut is not None + self.kernel = Kernel(1, 1) if kernel is None else kernel + self.is_elementwise = npu_block_type == NpuBlockType.ElementWise + self.uses_lut = uses_lut + self.ifm_count = ifm_count - self.is_equal_depth_op = self.is_elementwise or ps.npu_block_type in ( + self.is_equal_depth_op = self.is_elementwise or npu_block_type in ( NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling, ) @@ -58,42 +70,26 @@ class SharedBufferAllocation: else: self.use_ifm_element = SHRAMElements.IFM8 - self.ifm_resampling_mode = resampling_mode.NONE - self.ifm_bits = 0 - self.ifm_depth = 0 - if ifm_tensor: - self.ifm_resampling_mode = ifm_tensor.resampling_mode - self.ifm_bits = ifm_tensor.dtype.size_in_bits() - - if ifm_tensor.shape != []: - self.ifm_depth = ifm_tensor.shape[-1] - - if self.is_elementwise: - self.ifm_count = 2 - if ifm_tensor.shape == []: # Scalar in ifm1 - assert ifm2_tensor - self.ifm_depth = ifm2_tensor.shape[-1] - self.ifm_count = 1 - elif not ifm2_tensor or ifm2_tensor.shape == []: # Scalar in ifm2 - self.ifm_count = 1 - - if self.ifm_bits == 16: - if is_acc_40bits_used(ps.npu_block_type, ifm_tensor, ofm_tensor, ifm2_tensor): - self.use_accumulator_element = SHRAMElements.Acc40 - self.use_ifm_element = self.use_ifm_element + 1 - assert (self.use_ifm_element == SHRAMElements.IFM16) or ( - self.use_ifm_element == SHRAMElements.IFM16_Elementwise - ) - elif self.ifm_bits == 32: - assert ( - self.is_elementwise or ps.npu_block_type == NpuBlockType.ReduceSum - ), "Unsupported 32-bit IFM operation" - self.use_ifm_element = SHRAMElements.IFM32 - else: - assert self.ifm_bits == 8, "Unexpected IFM bitdepth" + self.ifm_resampling_mode = ifm_resampling_mode + self.ifm_bits = ifm_bits + self.ifm_depth = ifm_depth + self.ifm_count = ifm_count + + if self.ifm_bits == 16: + if npu_block_type != NpuBlockType.Pooling and all_fms_have_quant: + self.use_accumulator_element = SHRAMElements.Acc40 + self.use_ifm_element = self.use_ifm_element + 1 + assert (self.use_ifm_element == SHRAMElements.IFM16) or ( + self.use_ifm_element == SHRAMElements.IFM16_Elementwise + ) + elif self.ifm_bits == 32: + assert self.is_elementwise or npu_block_type == NpuBlockType.ReduceSum, "Unsupported 32-bit IFM operation" + self.use_ifm_element = SHRAMElements.IFM32 + else: + assert self.ifm_bits == 8, "Unexpected IFM bitdepth" self.ifm_block_depth = arch.calc_ifm_block_depth(self.ifm_depth, self.ifm_bits) - self.ofm_tensor = ofm_tensor + self.ofm_shape = ofm_shape self.banks_required[SharedBufferArea.Weights] = arch.shram_reserved_weight_banks self.banks_required[SharedBufferArea.OFM] = arch.shram_reserved_output_banks @@ -168,15 +164,63 @@ class SharedBufferAllocation: ) -def is_acc_40bits_used(npu_block_type, ifm_tensor, ofm_tensor, ifm2_tensor=None): +def _all_fms_have_quant(ifm_tensor, ofm_tensor, ifm2_tensor=None) -> bool: tensors = [t for t in (ifm_tensor, ifm2_tensor, ofm_tensor) if t is not None] scales = [t.quantization.scale_f32 for t in tensors if t.quantization is not None] - has_scale = len(tensors) == len(scales) and None not in scales - return npu_block_type != NpuBlockType.Pooling and has_scale + return len(tensors) == len(scales) and None not in scales -def shared_buffer_allocation_for_pass_and_block_config(arch, ps, block_config): - alloc = SharedBufferAllocation(arch, ps) +def is_acc_40bits_used(npu_block_type, ifm_tensor, ofm_tensor, ifm2_tensor=None): + return npu_block_type != NpuBlockType.Pooling and _all_fms_have_quant(ifm_tensor, ofm_tensor, ifm2_tensor) + + +def shared_buffer_allocation_for_pass(arch, ps) -> SharedBufferAllocation: + ifm_tensor, ifm2_tensor, _, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm() + all_fms_have_quant = _all_fms_have_quant(ifm_tensor, ifm2_tensor, ofm_tensor) + + kernel = Kernel(1, 1) + is_elementwise = ps.npu_block_type == NpuBlockType.ElementWise + uses_lut = False + ifm_count = 1 + + if ps.primary_op: + kernel = ps.primary_op.kernel + uses_lut = ps.primary_op.activation_lut is not None + + ifm_resampling_mode = resampling_mode.NONE + ifm_bits = 0 + ifm_depth = 0 + if ifm_tensor: + ifm_resampling_mode = ifm_tensor.resampling_mode + ifm_bits = ifm_tensor.dtype.size_in_bits() + + if ifm_tensor.shape != []: + ifm_depth = ifm_tensor.shape[-1] + + if is_elementwise: + ifm_count = 2 + if ifm_tensor.shape == []: # Scalar in ifm1 + assert ifm2_tensor + ifm_depth = ifm2_tensor.shape[-1] + ifm_count = 1 + elif not ifm2_tensor or ifm2_tensor.shape == []: # Scalar in ifm2 + ifm_count = 1 + return SharedBufferAllocation( + arch, + kernel, + uses_lut, + npu_block_type=ps.npu_block_type, + all_fms_have_quant=all_fms_have_quant, + ifm_resampling_mode=ifm_resampling_mode, + ifm_bits=ifm_bits, + ifm_depth=ifm_depth, + ifm_count=ifm_count, + ofm_shape=ofm_tensor.shape, + ) + + +def shared_buffer_allocation_for_pass_and_block_config(arch, ps, block_config) -> SharedBufferAllocation: + alloc = shared_buffer_allocation_for_pass(arch, ps) assert (alloc.ifm_block_depth == block_config[2]) or alloc.is_equal_depth_op if alloc.try_block(Block(block_config[1], block_config[0], block_config[3])): return alloc @@ -184,9 +228,34 @@ def shared_buffer_allocation_for_pass_and_block_config(arch, ps, block_config): return None -def find_block_configs_suitable_for_pass_and_shared_buffer(arch, ps): - alloc = SharedBufferAllocation(arch, ps) - +def shared_buffer_allocation_for_npu_op( + arch, npu_op: NpuBlockOperation, npu_block_type: NpuBlockType, ifm_resampling_mode +) -> SharedBufferAllocation: + uses_lut = npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP + fms = [npu_op.ifm, npu_op.ofm] + if npu_op.ifm2 is not None: + fms.append(npu_op.ifm2) + all_fms_have_quant = not any(fm.quantization is None or fm.quantization.scale_f32 is None for fm in fms) + ifm_bits = npu_op.ifm.data_type.size_in_bits() + ifm_depth = npu_op.ifm.shape.depth + ifm_count = 2 if npu_op.ifm2 is not None and npu_op.ifm2_scalar is None else 1 + ofm_shape = [1, npu_op.ofm.shape.height, npu_op.ofm.shape.width, npu_op.ofm.shape.depth] + return SharedBufferAllocation( + arch, + to_kernel(npu_op.kernel), + uses_lut, + npu_block_type=npu_block_type, + all_fms_have_quant=all_fms_have_quant, + ifm_resampling_mode=ifm_resampling_mode, + ifm_bits=ifm_bits, + ifm_depth=ifm_depth, + ifm_count=ifm_count, + ofm_shape=ofm_shape, + ) + + +def find_suitable_block_configs(arch, alloc: SharedBufferAllocation) -> List[Tuple]: + """Returns list of block configs that would fit with the given shared buffer allocation""" if arch.override_block_config: config = alloc.try_block(arch.override_block_config) if config is None: @@ -195,14 +264,14 @@ def find_block_configs_suitable_for_pass_and_shared_buffer(arch, ps): # Constrain the search space if the OFM is smaller than the max block size # - Add other block search constraints here if required - if len(alloc.ofm_tensor.shape) <= 2: - max_block_height = max_block_width = alloc.ofm_tensor.shape[0] + if len(alloc.ofm_shape) <= 2: + max_block_height = max_block_width = alloc.ofm_shape[0] else: - max_block_width = alloc.ofm_tensor.shape[-2] - max_block_height = alloc.ofm_tensor.shape[-3] + max_block_width = alloc.ofm_shape[-2] + max_block_height = alloc.ofm_shape[-3] # Common block depth - max_block_depth = alloc.ofm_tensor.shape[-1] + max_block_depth = alloc.ofm_shape[-1] # Constrain to valid ranges before search max_block_width = min(arch.ofm_block_max.width, max_block_width) @@ -224,3 +293,8 @@ def find_block_configs_suitable_for_pass_and_shared_buffer(arch, ps): assert len(valid_block_configs) > 0 return valid_block_configs + + +def find_block_configs_suitable_for_pass_and_shared_buffer(arch, ps) -> List[Tuple]: + alloc = shared_buffer_allocation_for_pass(arch, ps) + return find_suitable_block_configs(arch, alloc) diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py index efd91a3..01146ee 100644 --- a/ethosu/vela/softmax.py +++ b/ethosu/vela/softmax.py @@ -24,8 +24,10 @@ import numpy as np from . import fp_math from . import scaling +from .api import NpuRoundingMode from .data_type import DataType from .debug_database import DebugDatabase +from .operation import ActivationFunction from .operation import Op from .operation import Operation from .tensor import create_const_tensor @@ -227,6 +229,12 @@ class SoftMax: no_scale_quant = ifm.quantization.clone() no_scale_quant.scale_f32 = None no_scale_quant.zero_point = 0 + activation = ActivationFunction(Op.Clip) + activation.min = ifm.quantization.quant_min + activation.max = ifm.quantization.quant_max + activation2 = activation.clone() + activation2.min = 2 * ifm.quantization.quant_min + activation2.max = 2 * ifm.quantization.quant_max one_scale_quant = ifm.quantization.clone() one_scale_quant.scale_f32 = 1.0 one_scale_quant.zero_point = 0 @@ -263,20 +271,23 @@ class SoftMax: ifm_exp = Tensor(ifm.shape, DataType.int32, sub_op.name + "_0") ifm_exp.quantization = one_scale_quant.clone() ifm_exp.quantization.zero_point = 127 - ifm_exp.quantization.quant_min = -128 - ifm_exp.quantization.quant_max = 127 + sub_op.activation = ActivationFunction(Op.LUT) + # Note: activation.min/max are non-quantized values + sub_op.activation.min = -128 - ifm_exp.quantization.zero_point + sub_op.activation.max = 127 - ifm_exp.quantization.zero_point sub_op.set_output_tensor(ifm_exp) DebugDatabase.add_optimised(self.op, sub_op) # PASS 2 - SHR shr2_op = Operation(Op.SHR, self.op.name + "_shr2") - shr2_op.attrs["rounding_mode"] = b"NATURAL" + shr2_op.attrs["rounding_mode"] = NpuRoundingMode.NATURAL shr2_op.add_input_tensor(ifm_exp) shr2_op.add_input_tensor( create_const_tensor( shr2_op.name + "_const", [1, 1, 1, 1], DataType.int32, [12], np.int32, quantization=no_scale_quant ), ) + shr2_op.activation = activation.clone() rescaled_exp = Tensor(ifm.shape, ifm_exp.dtype, shr2_op.name + "_0") rescaled_exp.quantization = no_scale_quant shr2_op.set_output_tensor(rescaled_exp) @@ -292,6 +303,7 @@ class SoftMax: reduce_sum_op.attrs["strides"] = [1, reduce_sum_op.attrs["stride_h"], reduce_sum_op.attrs["stride_w"], 1] reduce_sum_op.attrs["ksize"] = [1, reduce_sum_op.attrs["filter_height"], reduce_sum_op.attrs["filter_width"], 1] reduce_sum_op.add_input_tensor(rescaled_exp) + reduce_sum_op.activation = activation.clone() reduce_sum_shape = [1, rescaled_exp.shape[1], rescaled_exp.shape[2], 1] sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0") @@ -302,6 +314,7 @@ class SoftMax: # PASS 4 - CLZ clz_op = Operation(Op.CLZ, self.op.name + "_clz4") clz_op.add_input_tensor(sum_of_exp) + clz_op.activation = activation.clone() headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0") headroom_plus_one.quantization = no_scale_quant clz_op.set_output_tensor(headroom_plus_one) @@ -320,6 +333,7 @@ class SoftMax: ), ) sub5_op.add_input_tensor(headroom_plus_one) + sub5_op.activation = activation.clone() right_shift = Tensor(reduce_sum_shape, DataType.int32, sub5_op.name + "_0") right_shift.quantization = no_scale_quant sub5_op.set_output_tensor(right_shift) @@ -330,6 +344,7 @@ class SoftMax: sub6_op = Operation(Op.Sub, self.op.name + "_sub6") sub6_op.add_input_tensor(headroom_plus_one) sub6_op.add_input_tensor(one) + sub6_op.activation = activation.clone() headroom = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0") headroom.quantization = no_scale_quant sub6_op.set_output_tensor(headroom) @@ -339,8 +354,10 @@ class SoftMax: shl7_op = Operation(Op.SHL, self.op.name + "_shl7") shl7_op.add_input_tensor(sum_of_exp) shl7_op.add_input_tensor(headroom) + shl7_op.activation = activation.clone() shifted_sum = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0") shifted_sum.quantization = no_scale_quant + shl7_op.set_output_tensor(shifted_sum) DebugDatabase.add_optimised(self.op, shl7_op) @@ -352,6 +369,7 @@ class SoftMax: "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], np.int32, quantization=no_scale_quant ), ) + sub8_op.activation = activation.clone() shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0") shifted_sum_minus_one.quantization = no_scale_quant sub8_op.set_output_tensor(shifted_sum_minus_one) @@ -361,6 +379,7 @@ class SoftMax: shl9_op = Operation(Op.SHL, self.op.name + "_shl9") shl9_op.add_input_tensor(shifted_sum_minus_one) shl9_op.add_input_tensor(one) + shl9_op.activation = activation.clone() shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0") shifted_sum_minus_one.quantization = no_scale_quant shl9_op.set_output_tensor(shifted_sum_minus_one) @@ -374,7 +393,8 @@ class SoftMax: ), ) add10_op.add_input_tensor(shifted_sum_minus_one) - add10_op.attrs["rescale"] = [1, 1] + add10_op.activation = activation.clone() + add10_op.attrs["rescale"] = (1, 1) half_denominator = Tensor(reduce_sum_shape, DataType.int32, add10_op.name + "_0") half_denominator.quantization = one_scale_quant add10_op.set_output_tensor(half_denominator) @@ -396,6 +416,7 @@ class SoftMax: rescaled = Tensor(reduce_sum_shape, DataType.int32, mul11_op.name + "_0") rescaled.quantization = one_scale_quant.clone() rescaled.quantization.scale_f32 = 2.0 + mul11_op.activation = activation2.clone() mul11_op.set_output_tensor(rescaled) DebugDatabase.add_optimised(self.op, mul11_op) @@ -407,6 +428,7 @@ class SoftMax: "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], np.int32, quantization=no_scale_quant ), ) + add12_op.activation = activation.clone() rescale_w_offset = Tensor(reduce_sum_shape, DataType.int32, add12_op.name + "_0") rescale_w_offset.quantization = one_scale_quant add12_op.set_output_tensor(rescale_w_offset) @@ -424,6 +446,7 @@ class SoftMax: mul_op = Operation(Op.Mul, self.op.name + "_mul%d" % (13 + i * 5)) mul_op.add_input_tensor(nr_x) mul_op.add_input_tensor(half_denominator) + mul_op.activation = activation2.clone() half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0") half_denominator_times_x.quantization = one_scale_quant.clone() half_denominator_times_x.quantization.scale_f32 = 2.0 @@ -433,6 +456,7 @@ class SoftMax: sub_op = Operation(Op.Sub, self.op.name + "_sub%d" % (14 + i * 5)) sub_op.add_input_tensor(F2_one) sub_op.add_input_tensor(half_denominator_times_x) + sub_op.activation = activation.clone() one_minus_half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, sub_op.name + "_0") one_minus_half_denominator_times_x.quantization = one_scale_quant sub_op.set_output_tensor(one_minus_half_denominator_times_x) @@ -441,6 +465,7 @@ class SoftMax: mul_op = Operation(Op.Mul, self.op.name + "_mul%d" % (15 + i * 5)) mul_op.add_input_tensor(nr_x) mul_op.add_input_tensor(one_minus_half_denominator_times_x) + mul_op.activation = activation2.clone() to_rescale = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0") to_rescale.quantization = one_scale_quant.clone() to_rescale.quantization.scale_f32 = 2.0 @@ -450,6 +475,7 @@ class SoftMax: shl_op = Operation(Op.Mul, self.op.name + "_mul%d" % (16 + i * 5)) shl_op.add_input_tensor(to_rescale) shl_op.add_input_tensor(four) + shl_op.activation = activation.clone() to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0") to_add.quantization = no_scale_quant shl_op.set_output_tensor(to_add) @@ -458,6 +484,7 @@ class SoftMax: add_op = Operation(Op.Add, self.op.name + "_add%d" % (17 + i * 5)) add_op.add_input_tensor(nr_x) add_op.add_input_tensor(to_add) + add_op.activation = activation.clone() nr_x = Tensor(reduce_sum_shape, DataType.int32, add_op.name + "_0") nr_x.quantization = one_scale_quant add_op.set_output_tensor(nr_x) @@ -469,6 +496,7 @@ class SoftMax: mul28_op.add_input_tensor( create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant) ) + mul28_op.activation = activation.clone() scale_factor = Tensor(reduce_sum_shape, DataType.int32, mul28_op.name + "_0") scale_factor.quantization = one_scale_quant mul28_op.set_output_tensor(scale_factor) @@ -478,6 +506,7 @@ class SoftMax: mul_op = Operation(Op.Mul, self.op.name + "_mul29") mul_op.add_input_tensor(ifm_exp) mul_op.add_input_tensor(scale_factor) + mul_op.activation = activation2.clone() scaled_exp = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0") scaled_exp.quantization = one_scale_quant.clone() scaled_exp.quantization.scale_f32 = 2.0 @@ -486,7 +515,7 @@ class SoftMax: # PASS 30 - SHR shr30_op = Operation(Op.SHR, self.op.name + "_shr30") - shr30_op.attrs["rounding_mode"] = b"NATURAL" + shr30_op.attrs["rounding_mode"] = NpuRoundingMode.NATURAL shr30_op.add_input_tensor(scaled_exp) shr30_op.add_input_tensor(right_shift) shr30_op.set_output_tensor(ofm) diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 3e649e0..b537b65 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -379,9 +379,13 @@ class SupportedOperators: @docstring_format_args([supported_fused_activations]) def constraint_faf(cls, op): "The fused activation function (if present) must be one of type: {}" - faf = op.activation - valid = (faf is None) or (faf in cls.supported_fused_activations) - return valid, f"Op has its fused activation function as: {faf}" + if op.activation is None: + res = True, "Op has no fused activation function" + else: + faf = op.activation.op_type + valid = faf in cls.supported_fused_activations + res = valid, f"Op has its fused activation function as: {faf}" + return res @staticmethod def constraint_stride_type(op): diff --git a/ethosu/vela/test/extapi/test_extapi_encode_weights.py b/ethosu/vela/test/extapi/test_extapi_encode_weights.py index 47ca02b..356bbc1 100644 --- a/ethosu/vela/test/extapi/test_extapi_encode_weights.py +++ b/ethosu/vela/test/extapi/test_extapi_encode_weights.py @@ -19,6 +19,7 @@ import numpy as np import pytest from ethosu.vela import weight_compressor +from ethosu.vela.api import NpuBlockTraversal from ethosu.vela.architecture_features import Accelerator @@ -52,7 +53,7 @@ def test_encode_weights( weights_hwio = np.random.randint(val_max, size=weights_shape, dtype=np.uint8) weights_ohwi = np.transpose(weights_hwio, (3, 0, 1, 2)) is_depthwise = True if depth_control == 2 else False - is_partkernel = True if depth_control == 3 else False + block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST if depth_control == 3 else NpuBlockTraversal.DEPTH_FIRST dilation_xy = (dilation_x, dilation_y) encoded_stream = weight_compressor.encode_weights( @@ -62,7 +63,7 @@ def test_encode_weights( ifm_bitdepth=ifm_bitdepth, ofm_block_depth=ofm_block_depth, is_depthwise=is_depthwise, - is_partkernel=is_partkernel, + block_traversal=block_traversal, ) assert type(encoded_stream) == bytearray diff --git a/ethosu/vela/test/extapi/test_extapi_generate_commands.py b/ethosu/vela/test/extapi/test_extapi_generate_commands.py new file mode 100644 index 0000000..49b24b2 --- /dev/null +++ b/ethosu/vela/test/extapi/test_extapi_generate_commands.py @@ -0,0 +1,370 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# Contains unit tests for generate_register_command_stream API for an external consumer +from ethosu.vela.api import NpuActivation +from ethosu.vela.api import NpuActivationOp +from ethosu.vela.api import NpuAddressRange +from ethosu.vela.api import NpuBlockTraversal +from ethosu.vela.api import NpuConv2DOperation +from ethosu.vela.api import NpuConvDepthWiseOperation +from ethosu.vela.api import NpuDataType +from ethosu.vela.api import NpuDmaOperation +from ethosu.vela.api import NpuElementWiseOp +from ethosu.vela.api import NpuElementWiseOperation +from ethosu.vela.api import NpuFeatureMap +from ethosu.vela.api import NpuKernel +from ethosu.vela.api import NpuLayout +from ethosu.vela.api import NpuPadding +from ethosu.vela.api import NpuPoolingOp +from ethosu.vela.api import NpuPoolingOperation +from ethosu.vela.api import NpuQuantization +from ethosu.vela.api import NpuShape3D +from ethosu.vela.api import NpuTileBox +from ethosu.vela.architecture_features import Accelerator +from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd0 +from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd1 +from ethosu.vela.register_command_stream_generator import CmdMode +from ethosu.vela.register_command_stream_generator import generate_register_command_stream +from ethosu.vela.register_command_stream_generator import get_address_ranges + + +def check_cmd0(cmd_stream, cmd, param): + """Checks that the command stream contains the given command + parameter""" + param = int(param) & 0xFFFF + command = cmd.value | (param << 16) + assert command in cmd_stream, f"Not in command stream: {cmd} {param}" + + +def check_cmd1(cmd_stream, cmd, offset, param=0x0): + """Checks that the command stream contains the given command + parameter""" + offset = int(offset) & 0xFFFFFFFFF + command = cmd.value | CmdMode.Payload32.value | (param << 16) + for i in range(len(cmd_stream) - 1): + if cmd_stream[i] == command and cmd_stream[i + 1] == offset: + return # found + assert False, f"Not in command stream: {cmd} {offset} {param}" + + +def find_cmd0(cmd_stream, cmd) -> int: + """Returns parameter of the first command in the stream that matches the given command""" + for command in cmd_stream: + if (command & 0xFFFF) == cmd.value: + return (command >> 16) & 0xFFFF + assert False, f"Not in command stream: {cmd}" + + +def create_feature_map( + shape: NpuShape3D, + region: int, + address: int, + dtype: NpuDataType = NpuDataType.UINT8, + layout: NpuLayout = NpuLayout.NHWC, + quant=NpuQuantization(scale_f32=1, zero_point=0), +) -> NpuFeatureMap: + """Creates feature map using 1 tile""" + fm = NpuFeatureMap() + fm.data_type = dtype + fm.shape = shape + fm.tiles = NpuTileBox( + width_0=shape.width, height_0=shape.height, height_1=shape.height, addresses=[address, 0, 0, 0] + ) + fm.region = region + fm.layout = layout + fm.quantization = quant + return fm + + +def test_conv2d(): + """Tests command stream generation for a conv2d operation""" + op = NpuConv2DOperation() + op.ifm = create_feature_map( + NpuShape3D(height=30, width=62, depth=46), 1, 512, quant=NpuQuantization(scale_f32=0.007843138, zero_point=128) + ) + op.ofm = create_feature_map( + NpuShape3D(height=30, width=31, depth=46), + 1, + 0x14E40, + quant=NpuQuantization(scale_f32=0.20392157, zero_point=128), + ) + op.kernel = NpuKernel(3, 2, 2, 1) + op.weights = [NpuAddressRange(region=0, address=0, length=7696)] + op.biases = [NpuAddressRange(region=0, address=32000, length=464)] + op.padding = NpuPadding(top=0, left=0, right=1, bottom=1) + op.block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST + # In this example we assume that the weights were compressed with ofm depth 16; + # let vela choose suitable block width and height by setting these to -1 + op.block_config = NpuShape3D(height=-1, width=-1, depth=16) + cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_128) + check_cmd0(cmds, cmd0.NPU_SET_IFM_REGION, 1) + check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE0, 512) + check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE1, 0) + check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE2, 0) + check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE3, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM_HEIGHT0_M1, 29) + check_cmd0(cmds, cmd0.NPU_SET_IFM_HEIGHT1_M1, 29) + check_cmd0(cmds, cmd0.NPU_SET_IFM_WIDTH0_M1, 61) + check_cmd0(cmds, cmd0.NPU_SET_IFM_DEPTH_M1, 45) + check_cmd1(cmds, cmd1.NPU_SET_IFM_STRIDE_C, 1) + check_cmd1(cmds, cmd1.NPU_SET_IFM_STRIDE_Y, 2852) + check_cmd1(cmds, cmd1.NPU_SET_IFM_STRIDE_X, 46) + check_cmd0(cmds, cmd0.NPU_SET_IFM_ZERO_POINT, 128) + check_cmd0(cmds, cmd0.NPU_SET_IFM_PRECISION, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM_UPSCALE, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM_PAD_TOP, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM_PAD_LEFT, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM_PAD_BOTTOM, 1) + check_cmd0(cmds, cmd0.NPU_SET_IFM_PAD_RIGHT, 1) + check_cmd0(cmds, cmd0.NPU_SET_OFM_REGION, 1) + check_cmd1(cmds, cmd1.NPU_SET_OFM_BASE0, 85568) + check_cmd1(cmds, cmd1.NPU_SET_OFM_BASE1, 0) + check_cmd1(cmds, cmd1.NPU_SET_OFM_BASE2, 0) + check_cmd1(cmds, cmd1.NPU_SET_OFM_BASE3, 0) + check_cmd0(cmds, cmd0.NPU_SET_OFM_HEIGHT0_M1, 29) + check_cmd0(cmds, cmd0.NPU_SET_OFM_HEIGHT1_M1, 29) + check_cmd0(cmds, cmd0.NPU_SET_OFM_WIDTH0_M1, 30) + check_cmd0(cmds, cmd0.NPU_SET_OFM_HEIGHT_M1, 29) + check_cmd0(cmds, cmd0.NPU_SET_OFM_WIDTH_M1, 30) + check_cmd0(cmds, cmd0.NPU_SET_OFM_DEPTH_M1, 45) + check_cmd1(cmds, cmd1.NPU_SET_OFM_STRIDE_C, 1) + check_cmd1(cmds, cmd1.NPU_SET_OFM_STRIDE_Y, 1426) + check_cmd1(cmds, cmd1.NPU_SET_OFM_STRIDE_X, 46) + check_cmd0(cmds, cmd0.NPU_SET_OFM_ZERO_POINT, 128) + check_cmd0(cmds, cmd0.NPU_SET_OFM_PRECISION, 0) + check_cmd0(cmds, cmd0.NPU_SET_KERNEL_HEIGHT_M1, 1) + check_cmd0(cmds, cmd0.NPU_SET_KERNEL_WIDTH_M1, 2) + check_cmd0(cmds, cmd0.NPU_SET_KERNEL_STRIDE, 5) + check_cmd0(cmds, cmd0.NPU_SET_WEIGHT_REGION, 0) + check_cmd1(cmds, cmd1.NPU_SET_WEIGHT_BASE, 0) + check_cmd1(cmds, cmd1.NPU_SET_WEIGHT_LENGTH, 7696) + check_cmd0(cmds, cmd0.NPU_SET_SCALE_REGION, 0) + check_cmd1(cmds, cmd1.NPU_SET_SCALE_BASE, 32000) + check_cmd1(cmds, cmd1.NPU_SET_SCALE_LENGTH, 464) + check_cmd0(cmds, cmd0.NPU_SET_ACTIVATION, 0) + check_cmd0(cmds, cmd0.NPU_SET_ACTIVATION_MIN, 0) + check_cmd0(cmds, cmd0.NPU_SET_ACTIVATION_MAX, 255) + check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, 15) + check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1, 3) + check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_DEPTH_M1, 15) + check_cmd0(cmds, cmd0.NPU_SET_IFM_IB_END, 14) + check_cmd0(cmds, cmd0.NPU_SET_AB_START, 14) + check_cmd0(cmds, cmd0.NPU_SET_ACC_FORMAT, 0) + check_cmd0(cmds, cmd0.NPU_SET_BLOCKDEP, 0) + check_cmd0(cmds, cmd0.NPU_OP_CONV, 0) + # Check that block width/height were generated that fit + blk_height = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1) + blk_width = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1) + assert blk_height > 0 + assert blk_width > 0 + assert (blk_height + 1) * (blk_width + 1) <= 64 + + +def create_fully_connected_op() -> NpuConv2DOperation: + op = NpuConv2DOperation() + op.ifm = create_feature_map( + NpuShape3D(height=1, width=1, depth=114), + 1, + 0, + quant=NpuQuantization(scale_f32=0.007843138, zero_point=128), + layout=NpuLayout.NHCWB16, + ) + op.ofm = create_feature_map( + NpuShape3D(height=1, width=1, depth=96), + 1, + 0x6A0, + quant=NpuQuantization(scale_f32=0.20392157, zero_point=128), + layout=NpuLayout.NHCWB16, + ) + op.kernel = NpuKernel(1, 1) + op.weights = [NpuAddressRange(region=0, address=0x16880, length=13120)] + op.biases = [NpuAddressRange(region=0, address=0x19BC0, length=960)] + op.padding = NpuPadding(top=0, left=0, right=0, bottom=0) + op.block_traversal = NpuBlockTraversal.DEPTH_FIRST + # In this example we assume that the weights were compressed with ofm depth 96; + # let vela choose suitable block width and height by setting these to -1 + op.block_config = NpuShape3D(height=-1, width=-1, depth=96) + return op + + +def test_fully_connected(): + """Tests command stream generation for a fully connected operation""" + op = create_fully_connected_op() + cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_128) + check_cmd0(cmds, cmd0.NPU_OP_CONV, 0) + assert len(cmds) > 20 + + +def test_depthwise(): + """Test depthwise operation, preceeded by DMA operation""" + weights_src = NpuAddressRange(region=0, address=0x40, length=96) + weights_dest = NpuAddressRange(region=1, address=0x10000, length=96) + dma_op = NpuDmaOperation(weights_src, weights_dest) + op = NpuConvDepthWiseOperation() + ifm_quant = NpuQuantization(scale_f32=0.007843138, zero_point=128) + op.ifm = create_feature_map(NpuShape3D(height=64, width=64, depth=8), 1, 0x0, quant=ifm_quant) + ofm_quant = NpuQuantization(scale_f32=0.062745101749897, zero_point=128) + op.ofm = create_feature_map(NpuShape3D(height=64, width=64, depth=8), 1, 0x8000, quant=ofm_quant) + op.kernel = NpuKernel(3, 3) + op.padding = NpuPadding(top=1, left=1, right=1, bottom=1) + op.weights = [weights_dest] + op.biases = [NpuAddressRange(region=0, address=0, length=80)] + op.block_config = NpuShape3D(height=-1, width=-1, depth=8) + cmds = generate_register_command_stream([dma_op, op], Accelerator.Ethos_U55_128) + check_cmd0(cmds, cmd0.NPU_SET_DMA0_SRC_REGION, 0) + check_cmd1(cmds, cmd1.NPU_SET_DMA0_SRC, 0x40) + check_cmd0(cmds, cmd0.NPU_SET_DMA0_DST_REGION, 1) + check_cmd1(cmds, cmd1.NPU_SET_DMA0_DST, 0x10000) + check_cmd1(cmds, cmd1.NPU_SET_DMA0_LEN, 96) + check_cmd0(cmds, cmd0.NPU_OP_DMA_START, 0) + # A DMA WAIT should have been inserted + check_cmd0(cmds, cmd0.NPU_OP_DMA_WAIT, 0) + check_cmd0(cmds, cmd0.NPU_OP_DEPTHWISE, 0) + blk_height = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1) + blk_width = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1) + assert blk_height > 0 + assert blk_width > 0 + + +def test_mul_with_broadcast_and_relu(): + """Test multiplication with broadcasted IFM2""" + op = NpuElementWiseOperation(NpuElementWiseOp.MUL) + op.ifm = create_feature_map(NpuShape3D(height=31, width=22, depth=31), 1, 0x20) + op.ifm2 = create_feature_map(NpuShape3D(height=1, width=22, depth=1), 1, 0) + op.ofm = create_feature_map(NpuShape3D(height=31, width=22, depth=31), 1, 0x52C0) + op.activation = NpuActivation(NpuActivationOp.NONE_OR_RELU) + op.activation.min = 0 # RELU + # Do not set a block config, let vela choose one + cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_32) + check_cmd1(cmds, cmd1.NPU_SET_OFM_SCALE, 1073741824, 30) + check_cmd0(cmds, cmd0.NPU_SET_IFM_REGION, 1) + check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE0, 32) + check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE1, 0) + check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE2, 0) + check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE3, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM_HEIGHT0_M1, 30) + check_cmd0(cmds, cmd0.NPU_SET_IFM_HEIGHT1_M1, 30) + check_cmd0(cmds, cmd0.NPU_SET_IFM_WIDTH0_M1, 21) + check_cmd0(cmds, cmd0.NPU_SET_IFM_DEPTH_M1, 30) + check_cmd1(cmds, cmd1.NPU_SET_IFM_STRIDE_C, 1) + check_cmd1(cmds, cmd1.NPU_SET_IFM_STRIDE_Y, 682) + check_cmd1(cmds, cmd1.NPU_SET_IFM_STRIDE_X, 31) + check_cmd0(cmds, cmd0.NPU_SET_IFM_ZERO_POINT, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM_PRECISION, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM_UPSCALE, 0) + check_cmd0(cmds, cmd0.NPU_SET_OFM_REGION, 1) + check_cmd1(cmds, cmd1.NPU_SET_OFM_BASE0, 21184) + check_cmd1(cmds, cmd1.NPU_SET_OFM_BASE1, 0) + check_cmd1(cmds, cmd1.NPU_SET_OFM_BASE2, 0) + check_cmd1(cmds, cmd1.NPU_SET_OFM_BASE3, 0) + check_cmd0(cmds, cmd0.NPU_SET_OFM_HEIGHT0_M1, 30) + check_cmd0(cmds, cmd0.NPU_SET_OFM_HEIGHT1_M1, 30) + check_cmd0(cmds, cmd0.NPU_SET_OFM_WIDTH0_M1, 21) + check_cmd0(cmds, cmd0.NPU_SET_OFM_HEIGHT_M1, 30) + check_cmd0(cmds, cmd0.NPU_SET_OFM_WIDTH_M1, 21) + check_cmd0(cmds, cmd0.NPU_SET_OFM_DEPTH_M1, 30) + check_cmd1(cmds, cmd1.NPU_SET_OFM_STRIDE_C, 1) + check_cmd1(cmds, cmd1.NPU_SET_OFM_STRIDE_Y, 682) + check_cmd1(cmds, cmd1.NPU_SET_OFM_STRIDE_X, 31) + check_cmd0(cmds, cmd0.NPU_SET_OFM_ZERO_POINT, 0) + check_cmd0(cmds, cmd0.NPU_SET_OFM_PRECISION, 256) + check_cmd0(cmds, cmd0.NPU_SET_ACTIVATION, 0) + check_cmd0(cmds, cmd0.NPU_SET_ACTIVATION_MIN, 0) + check_cmd0(cmds, cmd0.NPU_SET_ACTIVATION_MAX, 255) + check_cmd0(cmds, cmd0.NPU_SET_IFM2_REGION, 1) + check_cmd1(cmds, cmd1.NPU_SET_IFM2_BASE0, 0) + check_cmd1(cmds, cmd1.NPU_SET_IFM2_BASE1, 0) + check_cmd1(cmds, cmd1.NPU_SET_IFM2_BASE2, 0) + check_cmd1(cmds, cmd1.NPU_SET_IFM2_BASE3, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM2_HEIGHT0_M1, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM2_HEIGHT1_M1, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM2_WIDTH0_M1, 21) + check_cmd1(cmds, cmd1.NPU_SET_IFM2_STRIDE_C, 1) + check_cmd1(cmds, cmd1.NPU_SET_IFM2_STRIDE_Y, 22) + check_cmd1(cmds, cmd1.NPU_SET_IFM2_STRIDE_X, 1) + check_cmd0(cmds, cmd0.NPU_SET_IFM2_ZERO_POINT, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM2_PRECISION, 0) + check_cmd0(cmds, cmd0.NPU_SET_IFM2_BROADCAST, 5) + check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, 23) + check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1, 3) + check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_DEPTH_M1, 31) + check_cmd0(cmds, cmd0.NPU_SET_IFM_IB_END, 16) + check_cmd0(cmds, cmd0.NPU_SET_AB_START, 16) + check_cmd0(cmds, cmd0.NPU_SET_IFM2_IB_START, 9) + check_cmd0(cmds, cmd0.NPU_SET_ACC_FORMAT, 0) + check_cmd0(cmds, cmd0.NPU_SET_BLOCKDEP, 0) + check_cmd0(cmds, cmd0.NPU_OP_ELEMENTWISE, 0) + # Check that block width/height were generated that fit + blk_height = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1) + blk_width = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1) + blk_depth = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_DEPTH_M1) + assert blk_height >= 0 + assert blk_width >= 0 + assert blk_depth >= 0 + assert (blk_height + 1) * (blk_width + 1) + (blk_depth + 1) <= 3072 + + +def create_avg_pool_op() -> NpuPoolingOperation: + op = NpuPoolingOperation(NpuPoolingOp.AVERAGE) + op.ifm = create_feature_map( + NpuShape3D(height=29, width=30, depth=27), 2, 0, quant=NpuQuantization(scale_f32=0.007843138, zero_point=128) + ) + op.ofm = create_feature_map( + NpuShape3D(height=10, width=10, depth=27), + 2, + 0x5BD0, + quant=NpuQuantization(scale_f32=0.20392157, zero_point=128), + ) + op.kernel = NpuKernel(8, 2, 3, 3) + op.padding = NpuPadding(top=0, left=2, right=3, bottom=0) + # Do not set a block config, let vela choose one + return op + + +def test_avg_pool(): + """Tests average pool operation""" + op = create_avg_pool_op() + cmds = generate_register_command_stream([op], Accelerator.Ethos_U55_128) + check_cmd0(cmds, cmd0.NPU_OP_POOL, 1) + assert len(cmds) > 10 + + +def test_two_operations(): + """Tests code generation with 2 operations""" + op1 = create_fully_connected_op() + op2 = create_avg_pool_op() + cmds = generate_register_command_stream([op1, op2], Accelerator.Ethos_U55_64) + check_cmd0(cmds, cmd0.NPU_OP_POOL, 1) + check_cmd0(cmds, cmd0.NPU_OP_CONV, 0) + check_cmd0(cmds, cmd0.NPU_SET_BLOCKDEP, 0) + # The operations are not dependent, so expect a blockdep 3 + check_cmd0(cmds, cmd0.NPU_SET_BLOCKDEP, 3) + assert len(cmds) > 10 + + +def test_dma_op(): + """Tests DMA operation followed by average pool. The DMA provides the contents of the average pool's IFM.""" + pool_op = create_avg_pool_op() + assert pool_op.ofm is not None + dest = get_address_ranges(pool_op.ofm)[0] + assert dest is not None + src = NpuAddressRange(0, 0x24000, dest.length) + dma_op = NpuDmaOperation(src, dest) + cmds = generate_register_command_stream([dma_op, pool_op], Accelerator.Ethos_U55_64) + check_cmd0(cmds, cmd0.NPU_OP_DMA_START, 0) + # A DMA WAIT should have been inserted + check_cmd0(cmds, cmd0.NPU_OP_DMA_WAIT, 0) + check_cmd0(cmds, cmd0.NPU_OP_POOL, 1) diff --git a/ethosu/vela/test/test_register_command_generator.py b/ethosu/vela/test/test_register_command_generator.py new file mode 100644 index 0000000..f2a1609 --- /dev/null +++ b/ethosu/vela/test/test_register_command_generator.py @@ -0,0 +1,104 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# Contains unit tests for register command stream generator +from ethosu.vela.api import NpuAddressRange +from ethosu.vela.api import NpuDataType +from ethosu.vela.api import NpuFeatureMap +from ethosu.vela.api import NpuLayout +from ethosu.vela.api import NpuShape3D +from ethosu.vela.api import NpuTileBox +from ethosu.vela.register_command_stream_generator import get_address_ranges +from ethosu.vela.register_command_stream_generator import get_strides + + +def test_get_fm_strides(): + """Tests calculation of feature map strides""" + fm = NpuFeatureMap() + fm.layout = NpuLayout.NHCWB16 + fm.data_type = NpuDataType.INT16 + fm.shape = NpuShape3D(height=7, width=10, depth=24) + assert get_strides(fm) == NpuShape3D(height=640, width=32, depth=320) + fm.layout = NpuLayout.NHWC + assert get_strides(fm) == NpuShape3D(height=480, width=48, depth=2) + fm.data_type = NpuDataType.UINT8 + assert get_strides(fm) == NpuShape3D(height=240, width=24, depth=1) + + +def test_get_address_ranges_one_tile(): + """Tests calculation of feature map address ranges, with 1 tile used""" + fm = NpuFeatureMap() + fm.region = 4 + fm.layout = NpuLayout.NHWC + fm.data_type = NpuDataType.INT16 + fm.shape = NpuShape3D(height=50, width=40, depth=3) + fm.tiles = NpuTileBox(height_0=50, height_1=50, width_0=40, addresses=[8000, 0, 0, 0]) + ranges = get_address_ranges(fm) + assert ranges == [NpuAddressRange(region=4, address=8000, length=12000), None, None, None] + + +def test_get_address_ranges_horizontal_tiles(): + """Tests calculation of feature map address ranges, with 2 horizontal tiles used""" + fm = NpuFeatureMap() + fm.region = 6 + fm.layout = NpuLayout.NHWC + fm.data_type = NpuDataType.INT16 + fm.shape = NpuShape3D(height=50, width=10, depth=20) + fm.tiles = NpuTileBox(height_0=20, height_1=30, width_0=10, addresses=[256, 0, 16000, 0]) + ranges = get_address_ranges(fm) + assert ranges == [ + NpuAddressRange(region=6, address=256, length=8000), + None, + NpuAddressRange(region=6, address=16000, length=12000), + None, + ] + + +def test_get_address_ranges_vertical_tiles(): + """Tests calculation of feature map address ranges, with 2 vertical tiles used""" + fm = NpuFeatureMap() + fm.region = 6 + fm.layout = NpuLayout.NHWC + fm.data_type = NpuDataType.INT8 + # Set strides explicitly + fm.shape = NpuShape3D(height=50, width=10, depth=20) + fm.strides = NpuShape3D(height=100, width=20, depth=1) + fm.tiles = NpuTileBox(height_0=50, height_1=50, width_0=5, addresses=[16, 32000, 0, 0]) + ranges = get_address_ranges(fm) + assert ranges == [ + NpuAddressRange(region=6, address=16, length=5000), + NpuAddressRange(region=6, address=32000, length=5000), + None, + None, + ] + + +def test_get_address_ranges_4_tiles(): + """Tests calculation of feature map address ranges, with 4 tiles used""" + fm = NpuFeatureMap() + fm.region = 6 + fm.layout = NpuLayout.NHCWB16 + fm.data_type = NpuDataType.INT16 + fm.shape = NpuShape3D(height=50, width=10, depth=20) + fm.tiles = NpuTileBox(height_0=30, height_1=10, width_0=3, addresses=[16, 32000, 8000, 16000]) + ranges = get_address_ranges(fm) + assert ranges == [ + NpuAddressRange(region=6, address=16, length=18952), + NpuAddressRange(region=6, address=32000, length=6280), + NpuAddressRange(region=6, address=8000, length=12552), + NpuAddressRange(region=6, address=28800, length=12680), + ] diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 245ebcf..7e13f42 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -19,6 +19,7 @@ import numpy as np from ethosu.vela.data_type import DataType +from ethosu.vela.operation import ActivationFunction from ethosu.vela.operation import Op from ethosu.vela.supported_operators import SupportedOperators from ethosu.vela.tensor import create_const_tensor @@ -102,7 +103,7 @@ def test_constraint_tens_quant_scale(): def test_constraint_faf(): # Fused activation functions, if set, must be a valid op type op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8]) - op.activation = Op.Conv2D + op.activation = ActivationFunction(Op.Conv2D) assert not support.is_operator_supported(op) diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 24f9f87..b3b0720 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -23,6 +23,7 @@ from .errors import InputFileError from .errors import TensorError from .nn_graph import Graph from .nn_graph import Subgraph +from .operation import create_activation_function from .operation import Op from .operation import Operation from .tensor import QuantizationParameters @@ -186,7 +187,9 @@ class TFLiteSubgraph: if "depth_multiplier" in op.attrs: op.attrs["channel_multiplier"] = op.attrs["depth_multiplier"] - op.activation = op.attrs.pop("fused_activation_function", None) + faf = op.attrs.pop("fused_activation_function", None) + if faf is not None: + op.activation = create_activation_function(faf) if custom_code is not None: op.attrs["custom_code"] = custom_code diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 0f20878..de0ee74 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -278,7 +278,8 @@ class TFLiteSerialiser: attrs["dilation_w_factor"] = attrs["dilation"][2] if "channel_multiplier" in attrs: attrs["depth_multiplier"] = attrs["channel_multiplier"] - attrs["fused_activation_function"] = op.activation + if op.activation is not None: + attrs["fused_activation_function"] = op.activation.op_type builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs) diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py index b0187b6..c07229f 100644 --- a/ethosu/vela/weight_compressor.py +++ b/ethosu/vela/weight_compressor.py @@ -20,6 +20,7 @@ from collections import namedtuple import numpy as np +from .api import NpuBlockTraversal from .architecture_features import Accelerator from .architecture_features import ArchitectureFeatures from .data_type import DataType @@ -53,7 +54,7 @@ def encode_weights( ifm_bitdepth: int, ofm_block_depth: int, is_depthwise: bool, - is_partkernel: bool, + block_traversal: NpuBlockTraversal, ): """ Public facing API to use the ethosu weight encoding. @@ -64,7 +65,7 @@ def encode_weights( :param ifm_bitdepth: the bitdepth of input feature map :param ofm_block_depth: the depth of blocks for ethosu processing :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal - :param is_partkernel: a boolean indicating these weights are traversed on sub-kernal basis + :param block_traversal: indicates how these weights are traversed on sub-kernal basis :return: a bytearray of compressed weights """ @@ -75,13 +76,15 @@ def encode_weights( assert isinstance(ifm_bitdepth, int) assert isinstance(ofm_block_depth, int) assert isinstance(is_depthwise, bool) - assert isinstance(is_partkernel, bool) + assert isinstance(block_traversal, NpuBlockTraversal) # Checks for weight layout assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4" # It cannot be both partkernel and depthwise - assert not (is_depthwise and is_partkernel), "encode_weights :: partkernel and depthwise are mutually exclusive" + assert not ( + is_depthwise and block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST + ), "encode_weights :: partkernel and depthwise are mutually exclusive" # Check valid values for dilation assert dilation_xy[0] in (1, 2), "encode_weights :: dilation x should be 1 or 2 not {}".format(dilation_xy[0]) @@ -95,7 +98,7 @@ def encode_weights( brick_weights=weights_volume, ofm_block_depth=ofm_block_depth, is_depthwise=is_depthwise, - is_partkernel=is_partkernel, + is_partkernel=block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST, ifm_bitdepth=ifm_bitdepth, dilation=dilation_xy, ) @@ -335,7 +338,10 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth tens.block_traversal = TensorBlockTraversal.DepthFirst is_depthwise = tens.block_traversal == TensorBlockTraversal.DepthWise - is_partkernel = tens.block_traversal == TensorBlockTraversal.PartKernelFirst + if tens.block_traversal == TensorBlockTraversal.PartKernelFirst: + block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST + else: + block_traversal = NpuBlockTraversal.DEPTH_FIRST if tens.consumer_list[0].type == Op.Conv2DBackpropInputSwitchedBias: # Transpose Convoluion, reverse weights in H and W axes @@ -370,7 +376,7 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth ifm_bitdepth=ifm_bitdepth, ofm_block_depth=block_depth, is_depthwise=is_depthwise, - is_partkernel=is_partkernel, + block_traversal=block_traversal, ) encoded_stream.extend(encoded_substream) substream_offsets.append(len(encoded_stream)) -- cgit v1.2.1