aboutsummaryrefslogtreecommitdiff
path: root/verif/runner
diff options
context:
space:
mode:
Diffstat (limited to 'verif/runner')
-rw-r--r--verif/runner/run_command.py61
-rw-r--r--verif/runner/tosa_ref_run.py78
-rw-r--r--verif/runner/tosa_refmodel_sut_run.py73
-rw-r--r--verif/runner/tosa_test_runner.py212
-rw-r--r--verif/runner/tosa_verif_run_ref.py267
-rw-r--r--verif/runner/tosa_verif_run_tests.py375
6 files changed, 676 insertions, 390 deletions
diff --git a/verif/runner/run_command.py b/verif/runner/run_command.py
new file mode 100644
index 0000000..eef5a76
--- /dev/null
+++ b/verif/runner/run_command.py
@@ -0,0 +1,61 @@
+"""Shell command runner function."""
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import shlex
+import subprocess
+
+
+class RunShCommandError(Exception):
+ """Exception raised for errors running the shell command.
+
+ Attributes:
+ return_code - non-zero return code from running command
+ full_cmd_esc - command and arguments list (pre-escaped)
+ stderr - (optional) - standard error output
+ """
+
+ def __init__(self, return_code, full_cmd_esc, stderr=None, stdout=None):
+ """Initialize run shell command error."""
+ self.return_code = return_code
+ self.full_cmd_esc = full_cmd_esc
+ self.stderr = stderr
+ self.stdout = stdout
+ self.message = "Error {} running command: {}".format(
+ self.return_code, " ".join(self.full_cmd_esc)
+ )
+ if stdout:
+ self.message = "{}\n{}".format(self.message, self.stdout)
+ if stderr:
+ self.message = "{}\n{}".format(self.message, self.stderr)
+ super().__init__(self.message)
+
+
+def run_sh_command(full_cmd, verbose=False, capture_output=False):
+ """Run an external shell command.
+
+ full_cmd: array containing shell command and its arguments
+ verbose: optional flag that enables verbose output
+ capture_output: optional flag to return captured stdout/stderr
+ """
+ # Quote the command line for printing
+ full_cmd_esc = [shlex.quote(x) for x in full_cmd]
+
+ if verbose:
+ print("### Running {}".format(" ".join(full_cmd_esc)))
+
+ if capture_output:
+ rc = subprocess.run(full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout = rc.stdout.decode("utf-8")
+ stderr = rc.stderr.decode("utf-8")
+ if verbose:
+ if stdout:
+ print(stdout, end="")
+ if stderr:
+ print(stderr, end="")
+ else:
+ stdout, stderr = None, None
+ rc = subprocess.run(full_cmd)
+
+ if rc.returncode != 0:
+ raise RunShCommandError(rc.returncode, full_cmd_esc, stderr, stdout)
+ return (stdout, stderr)
diff --git a/verif/runner/tosa_ref_run.py b/verif/runner/tosa_ref_run.py
deleted file mode 100644
index c1d5e79..0000000
--- a/verif/runner/tosa_ref_run.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Copyright (c) 2020-2021, ARM Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import json
-import shlex
-import subprocess
-from enum import Enum, IntEnum, unique
-from runner.tosa_test_runner import TosaTestRunner, run_sh_command
-
-
-@unique
-class TosaReturnCode(IntEnum):
- VALID = 0
- UNPREDICTABLE = 1
- ERROR = 2
-
-
-class TosaRefRunner(TosaTestRunner):
- def __init__(self, args, runnerArgs, testDir):
- super().__init__(args, runnerArgs, testDir)
-
- def runModel(self):
- # Build up the TOSA reference command line
- # Uses arguments from the argParser args, not the runnerArgs
- args = self.args
-
- ref_cmd = [
- args.ref_model_path,
- "-Ctest_desc={}".format(os.path.join(self.testDir, "desc.json")),
- ]
-
- if args.ref_debug:
- ref_cmd.extend(["-dALL", "-l{}".format(args.ref_debug)])
-
- if args.ref_intermediates:
- ref_cmd.extend(["-Ddump_intermediates=1"])
-
- expectedReturnCode = self.testDesc["expected_return_code"]
-
- try:
- rc = run_sh_command(self.args, ref_cmd)
- if rc == TosaReturnCode.VALID:
- if expectedReturnCode == TosaReturnCode.VALID:
- result = TosaTestRunner.Result.EXPECTED_PASS
- else:
- result = TosaTestRunner.Result.UNEXPECTED_PASS
- elif rc == TosaReturnCode.ERROR:
- if expectedReturnCode == TosaReturnCode.ERROR:
- result = TosaTestRunner.Result.EXPECTED_FAILURE
- else:
- result = TosaTestRunner.Result.UNEXPECTED_FAILURE
- elif rc == TosaReturnCode.UNPREDICTABLE:
- if expectedReturnCode == TosaReturnCode.UNPREDICTABLE:
- result = TosaTestRunner.Result.EXPECTED_FAILURE
- else:
- result = TosaTestRunner.Result.UNEXPECTED_FAILURE
- elif rc < 0:
- # Unix signal caught (e.g., SIGABRT, SIGSEGV, SIGFPE, etc)
- result = TosaTestRunner.Result.INTERNAL_ERROR
- else:
- raise Exception(f"Return code ({rc}) unknown.")
-
- except Exception as e:
- raise Exception("Runtime Error when running: {}".format(" ".join(ref_cmd)))
-
- return result
diff --git a/verif/runner/tosa_refmodel_sut_run.py b/verif/runner/tosa_refmodel_sut_run.py
new file mode 100644
index 0000000..b9a9575
--- /dev/null
+++ b/verif/runner/tosa_refmodel_sut_run.py
@@ -0,0 +1,73 @@
+"""TOSA test runner module for the Reference Model."""
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+from enum import IntEnum
+from enum import unique
+
+from runner.run_command import run_sh_command
+from runner.run_command import RunShCommandError
+from runner.tosa_test_runner import TosaTestRunner
+
+
+@unique
+class TosaRefReturnCode(IntEnum):
+ """Return codes from the Tosa Reference Model."""
+
+ VALID = 0
+ UNPREDICTABLE = 1
+ ERROR = 2
+
+
+class TosaSUTRunner(TosaTestRunner):
+ """TOSA Reference Model runner."""
+
+ def __init__(self, args, runnerArgs, testDir):
+ """Initialize using the given test details."""
+ super().__init__(args, runnerArgs, testDir)
+
+ def runTestGraph(self):
+ """Run the test on the reference model."""
+ # Build up the TOSA reference command line
+ # Uses arguments from the argParser args, not the runnerArgs
+ args = self.args
+
+ # Call Reference model with description file to provide all file details
+ cmd = [
+ args.ref_model_path,
+ "-Coperator_fbs={}".format(args.operator_fbs),
+ "-Ctest_desc={}".format(self.descFile),
+ ]
+
+ # Specific debug options for reference model
+ if args.ref_debug:
+ cmd.extend(["-dALL", "-l{}".format(args.ref_debug)])
+
+ if args.ref_intermediates:
+ cmd.extend(["-Ddump_intermediates=1"])
+
+ # Run command and interpret tosa graph result via process return codes
+ graphMessage = None
+ try:
+ run_sh_command(cmd, self.args.verbose, capture_output=True)
+ graphResult = TosaTestRunner.TosaGraphResult.TOSA_VALID
+ except RunShCommandError as e:
+ graphMessage = e.stderr
+ if e.return_code == TosaRefReturnCode.ERROR:
+ graphResult = TosaTestRunner.TosaGraphResult.TOSA_ERROR
+ elif e.return_code == TosaRefReturnCode.UNPREDICTABLE:
+ graphResult = TosaTestRunner.TosaGraphResult.TOSA_UNPREDICTABLE
+ else:
+ graphResult = TosaTestRunner.TosaGraphResult.OTHER_ERROR
+ if (
+ self.args.verbose
+ or graphResult == TosaTestRunner.TosaGraphResult.OTHER_ERROR
+ ):
+ print(e)
+
+ except Exception as e:
+ print(e)
+ graphMessage = str(e)
+ graphResult = TosaTestRunner.TosaGraphResult.OTHER_ERROR
+
+ # Return graph result and message
+ return graphResult, graphMessage
diff --git a/verif/runner/tosa_test_runner.py b/verif/runner/tosa_test_runner.py
index e8f921d..0fd7f13 100644
--- a/verif/runner/tosa_test_runner.py
+++ b/verif/runner/tosa_test_runner.py
@@ -1,68 +1,190 @@
-import os
-
-# Copyright (c) 2020, ARM Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
+"""Template test runner class for running TOSA tests."""
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
import json
-import shlex
-import subprocess
-from enum import IntEnum, unique
-
+from enum import IntEnum
+from pathlib import Path
-def run_sh_command(args, full_cmd, capture_output=False):
- """Utility function to run an external command. Optionally return captured stdout/stderr"""
+from checker.tosa_result_checker import LogColors
+from checker.tosa_result_checker import print_color
+from checker.tosa_result_checker import test_check
+from json2fbbin import json2fbbin
- # Quote the command line for printing
- full_cmd_esc = [shlex.quote(x) for x in full_cmd]
- if args.verbose:
- print("### Running {}".format(" ".join(full_cmd_esc)))
+class TosaTestInvalid(Exception):
+ """Exception raised for errors loading test description.
- if capture_output:
- rc = subprocess.run(full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- if rc.returncode != 0:
- print(rc.stdout.decode("utf-8"))
- print(rc.stderr.decode("utf-8"))
- raise Exception(
- "Error running command: {}.\n{}".format(
- " ".join(full_cmd_esc), rc.stderr.decode("utf-8")
- )
- )
- return (rc.stdout, rc.stderr)
- else:
- rc = subprocess.run(full_cmd)
+ Attributes:
+ path - full path to missing test description file
+ exception = underlying exception
+ """
- return rc.returncode
+ def __init__(self, path, exception):
+ """Initialize test not found error."""
+ self.path = path
+ self.exception = exception
+ self.message = "Invalid test, could not read test description {}: {}".format(
+ self.path, str(self.exception)
+ )
+ super().__init__(self.message)
class TosaTestRunner:
- def __init__(self, args, runnerArgs, testDir):
+ """TOSA Test Runner template class for systems under test."""
+ def __init__(self, args, runnerArgs, testDir):
+ """Initialize and load JSON meta data file."""
self.args = args
self.runnerArgs = runnerArgs
self.testDir = testDir
+ self.testName = Path(self.testDir).name
+
+ # Check if we want to run binary and if its already converted
+ descFilePath = Path(testDir, "desc.json")
+ descBinFilePath = Path(testDir, "desc_binary.json")
+ if args.binary:
+ if descBinFilePath.is_file():
+ descFilePath = descBinFilePath
+
+ try:
+ # Load the json test file
+ with open(descFilePath, "r") as fd:
+ self.testDesc = json.load(fd)
+ except Exception as e:
+ raise TosaTestInvalid(str(descFilePath), e)
+
+ # Convert to binary if needed
+ tosaFilePath = Path(testDir, self.testDesc["tosa_file"])
+ if args.binary and tosaFilePath.suffix == ".json":
+ # Convert tosa JSON to binary
+ json2fbbin.json_to_fbbin(
+ Path(args.flatc_path),
+ Path(args.operator_fbs),
+ tosaFilePath,
+ Path(testDir),
+ )
+ # Write new desc_binary file
+ self.testDesc["tosa_file"] = tosaFilePath.stem + ".tosa"
+ with open(descBinFilePath, "w") as fd:
+ json.dump(self.testDesc, fd, indent=2)
+ descFilePath = descBinFilePath
+
+ # Set location of desc.json (or desc_binary.json) file in use
+ self.descFile = str(descFilePath)
- # Load the json test file
- with open(os.path.join(testDir, "desc.json"), "r") as fd:
- self.testDesc = json.load(fd)
+ def skipTest(self):
+ """Check if the test is skipped due to test type selection."""
+ expectedFailure = self.testDesc["expected_failure"]
+ if self.args.test_type == "negative" and not expectedFailure:
+ return True
+ elif self.args.test_type == "positive" and expectedFailure:
+ return True
+ return False
- def runModel(self):
+ def runTestGraph(self):
+ """Override with function that calls system under test."""
pass
+ def testResult(self, tosaGraphResult, graphMessage=None):
+ """Work out test result based on graph result and output files."""
+ expectedFailure = self.testDesc["expected_failure"]
+ print_result_line = True
+
+ if tosaGraphResult == TosaTestRunner.TosaGraphResult.TOSA_VALID:
+ if expectedFailure:
+ result = TosaTestRunner.Result.UNEXPECTED_PASS
+ resultMessage = "Expected failure test incorrectly passed"
+ else:
+ # Work through all the results produced by the testing, assuming success
+ # but overriding this with any failures found
+ result = TosaTestRunner.Result.EXPECTED_PASS
+ messages = []
+ for resultNum, resultFileName in enumerate(self.testDesc["ofm_file"]):
+ if "expected_result_file" in self.testDesc:
+ try:
+ conformanceFile = Path(
+ self.testDir,
+ self.testDesc["expected_result_file"][resultNum],
+ )
+ except IndexError:
+ result = TosaTestRunner.Result.INTERNAL_ERROR
+ msg = "Internal error: Missing expected_result_file {} in {}".format(
+ resultNum, self.descFile
+ )
+ messages.append(msg)
+ print(msg)
+ break
+ else:
+ conformanceFile = None
+ resultFile = Path(self.testDir, resultFileName)
+
+ if conformanceFile:
+ print_result_line = False # Checker will print one for us
+ chkResult, tolerance, msg = test_check(
+ str(conformanceFile),
+ str(resultFile),
+ test_name=self.testName,
+ )
+ # Change EXPECTED_PASS assumption if we have any failures
+ if chkResult != 0:
+ result = TosaTestRunner.Result.UNEXPECTED_FAILURE
+ messages.append(msg)
+ if self.args.verbose:
+ print(msg)
+ else:
+ # No conformance file to verify, just check results file exists
+ if not resultFile.is_file():
+ result = TosaTestRunner.Result.UNEXPECTED_FAILURE
+ msg = "Results file is missing: {}".format(resultFile)
+ messages.append(msg)
+ print(msg)
+
+ if resultFile.is_file():
+ # Move the resultFile to allow subsequent system under
+ # tests to create them and to test they have been created
+ resultFile = resultFile.rename(
+ resultFile.with_suffix(
+ ".{}{}".format(self.__module__, resultFile.suffix)
+ )
+ )
+
+ resultMessage = "\n".join(messages) if len(messages) > 0 else None
+ else:
+ if (
+ expectedFailure
+ and tosaGraphResult == TosaTestRunner.TosaGraphResult.TOSA_ERROR
+ ):
+ result = TosaTestRunner.Result.EXPECTED_FAILURE
+ resultMessage = None
+ else:
+ result = TosaTestRunner.Result.UNEXPECTED_FAILURE
+ resultMessage = graphMessage
+
+ if print_result_line:
+ if (
+ result == TosaTestRunner.Result.EXPECTED_FAILURE
+ or result == TosaTestRunner.Result.EXPECTED_PASS
+ ):
+ print_color(LogColors.GREEN, "Results PASS {}".format(self.testName))
+ else:
+ print_color(LogColors.RED, "Results FAIL {}".format(self.testName))
+
+ return result, resultMessage
+
class Result(IntEnum):
+ """Test result codes."""
+
EXPECTED_PASS = 0
EXPECTED_FAILURE = 1
UNEXPECTED_PASS = 2
UNEXPECTED_FAILURE = 3
INTERNAL_ERROR = 4
+ SKIPPED = 5
+
+ class TosaGraphResult(IntEnum):
+ """The tosa_graph_result codes."""
+
+ TOSA_VALID = 0
+ TOSA_UNPREDICTABLE = 1
+ TOSA_ERROR = 2
+ OTHER_ERROR = 3
diff --git a/verif/runner/tosa_verif_run_ref.py b/verif/runner/tosa_verif_run_ref.py
deleted file mode 100644
index 626819f..0000000
--- a/verif/runner/tosa_verif_run_ref.py
+++ /dev/null
@@ -1,267 +0,0 @@
-# Copyright (c) 2020-2021, ARM Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import argparse
-import sys
-import re
-import os
-import subprocess
-import shlex
-import json
-import glob
-import math
-import queue
-import threading
-import traceback
-import importlib
-
-
-from enum import IntEnum, Enum, unique
-from datetime import datetime
-
-from xunit import xunit
-
-from runner.tosa_test_runner import TosaTestRunner
-
-no_color_printing = False
-# from run_tf_unit_test import LogColors, print_color, run_sh_command
-
-
-def parseArgs():
-
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-t",
- "--test",
- dest="test",
- type=str,
- nargs="+",
- help="Test(s) to run",
- required=True,
- )
- parser.add_argument(
- "--seed",
- dest="random_seed",
- default=42,
- type=int,
- help="Random seed for test generation",
- )
- parser.add_argument(
- "--ref-model-path",
- dest="ref_model_path",
- default="build/reference_model/tosa_reference_model",
- type=str,
- help="Path to reference model executable",
- )
- parser.add_argument(
- "--ref-debug",
- dest="ref_debug",
- default="",
- type=str,
- help="Reference debug flag (low, med, high)",
- )
- parser.add_argument(
- "--ref-intermediates",
- dest="ref_intermediates",
- default=0,
- type=int,
- help="Reference model dumps intermediate tensors",
- )
- parser.add_argument(
- "-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
- )
- parser.add_argument(
- "-j", "--jobs", dest="jobs", type=int, default=1, help="Number of parallel jobs"
- )
- parser.add_argument(
- "--sut-module",
- "-s",
- dest="sut_module",
- type=str,
- nargs="+",
- default=["runner.tosa_ref_run"],
- help="System under test module to load (derives from TosaTestRunner). May be repeated",
- )
- parser.add_argument(
- "--sut-module-args",
- dest="sut_module_args",
- type=str,
- nargs="+",
- default=[],
- help="System under test module arguments. Use sutmodulename:argvalue to pass an argument. May be repeated.",
- )
- parser.add_argument(
- "--xunit-file",
- dest="xunit_file",
- type=str,
- default="result.xml",
- help="XUnit output file",
- )
-
- args = parser.parse_args()
-
- # Autodetect CPU count
- if args.jobs <= 0:
- args.jobs = os.cpu_count()
-
- return args
-
-
-def workerThread(task_queue, runnerList, args, result_queue):
- while True:
- try:
- test = task_queue.get(block=False)
- except queue.Empty:
- break
-
- if test is None:
- break
-
- msg = ""
- start_time = datetime.now()
- try:
-
- for runnerModule, runnerArgs in runnerList:
- if args.verbose:
- print(
- "Running runner {} with test {}".format(
- runnerModule.__name__, test
- )
- )
- runner = runnerModule.TosaRefRunner(args, runnerArgs, test)
- try:
- rc = runner.runModel()
- except Exception as e:
- rc = TosaTestRunner.Result.INTERNAL_ERROR
- print(f"runner.runModel Exception: {e}")
- print(
- "".join(
- traceback.format_exception(
- etype=type(e), value=e, tb=e.__traceback__
- )
- )
- )
- except Exception as e:
- print("Internal regression error: {}".format(e))
- print(
- "".join(
- traceback.format_exception(
- etype=type(e), value=e, tb=e.__traceback__
- )
- )
- )
- rc = TosaTestRunner.Result.INTERNAL_ERROR
-
- end_time = datetime.now()
-
- result_queue.put((test, rc, msg, end_time - start_time))
- task_queue.task_done()
-
- return True
-
-
-def loadRefModules(args):
- # Returns a tuple of (runner_module, [argument list])
- runnerList = []
- for r in args.sut_module:
- if args.verbose:
- print("Loading module {}".format(r))
-
- runner = importlib.import_module(r)
-
- # Look for arguments associated with this runner
- runnerArgPrefix = "{}:".format(r)
- runnerArgList = []
- for a in args.sut_module_args:
- if a.startswith(runnerArgPrefix):
- runnerArgList.append(a[len(runnerArgPrefix) :])
- runnerList.append((runner, runnerArgList))
-
- return runnerList
-
-
-def main():
- args = parseArgs()
-
- runnerList = loadRefModules(args)
-
- threads = []
- taskQueue = queue.Queue()
- resultQueue = queue.Queue()
-
- for t in args.test:
- taskQueue.put((t))
-
- print("Running {} tests ".format(taskQueue.qsize()))
-
- for i in range(args.jobs):
- t = threading.Thread(
- target=workerThread, args=(taskQueue, runnerList, args, resultQueue)
- )
- t.setDaemon(True)
- t.start()
- threads.append(t)
-
- taskQueue.join()
-
- resultList = []
- results = [0] * len(TosaTestRunner.Result)
-
- while True:
- try:
- test, rc, msg, time_delta = resultQueue.get(block=False)
- except queue.Empty:
- break
-
- resultList.append((test, rc, msg, time_delta))
- results[rc] = results[rc] + 1
-
- xunit_result = xunit.xunit_results("Regressions")
- xunit_suite = xunit_result.create_suite("Unit tests")
-
- # Sort by test name
- for test, rc, msg, time_delta in sorted(resultList, key=lambda tup: tup[0]):
- test_name = test
- xt = xunit.xunit_test(test_name, "reference")
-
- xt.time = str(
- float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6)
- )
-
- if (
- rc == TosaTestRunner.Result.EXPECTED_PASS
- or rc == TosaTestRunner.Result.EXPECTED_FAILURE
- ):
- if args.verbose:
- print("{} {}".format(rc.name, test_name))
- else:
- xt.failed(msg)
- print("{} {}".format(rc.name, test_name))
-
- xunit_suite.tests.append(xt)
- resultQueue.task_done()
-
- xunit_result.write_results(args.xunit_file)
-
- print("Totals: ", end="")
- for result in TosaTestRunner.Result:
- print("{} {}, ".format(results[result], result.name.lower()), end="")
- print()
-
- return 0
-
-
-if __name__ == "__main__":
- exit(main())
diff --git a/verif/runner/tosa_verif_run_tests.py b/verif/runner/tosa_verif_run_tests.py
new file mode 100644
index 0000000..dd86950
--- /dev/null
+++ b/verif/runner/tosa_verif_run_tests.py
@@ -0,0 +1,375 @@
+"""TOSA verification runner script."""
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import argparse
+import glob
+import importlib
+import os
+import queue
+import threading
+import traceback
+from datetime import datetime
+from pathlib import Path
+
+from json2numpy import json2numpy
+from runner.tosa_test_runner import TosaTestInvalid
+from runner.tosa_test_runner import TosaTestRunner
+from xunit import xunit
+
+TOSA_REFMODEL_RUNNER = "runner.tosa_refmodel_sut_run"
+MAX_XUNIT_TEST_MESSAGE = 1000
+
+
+def parseArgs(argv):
+ """Parse the arguments and return the settings."""
+ parser = argparse.ArgumentParser()
+ group = parser.add_mutually_exclusive_group(required=True)
+ group.add_argument(
+ "-t",
+ "--test",
+ dest="test",
+ type=str,
+ nargs="+",
+ help="Test(s) to run",
+ )
+ group.add_argument(
+ "-T",
+ "--test-list",
+ dest="test_list_file",
+ type=Path,
+ help="File containing list of tests to run (one per line)",
+ )
+ parser.add_argument(
+ "--operator-fbs",
+ dest="operator_fbs",
+ default="conformance_tests/third_party/serialization_lib/schema/tosa.fbs",
+ type=str,
+ help="flat buffer syntax file",
+ )
+ parser.add_argument(
+ "--ref-model-path",
+ dest="ref_model_path",
+ default="reference_model/build/reference_model/tosa_reference_model",
+ type=str,
+ help="Path to reference model executable",
+ )
+ parser.add_argument(
+ "--flatc-path",
+ dest="flatc_path",
+ default="reference_model/build/thirdparty/serialization_lib/third_party/flatbuffers/flatc",
+ type=str,
+ help="Path to flatc compiler executable",
+ )
+ parser.add_argument(
+ "--ref-debug",
+ dest="ref_debug",
+ default="",
+ type=str,
+ help="Reference debug flag (low, med, high)",
+ )
+ parser.add_argument(
+ "--ref-intermediates",
+ dest="ref_intermediates",
+ default=0,
+ type=int,
+ help="Reference model dumps intermediate tensors",
+ )
+ parser.add_argument(
+ "-b",
+ "--binary",
+ dest="binary",
+ action="store_true",
+ help="Convert to using binary flatbuffers instead of JSON",
+ )
+ parser.add_argument(
+ "-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
+ )
+ parser.add_argument(
+ "-j", "--jobs", dest="jobs", type=int, default=1, help="Number of parallel jobs"
+ )
+ parser.add_argument(
+ "--sut-module",
+ "-s",
+ dest="sut_module",
+ type=str,
+ nargs="+",
+ default=[TOSA_REFMODEL_RUNNER],
+ help="System under test module to load (derives from TosaTestRunner). May be repeated",
+ )
+ parser.add_argument(
+ "--sut-module-args",
+ dest="sut_module_args",
+ type=str,
+ nargs="+",
+ default=[],
+ help="System under test module arguments. Use sutmodulename:argvalue to pass an argument. May be repeated.",
+ )
+ parser.add_argument(
+ "--xunit-file",
+ dest="xunit_file",
+ type=str,
+ default="result.xml",
+ help="XUnit output file",
+ )
+ parser.add_argument(
+ "--test-type",
+ dest="test_type",
+ type=str,
+ default="both",
+ choices=["positive", "negative", "both"],
+ help="Filter tests based on expected failure status (positive, negative or both)",
+ )
+
+ args = parser.parse_args(argv)
+
+ # Autodetect CPU count
+ if args.jobs <= 0:
+ args.jobs = os.cpu_count()
+
+ return args
+
+
+EXCLUSION_PREFIX = ["test", "model", "desc"]
+
+
+def convert2Numpy(testDir):
+ """Convert all the JSON numpy files back into binary numpy."""
+ jsons = glob.glob(os.path.join(testDir, "*.json"))
+ for json in jsons:
+ for exclude in EXCLUSION_PREFIX:
+ if os.path.basename(json).startswith(exclude):
+ json = ""
+ if json:
+ # debug print("Converting " + json)
+ json2numpy.json_to_npy(Path(json))
+
+
+def workerThread(task_queue, runnerList, args, result_queue):
+ """Worker thread that runs the next test from the queue."""
+ while True:
+ try:
+ test = task_queue.get(block=False)
+ except queue.Empty:
+ break
+
+ if test is None:
+ break
+
+ msg = ""
+ converted = False
+ for runnerModule, runnerArgs in runnerList:
+ try:
+ start_time = datetime.now()
+ # Set up system under test runner
+ runnerName = runnerModule.__name__
+ runner = runnerModule.TosaSUTRunner(args, runnerArgs, test)
+
+ if runner.skipTest():
+ msg = "Skipping non-{} test".format(args.test_type)
+ print("{} {}".format(msg, test))
+ rc = TosaTestRunner.Result.SKIPPED
+ else:
+ # Convert JSON data files into numpy format on first pass
+ if not converted:
+ convert2Numpy(test)
+ converted = True
+
+ if args.verbose:
+ print("Running runner {} with test {}".format(runnerName, test))
+ try:
+ grc, gmsg = runner.runTestGraph()
+ rc, msg = runner.testResult(grc, gmsg)
+ except Exception as e:
+ msg = "System Under Test error: {}".format(e)
+ print(msg)
+ print(
+ "".join(
+ traceback.format_exception(
+ etype=type(e), value=e, tb=e.__traceback__
+ )
+ )
+ )
+ rc = TosaTestRunner.Result.INTERNAL_ERROR
+ except Exception as e:
+ msg = "Internal error: {}".format(e)
+ print(msg)
+ if not isinstance(e, TosaTestInvalid):
+ # Show stack trace on unexpected exceptions
+ print(
+ "".join(
+ traceback.format_exception(
+ etype=type(e), value=e, tb=e.__traceback__
+ )
+ )
+ )
+ rc = TosaTestRunner.Result.INTERNAL_ERROR
+ finally:
+ end_time = datetime.now()
+ result_queue.put((runnerName, test, rc, msg, end_time - start_time))
+
+ task_queue.task_done()
+
+ return True
+
+
+def loadSUTRunnerModules(args):
+ """Load in the system under test modules.
+
+ Returns a list of tuples of (runner_module, [argument list])
+ """
+ runnerList = []
+ # Remove any duplicates from the list
+ sut_module_list = list(set(args.sut_module))
+ for r in sut_module_list:
+ if args.verbose:
+ print("Loading module {}".format(r))
+
+ runner = importlib.import_module(r)
+
+ # Look for arguments associated with this runner
+ runnerArgPrefix = "{}:".format(r)
+ runnerArgList = []
+ for a in args.sut_module_args:
+ if a.startswith(runnerArgPrefix):
+ runnerArgList.append(a[len(runnerArgPrefix) :])
+ runnerList.append((runner, runnerArgList))
+
+ return runnerList
+
+
+def createXUnitResults(xunitFile, runnerList, resultLists, verbose):
+ """Create the xunit results file."""
+ xunit_result = xunit.xunit_results()
+
+ for runnerModule, _ in runnerList:
+ # Create test suite per system under test (runner)
+ runner = runnerModule.__name__
+ xunit_suite = xunit_result.create_suite(runner)
+
+ # Sort by test name
+ for test, rc, msg, time_delta in sorted(
+ resultLists[runner], key=lambda tup: tup[0]
+ ):
+ test_name = test
+ xt = xunit.xunit_test(test_name, runner)
+
+ xt.time = str(
+ float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6)
+ )
+
+ testMsg = rc.name if not msg else "{}: {}".format(rc.name, msg)
+
+ if (
+ rc == TosaTestRunner.Result.EXPECTED_PASS
+ or rc == TosaTestRunner.Result.EXPECTED_FAILURE
+ ):
+ if verbose:
+ print("{} {} ({})".format(rc.name, test_name, runner))
+ elif rc == TosaTestRunner.Result.SKIPPED:
+ xt.skipped()
+ if verbose:
+ print("{} {} ({})".format(rc.name, test_name, runner))
+ else:
+ xt.failed(testMsg)
+ print("{} {} ({})".format(rc.name, test_name, runner))
+
+ xunit_suite.tests.append(xt)
+
+ xunit_result.write_results(xunitFile)
+
+
+def main(argv=None):
+ """Start worker threads to do the testing and outputs the results."""
+ args = parseArgs(argv)
+
+ if TOSA_REFMODEL_RUNNER in args.sut_module and not os.path.isfile(
+ args.ref_model_path
+ ):
+ print(
+ "Argument error: Reference Model not found ({})".format(args.ref_model_path)
+ )
+ exit(2)
+
+ if args.test_list_file:
+ try:
+ with open(args.test_list_file) as f:
+ args.test = f.read().splitlines()
+ except Exception as e:
+ print(
+ "Argument error: Cannot read list of tests in {}\n{}".format(
+ args.test_list_file, e
+ )
+ )
+ exit(2)
+
+ runnerList = loadSUTRunnerModules(args)
+
+ threads = []
+ taskQueue = queue.Queue()
+ resultQueue = queue.Queue()
+
+ for t in args.test:
+ if os.path.isfile(t):
+ if not os.path.basename(t) == "README":
+ print("Warning: Skipping test {} as not a valid directory".format(t))
+ else:
+ taskQueue.put((t))
+
+ print(
+ "Running {} tests on {} system{} under test".format(
+ taskQueue.qsize(), len(runnerList), "s" if len(runnerList) > 1 else ""
+ )
+ )
+
+ for i in range(args.jobs):
+ t = threading.Thread(
+ target=workerThread, args=(taskQueue, runnerList, args, resultQueue)
+ )
+ t.setDaemon(True)
+ t.start()
+ threads.append(t)
+
+ taskQueue.join()
+
+ # Set up results lists for each system under test
+ resultLists = {}
+ results = {}
+ for runnerModule, _ in runnerList:
+ runner = runnerModule.__name__
+ resultLists[runner] = []
+ results[runner] = [0] * len(TosaTestRunner.Result)
+
+ while True:
+ try:
+ runner, test, rc, msg, time_delta = resultQueue.get(block=False)
+ resultQueue.task_done()
+ except queue.Empty:
+ break
+
+ # Limit error messages to make results easier to digest
+ if msg and len(msg) > MAX_XUNIT_TEST_MESSAGE:
+ half = int(MAX_XUNIT_TEST_MESSAGE / 2)
+ trimmed = len(msg) - MAX_XUNIT_TEST_MESSAGE
+ msg = "{} ...\nskipped {} bytes\n... {}".format(
+ msg[:half], trimmed, msg[-half:]
+ )
+ resultLists[runner].append((test, rc, msg, time_delta))
+ results[runner][rc] += 1
+
+ createXUnitResults(args.xunit_file, runnerList, resultLists, args.verbose)
+
+ # Print out results for each system under test
+ for runnerModule, _ in runnerList:
+ runner = runnerModule.__name__
+ resultSummary = []
+ for result in TosaTestRunner.Result:
+ resultSummary.append(
+ "{} {}".format(results[runner][result], result.name.lower())
+ )
+ print("Totals ({}): {}".format(runner, ", ".join(resultSummary)))
+
+ return 0
+
+
+if __name__ == "__main__":
+ exit(main())