From 189f748e1a79ed88044efbe7137963bca830cbb5 Mon Sep 17 00:00:00 2001 From: Diqing Zhong Date: Tue, 26 Jan 2021 12:12:51 +0100 Subject: MLBEDSW-3224: Support HardSwish Change-Id: If49abc31f093f1bd3393bee86f821fd35972086f Signed-off-by: Diqing Zhong --- ethosu/vela/test/test_supported_operators.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'ethosu/vela/test/test_supported_operators.py') diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 36213b73..5c01027d 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -834,3 +834,26 @@ def test_constraint_alpha_valid(): assert support.is_operator_supported(op) op.attrs["alpha"] = -1 assert not support.is_operator_supported(op) + + +def test_constraint_hardswish_dtype(): + # HardSwish operator dtype should be int8 or uint8, and input dtype must match output + # UINT8 + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8]) + assert support.is_operator_supported(op) + # INT8 + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8) + assert support.is_operator_supported(op) + + # Invalid + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16) + assert not support.is_operator_supported(op) + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16) + assert not support.is_operator_supported(op) + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32) + assert not support.is_operator_supported(op) + + in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in") + out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out") + op = testutil.create_op(Op.HardSwish, [in_tens], out_tens) + assert not support.is_operator_supported(op) -- cgit v1.2.1