From 5ca7409bc02ea1ac8ea34f0779f18221880fa6ac Mon Sep 17 00:00:00 2001 From: Gian Marco Date: Thu, 8 Feb 2018 16:21:54 +0000 Subject: COMPMID-765 - Used GEMM-based convolution in VGG16 In order to use GEMM-based convolution in VGG16, it has been created a function which allocates 1.8 GB. If the function fails, will be used DIRECT convolution instead Change-Id: Ibec8928ee6fe6684d6dc24b7df380beeb671bf27 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/119490 Tested-by: Jenkins Reviewed-by: Michalis Spyrou Reviewed-by: Gian Marco Iodice Reviewed-by: Anthony Barbier --- examples/graph_vgg16.cpp | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) (limited to 'examples/graph_vgg16.cpp') diff --git a/examples/graph_vgg16.cpp b/examples/graph_vgg16.cpp index d97c5b5d02..c3eb922f0e 100644 --- a/examples/graph_vgg16.cpp +++ b/examples/graph_vgg16.cpp @@ -33,6 +33,20 @@ using namespace arm_compute::utils; using namespace arm_compute::graph; using namespace arm_compute::graph_utils; +namespace +{ +/** This function checks if we can use GEMM-based convolution trying to allocate a memory of size "size_in_bytes" + * + * @param[in] size_in_bytes Memory size in bytes needed for VGG-16 + * + * @return The convolution layer hint + */ +ConvolutionMethodHint convolution_hint_vgg16(size_t size_in_bytes) +{ + return ((get_mem_free_from_meminfo() * 1024) >= size_in_bytes) ? ConvolutionMethodHint::GEMM : ConvolutionMethodHint::DIRECT; +} +} // namespace + /** Example demonstrating how to implement VGG16's network using the Compute Library's graph API * * @param[in] argc Number of arguments @@ -52,8 +66,11 @@ public: constexpr float mean_b = 103.939f; /* Mean value to subtract from blue channel */ // Set target. 0 (NEON), 1 (OpenCL). By default it is NEON - TargetHint target_hint = set_target_hint(argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0); - ConvolutionMethodHint convolution_hint = ConvolutionMethodHint::DIRECT; + TargetHint target_hint = set_target_hint(argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0); + + // Check if we can use GEMM-based convolutions evaluating if the platform has at least 1.8 GB of available memory + const size_t memory_required = 1932735283L; + ConvolutionMethodHint convolution_hint = convolution_hint_vgg16(memory_required); // Parse arguments if(argc < 2) -- cgit v1.2.1