aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/utils/common.py
diff options
context:
space:
mode:
authorDiego Russo <diego.russo@arm.com>2022-05-30 13:34:14 +0100
committerDiego Russo <diego.russo@arm.com>2022-05-30 13:34:14 +0100
commit0efca3cadbad5517a59884576ddb90cfe7ac30f8 (patch)
treeabed6cb6fbf3c439fc8d947f505b6a53d5daeb1e /tests/mlia/utils/common.py
parent0777092695c143c3a54680b5748287d40c914c35 (diff)
downloadmlia-0efca3cadbad5517a59884576ddb90cfe7ac30f8.tar.gz
Add MLIA codebase0.3.0-rc.1
Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd
Diffstat (limited to 'tests/mlia/utils/common.py')
-rw-r--r--tests/mlia/utils/common.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/tests/mlia/utils/common.py b/tests/mlia/utils/common.py
new file mode 100644
index 0000000..4313cde
--- /dev/null
+++ b/tests/mlia/utils/common.py
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common test utils module."""
+from typing import Tuple
+
+import numpy as np
+import tensorflow as tf
+
+
+def get_dataset() -> Tuple[np.array, np.array]:
+ """Return sample dataset."""
+ mnist = tf.keras.datasets.mnist
+ (x_train, y_train), _ = mnist.load_data()
+ x_train = x_train / 255.0
+
+ # Use subset of 60000 examples to keep unit test speed fast.
+ x_train = x_train[0:1]
+ y_train = y_train[0:1]
+
+ return x_train, y_train
+
+
+def train_model(model: tf.keras.Model) -> None:
+ """Train model using sample dataset."""
+ num_epochs = 1
+
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
+
+ x_train, y_train = get_dataset()
+
+ model.fit(x_train, y_train, epochs=num_epochs)