aboutsummaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
Diffstat (limited to 'utils')
-rw-r--r--utils/Utils.h63
1 files changed, 63 insertions, 0 deletions
diff --git a/utils/Utils.h b/utils/Utils.h
index b519f83a83..3c84c824da 100644
--- a/utils/Utils.h
+++ b/utils/Utils.h
@@ -28,6 +28,7 @@
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
#include "arm_compute/runtime/Tensor.h"
#ifdef ARM_COMPUTE_CL
@@ -320,6 +321,68 @@ void save_to_ppm(T &tensor, const std::string &ppm_filename)
ARM_COMPUTE_ERROR("Writing %s: (%s)", ppm_filename.c_str(), e.what());
}
}
+
+/** Load the tensor with pre-trained data from a binary file
+ *
+ * @param[in] tensor The tensor to be filled. Data type supported: F32.
+ * @param[in] filename Filename of the binary file to load from.
+ */
+template <typename T>
+void load_trained_data(T &tensor, const std::string &filename)
+{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::F32);
+
+ std::ifstream fs;
+
+ try
+ {
+ fs.exceptions(std::ofstream::failbit | std::ofstream::badbit | std::ofstream::eofbit);
+ // Open file
+ fs.open(filename, std::ios::in | std::ios::binary);
+
+ if(!fs.good())
+ {
+ throw std::runtime_error("Could not load binary data: " + filename);
+ }
+
+#ifdef ARM_COMPUTE_CL
+ // Map buffer if creating a CLTensor
+ if(std::is_same<typename std::decay<T>::type, arm_compute::CLTensor>::value)
+ {
+ tensor.map();
+ }
+#endif
+ Window window;
+
+ window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, 1, 1));
+
+ for(unsigned int d = 1; d < tensor.info()->num_dimensions(); ++d)
+ {
+ window.set(d, Window::Dimension(0, tensor.info()->tensor_shape()[d], 1));
+ }
+
+ arm_compute::Iterator in(&tensor, window);
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ fs.read(reinterpret_cast<std::fstream::char_type *>(in.ptr()), tensor.info()->tensor_shape()[0] * tensor.info()->element_size());
+ },
+ in);
+
+#ifdef ARM_COMPUTE_CL
+ // Unmap buffer if creating a CLTensor
+ if(std::is_same<typename std::decay<T>::type, arm_compute::CLTensor>::value)
+ {
+ tensor.unmap();
+ }
+#endif
+ }
+ catch(const std::ofstream::failure &e)
+ {
+ ARM_COMPUTE_ERROR("Writing %s: (%s)", filename.c_str(), e.what());
+ }
+}
+
} // namespace utils
} // namespace arm_compute
#endif /* __UTILS_UTILS_H__*/