diff options
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r-- | ethosu/vela/operation.py | 42 |
1 files changed, 31 insertions, 11 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index cc52ff4b..1ba2a388 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -220,13 +220,13 @@ class Op(Enum): Sin = OperatorInfo() SkipGram = OperatorInfo() Slice = OperatorInfo(indices=IFM_INDICES) - Softmax = OperatorInfo() + Softmax = OperatorInfo(indices=IFM_INDICES) SpaceToBatchND = OperatorInfo() SpaceToDepth = OperatorInfo() SparseToDense = OperatorInfo() Split = OperatorInfo(indices=SPLIT_IFM_INDICES) SplitSliceRead = OperatorInfo(indices=IFM_INDICES) - SplitV = OperatorInfo(indices=IFM_INDICES) + SplitV = OperatorInfo(indices=IFM_IFM2_INDICES) Sqrt = OperatorInfo() Square = OperatorInfo() SquaredDifference = OperatorInfo() @@ -399,19 +399,39 @@ class Operation: __repr__ = __str__ - @property - def kernel(self): - strides = self.attrs.get("strides", (1, 1, 1, 1)) - dilation = self.attrs.get("dilation", (1, 1, 1, 1)) + def get_kernel_size(self): weights = self.weights if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN): weight_shape = full_shape(4, weights.shape, 1) - k_h = weight_shape[-4] - k_w = weight_shape[-3] + h = weight_shape[-4] + w = weight_shape[-3] else: - k_h = self.attrs.get("filter_height", 1) - k_w = self.attrs.get("filter_width", 1) - self._kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1]) + h = self.attrs.get("filter_height", 1) + w = self.attrs.get("filter_width", 1) + return w, h + + def get_kernel_stride(self): + if "strides" in self.attrs: + _, h, w, _ = self.attrs["strides"] + else: + h = self.attrs.get("stride_h", 1) + w = self.attrs.get("stride_w", 1) + return w, h + + def get_kernel_dilation(self): + if "dilation" in self.attrs: + _, h, w, _ = self.attrs["dilation"] + else: + h = self.attrs.get("dilation_h_factor", 1) + w = self.attrs.get("dilation_w_factor", 1) + return w, h + + @property + def kernel(self): + k_w, k_h = self.get_kernel_size() + s_w, s_h = self.get_kernel_stride() + d_w, d_h = self.get_kernel_dilation() + self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h) return self._kernel def get_ifm_ifm2_weights_ofm(self): |