aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/extract_npu_subgraphs.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/extract_npu_subgraphs.py')
-rw-r--r--ethosu/vela/extract_npu_subgraphs.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/ethosu/vela/extract_npu_subgraphs.py b/ethosu/vela/extract_npu_subgraphs.py
index c0430b5d..e08392dc 100644
--- a/ethosu/vela/extract_npu_subgraphs.py
+++ b/ethosu/vela/extract_npu_subgraphs.py
@@ -25,17 +25,19 @@ import numpy as np
from .nn_graph import Pass
from .nn_graph import PassPlacement
from .nn_graph import Subgraph
+from .operation import CustomType
from .operation import NpuBlockType
+from .operation import Op
from .operation import Operation
def make_npu_call_op_pass(npu_subgraph):
- op = Operation("NpuOp", "call_" + npu_subgraph.name)
+ op = Operation(Op.CustomNpuOp, "call_" + npu_subgraph.name)
op.attrs["subgraph"] = npu_subgraph
+ op.attrs["custom_type"] = CustomType.NpuOp
ps = Pass(op.name, PassPlacement.MemoryOnly, False, NpuBlockType.Default)
ps.ops = [op]
ps.primary_op = op
- op.attrs["npu_block_type"] = ps.npu_block_type
op.scheduled_pass = ps
# Inputs and outputs filled in later as we cut the graphs
@@ -69,14 +71,13 @@ def switch_tensor_for_op(op, orig_tens, new_tens):
def rewrite_tensor_cpu_producer_npu_consumers(
orig_tens, call_ps, startup_init_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass
):
- is_const = orig_tens.ops[0].type == "Const"
+ is_const = orig_tens.ops[0].type == Op.Const
new_tens = orig_tens.clone("_npu")
- op_type = "SubgraphInput"
+ op_type = Op.SubgraphInput
if is_const:
- op_type = "Const"
+ op_type = Op.Const
op = Operation(op_type, orig_tens.name + "_input")
- op.attrs["npu_block_type"] = NpuBlockType.Default
op.scheduled_pass = startup_init_ps
op.set_output_tensor(new_tens)
startup_init_ps.ops.append(op)