aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/high_level_command_to_npu_op.py
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-10-19 11:20:12 +0200
committerJohan Alfvén <johan.alfven@arm.com>2022-10-20 14:38:55 +0200
commit56a71b0108f43a1cb118b1e2fae902c31b2a9969 (patch)
tree4692139c4dcd6e53b55b1a07ff1b09eb8461a4a8 /ethosu/vela/high_level_command_to_npu_op.py
parentfba0a7dc43373a69f3c0792587d3d9b0cc010ccf (diff)
downloadethos-u-vela-56a71b0108f43a1cb118b1e2fae902c31b2a9969.tar.gz
MLBEDSW-7019: Update to elementwise cascading
- The cascade builder is using the ifm_ifm2_correct_order function in order to decide if the operator is cascadable or not. The problem is that this function expects a full shape or no shape and the cascade builder did not provide that, so the operator was reported to be non cascadable. - The fix is to provide a full 4D shape, also refactoring ifm_ifm2_correct_order to use 4D shape to avoid confusion in the future. - Refactoring code so that the scheduler can perform a correct ifm and ifm2 swap. Signed-off-by: Johan Alfven <johan.alfven@arm.com> Change-Id: I9a86c4690612f332afa428456a07e67698852495
Diffstat (limited to 'ethosu/vela/high_level_command_to_npu_op.py')
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 202917bd..228c76f8 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -112,14 +112,15 @@ resampling_mode_inv_map = {
}
-def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
- if ifm_shape == []:
+def ifm_ifm2_correct_order(ifm_shape: Shape4D, ifm2_shape: Shape4D) -> bool:
+
+ if ifm_shape is None:
# Scalar needs to be in IFM2
return False
- if ifm2_shape == []:
+ if ifm2_shape is None:
return True
- for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
+ for ifm, ifm2 in zip(ifm_shape.as_list(), ifm2_shape.as_list()):
if ifm != ifm2 and ifm == 1:
# Broadcasted FM needs to be in IFM2
return False
@@ -553,8 +554,8 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu
npu_op = NpuElementWiseOperation(elemwise_op)
if elemwise_op not in UNARY_ELEMWISE_OPS:
- ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
- ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
+ ifm_shape = None if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0]
+ ifm2_shape = None if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1]
if cmd.reversed_operands:
assert ifm_ifm2_correct_order(ifm_shape, ifm2_shape)
npu_op.reversed_operands = True