diff options
author | steniu01 <steven.niu@arm.com> | 2017-06-21 16:45:41 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-09-17 14:14:20 +0100 |
commit | bee466b5eac4ec39d4032d946c9a4aee051f2b31 (patch) | |
tree | 264e5124e7d2e1ccb3277b0ef478f0bb4a145a0a /utils/Utils.h | |
parent | 8af2dd6eb230f2205070dce50c2a22bdf2d55e46 (diff) | |
download | ComputeLibrary-bee466b5eac4ec39d4032d946c9a4aee051f2b31.tar.gz |
COMPID-345 Add caffe_data_extractor.py script and the instructions
Change-Id: Ibb84b2060c4d6362be9ce4b1757e273e013de618
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78630
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'utils/Utils.h')
-rw-r--r-- | utils/Utils.h | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/utils/Utils.h b/utils/Utils.h index b519f83a83..3c84c824da 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -28,6 +28,7 @@ #include "arm_compute/core/ITensor.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/core/Window.h" #include "arm_compute/runtime/Tensor.h" #ifdef ARM_COMPUTE_CL @@ -320,6 +321,68 @@ void save_to_ppm(T &tensor, const std::string &ppm_filename) ARM_COMPUTE_ERROR("Writing %s: (%s)", ppm_filename.c_str(), e.what()); } } + +/** Load the tensor with pre-trained data from a binary file + * + * @param[in] tensor The tensor to be filled. Data type supported: F32. + * @param[in] filename Filename of the binary file to load from. + */ +template <typename T> +void load_trained_data(T &tensor, const std::string &filename) +{ + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::F32); + + std::ifstream fs; + + try + { + fs.exceptions(std::ofstream::failbit | std::ofstream::badbit | std::ofstream::eofbit); + // Open file + fs.open(filename, std::ios::in | std::ios::binary); + + if(!fs.good()) + { + throw std::runtime_error("Could not load binary data: " + filename); + } + +#ifdef ARM_COMPUTE_CL + // Map buffer if creating a CLTensor + if(std::is_same<typename std::decay<T>::type, arm_compute::CLTensor>::value) + { + tensor.map(); + } +#endif + Window window; + + window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, 1, 1)); + + for(unsigned int d = 1; d < tensor.info()->num_dimensions(); ++d) + { + window.set(d, Window::Dimension(0, tensor.info()->tensor_shape()[d], 1)); + } + + arm_compute::Iterator in(&tensor, window); + + execute_window_loop(window, [&](const Coordinates & id) + { + fs.read(reinterpret_cast<std::fstream::char_type *>(in.ptr()), tensor.info()->tensor_shape()[0] * tensor.info()->element_size()); + }, + in); + +#ifdef ARM_COMPUTE_CL + // Unmap buffer if creating a CLTensor + if(std::is_same<typename std::decay<T>::type, arm_compute::CLTensor>::value) + { + tensor.unmap(); + } +#endif + } + catch(const std::ofstream::failure &e) + { + ARM_COMPUTE_ERROR("Writing %s: (%s)", filename.c_str(), e.what()); + } +} + } // namespace utils } // namespace arm_compute #endif /* __UTILS_UTILS_H__*/ |