aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_supported_operators.py
diff options
context:
space:
mode:
authorDwight Lidman <dwight.lidman@arm.com>2021-03-26 10:53:28 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2021-04-07 10:50:39 +0000
commit95b279f1454d58a93238851cb5ff394c7782ad32 (patch)
treeeb2e8f4db229f0581a894084c75f44b55877a05b /ethosu/vela/test/test_supported_operators.py
parentfe368bc231fb680ebfa48e2c35e92dec5639df5e (diff)
downloadethos-u-vela-95b279f1454d58a93238851cb5ff394c7782ad32.tar.gz
MEAN implementation changed to Average Pool
This is a small commit which changes one of the four MEAN implementations to a simpler one, using an AvgPool instead of a DepthwiseConv. Signed-off-by: Dwight Lidman <dwight.lidman@arm.com> Change-Id: I9e8af071e8b820796577ee4792b4812a1212602b
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_supported_operators.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index aad2849a..355b472c 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -864,7 +864,7 @@ def test_mean_axis():
def test_mean_hw_product():
- op = create_mean([1, 64, 64, 16], [1, 1, 16], [1, 2], DataType.uint8, {})
+ op = create_mean([1, 64, 64, 16], [1, 16], [1, 2], DataType.uint8, {})
assert support.is_operator_supported(op)
op = create_mean([1, 65, 64, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
assert not support.is_operator_supported(op)
@@ -875,3 +875,10 @@ def test_mean_hw_product_int8():
assert support.is_operator_supported(op)
op = create_mean([1, 16, 17, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
assert not support.is_operator_supported(op)
+
+
+def test_mean_hw_product_avgpool():
+ op = create_mean([1, 200, 200, 16], [1, 16], [1, 2], DataType.uint8, {"keep_dims": False})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)