diff options
Diffstat (limited to 'ethosu/vela/register_command_stream_util.py')
-rw-r--r-- | ethosu/vela/register_command_stream_util.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py index c7050a3..74c4f90 100644 --- a/ethosu/vela/register_command_stream_util.py +++ b/ethosu/vela/register_command_stream_util.py @@ -60,10 +60,18 @@ def check_alignment(payload, required_alignment): raise ByteAlignmentError(f"Cmd1 payload of size: {payload} Bytes is not {required_alignment}-byte aligned") -def check_size(payload, required_multiple): +def check_size(payload, required_multiple, value_type): # assuming payload is defined in bytes if payload % required_multiple != 0: - raise ByteSizeError(f"Cmd1 payload of size: {payload} Bytes is not a multiple of {required_multiple}") + raise ByteSizeError(f"Cmd1 {value_type} of size: {payload} Bytes is not a multiple of {required_multiple}") + + +def check_stride(stride, required_multiple): + check_size(stride, required_multiple, "stride") + + +def check_length(length, required_multiple): + check_size(length, required_multiple, "length") def to_npu_kernel(kernel: Kernel) -> NpuKernel: @@ -263,12 +271,12 @@ def check_strides(fm: NpuFeatureMap, strides: NpuShape3D): if fm.layout == NpuLayout.NHCWB16: strides_to_check = [strides.depth, strides.height] - required_multiple = 16 * element_size_in_bytes + required_multiple = 16 else: strides_to_check = [strides.height, strides.width] required_multiple = element_size_in_bytes for stride in strides_to_check: - check_size(stride, required_multiple) + check_stride(stride, required_multiple) def check_addresses(addresses: List[int], layout: NpuLayout, element_size, arch: ArchitectureFeatures): @@ -384,11 +392,11 @@ def check_dma_op(dma_op: NpuDmaOperation, arch: ArchitectureFeatures): check_alignment(dma_op.src.address, 16) if dma_op.dest.region == BASE_PTR_INDEX_MEM2MEM: check_alignment(dma_op.dest.address, 16) - check_size(dma_op.src.length, 16) + check_length(dma_op.src.length, 16) else: check_alignment(dma_op.src.address, 16) check_alignment(dma_op.dest.address, 16) - check_size(dma_op.src.length, 16) + check_length(dma_op.src.length, 16) # ------------------------------------------------------------------- |