diff options
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r-- | ethosu/vela/operation.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 6e5b4820..cc52ff4b 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -18,6 +18,11 @@ from collections import namedtuple from enum import Enum +from .numeric_util import full_shape + +PointXY = namedtuple("PointXY", "x y") +PointXYZ = namedtuple("PointXYZ", "x y z") + class NpuBlockType(Enum): Default = 0 @@ -29,6 +34,26 @@ class NpuBlockType(Enum): ReduceSum = 6 +class Kernel: + def __init__(self, w, h, sx=1, sy=1, dx=1, dy=1): + assert sx > 0 and sy > 0 + assert dx > 0 and dy > 0 + self.width = w + self.height = h + self.stride = PointXY(sx, sy) + self.dilation = PointXY(dx, dy) + self.upscale = 1 + + def elements_wh(self): + return self.width * self.height + + def area_width(self): + return (self.width - 1) * self.dilation.x + 1 + + def area_height(self): + return (self.height - 1) * self.dilation.y + 1 + + # Classifies operators of type Custom class CustomType(Enum): ThirdPartyOp = 0 # Third party custom op @@ -330,6 +355,7 @@ class Operation: "memory_function", "forced_output_quantization", "activation_lut", + "_kernel", ) def __init__(self, op_type, name): @@ -350,6 +376,7 @@ class Operation: self.scheduled_pass = None self.op_index = None # input network operator index self.activation_lut = None + self._kernel = None def clone(self, suffix="_clone"): res = Operation(self.type, self.name + suffix) @@ -372,6 +399,21 @@ class Operation: __repr__ = __str__ + @property + def kernel(self): + strides = self.attrs.get("strides", (1, 1, 1, 1)) + dilation = self.attrs.get("dilation", (1, 1, 1, 1)) + weights = self.weights + if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN): + weight_shape = full_shape(4, weights.shape, 1) + k_h = weight_shape[-4] + k_w = weight_shape[-3] + else: + k_h = self.attrs.get("filter_height", 1) + k_w = self.attrs.get("filter_width", 1) + self._kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1]) + return self._kernel + def get_ifm_ifm2_weights_ofm(self): return self.ifm, self.ifm2, self.weights, self.ofm |