From a2ec5aa747633da72b6310ce7e5552c39f7f54bb Mon Sep 17 00:00:00 2001 From: Ayaan Masood Date: Thu, 21 Apr 2022 14:28:03 +0100 Subject: MLBEDSW-5384 FC layers run on NPU if underlying shape is 2D *Added generic function which checks if underlying shape of FullyConnected operation is 2D and performs shape reduction *Fully connected operation >2 dimensions now run on NPU if the above case is satisfied *constraint_fc_output_2d and rewrite_fully_connected_input refactored *Added unit test to confirm this functionality Signed-off-by: Ayaan Masood Change-Id: I0e29c767e5b84841eb53bbc44464b36a454f7b38 --- ethosu/vela/test/test_tflite_model_semantic.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) (limited to 'ethosu/vela/test/test_tflite_model_semantic.py') diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 1e5dbd4d..2d6ca15a 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -81,11 +81,13 @@ def test_constraint_tens_quant_scale(): def test_constraint_fc_output_2d_not_supp(): - op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [7, 4, 6], [3, 2, 2, 8], weights_shape=[1, 9, 1]) assert not semantic_checker.is_operator_semantic_valid(op) - op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 6, 1], [3, 7, 4], weights_shape=[1, 1, 7, 1]) assert not semantic_checker.is_operator_semantic_valid(op) - op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 1, 4, 7], [1, 9], weights_shape=[12, 3]) + assert not semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4], [9], weights_shape=[3, 2]) assert not semantic_checker.is_operator_semantic_valid(op) @@ -94,6 +96,20 @@ def test_constraint_fc_output_2d_is_supp(): assert semantic_checker.is_operator_semantic_valid(op) op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024]) assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1, 1], weights_shape=[12, 1, 1, 1]) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1], weights_shape=[12, 1, 1, 1]) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [1, 1, 3, 2], weights_shape=[12, 1, 1, 1]) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 1, 1], weights_shape=[12, 1, 1, 1]) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors( + Op.FullyConnected, [12, 1, 1, 1], [1, 1, 24], weights_shape=[12, 1, 1, 1] + ) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1, 3, 4], weights_shape=[1, 1, 1, 1]) + assert semantic_checker.is_operator_semantic_valid(op) def test_constraint_conv_pass(): -- cgit v1.2.1