aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index c2d6b6e5..98487324 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -21,8 +21,10 @@ import uuid
import numpy as np
from . import numeric_util
+from .data_type import DataType
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .numeric_util import round_up_divide
+from .operation import Operation
from .range_set import MemoryRangeSet
@@ -231,6 +233,38 @@ class QuantizationParameters:
return res
+def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None):
+ # Tensor
+ const_tensor = Tensor(shape, dtype, name + "_0")
+ const_tensor.purpose = purpose
+ const_tensor.quantization = quantization
+ const_tensor.values = np.array(values, dtype=value_dtype)
+ const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8)
+ # Operator
+ const_op = Operation("Const", name)
+ const_op.set_output_tensor(const_tensor)
+ return const_tensor
+
+
+def create_reshape_tensor(tens, shape, ifm_reshape=True):
+ if shape == tens.shape:
+ return tens
+ # Tensors
+ name = tens.name + "_reshape"
+ reshape_ifm = tens
+ reshape_ofm = tens.clone("_reshaped")
+ reshape_ofm.set_all_shapes(shape)
+ if not ifm_reshape:
+ reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm
+ # Operator
+ reshape_op = Operation("Reshape", name)
+ reshape_op.attrs["new_shape"] = shape
+ reshape_op.add_input_tensor(reshape_ifm)
+ reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
+ reshape_op.set_output_tensor(reshape_ofm)
+ return reshape_ofm if ifm_reshape else reshape_ifm
+
+
class Tensor:
__slots__ = (
"shape",
@@ -696,6 +730,15 @@ class Tensor:
self.storage_shape = shape
self.bandwidth_shape = shape
+ def get_full_shape(self):
+ d = len(self.shape)
+ if d in (1, 3):
+ return [1] * (4 - d) + self.shape
+ elif d == 2:
+ return [self.shape[0], 1, 1, self.shape[1]]
+ else:
+ return self.shape
+
def __str__(self):
return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)