diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 0f8170d4..eedbadad 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -18,6 +18,7 @@ import enum import uuid from collections import defaultdict +from functools import lru_cache import numpy as np @@ -159,6 +160,12 @@ def shape_round_to_quantum(shp, quantum): return new_shp +@lru_cache(maxsize=None) +def create_equivalence_id(key): + # Generates equivalence_id based on the given key. + return uuid.uuid4() + + class QuantizationParameters: __slots__ = "min", "max", "num_bits", "narrow_range", "scale_f32", "zero_point", "quant_min", "quant_max" @@ -303,6 +310,7 @@ class Tensor: "compression_scale_for_worst_weight_stream", "weight_compression_scales", "weight_compression_config", + "value_id", "storage_rounding_quantum", "brick_size", "quantization", @@ -342,7 +350,10 @@ class Tensor: self.bandwidth_compression_scale = 1.0 self.compression_scale_for_worst_weight_stream = 1.0 self.weight_compression_scales = None + # if two tensors have the same weight_compression_config, then they have the same compressed values self.weight_compression_config = None + # if two tensors have the same value_id, then they have the same values + self.value_id = uuid.uuid4() self.weight_compressed_offsets = [] self.storage_rounding_quantum = (1, 1, 1, 1) self.brick_size = (1, 1, 1, 1) @@ -375,7 +386,6 @@ class Tensor: res.ops = [] res.consumer_list = [] - res.equivalence_id = self.equivalence_id res.values = self.values res.quant_values = self.quant_values @@ -407,6 +417,7 @@ class Tensor: def copy_compressed_weight_info(self, src_tens): # Copies compressed values + all related weight compression info from the given tensor + self.equivalence_id = src_tens.equivalence_id self.compressed_values = src_tens.compressed_values self.compressed_values_substream_offsets = src_tens.compressed_values_substream_offsets self.storage_shape = src_tens.storage_shape @@ -418,6 +429,7 @@ class Tensor: self.storage_compression_scale = src_tens.storage_compression_scale self.block_traversal = src_tens.block_traversal self.weight_compression_config = src_tens.weight_compression_config + self.value_id = src_tens.value_id def set_format(self, fmt, arch): self.format = fmt |