diff options
author | SiCong Li <sicong.li@arm.com> | 2021-01-08 15:16:02 +0000 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2021-02-02 22:05:34 +0000 |
commit | 7061eb283969f9a020c08349454447564e4dd5b3 (patch) | |
tree | cf11604aaf57ee86ff6e4980a3ddf95bb49167e8 /src/runtime/CL/mlgo/MLGOHeuristics.cpp | |
parent | 74a142c11ec0b2f2b3fe1feb3fdfd98e9190762e (diff) | |
download | ComputeLibrary-7061eb283969f9a020c08349454447564e4dd5b3.tar.gz |
Implement MLGO module
* Implement MLGOHeuristics which provides a query and a loading
interface
* Implement a top-down parser MLGOParser for parsing dotmlgo
* Add validation tests for MLGOHeuristics
Resolves COMPMID-3840, COMPMID-3841
Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: Iae96d2779524b2dd83623d1a3a30ef57823ae084
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4941
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/mlgo/MLGOHeuristics.cpp')
-rw-r--r-- | src/runtime/CL/mlgo/MLGOHeuristics.cpp | 242 |
1 files changed, 242 insertions, 0 deletions
diff --git a/src/runtime/CL/mlgo/MLGOHeuristics.cpp b/src/runtime/CL/mlgo/MLGOHeuristics.cpp new file mode 100644 index 0000000000..ec99b50488 --- /dev/null +++ b/src/runtime/CL/mlgo/MLGOHeuristics.cpp @@ -0,0 +1,242 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/CL/mlgo/MLGOHeuristics.h" +#include "arm_compute/core/Log.h" +#include "src/runtime/CL/mlgo/MLGOParser.h" + +#include <fstream> + +namespace arm_compute +{ +namespace mlgo +{ +bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs) +{ + return std::tie(lhs.m0, lhs.n0, lhs.k0) == std::tie(rhs.m0, rhs.n0, rhs.k0); +} +bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs) +{ + return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs, + rhs.export_cl_image); +} +bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs) +{ + return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0, + rhs.interleave_lhs, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image); +} + +constexpr size_t MLGOHeuristics::_max_num_trees; + +MLGOHeuristics::MLGOHeuristics() + : _indices{}, _trees{}, _tree_valid{}, _valid{ false } +{ +} + +std::pair<bool, GEMMType> MLGOHeuristics::query_gemm_type(Query query) const +{ + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm type"); + const auto invalid = GEMMType::RESHAPED; + if(!_valid) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); + return { false, invalid }; + } + auto index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type); + GEMMShape shape_query{ query.m, query.n, query.k, query.b }; + if(_trees.find(index) == _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); + return { false, invalid }; + } + return _trees.at(index).query<GEMMType>(shape_query); +} +std::pair<bool, GEMMConfigNative> MLGOHeuristics::query_gemm_config_native(Query query) const +{ + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm config native"); + const auto invalid = GEMMConfigNative{}; + if(!_valid) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); + return { false, invalid }; + } + auto index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type); + GEMMShape shape_query{ query.m, query.n, query.k, query.b }; + if(_trees.find(index) == _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); + return { false, invalid }; + } + return _trees.at(index).query<GEMMConfigNative>(shape_query); +} +std::pair<bool, GEMMConfigReshapedOnlyRHS> MLGOHeuristics::query_gemm_config_reshaped_only_rhs(Query query) const +{ + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm config reshaped only rhs"); + const auto invalid = GEMMConfigReshapedOnlyRHS{}; + if(!_valid) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); + return { false, invalid }; + } + auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type); + GEMMShape shape_query{ query.m, query.n, query.k, query.b }; + if(_trees.find(index) == _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); + return { false, invalid }; + } + return _trees.at(index).query<GEMMConfigReshapedOnlyRHS>(shape_query); +} +std::pair<bool, GEMMConfigReshaped> MLGOHeuristics::query_gemm_config_reshaped(Query query) const +{ + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm config reshaped"); + const auto invalid = GEMMConfigReshaped{}; + if(!_valid) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); + return { false, invalid }; + } + auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type); + GEMMShape shape_query{ query.m, query.n, query.k, query.b }; + if(_trees.find(index) == _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); + return { false, invalid }; + } + return _trees.at(index).query<GEMMConfigReshaped>(shape_query); +} + +bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id) +{ + bool status; + HeuristicTree *tree{ nullptr }; + std::tie(status, tree) = get_heuristic_tree(id); + if(!status) + { + return status; + } + status = tree->check(); + if(!status) + { + return status; + } + _tree_valid[id] = true; + return true; +} + +bool MLGOHeuristics::check_all() const +{ + // Tree validities are already checked and cached. + bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v) + { + return !v.second; + }) + == _tree_valid.end(); + if(!all_trees_are_checked) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo"); + return false; + } + + // Other top level checks... + + return true; +} + +std::pair<bool, HeuristicTree *> MLGOHeuristics::get_heuristic_tree(HeuristicTree::TreeID id) +{ + if(_indices.find(id) == _indices.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id); + return std::make_pair(false, nullptr); + } + const auto index = _indices[id]; + + if(_trees.find(index) == _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); + return std::make_pair(false, nullptr); + } + auto &t = _trees[index]; + + return std::make_pair(true, &t); +} + +bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t) +{ + if(_indices.size() >= _max_num_trees) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees); + return false; + } + // PRE: correctness of t is guaranteed by the tree construction process + // Ensure unique id + const auto id = t.id(); + if(_indices.find(id) != _indices.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id); + return false; + } + + // Ensure unique index + const auto index = t.index(); + if(_trees.find(index) != _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists"); + return false; + } + + _indices[id] = index; + _trees[index] = std::move(t); + _tree_valid[id] = false; + return true; +} + +bool MLGOHeuristics::reload_from_file(const std::string &filename) +{ + std::ifstream fs; + fs.exceptions(std::ifstream::badbit); + fs.open(filename, std::ios::in); + if(!fs.is_open()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str()); + return _valid = false; + } + return reload_from_stream(fs); +} + +bool MLGOHeuristics::reload_from_stream(std::istream &in) +{ + auto parsed = parser::parse_mlgo(in); + if(!parsed.first) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead"); + return _valid = false; + } + *this = std::move(parsed.second); + ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO loaded successfully"); + return _valid = true; +} + +} // namespace mlgo +} // namespace arm_compute
\ No newline at end of file |