aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_supported_operators.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index df310434..53c20927 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -16,9 +16,12 @@
#
# Description:
# Unit tests for support_operators
+import numpy as np
+
from ethosu.vela.data_type import DataType
from ethosu.vela.supported_operators import SupportedOperators
from ethosu.vela.tensor import create_const_tensor
+from ethosu.vela.tensor import QuantizationParameters
from ethosu.vela.tensor import Tensor
from ethosu.vela.test import testutil
@@ -84,3 +87,76 @@ def test_strided_slice():
op = create_strided_slice()
op.attrs["end_mask"] = 0
assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_defined_shape():
+ # Tensors cannot have None in them
+ inp = Tensor([1, 8, None, 8], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_shapeless():
+ # Shapeless input is allowed if its of a certain type:
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
+ assert support.is_operator_supported(op)
+ # Shapeless output is not allowed at all:
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [1, 8, 8, 8], [])
+ assert not support.is_operator_supported(op)
+ # Invalid shapeless input due to op type:
+ inp = Tensor([], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_shape_size():
+ # Tensors cannot be > 4D
+ inp = Tensor([1, 1, 8, 8, 8], DataType.uint8, "in")
+ out = Tensor([1, 1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_dtype():
+ # Tensors can only be of type uint8, int8, int16 (and int32)
+ inp = Tensor([1, 8, 8, 8], DataType.float32, "in")
+ out = Tensor([1, 8, 8, 8], DataType.float32, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+ # For int32, only select op types are allowed:
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8], DataType.int32)
+ assert support.is_operator_supported(op)
+ inp = Tensor([1, 8, 8, 8], DataType.int32, "in")
+ out = Tensor([1, 8, 8, 8], DataType.int32, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_dimension():
+ # Tensors can only have values in the inclusive range of 1-65535
+ inp = Tensor([1, 8, 8, 0], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 0], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+ inp = Tensor([1, 8, 8, 65536], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 65536], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_faf():
+ # Fused activation functions, if set, must be a valid op type
+ inp = Tensor([1, 8, 8, 8], DataType.uint8, "in")
+ out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op("Relu", [inp], out, attrs={"fused_activation_function": "Conv2D"})
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_quant_scale():
+ # Quantization scale cannot be infinit
+ op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
+ op.inputs[0].quantization = QuantizationParameters()
+ op.inputs[0].quantization.scale_f32 = np.inf
+ assert not support.is_operator_supported(op)