diff options
author | Louis Verhaard <louis.verhaard@arm.com> | 2020-09-30 09:01:52 +0200 |
---|---|---|
committer | Louis Verhaard <louis.verhaard@arm.com> | 2020-10-08 16:29:29 +0200 |
commit | aee5d7537ff81ffda5ba222721b72f914ce50fb8 (patch) | |
tree | 495b9dfff2a188c6916f8ca2e390ee88f7da8ccc /ethosu/vela/nn_graph.py | |
parent | 36ad73a0fb46d3f844845c97c56d92de2a7a9b3d (diff) | |
download | ethos-u-vela-aee5d7537ff81ffda5ba222721b72f914ce50fb8.tar.gz |
MLBEDSW-3148: Refactor Operation
- op.type is now an enum instead of a string
- Removed unused operator codes
- Refactored some attributes like npu_block_type, fused_activation_function
- Refactored operator index calculation
- Refactored a number of operator sets
Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/nn_graph.py')
-rw-r--r-- | ethosu/vela/nn_graph.py | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index 58aab611..12edf5ef 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -22,6 +22,8 @@ # Graph - A full neural network graph with one or more Subgraphs. import enum +from .operation import Op + class PassPlacement(enum.Enum): Unknown = 0 @@ -176,7 +178,7 @@ class Subgraph: visit_tensor(inp) inp.consumer_list.append(op) - if op.type in set(("Placeholder", "SubgraphInput")): + if op.type in set((Op.Placeholder, Op.SubgraphInput)): assert len(op.outputs) == 1 self.input_tensors.append(op.outputs[0]) @@ -319,19 +321,14 @@ class Subgraph: all_ops = self.get_all_ops() unique_ops = [] for op in all_ops: - if op.type in set(("Const", "Identity", "Placeholder")): + if op.type in set((Op.Const, Op.Identity, Op.Placeholder)): continue - attrs = op.attrs - if ( - op.type == "Conv2D" - or op.type == "DepthwiseConv2dNative" - or op.type == "Conv2DBiasAct" - or op.type == "DepthwiseConv2dBiasAct" - ): + attrs = op.attrs.copy() + if op.type in (Op.Conv2D, Op.Conv2DBias, Op.DepthwiseConv2DBias): kshape = op.inputs[1].shape attrs["kshape"] = [kshape[0], kshape[1]] - attrs["type"] = op.type + attrs["type"] = op.type.name attrs.pop("use_cudnn_on_gpu", None) if attrs not in unique_ops: unique_ops.append(attrs) |