aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/vela.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/vela.py')
-rw-r--r--ethosu/vela/vela.py35
1 files changed, 25 insertions, 10 deletions
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 4b43751a..5df20d22 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -31,6 +31,7 @@ from . import scheduler
from . import stats_writer
from . import tflite_writer
from ._version import __version__
+from .debug_database import DebugDatabase
from .errors import InputFileError
from .nn_graph import PassPlacement
from .nn_graph import TensorAllocator
@@ -39,14 +40,18 @@ from .tensor import MemArea
from .tensor import Tensor
-def process(fname, arch, model_reader_options, compiler_options, scheduler_options):
+def process(input_name, enable_debug_db, arch, model_reader_options, compiler_options, scheduler_options):
if compiler_options.timing:
start = time.time()
- nng = model_reader.read_model(fname, model_reader_options)
+ os.makedirs(compiler_options.output_dir, exist_ok=True)
+ output_basename = os.path.join(compiler_options.output_dir, os.path.splitext(os.path.basename(input_name))[0])
+ DebugDatabase.show_warnings = enable_debug_db
+
+ nng = model_reader.read_model(input_name, model_reader_options)
if not nng:
- raise InputFileError(fname, "input file could not be read")
+ raise InputFileError(input_name, "input file could not be read")
if compiler_options.verbose_operators:
nng.print_operators()
@@ -58,16 +63,21 @@ def process(fname, arch, model_reader_options, compiler_options, scheduler_optio
compiler_driver.compiler_driver(nng, arch, compiler_options, scheduler_options)
- passes_csv_file = "%s/%s_pass-breakdown_%s.csv" % (compiler_options.output_dir, nng.name, arch.system_config)
+ passes_csv_file = "{0}_pass-breakdown_{1}.csv".format(output_basename, arch.system_config)
stats_writer.write_pass_metrics_csv(nng, passes_csv_file)
- summary_csv_file = "%s/%s_summary_%s.csv" % (compiler_options.output_dir, nng.name, arch.system_config)
+ summary_csv_file = "{0}_summary_{1}.csv".format(output_basename, arch.system_config)
stats_writer.write_summary_metrics_csv(nng, summary_csv_file, arch)
stats_writer.print_performance_metrics(nng, show_cpu_operations=compiler_options.show_cpu_operations, arch=arch)
- if fname.endswith(".tflite"):
- tflite_writer.write_tflite(nng, "%s/%s_vela.tflite" % (compiler_options.output_dir, nng.name))
+ output_filename = output_basename + "_vela.tflite"
+ if input_name.endswith(".tflite"):
+ tflite_writer.write_tflite(nng, output_filename)
+
+ if enable_debug_db:
+ debug_filename = output_basename + "_debug.xml"
+ DebugDatabase.write(debug_filename, input_name, output_filename)
if compiler_options.timing:
stop = time.time()
@@ -123,6 +133,13 @@ def main(args=None):
parser.add_argument(
"--output-dir", type=str, default="output", help="Output directory to write files to (default: %(default)s)"
)
+ parser.add_argument(
+ "--enable-debug-db",
+ action="store_true",
+ default=None,
+ help="Enables the calculation and writing of a network debug database to output directory",
+ )
+
parser.add_argument("--config", type=str, help="Location of vela configuration file")
parser.add_argument("--verbose-graph", action="store_true", help="Verbose graph rewriter")
@@ -319,9 +336,7 @@ def main(args=None):
model_reader_options = model_reader.ModelReaderOptions()
- os.makedirs(args.output_dir, exist_ok=True)
-
- nng = process(args.network, arch, model_reader_options, compiler_options, scheduler_options)
+ nng = process(args.network, args.enable_debug_db, arch, model_reader_options, compiler_options, scheduler_options)
if args.show_subgraph_io_summary:
print_subgraph_io_summary(nng)