aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-10-31 14:39:02 +0100
committerJohan Alfvén <johan.alfven@arm.com>2022-11-01 17:16:14 +0100
commitb9f8159af8bb4e1eeb522f26852d8cbf90cb87a3 (patch)
tree462921b7ff59cd080dab6959cefb1a0aa23ff3fd
parent48e5159e8b34abe91f331d76e746c25b4017a96e (diff)
downloadethos-u-vela-b9f8159af8bb4e1eeb522f26852d8cbf90cb87a3.tar.gz
MLBEDSW-7077: Store original tensor shape in optimized file
- CPU side always needs to work we the original tensor shape. Due to a bypass memory optimization the IFM, produced by CPU, was stored with the wrong shape in the optimized file. - Store the original tensor shape so it can be correctly written to the optimized file. Change-Id: I666dbcb0acd806ad208c0f925a51dfc25421688b Signed-off-by: Johan Alfven <johan.alfven@arm.com>
-rw-r--r--ethosu/vela/tensor.py6
-rw-r--r--ethosu/vela/tflite_writer.py8
2 files changed, 13 insertions, 1 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index ba385886..673208ac 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -340,6 +340,7 @@ class TensorAddressMap:
class Tensor:
__slots__ = (
"shape",
+ "_original_shape",
"storage_shape",
"bandwidth_shape",
"dtype",
@@ -379,6 +380,7 @@ class Tensor:
def __init__(self, shape: Shape, dtype: DataType, name: str):
self.shape = shape
+ self._original_shape = shape
self.storage_shape = shape
self.bandwidth_shape = shape
self.dtype = dtype
@@ -425,6 +427,10 @@ class Tensor:
self.src_tensor: Optional[Tensor] = None
@property
+ def original_shape(self):
+ return self._original_shape
+
+ @property
def address(self) -> int:
return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index ce53f9b1..6fdfe019 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -26,6 +26,7 @@ from .nn_graph import PassPlacement
from .operation import Op
from .reader_util import align_inputs_indices
from .tensor import MemType
+from .tensor import shape_num_elements
from .tensor import TensorPurpose
from .tflite import Buffer
from .tflite import Metadata
@@ -248,7 +249,12 @@ class TFLiteSerialiser:
def serialise_tensor(self, tens):
builder = self.builder
- tens_shape = tens.shape
+ if shape_num_elements(tens.original_shape) != shape_num_elements(tens.shape):
+ # shapes have changed size, therefore assume that the latest (modified) shape is correct
+ tens_shape = tens.shape
+ else:
+ # shapes have not changed size, therefore the original shape is valid
+ tens_shape = tens.original_shape
values = tens.values
if values is None: