From 4841c97170b85be0706b65d424e967e561cef932 Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Wed, 3 Feb 2021 12:17:35 +0000 Subject: Add mlgo to graph examples Resolves COMPMID-3847 Change-Id: I99f73bfc8eda66e8ce1dd1f2a18be76e9d826569 Signed-off-by: SiCong Li Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5033 Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- arm_compute/graph/Types.h | 3 ++- arm_compute/graph/backends/CL/CLDeviceBackend.h | 12 +++++++----- examples/graph_alexnet.cpp | 3 ++- examples/graph_deepspeech_v0_4_1.cpp | 3 ++- examples/graph_edsr.cpp | 3 ++- examples/graph_googlenet.cpp | 3 ++- examples/graph_inception_resnet_v1.cpp | 3 ++- examples/graph_inception_resnet_v2.cpp | 3 ++- examples/graph_inception_v3.cpp | 3 ++- examples/graph_inception_v4.cpp | 3 ++- examples/graph_lenet.cpp | 3 ++- examples/graph_mnist.cpp | 3 ++- examples/graph_mobilenet.cpp | 4 +++- examples/graph_mobilenet_v2.cpp | 3 ++- examples/graph_resnet12.cpp | 3 ++- examples/graph_resnet50.cpp | 3 ++- examples/graph_resnet_v2_50.cpp | 3 ++- examples/graph_resnext50.cpp | 3 ++- examples/graph_shufflenet.cpp | 3 ++- examples/graph_squeezenet.cpp | 3 ++- examples/graph_squeezenet_v1_1.cpp | 3 ++- examples/graph_srcnn955.cpp | 3 ++- examples/graph_ssd_mobilenet.cpp | 3 ++- examples/graph_vgg16.cpp | 3 ++- examples/graph_vgg19.cpp | 3 ++- examples/graph_vgg_vdsr.cpp | 3 ++- examples/graph_yolov3.cpp | 3 ++- src/graph/backends/CL/CLDeviceBackend.cpp | 8 ++++++-- tests/benchmark_examples/RunExample.cpp | 6 ++++-- utils/CommonGraphOptions.cpp | 8 ++++++-- utils/CommonGraphOptions.h | 4 +++- 31 files changed, 79 insertions(+), 38 deletions(-) diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h index c5d3d17a9b..b891c1772f 100644 --- a/arm_compute/graph/Types.h +++ b/arm_compute/graph/Types.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -87,6 +87,7 @@ struct GraphConfig CLTunerMode tuner_mode{ CLTunerMode::EXHAUSTIVE }; /**< Tuner mode to be used by the CL tuner */ int num_threads{ -1 }; /**< Number of threads to use (thread capable backends), if 0 the backend will auto-initialize, if -1 the backend will stay as it is. */ std::string tuner_file{ "acl_tuner.csv" }; /**< File to load/store tuning values from */ + std::string mlgo_file{ "heuristics.mlgo" }; /**< Filename to load MLGO heuristics from */ }; /**< Device target types */ diff --git a/arm_compute/graph/backends/CL/CLDeviceBackend.h b/arm_compute/graph/backends/CL/CLDeviceBackend.h index a8ee25d7e2..82c0eacd11 100644 --- a/arm_compute/graph/backends/CL/CLDeviceBackend.h +++ b/arm_compute/graph/backends/CL/CLDeviceBackend.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,7 @@ #include "arm_compute/graph/IDeviceBackend.h" #include "arm_compute/runtime/CL/CLBufferAllocator.h" +#include "arm_compute/runtime/CL/CLGEMMHeuristicsHandle.h" #include "arm_compute/runtime/CL/CLTuner.h" namespace arm_compute @@ -70,10 +71,11 @@ public: std::shared_ptr create_weights_manager() override; private: - int _context_count; /**< Counts how many contexts are currently using the backend */ - CLTuner _tuner; /**< CL kernel tuner */ - std::unique_ptr _allocator; /**< CL buffer affinity allocator */ - std::string _tuner_file; /**< Filename to load/store the tuner's values from */ + int _context_count; /**< Counts how many contexts are currently using the backend */ + CLTuner _tuner; /**< CL kernel tuner */ + CLGEMMHeuristicsHandle _gemm_heuristics; /**< GEMM heuristics */ + std::unique_ptr _allocator; /**< CL buffer affinity allocator */ + std::string _tuner_file; /**< Filename to load/store the tuner's values from */ }; } // namespace backends } // namespace graph diff --git a/examples/graph_alexnet.cpp b/examples/graph_alexnet.cpp index ce398be6cf..7f4e75aaf8 100644 --- a/examples/graph_alexnet.cpp +++ b/examples/graph_alexnet.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -159,6 +159,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; // Load the precompiled kernels from a file into the kernel library, in this way the next time they are needed // compilation won't be required. diff --git a/examples/graph_deepspeech_v0_4_1.cpp b/examples/graph_deepspeech_v0_4_1.cpp index 4a8a8b15a9..a5658625c7 100644 --- a/examples/graph_deepspeech_v0_4_1.cpp +++ b/examples/graph_deepspeech_v0_4_1.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -211,6 +211,7 @@ public: config.num_threads = common_params.threads; config.use_tuner = common_params.enable_tuner; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); graph.finalize(common_params.target, config); diff --git a/examples/graph_edsr.cpp b/examples/graph_edsr.cpp index 77783d97ed..0e41f12155 100644 --- a/examples/graph_edsr.cpp +++ b/examples/graph_edsr.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Arm Limited. + * Copyright (c) 2020-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -75,6 +75,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; context.set_config(config); diff --git a/examples/graph_googlenet.cpp b/examples/graph_googlenet.cpp index 0a53355611..7555d805c1 100644 --- a/examples/graph_googlenet.cpp +++ b/examples/graph_googlenet.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -130,6 +130,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_inception_resnet_v1.cpp b/examples/graph_inception_resnet_v1.cpp index 7a55733a20..6ae5b5dc77 100644 --- a/examples/graph_inception_resnet_v1.cpp +++ b/examples/graph_inception_resnet_v1.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -215,6 +215,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_inception_resnet_v2.cpp b/examples/graph_inception_resnet_v2.cpp index 60236d0780..ae37ee507d 100644 --- a/examples/graph_inception_resnet_v2.cpp +++ b/examples/graph_inception_resnet_v2.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -196,6 +196,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_inception_v3.cpp b/examples/graph_inception_v3.cpp index 5cacbcb6e1..8ceeb5c68e 100644 --- a/examples/graph_inception_v3.cpp +++ b/examples/graph_inception_v3.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -201,6 +201,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); graph.finalize(common_params.target, config); diff --git a/examples/graph_inception_v4.cpp b/examples/graph_inception_v4.cpp index db2a31047e..cafa5c9f10 100644 --- a/examples/graph_inception_v4.cpp +++ b/examples/graph_inception_v4.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -156,6 +156,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); // Load the precompiled kernels from a file into the kernel library, in this way the next time they are needed diff --git a/examples/graph_lenet.cpp b/examples/graph_lenet.cpp index e5783078f1..6560a980cc 100644 --- a/examples/graph_lenet.cpp +++ b/examples/graph_lenet.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -111,6 +111,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_mnist.cpp b/examples/graph_mnist.cpp index 85ab0ab972..4ef96cc596 100644 --- a/examples/graph_mnist.cpp +++ b/examples/graph_mnist.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -140,6 +140,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_mobilenet.cpp b/examples/graph_mobilenet.cpp index b73f7a2abd..09b6e6e097 100644 --- a/examples/graph_mobilenet.cpp +++ b/examples/graph_mobilenet.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -100,6 +100,8 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_mobilenet_v2.cpp b/examples/graph_mobilenet_v2.cpp index fa16c94645..b1b33be2f5 100644 --- a/examples/graph_mobilenet_v2.cpp +++ b/examples/graph_mobilenet_v2.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -91,6 +91,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_resnet12.cpp b/examples/graph_resnet12.cpp index ebd2e5dd16..8818cf742a 100644 --- a/examples/graph_resnet12.cpp +++ b/examples/graph_resnet12.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -136,6 +136,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_resnet50.cpp b/examples/graph_resnet50.cpp index 47d258ede7..b585284c60 100644 --- a/examples/graph_resnet50.cpp +++ b/examples/graph_resnet50.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -115,6 +115,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); graph.finalize(common_params.target, config); diff --git a/examples/graph_resnet_v2_50.cpp b/examples/graph_resnet_v2_50.cpp index 921fb145d6..472bf02b47 100644 --- a/examples/graph_resnet_v2_50.cpp +++ b/examples/graph_resnet_v2_50.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -118,6 +118,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); graph.finalize(common_params.target, config); diff --git a/examples/graph_resnext50.cpp b/examples/graph_resnext50.cpp index 1d9ed8dc89..ec87e0b882 100644 --- a/examples/graph_resnext50.cpp +++ b/examples/graph_resnext50.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -102,6 +102,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_shufflenet.cpp b/examples/graph_shufflenet.cpp index 300d0f15a1..f90f36149d 100644 --- a/examples/graph_shufflenet.cpp +++ b/examples/graph_shufflenet.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -148,6 +148,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_squeezenet.cpp b/examples/graph_squeezenet.cpp index 2e72c14763..3d32794e8d 100644 --- a/examples/graph_squeezenet.cpp +++ b/examples/graph_squeezenet.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -168,6 +168,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); graph.finalize(common_params.target, config); diff --git a/examples/graph_squeezenet_v1_1.cpp b/examples/graph_squeezenet_v1_1.cpp index 1708ac2f5a..6d4ffee994 100644 --- a/examples/graph_squeezenet_v1_1.cpp +++ b/examples/graph_squeezenet_v1_1.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -168,6 +168,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); graph.finalize(common_params.target, config); diff --git a/examples/graph_srcnn955.cpp b/examples/graph_srcnn955.cpp index bcc3824c60..f4ffc02130 100644 --- a/examples/graph_srcnn955.cpp +++ b/examples/graph_srcnn955.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -119,6 +119,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); graph.finalize(common_params.target, config); diff --git a/examples/graph_ssd_mobilenet.cpp b/examples/graph_ssd_mobilenet.cpp index f5af84f4d4..c0859227ab 100644 --- a/examples/graph_ssd_mobilenet.cpp +++ b/examples/graph_ssd_mobilenet.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -97,6 +97,7 @@ public: config.num_threads = common_params.threads; config.use_tuner = common_params.enable_tuner; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_vgg16.cpp b/examples/graph_vgg16.cpp index a4c5e6bbd2..83e663798b 100644 --- a/examples/graph_vgg16.cpp +++ b/examples/graph_vgg16.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -216,6 +216,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); graph.finalize(common_params.target, config); diff --git a/examples/graph_vgg19.cpp b/examples/graph_vgg19.cpp index c95fb03368..03f7e1606c 100644 --- a/examples/graph_vgg19.cpp +++ b/examples/graph_vgg19.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -227,6 +227,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); graph.finalize(common_params.target, config); diff --git a/examples/graph_vgg_vdsr.cpp b/examples/graph_vgg_vdsr.cpp index 3fa7dd1330..bdb898081d 100644 --- a/examples/graph_vgg_vdsr.cpp +++ b/examples/graph_vgg_vdsr.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -140,6 +140,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); graph.finalize(common_params.target, config); diff --git a/examples/graph_yolov3.cpp b/examples/graph_yolov3.cpp index 54aaf201cb..3c8ddbffd8 100644 --- a/examples/graph_yolov3.cpp +++ b/examples/graph_yolov3.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -401,6 +401,7 @@ public: config.use_tuner = common_params.enable_tuner; config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/src/graph/backends/CL/CLDeviceBackend.cpp b/src/graph/backends/CL/CLDeviceBackend.cpp index 50dd799ee1..f8e22ca7a0 100644 --- a/src/graph/backends/CL/CLDeviceBackend.cpp +++ b/src/graph/backends/CL/CLDeviceBackend.cpp @@ -65,7 +65,7 @@ bool file_exists(const std::string &filename) static detail::BackendRegistrar CLDeviceBackend_registrar(Target::CL); CLDeviceBackend::CLDeviceBackend() - : _context_count(0), _tuner(), _allocator(nullptr), _tuner_file() + : _context_count(0), _tuner(), _gemm_heuristics(), _allocator(nullptr), _tuner_file() { } @@ -87,7 +87,7 @@ void CLDeviceBackend::set_kernel_tuning_mode(CLTunerMode tuning_mode) void CLDeviceBackend::initialize_backend() { // Setup Scheduler - CLScheduler::get().default_init(&_tuner); + CLScheduler::get().default_init(&_tuner, &_gemm_heuristics); // Create allocator with new context _allocator = std::make_unique(nullptr /* legacy path for CLCoreRuntimeContext */); } @@ -123,6 +123,10 @@ void CLDeviceBackend::setup_backend_context(GraphContext &ctx) set_kernel_tuning(ctx.config().use_tuner); set_kernel_tuning_mode(ctx.config().tuner_mode); + // Attempt to load mlgo heuristics + ARM_COMPUTE_ERROR_ON(CLScheduler::get().gemm_heuristics() == nullptr); + CLScheduler::get().gemm_heuristics()->reload_from_file(ctx.config().mlgo_file); + // Setup a management backend if(ctx.memory_management_ctx(Target::CL) == nullptr) { diff --git a/tests/benchmark_examples/RunExample.cpp b/tests/benchmark_examples/RunExample.cpp index 925daaf156..8adcd95ff6 100644 --- a/tests/benchmark_examples/RunExample.cpp +++ b/tests/benchmark_examples/RunExample.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -34,6 +34,7 @@ #include "utils/command_line/CommandLineParser.h" #ifdef ARM_COMPUTE_CL +#include "arm_compute/runtime/CL/CLGEMMHeuristicsHandle.h" #include "arm_compute/runtime/CL/CLHelpers.h" #include "arm_compute/runtime/CL/CLScheduler.h" #endif /* ARM_COMPUTE_CL */ @@ -127,12 +128,13 @@ int run_example(int argc, char **argv, std::unique_ptr example) } #ifdef ARM_COMPUTE_CL + CLGEMMHeuristicsHandle gemm_h; if(opencl_is_available()) { auto ctx_dev_err = create_opencl_context_and_device(); ARM_COMPUTE_ERROR_ON_MSG(std::get<2>(ctx_dev_err) != CL_SUCCESS, "Failed to create OpenCL context"); CLScheduler::get() - .default_init_with_context(std::get<1>(ctx_dev_err), std::get<0>(ctx_dev_err)); + .default_init_with_context(std::get<1>(ctx_dev_err), std::get<0>(ctx_dev_err), nullptr, &gemm_h); } #endif /* ARM_COMPUTE_CL */ diff --git a/utils/CommonGraphOptions.cpp b/utils/CommonGraphOptions.cpp index d262ea86e9..44d66fa91b 100644 --- a/utils/CommonGraphOptions.cpp +++ b/utils/CommonGraphOptions.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -87,6 +87,7 @@ namespace utils os << "Cache enabled? : " << (common_params.enable_cl_cache ? true_str : false_str) << std::endl; os << "Tuner mode : " << common_params.tuner_mode << std::endl; os << "Tuner file : " << common_params.tuner_file << std::endl; + os << "MLGO file : " << common_params.mlgo_file << std::endl; os << "Fast math enabled? : " << (common_params.fast_math_hint == FastMathHint::Enabled ? true_str : false_str) << std::endl; if(!common_params.data_path.empty()) { @@ -129,7 +130,8 @@ CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser) validation_file(parser.add_option>("validation-file")), validation_path(parser.add_option>("validation-path")), validation_range(parser.add_option>("validation-range")), - tuner_file(parser.add_option>("tuner-file")) + tuner_file(parser.add_option>("tuner-file")), + mlgo_file(parser.add_option>("mlgo-file")) { std::set supported_targets { @@ -183,6 +185,7 @@ CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser) validation_path->set_help("Path to the validation data"); validation_range->set_help("Range of the images to validate for (Format : start,end)"); tuner_file->set_help("File to load/save CLTuner values"); + mlgo_file->set_help("File to load MLGO heuristics"); } CommonGraphParams consume_common_graph_parameters(CommonGraphOptions &options) @@ -211,6 +214,7 @@ CommonGraphParams consume_common_graph_parameters(CommonGraphOptions &options) common_params.validation_range_start = validation_range.first; common_params.validation_range_end = validation_range.second; common_params.tuner_file = options.tuner_file->value(); + common_params.mlgo_file = options.mlgo_file->value(); return common_params; } diff --git a/utils/CommonGraphOptions.h b/utils/CommonGraphOptions.h index dac2e10b19..13cd653e46 100644 --- a/utils/CommonGraphOptions.h +++ b/utils/CommonGraphOptions.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -108,6 +108,7 @@ struct CommonGraphParams std::string validation_file{}; std::string validation_path{}; std::string tuner_file{}; + std::string mlgo_file{}; unsigned int validation_range_start{ 0 }; unsigned int validation_range_end{ std::numeric_limits::max() }; }; @@ -165,6 +166,7 @@ public: SimpleOption *validation_path; /**< Validation data path */ SimpleOption *validation_range; /**< Validation range */ SimpleOption *tuner_file; /**< File to load/store the tuner's values from */ + SimpleOption *mlgo_file; /**< File to load the MLGO heuristics from */ }; /** Consumes the common graph options and creates a structure containing any information -- cgit v1.2.1