diff options
Diffstat (limited to 'ethosu/vela/high_level_command_to_npu_op.py')
-rw-r--r-- | ethosu/vela/high_level_command_to_npu_op.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py index 202917bd..228c76f8 100644 --- a/ethosu/vela/high_level_command_to_npu_op.py +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -112,14 +112,15 @@ resampling_mode_inv_map = { } -def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool: - if ifm_shape == []: +def ifm_ifm2_correct_order(ifm_shape: Shape4D, ifm2_shape: Shape4D) -> bool: + + if ifm_shape is None: # Scalar needs to be in IFM2 return False - if ifm2_shape == []: + if ifm2_shape is None: return True - for ifm, ifm2 in zip(ifm_shape, ifm2_shape): + for ifm, ifm2 in zip(ifm_shape.as_list(), ifm2_shape.as_list()): if ifm != ifm2 and ifm == 1: # Broadcasted FM needs to be in IFM2 return False @@ -553,8 +554,8 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu npu_op = NpuElementWiseOperation(elemwise_op) if elemwise_op not in UNARY_ELEMWISE_OPS: - ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list() - ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list() + ifm_shape = None if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0] + ifm2_shape = None if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1] if cmd.reversed_operands: assert ifm_ifm2_correct_order(ifm_shape, ifm2_shape) npu_op.reversed_operands = True |