diff options
Diffstat (limited to 'tests/AssetsLibrary.h')
-rw-r--r-- | tests/AssetsLibrary.h | 43 |
1 files changed, 11 insertions, 32 deletions
diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h index e625c37505..84653ed089 100644 --- a/tests/AssetsLibrary.h +++ b/tests/AssetsLibrary.h @@ -32,10 +32,6 @@ #include "arm_compute/core/Types.h" #include "arm_compute/core/Window.h" #include "arm_compute/core/utils/misc/Random.h" -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#include "libnpy/npy.hpp" -#pragma GCC diagnostic pop #include "tests/RawTensor.h" #include "tests/TensorCache.h" #include "tests/Utils.h" @@ -469,6 +465,16 @@ inline std::vector<std::pair<T, T>> convert_range_pair(const std::vector<AssetsL }); return converted; } + +/* Read npy header and check the payload is suitable for the specified type and shape + * + * @param[in] stream ifstream of the npy file + * @param[in] expect_typestr Expected typestr + * @param[in] expect_shape Shape of tensor expected to receive the data + * + * @note Advances stream to the beginning of the data payload + */ +void validate_npy_header(std::ifstream &stream, const std::string &expect_typestr, const TensorShape &expect_shape); } // namespace detail template <typename T, typename D> @@ -959,41 +965,14 @@ void AssetsLibrary::fill_layer_data(T &&tensor, std::string name) const #endif /* _WIN32 */ const std::string path = _library_path + path_separator + name; - std::vector<unsigned long> shape; - // Open file std::ifstream stream(path, std::ios::in | std::ios::binary); if(!stream.good()) { throw framework::FileNotFound("Could not load npy file: " + path); } - std::string header = npy::read_header(stream); - - // Parse header - bool fortran_order = false; - std::string typestr; - npy::parse_header(header, typestr, fortran_order, shape); - // Check if the typestring matches the given one - std::string expect_typestr = get_typestring(tensor.data_type()); - ARM_COMPUTE_ERROR_ON_MSG(typestr != expect_typestr, "Typestrings mismatch"); - - // Validate tensor shape - ARM_COMPUTE_ERROR_ON_MSG(shape.size() != tensor.shape().num_dimensions(), "Tensor ranks mismatch"); - if(fortran_order) - { - for(size_t i = 0; i < shape.size(); ++i) - { - ARM_COMPUTE_ERROR_ON_MSG(tensor.shape()[i] != shape[i], "Tensor dimensions mismatch"); - } - } - else - { - for(size_t i = 0; i < shape.size(); ++i) - { - ARM_COMPUTE_ERROR_ON_MSG(tensor.shape()[i] != shape[shape.size() - i - 1], "Tensor dimensions mismatch"); - } - } + validate_npy_header(stream, tensor.data_type(), tensor.shape()); // Read data if(tensor.padding().empty()) |