diff options
-rw-r--r-- | SUPPORTED_OPS.md | 3 | ||||
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 3 | ||||
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 9 |
3 files changed, 4 insertions, 11 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md index 6f6167d2..ba5b7919 100644 --- a/SUPPORTED_OPS.md +++ b/SUPPORTED_OPS.md @@ -1,7 +1,7 @@ # Supported Ops This file was automatically generated by Vela using the `--supported-ops-report` parameter. -Vela version: `3.7.1.dev8+ga182a70.d20230322` +Vela version: `3.7.1.dev10+g521c494` This file complies with [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md) @@ -101,7 +101,6 @@ This is a list of constraints that the ADD operator must satisfy in order to be This is a list of constraints that the ARG_MAX operator must satisfy in order to be scheduled on the NPU. - IFM must be int8 or uint8 -- Number of input dimensions must be 4 - Operation must be performed along the depth axis - IFM depth must be no greater than 127 diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 518b6db0..e0c7fd2c 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -41,6 +41,7 @@ from .graph_optimiser_util import needed_total_padding from .graph_optimiser_util import set_ifm_ofm_op_shapes from .graph_optimiser_util import set_tensor_equivalence from .numeric_util import clamp_sigmoid +from .numeric_util import full_shape from .numeric_util import round_away_zero from .operation import create_activation_function from .operation import ExplicitScaling @@ -524,7 +525,7 @@ def convert_argmax_to_depthwise_conv_and_max_pool(op, arch, nng): op.name = "depthwise_conv_SHL_7" op.type = Op.DepthwiseConv2DBias op.attrs.update(dw_op_attrs) - n, h, w, c = ifm.shape + n, h, w, c = full_shape(4, ifm.shape, 1) shape = [1, 1, 1, c] kernel = np.dstack([2**7] * c) op.inputs = [] diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index fd9a9c20..66b9e944 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -316,7 +316,6 @@ class TFLiteSupportedOperators: self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant) # ArgMax specific checks: - self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_input_dimensions) self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis) self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth) @@ -879,17 +878,11 @@ class TFLiteSupportedOperators: inp_dims = len(op.inputs[0].shape) axis = op.inputs[1].values return ( - axis in (3, -1), + axis in (inp_dims - 1, -1), f"Axis is {axis} and number of input dimensions is {inp_dims}", ) @staticmethod - def constraint_argmax_input_dimensions(op): - "Number of input dimensions must be 4" - inp_dims = len(op.inputs[0].shape) - return inp_dims == 4, f"Number of input dimensions is {inp_dims}" - - @staticmethod def constraint_argmax_depth(op): "IFM depth must be no greater than 127" ifm_depth = op.inputs[0].shape[-1] |