diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-04-04 11:14:06 +0100 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-04-04 16:51:08 +0100 |
commit | 8f9e2842ce7d25645233ad4f6fa406be982346ae (patch) | |
tree | d540cb38fca9733ec983afafb96ed2ccfc626cdc | |
parent | ad78daaf0fa1e41742cbed314459c3dbbb483c20 (diff) | |
download | serialization_lib-8f9e2842ce7d25645233ad4f6fa406be982346ae.tar.gz |
Fix rank 0 support in serialization_lib
Numpy rank 0 files correctly written as shape () instead of (1)
Constant tensors of rank 0 now have data written out
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ie4bad8f798674cdb0484955e9db684f7f4100145
-rw-r--r-- | python/serializer/tosa_serializer.py | 9 | ||||
-rw-r--r-- | src/numpy_utils.cpp | 9 |
2 files changed, 10 insertions, 8 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index 9658edf..e6ab3d0 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -407,12 +407,15 @@ class TosaSerializerTensor: if isinstance(data, np.ndarray): data = data.flatten().astype(fntype).tolist() data = list(map(fntype, data)) - self.data = data elif isinstance(data, list): data = list(map(fntype, data)) - self.data = data + elif data is not None: + # Assume data is rank 0 data type + data = list(map(fntype, [data])) else: - self.data = None + data = None + + self.data = data # Filename for placeholder tensors. These get generated by the test generation # process and are written to disk, but are considered input tensors by the diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index 5fe0490..e4171d7 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -432,12 +432,11 @@ NumpyUtilities::NPError // Output the format dictionary // Hard-coded for I32 for now - headerPos += - snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,", - dtype_str, shape.empty() ? 1 : shape[0]); + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, + "'descr': %s, 'fortran_order': False, 'shape': (", dtype_str); - // Remainder of shape array - for (i = 1; i < shape.size(); i++) + // Add shape contents (if any - as this will be empty for rank 0) + for (i = 0; i < shape.size(); i++) { headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]); } |