aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/graph_vgg16.cpp21
-rw-r--r--utils/Utils.cpp36
-rw-r--r--utils/Utils.h6
3 files changed, 59 insertions, 4 deletions
diff --git a/examples/graph_vgg16.cpp b/examples/graph_vgg16.cpp
index d97c5b5d02..c3eb922f0e 100644
--- a/examples/graph_vgg16.cpp
+++ b/examples/graph_vgg16.cpp
@@ -33,6 +33,20 @@ using namespace arm_compute::utils;
using namespace arm_compute::graph;
using namespace arm_compute::graph_utils;
+namespace
+{
+/** This function checks if we can use GEMM-based convolution trying to allocate a memory of size "size_in_bytes"
+ *
+ * @param[in] size_in_bytes Memory size in bytes needed for VGG-16
+ *
+ * @return The convolution layer hint
+ */
+ConvolutionMethodHint convolution_hint_vgg16(size_t size_in_bytes)
+{
+ return ((get_mem_free_from_meminfo() * 1024) >= size_in_bytes) ? ConvolutionMethodHint::GEMM : ConvolutionMethodHint::DIRECT;
+}
+} // namespace
+
/** Example demonstrating how to implement VGG16's network using the Compute Library's graph API
*
* @param[in] argc Number of arguments
@@ -52,8 +66,11 @@ public:
constexpr float mean_b = 103.939f; /* Mean value to subtract from blue channel */
// Set target. 0 (NEON), 1 (OpenCL). By default it is NEON
- TargetHint target_hint = set_target_hint(argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0);
- ConvolutionMethodHint convolution_hint = ConvolutionMethodHint::DIRECT;
+ TargetHint target_hint = set_target_hint(argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0);
+
+ // Check if we can use GEMM-based convolutions evaluating if the platform has at least 1.8 GB of available memory
+ const size_t memory_required = 1932735283L;
+ ConvolutionMethodHint convolution_hint = convolution_hint_vgg16(memory_required);
// Parse arguments
if(argc < 2)
diff --git a/utils/Utils.cpp b/utils/Utils.cpp
index 32d5e3a6c0..8a2d11814e 100644
--- a/utils/Utils.cpp
+++ b/utils/Utils.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017, 2018 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -191,5 +191,39 @@ std::tuple<std::vector<unsigned long>, bool, std::string> parse_npy_header(std::
return std::make_tuple(shape, fortran_order, typestr);
}
+
+/** This function returns the amount of memory free reading from /proc/meminfo
+ *
+ * @return The free memory in kB
+ */
+uint64_t get_mem_free_from_meminfo()
+{
+ std::string line_attribute;
+ std::ifstream file_meminfo("/proc/meminfo");
+
+ if(file_meminfo.is_open())
+ {
+ while(!(file_meminfo >> line_attribute).fail())
+ {
+ //Test if is the line containing MemFree
+ if(line_attribute == "MemFree:")
+ {
+ uint64_t mem_available;
+ if(!(file_meminfo >> mem_available).fail())
+ {
+ return mem_available;
+ }
+ else
+ {
+ return 0;
+ }
+ }
+ // if it's not MemFree ignore rest of the line
+ file_meminfo.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
+ }
+ }
+ // Nothing found or an error during opening the file
+ return 0;
+}
} // namespace utils
} // namespace arm_compute
diff --git a/utils/Utils.h b/utils/Utils.h
index ff4c4c99fd..8822fe7121 100644
--- a/utils/Utils.h
+++ b/utils/Utils.h
@@ -895,7 +895,11 @@ void init_sgemm_output(T &dst, T &src0, T &src1, arm_compute::DataType dt)
{
dst.allocator()->init(TensorInfo(TensorShape(src1.info()->dimension(0), src0.info()->dimension(1)), 1, dt));
}
-
+/** This function returns the amount of memory free reading from /proc/meminfo
+ *
+ * @return The free memory in kB
+ */
+uint64_t get_mem_free_from_meminfo();
} // namespace utils
} // namespace arm_compute
#endif /* __UTILS_UTILS_H__*/