aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py35
1 files changed, 14 insertions, 21 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 7dbdcddf..677757ca 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -254,20 +254,8 @@ class QuantizationParameters:
res.quant_max = self.quant_max
return res
- def dequantize(self, values):
- if self.zero_point.size == 1 and self.scale_f32.size == 1:
- # same scale is used for all values
- res = (values.astype(np.float64) - self.zero_point) * self.scale_f32
- else:
- # a different scale is used for different sets of values
- values_as_float = values.astype(np.float64)
-
- # this is not compatible with the format of depthwise weights,
- # where input is at index 3 (Output, Kh, Kw, Input)
- # return the quantized values
- return np.ndarray((values_as_float.shape))
-
- return res
+ def dequantize(self, values) -> np.ndarray:
+ return np.subtract(values, self.zero_point) * self.scale_f32
def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
# quantisation parameter scaling is not equal if 'other' is None because
@@ -300,16 +288,12 @@ def create_const_tensor(
value_dtype: np.dtype = None,
purpose: TensorPurpose = TensorPurpose.Unknown,
quantization: QuantizationParameters = None,
- quant_value_dtype: np.dtype = None,
):
# Tensor
const_tensor = Tensor(shape, dtype, name + "_0")
const_tensor.purpose = purpose
const_tensor.quantization = quantization
const_tensor.values = np.array(values, dtype=value_dtype)
- const_tensor.quant_values = np.frombuffer(
- const_tensor.values.tobytes(), dtype=np.uint8 if not quant_value_dtype else quant_value_dtype
- )
# Operator
const_op = Operation(Op.Const, name)
const_op.set_output_tensor(const_tensor)
@@ -349,7 +333,6 @@ class Tensor:
"ops",
"consumer_list",
"values",
- "quant_values",
"compressed_values",
"compressed_values_substream_offsets",
"mem_area",
@@ -391,8 +374,7 @@ class Tensor:
self.ops: List[Operation] = []
self.consumer_list: List[Operation] = []
- self.values: Optional[np.ndarray] = None
- self.quant_values: Optional[np.ndarray] = None
+ self.values: Optional[np.ndarray] = None # elements are of type self.dtype
self.compressed_values: Optional[np.ndarray] = None
self.compressed_values_substream_offsets: Optional[List] = None
self.mem_area: MemArea = MemArea.Unknown
@@ -816,6 +798,17 @@ class Tensor:
return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
+ def get_scalar(self):
+ """
+ return: Unquantized or dequantized scalar value
+ rtype: self.dtype (if unquantized) or float (if dequantized)
+ """
+ assert self.values.size == 1, "get_scalar called on non-scalar tensor"
+ if self.is_quantized():
+ return self.quantization.dequantize(self.values).item(0)
+ else:
+ return self.values.item(0)
+
def __lt__(self, other: "Tensor") -> bool:
return self.equivalence_id < other.equivalence_id