From b3d941ec44d44bdbcffadb3d62bda8ae4b001655 Mon Sep 17 00:00:00 2001 From: Fredrik Svedberg Date: Wed, 13 Oct 2021 14:06:03 +0200 Subject: MLBEDSW-5396 Vela print_graph_with_tensors crash for optional tensor Fixed crash in nn_graph.print_graph_with_tensors() and nn_graph.print_graph_with_tensor_quantization() for optional input tensors. Signed-off-by: Fredrik Svedberg Change-Id: I7a2d23892558006485c5c84842d65aa221dba44b --- ethosu/vela/nn_graph.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index 97afde31..8a2517de 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -381,14 +381,17 @@ class Subgraph: for idx, op in enumerate(all_ops): print(idx, op.type, op.name) for idx, tens in enumerate(op.inputs): - print( - " Input %02d %20s %20s %20s %s" - % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens) - ) + if tens: + print( + f" Input {idx:02d}" + f" {tens.purpose.name:>20} {tens.mem_area.name:>20} {tens.mem_type.name:>20} {tens}" + ) + else: + print(f" Input {idx:02d} {'-':>20} {'-':>20} {'-':>20} {tens}") for idx, tens in enumerate(op.outputs): print( - " Output %02d %20s %20s %20s %s" - % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens) + f" Output {idx:02d}" + f" {tens.purpose.name:>20} {tens.mem_area.name:>20} {tens.mem_type.name:>20} {tens}" ) print() @@ -398,22 +401,25 @@ class Subgraph: for idx, op in enumerate(all_ops): print(idx, op.type, op.name) for idx, tens in enumerate(op.inputs): - q = tens.quantization - if q is None: - print(" Input %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name)) + if tens: + q = tens.quantization + if q is None: + print(f" Input {idx:02d} {tens.dtype!s:>10} NO QUANTIZATION INFO {tens.name}") + else: + print( + f" Input {idx:02d} {tens.dtype!s:>10}" + f" min={q.min} max={q.max} scale={q.scale_f32!s} zero_point={q.zero_point} {tens.name}" + ) else: - print( - " Input %02d %10s min=%s max=%s scale=%s zero_point=%s %s" - % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name) - ) + print(f" Input {idx:02d} {'-':>10} {tens}") for idx, tens in enumerate(op.outputs): q = tens.quantization if q is None: - print(" Output %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name)) + print(f" Output {idx:02d} {tens.dtype!s:>10} NO QUANTIZATION INFO {tens.name}") else: print( - " Output %02d %10s min=%s max=%s scale=%s zero_point=%s %s" - % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name) + f" Output {idx:02d} {tens.dtype!s:>10}" + f" min={q.min} max={q.max} scale={q.scale_f32!s} zero_point={q.zero_point} {tens.name}" ) print() -- cgit v1.2.1