diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index c2d6b6e5..98487324 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -21,8 +21,10 @@ import uuid import numpy as np from . import numeric_util +from .data_type import DataType from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .numeric_util import round_up_divide +from .operation import Operation from .range_set import MemoryRangeSet @@ -231,6 +233,38 @@ class QuantizationParameters: return res +def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None): + # Tensor + const_tensor = Tensor(shape, dtype, name + "_0") + const_tensor.purpose = purpose + const_tensor.quantization = quantization + const_tensor.values = np.array(values, dtype=value_dtype) + const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8) + # Operator + const_op = Operation("Const", name) + const_op.set_output_tensor(const_tensor) + return const_tensor + + +def create_reshape_tensor(tens, shape, ifm_reshape=True): + if shape == tens.shape: + return tens + # Tensors + name = tens.name + "_reshape" + reshape_ifm = tens + reshape_ofm = tens.clone("_reshaped") + reshape_ofm.set_all_shapes(shape) + if not ifm_reshape: + reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm + # Operator + reshape_op = Operation("Reshape", name) + reshape_op.attrs["new_shape"] = shape + reshape_op.add_input_tensor(reshape_ifm) + reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape)) + reshape_op.set_output_tensor(reshape_ofm) + return reshape_ofm if ifm_reshape else reshape_ifm + + class Tensor: __slots__ = ( "shape", @@ -696,6 +730,15 @@ class Tensor: self.storage_shape = shape self.bandwidth_shape = shape + def get_full_shape(self): + d = len(self.shape) + if d in (1, 3): + return [1] * (4 - d) + self.shape + elif d == 2: + return [self.shape[0], 1, 1, self.shape[1]] + else: + return self.shape + def __str__(self): return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype) |