aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/operation.py
diff options
context:
space:
mode:
authorMichael McGeagh <michael.mcgeagh@arm.com>2020-10-20 11:49:28 +0100
committerMichael McGeagh <michael.mcgeagh@arm.com>2020-11-04 14:11:24 +0000
commit65fd99830a762b2c59aaa446b55cbfa43a92f8ba (patch)
tree2320b8b0573234c7976d8228679f3b8f577b4590 /ethosu/vela/operation.py
parent37ce38c208601c6a7901d2dc266ed7db6842405b (diff)
downloadethos-u-vela-65fd99830a762b2c59aaa446b55cbfa43a92f8ba.tar.gz
MLBEDSW-2412 All constraints have been refactored
All existing constraints have now been refactored using the new framework. Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com> Change-Id: Ic9ba0d7040cb9f114b959a949bfdf777f86752c7
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r--ethosu/vela/operation.py42
1 files changed, 31 insertions, 11 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index cc52ff4b..1ba2a388 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -220,13 +220,13 @@ class Op(Enum):
Sin = OperatorInfo()
SkipGram = OperatorInfo()
Slice = OperatorInfo(indices=IFM_INDICES)
- Softmax = OperatorInfo()
+ Softmax = OperatorInfo(indices=IFM_INDICES)
SpaceToBatchND = OperatorInfo()
SpaceToDepth = OperatorInfo()
SparseToDense = OperatorInfo()
Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
- SplitV = OperatorInfo(indices=IFM_INDICES)
+ SplitV = OperatorInfo(indices=IFM_IFM2_INDICES)
Sqrt = OperatorInfo()
Square = OperatorInfo()
SquaredDifference = OperatorInfo()
@@ -399,19 +399,39 @@ class Operation:
__repr__ = __str__
- @property
- def kernel(self):
- strides = self.attrs.get("strides", (1, 1, 1, 1))
- dilation = self.attrs.get("dilation", (1, 1, 1, 1))
+ def get_kernel_size(self):
weights = self.weights
if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
weight_shape = full_shape(4, weights.shape, 1)
- k_h = weight_shape[-4]
- k_w = weight_shape[-3]
+ h = weight_shape[-4]
+ w = weight_shape[-3]
else:
- k_h = self.attrs.get("filter_height", 1)
- k_w = self.attrs.get("filter_width", 1)
- self._kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1])
+ h = self.attrs.get("filter_height", 1)
+ w = self.attrs.get("filter_width", 1)
+ return w, h
+
+ def get_kernel_stride(self):
+ if "strides" in self.attrs:
+ _, h, w, _ = self.attrs["strides"]
+ else:
+ h = self.attrs.get("stride_h", 1)
+ w = self.attrs.get("stride_w", 1)
+ return w, h
+
+ def get_kernel_dilation(self):
+ if "dilation" in self.attrs:
+ _, h, w, _ = self.attrs["dilation"]
+ else:
+ h = self.attrs.get("dilation_h_factor", 1)
+ w = self.attrs.get("dilation_w_factor", 1)
+ return w, h
+
+ @property
+ def kernel(self):
+ k_w, k_h = self.get_kernel_size()
+ s_w, s_h = self.get_kernel_stride()
+ d_w, d_h = self.get_kernel_dilation()
+ self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h)
return self._kernel
def get_ifm_ifm2_weights_ofm(self):