aboutsummaryrefslogtreecommitdiff
path: root/python/pytests
diff options
context:
space:
mode:
authorKaushik Varadharajan <kaushik.varadharajan@arm.com>2024-06-25 18:21:04 +0000
committerKaushik Varadharajan <kaushik.varadharajan@arm.com>2024-06-25 18:21:04 +0000
commit11f1ef0e5fc14c402f654ae4c21492b114f925a3 (patch)
tree966eb9057735dab0e29fec876c170a5fc998a213 /python/pytests
parent6e13113b5d1a1d8afe00c8a577a014db7df5b0a4 (diff)
downloadserialization_lib-11f1ef0e5fc14c402f654ae4c21492b114f925a3.tar.gz
Add way to explicitly skip datatypes in Pytest
Signed-off-by: Kaushik Varadharajan <kaushik.varadharajan@arm.com> Change-Id: Ic640f7d48720ce9403790588e46333c028960756
Diffstat (limited to 'python/pytests')
-rw-r--r--python/pytests/test_single_tensor.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/python/pytests/test_single_tensor.py b/python/pytests/test_single_tensor.py
index f665161..02e84c2 100644
--- a/python/pytests/test_single_tensor.py
+++ b/python/pytests/test_single_tensor.py
@@ -24,6 +24,13 @@ import numpy as np
from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2, int4, finfo, iinfo
+# These datatypes are skipped during testing, presumably since they are new
+# and tests have not yet been implemented. This should be emptied frequently.
+SKIPPED_DTYPES = []
+
+TESTED_DTYPES = set(ts.DTypeNames) - set(SKIPPED_DTYPES)
+
+
def generate_random_data(dtype_str):
# Creating the random data.
@@ -78,7 +85,8 @@ def generate_random_data(dtype_str):
).astype(py_dtype)
else:
raise NotImplementedError(
- f"Random tensor generation for type {dtype_str} not implemented"
+ f"Random tensor generation for type {dtype_str} not implemented. \
+Consider adding to SKIPPED_DTYPES."
)
return data, shape, py_dtype
@@ -116,7 +124,7 @@ def serialize_and_load_json(ser: ts.TosaSerializer, tosa_filename) -> dict:
return json.load(f)
-@pytest.mark.parametrize("dtype_str", ts.DTypeNames)
+@pytest.mark.parametrize("dtype_str", TESTED_DTYPES)
def test_single_intermediate(request, dtype_str):
"""
Creating an intermediate tensor of each dtype
@@ -152,7 +160,7 @@ def test_single_intermediate(request, dtype_str):
def placeholder_cases():
- for dtype_str in ts.DTypeNames:
+ for dtype_str in TESTED_DTYPES:
# The ml_dtypes library has issues with serializing FP8E5M2 to .npy
# files, so we don't test it.
if dtype_str in ["UNKNOWN", "FP8E5M2"]:
@@ -201,7 +209,7 @@ def test_single_placeholder(request, dtype_str):
def const_cases():
- for dtype_str in ts.DTypeNames:
+ for dtype_str in TESTED_DTYPES:
for const_mode in ts.ConstMode.__members__.values():
# We don't support uint8 or uint16 serialization to flatbuffer;
# see convertDataToUint8Vec