diff options
author | Michael McGeagh <michael.mcgeagh@arm.com> | 2020-10-20 11:49:28 +0100 |
---|---|---|
committer | Michael McGeagh <michael.mcgeagh@arm.com> | 2020-11-04 14:11:24 +0000 |
commit | 65fd99830a762b2c59aaa446b55cbfa43a92f8ba (patch) | |
tree | 2320b8b0573234c7976d8228679f3b8f577b4590 /ethosu/vela/operation.py | |
parent | 37ce38c208601c6a7901d2dc266ed7db6842405b (diff) | |
download | ethos-u-vela-65fd99830a762b2c59aaa446b55cbfa43a92f8ba.tar.gz |
MLBEDSW-2412 All constraints have been refactored
All existing constraints have now been refactored using the new
framework.
Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com>
Change-Id: Ic9ba0d7040cb9f114b959a949bfdf777f86752c7
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r-- | ethosu/vela/operation.py | 42 |
1 files changed, 31 insertions, 11 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index cc52ff4b..1ba2a388 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -220,13 +220,13 @@ class Op(Enum): Sin = OperatorInfo() SkipGram = OperatorInfo() Slice = OperatorInfo(indices=IFM_INDICES) - Softmax = OperatorInfo() + Softmax = OperatorInfo(indices=IFM_INDICES) SpaceToBatchND = OperatorInfo() SpaceToDepth = OperatorInfo() SparseToDense = OperatorInfo() Split = OperatorInfo(indices=SPLIT_IFM_INDICES) SplitSliceRead = OperatorInfo(indices=IFM_INDICES) - SplitV = OperatorInfo(indices=IFM_INDICES) + SplitV = OperatorInfo(indices=IFM_IFM2_INDICES) Sqrt = OperatorInfo() Square = OperatorInfo() SquaredDifference = OperatorInfo() @@ -399,19 +399,39 @@ class Operation: __repr__ = __str__ - @property - def kernel(self): - strides = self.attrs.get("strides", (1, 1, 1, 1)) - dilation = self.attrs.get("dilation", (1, 1, 1, 1)) + def get_kernel_size(self): weights = self.weights if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN): weight_shape = full_shape(4, weights.shape, 1) - k_h = weight_shape[-4] - k_w = weight_shape[-3] + h = weight_shape[-4] + w = weight_shape[-3] else: - k_h = self.attrs.get("filter_height", 1) - k_w = self.attrs.get("filter_width", 1) - self._kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1]) + h = self.attrs.get("filter_height", 1) + w = self.attrs.get("filter_width", 1) + return w, h + + def get_kernel_stride(self): + if "strides" in self.attrs: + _, h, w, _ = self.attrs["strides"] + else: + h = self.attrs.get("stride_h", 1) + w = self.attrs.get("stride_w", 1) + return w, h + + def get_kernel_dilation(self): + if "dilation" in self.attrs: + _, h, w, _ = self.attrs["dilation"] + else: + h = self.attrs.get("dilation_h_factor", 1) + w = self.attrs.get("dilation_w_factor", 1) + return w, h + + @property + def kernel(self): + k_w, k_h = self.get_kernel_size() + s_w, s_h = self.get_kernel_stride() + d_w, d_h = self.get_kernel_dilation() + self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h) return self._kernel def get_ifm_ifm2_weights_ofm(self): |