aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/operation.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r--ethosu/vela/operation.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 6e5b482..cc52ff4 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -18,6 +18,11 @@
from collections import namedtuple
from enum import Enum
+from .numeric_util import full_shape
+
+PointXY = namedtuple("PointXY", "x y")
+PointXYZ = namedtuple("PointXYZ", "x y z")
+
class NpuBlockType(Enum):
Default = 0
@@ -29,6 +34,26 @@ class NpuBlockType(Enum):
ReduceSum = 6
+class Kernel:
+ def __init__(self, w, h, sx=1, sy=1, dx=1, dy=1):
+ assert sx > 0 and sy > 0
+ assert dx > 0 and dy > 0
+ self.width = w
+ self.height = h
+ self.stride = PointXY(sx, sy)
+ self.dilation = PointXY(dx, dy)
+ self.upscale = 1
+
+ def elements_wh(self):
+ return self.width * self.height
+
+ def area_width(self):
+ return (self.width - 1) * self.dilation.x + 1
+
+ def area_height(self):
+ return (self.height - 1) * self.dilation.y + 1
+
+
# Classifies operators of type Custom
class CustomType(Enum):
ThirdPartyOp = 0 # Third party custom op
@@ -330,6 +355,7 @@ class Operation:
"memory_function",
"forced_output_quantization",
"activation_lut",
+ "_kernel",
)
def __init__(self, op_type, name):
@@ -350,6 +376,7 @@ class Operation:
self.scheduled_pass = None
self.op_index = None # input network operator index
self.activation_lut = None
+ self._kernel = None
def clone(self, suffix="_clone"):
res = Operation(self.type, self.name + suffix)
@@ -372,6 +399,21 @@ class Operation:
__repr__ = __str__
+ @property
+ def kernel(self):
+ strides = self.attrs.get("strides", (1, 1, 1, 1))
+ dilation = self.attrs.get("dilation", (1, 1, 1, 1))
+ weights = self.weights
+ if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
+ weight_shape = full_shape(4, weights.shape, 1)
+ k_h = weight_shape[-4]
+ k_w = weight_shape[-3]
+ else:
+ k_h = self.attrs.get("filter_height", 1)
+ k_w = self.attrs.get("filter_width", 1)
+ self._kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1])
+ return self._kernel
+
def get_ifm_ifm2_weights_ofm(self):
return self.ifm, self.ifm2, self.weights, self.ofm