aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
authorMichael McGeagh <michael.mcgeagh@arm.com>2020-08-06 17:31:02 +0100
committerFredrik Knutsson <fredrik.knutsson.hunnebo@gmail.com>2020-08-12 06:29:18 +0000
commit5778ffdab61a46369c73c91f2c6289ba9833e3a3 (patch)
tree3225088facdb0fe46190170c47da304df45e4aee /ethosu/vela/tensor.py
parent22f74e1c39572f084ad05cc2f208446fd2f50138 (diff)
downloadethos-u-vela-5778ffdab61a46369c73c91f2c6289ba9833e3a3.tar.gz
MLBEDSW-2637 Refactor util funcs out of softmax.py
There were a number of "TensorUtil" functions defined in softmax.py These have been moved to their respective classes for Tensor and Operator respectively. Two of the functions were not a simple tensor/op function. These helper functions have been moved to tensor.py for the simple fact that they return Tensor's Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com> Change-Id: I17d39c4e11f0837b7867b4a54da2e4a56383e095
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index c2d6b6e5..98487324 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -21,8 +21,10 @@ import uuid
import numpy as np
from . import numeric_util
+from .data_type import DataType
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .numeric_util import round_up_divide
+from .operation import Operation
from .range_set import MemoryRangeSet
@@ -231,6 +233,38 @@ class QuantizationParameters:
return res
+def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None):
+ # Tensor
+ const_tensor = Tensor(shape, dtype, name + "_0")
+ const_tensor.purpose = purpose
+ const_tensor.quantization = quantization
+ const_tensor.values = np.array(values, dtype=value_dtype)
+ const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8)
+ # Operator
+ const_op = Operation("Const", name)
+ const_op.set_output_tensor(const_tensor)
+ return const_tensor
+
+
+def create_reshape_tensor(tens, shape, ifm_reshape=True):
+ if shape == tens.shape:
+ return tens
+ # Tensors
+ name = tens.name + "_reshape"
+ reshape_ifm = tens
+ reshape_ofm = tens.clone("_reshaped")
+ reshape_ofm.set_all_shapes(shape)
+ if not ifm_reshape:
+ reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm
+ # Operator
+ reshape_op = Operation("Reshape", name)
+ reshape_op.attrs["new_shape"] = shape
+ reshape_op.add_input_tensor(reshape_ifm)
+ reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
+ reshape_op.set_output_tensor(reshape_ofm)
+ return reshape_ofm if ifm_reshape else reshape_ifm
+
+
class Tensor:
__slots__ = (
"shape",
@@ -696,6 +730,15 @@ class Tensor:
self.storage_shape = shape
self.bandwidth_shape = shape
+ def get_full_shape(self):
+ d = len(self.shape)
+ if d in (1, 3):
+ return [1] * (4 - d) + self.shape
+ elif d == 2:
+ return [self.shape[0], 1, 1, self.shape[1]]
+ else:
+ return self.shape
+
def __str__(self):
return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)