diff options
author | Thibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon@arm.com> | 2022-07-06 10:23:22 +0100 |
---|---|---|
committer | Thibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon@arm.com> | 2022-08-18 14:35:45 +0100 |
commit | 52dacd6556d60815253d4e4938e218ea3d8084a2 (patch) | |
tree | 4c470c567da6f70f65987d5af161bf4f950d107b /examples/tosa_checker.ipynb | |
parent | cc5d89eea4ff3dc398cac3b6025450f48ac20c1e (diff) | |
download | tosa_checker-52dacd6556d60815253d4e4938e218ea3d8084a2.tar.gz |
Initial commit0.1.0-rc.1
Change-Id: I2fb0933d595a6ede6417d09dd905ef72d6c60c9b
Diffstat (limited to 'examples/tosa_checker.ipynb')
-rw-r--r-- | examples/tosa_checker.ipynb | 117 |
1 files changed, 117 insertions, 0 deletions
diff --git a/examples/tosa_checker.ipynb b/examples/tosa_checker.ipynb new file mode 100644 index 0000000..8fe1f26 --- /dev/null +++ b/examples/tosa_checker.ipynb @@ -0,0 +1,117 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Import the modules needed to create a test model and run the TOSA Checker." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import tosa_checker as tc\n", + "import tensorflow as tf\n", + "import tempfile\n", + "import os" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create a simple model that is compatible with the TOSA specification." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Assets written to: /tmp/tmpxc09cs65/assets\n" + ] + } + ], + "source": [ + "input = tf.keras.layers.Input(shape=(16,))\n", + "x = tf.keras.layers.Dense(8, activation=\"relu\")(input)\n", + "model = tf.keras.models.Model(inputs=[input], outputs=x)\n", + "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", + "tflite_model = converter.convert()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Save this model in `.tflite` format. Note that the TOSA Checker only accepts models in this format currently." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "_, tflite_file = tempfile.mkstemp('.tflite')\n", + "with open(tflite_file, \"wb\") as f:\n", + " f.write(tflite_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use the TOSA Checker to check this model." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Is model TOSA compatible ? True\n" + ] + } + ], + "source": [ + "checker = tc.TOSAChecker(model_path=tflite_file)\n", + "result = checker.is_tosa_compatible()\n", + "print(\"Is model TOSA compatible ? {}\".format(result))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.0 ('tosa_checker': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} |