aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py19
1 files changed, 6 insertions, 13 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 66bed59d..5fdea979 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -184,19 +184,6 @@ class QuantizationParameters:
__repr__ = __str__
- def __eq__(self, other):
- if other is None:
- return False
- if not isinstance(other, QuantizationParameters):
- return False
-
- pairs = ((getattr(self, s), getattr(other, s)) for s in QuantizationParameters.__slots__)
-
- return all(np.array_equal(a, b) for a, b in pairs)
-
- def __ne__(self, other):
- return not self == other
-
def clone(self):
res = QuantizationParameters()
res.min = self.min
@@ -232,6 +219,12 @@ class QuantizationParameters:
return res
+ def is_scaling_equal(self, other):
+ if other is None or not isinstance(other, QuantizationParameters):
+ return False
+
+ return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
+
def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None):
# Tensor