diff options
author | Rickard Bolin <rickard.bolin@arm.com> | 2024-01-31 12:05:11 +0000 |
---|---|---|
committer | Rickard Bolin <rickard.bolin@arm.com> | 2024-05-16 14:08:21 +0000 |
commit | be78a053a57da7bdae240690c933824c0861f55b (patch) | |
tree | e6eabce902b42fcbdc7ef4cf7cfbc8136e11246d /ethosu/vela/high_level_command_to_npu_op.py | |
parent | 891468561ecfc61d27adcdc92b41ec216eaa1b08 (diff) | |
download | ethos-u-vela-be78a053a57da7bdae240690c933824c0861f55b.tar.gz |
MLBEDSW-8561: Striding support in H/W for StridedSlice3.12.0.rc1
Change-Id: Ie6f39d9c4125f7c16d27621de47cd76143c2e636
Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Diffstat (limited to 'ethosu/vela/high_level_command_to_npu_op.py')
-rw-r--r-- | ethosu/vela/high_level_command_to_npu_op.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py index 52d07187..71181d05 100644 --- a/ethosu/vela/high_level_command_to_npu_op.py +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -410,16 +410,20 @@ def create_feature_map( assert strides is not None + multiplied_strides = strides.copy() if stride_multiplier and stride_multiplier != [1, 1, 1]: assert ( tens.format == TensorFormat.NHWC ), "Only default stride multiplier ([1, 1, 1]) supported for NHCWB16 format" # Multiply strides for C/H/W (in that order) with corresponding stride factor for i, stride_factor in enumerate(stride_multiplier, start=1): - strides[i] *= stride_factor + multiplied_strides[i] *= stride_factor + + # Stride multiplier only affects tiles and addresses for OFM + _strides = multiplied_strides if is_ofm else strides height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer( - box.start_coord, box.end_coord, strides, op_shape4D + box.start_coord, box.end_coord, _strides, op_shape4D ) for idx, offset in enumerate(tile_base_offsets): @@ -427,7 +431,9 @@ def create_feature_map( fm.tiles = NpuTileBox( height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses] ) - fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1])) + fm.strides = NpuShape3D( + height=int(multiplied_strides[2]), width=int(multiplied_strides[3]), depth=int(multiplied_strides[1]) + ) fm.name = tens.name return fm @@ -518,8 +524,9 @@ def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: Archit ifm_height = cmd.ifm_box.get_block().height ifm_width = cmd.ifm_box.get_block().width ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box) - - npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0], op.tile_base_offsets_ifm[0]) + npu_op.ifm = create_feature_map( + cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0], op.tile_base_offsets_ifm[0], op.ifm_stride_multiplier[0] + ) npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth) npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor) |