aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-09-30 09:01:52 +0200
committerLouis Verhaard <louis.verhaard@arm.com>2020-10-08 16:29:29 +0200
commitaee5d7537ff81ffda5ba222721b72f914ce50fb8 (patch)
tree495b9dfff2a188c6916f8ca2e390ee88f7da8ccc /ethosu/vela/tensor.py
parent36ad73a0fb46d3f844845c97c56d92de2a7a9b3d (diff)
downloadethos-u-vela-aee5d7537ff81ffda5ba222721b72f914ce50fb8.tar.gz
MLBEDSW-3148: Refactor Operation
- op.type is now an enum instead of a string - Removed unused operator codes - Refactored some attributes like npu_block_type, fused_activation_function - Refactored operator index calculation - Refactored a number of operator sets Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index c0786bfc..98dfa3d3 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -25,6 +25,7 @@ import numpy as np
from . import numeric_util
from .data_type import DataType
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
+from .operation import Op
from .operation import Operation
from .range_set import MemoryRangeSet
@@ -242,7 +243,7 @@ def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=Te
const_tensor.values = np.array(values, dtype=value_dtype)
const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8)
# Operator
- const_op = Operation("Const", name)
+ const_op = Operation(Op.Const, name)
const_op.set_output_tensor(const_tensor)
return const_tensor
@@ -258,7 +259,7 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True):
if not ifm_reshape:
reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm
# Operator
- reshape_op = Operation("Reshape", name)
+ reshape_op = Operation(Op.Reshape, name)
reshape_op.attrs["new_shape"] = shape
reshape_op.add_input_tensor(reshape_ifm)
reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
@@ -649,7 +650,7 @@ class Tensor:
return strides
def needs_dma(self):
- return len(self.ops) == 1 and self.ops[0].type == "DMA"
+ return len(self.ops) == 1 and self.ops[0].type == Op.DMA
def get_dma_src_tensor(self):
# For weight tensors that need DMA: returns the source tensor in Flash, else None
@@ -659,7 +660,7 @@ class Tensor:
def find_npu_op(self):
# Returns the NPU operator that uses this tensor, excluding DMA operators.
for op in self.consumers():
- if op.type == "DMA":
+ if op.type == Op.DMA:
return op.outputs[0].find_npu_op()
if op.run_on_npu:
return op