aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations/common.py')
-rw-r--r--src/mlia/nn/tensorflow/optimizations/common.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/common.py b/src/mlia/nn/tensorflow/optimizations/common.py
new file mode 100644
index 0000000..1dce0b2
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/common.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common items for the optimizations module."""
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+
+import tensorflow as tf
+
+
+@dataclass
+class OptimizerConfiguration:
+ """Abstract optimizer configuration."""
+
+
+class Optimizer(ABC):
+ """Abstract class for the optimizer."""
+
+ @abstractmethod
+ def get_model(self) -> tf.keras.Model:
+ """Abstract method to return the model instance from the optimizer."""
+
+ @abstractmethod
+ def apply_optimization(self) -> None:
+ """Abstract method to apply optimization to the model."""
+
+ @abstractmethod
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""