diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-07 13:46:39 +0100 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-21 16:51:15 +0100 |
commit | 0d3cc76284f9311c99169b568570d767f5b0aeb6 (patch) | |
tree | 2f187e003db300a61e91759040d867c568cca2c8 /src/mlia/nn/rewrite/core/rewrite.py | |
parent | fa1bf7cde005283eb8ef195ada4af48b31ff043e (diff) | |
download | mlia-0d3cc76284f9311c99169b568570d767f5b0aeb6.tar.gz |
feat: CLI and API changes for the conv2d rewrites
Implements CLI and API changes for the new conv2d rewrite targets
Resolves: MLIA-1157
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I03c7a3a536d2f0a805b4689a9d96b95f8b4ab86c
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index e2c097c..8fba806 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -24,13 +24,11 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters -from mlia.nn.rewrite.library.fc_clustering_layer import ( - get_keras_model_clus as fc_clustering_rewrite, -) -from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite -from mlia.nn.rewrite.library.fc_sparsity24_layer import ( - get_keras_model as fc_rewrite_sparsity24, -) +from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite +from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite +from mlia.nn.rewrite.library.fc_layer import fc_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry @@ -227,8 +225,10 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ GenericRewrite("fully-connected", fc_rewrite), - Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), + Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite), ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), + ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite), + Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite), ] ) @@ -251,6 +251,12 @@ class RewritingOptimizer(Optimizer): self.optimizer_configuration.optimization_target ] + if self.optimizer_configuration.optimization_target in [ + "conv2d-clustering", + "conv2d-sparsity24", + ]: + raise NotImplementedError + use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) |