aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py14
1 files changed, 13 insertions, 1 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index d62ebc8e..8c5e277a 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -209,7 +209,17 @@ def create_equivalence_id(key) -> UUID:
class QuantizationParameters:
- __slots__ = "min", "max", "num_bits", "narrow_range", "scale_f32", "zero_point", "quant_min", "quant_max"
+ __slots__ = (
+ "min",
+ "max",
+ "num_bits",
+ "narrow_range",
+ "scale_f32",
+ "zero_point",
+ "quant_min",
+ "quant_max",
+ "quant_dim",
+ )
def __init__(
self,
@@ -228,6 +238,7 @@ class QuantizationParameters:
self.zero_point: Union[int, np.ndarray, None] = None
self.quant_min: Optional[float] = None
self.quant_max: Optional[float] = None
+ self.quant_dim: Optional[int] = None
def __str__(self):
return "<nng.QuantizationParameters min=%s max=%s, num_bits=%s, scale=%s, zero_point=%s>" % (
@@ -252,6 +263,7 @@ class QuantizationParameters:
res.zero_point = self.zero_point
res.quant_min = self.quant_min
res.quant_max = self.quant_max
+ res.quant_dim = self.quant_dim
return res
def dequantize(self, values) -> np.ndarray: