aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/register_command_stream_generator.py
diff options
context:
space:
mode:
authorJacob Bohlin <jacob.bohlin@arm.com>2020-07-13 16:01:51 +0200
committerJacob Bohlin <jacob.bohlin@arm.com>2020-08-10 15:49:40 +0200
commite99b893beaa1b95ee86d51a613f208f9f4edf150 (patch)
treeb17720ae711a5f638f0fc7145f886d73e67bdf46 /ethosu/vela/register_command_stream_generator.py
parentecd2052d8106bc81a866f4d80ed5906d99437eec (diff)
downloadethos-u-vela-e99b893beaa1b95ee86d51a613f208f9f4edf150.tar.gz
MLBEDSW-2639: Moved the IFM/IFM2 order switch to register cmd stream generator
For binary elementwise ops with broadcasting in first IFM. Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com> Change-Id: I25af67be8d3a852247989bc3ddc8e08e946f6bfa
Diffstat (limited to 'ethosu/vela/register_command_stream_generator.py')
-rw-r--r--ethosu/vela/register_command_stream_generator.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index e0f114eb..471953d9 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -330,6 +330,21 @@ def get_op_padding_lt(cmd):
return (explicit_padding[1], explicit_padding[0])
+def ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
+ if ifm_shape == []:
+ # Scalar needs to be in IFM2
+ return False
+ elif ifm2_shape == []:
+ return True
+
+ for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
+ if ifm != ifm2 and ifm == 1:
+ # Broadcasted FM needs to be in IFM2
+ return False
+
+ return True
+
+
def generate_register_command_stream(nng, sg, arch, verbose=False):
emit = CommandStreamEmitter()
@@ -472,7 +487,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
IFM2Broadcast.ReverseOperandOrder if primary_op.attrs.get("reverse_op_order", False) else 0
)
- if cmd.ifm_tensor.shape == []:
+ if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
# The scalar has to be the ifm2 tensor so switch the ifms
cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box