diff options
Diffstat (limited to 'ethosu/vela/nn_graph.py')
-rw-r--r-- | ethosu/vela/nn_graph.py | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index 58aab611..12edf5ef 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -22,6 +22,8 @@ # Graph - A full neural network graph with one or more Subgraphs. import enum +from .operation import Op + class PassPlacement(enum.Enum): Unknown = 0 @@ -176,7 +178,7 @@ class Subgraph: visit_tensor(inp) inp.consumer_list.append(op) - if op.type in set(("Placeholder", "SubgraphInput")): + if op.type in set((Op.Placeholder, Op.SubgraphInput)): assert len(op.outputs) == 1 self.input_tensors.append(op.outputs[0]) @@ -319,19 +321,14 @@ class Subgraph: all_ops = self.get_all_ops() unique_ops = [] for op in all_ops: - if op.type in set(("Const", "Identity", "Placeholder")): + if op.type in set((Op.Const, Op.Identity, Op.Placeholder)): continue - attrs = op.attrs - if ( - op.type == "Conv2D" - or op.type == "DepthwiseConv2dNative" - or op.type == "Conv2DBiasAct" - or op.type == "DepthwiseConv2dBiasAct" - ): + attrs = op.attrs.copy() + if op.type in (Op.Conv2D, Op.Conv2DBias, Op.DepthwiseConv2DBias): kshape = op.inputs[1].shape attrs["kshape"] = [kshape[0], kshape[1]] - attrs["type"] = op.type + attrs["type"] = op.type.name attrs.pop("use_cudnn_on_gpu", None) if attrs not in unique_ops: unique_ops.append(attrs) |