aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/nn_graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/nn_graph.py')
-rw-r--r--ethosu/vela/nn_graph.py17
1 files changed, 7 insertions, 10 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 58aab611..12edf5ef 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -22,6 +22,8 @@
# Graph - A full neural network graph with one or more Subgraphs.
import enum
+from .operation import Op
+
class PassPlacement(enum.Enum):
Unknown = 0
@@ -176,7 +178,7 @@ class Subgraph:
visit_tensor(inp)
inp.consumer_list.append(op)
- if op.type in set(("Placeholder", "SubgraphInput")):
+ if op.type in set((Op.Placeholder, Op.SubgraphInput)):
assert len(op.outputs) == 1
self.input_tensors.append(op.outputs[0])
@@ -319,19 +321,14 @@ class Subgraph:
all_ops = self.get_all_ops()
unique_ops = []
for op in all_ops:
- if op.type in set(("Const", "Identity", "Placeholder")):
+ if op.type in set((Op.Const, Op.Identity, Op.Placeholder)):
continue
- attrs = op.attrs
- if (
- op.type == "Conv2D"
- or op.type == "DepthwiseConv2dNative"
- or op.type == "Conv2DBiasAct"
- or op.type == "DepthwiseConv2dBiasAct"
- ):
+ attrs = op.attrs.copy()
+ if op.type in (Op.Conv2D, Op.Conv2DBias, Op.DepthwiseConv2DBias):
kshape = op.inputs[1].shape
attrs["kshape"] = [kshape[0], kshape[1]]
- attrs["type"] = op.type
+ attrs["type"] = op.type.name
attrs.pop("use_cudnn_on_gpu", None)
if attrs not in unique_ops:
unique_ops.append(attrs)