aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-11-25 14:10:30 +0100
committerLouis Verhaard <louis.verhaard@arm.com>2020-11-26 08:13:50 +0100
commit933f55ea6f686d0cf390f4767e87a391686c3df8 (patch)
tree370321021ef2553df76e6e46b127cac07ec9d8be
parent34b9dc15b27219bd6485eb5104506d647e1f6d29 (diff)
downloadethos-u-vela-933f55ea6f686d0cf390f4767e87a391686c3df8.tar.gz
MLBEDSW-3599: Added API for finding block configs
Added public API function npu_find_block_configs. Change-Id: Ib0925a62d7c5d19a9b9fbd8d808943c2ea2df02f Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
-rw-r--r--API.md7
-rw-r--r--ethosu/vela/api.py24
-rw-r--r--ethosu/vela/register_command_stream_generator.py120
-rw-r--r--ethosu/vela/test/extapi/test_extapi_find_block_configs.py63
-rw-r--r--ethosu/vela/test/extapi/test_extapi_generate_commands.py33
5 files changed, 147 insertions, 100 deletions
diff --git a/API.md b/API.md
index 4607378..25a1a16 100644
--- a/API.md
+++ b/API.md
@@ -40,6 +40,13 @@ these basic NPU operations. Note that the compiler is responsible for all
address planning, i.e. it needs to supply addresses of all input and output
tensors, weights, and biases.
+### Finding block configs
+
+For all NPU operations, a block config must be set, which is the unit of work in
+which the NPU generates the output. There are restrictions to the size of block
+configs. Function `npu_find_block_configs` can be used to find valid block
+configs for an operation.
+
### Encoding of weights and biases
All weights that are used in the NPU operations must be encoded using
diff --git a/ethosu/vela/api.py b/ethosu/vela/api.py
index e628600..f972133 100644
--- a/ethosu/vela/api.py
+++ b/ethosu/vela/api.py
@@ -320,20 +320,19 @@ class NpuBlockOperation(NpuOperation):
self.ofm: Optional[NpuFeatureMap] = None
self.kernel: Optional[NpuKernel] = None
# Weights, one element for each NPU core, empty if no weights are used.
- # Must have been compressed using weight_compressor.encode_weights()
+ # Must have been compressed using npu_encode_weights()
self.weights: List[NpuAddressRange] = []
# Biases, one element for each NPU core, empty if no bias is used.
- # Must have been encoded using weight_compressor.encode_bias()
+ # Must have been encoded using npu_encode_bias()
self.biases: List[NpuAddressRange] = []
self.padding: Optional[NpuPadding] = None
# Optional activation function to be applied
self.activation: Optional[NpuActivation] = None
- # The block config is the unit of work in which the NPU generates the OFM.
+ # The block config to be used, which must be valid for the given operation.
+ # See also npu_find_block_configs.
# If the operation has weights, the depth of the block config must be the same as
- # the ofm depth used in the call to weight_compressor.encode_weights()
- # If set to None, vela will determine a suitable block size (can only be used if there are no weights)
- # If block_config.width and height are set to -1, vela will determine suitable width/height
- self.block_config: Optional[NpuShape3D] = None # OFM_BLK parameters
+ # the ofm depth used in the call to npu_encode_weights()
+ self.block_config: NpuShape3D
self.rounding_mode: NpuRoundingMode = NpuRoundingMode.TFL
# Set to True if the operations is fused with a Quantize operation (affects scaling)
self.fused_quantize: bool = False
@@ -441,6 +440,17 @@ def npu_encode_bias(bias: numpy.int64, scale: int, shift: int):
return weight_compressor.encode_bias(bias, scale, shift)
+def npu_find_block_configs(npu_op: NpuOperation, accelerator: NpuAccelerator) -> List[NpuShape3D]:
+ """
+ Public facing API that returns a list of block configs that are valid for the given operation.
+ This function can be used to find a valid value for npu_op.block_config.
+ The block config is the unit of work in which the NPU generates the OFM.
+ """
+ from . import register_command_stream_generator
+
+ return register_command_stream_generator.find_block_configs(npu_op, accelerator)
+
+
def npu_generate_register_command_stream(npu_op_list: List[NpuOperation], accelerator: NpuAccelerator) -> List[int]:
"""
Public facing API for generating an Ethos-U register command stream.
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 9d79d58..015a8c4 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -521,35 +521,15 @@ def generate_block_config(
npu_op: NpuBlockOperation,
arch: ArchitectureFeatures,
shared_buffer: SharedBufferAllocation,
-) -> NpuShape3D:
- """Selects a suitable block config if none has been set, and generates OFM_BLK_HEIGHT/WIDTH/DEPTH registers"""
+):
+ """Generates OFM_BLK_HEIGHT/WIDTH/DEPTH registers"""
block_config = npu_op.block_config
- if block_config is None or block_config.height < 0:
- # Note: this code only used if the public API to generate command streams is used;
- # in the "normal" flow, the block config selected by the scheduler is used
- if npu_op.weights:
- assert block_config is not None, "block_config.depth must be provided for ops with weights"
- # Block config has not been provided: find one
- blocks = find_suitable_block_configs(arch, shared_buffer)
- # Return the block with biggest volume
- # TODO: use a better algorithm to find the best block
- best_block = None
- best_value = 0
- for block in blocks:
- if block_config is not None and block[3] != block_config.depth:
- continue
- value = block[0] * block[1] * block[3]
- if value > best_value:
- best_value = value
- best_block = block
- assert best_block is not None, f"No suitable block config was found, {npu_op.op_type}"
- block_config = NpuShape3D(height=best_block[0], width=best_block[1], depth=best_block[3])
+ assert block_config is not None, "block_config has not been set"
alloc = shared_buffer.try_block(Block(block_config.width, block_config.height, block_config.depth))
assert alloc is not None, f"Block config {block_config} does not fit, op: {npu_op.op_type}"
emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, block_config.height - 1)
emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_WIDTH_M1, block_config.width - 1)
emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_DEPTH_M1, block_config.depth - 1)
- return block_config
def generate_shram_registers_elementwise(
@@ -585,6 +565,24 @@ def generate_shram_registers_non_elementwise(emit: CommandStreamEmitter, shared_
emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[shared_buffer.use_accumulator_element])
+def create_shared_buffer(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> SharedBufferAllocation:
+ """Creates shared buffer allocation for the given operation"""
+ op_type = npu_op.op_type
+ block_type = NpuBlockType.Default
+ if op_type == NpuOperationType.Conv2D:
+ block_type = NpuBlockType.ConvolutionMxN
+ elif op_type == NpuOperationType.ConvDepthWise:
+ block_type = NpuBlockType.ConvolutionDepthWise
+ elif op_type == NpuOperationType.Pooling:
+ block_type = NpuBlockType.ReduceSum if npu_op.sub_op_type == NpuPoolingOp.REDUCE_SUM else NpuBlockType.Pooling
+ elif op_type == NpuOperationType.ElementWise:
+ block_type = NpuBlockType.ElementWise
+ else:
+ assert 0, "Unsupported operation"
+ ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale]
+ return shared_buffer_allocation_for_npu_op(arch, npu_op, block_type, ifm_resampling_mode)
+
+
def generate_common(
emit: CommandStreamEmitter,
npu_op: NpuBlockOperation,
@@ -608,6 +606,12 @@ def generate_common(
generate_weights(emit, npu_op.weights, arch)
generate_biases(emit, npu_op.biases, arch)
generate_activation(emit, npu_op.activation, npu_op.ofm)
+ shared_buffer = create_shared_buffer(npu_op, arch)
+ generate_block_config(emit, npu_op, arch, shared_buffer)
+ if npu_op.op_type == NpuOperationType.ElementWise:
+ generate_shram_registers_elementwise(emit, npu_op, arch, shared_buffer)
+ else:
+ generate_shram_registers_non_elementwise(emit, shared_buffer)
# -------------------------------------------------------------------
@@ -962,13 +966,7 @@ def get_ifm_ofm_block_depth(arch: ArchitectureFeatures, npu_op: NpuBlockOperatio
return npu_op.ofm.shape.depth
-def calc_blockdep(
- arch: ArchitectureFeatures,
- prev_op: Optional[NpuBlockOperation],
- prev_block_config: Optional[NpuShape3D],
- npu_op: NpuBlockOperation,
- block_config: NpuShape3D,
-) -> int:
+def calc_blockdep(arch: ArchitectureFeatures, prev_op: Optional[NpuBlockOperation], npu_op: NpuBlockOperation,) -> int:
"""Calculates the value of the BLOCKDEP register"""
if prev_op is None:
return 0
@@ -976,6 +974,8 @@ def calc_blockdep(
return ArchitectureFeatures.MAX_BLOCKDEP
if prev_op.ofm.shape != npu_op.ifm.shape:
return 0
+ prev_block_config = prev_op.block_config
+ block_config = npu_op.block_config
prev_ifm_block_depth = get_ifm_ofm_block_depth(arch, prev_op)
prev_ofm_block = Block(prev_block_config.width, prev_block_config.height, prev_block_config.depth)
prev_ofm_rect = shape3d_to_rect(prev_op.ofm.shape)
@@ -1094,28 +1094,14 @@ def generate_operation_code(emit: CommandStreamEmitter, npu_op: NpuOperation):
assert 0, "Unsupported operation"
-def generate_conv2d_op(
- emit: CommandStreamEmitter, npu_op: NpuConv2DOperation, arch: ArchitectureFeatures
-) -> NpuShape3D:
+def generate_conv2d_op(emit: CommandStreamEmitter, npu_op: NpuConv2DOperation, arch: ArchitectureFeatures):
"""Generates register commands for Conv2D operations"""
generate_common(emit, npu_op, npu_op.block_traversal, arch)
- ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale]
- shared_buffer = shared_buffer_allocation_for_npu_op(arch, npu_op, NpuBlockType.ConvolutionMxN, ifm_resampling_mode)
- block_config = generate_block_config(emit, npu_op, arch, shared_buffer)
- generate_shram_registers_non_elementwise(emit, shared_buffer)
- return block_config
def generate_conv_depthwise_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, arch: ArchitectureFeatures):
"""Generates register commands for depthwise convolution operations"""
generate_common(emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch)
- ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale]
- shared_buffer = shared_buffer_allocation_for_npu_op(
- arch, npu_op, NpuBlockType.ConvolutionDepthWise, ifm_resampling_mode
- )
- block_config = generate_block_config(emit, npu_op, arch, shared_buffer)
- generate_shram_registers_non_elementwise(emit, shared_buffer)
- return block_config
def generate_pooling_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, arch: ArchitectureFeatures):
@@ -1127,12 +1113,6 @@ def generate_pooling_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation,
# Pooling op specific
if use_global_scale:
generate_ofm_scaling_for_pooling(emit, npu_op)
- ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale]
- npu_block_type = NpuBlockType.ReduceSum if npu_op.sub_op_type == NpuPoolingOp.REDUCE_SUM else NpuBlockType.Pooling
- shared_buffer = shared_buffer_allocation_for_npu_op(arch, npu_op, npu_block_type, ifm_resampling_mode)
- block_config = generate_block_config(emit, npu_op, arch, shared_buffer)
- generate_shram_registers_non_elementwise(emit, shared_buffer)
- return block_config
def generate_elementwise_op(emit: CommandStreamEmitter, npu_op: NpuElementWiseOperation, arch: ArchitectureFeatures):
@@ -1160,11 +1140,6 @@ def generate_elementwise_op(emit: CommandStreamEmitter, npu_op: NpuElementWiseOp
quantized_scalar = quantise(npu_op.ifm2_scalar, npu_op.ifm2.quantization)
assert npu_op.ifm2.data_type.min_value() <= quantized_scalar <= npu_op.ifm2.data_type.max_value()
emit.cmd0_with_param(cmd0.NPU_SET_IFM2_SCALAR, quantized_scalar)
- ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale]
- shared_buffer = shared_buffer_allocation_for_npu_op(arch, npu_op, NpuBlockType.ElementWise, ifm_resampling_mode)
- block_config = generate_block_config(emit, npu_op, arch, shared_buffer)
- generate_shram_registers_elementwise(emit, npu_op, arch, shared_buffer)
- return block_config
def generate_dma_op(emit: CommandStreamEmitter, dma_op: NpuDmaOperation):
@@ -1177,28 +1152,24 @@ def generate_dma_op(emit: CommandStreamEmitter, dma_op: NpuDmaOperation):
emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_LEN, dma_op.src.length)
-def generate_registers_for_op(
- emit: CommandStreamEmitter, npu_op: NpuOperation, arch: ArchitectureFeatures
-) -> Optional[NpuShape3D]:
+def generate_registers_for_op(emit: CommandStreamEmitter, npu_op: NpuOperation, arch: ArchitectureFeatures):
"""
Generates register commands for the given operation, but not the final NPU_OP_... command.
Returns the selected block config
"""
op_type = npu_op.op_type
- block_config = None
if op_type == NpuOperationType.Conv2D:
- block_config = generate_conv2d_op(emit, npu_op, arch)
+ generate_conv2d_op(emit, npu_op, arch)
elif op_type == NpuOperationType.ConvDepthWise:
- block_config = generate_conv_depthwise_op(emit, npu_op, arch)
+ generate_conv_depthwise_op(emit, npu_op, arch)
elif op_type == NpuOperationType.Pooling:
- block_config = generate_pooling_op(emit, npu_op, arch)
+ generate_pooling_op(emit, npu_op, arch)
elif op_type == NpuOperationType.ElementWise:
- block_config = generate_elementwise_op(emit, npu_op, arch)
+ generate_elementwise_op(emit, npu_op, arch)
elif op_type == NpuOperationType.Dma:
generate_dma_op(emit, npu_op)
else:
assert 0, "Unsupported operation"
- return block_config
def generate_command_stream(
@@ -1216,19 +1187,16 @@ def generate_command_stream(
emit.cmd0_with_param(cmd0.NPU_SET_PARALLEL_MODE, arch.ncores - 1)
dep_watermark = Watermark(0, 0)
prev_op = None
- prev_block_config = None
# Generate register commands for all operations
for op_index, npu_op in enumerate(npu_op_list):
dep_watermark, cmd_waits = get_wait_dependency(arch, npu_op_list, memory_accesses, op_index, dep_watermark)
- block_config = generate_registers_for_op(emit, npu_op, arch)
+ generate_registers_for_op(emit, npu_op, arch)
if not is_dma_op(npu_op):
# Generate BLOCKDEP
- assert block_config is not None
- blockdep = calc_blockdep(arch, prev_op, prev_block_config, npu_op, block_config)
+ blockdep = calc_blockdep(arch, prev_op, npu_op)
blockdep = min(blockdep, arch.max_blockdep)
emit.cmd0_with_param(cmd0.NPU_SET_BLOCKDEP, blockdep)
prev_op = npu_op
- prev_block_config = block_config
generate_cmd_waits(emit, cmd_waits)
# Generate the actual NPU_OP command
@@ -1272,6 +1240,18 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
print("command stream length in words", len(sg.register_command_stream))
+def find_block_configs(npu_op: NpuOperation, npu_accelerator: NpuAccelerator) -> List[NpuShape3D]:
+ """
+ Internal implementation of the public facing API for finding block configs.
+ """
+ if is_dma_op(npu_op):
+ return []
+ arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator))
+ shared_buffer = create_shared_buffer(npu_op, arch)
+ blocks = find_suitable_block_configs(arch, shared_buffer)
+ return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks]
+
+
def generate_register_command_stream(npu_op_list: List[NpuOperation], npu_accelerator: NpuAccelerator) -> List[int]:
"""
Internal implementation of the public facing API for generating an Ethos-U register command stream.
diff --git a/ethosu/vela/test/extapi/test_extapi_find_block_configs.py b/ethosu/vela/test/extapi/test_extapi_find_block_configs.py
new file mode 100644
index 0000000..07cb9cb
--- /dev/null
+++ b/ethosu/vela/test/extapi/test_extapi_find_block_configs.py
@@ -0,0 +1,63 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Contains unit tests for npu_find_block_configs API for an external consumer
+from ethosu.vela.api import npu_find_block_configs
+from ethosu.vela.api import npu_generate_register_command_stream
+from ethosu.vela.api import NpuAccelerator
+from ethosu.vela.api import NpuAddressRange
+from ethosu.vela.api import NpuBlockTraversal
+from ethosu.vela.api import NpuConv2DOperation
+from ethosu.vela.api import NpuKernel
+from ethosu.vela.api import NpuPadding
+from ethosu.vela.api import NpuQuantization
+from ethosu.vela.api import NpuShape3D
+from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd0
+from ethosu.vela.test.extapi.test_extapi_generate_commands import check_cmd0
+from ethosu.vela.test.extapi.test_extapi_generate_commands import create_feature_map
+
+
+def test_find_block_configs():
+ """Tests npu_find_block_configs"""
+ # Create a Conv2D operation
+ op = NpuConv2DOperation()
+ op.ifm = create_feature_map(
+ NpuShape3D(height=30, width=62, depth=46), 1, 512, quant=NpuQuantization(scale_f32=0.007843138, zero_point=128)
+ )
+ op.ofm = create_feature_map(
+ NpuShape3D(height=30, width=31, depth=46),
+ 1,
+ 0x14E40,
+ quant=NpuQuantization(scale_f32=0.20392157, zero_point=128),
+ )
+ op.kernel = NpuKernel(3, 2, 2, 1)
+ op.biases = [NpuAddressRange(region=0, address=32000, length=464)]
+ op.padding = NpuPadding(top=0, left=0, right=1, bottom=1)
+ op.block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
+ # Find valid block configs
+ accelerator = NpuAccelerator.Ethos_U55_256
+ block_configs = npu_find_block_configs(op, accelerator)
+ # Select the last one
+ op.block_config = block_configs[-1]
+ # Note: the weights should be encoded with op.block_config.depth (not shown here)
+ op.weights = [NpuAddressRange(region=0, address=0, length=7696)]
+ # Check that generating register commands succeeds
+ cmds = npu_generate_register_command_stream([op], accelerator)
+ # Check that the selected block config was used
+ check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, op.block_config.height - 1)
+ check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1, op.block_config.width - 1)
+ check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_DEPTH_M1, op.block_config.depth - 1)
diff --git a/ethosu/vela/test/extapi/test_extapi_generate_commands.py b/ethosu/vela/test/extapi/test_extapi_generate_commands.py
index 86ef804..812991a 100644
--- a/ethosu/vela/test/extapi/test_extapi_generate_commands.py
+++ b/ethosu/vela/test/extapi/test_extapi_generate_commands.py
@@ -16,6 +16,7 @@
#
# Description:
# Contains unit tests for npu_generate_register_command_stream API for an external consumer
+from ethosu.vela.api import npu_find_block_configs
from ethosu.vela.api import npu_generate_register_command_stream
from ethosu.vela.api import NpuAccelerator
from ethosu.vela.api import NpuActivation
@@ -106,9 +107,7 @@ def test_conv2d():
op.biases = [NpuAddressRange(region=0, address=32000, length=464)]
op.padding = NpuPadding(top=0, left=0, right=1, bottom=1)
op.block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
- # In this example we assume that the weights were compressed with ofm depth 16;
- # let vela choose suitable block width and height by setting these to -1
- op.block_config = NpuShape3D(height=-1, width=-1, depth=16)
+ op.block_config = NpuShape3D(height=16, width=4, depth=16)
cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_128)
check_cmd0(cmds, cmd0.NPU_SET_IFM_REGION, 1)
check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE0, 512)
@@ -165,12 +164,6 @@ def test_conv2d():
check_cmd0(cmds, cmd0.NPU_SET_ACC_FORMAT, 0)
check_cmd0(cmds, cmd0.NPU_SET_BLOCKDEP, 0)
check_cmd0(cmds, cmd0.NPU_OP_CONV, 0)
- # Check that block width/height were generated that fit
- blk_height = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1)
- blk_width = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1)
- assert blk_height > 0
- assert blk_width > 0
- assert (blk_height + 1) * (blk_width + 1) <= 64
def create_fully_connected_op() -> NpuConv2DOperation:
@@ -194,9 +187,7 @@ def create_fully_connected_op() -> NpuConv2DOperation:
op.biases = [NpuAddressRange(region=0, address=0x19BC0, length=960)]
op.padding = NpuPadding(top=0, left=0, right=0, bottom=0)
op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
- # In this example we assume that the weights were compressed with ofm depth 96;
- # let vela choose suitable block width and height by setting these to -1
- op.block_config = NpuShape3D(height=-1, width=-1, depth=96)
+ op.block_config = NpuShape3D(height=2, width=4, depth=96)
return op
@@ -222,7 +213,7 @@ def test_depthwise():
op.padding = NpuPadding(top=1, left=1, right=1, bottom=1)
op.weights = [weights_dest]
op.biases = [NpuAddressRange(region=0, address=0, length=80)]
- op.block_config = NpuShape3D(height=-1, width=-1, depth=8)
+ op.block_config = NpuShape3D(height=8, width=12, depth=8)
cmds = npu_generate_register_command_stream([dma_op, op], NpuAccelerator.Ethos_U55_128)
check_cmd0(cmds, cmd0.NPU_SET_DMA0_SRC_REGION, 0)
check_cmd1(cmds, cmd1.NPU_SET_DMA0_SRC, 0x40)
@@ -233,10 +224,6 @@ def test_depthwise():
# A DMA WAIT should have been inserted
check_cmd0(cmds, cmd0.NPU_OP_DMA_WAIT, 0)
check_cmd0(cmds, cmd0.NPU_OP_DEPTHWISE, 0)
- blk_height = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1)
- blk_width = find_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1)
- assert blk_height > 0
- assert blk_width > 0
def test_mul_with_broadcast_and_relu():
@@ -247,8 +234,10 @@ def test_mul_with_broadcast_and_relu():
op.ofm = create_feature_map(NpuShape3D(height=31, width=22, depth=31), 1, 0x52C0)
op.activation = NpuActivation(NpuActivationOp.NONE_OR_RELU)
op.activation.min = 0 # RELU
- # Do not set a block config, let vela choose one
- cmds = npu_generate_register_command_stream([op], NpuAccelerator.Ethos_U55_32)
+ accelerator = NpuAccelerator.Ethos_U55_32
+ # Select a block config using npu_find_block_configs
+ op.block_config = npu_find_block_configs(op, accelerator)[0]
+ cmds = npu_generate_register_command_stream([op], accelerator)
check_cmd1(cmds, cmd1.NPU_SET_OFM_SCALE, 1073741824, 30)
check_cmd0(cmds, cmd0.NPU_SET_IFM_REGION, 1)
check_cmd1(cmds, cmd1.NPU_SET_IFM_BASE0, 32)
@@ -298,9 +287,6 @@ def test_mul_with_broadcast_and_relu():
check_cmd0(cmds, cmd0.NPU_SET_IFM2_ZERO_POINT, 0)
check_cmd0(cmds, cmd0.NPU_SET_IFM2_PRECISION, 0)
check_cmd0(cmds, cmd0.NPU_SET_IFM2_BROADCAST, 5)
- check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, 23)
- check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_WIDTH_M1, 3)
- check_cmd0(cmds, cmd0.NPU_SET_OFM_BLK_DEPTH_M1, 31)
check_cmd0(cmds, cmd0.NPU_SET_IFM_IB_END, 16)
check_cmd0(cmds, cmd0.NPU_SET_AB_START, 16)
check_cmd0(cmds, cmd0.NPU_SET_IFM2_IB_START, 9)
@@ -330,7 +316,8 @@ def create_avg_pool_op() -> NpuPoolingOperation:
)
op.kernel = NpuKernel(8, 2, 3, 3)
op.padding = NpuPadding(top=0, left=2, right=3, bottom=0)
- # Do not set a block config, let vela choose one
+ # Select a block config
+ op.block_config = NpuShape3D(height=4, width=4, depth=16)
return op