aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py43
1 files changed, 39 insertions, 4 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 98dfa3d3..84af8edb 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -23,6 +23,7 @@ from functools import lru_cache
import numpy as np
from . import numeric_util
+from .data_type import BaseType
from .data_type import DataType
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .operation import Op
@@ -229,11 +230,22 @@ class QuantizationParameters:
return res
def is_scaling_equal(self, other):
- if other is None or not isinstance(other, QuantizationParameters):
+ # quantisation parameter scaling is not equal if 'other' is None because
+ # it implies that the tensor it belongs to is not quantised. otherwise,
+ # it depends upon whether the scale and zero point are equal
+
+ if other is None:
return False
+ assert isinstance(other, QuantizationParameters)
+
return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
+ def is_valid(self):
+ # quantisation parameters are consider valid if they have a scale and zero point
+
+ return None not in (self.scale_f32, self.zero_point)
+
def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None):
# Tensor
@@ -765,9 +777,6 @@ class Tensor:
return True
return False
- def is_scaling_equal(self, tens):
- return self.quantization.is_scaling_equal(tens.quantization)
-
def equivalent(self, tens):
return self.equivalence_id == tens.equivalence_id
@@ -785,7 +794,33 @@ class Tensor:
else:
return self.shape.copy()
+ def is_quantized(self):
+ # a tensor is quantized if it has an integral type and it contains valid quantization params
+
+ if (self.dtype.type & BaseType.Int) == 0 or self.quantization is None:
+ return False
+
+ assert isinstance(self.quantisation, QuantizationParameters)
+ assert self.quantization.is_valid()
+
+ return True
+
def __str__(self):
return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
__repr__ = __str__
+
+
+def check_tens_quantized(tens):
+ # checks that a tensor is quantized
+
+ return isinstance(tens, Tensor) and tens.is_quantized()
+
+
+def check_quantized_tens_scaling_equal(tens_a, tens_b):
+ # checks that the scaling of two quantized tensors are equal
+
+ assert check_tens_quantized(tens_a)
+ assert check_tens_quantized(tens_b)
+
+ return tens_a.quantization.is_scaling_equal(tens_b.quantization)