aboutsummaryrefslogtreecommitdiff
path: root/examples/tosa_checker.ipynb
diff options
context:
space:
mode:
authorThibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon@arm.com>2022-07-06 10:23:22 +0100
committerThibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon@arm.com>2022-08-18 14:35:45 +0100
commit52dacd6556d60815253d4e4938e218ea3d8084a2 (patch)
tree4c470c567da6f70f65987d5af161bf4f950d107b /examples/tosa_checker.ipynb
parentcc5d89eea4ff3dc398cac3b6025450f48ac20c1e (diff)
downloadtosa_checker-52dacd6556d60815253d4e4938e218ea3d8084a2.tar.gz
Initial commit0.1.0-rc.1
Change-Id: I2fb0933d595a6ede6417d09dd905ef72d6c60c9b
Diffstat (limited to 'examples/tosa_checker.ipynb')
-rw-r--r--examples/tosa_checker.ipynb117
1 files changed, 117 insertions, 0 deletions
diff --git a/examples/tosa_checker.ipynb b/examples/tosa_checker.ipynb
new file mode 100644
index 0000000..8fe1f26
--- /dev/null
+++ b/examples/tosa_checker.ipynb
@@ -0,0 +1,117 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Import the modules needed to create a test model and run the TOSA Checker."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import tosa_checker as tc\n",
+ "import tensorflow as tf\n",
+ "import tempfile\n",
+ "import os"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create a simple model that is compatible with the TOSA specification."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "INFO:tensorflow:Assets written to: /tmp/tmpxc09cs65/assets\n"
+ ]
+ }
+ ],
+ "source": [
+ "input = tf.keras.layers.Input(shape=(16,))\n",
+ "x = tf.keras.layers.Dense(8, activation=\"relu\")(input)\n",
+ "model = tf.keras.models.Model(inputs=[input], outputs=x)\n",
+ "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
+ "tflite_model = converter.convert()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Save this model in `.tflite` format. Note that the TOSA Checker only accepts models in this format currently."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "_, tflite_file = tempfile.mkstemp('.tflite')\n",
+ "with open(tflite_file, \"wb\") as f:\n",
+ " f.write(tflite_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the TOSA Checker to check this model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Is model TOSA compatible ? True\n"
+ ]
+ }
+ ],
+ "source": [
+ "checker = tc.TOSAChecker(model_path=tflite_file)\n",
+ "result = checker.is_tosa_compatible()\n",
+ "print(\"Is model TOSA compatible ? {}\".format(result))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.8.0 ('tosa_checker': venv)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.0"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}