aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 86306cad..51c7592e 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -397,7 +397,7 @@ class Tensor:
"block_traversal",
"equivalence_id",
"src_tensor",
- "needs_linear_format",
+ "force_linear_format",
"ifm_write_protected",
)
AllocationQuantum = 16
@@ -444,13 +444,19 @@ class Tensor:
self.quantization: Optional[QuantizationParameters] = None
self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
- self.needs_linear_format = True
+ # Keep track of whether the linear format should be enforced
+ self.force_linear_format: Optional[bool] = None
self.ifm_write_protected = False
# Reference to parent-tensor if this tensor is a clone
self.src_tensor: Optional[Tensor] = None
@property
+ def use_linear_format(self) -> bool:
+ """Return whether the tensor should use linear format or not."""
+ return self.force_linear_format in (True, None)
+
+ @property
def original_shape(self):
return self._original_shape
@@ -545,7 +551,7 @@ class Tensor:
if shape_len > 4:
return
- assert not (self.needs_linear_format and fmt == TensorFormat.NHCWB16)
+ assert not (self.use_linear_format and fmt == TensorFormat.NHCWB16)
self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
self.brick_size = arch.brick_sizes[self.format]