diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index c0786bfc..98dfa3d3 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -25,6 +25,7 @@ import numpy as np from . import numeric_util from .data_type import DataType from .ethos_u55_regs.ethos_u55_regs import resampling_mode +from .operation import Op from .operation import Operation from .range_set import MemoryRangeSet @@ -242,7 +243,7 @@ def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=Te const_tensor.values = np.array(values, dtype=value_dtype) const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8) # Operator - const_op = Operation("Const", name) + const_op = Operation(Op.Const, name) const_op.set_output_tensor(const_tensor) return const_tensor @@ -258,7 +259,7 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True): if not ifm_reshape: reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm # Operator - reshape_op = Operation("Reshape", name) + reshape_op = Operation(Op.Reshape, name) reshape_op.attrs["new_shape"] = shape reshape_op.add_input_tensor(reshape_ifm) reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape)) @@ -649,7 +650,7 @@ class Tensor: return strides def needs_dma(self): - return len(self.ops) == 1 and self.ops[0].type == "DMA" + return len(self.ops) == 1 and self.ops[0].type == Op.DMA def get_dma_src_tensor(self): # For weight tensors that need DMA: returns the source tensor in Flash, else None @@ -659,7 +660,7 @@ class Tensor: def find_npu_op(self): # Returns the NPU operator that uses this tensor, excluding DMA operators. for op in self.consumers(): - if op.type == "DMA": + if op.type == Op.DMA: return op.outputs[0].find_npu_op() if op.run_on_npu: return op |