aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index c0786bfc..98dfa3d3 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -25,6 +25,7 @@ import numpy as np
from . import numeric_util
from .data_type import DataType
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
+from .operation import Op
from .operation import Operation
from .range_set import MemoryRangeSet
@@ -242,7 +243,7 @@ def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=Te
const_tensor.values = np.array(values, dtype=value_dtype)
const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8)
# Operator
- const_op = Operation("Const", name)
+ const_op = Operation(Op.Const, name)
const_op.set_output_tensor(const_tensor)
return const_tensor
@@ -258,7 +259,7 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True):
if not ifm_reshape:
reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm
# Operator
- reshape_op = Operation("Reshape", name)
+ reshape_op = Operation(Op.Reshape, name)
reshape_op.attrs["new_shape"] = shape
reshape_op.add_input_tensor(reshape_ifm)
reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
@@ -649,7 +650,7 @@ class Tensor:
return strides
def needs_dma(self):
- return len(self.ops) == 1 and self.ops[0].type == "DMA"
+ return len(self.ops) == 1 and self.ops[0].type == Op.DMA
def get_dma_src_tensor(self):
# For weight tensors that need DMA: returns the source tensor in Flash, else None
@@ -659,7 +660,7 @@ class Tensor:
def find_npu_op(self):
# Returns the NPU operator that uses this tensor, excluding DMA operators.
for op in self.consumers():
- if op.type == "DMA":
+ if op.type == Op.DMA:
return op.outputs[0].find_npu_op()
if op.run_on_npu:
return op