aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-03-27 11:33:50 +0200
committerFredrik Svedberg <fredrik.svedberg@arm.com>2023-03-31 12:46:51 +0000
commit56811e6d3c62ae017f6eb298fb553f7d1e77cc96 (patch)
tree26a9fb07ffa51a8cb6bf6d64cf50871d2cccfc17
parent72c6a2414205e033279f80b622cdf479c05a4f5b (diff)
downloadethos-u-vela-56811e6d3c62ae017f6eb298fb553f7d1e77cc96.tar.gz
MLBEDSW-7439: Add support for input dims < 4 for ArgMax
- Updated ARG_MAX to support IFM rank less than 4 - Regenerated SUPPORTED_OPS.md Change-Id: Icd8e72733279413cbea49021325e1ab06fdc6011 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
-rw-r--r--SUPPORTED_OPS.md3
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py3
-rw-r--r--ethosu/vela/tflite_supported_operators.py9
3 files changed, 4 insertions, 11 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 6f6167d2..ba5b7919 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -1,7 +1,7 @@
# Supported Ops
This file was automatically generated by Vela using the `--supported-ops-report` parameter.
-Vela version: `3.7.1.dev8+ga182a70.d20230322`
+Vela version: `3.7.1.dev10+g521c494`
This file complies with
[**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -101,7 +101,6 @@ This is a list of constraints that the ADD operator must satisfy in order to be
This is a list of constraints that the ARG_MAX operator must satisfy in order to be scheduled on the NPU.
- IFM must be int8 or uint8
-- Number of input dimensions must be 4
- Operation must be performed along the depth axis
- IFM depth must be no greater than 127
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 518b6db0..e0c7fd2c 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -41,6 +41,7 @@ from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
from .numeric_util import clamp_sigmoid
+from .numeric_util import full_shape
from .numeric_util import round_away_zero
from .operation import create_activation_function
from .operation import ExplicitScaling
@@ -524,7 +525,7 @@ def convert_argmax_to_depthwise_conv_and_max_pool(op, arch, nng):
op.name = "depthwise_conv_SHL_7"
op.type = Op.DepthwiseConv2DBias
op.attrs.update(dw_op_attrs)
- n, h, w, c = ifm.shape
+ n, h, w, c = full_shape(4, ifm.shape, 1)
shape = [1, 1, 1, c]
kernel = np.dstack([2**7] * c)
op.inputs = []
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index fd9a9c20..66b9e944 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -316,7 +316,6 @@ class TFLiteSupportedOperators:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
# ArgMax specific checks:
- self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_input_dimensions)
self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_axis)
self.specific_constraints[Op.ArgMax].append(TFLiteSupportedOperators.constraint_argmax_depth)
@@ -879,17 +878,11 @@ class TFLiteSupportedOperators:
inp_dims = len(op.inputs[0].shape)
axis = op.inputs[1].values
return (
- axis in (3, -1),
+ axis in (inp_dims - 1, -1),
f"Axis is {axis} and number of input dimensions is {inp_dims}",
)
@staticmethod
- def constraint_argmax_input_dimensions(op):
- "Number of input dimensions must be 4"
- inp_dims = len(op.inputs[0].shape)
- return inp_dims == 4, f"Number of input dimensions is {inp_dims}"
-
- @staticmethod
def constraint_argmax_depth(op):
"IFM depth must be no greater than 127"
ifm_depth = op.inputs[0].shape[-1]