diff options
author | SiCong Li <sicong.li@arm.com> | 2019-11-05 10:43:06 +0000 |
---|---|---|
committer | SiCong Li <sicong.li@arm.com> | 2019-11-12 09:52:55 +0000 |
commit | 75041a1cb81c59a5a5ddd9b708476c0142362d9e (patch) | |
tree | 5ef625879aaa35c757ee4070f753f769f4362011 /examples/gemm_tuner/GemmTuner.py | |
parent | 77d42528b796f3b8f5033785d3bbb8d9cb3fc637 (diff) | |
download | ComputeLibrary-75041a1cb81c59a5a5ddd9b708476c0142362d9e.tar.gz |
COMPMID-2563 Change how the best overall GEMM configuration is selected
* Based on a specified tolerance, each GEMMParam (GEMM Shape) now can
have a set of best GEMM configurations, instead of just a single one.
This improves the robustness and completeness of the tuned results, and
is in compliance with how we define the GEMMParam archetypes (the main
goal of this story)
* The tuner then tries to find the best overall GEMMConfig, from all the
best config sets, through the same voting mechanism: the config that
gets voted the most is the best overall GEMMConfig.
Change-Id: Ief770bb6ffc04629d91f1dc778eea69274e007f0
Signed-off-by: SiCong Li <sicong.li@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2228
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'examples/gemm_tuner/GemmTuner.py')
-rw-r--r-- | examples/gemm_tuner/GemmTuner.py | 159 |
1 files changed, 98 insertions, 61 deletions
diff --git a/examples/gemm_tuner/GemmTuner.py b/examples/gemm_tuner/GemmTuner.py index 8093ad0e11..29c414cfe8 100644 --- a/examples/gemm_tuner/GemmTuner.py +++ b/examples/gemm_tuner/GemmTuner.py @@ -31,7 +31,7 @@ import os from collections import Counter, defaultdict, deque, namedtuple from enum import Enum from pathlib import Path -from typing import Deque, Dict, Generator, List, NamedTuple, Tuple, Union +from typing import Deque, Dict, Generator, List, NamedTuple, Set, Tuple, Union ################################################################################ # Types @@ -51,6 +51,9 @@ class GEMMParam(NamedTuple): def parse_from_strs(*args): return GEMMParam(*map(int, args)) + def __str__(self): + return "-".join(map(str, self)) + # Gemm configuration for strategy Native class NativeGEMMConfig(NamedTuple): @@ -63,6 +66,9 @@ class NativeGEMMConfig(NamedTuple): *mnk, = map(int, args) return NativeGEMMConfig(*mnk) + def __str__(self): + return "-".join(map(str, self)) + # Gemm configuration for strategy Reshaped Only RHS class ReshapedOnlyRHSGEMMConfig(NamedTuple): @@ -80,6 +86,9 @@ class ReshapedOnlyRHSGEMMConfig(NamedTuple): transpose_rhs = transpose_rhs == 1 return ReshapedOnlyRHSGEMMConfig(*mnkh, interleave_rhs, transpose_rhs) + def __str__(self): + return "-".join(map(str, self)) + # Gemm configuration for strategy Reshaped class ReshapedGEMMConfig(NamedTuple): @@ -100,13 +109,19 @@ class ReshapedGEMMConfig(NamedTuple): transpose_rhs = transpose_rhs == 1 return ReshapedGEMMConfig(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs) + def __str__(self): + return "-".join(map(str, self)) + # Measurement we take from the benchmark result. class Measurement(NamedTuple): opencl_timer_ms: float - def is_better_than(self, other): - return self < other + def is_close_to(self, other, tol): + return math.fabs(self.opencl_timer_ms - other.opencl_timer_ms) < tol + + def is_better_than(self, other, tol): + return self < other and not self.is_close_to(other) def __add__(self, other): return Measurement(self.opencl_timer_ms + other.opencl_timer_ms) @@ -163,37 +178,59 @@ class GEMMBenchmarkResultRecorder: SummaryLevel = Enum("SummaryLevel", ["Short", "Detailed"]) - def __init__(self): + def __init__(self, tol=0.01): """ Initializer """ - # Record that holds all recorded benchmark results. - # Indexed by (GEMMParam, Strategy) and each such pair maps to a deque of (GEMMConfig, Measurements), - # with the best one always at the front (index 0) of the deque - self._benchmark_result_record: Dict[ - Tuple[GEMMParam, Strategy], Deque[Tuple[GEMMConfig, Measurements]] - ] = {} + self._benchmark_result_record: List[BenchmarkResult] = [] # Strategies recorded self._strategies = set() + self._tol = tol def add(self, benchmark_result: BenchmarkResult): """ Add a benchmark result to the record. - Keep the best gemm config at the front of the deque. """ gemm_param, strategy, gemm_config, measurement = benchmark_result # Update strategies encoutnered self._strategies.add(strategy) - # Update the best configuration of the given gemm param - configs_with_measurements = self._benchmark_result_record.setdefault( - (gemm_param, strategy), deque([]) - ) - if len(configs_with_measurements) == 0: - configs_with_measurements.append((gemm_config, measurement)) - else: - best_config, best_measurement = configs_with_measurements[0] - if measurement.is_better_than(best_measurement): - configs_with_measurements.appendleft((gemm_config, measurement)) - else: - configs_with_measurements.append((gemm_config, measurement)) + + self._benchmark_result_record.append(benchmark_result) + + def get_record(self) -> Generator[BenchmarkResult, None, None]: + """ Return an iterator that iterates over the record. + """ + yield from self._benchmark_result_record + + def get_best_gemm_configs(self): + """ Get the best GEMMConfig set per GEMMParam per Strategy + """ + best_gc_sets: Dict[ + Tuple[GEMMParam, Strategy], List[Tuple[GEMMConfig, Measurement]] + ] = defaultdict(list) + for gemm_param, strategy, gemm_config, measurement in self.get_record(): + best_gc_set = best_gc_sets.setdefault((gemm_param, strategy), []) + best_gc_set.append((gemm_config, measurement)) + # Sort the best config set (list) + best_gc_set = sorted(best_gc_set, key=lambda gc_and_m: gc_and_m[1]) + # Filter out configs that are beyond tolerance to the best GEMMConfig's measurement + best_gc, best_m = best_gc_set[0] + best_gc_set_new = [ + (gemm_config, measurement) + for gemm_config, measurement in best_gc_set[1:] + if measurement.is_close_to(best_m, self._tol) + ] + # Add back the best config + best_gc_set_new.insert(0, (best_gc, best_m)) + best_gc_sets[(gemm_param, strategy)] = best_gc_set_new + + return best_gc_sets + + def get_best_gemm_configs_as_sequence(self): + """ Get the best GEMMConfig set per GEMMParam per Strategy, and flatten the result into a sequence + of BenchmarkResults + """ + for (gemm_param, strategy), best_gc_sets in self.get_best_gemm_configs().items(): + for best_gemm_config, best_measurement in best_gc_sets: + yield BenchmarkResult(gemm_param, strategy, best_gemm_config, best_measurement) def get_config_distributions(self): """ Return GEMMConfigDistribution for each strategy @@ -201,9 +238,10 @@ class GEMMBenchmarkResultRecorder: gemm_config_distributions: Dict[Strategy, GEMMConfigDistribution] = defaultdict( GEMMConfigDistribution ) - for benchmark_result in self.get_record(only_best_config=True): - gemm_param, strategy, gemm_config, measurement = benchmark_result + for benchmark_result in self.get_best_gemm_configs_as_sequence(): + _, strategy, _, _ = benchmark_result gemm_config_distributions[strategy].add(benchmark_result) + return gemm_config_distributions def save_to_csvs(self, out_dir, only_best_config=True): @@ -226,12 +264,15 @@ class GEMMBenchmarkResultRecorder: logging.info("Skipping {}".format(out_csv_path)) continue logging.info("Saving csv file to {}".format(out_csv_path)) + record = ( + self.get_best_gemm_configs_as_sequence() if only_best_config else self.get_record() + ) with open(out_csv_path, "w") as f: csv_writer = csv.DictWriter(f, fieldnames=BenchmarkResultCSVRow._fields) csv_writer.writeheader() csv_writer.writerows( benchmark_result_2_csv_row(res)._asdict() - for res in self.get_record(only_best_config) + for res in record if res.strategy == strategy ) logging.info("Saved") @@ -239,9 +280,9 @@ class GEMMBenchmarkResultRecorder: def summary(self, sum_level=SummaryLevel.Short): """ Return the summary string of the record """ - num_raw_records = sum(1 for _ in self.get_record(only_best_config=False)) + num_raw_records = sum(1 for _ in self.get_record()) gemm_params_per_strategy = defaultdict(list) - for gemm_param, strategy, _, _ in self.get_record(only_best_config=True): + for gemm_param, strategy in self.get_best_gemm_configs().keys(): gemm_params_per_strategy[strategy].append(gemm_param) global_summary = f""" === {self.__class__.__name__} Summary === @@ -265,20 +306,6 @@ GEMM parameters: strategy_summaries.append(summary) return global_summary + "".join(strategy_summaries) - def get_record(self, only_best_config=True) -> Generator[BenchmarkResult, None, None]: - """ Return an iterator that iterates over the record. - """ - for ( - (gemm_param, strategy), - configs_with_measurements, - ) in self._benchmark_result_record.items(): - if only_best_config: - best_gemm_config, best_measurement = configs_with_measurements[0] - yield BenchmarkResult(gemm_param, strategy, best_gemm_config, best_measurement) - else: - for gemm_config, measurement in configs_with_measurements: - yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement) - class GEMMConfigDistribution: """ A representation of the GEMM Configuration distribution produced by the GEMMBenchmarkResultRecorder. @@ -299,18 +326,13 @@ class GEMMConfigDistribution: self._gemm_config_dist[gemm_config].append((gemm_param, measurement)) self._gemm_config_freq[gemm_config] += 1 - def get_measurement(self, gemm_config, measure=min): - """ Get measurement of a gemm_config - """ - return measure(list(zip(*self._gemm_config_dist[gemm_config]))[1]) - def distribution(self): return self._gemm_config_dist def frequency(self): """ Get the frequency of each (best) gemm config recorded """ - return self._gemm_config_freq.copy() + return self._gemm_config_freq.most_common() def best_config(self): """ Get the overall best config, as voted by all benchmark results. @@ -392,7 +414,9 @@ def parse_benchmark_commandline(commandline: str) -> Dict[str, str]: return {transform(name): val for name, val in args} -def extract_benchmark_results(json_results: Dict) -> Generator[BenchmarkResult, None, None]: +def extract_benchmark_results( + json_results: Dict, measurement_method="avg" +) -> Generator[BenchmarkResult, None, None]: """ Parse the benchmark result and extract relevant information, namely: GEMM param, Strategy, @@ -430,8 +454,14 @@ def extract_benchmark_results(json_results: Dict) -> Generator[BenchmarkResult, # Get instrument name and assert that it is the one we expect measurement_instrument_name = measurement_instrument.split("/")[0] assert measurement_instrument_name == "OpenCLTimer" - # Take the MINIMUM of the raw data as the measurement value - measurement_val = min(data["raw"]) + # Take either the minimum or the average of the raw data as the measurement value + if measurement_method == "min": + measurement_val = min(data["raw"]) + elif measurement_method == "avg": + measurement_val = sum(data["raw"]) / len(data["raw"]) + else: + raise ValueError("Invalid measurement method: {}".format(measurement_method)) + measurement = Measurement(measurement_val) yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement) @@ -456,7 +486,7 @@ def main(args): benchmark_results = extract_benchmark_results(parse_json(args.benchmark_results_dir)) # Add all benchmark results to the recorder - benchmark_result_recorder = GEMMBenchmarkResultRecorder() + benchmark_result_recorder = GEMMBenchmarkResultRecorder(tol=args.tolerance) for benchmark_result in benchmark_results: benchmark_result_recorder.add(benchmark_result) @@ -466,20 +496,18 @@ def main(args): recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short # Print overall summary of the recorded results - print(benchmark_result_recorder.summary(sum_level=recorder_sum_level)) + logging.info(benchmark_result_recorder.summary(sum_level=recorder_sum_level)) # Get GEMM configuration distributions for each strategy all_config_dists = benchmark_result_recorder.get_config_distributions() - print("=== Result ===") + logging.info("=== Result ===") for strategy, config_dist in all_config_dists.items(): - print("Strategy: {}".format(strategy.name)) - print("GEMM Config votes: ") - print("GEMM Config: Best measurement, Vote") - for config, freq in config_dist.frequency().items(): - print(config, end=": ") - print(config_dist.get_measurement(config), freq, sep=",") - print( + logging.info("Strategy: {}".format(strategy.name)) + logging.debug("GEMM Config, Votes") + for config, freq in config_dist.frequency(): + logging.debug("{}, {}".format(config, freq)) + logging.info( "Best GEMM Config: {} with std: {}".format(config_dist.best_config(), config_dist.std()) ) @@ -513,6 +541,15 @@ if __name__ == "__main__": help="Path to directory that holds output csv files. One per strategy", ) parser.add_argument( + "-t", + "--tolerance", + action="store", + type=float, + default=0.01, + help="For testing if two GEMMConfigs are equivalent in terms of performance. The tolerance is OpenCL timer in\ + milliseconds. Recommended value: <= 0.1 ms", + ) + parser.add_argument( "-D", "--debug", dest="debug", action="store_true", help="Enable script debugging output" ) args = parser.parse_args() |