aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/register_command_stream_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/register_command_stream_util.py')
-rw-r--r--ethosu/vela/register_command_stream_util.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py
index c7050a3..74c4f90 100644
--- a/ethosu/vela/register_command_stream_util.py
+++ b/ethosu/vela/register_command_stream_util.py
@@ -60,10 +60,18 @@ def check_alignment(payload, required_alignment):
raise ByteAlignmentError(f"Cmd1 payload of size: {payload} Bytes is not {required_alignment}-byte aligned")
-def check_size(payload, required_multiple):
+def check_size(payload, required_multiple, value_type):
# assuming payload is defined in bytes
if payload % required_multiple != 0:
- raise ByteSizeError(f"Cmd1 payload of size: {payload} Bytes is not a multiple of {required_multiple}")
+ raise ByteSizeError(f"Cmd1 {value_type} of size: {payload} Bytes is not a multiple of {required_multiple}")
+
+
+def check_stride(stride, required_multiple):
+ check_size(stride, required_multiple, "stride")
+
+
+def check_length(length, required_multiple):
+ check_size(length, required_multiple, "length")
def to_npu_kernel(kernel: Kernel) -> NpuKernel:
@@ -263,12 +271,12 @@ def check_strides(fm: NpuFeatureMap, strides: NpuShape3D):
if fm.layout == NpuLayout.NHCWB16:
strides_to_check = [strides.depth, strides.height]
- required_multiple = 16 * element_size_in_bytes
+ required_multiple = 16
else:
strides_to_check = [strides.height, strides.width]
required_multiple = element_size_in_bytes
for stride in strides_to_check:
- check_size(stride, required_multiple)
+ check_stride(stride, required_multiple)
def check_addresses(addresses: List[int], layout: NpuLayout, element_size, arch: ArchitectureFeatures):
@@ -384,11 +392,11 @@ def check_dma_op(dma_op: NpuDmaOperation, arch: ArchitectureFeatures):
check_alignment(dma_op.src.address, 16)
if dma_op.dest.region == BASE_PTR_INDEX_MEM2MEM:
check_alignment(dma_op.dest.address, 16)
- check_size(dma_op.src.length, 16)
+ check_length(dma_op.src.length, 16)
else:
check_alignment(dma_op.src.address, 16)
check_alignment(dma_op.dest.address, 16)
- check_size(dma_op.src.length, 16)
+ check_length(dma_op.src.length, 16)
# -------------------------------------------------------------------