diff options
Diffstat (limited to 'tests/AssetsLibrary.cpp')
-rw-r--r-- | tests/AssetsLibrary.cpp | 44 |
1 files changed, 43 insertions, 1 deletions
diff --git a/tests/AssetsLibrary.cpp b/tests/AssetsLibrary.cpp index c6d86d1c1a..eafa6314b1 100644 --- a/tests/AssetsLibrary.cpp +++ b/tests/AssetsLibrary.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -28,6 +28,11 @@ #include "arm_compute/core/ITensor.h" +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#include "libnpy/npy.hpp" +#pragma GCC diagnostic pop + #include <cctype> #include <fstream> #include <limits> @@ -511,5 +516,42 @@ RawTensor AssetsLibrary::get(const std::string &name, Format format, Channel cha { return RawTensor(find_or_create_raw_tensor(name, format, channel)); } + +namespace detail +{ +inline void validate_npy_header(std::ifstream &stream, const std::string &expect_typestr, const TensorShape &expect_shape) +{ + ARM_COMPUTE_UNUSED(expect_typestr); + ARM_COMPUTE_UNUSED(expect_shape); + + std::string header = npy::read_header(stream); + + // Parse header + std::vector<unsigned long> shape; + bool fortran_order = false; + std::string typestr; + npy::parse_header(header, typestr, fortran_order, shape); + + // Check if the typestring matches the given one + ARM_COMPUTE_ERROR_ON_MSG(typestr != expect_typestr, "Typestrings mismatch"); + + // Validate tensor shape + ARM_COMPUTE_ERROR_ON_MSG(shape.size() != expect_shape.num_dimensions(), "Tensor ranks mismatch"); + if(fortran_order) + { + for(size_t i = 0; i < shape.size(); ++i) + { + ARM_COMPUTE_ERROR_ON_MSG(expect_shape[i] != shape[i], "Tensor dimensions mismatch"); + } + } + else + { + for(size_t i = 0; i < shape.size(); ++i) + { + ARM_COMPUTE_ERROR_ON_MSG(expect_shape[i] != shape[shape.size() - i - 1], "Tensor dimensions mismatch"); + } + } +} +} // namespace detail } // namespace test } // namespace arm_compute |