diff options
Diffstat (limited to 'ethosu/vela/vela.py')
-rw-r--r-- | ethosu/vela/vela.py | 41 |
1 files changed, 22 insertions, 19 deletions
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py index 8a808276..63cccc5b 100644 --- a/ethosu/vela/vela.py +++ b/ethosu/vela/vela.py @@ -37,7 +37,6 @@ from .debug_database import DebugDatabase from .errors import InputFileError from .errors import VelaError from .nn_graph import NetworkType -from .nn_graph import PassPlacement from .nn_graph import TensorAllocator from .tensor import MemArea from .tensor import Tensor @@ -140,30 +139,34 @@ def print_subgraph_io_summary(nng): print("Subgraph IO Summary") print("-------------------") - print("NNG: {0}".format(nng.name)) + print(f"NNG: {nng.name}") max_sg_size = 0 for sg in reversed(nng.subgraphs): - print(" Subgraph: {0} = {1}".format(sg.name, sg.placement)) + print(f" NNG Subgraph: {sg.name} = {sg.placement}") sg_size = 0 - if sg.placement == PassPlacement.Npu: - for tens in sg.input_tensors + [sg.scratch_tensor] + sg.output_tensors: - if tens in sg.input_tensors: - tens_dir = "In" - elif tens in sg.output_tensors: - tens_dir = "Out" - else: - tens_dir = "In/Out" - - size = tens.elements() * tens.element_size() / 1024.0 - sg_size = sg_size + size - print(" Tensor [{0}]: {1} = {2} KiB".format(tens_dir, tens.name, size)) - - print(" Total Size = {0} KiB".format(sg_size)) - print(" SRAM Memory Used = {0} KiB".format(sg.memory_used.get(MemArea.Sram, 0) / 1024.0)) + if hasattr(sg, "scratch_tensor") and sg.scratch_tensor is not None: + sg_tensors = sg.input_tensors + [sg.scratch_tensor] + sg.output_tensors + else: + sg_tensors = sg.input_tensors + sg.output_tensors + + for tens in sg_tensors: + if tens in sg.input_tensors: + tens_dir = "In" + elif tens in sg.output_tensors: + tens_dir = "Out" + else: + tens_dir = "In/Out" + + size = tens.elements() * tens.element_size() / 1024.0 + sg_size = sg_size + size + print(f" Tensor [{tens_dir}]: {tens.name} = {size} KiB") + + print(f" Total Size = {sg_size} KiB") + print(f" SRAM Memory Used = {sg.memory_used.get(MemArea.Sram, 0) / 1024.0} KiB") max_sg_size = max(sg_size, max_sg_size) - print(" Maximum Subgraph Size = {0} KiB".format(max_sg_size)) + print(f" Maximum NNG Subgraph Size = {max_sg_size} KiB") def generate_supported_ops(): |