aboutsummaryrefslogtreecommitdiff
path: root/tests/AssetsLibrary.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/AssetsLibrary.h')
-rw-r--r--tests/AssetsLibrary.h43
1 files changed, 11 insertions, 32 deletions
diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h
index e625c37505..84653ed089 100644
--- a/tests/AssetsLibrary.h
+++ b/tests/AssetsLibrary.h
@@ -32,10 +32,6 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Window.h"
#include "arm_compute/core/utils/misc/Random.h"
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wunused-parameter"
-#include "libnpy/npy.hpp"
-#pragma GCC diagnostic pop
#include "tests/RawTensor.h"
#include "tests/TensorCache.h"
#include "tests/Utils.h"
@@ -469,6 +465,16 @@ inline std::vector<std::pair<T, T>> convert_range_pair(const std::vector<AssetsL
});
return converted;
}
+
+/* Read npy header and check the payload is suitable for the specified type and shape
+ *
+ * @param[in] stream ifstream of the npy file
+ * @param[in] expect_typestr Expected typestr
+ * @param[in] expect_shape Shape of tensor expected to receive the data
+ *
+ * @note Advances stream to the beginning of the data payload
+ */
+void validate_npy_header(std::ifstream &stream, const std::string &expect_typestr, const TensorShape &expect_shape);
} // namespace detail
template <typename T, typename D>
@@ -959,41 +965,14 @@ void AssetsLibrary::fill_layer_data(T &&tensor, std::string name) const
#endif /* _WIN32 */
const std::string path = _library_path + path_separator + name;
- std::vector<unsigned long> shape;
-
// Open file
std::ifstream stream(path, std::ios::in | std::ios::binary);
if(!stream.good())
{
throw framework::FileNotFound("Could not load npy file: " + path);
}
- std::string header = npy::read_header(stream);
-
- // Parse header
- bool fortran_order = false;
- std::string typestr;
- npy::parse_header(header, typestr, fortran_order, shape);
- // Check if the typestring matches the given one
- std::string expect_typestr = get_typestring(tensor.data_type());
- ARM_COMPUTE_ERROR_ON_MSG(typestr != expect_typestr, "Typestrings mismatch");
-
- // Validate tensor shape
- ARM_COMPUTE_ERROR_ON_MSG(shape.size() != tensor.shape().num_dimensions(), "Tensor ranks mismatch");
- if(fortran_order)
- {
- for(size_t i = 0; i < shape.size(); ++i)
- {
- ARM_COMPUTE_ERROR_ON_MSG(tensor.shape()[i] != shape[i], "Tensor dimensions mismatch");
- }
- }
- else
- {
- for(size_t i = 0; i < shape.size(); ++i)
- {
- ARM_COMPUTE_ERROR_ON_MSG(tensor.shape()[i] != shape[shape.size() - i - 1], "Tensor dimensions mismatch");
- }
- }
+ validate_npy_header(stream, tensor.data_type(), tensor.shape());
// Read data
if(tensor.padding().empty())