From 8f9e2842ce7d25645233ad4f6fa406be982346ae Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 4 Apr 2024 11:14:06 +0100 Subject: Fix rank 0 support in serialization_lib Numpy rank 0 files correctly written as shape () instead of (1) Constant tensors of rank 0 now have data written out Signed-off-by: Jeremy Johnson Change-Id: Ie4bad8f798674cdb0484955e9db684f7f4100145 --- src/numpy_utils.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index 5fe0490..e4171d7 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -432,12 +432,11 @@ NumpyUtilities::NPError // Output the format dictionary // Hard-coded for I32 for now - headerPos += - snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,", - dtype_str, shape.empty() ? 1 : shape[0]); + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, + "'descr': %s, 'fortran_order': False, 'shape': (", dtype_str); - // Remainder of shape array - for (i = 1; i < shape.size(); i++) + // Add shape contents (if any - as this will be empty for rank 0) + for (i = 0; i < shape.size(); i++) { headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]); } -- cgit v1.2.1