aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/vela.py
diff options
context:
space:
mode:
authorDiqing Zhong <diqing.zhong@arm.com>2021-08-16 17:24:09 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2021-09-01 12:23:28 +0000
commit5e5a7847b8fc1eb261c7561f44585d2f6b524df3 (patch)
tree586d7a83b6b4da362f879620940b0146ca4428a7 /ethosu/vela/vela.py
parente389652ccb9821f8e959d533efd553f0d5e200d9 (diff)
downloadethos-u-vela-5e5a7847b8fc1eb261c7561f44585d2f6b524df3.tar.gz
TOSA raw data output
- Add TOSA output generation in npz format Change-Id: I97822e3a93a8fef1a95a990f23ef2c4ca5a8f73a Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
Diffstat (limited to 'ethosu/vela/vela.py')
-rw-r--r--ethosu/vela/vela.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 9e237f84..7400b8e9 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -27,6 +27,7 @@ import flatbuffers
from . import architecture_features
from . import compiler_driver
from . import model_reader
+from . import rawdata_writer
from . import scheduler
from . import stats_writer
from . import tflite_writer
@@ -83,18 +84,20 @@ def process(input_name, enable_debug_db, arch, model_reader_options, compiler_op
arch=arch,
)
- output_filename = output_basename + "_vela.tflite"
+ output_tfl_filename = output_basename + "_vela.tflite"
if input_name.endswith(".tflite"):
- tflite_writer.write_tflite(nng, output_filename)
+ tflite_writer.write_tflite(nng, output_tfl_filename)
+ elif input_name.endswith(".tosa"):
+ rawdata_writer.write_rawdata_output(nng, arch, output_basename)
if enable_debug_db:
- file_offsets = calculate_operator_file_offsets(output_filename)
+ file_offsets = calculate_operator_file_offsets(output_tfl_filename)
for idx, offset in enumerate(sorted(file_offsets)):
sg = find_subgraph_with_command_stream_order(nng, idx)
if sg is not None:
DebugDatabase.set_stream_offset(sg, offset)
debug_filename = output_basename + "_debug.xml"
- DebugDatabase.write(debug_filename, input_name, output_filename)
+ DebugDatabase.write(debug_filename, input_name, output_tfl_filename)
if compiler_options.timing:
stop = time.time()