aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/high_level_command_to_npu_op.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/high_level_command_to_npu_op.py')
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py20
1 files changed, 18 insertions, 2 deletions
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 5e9dffa..53df096 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -383,6 +383,7 @@ def create_feature_map(
op_shape4D: Shape4D,
tile_base_offsets: List[int],
stride_multiplier: Optional[List[int]] = None,
+ is_ofm: bool = False,
) -> NpuFeatureMap:
"""Creates feature map with common fields populated"""
fm = NpuFeatureMap()
@@ -395,7 +396,16 @@ def create_feature_map(
else:
assert 0, "Incorrect tensor format"
- strides = tens.get_strides(op_shape4D)
+ if is_ofm and tens.ops[0] is not None and tens.ops[0].original_type == Op.Transpose:
+ # op_shape4D has ifm shape, see fixup_transpose. Stride calculations needs to be
+ # based on the correct ofm shape.
+ op_shape4D_ofm_shape = Shape4D([op_shape4D.batch, op_shape4D.width, op_shape4D.height, op_shape4D.depth])
+ strides = tens.get_strides(op_shape4D_ofm_shape)
+ # Swap h and w strides which will cause the transpose to happen
+ strides[-3], strides[-2] = strides[-2], strides[-3]
+ else:
+ strides = tens.get_strides(op_shape4D)
+
assert strides is not None
if stride_multiplier and stride_multiplier != [1, 1, 1]:
@@ -513,7 +523,13 @@ def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: Archit
out_block = cmd.ofm_box.get_block()
npu_op.ofm = create_feature_map(
- cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0], op.tile_base_offsets_ofm, op.ofm_stride_multiplier
+ cmd.ofm_tensor,
+ cmd.ofm_box,
+ arch,
+ ps.ofm_shapes[0],
+ op.tile_base_offsets_ofm,
+ op.ofm_stride_multiplier,
+ is_ofm=True,
)
npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)