diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 22 |
1 files changed, 4 insertions, 18 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 0d299e15..3601c929 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -236,11 +236,9 @@ class QuantizationParameters: # it implies that the tensor it belongs to is not quantised. otherwise, # it depends upon whether the scale and zero point are equal - if other is None: + if not isinstance(other, QuantizationParameters): return False - assert isinstance(other, QuantizationParameters) - return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point def is_valid(self): @@ -793,13 +791,10 @@ class Tensor: def is_quantized(self): # a tensor is quantized if it has an integral type and it contains valid quantization params - if (self.dtype.type & BaseType.Int) == 0 or self.quantization is None: + if not isinstance(self.quantization, QuantizationParameters): return False - assert isinstance(self.quantization, QuantizationParameters) - assert self.quantization.is_valid() - - return True + return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid() def __str__(self): return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype) @@ -807,16 +802,7 @@ class Tensor: __repr__ = __str__ -def check_tens_quantized(tens): - # checks that a tensor is quantized - - return isinstance(tens, Tensor) and tens.is_quantized() - - def check_quantized_tens_scaling_equal(tens_a, tens_b): # checks that the scaling of two quantized tensors are equal - assert check_tens_quantized(tens_a) - assert check_tens_quantized(tens_b) - - return tens_a.quantization.is_scaling_equal(tens_b.quantization) + return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization) |