aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_supported_operators.py
diff options
context:
space:
mode:
authorDiqing Zhong <diqing.zhong@arm.com>2021-01-26 12:12:51 +0100
committerDiqing Zhong <diqing.zhong@arm.com>2021-01-29 16:17:40 +0100
commit189f748e1a79ed88044efbe7137963bca830cbb5 (patch)
tree4d3db8614574b5aedcf952941c2194e2bf7f8285 /ethosu/vela/test/test_supported_operators.py
parent2c2522dd44229a03d3d778cd239478fedc19ee57 (diff)
downloadethos-u-vela-189f748e1a79ed88044efbe7137963bca830cbb5.tar.gz
MLBEDSW-3224: Support HardSwish
Change-Id: If49abc31f093f1bd3393bee86f821fd35972086f Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_supported_operators.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 36213b73..5c01027d 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -834,3 +834,26 @@ def test_constraint_alpha_valid():
assert support.is_operator_supported(op)
op.attrs["alpha"] = -1
assert not support.is_operator_supported(op)
+
+
+def test_constraint_hardswish_dtype():
+ # HardSwish operator dtype should be int8 or uint8, and input dtype must match output
+ # UINT8
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
+ assert support.is_operator_supported(op)
+ # INT8
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
+ assert support.is_operator_supported(op)
+
+ # Invalid
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
+ assert not support.is_operator_supported(op)
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
+ assert not support.is_operator_supported(op)
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
+ assert not support.is_operator_supported(op)
+
+ in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
+ out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
+ assert not support.is_operator_supported(op)