diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 43 |
1 files changed, 39 insertions, 4 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 98dfa3d3..84af8edb 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -23,6 +23,7 @@ from functools import lru_cache import numpy as np from . import numeric_util +from .data_type import BaseType from .data_type import DataType from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .operation import Op @@ -229,11 +230,22 @@ class QuantizationParameters: return res def is_scaling_equal(self, other): - if other is None or not isinstance(other, QuantizationParameters): + # quantisation parameter scaling is not equal if 'other' is None because + # it implies that the tensor it belongs to is not quantised. otherwise, + # it depends upon whether the scale and zero point are equal + + if other is None: return False + assert isinstance(other, QuantizationParameters) + return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point + def is_valid(self): + # quantisation parameters are consider valid if they have a scale and zero point + + return None not in (self.scale_f32, self.zero_point) + def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None): # Tensor @@ -765,9 +777,6 @@ class Tensor: return True return False - def is_scaling_equal(self, tens): - return self.quantization.is_scaling_equal(tens.quantization) - def equivalent(self, tens): return self.equivalence_id == tens.equivalence_id @@ -785,7 +794,33 @@ class Tensor: else: return self.shape.copy() + def is_quantized(self): + # a tensor is quantized if it has an integral type and it contains valid quantization params + + if (self.dtype.type & BaseType.Int) == 0 or self.quantization is None: + return False + + assert isinstance(self.quantisation, QuantizationParameters) + assert self.quantization.is_valid() + + return True + def __str__(self): return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype) __repr__ = __str__ + + +def check_tens_quantized(tens): + # checks that a tensor is quantized + + return isinstance(tens, Tensor) and tens.is_quantized() + + +def check_quantized_tens_scaling_equal(tens_a, tens_b): + # checks that the scaling of two quantized tensors are equal + + assert check_tens_quantized(tens_a) + assert check_tens_quantized(tens_b) + + return tens_a.quantization.is_scaling_equal(tens_b.quantization) |