aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/high_level_command_to_npu_op.py
diff options
context:
space:
mode:
authorRickard Bolin <rickard.bolin@arm.com>2024-01-31 12:05:11 +0000
committerRickard Bolin <rickard.bolin@arm.com>2024-05-16 14:08:21 +0000
commitbe78a053a57da7bdae240690c933824c0861f55b (patch)
treee6eabce902b42fcbdc7ef4cf7cfbc8136e11246d /ethosu/vela/high_level_command_to_npu_op.py
parent891468561ecfc61d27adcdc92b41ec216eaa1b08 (diff)
downloadethos-u-vela-be78a053a57da7bdae240690c933824c0861f55b.tar.gz
MLBEDSW-8561: Striding support in H/W for StridedSlice3.12.0.rc1
Change-Id: Ie6f39d9c4125f7c16d27621de47cd76143c2e636 Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Diffstat (limited to 'ethosu/vela/high_level_command_to_npu_op.py')
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 52d07187..71181d05 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -410,16 +410,20 @@ def create_feature_map(
assert strides is not None
+ multiplied_strides = strides.copy()
if stride_multiplier and stride_multiplier != [1, 1, 1]:
assert (
tens.format == TensorFormat.NHWC
), "Only default stride multiplier ([1, 1, 1]) supported for NHCWB16 format"
# Multiply strides for C/H/W (in that order) with corresponding stride factor
for i, stride_factor in enumerate(stride_multiplier, start=1):
- strides[i] *= stride_factor
+ multiplied_strides[i] *= stride_factor
+
+ # Stride multiplier only affects tiles and addresses for OFM
+ _strides = multiplied_strides if is_ofm else strides
height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
- box.start_coord, box.end_coord, strides, op_shape4D
+ box.start_coord, box.end_coord, _strides, op_shape4D
)
for idx, offset in enumerate(tile_base_offsets):
@@ -427,7 +431,9 @@ def create_feature_map(
fm.tiles = NpuTileBox(
height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
)
- fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
+ fm.strides = NpuShape3D(
+ height=int(multiplied_strides[2]), width=int(multiplied_strides[3]), depth=int(multiplied_strides[1])
+ )
fm.name = tens.name
return fm
@@ -518,8 +524,9 @@ def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: Archit
ifm_height = cmd.ifm_box.get_block().height
ifm_width = cmd.ifm_box.get_block().width
ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
-
- npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0], op.tile_base_offsets_ifm[0])
+ npu_op.ifm = create_feature_map(
+ cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0], op.tile_base_offsets_ifm[0], op.ifm_stride_multiplier[0]
+ )
npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)