aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-12-08 10:02:31 +0100
committerLouis Verhaard <louis.verhaard@arm.com>2020-12-10 15:24:54 +0100
commit93719a9b8c160de3acf047eacb9196f13167bddb (patch)
tree6ce67117c2bb1eac22fa6d79e090079d312a8e5f
parent5401823aed626884e08f5f2db1d6246d9e129278 (diff)
downloadethos-u-vela-93719a9b8c160de3acf047eacb9196f13167bddb.tar.gz
MLBEDSW-3653: Added type hints to tensor.py
Change-Id: I1b35e039f43471cc0f61cb46ed4d5aff5469d11d Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
-rw-r--r--ethosu/vela/tensor.py256
1 files changed, 138 insertions, 118 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index f6e628c8..d75b7879 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -20,6 +20,12 @@ import enum
import uuid
from collections import defaultdict
from functools import lru_cache
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+from uuid import UUID
import numpy as np
@@ -29,7 +35,8 @@ from .data_type import DataType
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .operation import Op
from .operation import Operation
-from .range_set import MemoryRangeSet
+
+Shape = List
class MemType(enum.IntFlag):
@@ -40,12 +47,13 @@ class MemType(enum.IntFlag):
Scratch_fast = 4
Size = Scratch_fast + 1
- def display_name(self):
+ def display_name(self) -> str:
return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
- def identifier_name(self):
+ def identifier_name(self) -> str:
return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
+ @staticmethod
def all():
return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
@@ -62,12 +70,13 @@ class MemArea(enum.IntFlag):
Shram = 5 # for LUT
Size = Shram + 1
- def display_name(self):
+ def display_name(self) -> str:
return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
- def identifier_name(self):
+ def identifier_name(self) -> str:
return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
+ @staticmethod
def all():
return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
@@ -84,12 +93,13 @@ class TensorPurpose(enum.IntFlag):
FSBias = 5
Size = 6
- def display_name(self):
+ def display_name(self) -> str:
return ("Unknown", "Weights", "FeatureMap", "Scratch", "LUT", "FastStorageBias", "Size")[self.value]
- def identifier_name(self):
+ def identifier_name(self) -> str:
return ("unknown", "weights", "feature_map", "scratch", "lut", "fast_storage_bias", "size")[self.value]
+ @staticmethod
def all():
return (TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FSBias)
@@ -101,12 +111,13 @@ class TensorSubPurpose(enum.Enum):
RollingBufferY = 3
RollingBufferXY = 4
- def display_name(self):
+ def display_name(self) -> str:
return ("Standard", "Double Buffer", "Rolling Buffer X", "Rolling Buffer Y", "Rolling Buffer XY")[self.value]
- def identifier_name(self):
+ def identifier_name(self) -> str:
return ("standard", "double_buffer", "rolling_buffer_x", "rolling_buffer_y", "rolling_buffer_xy")[self.value]
+ @staticmethod
def all():
return (
TensorSubPurpose.Standard,
@@ -134,7 +145,7 @@ class TensorBlockTraversal(enum.Enum):
PartKernelFirst = 3
-def shape_num_elements(shp):
+def shape_num_elements(shp: Shape) -> Optional[int]:
elems = 1
if shp is None:
return None
@@ -145,7 +156,7 @@ def shape_num_elements(shp):
return elems
-def shape_fully_defined(shp):
+def shape_fully_defined(shp: Shape) -> bool:
if shp is None:
return False
for d in shp:
@@ -154,7 +165,7 @@ def shape_fully_defined(shp):
return True
-def shape_round_to_quantum(shp, quantum):
+def shape_round_to_quantum(shp: Shape, quantum: Tuple) -> Shape:
new_shp = list(shp)
# Traverse backwards using length of shape since there may be more rounding quantums than shape elements
@@ -165,7 +176,7 @@ def shape_round_to_quantum(shp, quantum):
@lru_cache(maxsize=None)
-def create_equivalence_id(key):
+def create_equivalence_id(key) -> UUID:
# Generates equivalence_id based on the given key.
return uuid.uuid4()
@@ -173,17 +184,23 @@ def create_equivalence_id(key):
class QuantizationParameters:
__slots__ = "min", "max", "num_bits", "narrow_range", "scale_f32", "zero_point", "quant_min", "quant_max"
- def __init__(self, min=None, max=None, num_bits=None, narrow_range=None):
+ def __init__(
+ self,
+ min: Union[float, np.ndarray, None] = None,
+ max: Union[float, np.ndarray, None] = None,
+ num_bits=None,
+ narrow_range=None,
+ ):
self.min = min
self.max = max
self.num_bits = num_bits
self.narrow_range = narrow_range
- self.scale_f32 = None
- self.zero_point = None
- self.quant_min = None
- self.quant_max = None
+ self.scale_f32: Union[float, np.ndarray, None] = None
+ self.zero_point: Union[int, np.ndarray, None] = None
+ self.quant_min: Optional[float] = None
+ self.quant_max: Optional[float] = None
def __str__(self):
return "<nng.QuantizationParameters min=%s max=%s, num_bits=%s, scale=%s, zero_point=%s>" % (
@@ -196,7 +213,7 @@ class QuantizationParameters:
__repr__ = __str__
- def clone(self):
+ def clone(self) -> "QuantizationParameters":
res = QuantizationParameters()
res.min = self.min
res.max = self.max
@@ -223,15 +240,9 @@ class QuantizationParameters:
# return the quantized values
return np.ndarray((values_as_float.shape))
- shape = values_as_float.shape[0]
- assert self.zero_point.size == self.scale_f32.size == shape
- res = np.ndarray(values_as_float.shape)
- for i in range(shape):
- res[i] = (values_as_float[i] - self.zero_point[i]) * self.scale_f32[i]
-
return res
- def is_scaling_equal(self, other):
+ def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
# quantisation parameter scaling is not equal if 'other' is None because
# it implies that the tensor it belongs to is not quantised. otherwise,
# it depends upon whether the scale and zero point are equal
@@ -241,12 +252,12 @@ class QuantizationParameters:
return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
- def is_valid(self):
+ def is_valid(self) -> bool:
# quantisation parameters are consider valid if they have a scale and zero point
return None not in (self.scale_f32, self.zero_point)
- def is_per_axis(self):
+ def is_per_axis(self) -> bool:
"""Returns True if either the scale, zero point, minimum or maximum values are arrays"""
for attr in ("scale_f32", "zero_point", "min", "max"):
if isinstance(getattr(self, attr), np.ndarray):
@@ -254,7 +265,15 @@ class QuantizationParameters:
return False
-def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None):
+def create_const_tensor(
+ name: str,
+ shape: Shape,
+ dtype: DataType,
+ values: np.ndarray,
+ value_dtype: np.dtype = None,
+ purpose: TensorPurpose = TensorPurpose.Unknown,
+ quantization: QuantizationParameters = None,
+):
# Tensor
const_tensor = Tensor(shape, dtype, name + "_0")
const_tensor.purpose = purpose
@@ -288,14 +307,14 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True):
# class that keeps track of all tensor addresses in the different memory types
class TensorAddressMap:
- address_map = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
+ address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
@classmethod
- def get_address_for_tens(cls, tens_id, mem_type):
+ def get_address_for_tens(cls, tens_id: UUID, mem_type: MemType) -> int:
return cls.address_map[tens_id].get(mem_type)
@classmethod
- def set_address_for_tens(cls, tens_id, mem_type, address):
+ def set_address_for_tens(cls, tens_id: UUID, mem_type: MemType, address: int):
# Check previous address if there is one
previous_address = cls.address_map[tens_id].get(mem_type)
if address is not None and previous_address is not None:
@@ -343,58 +362,58 @@ class Tensor:
)
AllocationQuantum = 16
- def __init__(self, shape, dtype, name):
+ def __init__(self, shape: Shape, dtype: DataType, name: str):
self.shape = shape
self.storage_shape = shape
self.bandwidth_shape = shape
self.dtype = dtype
self.name = name
- self.equivalence_id = uuid.uuid4()
-
- self.ops = []
- self.consumer_list = []
-
- self.values = None
- self.quant_values = None
- self.compressed_values = None
- self.compressed_values_substream_offsets = None
- self.mem_area = MemArea.Unknown
- self.mem_type = MemType.Unknown
- self.format = TensorFormat.Unknown
- self.purpose = TensorPurpose.Unknown
- self.sub_purpose = TensorSubPurpose.Standard
- self.alignment = Tensor.AllocationQuantum
- self.weight_transpose_depthwise = False
-
- self.storage_compression_scale = 1.0
- self.bandwidth_compression_scale = 1.0
- self.compression_scale_for_worst_weight_stream = 1.0
- self.weight_compression_scales = None
+ self.equivalence_id: UUID = uuid.uuid4()
+
+ self.ops: List[Operation] = []
+ self.consumer_list: List[Operation] = []
+
+ self.values: Optional[np.ndarray] = None
+ self.quant_values: Optional[np.ndarray] = None
+ self.compressed_values: Optional[np.ndarray] = None
+ self.compressed_values_substream_offsets: Optional[List] = None
+ self.mem_area: MemArea = MemArea.Unknown
+ self.mem_type: MemType = MemType.Unknown
+ self.format: TensorFormat = TensorFormat.Unknown
+ self.purpose: TensorPurpose = TensorPurpose.Unknown
+ self.sub_purpose: TensorSubPurpose = TensorSubPurpose.Standard
+ self.alignment: int = Tensor.AllocationQuantum
+ self.weight_transpose_depthwise: bool = False
+
+ self.storage_compression_scale: float = 1.0
+ self.bandwidth_compression_scale: float = 1.0
+ self.compression_scale_for_worst_weight_stream: float = 1.0
+ self.weight_compression_scales: Optional[np.ndarray] = None
# if two tensors have the same weight_compression_config, then they have the same compressed values
self.weight_compression_config = None
# if two tensors have the same value_id, then they have the same values
- self.value_id = uuid.uuid4()
- self.weight_compressed_offsets = []
- self.storage_rounding_quantum = (1, 1, 1, 1)
- self.brick_size = (1, 1, 1, 1)
- self.element_size_bytes = 0
+ self.value_id: UUID = uuid.uuid4()
+ self.weight_compressed_offsets: List = []
+ self.storage_rounding_quantum: Tuple = (1, 1, 1, 1)
+ self.brick_size: Tuple = (1, 1, 1, 1)
+ self.element_size_bytes: int = 0
# quantization parameters
- self.quantization = None
- self.block_traversal = TensorBlockTraversal.Default
- self.resampling_mode = resampling_mode.NONE
+ self.quantization: Optional[QuantizationParameters] = None
+ self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
+ self.resampling_mode: resampling_mode = resampling_mode.NONE
- self.avoid_NHCWB16 = False
+ self.avoid_NHCWB16: bool = False
@property
- def address(self):
+ def address(self) -> int:
return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
@address.setter
- def address(self, address):
+ def address(self, address: int):
TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
- def element_size(self):
+ def element_size(self) -> int:
if self.element_size_bytes == 0:
return self.dtype.size_in_bits() / 8
return self.element_size_bytes
@@ -403,7 +422,7 @@ class Tensor:
# The references to Operators will be empty when returned
# Depending on set_unique, the copy is shallow, or deep
# For set_unique==True, a new equivalence_id will be set
- def clone(self, suffix="_clone", set_unique=False):
+ def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor":
if set_unique:
res = copy.deepcopy(self)
res.equivalence_id = uuid.uuid4()
@@ -420,13 +439,13 @@ class Tensor:
return res
- def clone_into_fast_storage(self, arch):
+ def clone_into_fast_storage(self, arch) -> "Tensor":
res = self.clone(suffix="_fast_storage")
res.mem_area = arch.fast_storage_mem_area
res.mem_type = MemType.Scratch_fast
return res
- def copy_compressed_weight_info(self, src_tens):
+ def copy_compressed_weight_info(self, src_tens: "Tensor"):
# Copies compressed values + all related weight compression info from the given tensor
self.equivalence_id = src_tens.equivalence_id
self.compressed_values = src_tens.compressed_values
@@ -443,7 +462,7 @@ class Tensor:
self.weight_compression_config = src_tens.weight_compression_config
self.value_id = src_tens.value_id
- def set_format(self, fmt, arch):
+ def set_format(self, fmt: TensorFormat, arch):
self.format = fmt
shape_len = 0
try:
@@ -454,9 +473,9 @@ class Tensor:
if shape_len > 4:
return
self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
- self.storage_rounding_quantum = self.storage_rounding_quantum[-shape_len:]
+ self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
self.brick_size = arch.brick_sizes[self.format]
- self.brick_size = self.brick_size[-shape_len:]
+ self.brick_size = tuple(self.brick_size[-shape_len:])
if self.shape is None:
return
@@ -469,29 +488,31 @@ class Tensor:
self.bandwidth_compression_scale = compression_ratio
self.compression_scale_for_worst_weight_stream = compression_ratio
- def storage_elements(self):
+ def storage_elements(self) -> int:
elems = shape_num_elements(self.storage_shape)
if elems is None:
return 0
return elems
- def elements(self):
+ def elements(self) -> int:
elems = shape_num_elements(self.shape)
if elems is None:
return 0
return elems
- def has_fully_defined_shape(self):
+ def has_fully_defined_shape(self) -> bool:
return shape_fully_defined(self.shape)
- def storage_size(self, scale=1.0):
+ def storage_size(self, scale: float = 1.0) -> int:
raw_size = self.storage_elements() * self.element_size() * scale
if raw_size == 0:
raw_size = 1 # force it to take up space
rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
return rounded_size
- def storage_size_for_sub_purpose(self, arch, sub_purpose, param_a=None, param_b=None):
+ def storage_size_for_sub_purpose(
+ self, arch, sub_purpose: TensorSubPurpose, param_a: Optional[int] = None, param_b: Optional[int] = None
+ ) -> int:
alt_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
elems = shape_num_elements(alt_shape)
if elems is None:
@@ -514,23 +535,30 @@ class Tensor:
rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
return rounded_size
- def storage_shape_for_sub_purpose(self, sub_purpose, param_a, param_b):
+ def storage_shape_for_sub_purpose(
+ self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
+ ) -> Shape:
if sub_purpose == TensorSubPurpose.DoubleBuffer:
shp = list(self.shape)
assert len(shp) >= 2
+ assert param_a is not None
shp[-1] = min(shp[-1], param_a * 2)
else:
shp = list(self.storage_shape)
if sub_purpose == TensorSubPurpose.RollingBufferX:
assert len(shp) == 4
+ assert param_a is not None
shp[0] = 1
shp[2] = min(shp[2], param_a)
elif sub_purpose == TensorSubPurpose.RollingBufferY:
assert len(shp) == 4
+ assert param_a is not None
shp[0] = 1
shp[1] = min(shp[1], param_a)
elif sub_purpose == TensorSubPurpose.RollingBufferXY:
assert len(shp) == 4
+ assert param_a is not None
+ assert param_b is not None
shp[0] = 1
shp[2] = min(shp[2], param_a)
shp[1] = min(shp[1], param_b)
@@ -541,36 +569,22 @@ class Tensor:
return shp
- def set_new_sub_purpose(self, sub_purpose, param_a=None, param_b=None):
+ def set_new_sub_purpose(self, sub_purpose: TensorSubPurpose, param_a=None, param_b=None):
self.storage_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
self.sub_purpose = sub_purpose
if sub_purpose == TensorSubPurpose.DoubleBuffer:
self.storage_compression_scale = self.compression_scale_for_worst_weight_stream
- def bandwidth(self):
+ def bandwidth(self) -> float:
elems = shape_num_elements(self.bandwidth_shape)
if elems is None:
return 0
return elems * self.element_size() * self.bandwidth_compression_scale
- def consumers(self):
+ def consumers(self) -> List[Operation]:
return self.consumer_list
- def get_address_ranges_for_coordinates(self, start_coord, end_coord):
- if self.sub_purpose in (
- TensorSubPurpose.RollingBufferX,
- TensorSubPurpose.RollingBufferY,
- TensorSubPurpose.RollingBufferXY,
- ):
- # build dummy coordinates that cover the entire buffer
- start_coord = [0] * len(start_coord)
- end_coord = [min(self.storage_shape[i], self.shape[i]) for i in range(len(end_coord))]
-
- start = self.address_for_coordinate(start_coord, is_top_box=False)
- end = self.address_for_coordinate(end_coord, is_top_box=True)
- return MemoryRangeSet(self.mem_area, start, end)
-
- def addresses_for_rolling_buffer(self, start_coord, end_coord):
+ def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape) -> Tuple:
# returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
if len(start_coord) < 4:
@@ -591,7 +605,7 @@ class Tensor:
box_height0 = crossing_y - start_coord[1]
box_width = crossing_x - start_coord[2]
- addresses = [None] * 4
+ addresses: List = [None] * 4
addresses[0] = self.address_for_coordinate(start_coord)
if end_coord[2] > crossing_x:
@@ -604,10 +618,12 @@ class Tensor:
return box_height0, box_height0, box_width, addresses
- def address_for_coordinate(self, coord, is_top_box=False):
- return self.address + self.address_offset_for_coordinate(coord, is_top_box)
+ def address_for_coordinate(self, coord: Shape, is_top_box: bool = False) -> int:
+ offset = self.address_offset_for_coordinate(coord, is_top_box)
+ assert offset is not None
+ return self.address + offset
- def get_strides_and_coord(self, coord=None):
+ def get_strides_and_coord(self, coord: Optional[Shape] = None) -> Tuple[Optional[Shape], Optional[Shape]]:
if coord is None:
coord = [0] * len(self.storage_shape)
@@ -624,7 +640,6 @@ class Tensor:
if self.format == TensorFormat.NHWC:
augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
- stride_order = [4, 1, 3, 2, 0]
elif self.format == TensorFormat.NHCWB16:
channel_divisor = 16
@@ -642,10 +657,11 @@ class Tensor:
assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
return None, None
- strides = [0] * len(augmented_shape)
+ strides: List = [0] * len(augmented_shape)
stride = self.element_size() * self.storage_compression_scale
if self.format != TensorFormat.NHCWB16:
+ stride_order = [4, 1, 3, 2, 0]
for i in stride_order:
strides[i] = stride
stride *= augmented_shape[i]
@@ -659,30 +675,31 @@ class Tensor:
return strides, augmented_coord
- def get_strides(self):
+ def get_strides(self) -> Shape:
strides, _ = self.get_strides_and_coord()
-
+ assert strides is not None
return strides
- def needs_dma(self):
+ def needs_dma(self) -> bool:
return len(self.ops) == 1 and self.ops[0].type == Op.DMA
- def get_dma_src_tensor(self):
+ def get_dma_src_tensor(self) -> "Optional[Tensor]":
# For weight tensors that need DMA: returns the source tensor in Flash, else None
# Note: for DMA ops, Pass.weight_tensor is referring to the SRAM weight tensor
return self.ops[0].inputs[0] if self.needs_dma() else None
- def find_npu_op(self):
+ def find_npu_op(self) -> Optional[Operation]:
# Returns the NPU operator that uses this tensor, excluding DMA operators.
for op in self.consumers():
if op.type == Op.DMA:
return op.outputs[0].find_npu_op()
if op.run_on_npu:
return op
- return None
+ return None
- def compressed_stream_index_from_coord(self, coord):
+ def compressed_stream_index_from_coord(self, coord: Shape) -> int:
assert self.format == TensorFormat.WeightsCompressed
+ assert self.compressed_values is not None
assert len(self.compressed_values) > 0
assert len(self.compressed_values) + 1 == len(self.weight_compressed_offsets)
@@ -704,15 +721,17 @@ class Tensor:
return index
- def size_of_compressed_stream(self, index):
+ def size_of_compressed_stream(self, index: int) -> int:
+ assert self.compressed_values is not None
assert 0 <= index < len(self.compressed_values)
return len(self.compressed_values[index])
- def is_last_index_in_compressed_stream(self, index):
+ def is_last_index_in_compressed_stream(self, index: int) -> bool:
+ assert self.compressed_values is not None
assert 0 <= index < len(self.compressed_values)
return index == len(self.compressed_values) - 1
- def address_offset_for_coordinate(self, orig_coord, is_top_box=False):
+ def address_offset_for_coordinate(self, orig_coord: Shape, is_top_box: bool = False) -> Optional[int]:
address_offset = 0
coord = orig_coord
@@ -739,6 +758,7 @@ class Tensor:
# Always round up to next boundary
index = numeric_util.round_up_divide(depth, brick_depth)
index = index % 2
+ assert self.compressed_values is not None
if len(self.compressed_values) <= 2:
if is_top_box and index == 0:
@@ -775,18 +795,18 @@ class Tensor:
assert address_offset <= self.storage_size()
return address_offset
- def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area):
+ def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
- def equivalent(self, tens):
+ def equivalent(self, tens: "Tensor") -> bool:
return self.equivalence_id == tens.equivalence_id
- def set_all_shapes(self, shape):
+ def set_all_shapes(self, shape: Shape):
self.shape = shape
self.storage_shape = shape
self.bandwidth_shape = shape
- def get_full_shape(self):
+ def get_full_shape(self) -> Shape:
d = len(self.shape)
if d in (1, 3):
return numeric_util.full_shape(4, self.shape, 1)
@@ -795,7 +815,7 @@ class Tensor:
else:
return self.shape.copy()
- def is_quantized(self):
+ def is_quantized(self) -> bool:
# a tensor is quantized if it has an integral type and it contains valid quantization params
if not isinstance(self.quantization, QuantizationParameters):
@@ -809,7 +829,7 @@ class Tensor:
__repr__ = __str__
-def check_quantized_tens_scaling_equal(tens_a, tens_b):
+def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
# checks that the scaling of two quantized tensors are equal
return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)