diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 51 |
1 files changed, 27 insertions, 24 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index b6f97963..d42caf58 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -20,6 +20,7 @@ from collections import defaultdict import numpy as np from .data_type import DataType +from .numeric_util import full_shape from .operation import Op from .operation import Padding from .supported_operators_util import docstring_format_args @@ -206,9 +207,20 @@ class TFLiteSupportedOperators: self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_int32_ops) self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_dimension) self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_quant_per_axis) + self.generic_constraints.append(TFLiteSupportedOperators.constraint_batch_size) self.generic_constraints.append(TFLiteSupportedOperators.constraint_faf) self.generic_constraints.append(TFLiteSupportedOperators.constraint_faf_type) + # Setup generic constraint exceptions + self.generic_constraints_exceptions = defaultdict(list) + self.generic_constraints_exceptions[Op.FullyConnected].append(TFLiteSupportedOperators.constraint_batch_size) + self.generic_constraints_exceptions[Op.Softmax].append(TFLiteSupportedOperators.constraint_batch_size) + self.generic_constraints_exceptions[Op.Reshape].append(TFLiteSupportedOperators.constraint_batch_size) + self.generic_constraints_exceptions[Op.Shape].append(TFLiteSupportedOperators.constraint_batch_size) + self.generic_constraints_exceptions[Op.Squeeze].append(TFLiteSupportedOperators.constraint_batch_size) + for op_type in TFLiteSupportedOperators.split_ops - set((Op.UnpackReshaped,)): + self.generic_constraints_exceptions[op_type].append(TFLiteSupportedOperators.constraint_batch_size) + # Setup specific constraints. Note: the order matters self.specific_constraints = defaultdict(list) @@ -223,7 +235,6 @@ class TFLiteSupportedOperators: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_limit) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_type) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_40bit) - self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_batch_size) # Depthwise Conv specific checks: for op_type in TFLiteSupportedOperators.depthwise_convolution_ops: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_depth_multiplier) @@ -235,7 +246,6 @@ class TFLiteSupportedOperators: # Pooling checks: for op_type in TFLiteSupportedOperators.pooling_ops: - self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_batch_size) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range) # AVG pooling specific checks: for op_type in TFLiteSupportedOperators.avg_pooling_ops: @@ -268,9 +278,7 @@ class TFLiteSupportedOperators: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_type) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_40bit) - # Element-wise checks: - for op_type in TFLiteSupportedOperators.elem_wise_main_ops: - self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_elemwise_batch_size) + # Element-wise checks # Binary Min/Max specific checks: for op_type in TFLiteSupportedOperators.binary_elem_wise_min_max_ops: self.specific_constraints[op_type].append( @@ -302,7 +310,6 @@ class TFLiteSupportedOperators: self.specific_constraints[Op.Pad].append(TFLiteSupportedOperators.constraint_pad_type) # Mean specific checks: - self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_batch_size) self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_avgpool) self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product) self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_int8) @@ -319,7 +326,10 @@ class TFLiteSupportedOperators: print(f"Info: {ext_type} '{op.name}' is a CPU only op") return False - for constraint in self.generic_constraints + self.specific_constraints[op.type]: + op_exceptions = self.generic_constraints_exceptions[op.type] + generic_constraints = [constraint for constraint in self.generic_constraints if constraint not in op_exceptions] + + for constraint in generic_constraints + self.specific_constraints[op.type]: valid, extra = constraint(op) if not valid: print(f"Warning: {ext_type} '{op.name}' is not supported on the NPU. Placing on CPU instead") @@ -497,9 +507,16 @@ class TFLiteSupportedOperators: @staticmethod def constraint_batch_size(op): "IFM Tensor batch size must be 1" - ifm = op.ifm - valid = ifm.shape[0] == 1 - return valid, f"Tensor '{ifm.name}' has batch size: {ifm.shape[0]}" + valid = True + extra = [] + for tens in (op.ifm, op.ifm2): + if tens is not None: + batch_size = full_shape(4, tens.shape, 1)[0] + if batch_size != 1: + valid = False + extra.append(f"Tensor '{tens.name}' has batch size: {batch_size}") + extra = "\n ".join(extra) + return valid, extra @staticmethod def constraint_depth_multiplier(op): @@ -753,20 +770,6 @@ class TFLiteSupportedOperators: return valid, f"Op has tensors with different quantization parameters to the OFM '{op.ofm.name}': {extra}" @staticmethod - def constraint_elemwise_batch_size(op): - "Batch size must be 1 for Input tensors with more than 2 dimensions" - valid = True - extra = [] - for tens in (op.ifm, op.ifm2): - # Unary ops have ifm2 as None - if tens is not None: - if (len(tens.shape) > 2) and (tens.shape[0] != 1): - valid = False - extra.append(tens.name) - extra = ", ".join(extra) - return valid, f"Op has invalid input tensors: {extra}" - - @staticmethod def constraint_broadcast_shapes(op): "Broadcasting is only allowed for rank indices with dimension 1, from either IFM1 or IFM2" ifm_shape = op.ifm.shape |