aboutsummaryrefslogtreecommitdiff
path: root/examples/gemm_tuner/GemmTuner.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/gemm_tuner/GemmTuner.py')
-rw-r--r--examples/gemm_tuner/GemmTuner.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/examples/gemm_tuner/GemmTuner.py b/examples/gemm_tuner/GemmTuner.py
index aab2d55e89..1361f90929 100644
--- a/examples/gemm_tuner/GemmTuner.py
+++ b/examples/gemm_tuner/GemmTuner.py
@@ -48,10 +48,11 @@ class GEMMParam(NamedTuple):
N: int # Number of rhs matrix columns
K: int # Number of lhs matrix columns/rhs matrix rows
B: int # Batch size
+ data_type: str # Data type
@staticmethod
- def parse_from_strs(*args):
- return GEMMParam(*map(int, args))
+ def parse_from_strs(*M_N_K_B, data_type):
+ return GEMMParam(*map(int, M_N_K_B),str(data_type))
def __str__(self):
return ",".join(map(str, self))
@@ -441,9 +442,10 @@ EXAMPLE_FILE_2_STRATEGY = {
# <-GEMMParam-><-------------GEMMConfig-------------->
# Note that the test strategy_name == strategy.name is in place to avoid unwanted enum aliases
GEMM_EXAMPLE_ARGS_FACTORY = {
+ # We ignore the data type field from GEMMParam as that is extracted separately
strategy: namedtuple(
"{}_Gemm_Example_Args".format(strategy_name),
- GEMMParam._fields + GEMM_CONFIG_FACTORY[strategy]._fields,
+ GEMMParam._fields[:-1] + GEMM_CONFIG_FACTORY[strategy]._fields,
)
for strategy_name, strategy in Strategy.__members__.items()
if strategy_name == strategy.name
@@ -460,6 +462,9 @@ BENCHMARK_RESULT_JSON_EXTENSION = "gemmtuner_benchmark"
def parse_benchmark_commandline(commandline: str) -> Dict[str, str]:
""" Parse the benchmark example command-line string into a dictionary of command-line agruments
"""
+ # Separate the data type option from the example_args portion of the string
+ commandline = commandline.replace(",--type=", " --type=")
+
args = commandline.split()
# Discard program name
args = args[1:]
@@ -502,9 +507,11 @@ def extract_benchmark_results(
example_args = Gemm_Example_Args_T(
*(benchmark_args["example_args"].split(",")))
# Gemm_Example_Arg consists of GEMMParam first and then GEMMConfig (in that order)
- gemm_param_fields_len = len(GEMMParam._fields)
+ # However data type option is parsed separately from end of options, hence -1 is applied to fields length
+ gemm_param_fields_len = len(GEMMParam._fields) - 1
gemm_param = GEMMParam.parse_from_strs(
- *example_args[:gemm_param_fields_len])
+ *example_args[:gemm_param_fields_len],
+ data_type = benchmark_args["type"])
GEMMConfig = GEMM_CONFIG_FACTORY[strategy]
gemm_config = GEMMConfig.parse_from_strs(
*example_args[gemm_param_fields_len:])