aboutsummaryrefslogtreecommitdiff
path: root/examples/graph_vgg16.cpp
diff options
context:
space:
mode:
authorGian Marco <gianmarco.iodice@arm.com>2018-02-08 16:21:54 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:47:18 +0000
commit5ca7409bc02ea1ac8ea34f0779f18221880fa6ac (patch)
tree3018aa85c14bb8f89979b13bca76c60d249e452e /examples/graph_vgg16.cpp
parent284cfe2e3a44e5b20978e561c96c94d2193e93a1 (diff)
downloadComputeLibrary-5ca7409bc02ea1ac8ea34f0779f18221880fa6ac.tar.gz
COMPMID-765 - Used GEMM-based convolution in VGG16
In order to use GEMM-based convolution in VGG16, it has been created a function which allocates 1.8 GB. If the function fails, will be used DIRECT convolution instead Change-Id: Ibec8928ee6fe6684d6dc24b7df380beeb671bf27 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/119490 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'examples/graph_vgg16.cpp')
-rw-r--r--examples/graph_vgg16.cpp21
1 files changed, 19 insertions, 2 deletions
diff --git a/examples/graph_vgg16.cpp b/examples/graph_vgg16.cpp
index d97c5b5d02..c3eb922f0e 100644
--- a/examples/graph_vgg16.cpp
+++ b/examples/graph_vgg16.cpp
@@ -33,6 +33,20 @@ using namespace arm_compute::utils;
using namespace arm_compute::graph;
using namespace arm_compute::graph_utils;
+namespace
+{
+/** This function checks if we can use GEMM-based convolution trying to allocate a memory of size "size_in_bytes"
+ *
+ * @param[in] size_in_bytes Memory size in bytes needed for VGG-16
+ *
+ * @return The convolution layer hint
+ */
+ConvolutionMethodHint convolution_hint_vgg16(size_t size_in_bytes)
+{
+ return ((get_mem_free_from_meminfo() * 1024) >= size_in_bytes) ? ConvolutionMethodHint::GEMM : ConvolutionMethodHint::DIRECT;
+}
+} // namespace
+
/** Example demonstrating how to implement VGG16's network using the Compute Library's graph API
*
* @param[in] argc Number of arguments
@@ -52,8 +66,11 @@ public:
constexpr float mean_b = 103.939f; /* Mean value to subtract from blue channel */
// Set target. 0 (NEON), 1 (OpenCL). By default it is NEON
- TargetHint target_hint = set_target_hint(argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0);
- ConvolutionMethodHint convolution_hint = ConvolutionMethodHint::DIRECT;
+ TargetHint target_hint = set_target_hint(argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0);
+
+ // Check if we can use GEMM-based convolutions evaluating if the platform has at least 1.8 GB of available memory
+ const size_t memory_required = 1932735283L;
+ ConvolutionMethodHint convolution_hint = convolution_hint_vgg16(memory_required);
// Parse arguments
if(argc < 2)