diff options
Diffstat (limited to 'ethosu/vela/insert_dma.py')
-rw-r--r-- | ethosu/vela/insert_dma.py | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py index 99b46c07..56d68d13 100644 --- a/ethosu/vela/insert_dma.py +++ b/ethosu/vela/insert_dma.py @@ -17,6 +17,7 @@ # Insert DMA operations into the graph for transfering weights. from . import rewrite_graph from .operation import NpuBlockType +from .operation import Op from .operation import Operation from .tensor import MemArea from .tensor import MemType @@ -24,9 +25,6 @@ from .tensor import TensorPurpose from .weight_compressor import compress_weights -binary_elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum")) - - def weights_fit_sram(arch, op, tens, nng): if tens.purpose != TensorPurpose.Weights: return True @@ -57,7 +55,7 @@ def weights_fit_sram(arch, op, tens, nng): def insert_dma_cmd(op, arch, nng): - if op.type == "DMA" or not op.run_on_npu: + if op.type == Op.DMA or not op.run_on_npu: return op is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in op.inputs) @@ -76,14 +74,14 @@ def insert_dma_cmd(op, arch, nng): ) or tens.purpose == TensorPurpose.LUT: if tens.purpose in (TensorPurpose.Weights, TensorPurpose.LUT) or ( tens.purpose == TensorPurpose.FeatureMap - and op.type in binary_elementwise_op + and op.type.is_binary_elementwise_op() and tens.shape != [] and tens.shape != op.outputs[0].shape and tens.storage_size() > max_ifm_shram_avail ): only_vector_product_consumers = True for oper in tens.consumers(): - if oper is None or oper.attrs.get("npu_block_type") != NpuBlockType.VectorProduct: + if oper is None or oper.type.npu_block_type != NpuBlockType.VectorProduct: only_vector_product_consumers = False break @@ -95,7 +93,7 @@ def insert_dma_cmd(op, arch, nng): ) or tens.purpose == TensorPurpose.LUT: # Insert a DMA command here, as well as a new tensor situated in SRAM of the same size. new_tens = tens.clone_into_fast_storage(arch) - dma_cmd = Operation("DMA", tens.ops[0].name + "_dma") + dma_cmd = Operation(Op.DMA, tens.ops[0].name + "_dma") dma_cmd.inputs = [tens] dma_cmd.set_output_tensor(new_tens) dma_cmd.attrs["source"] = tens.mem_area |