aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2020-10-27 11:57:57 +0000
committertim.hall <tim.hall@arm.com>2020-11-20 08:52:14 +0000
commit8956761a84f413e6f4c9c7d6e4409b145f81c289 (patch)
tree3d6918b6ea2eb42a3d45d45098920b371bae8844
parent083f103fe612a88f41495022af89d5a12ea4aded (diff)
downloadethos-u-vela-8956761a84f413e6f4c9c7d6e4409b145f81c289.tar.gz
vela: Improve the scaling is equal check
- Improved tensor and scaling query functions - Fixed bug in convert_batched_fc_to_conv Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: Ibc3d14036540f27cf5e993beb2163d3e0f5e5933
-rw-r--r--ethosu/vela/graph_optimiser.py2
-rw-r--r--ethosu/vela/tensor.py22
2 files changed, 5 insertions, 19 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 35932d4..4f473dd 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -333,7 +333,7 @@ def convert_batched_fc_shape(op, arch, nng):
# There is a preceding Reshape
# Compare input of prev_op and input of op, to see if prev_op can be removed
ifm_prev_op = prev_op.inputs[0]
- if ifm_prev_op.shape == ifm.shape and check_quantized_tens_scaling_equal(ifm_prev_op, ifm.quantization):
+ if ifm_prev_op.shape == ifm.shape and check_quantized_tens_scaling_equal(ifm_prev_op, ifm):
# prev_op can be removed
op.set_input_tensor(ifm_prev_op, 0)
else:
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 0d299e1..3601c92 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -236,11 +236,9 @@ class QuantizationParameters:
# it implies that the tensor it belongs to is not quantised. otherwise,
# it depends upon whether the scale and zero point are equal
- if other is None:
+ if not isinstance(other, QuantizationParameters):
return False
- assert isinstance(other, QuantizationParameters)
-
return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
def is_valid(self):
@@ -793,13 +791,10 @@ class Tensor:
def is_quantized(self):
# a tensor is quantized if it has an integral type and it contains valid quantization params
- if (self.dtype.type & BaseType.Int) == 0 or self.quantization is None:
+ if not isinstance(self.quantization, QuantizationParameters):
return False
- assert isinstance(self.quantization, QuantizationParameters)
- assert self.quantization.is_valid()
-
- return True
+ return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
def __str__(self):
return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
@@ -807,16 +802,7 @@ class Tensor:
__repr__ = __str__
-def check_tens_quantized(tens):
- # checks that a tensor is quantized
-
- return isinstance(tens, Tensor) and tens.is_quantized()
-
-
def check_quantized_tens_scaling_equal(tens_a, tens_b):
# checks that the scaling of two quantized tensors are equal
- assert check_tens_quantized(tens_a)
- assert check_tens_quantized(tens_b)
-
- return tens_a.quantization.is_scaling_equal(tens_b.quantization)
+ return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)