diff options
author | Michael McGeagh <michael.mcgeagh@arm.com> | 2020-08-06 17:31:02 +0100 |
---|---|---|
committer | Fredrik Knutsson <fredrik.knutsson.hunnebo@gmail.com> | 2020-08-12 06:29:18 +0000 |
commit | 5778ffdab61a46369c73c91f2c6289ba9833e3a3 (patch) | |
tree | 3225088facdb0fe46190170c47da304df45e4aee /ethosu/vela/tensor.py | |
parent | 22f74e1c39572f084ad05cc2f208446fd2f50138 (diff) | |
download | ethos-u-vela-5778ffdab61a46369c73c91f2c6289ba9833e3a3.tar.gz |
MLBEDSW-2637 Refactor util funcs out of softmax.py
There were a number of "TensorUtil" functions defined in softmax.py
These have been moved to their respective classes for Tensor and
Operator respectively.
Two of the functions were not a simple tensor/op function. These helper
functions have been moved to tensor.py for the simple fact that they
return Tensor's
Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com>
Change-Id: I17d39c4e11f0837b7867b4a54da2e4a56383e095
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index c2d6b6e5..98487324 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -21,8 +21,10 @@ import uuid import numpy as np from . import numeric_util +from .data_type import DataType from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .numeric_util import round_up_divide +from .operation import Operation from .range_set import MemoryRangeSet @@ -231,6 +233,38 @@ class QuantizationParameters: return res +def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None): + # Tensor + const_tensor = Tensor(shape, dtype, name + "_0") + const_tensor.purpose = purpose + const_tensor.quantization = quantization + const_tensor.values = np.array(values, dtype=value_dtype) + const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8) + # Operator + const_op = Operation("Const", name) + const_op.set_output_tensor(const_tensor) + return const_tensor + + +def create_reshape_tensor(tens, shape, ifm_reshape=True): + if shape == tens.shape: + return tens + # Tensors + name = tens.name + "_reshape" + reshape_ifm = tens + reshape_ofm = tens.clone("_reshaped") + reshape_ofm.set_all_shapes(shape) + if not ifm_reshape: + reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm + # Operator + reshape_op = Operation("Reshape", name) + reshape_op.attrs["new_shape"] = shape + reshape_op.add_input_tensor(reshape_ifm) + reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape)) + reshape_op.set_output_tensor(reshape_ofm) + return reshape_ofm if ifm_reshape else reshape_ifm + + class Tensor: __slots__ = ( "shape", @@ -696,6 +730,15 @@ class Tensor: self.storage_shape = shape self.bandwidth_shape = shape + def get_full_shape(self): + d = len(self.shape) + if d in (1, 3): + return [1] * (4 - d) + self.shape + elif d == 2: + return [self.shape[0], 1, 1, self.shape[1]] + else: + return self.shape + def __str__(self): return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype) |