aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py34
1 files changed, 23 insertions, 11 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 160cf630..2f91f61c 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -225,7 +225,6 @@ class Tensor:
"quantization",
"weight_compressed_offsets",
"element_size_bytes",
- "reshaped",
"block_traversal",
"offset",
"cpu_tensor",
@@ -273,8 +272,6 @@ class Tensor:
# quantization parameters
self.quantization = None
-
- self.reshaped = False
self.block_traversal = TensorBlockTraversal.Default
self.resampling_mode = resampling_mode.NONE
@@ -294,20 +291,13 @@ class Tensor:
res.values = self.values
res.quant_values = self.quant_values
- res.compressed_values = self.compressed_values
res.mem_area = self.mem_area
res.format = self.format
res.purpose = self.purpose
res.sub_purpose = self.sub_purpose
res.alignment = self.alignment
- res.weight_transpose_depthwise = self.weight_transpose_depthwise
-
- res.storage_compression_scale = self.storage_compression_scale
res.bandwidth_compression_scale = self.bandwidth_compression_scale
- res.compression_scale_for_worst_weight_stream = self.compression_scale_for_worst_weight_stream
- res.weight_compression_scales = self.weight_compression_scales
res.storage_rounding_quantum = self.storage_rounding_quantum
- res.brick_size = self.brick_size
res.address = 0
if self.quantization is not None:
@@ -317,6 +307,7 @@ class Tensor:
res.resampling_mode = self.resampling_mode
+ res.copy_compressed_weight_info(self)
return res
def clone_into_fast_storage(self, arch):
@@ -324,6 +315,19 @@ class Tensor:
res.mem_area = arch.fast_storage_mem_area
return res
+ def copy_compressed_weight_info(self, src_tens):
+ # Copies compressed values + all related weight compression info from the given tensor
+ self.compressed_values = src_tens.compressed_values
+ self.storage_shape = src_tens.storage_shape
+ self.brick_size = src_tens.brick_size
+ self.weight_compression_scales = src_tens.weight_compression_scales
+ self.weight_compressed_offsets = src_tens.weight_compressed_offsets
+ self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
+ self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
+ self.storage_compression_scale = src_tens.storage_compression_scale
+ self.block_traversal = src_tens.block_traversal
+ self.weight_compression_config = src_tens.weight_compression_config
+
def set_format(self, fmt, arch):
self.format = fmt
shape_len = 0
@@ -527,6 +531,14 @@ class Tensor:
return strides
+ def needs_dma(self):
+ return len(self.ops) == 1 and self.ops[0].type == "DMA"
+
+ def get_dma_src_tensor(self):
+ # For weight tensors that need DMA: returns the source tensor in Flash, else None
+ # Note: for DMA ops, Pass.weight_tensor is referring to the SRAM weight tensor
+ return self.ops[0].inputs[0] if self.needs_dma() else None
+
def compressed_stream_index_from_coord(self, coord):
assert self.format == TensorFormat.WeightsCompressed
assert len(self.compressed_values) > 0
@@ -575,7 +587,7 @@ class Tensor:
if len(self.weight_compressed_offsets) == 0:
return 0
- if len(self.ops) == 1 and self.ops[0].type == "DMA" and self.sub_purpose == TensorSubPurpose.DoubleBuffer:
+ if self.needs_dma() and self.sub_purpose == TensorSubPurpose.DoubleBuffer:
depth = orig_coord[-1]
brick_depth = self.brick_size[-1]
# Clamp position at final element index