aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/nn_graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/nn_graph.py')
-rw-r--r--ethosu/vela/nn_graph.py54
1 files changed, 17 insertions, 37 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index a43aac2a..6dc6b583 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -338,41 +338,21 @@ class Subgraph:
return all_ops
- def print_operators(self):
- print("print_operators()", self.name)
- all_ops = self.get_all_ops()
- unique_ops = []
- for op in all_ops:
- if op.type in (Op.Const, Op.Identity, Op.Placeholder):
- continue
-
- attrs = op.attrs.copy()
- if op.type in (Op.Conv2D, Op.Conv2DBias, Op.DepthwiseConv2DBias):
- kshape = op.inputs[1].shape
- attrs["kshape"] = [kshape[0], kshape[1]]
- attrs["type"] = op.type.name
- attrs.pop("use_cudnn_on_gpu", None)
- custom_options = attrs.pop("custom_options", None)
- if attrs not in unique_ops:
- unique_ops.append(attrs)
- # print attributes in human readable format
- a = attrs.copy()
- if custom_options is not None:
- a["custom_options"] = custom_options
- s = a.pop("type")
- data_format = a.pop("data_format", None)
- if data_format and data_format != b"NHWC":
- s += " " + str(data_format)
- t = a.pop("T", None)
- if t:
- s += " " + str(t)[9:-2]
- srct = a.pop("SrcT", None)
- if srct:
- s += " " + str(srct)[9:-2]
- dstt = a.pop("DstT", None)
- if dstt:
- s += "->" + str(dstt)[9:-2]
- print(s + " " + str(a))
+ def print_operators(self, ignore_placeholder_const=True, show_attributes=True):
+ print(f"Operators of Subgraph {self.name}")
+
+ ignore_ops = (Op.Const, Op.Identity, Op.Placeholder) if ignore_placeholder_const else ()
+ all_ops = [op for op in self.get_all_ops() if op.type not in ignore_ops]
+
+ if len(all_ops) > 0:
+ max_op_type_len = max([len(op.type.name) for op in all_ops])
+
+ for idx, op in enumerate(all_ops):
+ attrs_str = f" - {op.attrs}" if show_attributes else ""
+ print(f"{idx:3}: {op.type:{max_op_type_len}}{attrs_str} - {op.name}")
+
+ else:
+ print("No Operators")
def print_graph(self, label=None):
if label:
@@ -562,9 +542,9 @@ class Graph:
for sg in self.subgraphs:
sg.refresh_after_modification()
- def print_operators(self):
+ def print_operators(self, ignore_placeholder_const=True, show_attributes=True):
for sg in self.subgraphs:
- sg.print_operators()
+ sg.print_operators(ignore_placeholder_const, show_attributes)
def print_graph(self, label=None):
for sg in self.subgraphs: