diff options
Diffstat (limited to 'ethosu/vela/nn_graph.py')
-rw-r--r-- | ethosu/vela/nn_graph.py | 54 |
1 files changed, 17 insertions, 37 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index a43aac2a..6dc6b583 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -338,41 +338,21 @@ class Subgraph: return all_ops - def print_operators(self): - print("print_operators()", self.name) - all_ops = self.get_all_ops() - unique_ops = [] - for op in all_ops: - if op.type in (Op.Const, Op.Identity, Op.Placeholder): - continue - - attrs = op.attrs.copy() - if op.type in (Op.Conv2D, Op.Conv2DBias, Op.DepthwiseConv2DBias): - kshape = op.inputs[1].shape - attrs["kshape"] = [kshape[0], kshape[1]] - attrs["type"] = op.type.name - attrs.pop("use_cudnn_on_gpu", None) - custom_options = attrs.pop("custom_options", None) - if attrs not in unique_ops: - unique_ops.append(attrs) - # print attributes in human readable format - a = attrs.copy() - if custom_options is not None: - a["custom_options"] = custom_options - s = a.pop("type") - data_format = a.pop("data_format", None) - if data_format and data_format != b"NHWC": - s += " " + str(data_format) - t = a.pop("T", None) - if t: - s += " " + str(t)[9:-2] - srct = a.pop("SrcT", None) - if srct: - s += " " + str(srct)[9:-2] - dstt = a.pop("DstT", None) - if dstt: - s += "->" + str(dstt)[9:-2] - print(s + " " + str(a)) + def print_operators(self, ignore_placeholder_const=True, show_attributes=True): + print(f"Operators of Subgraph {self.name}") + + ignore_ops = (Op.Const, Op.Identity, Op.Placeholder) if ignore_placeholder_const else () + all_ops = [op for op in self.get_all_ops() if op.type not in ignore_ops] + + if len(all_ops) > 0: + max_op_type_len = max([len(op.type.name) for op in all_ops]) + + for idx, op in enumerate(all_ops): + attrs_str = f" - {op.attrs}" if show_attributes else "" + print(f"{idx:3}: {op.type:{max_op_type_len}}{attrs_str} - {op.name}") + + else: + print("No Operators") def print_graph(self, label=None): if label: @@ -562,9 +542,9 @@ class Graph: for sg in self.subgraphs: sg.refresh_after_modification() - def print_operators(self): + def print_operators(self, ignore_placeholder_const=True, show_attributes=True): for sg in self.subgraphs: - sg.print_operators() + sg.print_operators(ignore_placeholder_const, show_attributes) def print_graph(self, label=None): for sg in self.subgraphs: |