diff options
author | Johan Alfvén <johan.alfven@arm.com> | 2022-10-31 14:39:02 +0100 |
---|---|---|
committer | Johan Alfvén <johan.alfven@arm.com> | 2022-11-01 17:16:14 +0100 |
commit | b9f8159af8bb4e1eeb522f26852d8cbf90cb87a3 (patch) | |
tree | 462921b7ff59cd080dab6959cefb1a0aa23ff3fd /ethosu | |
parent | 48e5159e8b34abe91f331d76e746c25b4017a96e (diff) | |
download | ethos-u-vela-b9f8159af8bb4e1eeb522f26852d8cbf90cb87a3.tar.gz |
MLBEDSW-7077: Store original tensor shape in optimized file
- CPU side always needs to work we the original tensor shape.
Due to a bypass memory optimization the IFM, produced by CPU,
was stored with the wrong shape in the optimized file.
- Store the original tensor shape so it can be correctly
written to the optimized file.
Change-Id: I666dbcb0acd806ad208c0f925a51dfc25421688b
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/tensor.py | 6 | ||||
-rw-r--r-- | ethosu/vela/tflite_writer.py | 8 |
2 files changed, 13 insertions, 1 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index ba385886..673208ac 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -340,6 +340,7 @@ class TensorAddressMap: class Tensor: __slots__ = ( "shape", + "_original_shape", "storage_shape", "bandwidth_shape", "dtype", @@ -379,6 +380,7 @@ class Tensor: def __init__(self, shape: Shape, dtype: DataType, name: str): self.shape = shape + self._original_shape = shape self.storage_shape = shape self.bandwidth_shape = shape self.dtype = dtype @@ -425,6 +427,10 @@ class Tensor: self.src_tensor: Optional[Tensor] = None @property + def original_shape(self): + return self._original_shape + + @property def address(self) -> int: return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type) diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index ce53f9b1..6fdfe019 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -26,6 +26,7 @@ from .nn_graph import PassPlacement from .operation import Op from .reader_util import align_inputs_indices from .tensor import MemType +from .tensor import shape_num_elements from .tensor import TensorPurpose from .tflite import Buffer from .tflite import Metadata @@ -248,7 +249,12 @@ class TFLiteSerialiser: def serialise_tensor(self, tens): builder = self.builder - tens_shape = tens.shape + if shape_num_elements(tens.original_shape) != shape_num_elements(tens.shape): + # shapes have changed size, therefore assume that the latest (modified) shape is correct + tens_shape = tens.shape + else: + # shapes have not changed size, therefore the original shape is valid + tens_shape = tens.original_shape values = tens.values if values is None: |