aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/library/fc_layer.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/library/fc_layer.py')
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
index 8704154..2480500 100644
--- a/src/mlia/nn/rewrite/library/fc_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_layer.py
@@ -7,12 +7,12 @@ import tensorflow as tf
def get_keras_model(input_shape: Any, output_shape: Any) -> tf.keras.Model:
- """Generate tflite model for rewrite."""
- input_tensor = tf.keras.layers.Input(
- shape=input_shape, name="MbileNet/avg_pool/AvgPool"
+ """Generate TensorFlow Lite model for rewrite."""
+ model = tf.keras.Sequential(
+ (
+ tf.keras.layers.InputLayer(input_shape=input_shape),
+ tf.keras.layers.Reshape([-1]),
+ tf.keras.layers.Dense(output_shape),
+ )
)
- output_tensor = tf.keras.layers.Dense(output_shape, name="MobileNet/fc1/BiasAdd")(
- input_tensor
- )
- model = tf.keras.Model(input_tensor, output_tensor)
return model