aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/utils
diff options
context:
space:
mode:
Diffstat (limited to 'tests/mlia/utils')
-rw-r--r--tests/mlia/utils/__init__.py3
-rw-r--r--tests/mlia/utils/common.py32
-rw-r--r--tests/mlia/utils/logging.py13
3 files changed, 48 insertions, 0 deletions
diff --git a/tests/mlia/utils/__init__.py b/tests/mlia/utils/__init__.py
new file mode 100644
index 0000000..27166ef
--- /dev/null
+++ b/tests/mlia/utils/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test utils module."""
diff --git a/tests/mlia/utils/common.py b/tests/mlia/utils/common.py
new file mode 100644
index 0000000..4313cde
--- /dev/null
+++ b/tests/mlia/utils/common.py
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common test utils module."""
+from typing import Tuple
+
+import numpy as np
+import tensorflow as tf
+
+
+def get_dataset() -> Tuple[np.array, np.array]:
+ """Return sample dataset."""
+ mnist = tf.keras.datasets.mnist
+ (x_train, y_train), _ = mnist.load_data()
+ x_train = x_train / 255.0
+
+ # Use subset of 60000 examples to keep unit test speed fast.
+ x_train = x_train[0:1]
+ y_train = y_train[0:1]
+
+ return x_train, y_train
+
+
+def train_model(model: tf.keras.Model) -> None:
+ """Train model using sample dataset."""
+ num_epochs = 1
+
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
+
+ x_train, y_train = get_dataset()
+
+ model.fit(x_train, y_train, epochs=num_epochs)
diff --git a/tests/mlia/utils/logging.py b/tests/mlia/utils/logging.py
new file mode 100644
index 0000000..d223fb2
--- /dev/null
+++ b/tests/mlia/utils/logging.py
@@ -0,0 +1,13 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils for logging."""
+import logging
+
+
+def clear_loggers() -> None:
+ """Close the log handlers."""
+ for _, logger in logging.Logger.manager.loggerDict.items():
+ if not isinstance(logger, logging.PlaceHolder):
+ for handler in logger.handlers:
+ handler.close()
+ logger.removeHandler(handler)