diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index f2a8c803..b1a56605 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1408,10 +1408,8 @@ def optimise_quantize(op: Operation, arch, nng): if input_values.ndim == 0: input_values = np.array([input_values]) - # requantized int8 to int8 - if (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) or ( - ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16 - ): + # requantized int8 to int8 or int16 to int16 + if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16: # scale needs to use double precision to match TFLite reference kernel effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32) @@ -1495,7 +1493,7 @@ def supported_operator_check(op, arch, nng): def tflite_optimise_graph(nng, arch): - # Compile time optimisations + # Compile time static optimisations optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor] for idx, sg in enumerate(nng.subgraphs): |