From 933f55ea6f686d0cf390f4767e87a391686c3df8 Mon Sep 17 00:00:00 2001 From: Louis Verhaard Date: Wed, 25 Nov 2020 14:10:30 +0100 Subject: MLBEDSW-3599: Added API for finding block configs Added public API function npu_find_block_configs. Change-Id: Ib0925a62d7c5d19a9b9fbd8d808943c2ea2df02f Signed-off-by: Louis Verhaard --- API.md | 7 ++ ethosu/vela/api.py | 24 +++-- ethosu/vela/register_command_stream_generator.py | 120 +++++++++------------ .../test/extapi/test_extapi_find_block_configs.py | 63 +++++++++++ .../test/extapi/test_extapi_generate_commands.py | 33 ++---- 5 files changed, 147 insertions(+), 100 deletions(-) create mode 100644 ethosu/vela/test/extapi/test_extapi_find_block_configs.py diff --git a/API.md b/API.md index 4607378a..25a1a16e 100644 --- a/API.md +++ b/API.md @@ -40,6 +40,13 @@ these basic NPU operations. Note that the compiler is responsible for all address planning, i.e. it needs to supply addresses of all input and output tensors, weights, and biases. +### Finding block configs + +For all NPU operations, a block config must be set, which is the unit of work in +which the NPU generates the output. There are restrictions to the size of block +configs. Function `npu_find_block_configs` can be used to find valid block +configs for an operation. + ### Encoding of weights and biases All weights that are used in the NPU operations must be encoded using diff --git a/ethosu/vela/api.py b/ethosu/vela/api.py index e6286008..f972133d 100644 --- a/ethosu/vela/api.py +++ b/ethosu/vela/api.py @@ -320,20 +320,19 @@ class NpuBlockOperation(NpuOperation): self.ofm: Optional[NpuFeatureMap] = None self.kernel: Optional[NpuKernel] = None # Weights, one element for each NPU core, empty if no weights are used. - # Must have been compressed using weight_compressor.encode_weights() + # Must have been compressed using npu_encode_weights() self.weights: List[NpuAddressRange] = [] # Biases, one element for each NPU core, empty if no bias is used. - # Must have been encoded using weight_compressor.encode_bias() + # Must have been encoded using npu_encode_bias() self.biases: List[NpuAddressRange] = [] self.padding: Optional[NpuPadding] = None # Optional activation function to be applied self.activation: Optional[NpuActivation] = None - # The block config is the unit of work in which the NPU generates the OFM. + # The block config to be used, which must be valid for the given operation. + # See also npu_find_block_configs. # If the operation has weights, the depth of the block config must be the same as - # the ofm depth used in the call to weight_compressor.encode_weights() - # If set to None, vela will determine a suitable block size (can only be used if there are no weights) - # If block_config.width and height are set to -1, vela will determine suitable width/height - self.block_config: Optional[NpuShape3D] = None # OFM_BLK parameters + # the ofm depth used in the call to npu_encode_weights() + self.block_config: NpuShape3D self.rounding_mode: NpuRoundingMode = NpuRoundingMode.TFL # Set to True if the operations is fused with a Quantize operation (affects scaling) self.fused_quantize: bool = False @@ -441,6 +440,17 @@ def npu_encode_bias(bias: numpy.int64, scale: int, shift: int): return weight_compressor.encode_bias(bias, scale, shift) +def npu_find_block_configs(npu_op: NpuOperation, accelerator: NpuAccelerator) -> List[NpuShape3D]: + """ + Public facing API that returns a list of block configs that are valid for the given operation. + This function can be used to find a valid value for npu_op.block_config. + The block config is the unit of work in which the NPU generates the OFM. + """ + from . import register_command_stream_generator + + return register_command_stream_generator.find_block_configs(npu_op, accelerator) + + def npu_generate_register_command_stream(npu_op_list: List[NpuOperation], accelerator: NpuAccelerator) -> List[int]: """ Public facing API for generating an Ethos-U register command stream. diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index 9d79d58a..015a8c49 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -521,35 +521,15 @@ def generate_block_config( npu_op: NpuBlockOperation, arch: ArchitectureFeatures, shared_buffer: SharedBufferAllocation, -) -> NpuShape3D: - """Selects a suitable block config if none has been set, and generates OFM_BLK_HEIGHT/WIDTH/DEPTH registers""" +): + """Generates OFM_BLK_HEIGHT/WIDTH/DEPTH registers""" block_config = npu_op.block_config - if block_config is None or block_config.height < 0: - # Note: this code only used if the public API to generate command streams is used; - # in the "normal" flow, the block config selected by the scheduler is used - if npu_op.weights: - assert block_config is not None, "block_config.depth must be provided for ops with weights" - # Block config has not been provided: find one - blocks = find_suitable_block_configs(arch, shared_buffer) - # Return the block with biggest volume - # TODO: use a better algorithm to find the best block - best_block = None - best_value = 0 - for block in blocks: - if block_config is not None and block[3] != block_config.depth: - continue - value = block[0] * block[1] * block[3] - if value > best_value: - best_value = value - best_block = block - assert best_block is not None, f"No suitable block config was found, {npu_op.op_type}" - block_config = NpuShape3D(height=best_block[0], width=best_block[1], depth=best_block[3]) + assert block_config is not None, "block_config has not been set" alloc = shared_buffer.try_block(Block(block_config.width, block_config.height, block_config.depth)) assert alloc is not None, f"Block config {block_config} does not fit, op: {npu_op.op_type}" emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, block_config.height - 1) emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_WIDTH_M1, block_config.width - 1) emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_DEPTH_M1, block_config.depth - 1) - return block_config def generate_shram_registers_elementwise( @@ -585,6 +565,24 @@ def generate_shram_registers_non_elementwise(emit: CommandStreamEmitter, shared_ emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[shared_buffer.use_accumulator_element]) +def create_shared_buffer(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> SharedBufferAllocation: + """Creates shared buffer allocation for the given operation""" + op_type = npu_op.op_type + block_type = NpuBlockType.Default + if op_type == NpuOperationType.Conv2D: + block_type = NpuBlockType.ConvolutionMxN + elif op_type == NpuOperationType.ConvDepthWise: + block_type = NpuBlockType.ConvolutionDepthWise + elif op_type == NpuOperationType.Pooling: + block_type = NpuBlockType.ReduceSum if npu_op.sub_op_type == NpuPoolingOp.REDUCE_SUM else NpuBlockType.Pooling + elif op_type == NpuOperationType.ElementWise: + block_type = NpuBlockType.ElementWise + else: + assert 0, "Unsupported operation" + ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale] + return shared_buffer_allocation_for_npu_op(arch, npu_op, block_type, ifm_resampling_mode) + + def generate_common( emit: CommandStreamEmitter, npu_op: NpuBlockOperation, @@ -608,6 +606,12 @@ def generate_common( generate_weights(emit, npu_op.weights, arch) generate_biases(emit, npu_op.biases, arch) generate_activation(emit, npu_op.activation, npu_op.ofm) + shared_buffer = create_shared_buffer(npu_op, arch) + generate_block_config(emit, npu_op, arch, shared_buffer) + if npu_op.op_type == NpuOperationType.ElementWise: + generate_shram_registers_elementwise(emit, npu_op, arch, shared_buffer) + else: + generate_shram_registers_non_elementwise(emit, shared_buffer) # ------------------------------------------------------------------- @@ -962,13 +966,7 @@ def get_ifm_ofm_block_depth(arch: ArchitectureFeatures, npu_op: NpuBlockOperatio return npu_op.ofm.shape.depth -def calc_blockdep( - arch: ArchitectureFeatures, - prev_op: Optional[NpuBlockOperation], - prev_block_config: Optional[NpuShape3D], - npu_op: NpuBlockOperation, - block_config: NpuShape3D, -) -> int: +def calc_blockdep(arch: ArchitectureFeatures, prev_op: Optional[NpuBlockOperation], npu_op: NpuBlockOperation,) -> int: """Calculates the value of the BLOCKDEP register""" if prev_op is None: return 0 @@ -976,6 +974,8 @@ def calc_blockdep( return ArchitectureFeatures.MAX_BLOCKDEP if prev_op.ofm.shape != npu_op.ifm.shape: return 0 + prev_block_config = prev_op.block_config + block_config = npu_op.block_config prev_ifm_block_depth = get_ifm_ofm_block_depth(arch, prev_op) prev_ofm_block = Block(prev_block_config.width, prev_block_config.height, prev_block_config.depth) prev_ofm_rect = shape3d_to_rect(prev_op.ofm.shape) @@ -1094,28 +1094,14 @@ def generate_operation_code(emit: CommandStreamEmitter, npu_op: NpuOperation): assert 0, "Unsupported operation" -def generate_conv2d_op( - emit: CommandStreamEmitter, npu_op: NpuConv2DOperation, arch: ArchitectureFeatures -) -> NpuShape3D: +def generate_conv2d_op(emit: CommandStreamEmitter, npu_op: NpuConv2DOperation, arch: ArchitectureFeatures): """Generates register commands for Conv2D operations""" generate_common(emit, npu_op, npu_op.block_traversal, arch) - ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale] - shared_buffer = shared_buffer_allocation_for_npu_op(arch, npu_op, NpuBlockType.ConvolutionMxN, ifm_resampling_mode) - block_config = generate_block_config(emit, npu_op, arch, shared_buffer) - generate_shram_registers_non_elementwise(emit, shared_buffer) - return block_config def generate_conv_depthwise_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, arch: ArchitectureFeatures): """Generates register commands for depthwise convolution operations""" generate_common(emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch) - ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale] - shared_buffer = shared_buffer_allocation_for_npu_op( - arch, npu_op, NpuBlockType.ConvolutionDepthWise, ifm_resampling_mode - ) - block_config = generate_block_config(emit, npu_op, arch, shared_buffer) - generate_shram_registers_non_elementwise(emit, shared_buffer) - return block_config def generate_pooling_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, arch: ArchitectureFeatures): @@ -1127,12 +1113,6 @@ def generate_pooling_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, # Pooling op specific if use_global_scale: generate_ofm_scaling_for_pooling(emit, npu_op) - ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale] - npu_block_type = NpuBlockType.ReduceSum if npu_op.sub_op_type == NpuPoolingOp.REDUCE_SUM else NpuBlockType.Pooling - shared_buffer = shared_buffer_allocation_for_npu_op(arch, npu_op, npu_block_type, ifm_resampling_mode) - block_config = generate_block_config(emit, npu_op, arch, shared_buffer) - generate_shram_registers_non_elementwise(emit, shared_buffer) - return block_config def generate_elementwise_op(emit: CommandStreamEmitter, npu_op: NpuElementWiseOperation, arch: ArchitectureFeatures): @@ -1160,11 +1140,6 @@ def generate_elementwise_op(emit: CommandStreamEmitter, npu_op: NpuElementWiseOp quantized_scalar = quantise(npu_op.ifm2_scalar, npu_op.ifm2.quantization) assert npu_op.ifm2.data_type.min_value() <= quantized_scalar <= npu_op.ifm2.data_type.max_value() emit.cmd0_with_param(cmd0.NPU_SET_IFM2_SCALAR, quantized_scalar) - ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale] - shared_buffer = shared_buffer_allocation_for_npu_op(arch, npu_op, NpuBlockType.ElementWise, ifm_resampling_mode) - block_config = generate_block_config(emit, npu_op, arch, shared_buffer) - generate_shram_registers_elementwise(emit, npu_op, arch, shared_buffer) - return block_config def generate_dma_op(emit: CommandStreamEmitter, dma_op: NpuDmaOperation): @@ -1177,28 +1152,24 @@ def generate_dma_op(emit: CommandStreamEmitter, dma_op: NpuDmaOperation): emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_LEN, dma_op.src.length) -def generate_registers_for_op( - emit: CommandStreamEmitter, npu_op: NpuOperation, arch: ArchitectureFeatures -) -> Optional[NpuShape3D]: +def generate_registers_for_op(emit: CommandStreamEmitter, npu_op: NpuOperation, arch: ArchitectureFeatures): """ Generates register commands for the given operation, but not the final NPU_OP_... command. Returns the selected block config """ op_type = npu_op.op_type - block_config = None if op_type == NpuOperationType.Conv2D: - block_config = generate_conv2d_op(emit, npu_op, arch) + generate_conv2d_op(emit, npu_op, arch) elif op_type == NpuOperationType.ConvDepthWise: - block_config = generate_conv_depthwise_op(emit, npu_op, arch) + generate_conv_depthwise_op(emit, npu_op, arch) elif op_type == NpuOperationType.Pooling: - block_config = generate_pooling_op(emit, npu_op, arch) + generate_pooling_op(emit, npu_op, arch) elif op_type == NpuOperationType.ElementWise: - block_config = generate_elementwise_op(emit, npu_op, arch) + generate_elementwise_op(emit, npu_op, arch) elif op_type == NpuOperationType.Dma: generate_dma_op(emit, npu_op) else: assert 0, "Unsupported operation" - return block_config def generate_command_stream( @@ -1216,19 +1187,16 @@ def generate_command_stream( emit.cmd0_with_param(cmd0.NPU_SET_PARALLEL_MODE, arch.ncores - 1) dep_watermark = Watermark(0, 0) prev_op = None - prev_block_config = None # Generate register commands for all operations for op_index, npu_op in enumerate(npu_op_list): dep_watermark, cmd_waits = get_wait_dependency(arch, npu_op_list, memory_accesses, op_index, dep_watermark) - block_config = generate_registers_for_op(emit, npu_op, arch) + generate_registers_for_op(emit, npu_op, arch) if not is_dma_op(npu_op): # Generate BLOCKDEP - assert block_config is not None - blockdep = calc_blockdep(arch, prev_op, prev_block_config, npu_op, block_config) + blockdep = calc_blockdep(arch, prev_op, npu_op) blockdep = min(blockdep, arch.max_blockdep) emit.cmd0_with_param(cmd0.NPU_SET_BLOCKDEP, blockdep) prev_op = npu_op - prev_block_config = block_config generate_cmd_waits(emit, cmd_waits) # Generate the actual NPU_OP command @@ -1272,6 +1240,18 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False): print("command stream length in words", len(sg.register_command_stream)) +def find_block_configs(npu_op: NpuOperation, npu_accelerator: NpuAccelerator) -> List[NpuShape3D]: + """ + Internal implementation of the public facing API for finding block configs. + """ + if is_dma_op(npu_op): + return [] + arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator)) + shared_buffer = create_shared_buffer(npu_op, arch) + blocks = find_suitable_block_configs(arch, shared_buffer) + return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks] + + def generate_register_command_stream(npu_op_list: List[NpuOperation], npu_accelerator: NpuAccelerator) -> List[int]: """ Internal implementation of the public facing API for generating an Ethos-U register command stream. diff --git a/ethosu/vela/test/extapi/test_extapi_find_block_configs.py b/ethosu/vela/test/extapi/test_extapi_find_block_configs.py new file mode 100644 index 00000000..07cb9cb4 --- /dev/null +++ b/ethosu/vela/test/extapi/test_extapi_find_block_configs.py @@ -0,0 +1,63 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# Contains unit tests for npu_find_block_configs API for an external consumer +from ethosu.vela.api import npu_find_block_configs +from ethosu.vela.api import npu_generate_register_command_stream +from ethosu.vela.api import NpuAccelerator +from ethosu.vela.api import NpuAddressRange +from ethosu.vela.api import NpuBlockTraversal +from ethosu.vela.api import NpuConv2DOperation +from ethosu.vela.api import NpuKernel +from ethosu.vela.api import NpuPadding +from ethosu.vela.api import NpuQuantization +from ethosu.vela.api import NpuShape3D +from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd0 +from ethosu.vela.test.extapi.test_extapi_generate_commands import check_cmd0 +from ethosu.vela.test.extapi.test_extapi_generate_commands import create_feature_map + + +def test_find_block_configs(): + """Tests npu_find_block_configs""" + # Create a Conv2D operation + op = NpuConv2DOperation() + op.ifm = create_feature_map( + NpuShape3D(height=30, width=62, depth=46), 1, 512, quant=NpuQuantization(scale_f32=0.007843138, zero_point=128) + ) + op.ofm = create_feature_map( + NpuShape3D(height=30, width=31, depth=46), + 1, + 0x14E40, + quant=NpuQuantization(scale_f32=0.20392157, zero_point=128), + ) + op.kernel = NpuKernel(3, 2, 2, 1) + op.biases = [NpuAddressRange(region=0, address=32000, length=464)] + op.padding = NpuPadding(top=0, left=0, right=1, bottom=1) + op.block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST + # Find valid block configs + accelerator = NpuAccelerator.Ethos_U55_256 + block_configs = npu_find_block_configs(op, accelerator) + # Select the last one + op.block_config = block_configs[-1] + # Note: the weights should be encoded with op.block_config.depth (not shown here) + op.weights = [NpuAddressRange(region=0, address=0, length=7696)] + # Check that generating register commands succeeds + cmds = npu_generate_register_command_stream([op], accelerator) + # Check that the selected block config was used + check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, op.block_config.height - 1) + check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1, op.block_config.width - 1) + check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_DEPTH_M1, op.block_config.depth - 1) diff --git a/ethosu/vela/test/extapi/test_extapi_generate_commands.py b/ethosu/vela/test/extapi/test_extapi_generate_commands.py index 86ef804a..812991a9 100644 --- a/ethosu/vela/test/extapi/test_extapi_generate_commands.py +++ b/ethosu/vela/test/extapi/test_extapi_generate_commands.py @@ -16,6 +16,7 @@ # # Description: # Contains unit tests for npu_generate_register_command_stream API for an external consumer +from ethosu.vela.api import npu_find_block_configs from ethosu.vela.api import npu_generate_register_command_stream from ethosu.vela.api import NpuAccelerator from ethosu.vela.api import NpuActivation @@ -106,9 +107,7 @@ def test_conv2d(): op.biases = [NpuAddressRange(region=0, address=32000, length=464)] op.padding = NpuPadding(top=0, left=0, right=1, bottom=1) op.block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST - # In this example we assume that the weights were compressed with ofm depth 16; - # let vela choose suitable block width and height by setting these to -1 - op.block_config = NpuShape3D(height=-1, width=-1, depth=16) + op.block_config = NpuShape3D(height=16, width=4, depth=16) cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_128) check_cmd0(cmds, cmd0.NPU_SET_IFM_REGION, 1) check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE0, 512) @@ -165,12 +164,6 @@ def test_conv2d(): check_cmd0(cmds, cmd0.NPU_SET_ACC_FORMAT, 0) check_cmd0(cmds, cmd0.NPU_SET_BLOCKDEP, 0) check_cmd0(cmds, cmd0.NPU_OP_CONV, 0) - # Check that block width/height were generated that fit - blk_height = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1) - blk_width = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1) - assert blk_height > 0 - assert blk_width > 0 - assert (blk_height + 1) * (blk_width + 1) <= 64 def create_fully_connected_op() -> NpuConv2DOperation: @@ -194,9 +187,7 @@ def create_fully_connected_op() -> NpuConv2DOperation: op.biases = [NpuAddressRange(region=0, address=0x19BC0, length=960)] op.padding = NpuPadding(top=0, left=0, right=0, bottom=0) op.block_traversal = NpuBlockTraversal.DEPTH_FIRST - # In this example we assume that the weights were compressed with ofm depth 96; - # let vela choose suitable block width and height by setting these to -1 - op.block_config = NpuShape3D(height=-1, width=-1, depth=96) + op.block_config = NpuShape3D(height=2, width=4, depth=96) return op @@ -222,7 +213,7 @@ def test_depthwise(): op.padding = NpuPadding(top=1, left=1, right=1, bottom=1) op.weights = [weights_dest] op.biases = [NpuAddressRange(region=0, address=0, length=80)] - op.block_config = NpuShape3D(height=-1, width=-1, depth=8) + op.block_config = NpuShape3D(height=8, width=12, depth=8) cmds = npu_generate_register_command_stream([dma_op, op], NpuAccelerator.Ethos_U55_128) check_cmd0(cmds, cmd0.NPU_SET_DMA0_SRC_REGION, 0) check_cmd1(cmds, cmd1.NPU_SET_DMA0_SRC, 0x40) @@ -233,10 +224,6 @@ def test_depthwise(): # A DMA WAIT should have been inserted check_cmd0(cmds, cmd0.NPU_OP_DMA_WAIT, 0) check_cmd0(cmds, cmd0.NPU_OP_DEPTHWISE, 0) - blk_height = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1) - blk_width = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1) - assert blk_height > 0 - assert blk_width > 0 def test_mul_with_broadcast_and_relu(): @@ -247,8 +234,10 @@ def test_mul_with_broadcast_and_relu(): op.ofm = create_feature_map(NpuShape3D(height=31, width=22, depth=31), 1, 0x52C0) op.activation = NpuActivation(NpuActivationOp.NONE_OR_RELU) op.activation.min = 0 # RELU - # Do not set a block config, let vela choose one - cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_32) + accelerator = NpuAccelerator.Ethos_U55_32 + # Select a block config using npu_find_block_configs + op.block_config = npu_find_block_configs(op, accelerator)[0] + cmds = npu_generate_register_command_stream([op], accelerator) check_cmd1(cmds, cmd1.NPU_SET_OFM_SCALE, 1073741824, 30) check_cmd0(cmds, cmd0.NPU_SET_IFM_REGION, 1) check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE0, 32) @@ -298,9 +287,6 @@ def test_mul_with_broadcast_and_relu(): check_cmd0(cmds, cmd0.NPU_SET_IFM2_ZERO_POINT, 0) check_cmd0(cmds, cmd0.NPU_SET_IFM2_PRECISION, 0) check_cmd0(cmds, cmd0.NPU_SET_IFM2_BROADCAST, 5) - check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, 23) - check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1, 3) - check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_DEPTH_M1, 31) check_cmd0(cmds, cmd0.NPU_SET_IFM_IB_END, 16) check_cmd0(cmds, cmd0.NPU_SET_AB_START, 16) check_cmd0(cmds, cmd0.NPU_SET_IFM2_IB_START, 9) @@ -330,7 +316,8 @@ def create_avg_pool_op() -> NpuPoolingOperation: ) op.kernel = NpuKernel(8, 2, 3, 3) op.padding = NpuPadding(top=0, left=2, right=3, bottom=0) - # Do not set a block config, let vela choose one + # Select a block config + op.block_config = NpuShape3D(height=4, width=4, depth=16) return op -- cgit v1.2.1