aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 6328a4e5..25a34e82 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -39,7 +39,12 @@ def _optype_formatter(op_list):
class TFLiteSupportedOperators:
# Categorised lists of supported operators
- npu_pre_ops = set((Op.SplitSliceRead,))
+ npu_pre_ops = set(
+ (
+ Op.SplitSliceRead,
+ Op.Shape,
+ )
+ )
convolution_ops = set(
(
Op.Conv2DBias,
@@ -103,6 +108,7 @@ class TFLiteSupportedOperators:
(
Op.ReduceSum,
Op.CLZ,
+ Op.Shape,
)
)
| binary_elem_wise_add_mul_sub
@@ -363,7 +369,7 @@ class TFLiteSupportedOperators:
if op.type not in cls.per_axis_quant_ops:
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
for tens in tensors:
- if tens.quantization.is_per_axis():
+ if tens.quantization and tens.quantization.is_per_axis():
valid = False
extra.append(tens.name)
return valid, "The following tensor(s) have per-axis quantization parameters: " + ", ".join(extra)