diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index d62ebc8e..8c5e277a 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -209,7 +209,17 @@ def create_equivalence_id(key) -> UUID: class QuantizationParameters: - __slots__ = "min", "max", "num_bits", "narrow_range", "scale_f32", "zero_point", "quant_min", "quant_max" + __slots__ = ( + "min", + "max", + "num_bits", + "narrow_range", + "scale_f32", + "zero_point", + "quant_min", + "quant_max", + "quant_dim", + ) def __init__( self, @@ -228,6 +238,7 @@ class QuantizationParameters: self.zero_point: Union[int, np.ndarray, None] = None self.quant_min: Optional[float] = None self.quant_max: Optional[float] = None + self.quant_dim: Optional[int] = None def __str__(self): return "<nng.QuantizationParameters min=%s max=%s, num_bits=%s, scale=%s, zero_point=%s>" % ( @@ -252,6 +263,7 @@ class QuantizationParameters: res.zero_point = self.zero_point res.quant_min = self.quant_min res.quant_max = self.quant_max + res.quant_dim = self.quant_dim return res def dequantize(self, values) -> np.ndarray: |