diff options
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r-- | ethosu/vela/tflite_writer.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index ce53f9b1..6fdfe019 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -26,6 +26,7 @@ from .nn_graph import PassPlacement from .operation import Op from .reader_util import align_inputs_indices from .tensor import MemType +from .tensor import shape_num_elements from .tensor import TensorPurpose from .tflite import Buffer from .tflite import Metadata @@ -248,7 +249,12 @@ class TFLiteSerialiser: def serialise_tensor(self, tens): builder = self.builder - tens_shape = tens.shape + if shape_num_elements(tens.original_shape) != shape_num_elements(tens.shape): + # shapes have changed size, therefore assume that the latest (modified) shape is correct + tens_shape = tens.shape + else: + # shapes have not changed size, therefore the original shape is valid + tens_shape = tens.original_shape values = tens.values if values is None: |