aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py22
1 files changed, 4 insertions, 18 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 0d299e15..3601c929 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -236,11 +236,9 @@ class QuantizationParameters:
# it implies that the tensor it belongs to is not quantised. otherwise,
# it depends upon whether the scale and zero point are equal
- if other is None:
+ if not isinstance(other, QuantizationParameters):
return False
- assert isinstance(other, QuantizationParameters)
-
return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
def is_valid(self):
@@ -793,13 +791,10 @@ class Tensor:
def is_quantized(self):
# a tensor is quantized if it has an integral type and it contains valid quantization params
- if (self.dtype.type & BaseType.Int) == 0 or self.quantization is None:
+ if not isinstance(self.quantization, QuantizationParameters):
return False
- assert isinstance(self.quantization, QuantizationParameters)
- assert self.quantization.is_valid()
-
- return True
+ return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
def __str__(self):
return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
@@ -807,16 +802,7 @@ class Tensor:
__repr__ = __str__
-def check_tens_quantized(tens):
- # checks that a tensor is quantized
-
- return isinstance(tens, Tensor) and tens.is_quantized()
-
-
def check_quantized_tens_scaling_equal(tens_a, tens_b):
# checks that the scaling of two quantized tensors are equal
- assert check_tens_quantized(tens_a)
- assert check_tens_quantized(tens_b)
-
- return tens_a.quantization.is_scaling_equal(tens_b.quantization)
+ return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)