diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index e2c097c..8fba806 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -24,13 +24,11 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters -from mlia.nn.rewrite.library.fc_clustering_layer import ( - get_keras_model_clus as fc_clustering_rewrite, -) -from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite -from mlia.nn.rewrite.library.fc_sparsity24_layer import ( - get_keras_model as fc_rewrite_sparsity24, -) +from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite +from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite +from mlia.nn.rewrite.library.fc_layer import fc_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry @@ -227,8 +225,10 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ GenericRewrite("fully-connected", fc_rewrite), - Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), + Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite), ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), + ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite), + Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite), ] ) @@ -251,6 +251,12 @@ class RewritingOptimizer(Optimizer): self.optimizer_configuration.optimization_target ] + if self.optimizer_configuration.optimization_target in [ + "conv2d-clustering", + "conv2d-sparsity24", + ]: + raise NotImplementedError + use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) |