aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_supported_operators.py')
-rw-r--r--ethosu/vela/tosa_supported_operators.py36
1 files changed, 35 insertions, 1 deletions
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index c619f2f9..d3686160 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -54,11 +54,13 @@ class TosaSupportedOperators:
# Supported data types
# TODO will differ compared to TensorFlow Lite, currently set to the same
supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32)) # TODO add bool
+ tens_dim_range = (1, 65535) # TODO HW limitation, that is to be resolved in SW
def __init__(self):
# Setup the generic constraints. Note: the order matters
self.generic_constraints = []
self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dtype)
+ self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dimension)
# Setup specific constraints. Note: the order matters
self.specific_constraints = defaultdict(list)
@@ -69,6 +71,10 @@ class TosaSupportedOperators:
for op_type in TosaSupportedOperators.depthwise_convolution_ops:
self.specific_constraints[op_type].append(TosaSupportedOperators.constraint_depth_multiplier)
+ # Avgpool specific checks
+ for op_type in TosaSupportedOperators.avg_pooling_ops:
+ self.specific_constraints[op_type].append(TosaSupportedOperators.constraint_padding)
+
def is_operator_supported(self, op):
ext_type = optype_to_tosa_op_type(op.type)
if op.type not in TosaSupportedOperators.supported_operators:
@@ -103,13 +109,41 @@ class TosaSupportedOperators:
extra.append(f"Tensor '{tens.name}' has data type: {tens.dtype}")
return valid, ", ".join(extra)
+ # TODO Duplicates check present for TFLite. But it is only temporarily added
+ # This is for a HW limitation, that is to be resolved in SW later on
+ @classmethod
+ @docstring_format_args(tens_dim_range)
+ def constraint_tens_dimension(cls, op):
+ "Tensor dimensions must be in the range [{}, {}]"
+ tens_min, tens_max = cls.tens_dim_range
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ if not all(tens_min <= dim <= tens_max for dim in tens.shape):
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+ return valid, ", ".join(extra)
+
@staticmethod
def constraint_ifm_producer(cls, op):
"Input must be constant data"
valid = op.ifm.ops and op.ifm.ops[0].type == Op.Const
return valid, "Op has ifm with non-constant data"
- # TODO duplicates TFLite_supported operators, but support for depth multiplier should be added at a later stage
+ @staticmethod
+ def constraint_padding(op):
+ # TODO Only support for when global scaling can be used.
+ # That is when there is padding no padding
+ "Avgpool only supported for no padding"
+ top, left, _, _ = op.attrs["explicit_padding"]
+ valid = top == 0 and left == 0
+
+ return valid, "Avgpool with pad_top {top} and pad_left {left}"
+
+ # TODO duplicates tflite_supported operators, but support for depth multiplier should be added at a later stage
@staticmethod
def constraint_depth_multiplier(op):
"For depth multipliers > 1, IFM channels must be 1 and OFM channels must be equal to the depth multiplier"