aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_supported_operators.py28
1 files changed, 15 insertions, 13 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 53c20927..20d448d7 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -19,6 +19,7 @@
import numpy as np
from ethosu.vela.data_type import DataType
+from ethosu.vela.operation import Op
from ethosu.vela.supported_operators import SupportedOperators
from ethosu.vela.tensor import create_const_tensor
from ethosu.vela.tensor import QuantizationParameters
@@ -35,7 +36,7 @@ def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1])
out = Tensor(out_shape, DataType.uint8, "out")
attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
- return testutil.create_op("StridedSlice", [in0, in1, in2, in3], out, attrs=attrs)
+ return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
def create_strided_slice():
@@ -93,21 +94,21 @@ def test_constraint_tens_defined_shape():
# Tensors cannot have None in them
inp = Tensor([1, 8, None, 8], DataType.uint8, "in")
out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
- op = testutil.create_op("Relu", [inp], out)
+ op = testutil.create_op(Op.Relu, [inp], out)
assert not support.is_operator_supported(op)
def test_constraint_tens_shapeless():
# Shapeless input is allowed if its of a certain type:
- op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
+ op = testutil.create_elemwise_op(Op.Mul, "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
assert support.is_operator_supported(op)
# Shapeless output is not allowed at all:
- op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [1, 8, 8, 8], [])
+ op = testutil.create_elemwise_op(Op.Mul, "scalar_mul", [1, 8, 8, 8], [1, 8, 8, 8], [])
assert not support.is_operator_supported(op)
# Invalid shapeless input due to op type:
inp = Tensor([], DataType.uint8, "in")
out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
- op = testutil.create_op("Relu", [inp], out)
+ op = testutil.create_op(Op.Relu, [inp], out)
assert not support.is_operator_supported(op)
@@ -115,7 +116,7 @@ def test_constraint_tens_shape_size():
# Tensors cannot be > 4D
inp = Tensor([1, 1, 8, 8, 8], DataType.uint8, "in")
out = Tensor([1, 1, 8, 8, 8], DataType.uint8, "out")
- op = testutil.create_op("Relu", [inp], out)
+ op = testutil.create_op(Op.Relu, [inp], out)
assert not support.is_operator_supported(op)
@@ -123,14 +124,14 @@ def test_constraint_tens_dtype():
# Tensors can only be of type uint8, int8, int16 (and int32)
inp = Tensor([1, 8, 8, 8], DataType.float32, "in")
out = Tensor([1, 8, 8, 8], DataType.float32, "out")
- op = testutil.create_op("Relu", [inp], out)
+ op = testutil.create_op(Op.Relu, [inp], out)
assert not support.is_operator_supported(op)
# For int32, only select op types are allowed:
- op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8], DataType.int32)
+ op = testutil.create_elemwise_op(Op.Mul, "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8], DataType.int32)
assert support.is_operator_supported(op)
inp = Tensor([1, 8, 8, 8], DataType.int32, "in")
out = Tensor([1, 8, 8, 8], DataType.int32, "out")
- op = testutil.create_op("Relu", [inp], out)
+ op = testutil.create_op(Op.Relu, [inp], out)
assert not support.is_operator_supported(op)
@@ -138,11 +139,11 @@ def test_constraint_tens_dimension():
# Tensors can only have values in the inclusive range of 1-65535
inp = Tensor([1, 8, 8, 0], DataType.uint8, "in")
out = Tensor([1, 8, 8, 0], DataType.uint8, "out")
- op = testutil.create_op("Relu", [inp], out)
+ op = testutil.create_op(Op.Relu, [inp], out)
assert not support.is_operator_supported(op)
inp = Tensor([1, 8, 8, 65536], DataType.uint8, "in")
out = Tensor([1, 8, 8, 65536], DataType.uint8, "out")
- op = testutil.create_op("Relu", [inp], out)
+ op = testutil.create_op(Op.Relu, [inp], out)
assert not support.is_operator_supported(op)
@@ -150,13 +151,14 @@ def test_constraint_faf():
# Fused activation functions, if set, must be a valid op type
inp = Tensor([1, 8, 8, 8], DataType.uint8, "in")
out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
- op = testutil.create_op("Relu", [inp], out, attrs={"fused_activation_function": "Conv2D"})
+ op = testutil.create_op(Op.Relu, [inp], out)
+ op.activation = Op.Conv2D
assert not support.is_operator_supported(op)
def test_constraint_tens_quant_scale():
# Quantization scale cannot be infinit
- op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
+ op = testutil.create_elemwise_op(Op.Mul, "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
op.inputs[0].quantization = QuantizationParameters()
op.inputs[0].quantization.scale_f32 = np.inf
assert not support.is_operator_supported(op)