diff options
author | Diqing Zhong <diqing.zhong@arm.com> | 2021-01-26 12:12:51 +0100 |
---|---|---|
committer | Diqing Zhong <diqing.zhong@arm.com> | 2021-01-29 16:17:40 +0100 |
commit | 189f748e1a79ed88044efbe7137963bca830cbb5 (patch) | |
tree | 4d3db8614574b5aedcf952941c2194e2bf7f8285 /ethosu/vela/test/test_supported_operators.py | |
parent | 2c2522dd44229a03d3d778cd239478fedc19ee57 (diff) | |
download | ethos-u-vela-189f748e1a79ed88044efbe7137963bca830cbb5.tar.gz |
MLBEDSW-3224: Support HardSwish
Change-Id: If49abc31f093f1bd3393bee86f821fd35972086f
Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r-- | ethosu/vela/test/test_supported_operators.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 36213b73..5c01027d 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -834,3 +834,26 @@ def test_constraint_alpha_valid(): assert support.is_operator_supported(op) op.attrs["alpha"] = -1 assert not support.is_operator_supported(op) + + +def test_constraint_hardswish_dtype(): + # HardSwish operator dtype should be int8 or uint8, and input dtype must match output + # UINT8 + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8]) + assert support.is_operator_supported(op) + # INT8 + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8) + assert support.is_operator_supported(op) + + # Invalid + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16) + assert not support.is_operator_supported(op) + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16) + assert not support.is_operator_supported(op) + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32) + assert not support.is_operator_supported(op) + + in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in") + out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out") + op = testutil.create_op(Op.HardSwish, [in_tens], out_tens) + assert not support.is_operator_supported(op) |