aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py39
1 files changed, 34 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 3b7f248a..b2644791 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -41,7 +41,13 @@ def _optype_formatter(op_list):
class TFLiteSemantic:
# Categorised lists of operators
- convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
+ convolution_ops = set(
+ (
+ Op.Conv2DBias,
+ Op.Conv2D,
+ Op.QuantizedConv2D,
+ )
+ )
depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
@@ -49,13 +55,36 @@ class TFLiteSemantic:
avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
- binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
- binary_elem_wise_shift_ops = set((Op.SHL, Op.SHR,))
- binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
+ binary_elem_wise_min_max_ops = set(
+ (
+ Op.Minimum,
+ Op.Maximum,
+ )
+ )
+ binary_elem_wise_shift_ops = set(
+ (
+ Op.SHL,
+ Op.SHR,
+ )
+ )
+ binary_elem_wise_add_mul_sub = set(
+ (
+ Op.Add,
+ Op.Mul,
+ Op.Sub,
+ )
+ )
binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims))
- reshape_ops = set((Op.Reshape, Op.QuantizedReshape, Op.Squeeze, Op.ExpandDims,))
+ reshape_ops = set(
+ (
+ Op.Reshape,
+ Op.QuantizedReshape,
+ Op.Squeeze,
+ Op.ExpandDims,
+ )
+ )
def __init__(self):
# Setup the generic constraints. Note: the order matters