aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/operation.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r--ethosu/vela/operation.py42
1 files changed, 31 insertions, 11 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index cc52ff4b..1ba2a388 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -220,13 +220,13 @@ class Op(Enum):
Sin = OperatorInfo()
SkipGram = OperatorInfo()
Slice = OperatorInfo(indices=IFM_INDICES)
- Softmax = OperatorInfo()
+ Softmax = OperatorInfo(indices=IFM_INDICES)
SpaceToBatchND = OperatorInfo()
SpaceToDepth = OperatorInfo()
SparseToDense = OperatorInfo()
Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
- SplitV = OperatorInfo(indices=IFM_INDICES)
+ SplitV = OperatorInfo(indices=IFM_IFM2_INDICES)
Sqrt = OperatorInfo()
Square = OperatorInfo()
SquaredDifference = OperatorInfo()
@@ -399,19 +399,39 @@ class Operation:
__repr__ = __str__
- @property
- def kernel(self):
- strides = self.attrs.get("strides", (1, 1, 1, 1))
- dilation = self.attrs.get("dilation", (1, 1, 1, 1))
+ def get_kernel_size(self):
weights = self.weights
if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
weight_shape = full_shape(4, weights.shape, 1)
- k_h = weight_shape[-4]
- k_w = weight_shape[-3]
+ h = weight_shape[-4]
+ w = weight_shape[-3]
else:
- k_h = self.attrs.get("filter_height", 1)
- k_w = self.attrs.get("filter_width", 1)
- self._kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1])
+ h = self.attrs.get("filter_height", 1)
+ w = self.attrs.get("filter_width", 1)
+ return w, h
+
+ def get_kernel_stride(self):
+ if "strides" in self.attrs:
+ _, h, w, _ = self.attrs["strides"]
+ else:
+ h = self.attrs.get("stride_h", 1)
+ w = self.attrs.get("stride_w", 1)
+ return w, h
+
+ def get_kernel_dilation(self):
+ if "dilation" in self.attrs:
+ _, h, w, _ = self.attrs["dilation"]
+ else:
+ h = self.attrs.get("dilation_h_factor", 1)
+ w = self.attrs.get("dilation_w_factor", 1)
+ return w, h
+
+ @property
+ def kernel(self):
+ k_w, k_h = self.get_kernel_size()
+ s_w, s_h = self.get_kernel_stride()
+ d_w, d_h = self.get_kernel_dilation()
+ self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h)
return self._kernel
def get_ifm_ifm2_weights_ofm(self):