aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_supported_operators.py
diff options
context:
space:
mode:
authorMichael McGeagh <michael.mcgeagh@arm.com>2020-10-01 15:37:44 +0100
committerMichael McGeagh <michael.mcgeagh@arm.com>2020-10-05 16:17:07 +0100
commit37ded34e3d71a13aa3a14803c1fdbb6f2e73d79e (patch)
treeea9d6734053fb442de5e7cc869591209740d157e /ethosu/vela/test/test_supported_operators.py
parent1eeea515402e38f4715250dbca1764bb791da17c (diff)
downloadethos-u-vela-37ded34e3d71a13aa3a14803c1fdbb6f2e73d79e.tar.gz
MLBEDSW-2412 Replace generic restrictions
A new mechanism to report generic restrictions/constraints for operators has been implemented. Each check is its own defined function, and has a general reason for the constraint defined as its docstring. This allows us to query all reasons up front and report this without having to run through real data to trigger the checks. This is part of a larger refactoring and the specific restrictions will be replaced by a similar mechanism. Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com> Change-Id: Id3fb2639f91cfac5fc5b8c14f7620de1a85972b2
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_supported_operators.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index df310434..53c20927 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -16,9 +16,12 @@
#
# Description:
# Unit tests for support_operators
+import numpy as np
+
from ethosu.vela.data_type import DataType
from ethosu.vela.supported_operators import SupportedOperators
from ethosu.vela.tensor import create_const_tensor
+from ethosu.vela.tensor import QuantizationParameters
from ethosu.vela.tensor import Tensor
from ethosu.vela.test import testutil
@@ -84,3 +87,76 @@ def test_strided_slice():
op = create_strided_slice()
op.attrs["end_mask"] = 0
assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_defined_shape():
+ # Tensors cannot have None in them
+ inp = Tensor([1, 8, None, 8], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_shapeless():
+ # Shapeless input is allowed if its of a certain type:
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
+ assert support.is_operator_supported(op)
+ # Shapeless output is not allowed at all:
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [1, 8, 8, 8], [])
+ assert not support.is_operator_supported(op)
+ # Invalid shapeless input due to op type:
+ inp = Tensor([], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_shape_size():
+ # Tensors cannot be > 4D
+ inp = Tensor([1, 1, 8, 8, 8], DataType.uint8, "in")
+ out = Tensor([1, 1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_dtype():
+ # Tensors can only be of type uint8, int8, int16 (and int32)
+ inp = Tensor([1, 8, 8, 8], DataType.float32, "in")
+ out = Tensor([1, 8, 8, 8], DataType.float32, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+ # For int32, only select op types are allowed:
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8], DataType.int32)
+ assert support.is_operator_supported(op)
+ inp = Tensor([1, 8, 8, 8], DataType.int32, "in")
+ out = Tensor([1, 8, 8, 8], DataType.int32, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_dimension():
+ # Tensors can only have values in the inclusive range of 1-65535
+ inp = Tensor([1, 8, 8, 0], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 0], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+ inp = Tensor([1, 8, 8, 65536], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 65536], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_faf():
+ # Fused activation functions, if set, must be a valid op type
+ inp = Tensor([1, 8, 8, 8], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out, attrs={"fused_activation_function": "Conv2D"})
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_quant_scale():
+ # Quantization scale cannot be infinit
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
+ op.inputs[0].quantization = QuantizationParameters()
+ op.inputs[0].quantization.scale_f32 = np.inf
+ assert not support.is_operator_supported(op)