aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index f2a8c803..b1a56605 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1408,10 +1408,8 @@ def optimise_quantize(op: Operation, arch, nng):
if input_values.ndim == 0:
input_values = np.array([input_values])
- # requantized int8 to int8
- if (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) or (
- ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16
- ):
+ # requantized int8 to int8 or int16 to int16
+ if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
# scale needs to use double precision to match TFLite reference kernel
effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
@@ -1495,7 +1493,7 @@ def supported_operator_check(op, arch, nng):
def tflite_optimise_graph(nng, arch):
- # Compile time optimisations
+ # Compile time static optimisations
optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
for idx, sg in enumerate(nng.subgraphs):