diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 86306cad..51c7592e 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -397,7 +397,7 @@ class Tensor: "block_traversal", "equivalence_id", "src_tensor", - "needs_linear_format", + "force_linear_format", "ifm_write_protected", ) AllocationQuantum = 16 @@ -444,13 +444,19 @@ class Tensor: self.quantization: Optional[QuantizationParameters] = None self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default - self.needs_linear_format = True + # Keep track of whether the linear format should be enforced + self.force_linear_format: Optional[bool] = None self.ifm_write_protected = False # Reference to parent-tensor if this tensor is a clone self.src_tensor: Optional[Tensor] = None @property + def use_linear_format(self) -> bool: + """Return whether the tensor should use linear format or not.""" + return self.force_linear_format in (True, None) + + @property def original_shape(self): return self._original_shape @@ -545,7 +551,7 @@ class Tensor: if shape_len > 4: return - assert not (self.needs_linear_format and fmt == TensorFormat.NHCWB16) + assert not (self.use_linear_format and fmt == TensorFormat.NHCWB16) self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format] self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:]) self.brick_size = arch.brick_sizes[self.format] |