aboutsummaryrefslogtreecommitdiff
path: root/utils/Utils.h
diff options
context:
space:
mode:
authorsteniu01 <steven.niu@arm.com>2017-06-21 16:45:41 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:14:20 +0100
commitbee466b5eac4ec39d4032d946c9a4aee051f2b31 (patch)
tree264e5124e7d2e1ccb3277b0ef478f0bb4a145a0a /utils/Utils.h
parent8af2dd6eb230f2205070dce50c2a22bdf2d55e46 (diff)
downloadComputeLibrary-bee466b5eac4ec39d4032d946c9a4aee051f2b31.tar.gz
COMPID-345 Add caffe_data_extractor.py script and the instructions
Change-Id: Ibb84b2060c4d6362be9ce4b1757e273e013de618 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78630 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'utils/Utils.h')
-rw-r--r--utils/Utils.h63
1 files changed, 63 insertions, 0 deletions
diff --git a/utils/Utils.h b/utils/Utils.h
index b519f83a83..3c84c824da 100644
--- a/utils/Utils.h
+++ b/utils/Utils.h
@@ -28,6 +28,7 @@
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
#include "arm_compute/runtime/Tensor.h"
#ifdef ARM_COMPUTE_CL
@@ -320,6 +321,68 @@ void save_to_ppm(T &tensor, const std::string &ppm_filename)
ARM_COMPUTE_ERROR("Writing %s: (%s)", ppm_filename.c_str(), e.what());
}
}
+
+/** Load the tensor with pre-trained data from a binary file
+ *
+ * @param[in] tensor The tensor to be filled. Data type supported: F32.
+ * @param[in] filename Filename of the binary file to load from.
+ */
+template <typename T>
+void load_trained_data(T &tensor, const std::string &filename)
+{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::F32);
+
+ std::ifstream fs;
+
+ try
+ {
+ fs.exceptions(std::ofstream::failbit | std::ofstream::badbit | std::ofstream::eofbit);
+ // Open file
+ fs.open(filename, std::ios::in | std::ios::binary);
+
+ if(!fs.good())
+ {
+ throw std::runtime_error("Could not load binary data: " + filename);
+ }
+
+#ifdef ARM_COMPUTE_CL
+ // Map buffer if creating a CLTensor
+ if(std::is_same<typename std::decay<T>::type, arm_compute::CLTensor>::value)
+ {
+ tensor.map();
+ }
+#endif
+ Window window;
+
+ window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, 1, 1));
+
+ for(unsigned int d = 1; d < tensor.info()->num_dimensions(); ++d)
+ {
+ window.set(d, Window::Dimension(0, tensor.info()->tensor_shape()[d], 1));
+ }
+
+ arm_compute::Iterator in(&tensor, window);
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ fs.read(reinterpret_cast<std::fstream::char_type *>(in.ptr()), tensor.info()->tensor_shape()[0] * tensor.info()->element_size());
+ },
+ in);
+
+#ifdef ARM_COMPUTE_CL
+ // Unmap buffer if creating a CLTensor
+ if(std::is_same<typename std::decay<T>::type, arm_compute::CLTensor>::value)
+ {
+ tensor.unmap();
+ }
+#endif
+ }
+ catch(const std::ofstream::failure &e)
+ {
+ ARM_COMPUTE_ERROR("Writing %s: (%s)", filename.c_str(), e.what());
+ }
+}
+
} // namespace utils
} // namespace arm_compute
#endif /* __UTILS_UTILS_H__*/