From 8956761a84f413e6f4c9c7d6e4409b145f81c289 Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Tue, 27 Oct 2020 11:57:57 +0000 Subject: vela: Improve the scaling is equal check - Improved tensor and scaling query functions - Fixed bug in convert_batched_fc_to_conv Signed-off-by: Tim Hall Change-Id: Ibc3d14036540f27cf5e993beb2163d3e0f5e5933 --- ethosu/vela/graph_optimiser.py | 2 +- ethosu/vela/tensor.py | 22 ++++------------------ 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 35932d49..4f473ddb 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -333,7 +333,7 @@ def convert_batched_fc_shape(op, arch, nng): # There is a preceding Reshape # Compare input of prev_op and input of op, to see if prev_op can be removed ifm_prev_op = prev_op.inputs[0] - if ifm_prev_op.shape == ifm.shape and check_quantized_tens_scaling_equal(ifm_prev_op, ifm.quantization): + if ifm_prev_op.shape == ifm.shape and check_quantized_tens_scaling_equal(ifm_prev_op, ifm): # prev_op can be removed op.set_input_tensor(ifm_prev_op, 0) else: diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 0d299e15..3601c929 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -236,11 +236,9 @@ class QuantizationParameters: # it implies that the tensor it belongs to is not quantised. otherwise, # it depends upon whether the scale and zero point are equal - if other is None: + if not isinstance(other, QuantizationParameters): return False - assert isinstance(other, QuantizationParameters) - return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point def is_valid(self): @@ -793,13 +791,10 @@ class Tensor: def is_quantized(self): # a tensor is quantized if it has an integral type and it contains valid quantization params - if (self.dtype.type & BaseType.Int) == 0 or self.quantization is None: + if not isinstance(self.quantization, QuantizationParameters): return False - assert isinstance(self.quantization, QuantizationParameters) - assert self.quantization.is_valid() - - return True + return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid() def __str__(self): return "" % (self.name, self.shape, self.dtype) @@ -807,16 +802,7 @@ class Tensor: __repr__ = __str__ -def check_tens_quantized(tens): - # checks that a tensor is quantized - - return isinstance(tens, Tensor) and tens.is_quantized() - - def check_quantized_tens_scaling_equal(tens_a, tens_b): # checks that the scaling of two quantized tensors are equal - assert check_tens_quantized(tens_a) - assert check_tens_quantized(tens_b) - - return tens_a.quantization.is_scaling_equal(tens_b.quantization) + return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization) -- cgit v1.2.1