diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 6328a4e5..25a34e82 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -39,7 +39,12 @@ def _optype_formatter(op_list): class TFLiteSupportedOperators: # Categorised lists of supported operators - npu_pre_ops = set((Op.SplitSliceRead,)) + npu_pre_ops = set( + ( + Op.SplitSliceRead, + Op.Shape, + ) + ) convolution_ops = set( ( Op.Conv2DBias, @@ -103,6 +108,7 @@ class TFLiteSupportedOperators: ( Op.ReduceSum, Op.CLZ, + Op.Shape, ) ) | binary_elem_wise_add_mul_sub @@ -363,7 +369,7 @@ class TFLiteSupportedOperators: if op.type not in cls.per_axis_quant_ops: tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] for tens in tensors: - if tens.quantization.is_per_axis(): + if tens.quantization and tens.quantization.is_per_axis(): valid = False extra.append(tens.name) return valid, "The following tensor(s) have per-axis quantization parameters: " + ", ".join(extra) |