From bee466b5eac4ec39d4032d946c9a4aee051f2b31 Mon Sep 17 00:00:00 2001 From: steniu01 Date: Wed, 21 Jun 2017 16:45:41 +0100 Subject: COMPID-345 Add caffe_data_extractor.py script and the instructions Change-Id: Ibb84b2060c4d6362be9ce4b1757e273e013de618 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78630 Tested-by: Kaizen Reviewed-by: Georgios Pinitas Reviewed-by: Anthony Barbier --- utils/Utils.h | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) (limited to 'utils/Utils.h') diff --git a/utils/Utils.h b/utils/Utils.h index b519f83a83..3c84c824da 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -28,6 +28,7 @@ #include "arm_compute/core/ITensor.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/core/Window.h" #include "arm_compute/runtime/Tensor.h" #ifdef ARM_COMPUTE_CL @@ -320,6 +321,68 @@ void save_to_ppm(T &tensor, const std::string &ppm_filename) ARM_COMPUTE_ERROR("Writing %s: (%s)", ppm_filename.c_str(), e.what()); } } + +/** Load the tensor with pre-trained data from a binary file + * + * @param[in] tensor The tensor to be filled. Data type supported: F32. + * @param[in] filename Filename of the binary file to load from. + */ +template +void load_trained_data(T &tensor, const std::string &filename) +{ + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::F32); + + std::ifstream fs; + + try + { + fs.exceptions(std::ofstream::failbit | std::ofstream::badbit | std::ofstream::eofbit); + // Open file + fs.open(filename, std::ios::in | std::ios::binary); + + if(!fs.good()) + { + throw std::runtime_error("Could not load binary data: " + filename); + } + +#ifdef ARM_COMPUTE_CL + // Map buffer if creating a CLTensor + if(std::is_same::type, arm_compute::CLTensor>::value) + { + tensor.map(); + } +#endif + Window window; + + window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, 1, 1)); + + for(unsigned int d = 1; d < tensor.info()->num_dimensions(); ++d) + { + window.set(d, Window::Dimension(0, tensor.info()->tensor_shape()[d], 1)); + } + + arm_compute::Iterator in(&tensor, window); + + execute_window_loop(window, [&](const Coordinates & id) + { + fs.read(reinterpret_cast(in.ptr()), tensor.info()->tensor_shape()[0] * tensor.info()->element_size()); + }, + in); + +#ifdef ARM_COMPUTE_CL + // Unmap buffer if creating a CLTensor + if(std::is_same::type, arm_compute::CLTensor>::value) + { + tensor.unmap(); + } +#endif + } + catch(const std::ofstream::failure &e) + { + ARM_COMPUTE_ERROR("Writing %s: (%s)", filename.c_str(), e.what()); + } +} + } // namespace utils } // namespace arm_compute #endif /* __UTILS_UTILS_H__*/ -- cgit v1.2.1