aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py99
1 files changed, 87 insertions, 12 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 4d826770..6328a4e5 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -40,7 +40,13 @@ def _optype_formatter(op_list):
class TFLiteSupportedOperators:
# Categorised lists of supported operators
npu_pre_ops = set((Op.SplitSliceRead,))
- convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
+ convolution_ops = set(
+ (
+ Op.Conv2DBias,
+ Op.Conv2D,
+ Op.QuantizedConv2D,
+ )
+ )
depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
@@ -48,7 +54,13 @@ class TFLiteSupportedOperators:
avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
resizing_ops = set((Op.ResizeBilinear,))
- fc_vector_products = set((Op.QuantizedMatMul, Op.MatMul, Op.FullyConnected,))
+ fc_vector_products = set(
+ (
+ Op.QuantizedMatMul,
+ Op.MatMul,
+ Op.FullyConnected,
+ )
+ )
mac_main_ops = (
# RNN/LSTM/GRU
set((Op.BlockLSTM,))
@@ -64,17 +76,47 @@ class TFLiteSupportedOperators:
| set((Op.Mean,))
)
unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
- binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
- binary_elem_wise_shift_ops = set((Op.SHL, Op.SHR,))
- binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
+ binary_elem_wise_min_max_ops = set(
+ (
+ Op.Minimum,
+ Op.Maximum,
+ )
+ )
+ binary_elem_wise_shift_ops = set(
+ (
+ Op.SHL,
+ Op.SHR,
+ )
+ )
+ binary_elem_wise_add_mul_sub = set(
+ (
+ Op.Add,
+ Op.Mul,
+ Op.Sub,
+ )
+ )
binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
pad_ops = set((Op.Pad,))
supported_int32_tensor_ops = (
- set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
+ set(
+ (
+ Op.ReduceSum,
+ Op.CLZ,
+ )
+ )
+ | binary_elem_wise_add_mul_sub
+ | binary_elem_wise_shift_ops
)
- relu_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip,))
+ relu_ops = set(
+ (
+ Op.Relu,
+ Op.Relu6,
+ Op.ReluN1To1,
+ Op.Clip,
+ )
+ )
activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish))
npu_post_ops = (
# activation functions
@@ -84,11 +126,44 @@ class TFLiteSupportedOperators:
# Quantization
| set((Op.Quantize,))
)
- split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,))
- concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
- memory_only_ops = set((Op.Reshape, Op.QuantizedReshape, Op.Squeeze, Op.ExpandDims,)) | concat_ops | split_ops
+ split_ops = set(
+ (
+ Op.Split,
+ Op.SplitV,
+ Op.StridedSlice,
+ Op.Slice,
+ Op.UnpackReshaped,
+ Op.Unpack,
+ )
+ )
+ concat_ops = set(
+ (
+ Op.Concat,
+ Op.ConcatTFLite,
+ Op.PackReshaped,
+ Op.Pack,
+ )
+ )
+ memory_only_ops = (
+ set(
+ (
+ Op.Reshape,
+ Op.QuantizedReshape,
+ Op.Squeeze,
+ Op.ExpandDims,
+ )
+ )
+ | concat_ops
+ | split_ops
+ )
per_axis_quant_ops = convolution_like_ops # per-axis/channel quantization only currently supported for conv ops
- supported_fused_activations = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.LUT,))
+ supported_fused_activations = relu_ops | set(
+ (
+ Op.Tanh,
+ Op.Sigmoid,
+ Op.LUT,
+ )
+ )
supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | pad_ops | npu_post_ops | memory_only_ops
# Supported data types
supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
@@ -441,7 +516,7 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_tconv_valid(op):
"""VALID padding: OFM dimensions must equal IFM dimensions multiplied by stride,
- minus difference between kernel size and stride"""
+ minus difference between kernel size and stride"""
if op.attrs["padding"] == Padding.VALID:
s_w = op.kernel.stride.x
s_h = op.kernel.stride.y