aboutsummaryrefslogtreecommitdiff
path: root/examples/gemm_tuner/GemmTuner.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/gemm_tuner/GemmTuner.py')
-rw-r--r--examples/gemm_tuner/GemmTuner.py159
1 files changed, 98 insertions, 61 deletions
diff --git a/examples/gemm_tuner/GemmTuner.py b/examples/gemm_tuner/GemmTuner.py
index 8093ad0e11..29c414cfe8 100644
--- a/examples/gemm_tuner/GemmTuner.py
+++ b/examples/gemm_tuner/GemmTuner.py
@@ -31,7 +31,7 @@ import os
from collections import Counter, defaultdict, deque, namedtuple
from enum import Enum
from pathlib import Path
-from typing import Deque, Dict, Generator, List, NamedTuple, Tuple, Union
+from typing import Deque, Dict, Generator, List, NamedTuple, Set, Tuple, Union
################################################################################
# Types
@@ -51,6 +51,9 @@ class GEMMParam(NamedTuple):
def parse_from_strs(*args):
return GEMMParam(*map(int, args))
+ def __str__(self):
+ return "-".join(map(str, self))
+
# Gemm configuration for strategy Native
class NativeGEMMConfig(NamedTuple):
@@ -63,6 +66,9 @@ class NativeGEMMConfig(NamedTuple):
*mnk, = map(int, args)
return NativeGEMMConfig(*mnk)
+ def __str__(self):
+ return "-".join(map(str, self))
+
# Gemm configuration for strategy Reshaped Only RHS
class ReshapedOnlyRHSGEMMConfig(NamedTuple):
@@ -80,6 +86,9 @@ class ReshapedOnlyRHSGEMMConfig(NamedTuple):
transpose_rhs = transpose_rhs == 1
return ReshapedOnlyRHSGEMMConfig(*mnkh, interleave_rhs, transpose_rhs)
+ def __str__(self):
+ return "-".join(map(str, self))
+
# Gemm configuration for strategy Reshaped
class ReshapedGEMMConfig(NamedTuple):
@@ -100,13 +109,19 @@ class ReshapedGEMMConfig(NamedTuple):
transpose_rhs = transpose_rhs == 1
return ReshapedGEMMConfig(*mnkvh, interleave_lhs, interleave_rhs, transpose_rhs)
+ def __str__(self):
+ return "-".join(map(str, self))
+
# Measurement we take from the benchmark result.
class Measurement(NamedTuple):
opencl_timer_ms: float
- def is_better_than(self, other):
- return self < other
+ def is_close_to(self, other, tol):
+ return math.fabs(self.opencl_timer_ms - other.opencl_timer_ms) < tol
+
+ def is_better_than(self, other, tol):
+ return self < other and not self.is_close_to(other)
def __add__(self, other):
return Measurement(self.opencl_timer_ms + other.opencl_timer_ms)
@@ -163,37 +178,59 @@ class GEMMBenchmarkResultRecorder:
SummaryLevel = Enum("SummaryLevel", ["Short", "Detailed"])
- def __init__(self):
+ def __init__(self, tol=0.01):
""" Initializer
"""
- # Record that holds all recorded benchmark results.
- # Indexed by (GEMMParam, Strategy) and each such pair maps to a deque of (GEMMConfig, Measurements),
- # with the best one always at the front (index 0) of the deque
- self._benchmark_result_record: Dict[
- Tuple[GEMMParam, Strategy], Deque[Tuple[GEMMConfig, Measurements]]
- ] = {}
+ self._benchmark_result_record: List[BenchmarkResult] = []
# Strategies recorded
self._strategies = set()
+ self._tol = tol
def add(self, benchmark_result: BenchmarkResult):
""" Add a benchmark result to the record.
- Keep the best gemm config at the front of the deque.
"""
gemm_param, strategy, gemm_config, measurement = benchmark_result
# Update strategies encoutnered
self._strategies.add(strategy)
- # Update the best configuration of the given gemm param
- configs_with_measurements = self._benchmark_result_record.setdefault(
- (gemm_param, strategy), deque([])
- )
- if len(configs_with_measurements) == 0:
- configs_with_measurements.append((gemm_config, measurement))
- else:
- best_config, best_measurement = configs_with_measurements[0]
- if measurement.is_better_than(best_measurement):
- configs_with_measurements.appendleft((gemm_config, measurement))
- else:
- configs_with_measurements.append((gemm_config, measurement))
+
+ self._benchmark_result_record.append(benchmark_result)
+
+ def get_record(self) -> Generator[BenchmarkResult, None, None]:
+ """ Return an iterator that iterates over the record.
+ """
+ yield from self._benchmark_result_record
+
+ def get_best_gemm_configs(self):
+ """ Get the best GEMMConfig set per GEMMParam per Strategy
+ """
+ best_gc_sets: Dict[
+ Tuple[GEMMParam, Strategy], List[Tuple[GEMMConfig, Measurement]]
+ ] = defaultdict(list)
+ for gemm_param, strategy, gemm_config, measurement in self.get_record():
+ best_gc_set = best_gc_sets.setdefault((gemm_param, strategy), [])
+ best_gc_set.append((gemm_config, measurement))
+ # Sort the best config set (list)
+ best_gc_set = sorted(best_gc_set, key=lambda gc_and_m: gc_and_m[1])
+ # Filter out configs that are beyond tolerance to the best GEMMConfig's measurement
+ best_gc, best_m = best_gc_set[0]
+ best_gc_set_new = [
+ (gemm_config, measurement)
+ for gemm_config, measurement in best_gc_set[1:]
+ if measurement.is_close_to(best_m, self._tol)
+ ]
+ # Add back the best config
+ best_gc_set_new.insert(0, (best_gc, best_m))
+ best_gc_sets[(gemm_param, strategy)] = best_gc_set_new
+
+ return best_gc_sets
+
+ def get_best_gemm_configs_as_sequence(self):
+ """ Get the best GEMMConfig set per GEMMParam per Strategy, and flatten the result into a sequence
+ of BenchmarkResults
+ """
+ for (gemm_param, strategy), best_gc_sets in self.get_best_gemm_configs().items():
+ for best_gemm_config, best_measurement in best_gc_sets:
+ yield BenchmarkResult(gemm_param, strategy, best_gemm_config, best_measurement)
def get_config_distributions(self):
""" Return GEMMConfigDistribution for each strategy
@@ -201,9 +238,10 @@ class GEMMBenchmarkResultRecorder:
gemm_config_distributions: Dict[Strategy, GEMMConfigDistribution] = defaultdict(
GEMMConfigDistribution
)
- for benchmark_result in self.get_record(only_best_config=True):
- gemm_param, strategy, gemm_config, measurement = benchmark_result
+ for benchmark_result in self.get_best_gemm_configs_as_sequence():
+ _, strategy, _, _ = benchmark_result
gemm_config_distributions[strategy].add(benchmark_result)
+
return gemm_config_distributions
def save_to_csvs(self, out_dir, only_best_config=True):
@@ -226,12 +264,15 @@ class GEMMBenchmarkResultRecorder:
logging.info("Skipping {}".format(out_csv_path))
continue
logging.info("Saving csv file to {}".format(out_csv_path))
+ record = (
+ self.get_best_gemm_configs_as_sequence() if only_best_config else self.get_record()
+ )
with open(out_csv_path, "w") as f:
csv_writer = csv.DictWriter(f, fieldnames=BenchmarkResultCSVRow._fields)
csv_writer.writeheader()
csv_writer.writerows(
benchmark_result_2_csv_row(res)._asdict()
- for res in self.get_record(only_best_config)
+ for res in record
if res.strategy == strategy
)
logging.info("Saved")
@@ -239,9 +280,9 @@ class GEMMBenchmarkResultRecorder:
def summary(self, sum_level=SummaryLevel.Short):
""" Return the summary string of the record
"""
- num_raw_records = sum(1 for _ in self.get_record(only_best_config=False))
+ num_raw_records = sum(1 for _ in self.get_record())
gemm_params_per_strategy = defaultdict(list)
- for gemm_param, strategy, _, _ in self.get_record(only_best_config=True):
+ for gemm_param, strategy in self.get_best_gemm_configs().keys():
gemm_params_per_strategy[strategy].append(gemm_param)
global_summary = f"""
=== {self.__class__.__name__} Summary ===
@@ -265,20 +306,6 @@ GEMM parameters:
strategy_summaries.append(summary)
return global_summary + "".join(strategy_summaries)
- def get_record(self, only_best_config=True) -> Generator[BenchmarkResult, None, None]:
- """ Return an iterator that iterates over the record.
- """
- for (
- (gemm_param, strategy),
- configs_with_measurements,
- ) in self._benchmark_result_record.items():
- if only_best_config:
- best_gemm_config, best_measurement = configs_with_measurements[0]
- yield BenchmarkResult(gemm_param, strategy, best_gemm_config, best_measurement)
- else:
- for gemm_config, measurement in configs_with_measurements:
- yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement)
-
class GEMMConfigDistribution:
""" A representation of the GEMM Configuration distribution produced by the GEMMBenchmarkResultRecorder.
@@ -299,18 +326,13 @@ class GEMMConfigDistribution:
self._gemm_config_dist[gemm_config].append((gemm_param, measurement))
self._gemm_config_freq[gemm_config] += 1
- def get_measurement(self, gemm_config, measure=min):
- """ Get measurement of a gemm_config
- """
- return measure(list(zip(*self._gemm_config_dist[gemm_config]))[1])
-
def distribution(self):
return self._gemm_config_dist
def frequency(self):
""" Get the frequency of each (best) gemm config recorded
"""
- return self._gemm_config_freq.copy()
+ return self._gemm_config_freq.most_common()
def best_config(self):
""" Get the overall best config, as voted by all benchmark results.
@@ -392,7 +414,9 @@ def parse_benchmark_commandline(commandline: str) -> Dict[str, str]:
return {transform(name): val for name, val in args}
-def extract_benchmark_results(json_results: Dict) -> Generator[BenchmarkResult, None, None]:
+def extract_benchmark_results(
+ json_results: Dict, measurement_method="avg"
+) -> Generator[BenchmarkResult, None, None]:
""" Parse the benchmark result and extract relevant information, namely:
GEMM param,
Strategy,
@@ -430,8 +454,14 @@ def extract_benchmark_results(json_results: Dict) -> Generator[BenchmarkResult,
# Get instrument name and assert that it is the one we expect
measurement_instrument_name = measurement_instrument.split("/")[0]
assert measurement_instrument_name == "OpenCLTimer"
- # Take the MINIMUM of the raw data as the measurement value
- measurement_val = min(data["raw"])
+ # Take either the minimum or the average of the raw data as the measurement value
+ if measurement_method == "min":
+ measurement_val = min(data["raw"])
+ elif measurement_method == "avg":
+ measurement_val = sum(data["raw"]) / len(data["raw"])
+ else:
+ raise ValueError("Invalid measurement method: {}".format(measurement_method))
+
measurement = Measurement(measurement_val)
yield BenchmarkResult(gemm_param, strategy, gemm_config, measurement)
@@ -456,7 +486,7 @@ def main(args):
benchmark_results = extract_benchmark_results(parse_json(args.benchmark_results_dir))
# Add all benchmark results to the recorder
- benchmark_result_recorder = GEMMBenchmarkResultRecorder()
+ benchmark_result_recorder = GEMMBenchmarkResultRecorder(tol=args.tolerance)
for benchmark_result in benchmark_results:
benchmark_result_recorder.add(benchmark_result)
@@ -466,20 +496,18 @@ def main(args):
recorder_sum_level = GEMMBenchmarkResultRecorder.SummaryLevel.Short
# Print overall summary of the recorded results
- print(benchmark_result_recorder.summary(sum_level=recorder_sum_level))
+ logging.info(benchmark_result_recorder.summary(sum_level=recorder_sum_level))
# Get GEMM configuration distributions for each strategy
all_config_dists = benchmark_result_recorder.get_config_distributions()
- print("=== Result ===")
+ logging.info("=== Result ===")
for strategy, config_dist in all_config_dists.items():
- print("Strategy: {}".format(strategy.name))
- print("GEMM Config votes: ")
- print("GEMM Config: Best measurement, Vote")
- for config, freq in config_dist.frequency().items():
- print(config, end=": ")
- print(config_dist.get_measurement(config), freq, sep=",")
- print(
+ logging.info("Strategy: {}".format(strategy.name))
+ logging.debug("GEMM Config, Votes")
+ for config, freq in config_dist.frequency():
+ logging.debug("{}, {}".format(config, freq))
+ logging.info(
"Best GEMM Config: {} with std: {}".format(config_dist.best_config(), config_dist.std())
)
@@ -513,6 +541,15 @@ if __name__ == "__main__":
help="Path to directory that holds output csv files. One per strategy",
)
parser.add_argument(
+ "-t",
+ "--tolerance",
+ action="store",
+ type=float,
+ default=0.01,
+ help="For testing if two GEMMConfigs are equivalent in terms of performance. The tolerance is OpenCL timer in\
+ milliseconds. Recommended value: <= 0.1 ms",
+ )
+ parser.add_argument(
"-D", "--debug", dest="debug", action="store_true", help="Enable script debugging output"
)
args = parser.parse_args()