aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index e2c097c..8fba806 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -24,13 +24,11 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
-from mlia.nn.rewrite.library.fc_clustering_layer import (
- get_keras_model_clus as fc_clustering_rewrite,
-)
-from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
-from mlia.nn.rewrite.library.fc_sparsity24_layer import (
- get_keras_model as fc_rewrite_sparsity24,
-)
+from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite
+from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite
+from mlia.nn.rewrite.library.fc_layer import fc_rewrite
+from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite
+from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
@@ -227,8 +225,10 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
GenericRewrite("fully-connected", fc_rewrite),
- Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
+ Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite),
ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite),
+ ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite),
+ Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite),
]
)
@@ -251,6 +251,12 @@ class RewritingOptimizer(Optimizer):
self.optimizer_configuration.optimization_target
]
+ if self.optimizer_configuration.optimization_target in [
+ "conv2d-clustering",
+ "conv2d-sparsity24",
+ ]:
+ raise NotImplementedError
+
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)