aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py51
1 files changed, 27 insertions, 24 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index b6f97963..d42caf58 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -20,6 +20,7 @@ from collections import defaultdict
import numpy as np
from .data_type import DataType
+from .numeric_util import full_shape
from .operation import Op
from .operation import Padding
from .supported_operators_util import docstring_format_args
@@ -206,9 +207,20 @@ class TFLiteSupportedOperators:
self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_int32_ops)
self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_dimension)
self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_quant_per_axis)
+ self.generic_constraints.append(TFLiteSupportedOperators.constraint_batch_size)
self.generic_constraints.append(TFLiteSupportedOperators.constraint_faf)
self.generic_constraints.append(TFLiteSupportedOperators.constraint_faf_type)
+ # Setup generic constraint exceptions
+ self.generic_constraints_exceptions = defaultdict(list)
+ self.generic_constraints_exceptions[Op.FullyConnected].append(TFLiteSupportedOperators.constraint_batch_size)
+ self.generic_constraints_exceptions[Op.Softmax].append(TFLiteSupportedOperators.constraint_batch_size)
+ self.generic_constraints_exceptions[Op.Reshape].append(TFLiteSupportedOperators.constraint_batch_size)
+ self.generic_constraints_exceptions[Op.Shape].append(TFLiteSupportedOperators.constraint_batch_size)
+ self.generic_constraints_exceptions[Op.Squeeze].append(TFLiteSupportedOperators.constraint_batch_size)
+ for op_type in TFLiteSupportedOperators.split_ops - set((Op.UnpackReshaped,)):
+ self.generic_constraints_exceptions[op_type].append(TFLiteSupportedOperators.constraint_batch_size)
+
# Setup specific constraints. Note: the order matters
self.specific_constraints = defaultdict(list)
@@ -223,7 +235,6 @@ class TFLiteSupportedOperators:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_limit)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_type)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_40bit)
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_batch_size)
# Depthwise Conv specific checks:
for op_type in TFLiteSupportedOperators.depthwise_convolution_ops:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_depth_multiplier)
@@ -235,7 +246,6 @@ class TFLiteSupportedOperators:
# Pooling checks:
for op_type in TFLiteSupportedOperators.pooling_ops:
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_batch_size)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range)
# AVG pooling specific checks:
for op_type in TFLiteSupportedOperators.avg_pooling_ops:
@@ -268,9 +278,7 @@ class TFLiteSupportedOperators:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_type)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_40bit)
- # Element-wise checks:
- for op_type in TFLiteSupportedOperators.elem_wise_main_ops:
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_elemwise_batch_size)
+ # Element-wise checks
# Binary Min/Max specific checks:
for op_type in TFLiteSupportedOperators.binary_elem_wise_min_max_ops:
self.specific_constraints[op_type].append(
@@ -302,7 +310,6 @@ class TFLiteSupportedOperators:
self.specific_constraints[Op.Pad].append(TFLiteSupportedOperators.constraint_pad_type)
# Mean specific checks:
- self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_batch_size)
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_avgpool)
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_int8)
@@ -319,7 +326,10 @@ class TFLiteSupportedOperators:
print(f"Info: {ext_type} '{op.name}' is a CPU only op")
return False
- for constraint in self.generic_constraints + self.specific_constraints[op.type]:
+ op_exceptions = self.generic_constraints_exceptions[op.type]
+ generic_constraints = [constraint for constraint in self.generic_constraints if constraint not in op_exceptions]
+
+ for constraint in generic_constraints + self.specific_constraints[op.type]:
valid, extra = constraint(op)
if not valid:
print(f"Warning: {ext_type} '{op.name}' is not supported on the NPU. Placing on CPU instead")
@@ -497,9 +507,16 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_batch_size(op):
"IFM Tensor batch size must be 1"
- ifm = op.ifm
- valid = ifm.shape[0] == 1
- return valid, f"Tensor '{ifm.name}' has batch size: {ifm.shape[0]}"
+ valid = True
+ extra = []
+ for tens in (op.ifm, op.ifm2):
+ if tens is not None:
+ batch_size = full_shape(4, tens.shape, 1)[0]
+ if batch_size != 1:
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has batch size: {batch_size}")
+ extra = "\n ".join(extra)
+ return valid, extra
@staticmethod
def constraint_depth_multiplier(op):
@@ -753,20 +770,6 @@ class TFLiteSupportedOperators:
return valid, f"Op has tensors with different quantization parameters to the OFM '{op.ofm.name}': {extra}"
@staticmethod
- def constraint_elemwise_batch_size(op):
- "Batch size must be 1 for Input tensors with more than 2 dimensions"
- valid = True
- extra = []
- for tens in (op.ifm, op.ifm2):
- # Unary ops have ifm2 as None
- if tens is not None:
- if (len(tens.shape) > 2) and (tens.shape[0] != 1):
- valid = False
- extra.append(tens.name)
- extra = ", ".join(extra)
- return valid, f"Op has invalid input tensors: {extra}"
-
- @staticmethod
def constraint_broadcast_shapes(op):
"Broadcasting is only allowed for rank indices with dimension 1, from either IFM1 or IFM2"
ifm_shape = op.ifm.shape