diff options
author | Tim Hall <tim.hall@arm.com> | 2020-11-09 16:46:37 +0000 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2020-11-11 11:38:41 +0000 |
commit | e6ccd87a2f40877cacdd9721a5116a6853dfe573 (patch) | |
tree | 8e22dacc02e82df59cb460b68d39e5fd338abf4d /ethosu/vela/vela.py | |
parent | e168b969dc75fc3057413a80fdf0e164ab936910 (diff) | |
download | ethos-u-vela-e6ccd87a2f40877cacdd9721a5116a6853dfe573.tar.gz |
MLBEDSW-3019: Add profiling debug database
- Added mechanism to track input to output graph transforms for
debugging the resultant command stream.
- Provides base implementation for MLBEDSW-2661
Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I2dfe8a409fbde7ad0282bfab5acb11ba1c8b82d8
Diffstat (limited to 'ethosu/vela/vela.py')
-rw-r--r-- | ethosu/vela/vela.py | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py index 4b43751a..5df20d22 100644 --- a/ethosu/vela/vela.py +++ b/ethosu/vela/vela.py @@ -31,6 +31,7 @@ from . import scheduler from . import stats_writer from . import tflite_writer from ._version import __version__ +from .debug_database import DebugDatabase from .errors import InputFileError from .nn_graph import PassPlacement from .nn_graph import TensorAllocator @@ -39,14 +40,18 @@ from .tensor import MemArea from .tensor import Tensor -def process(fname, arch, model_reader_options, compiler_options, scheduler_options): +def process(input_name, enable_debug_db, arch, model_reader_options, compiler_options, scheduler_options): if compiler_options.timing: start = time.time() - nng = model_reader.read_model(fname, model_reader_options) + os.makedirs(compiler_options.output_dir, exist_ok=True) + output_basename = os.path.join(compiler_options.output_dir, os.path.splitext(os.path.basename(input_name))[0]) + DebugDatabase.show_warnings = enable_debug_db + + nng = model_reader.read_model(input_name, model_reader_options) if not nng: - raise InputFileError(fname, "input file could not be read") + raise InputFileError(input_name, "input file could not be read") if compiler_options.verbose_operators: nng.print_operators() @@ -58,16 +63,21 @@ def process(fname, arch, model_reader_options, compiler_options, scheduler_optio compiler_driver.compiler_driver(nng, arch, compiler_options, scheduler_options) - passes_csv_file = "%s/%s_pass-breakdown_%s.csv" % (compiler_options.output_dir, nng.name, arch.system_config) + passes_csv_file = "{0}_pass-breakdown_{1}.csv".format(output_basename, arch.system_config) stats_writer.write_pass_metrics_csv(nng, passes_csv_file) - summary_csv_file = "%s/%s_summary_%s.csv" % (compiler_options.output_dir, nng.name, arch.system_config) + summary_csv_file = "{0}_summary_{1}.csv".format(output_basename, arch.system_config) stats_writer.write_summary_metrics_csv(nng, summary_csv_file, arch) stats_writer.print_performance_metrics(nng, show_cpu_operations=compiler_options.show_cpu_operations, arch=arch) - if fname.endswith(".tflite"): - tflite_writer.write_tflite(nng, "%s/%s_vela.tflite" % (compiler_options.output_dir, nng.name)) + output_filename = output_basename + "_vela.tflite" + if input_name.endswith(".tflite"): + tflite_writer.write_tflite(nng, output_filename) + + if enable_debug_db: + debug_filename = output_basename + "_debug.xml" + DebugDatabase.write(debug_filename, input_name, output_filename) if compiler_options.timing: stop = time.time() @@ -123,6 +133,13 @@ def main(args=None): parser.add_argument( "--output-dir", type=str, default="output", help="Output directory to write files to (default: %(default)s)" ) + parser.add_argument( + "--enable-debug-db", + action="store_true", + default=None, + help="Enables the calculation and writing of a network debug database to output directory", + ) + parser.add_argument("--config", type=str, help="Location of vela configuration file") parser.add_argument("--verbose-graph", action="store_true", help="Verbose graph rewriter") @@ -319,9 +336,7 @@ def main(args=None): model_reader_options = model_reader.ModelReaderOptions() - os.makedirs(args.output_dir, exist_ok=True) - - nng = process(args.network, arch, model_reader_options, compiler_options, scheduler_options) + nng = process(args.network, args.enable_debug_db, arch, model_reader_options, compiler_options, scheduler_options) if args.show_subgraph_io_summary: print_subgraph_io_summary(nng) |