aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/nn_graph.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-09-30 09:01:52 +0200
committerLouis Verhaard <louis.verhaard@arm.com>2020-10-08 16:29:29 +0200
commitaee5d7537ff81ffda5ba222721b72f914ce50fb8 (patch)
tree495b9dfff2a188c6916f8ca2e390ee88f7da8ccc /ethosu/vela/nn_graph.py
parent36ad73a0fb46d3f844845c97c56d92de2a7a9b3d (diff)
downloadethos-u-vela-aee5d7537ff81ffda5ba222721b72f914ce50fb8.tar.gz
MLBEDSW-3148: Refactor Operation
- op.type is now an enum instead of a string - Removed unused operator codes - Refactored some attributes like npu_block_type, fused_activation_function - Refactored operator index calculation - Refactored a number of operator sets Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/nn_graph.py')
-rw-r--r--ethosu/vela/nn_graph.py17
1 files changed, 7 insertions, 10 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 58aab611..12edf5ef 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -22,6 +22,8 @@
# Graph - A full neural network graph with one or more Subgraphs.
import enum
+from .operation import Op
+
class PassPlacement(enum.Enum):
Unknown = 0
@@ -176,7 +178,7 @@ class Subgraph:
visit_tensor(inp)
inp.consumer_list.append(op)
- if op.type in set(("Placeholder", "SubgraphInput")):
+ if op.type in set((Op.Placeholder, Op.SubgraphInput)):
assert len(op.outputs) == 1
self.input_tensors.append(op.outputs[0])
@@ -319,19 +321,14 @@ class Subgraph:
all_ops = self.get_all_ops()
unique_ops = []
for op in all_ops:
- if op.type in set(("Const", "Identity", "Placeholder")):
+ if op.type in set((Op.Const, Op.Identity, Op.Placeholder)):
continue
- attrs = op.attrs
- if (
- op.type == "Conv2D"
- or op.type == "DepthwiseConv2dNative"
- or op.type == "Conv2DBiasAct"
- or op.type == "DepthwiseConv2dBiasAct"
- ):
+ attrs = op.attrs.copy()
+ if op.type in (Op.Conv2D, Op.Conv2DBias, Op.DepthwiseConv2DBias):
kshape = op.inputs[1].shape
attrs["kshape"] = [kshape[0], kshape[1]]
- attrs["type"] = op.type
+ attrs["type"] = op.type.name
attrs.pop("use_cudnn_on_gpu", None)
if attrs not in unique_ops:
unique_ops.append(attrs)