diff options
author | Kaushik Varadharajan <kaushik.varadharajan@arm.com> | 2024-06-25 18:21:04 +0000 |
---|---|---|
committer | Kaushik Varadharajan <kaushik.varadharajan@arm.com> | 2024-06-25 18:21:04 +0000 |
commit | 11f1ef0e5fc14c402f654ae4c21492b114f925a3 (patch) | |
tree | 966eb9057735dab0e29fec876c170a5fc998a213 /python/pytests | |
parent | 6e13113b5d1a1d8afe00c8a577a014db7df5b0a4 (diff) | |
download | serialization_lib-11f1ef0e5fc14c402f654ae4c21492b114f925a3.tar.gz |
Add way to explicitly skip datatypes in Pytest
Signed-off-by: Kaushik Varadharajan <kaushik.varadharajan@arm.com>
Change-Id: Ic640f7d48720ce9403790588e46333c028960756
Diffstat (limited to 'python/pytests')
-rw-r--r-- | python/pytests/test_single_tensor.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/python/pytests/test_single_tensor.py b/python/pytests/test_single_tensor.py index f665161..02e84c2 100644 --- a/python/pytests/test_single_tensor.py +++ b/python/pytests/test_single_tensor.py @@ -24,6 +24,13 @@ import numpy as np from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2, int4, finfo, iinfo +# These datatypes are skipped during testing, presumably since they are new +# and tests have not yet been implemented. This should be emptied frequently. +SKIPPED_DTYPES = [] + +TESTED_DTYPES = set(ts.DTypeNames) - set(SKIPPED_DTYPES) + + def generate_random_data(dtype_str): # Creating the random data. @@ -78,7 +85,8 @@ def generate_random_data(dtype_str): ).astype(py_dtype) else: raise NotImplementedError( - f"Random tensor generation for type {dtype_str} not implemented" + f"Random tensor generation for type {dtype_str} not implemented. \ +Consider adding to SKIPPED_DTYPES." ) return data, shape, py_dtype @@ -116,7 +124,7 @@ def serialize_and_load_json(ser: ts.TosaSerializer, tosa_filename) -> dict: return json.load(f) -@pytest.mark.parametrize("dtype_str", ts.DTypeNames) +@pytest.mark.parametrize("dtype_str", TESTED_DTYPES) def test_single_intermediate(request, dtype_str): """ Creating an intermediate tensor of each dtype @@ -152,7 +160,7 @@ def test_single_intermediate(request, dtype_str): def placeholder_cases(): - for dtype_str in ts.DTypeNames: + for dtype_str in TESTED_DTYPES: # The ml_dtypes library has issues with serializing FP8E5M2 to .npy # files, so we don't test it. if dtype_str in ["UNKNOWN", "FP8E5M2"]: @@ -201,7 +209,7 @@ def test_single_placeholder(request, dtype_str): def const_cases(): - for dtype_str in ts.DTypeNames: + for dtype_str in TESTED_DTYPES: for const_mode in ts.ConstMode.__members__.values(): # We don't support uint8 or uint16 serialization to flatbuffer; # see convertDataToUint8Vec |