diff options
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r-- | ethosu/vela/operation.py | 28 |
1 files changed, 12 insertions, 16 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index c80e18b5..be26a26b 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -26,7 +26,6 @@ from typing import TYPE_CHECKING from .errors import VelaError from .numeric_util import full_shape -from .shape4d import Shape4D if TYPE_CHECKING: @@ -373,7 +372,7 @@ def create_activation_function(op_type: Op) -> ActivationFunction: return act -def get_slice_offsets(input_shape: List[int], offset_tens: int, offset_mask: int, is_begin: bool = True): +def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True): # For strided slice operator: get start or end offsets offsets = len(input_shape) * [0] if is_begin else input_shape[:] for idx in range(len(input_shape)): @@ -428,8 +427,8 @@ class Operation: self.op_index = None # input network operator index self.activation_lut = None self._kernel = None - self.ifm_shapes: List[Shape4D] = [] - self.ofm_shapes: List[Shape4D] = [] + self.ifm_shapes = [] + self.ofm_shapes = [] def clone(self, suffix="_clone"): res = Operation(self.type, self.name + suffix) @@ -708,9 +707,6 @@ class Operation: raise VelaError("\n".join(lines)) def set_ifm_ofm_shapes(self): - self.ifm_shapes = [] - self.ofm_shapes = [] - ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm() # set all shapes to op, as 4D @@ -720,24 +716,24 @@ class Operation: batch_size = elms // n_in_elems assert batch_size * n_in_elems == elms - self.ifm_shapes.append(Shape4D([batch_size, 1, 1, n_in_elems])) - self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) + self.ifm_shapes.append([batch_size, 1, 1, n_in_elems]) + self.ofm_shapes.append(ofm_tensor.get_full_shape()) elif self.type == Op.Softmax: - self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape())) - self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) + self.ifm_shapes.append(ifm_tensor.get_full_shape()) + self.ofm_shapes.append(ofm_tensor.get_full_shape()) elif self.type.is_split_op or self.type.is_concat_op(): for inp in self.inputs: if inp is not None: - self.ifm_shapes.append(Shape4D(full_shape(4, inp.shape, 1))) + self.ifm_shapes.append(full_shape(4, inp.shape, 1)) else: self.ifm_shapes.append(None) for out in self.outputs: if out is not None: - self.ofm_shapes.append(Shape4D(full_shape(4, out.shape, 1))) + self.ofm_shapes.append(full_shape(4, out.shape, 1)) else: self.ofm_shapes.append(None) else: - self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1))) + self.ifm_shapes.append(full_shape(4, ifm_tensor.shape, 1)) if ifm2_tensor is not None: - self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1))) - self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1))) + self.ifm_shapes.append(full_shape(4, ifm2_tensor.shape, 1)) + self.ofm_shapes.append(full_shape(4, ofm_tensor.shape, 1)) |