diff options
-rw-r--r-- | ethosu/vela/tensor.py | 13 | ||||
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 22 | ||||
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 12 | ||||
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 13 |
4 files changed, 42 insertions, 18 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 38b0e430..e9815845 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -823,6 +823,19 @@ class Tensor: else: return self.values.item(0) + def get_shape_as_2d(self, dimension_2_size: int) -> Optional[Shape4D]: + + elms = self.elements() + dimension_1_size = elms // dimension_2_size + # Checks if the reduction works and shape is not 1D + is_reducible = dimension_1_size * dimension_2_size == elms and not (len(self.shape) == 1) + + new_shape = None + if is_reducible: + new_shape = Shape4D([dimension_1_size, 1, 1, dimension_2_size]) + + return new_shape + def __lt__(self, other: "Tensor") -> bool: return self.equivalence_id < other.equivalence_id diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 1e5dbd4d..2d6ca15a 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -81,11 +81,13 @@ def test_constraint_tens_quant_scale(): def test_constraint_fc_output_2d_not_supp(): - op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [7, 4, 6], [3, 2, 2, 8], weights_shape=[1, 9, 1]) assert not semantic_checker.is_operator_semantic_valid(op) - op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 6, 1], [3, 7, 4], weights_shape=[1, 1, 7, 1]) assert not semantic_checker.is_operator_semantic_valid(op) - op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 1, 4, 7], [1, 9], weights_shape=[12, 3]) + assert not semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4], [9], weights_shape=[3, 2]) assert not semantic_checker.is_operator_semantic_valid(op) @@ -94,6 +96,20 @@ def test_constraint_fc_output_2d_is_supp(): assert semantic_checker.is_operator_semantic_valid(op) op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024]) assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1, 1], weights_shape=[12, 1, 1, 1]) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1], weights_shape=[12, 1, 1, 1]) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [1, 1, 3, 2], weights_shape=[12, 1, 1, 1]) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 1, 1], weights_shape=[12, 1, 1, 1]) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors( + Op.FullyConnected, [12, 1, 1, 1], [1, 1, 24], weights_shape=[12, 1, 1, 1] + ) + assert semantic_checker.is_operator_semantic_valid(op) + op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1, 3, 4], weights_shape=[1, 1, 1, 1]) + assert semantic_checker.is_operator_semantic_valid(op) def test_constraint_conv_pass(): diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index b2a34195..06395784 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -379,14 +379,12 @@ def convert_nop_split_to_identity(op, arch, nng): return op -def rewrite_fully_connected_input(op, arch, nng): - if op.type == Op.FullyConnected: - n_in_elems = op.weights.shape[-2] - elms = op.ifm.elements() - batch_size = elms // n_in_elems - assert batch_size * n_in_elems == elms +def rewrite_fully_connected_input(op: Operation, arch, nng): - op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems]) + if op.type == Op.FullyConnected: + new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2]) + assert new_shape is not None, "Tensor can not be reshaped to 2D" + op.ifm_shapes[0] = new_shape return op diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index b2644791..c811a0d4 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -295,14 +295,11 @@ class TFLiteSemantic: @staticmethod def constraint_fc_output_2d(op): - "The output tensor(s) must have 2D shape" - valid = True - extra = [] - for tens in op.outputs: - if len(tens.shape) != 2: - valid = False - extra.append(f"Tensor '{tens.name}' is {len(tens.shape)}D") - return valid, ", ".join(extra) + """The output tensor(s) must have 2D shape""" + valid = op.ifm.get_shape_as_2d(op.weights.shape[-2]) is not None + extra = f"Op has non-2D output tensor '{op.ofm.name}'" if not valid else "" + + return valid, extra @staticmethod def constraint_stride_type(op): |