aboutsummaryrefslogtreecommitdiff
path: root/examples/tosa_checker.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'examples/tosa_checker.ipynb')
-rw-r--r--examples/tosa_checker.ipynb117
1 files changed, 117 insertions, 0 deletions
diff --git a/examples/tosa_checker.ipynb b/examples/tosa_checker.ipynb
new file mode 100644
index 0000000..8fe1f26
--- /dev/null
+++ b/examples/tosa_checker.ipynb
@@ -0,0 +1,117 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Import the modules needed to create a test model and run the TOSA Checker."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import tosa_checker as tc\n",
+ "import tensorflow as tf\n",
+ "import tempfile\n",
+ "import os"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create a simple model that is compatible with the TOSA specification."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "INFO:tensorflow:Assets written to: /tmp/tmpxc09cs65/assets\n"
+ ]
+ }
+ ],
+ "source": [
+ "input = tf.keras.layers.Input(shape=(16,))\n",
+ "x = tf.keras.layers.Dense(8, activation=\"relu\")(input)\n",
+ "model = tf.keras.models.Model(inputs=[input], outputs=x)\n",
+ "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
+ "tflite_model = converter.convert()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Save this model in `.tflite` format. Note that the TOSA Checker only accepts models in this format currently."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "_, tflite_file = tempfile.mkstemp('.tflite')\n",
+ "with open(tflite_file, \"wb\") as f:\n",
+ " f.write(tflite_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the TOSA Checker to check this model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Is model TOSA compatible ? True\n"
+ ]
+ }
+ ],
+ "source": [
+ "checker = tc.TOSAChecker(model_path=tflite_file)\n",
+ "result = checker.is_tosa_compatible()\n",
+ "print(\"Is model TOSA compatible ? {}\".format(result))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.8.0 ('tosa_checker': venv)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.0"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}