aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py37
1 files changed, 24 insertions, 13 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 8446ec2..84432c7 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -119,7 +119,7 @@ class SupportedOperators:
filter_height_range = (1, 256)
filter_product_range = (1, 256 * 256)
# Supported consumers
- supported_pad_consumers = convolution_ops | depthwise_convolution_ops
+ supported_pad_consumers = convolution_ops | depthwise_convolution_ops | pooling_ops
def __init__(self):
# Setup the generic constraints. Note: the order matters
@@ -878,18 +878,29 @@ class SupportedOperators:
# which makes it impossible to calculate kernel size, hence use cached _kernel for those operators
k = cons.kernel if cons.inputs else cons._kernel
k_w, k_h = k.dilated_wh()
- if left > k_w // 2:
- return False, f"Left padding is {left}, kernel width is {k_w}"
- if right > k_w // 2:
- return False, f"Right padding is {right}, kernel width is {k_w}"
- if top > k_h // 2:
- return False, f"Top padding is {top}, kernel height is {k_h}"
- if bottom > k_h // 2:
- return False, f"Bottom padding is {bottom}, kernel height is {k_h}"
- if not SupportedOperators.__leading_pad_ok(top, k.stride.y, k_h):
- return False, f"Top padding is {top}, must be {k_h // 2} or multiple of {k.stride.y}"
- if not SupportedOperators.__leading_pad_ok(left, k.stride.x, k_w):
- return False, f"Left padding is {left}, must be {k_w // 2} or multiple of {k.stride.x}"
+ if cons.type.is_avgpool_op():
+ # For average pool, padding works different on the NPU; more restrictions apply
+ for name, pad, k_size in (
+ ("Left", left, k_w),
+ ("Right", right, k_w),
+ ("Top", top, k_h),
+ ("Bottom", bottom, k_h),
+ ):
+ if pad not in (0, k_size // 2):
+ return False, f"{name} padding is {pad}, only 0 or {k_size // 2} are supported"
+ else:
+ if left > k_w // 2:
+ return False, f"Left padding is {left}, kernel width is {k_w}"
+ if right > k_w // 2:
+ return False, f"Right padding is {right}, kernel width is {k_w}"
+ if top > k_h // 2:
+ return False, f"Top padding is {top}, kernel height is {k_h}"
+ if bottom > k_h // 2:
+ return False, f"Bottom padding is {bottom}, kernel height is {k_h}"
+ if not SupportedOperators.__leading_pad_ok(top, k.stride.y, k_h):
+ return False, f"Top padding is {top}, must be {k_h // 2} or multiple of {k.stride.y}"
+ if not SupportedOperators.__leading_pad_ok(left, k.stride.x, k_w):
+ return False, f"Left padding is {left}, must be {k_w // 2} or multiple of {k.stride.x}"
return True, "Pad size is ok"
@staticmethod