aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_supported_operators.py
diff options
context:
space:
mode:
authorDwight Lidman <dwight.lidman@arm.com>2020-11-24 13:45:50 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-11-26 07:19:01 +0000
commit0dd21c79ac6ef588e23393064d25e402e16cc2dd (patch)
tree6933d6bd1df37485f7537deed4b19c2e0af805f3 /ethosu/vela/test/test_supported_operators.py
parent933f55ea6f686d0cf390f4767e87a391686c3df8 (diff)
downloadethos-u-vela-0dd21c79ac6ef588e23393064d25e402e16cc2dd.tar.gz
MLBEDSW-3558: Put FC on CPU when OFM != 2D
This commit adds a constraint to FullyConnected ops in supported_operators.py that puts any such op on the CPU if tensor dimensions of the output(s) are not 2D. Signed-off-by: Dwight Lidman <dwight.lidman@arm.com> Change-Id: I8c898a780b40fc4a1383c09213f0696ea6699b7d
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_supported_operators.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 72ccad24..f132eef7 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -122,6 +122,22 @@ def test_constraint_tens_quant_per_axis_is_supp():
assert support.is_operator_supported(op)
+def test_constraint_fc_output_2d_not_supp():
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1])
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_fc_output_2d_is_supp():
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
+ assert support.is_operator_supported(op)
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
+ assert support.is_operator_supported(op)
+
+
def test_constraint_faf():
# Fused activation functions, if set, must be a valid op type
op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])