From bf31d647dc5df47410ee577b12427ddf076d816b Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Wed, 16 Dec 2020 13:08:06 +0100 Subject: MLBEDSW-3645 4D class for op ifm/ofm shapes Add 4D shape class for op Ifm/ofm shapes Signed-off-by: Patrik Gustavsson Change-Id: Ic0a98da9d2f9d085605e39a9ab5a26bad6e702a3 --- ethosu/vela/operation.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) (limited to 'ethosu/vela/operation.py') diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index be26a26b..c80e18b5 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -26,6 +26,7 @@ from typing import TYPE_CHECKING from .errors import VelaError from .numeric_util import full_shape +from .shape4d import Shape4D if TYPE_CHECKING: @@ -372,7 +373,7 @@ def create_activation_function(op_type: Op) -> ActivationFunction: return act -def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True): +def get_slice_offsets(input_shape: List[int], offset_tens: int, offset_mask: int, is_begin: bool = True): # For strided slice operator: get start or end offsets offsets = len(input_shape) * [0] if is_begin else input_shape[:] for idx in range(len(input_shape)): @@ -427,8 +428,8 @@ class Operation: self.op_index = None # input network operator index self.activation_lut = None self._kernel = None - self.ifm_shapes = [] - self.ofm_shapes = [] + self.ifm_shapes: List[Shape4D] = [] + self.ofm_shapes: List[Shape4D] = [] def clone(self, suffix="_clone"): res = Operation(self.type, self.name + suffix) @@ -707,6 +708,9 @@ class Operation: raise VelaError("\n".join(lines)) def set_ifm_ofm_shapes(self): + self.ifm_shapes = [] + self.ofm_shapes = [] + ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm() # set all shapes to op, as 4D @@ -716,24 +720,24 @@ class Operation: batch_size = elms // n_in_elems assert batch_size * n_in_elems == elms - self.ifm_shapes.append([batch_size, 1, 1, n_in_elems]) - self.ofm_shapes.append(ofm_tensor.get_full_shape()) + self.ifm_shapes.append(Shape4D([batch_size, 1, 1, n_in_elems])) + self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) elif self.type == Op.Softmax: - self.ifm_shapes.append(ifm_tensor.get_full_shape()) - self.ofm_shapes.append(ofm_tensor.get_full_shape()) + self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape())) + self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) elif self.type.is_split_op or self.type.is_concat_op(): for inp in self.inputs: if inp is not None: - self.ifm_shapes.append(full_shape(4, inp.shape, 1)) + self.ifm_shapes.append(Shape4D(full_shape(4, inp.shape, 1))) else: self.ifm_shapes.append(None) for out in self.outputs: if out is not None: - self.ofm_shapes.append(full_shape(4, out.shape, 1)) + self.ofm_shapes.append(Shape4D(full_shape(4, out.shape, 1))) else: self.ofm_shapes.append(None) else: - self.ifm_shapes.append(full_shape(4, ifm_tensor.shape, 1)) + self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1))) if ifm2_tensor is not None: - self.ifm_shapes.append(full_shape(4, ifm2_tensor.shape, 1)) - self.ofm_shapes.append(full_shape(4, ofm_tensor.shape, 1)) + self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1))) + self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1))) -- cgit v1.2.1