aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-04-04 11:14:06 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2024-04-04 16:51:08 +0100
commit8f9e2842ce7d25645233ad4f6fa406be982346ae (patch)
treed540cb38fca9733ec983afafb96ed2ccfc626cdc
parentad78daaf0fa1e41742cbed314459c3dbbb483c20 (diff)
downloadserialization_lib-8f9e2842ce7d25645233ad4f6fa406be982346ae.tar.gz
Fix rank 0 support in serialization_lib
Numpy rank 0 files correctly written as shape () instead of (1) Constant tensors of rank 0 now have data written out Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ie4bad8f798674cdb0484955e9db684f7f4100145
-rw-r--r--python/serializer/tosa_serializer.py9
-rw-r--r--src/numpy_utils.cpp9
2 files changed, 10 insertions, 8 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index 9658edf..e6ab3d0 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -407,12 +407,15 @@ class TosaSerializerTensor:
if isinstance(data, np.ndarray):
data = data.flatten().astype(fntype).tolist()
data = list(map(fntype, data))
- self.data = data
elif isinstance(data, list):
data = list(map(fntype, data))
- self.data = data
+ elif data is not None:
+ # Assume data is rank 0 data type
+ data = list(map(fntype, [data]))
else:
- self.data = None
+ data = None
+
+ self.data = data
# Filename for placeholder tensors. These get generated by the test generation
# process and are written to disk, but are considered input tensors by the
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index 5fe0490..e4171d7 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -432,12 +432,11 @@ NumpyUtilities::NPError
// Output the format dictionary
// Hard-coded for I32 for now
- headerPos +=
- snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
- dtype_str, shape.empty() ? 1 : shape[0]);
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos,
+ "'descr': %s, 'fortran_order': False, 'shape': (", dtype_str);
- // Remainder of shape array
- for (i = 1; i < shape.size(); i++)
+ // Add shape contents (if any - as this will be empty for rank 0)
+ for (i = 0; i < shape.size(); i++)
{
headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
}