aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py33
1 files changed, 27 insertions, 6 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 597e0a2c..7d544004 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -191,7 +191,7 @@ class TFLiteSupportedOperators:
filter_range = (1, 8)
filter_height_range = (1, 256)
filter_product_range = (1, 256 * 256)
- mean_width_size = 64 * 64
+ mean_reduced_axis_max_size = 64 * 64
mean_kernel_product_int8 = 2 ** (24)
mean_kernel_product_uint8 = 2 ** (23)
mean_kernel_product_int16 = 2 ** (16)
@@ -315,6 +315,7 @@ class TFLiteSupportedOperators:
# Mean specific checks:
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_width)
+ self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_depth)
# Reshape specific checks:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
@@ -844,9 +845,9 @@ class TFLiteSupportedOperators:
@docstring_format_args([mean_kernel_product_int8, mean_kernel_product_uint8, mean_kernel_product_int16])
def constraint_mean_height_width_product(cls, op):
"""Product of reduced axes must be no greater than:
- - {} for signed 8-bit inputs
- - {} for unsigned 8-bit inputs
- - {} for signed 16-bit inputs"""
+ - {} for signed 8-bit inputs.
+ - {} for unsigned 8-bit inputs.
+ - {} for signed 16-bit inputs."""
shape = op.inputs[0].shape
if op.inputs[1].shape == []:
axis = [int(op.inputs[1].values)]
@@ -869,15 +870,35 @@ class TFLiteSupportedOperators:
return prod <= max_prod, f"Datatype is {datatype}, product of axes is {prod}"
@classmethod
- @docstring_format_args([mean_width_size])
+ @docstring_format_args([mean_reduced_axis_max_size])
def constraint_mean_width(cls, op):
"""If Width axis is reduced its shape must be no greater than {}."""
shape = op.inputs[0].shape
hi = 0 if len(shape) < 4 else 1
h, w = shape[hi : hi + 2]
- max_width = cls.mean_width_size
+ max_width = cls.mean_reduced_axis_max_size
return w <= max_width, f"Width is {w}"
+ @classmethod
+ @docstring_format_args([mean_reduced_axis_max_size])
+ def constraint_mean_depth(cls, op):
+ """If Depth axis is reduced its shape must be no greater than {}."""
+ max_depth = cls.mean_reduced_axis_max_size
+ shape = op.inputs[0].shape
+
+ if op.inputs[1].shape == []:
+ axis = [int(op.inputs[1].values)]
+ else:
+ axis = list(op.inputs[1].values)
+
+ depth_idx = len(shape) - 1
+
+ supported = True
+ if depth_idx in axis and shape[-1] > max_depth:
+ supported = False
+
+ return supported, f"Depth is {shape[-1]}, shape is {shape}, axis is {axis}"
+
@staticmethod
def constraint_reshape_shape_constant(op):
"Shape must be constant"