aboutsummaryrefslogtreecommitdiff
path: root/utils/Utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'utils/Utils.h')
-rw-r--r--utils/Utils.h15
1 files changed, 9 insertions, 6 deletions
diff --git a/utils/Utils.h b/utils/Utils.h
index d46fbc3633..e3a5bb2c3c 100644
--- a/utils/Utils.h
+++ b/utils/Utils.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -143,7 +143,7 @@ std::tuple<unsigned int, unsigned int, int> parse_ppm_header(std::ifstream &fs);
*
* @return The width and height stored in the header of the NPY file
*/
-std::tuple<std::vector<unsigned long>, bool, std::string> parse_npy_header(std::ifstream &fs);
+npy::header_t parse_npy_header(std::ifstream &fs);
/** Obtain numpy type string from DataType.
*
@@ -305,7 +305,10 @@ public:
_fs.exceptions(std::ifstream::failbit | std::ifstream::badbit);
_file_layout = file_layout;
- std::tie(_shape, _fortran_order, _typestring) = parse_npy_header(_fs);
+ npy::header_t header = parse_npy_header(_fs);
+ _shape = header.shape;
+ _fortran_order = header.fortran_order;
+ _typestring = header.dtype.str();
}
catch(const std::ifstream::failure &e)
{
@@ -603,11 +606,11 @@ void save_to_npy(T &tensor, const std::string &npy_filename, bool fortran_order)
using typestring_type = typename std::conditional<std::is_floating_point<U>::value, float, qasymm8_t>::type;
std::vector<typestring_type> tmp; /* Used only to get the typestring */
- npy::Typestring typestring_o{ tmp };
- std::string typestring = typestring_o.str();
+ const npy::dtype_t dtype = npy::dtype_map.at(std::type_index(typeid(tmp)));
std::ofstream stream(npy_filename, std::ofstream::binary);
- npy::write_header(stream, typestring, fortran_order, shape);
+ npy::header_t header{ dtype, fortran_order, shape };
+ npy::write_header(stream, header);
arm_compute::Window window;
window.use_tensor_dimensions(tensor.info()->tensor_shape());