{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Import the modules needed to create a test model and run the TOSA Checker." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tosa_checker as tc\n", "import tensorflow as tf\n", "import tempfile\n", "import os" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a simple model that is compatible with the TOSA specification." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/tmpxc09cs65/assets\n" ] } ], "source": [ "input = tf.keras.layers.Input(shape=(16,))\n", "x = tf.keras.layers.Dense(8, activation=\"relu\")(input)\n", "model = tf.keras.models.Model(inputs=[input], outputs=x)\n", "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "tflite_model = converter.convert()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Save this model in `.tflite` format. Note that the TOSA Checker only accepts models in this format currently." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "_, tflite_file = tempfile.mkstemp('.tflite')\n", "with open(tflite_file, \"wb\") as f:\n", " f.write(tflite_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use the TOSA Checker to check this model." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is model TOSA compatible ? True\n" ] } ], "source": [ "checker = tc.TOSAChecker(model_path=tflite_file)\n", "result = checker.is_tosa_compatible()\n", "print(\"Is model TOSA compatible ? {}\".format(result))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.0 ('tosa_checker': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }