aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/operation.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r--ethosu/vela/operation.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 277f2de5..5a6423d8 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -15,6 +15,9 @@
# limitations under the License.
# Description:
# Internal representation of a Neural Network Operation.
+# For Class name forward references for the type annotations. (see PEP 563).
+from __future__ import annotations
+
import copy
from collections import namedtuple
from enum import Enum
@@ -24,13 +27,14 @@ from typing import List
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
+from typing import Union
from .api import NpuRoundingMode
from .errors import VelaError
from .numeric_util import full_shape
from .shape4d import Shape4D
-
+# Import needed for Type annotations. Only import for Type checking to avoid run-time errors due to cyclic import.
if TYPE_CHECKING:
from .tensor import Tensor
@@ -80,9 +84,6 @@ class Kernel:
def area_height(self) -> int:
return (self.height - 1) * self.dilation.y + 1
- def dilation(self) -> PointXY:
- return self.dilation
-
def dilated_wh(self) -> Tuple[int, int]:
"""Returns the dilated kernel width/height"""
return self.dilation.x * (self.width - 1) + 1, self.dilation.y * (self.height - 1) + 1
@@ -443,7 +444,7 @@ def create_activation_function(op_type: Op, min=None, max=None) -> ActivationFun
return act
-def get_slice_offsets(input_shape: List[int], offset_tens: int, offset_mask: int, is_begin: bool = True):
+def get_slice_offsets(input_shape: List[int], offset_tens: Tensor, offset_mask: int, is_begin: bool = True):
# For strided slice operator: get start or end offsets
offsets = len(input_shape) * [0] if is_begin else input_shape[:]
for idx in range(len(input_shape)):
@@ -493,7 +494,7 @@ class Operation:
self.type = op_type
self.name = name
self.attrs: Dict[str, Any] = {}
- self.inputs: List[Tensor] = []
+ self.inputs: List[Optional[Tensor]] = []
self.outputs: List[Tensor] = []
self.intermediates: List[Tensor] = []
self.flops = 0
@@ -514,9 +515,9 @@ class Operation:
self.ofm_shapes: List[Shape4D] = []
# If not none: contains rescale to be used as output scaling
# (which overrides the ofm tensor's scale)
- self.rescale = None
- self.read_offsets: List[Shape4D] = [None, None] # offset for [ifm, ifm2]
- self.read_shapes: List[Shape4D] = [None, None] # read shape for [ifm, ifm2]
+ self.rescale: Optional[Union[Tuple[int, int], ExplicitScaling]] = None
+ self.read_offsets: List[Optional[Shape4D]] = [None, None] # offset for [ifm, ifm2]
+ self.read_shapes: List[Optional[Shape4D]] = [None, None] # read shape for [ifm, ifm2]
self.rounding_mode: Optional[NpuRoundingMode] = None
# Rescale op in TOSA supplies explicit multiplier and shift values
self.explicit_scaling: Optional[ExplicitScaling] = None