# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Contains class Rewriter to replace a subgraph/layer of a model.""" from __future__ import annotations import importlib import tempfile from dataclasses import dataclass from pathlib import Path from typing import Any from mlia.core.errors import ConfigurationError from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import eval_in_dir from mlia.nn.rewrite.core.train import join_in_dir from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import train_in_dir from mlia.nn.tensorflow.config import TFLiteModel @dataclass class RewriteConfiguration(OptimizerConfiguration): """Rewrite configuration.""" optimization_target: str layers_to_optimize: list[str] | None = None dataset: Path | None = None def __str__(self) -> str: """Return string representation of the configuration.""" return f"rewrite: {self.optimization_target}" class Rewriter(Optimizer): """Rewriter class for basic rewrite flow.""" def __init__( self, tflite_model_path: Path, optimizer_configuration: RewriteConfiguration ): """Init Rewriter instance.""" self.model = TFLiteModel(tflite_model_path) self.optimizer_configuration = optimizer_configuration self.train_dir = "" def apply_optimization(self) -> None: """Apply the rewrite flow.""" def get_function(arg: str) -> Any: module_name = ".".join(arg.split(".")[:-1]) fn_name = arg.split(".")[-1] module = importlib.import_module(module_name) return getattr(module, fn_name) if self.optimizer_configuration.optimization_target == "fully_connected": replace_function = "mlia.nn.rewrite.library.fc_layer.get_keras_model" else: raise ConfigurationError( "Only fully_connected replacement is supported in rewrite module." ) replace_fn = get_function(replace_function) augmentation_preset = (None, None) use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) with tempfile.TemporaryDirectory() as tmp_dir: tmp_output = Path(tmp_dir, "output.tflite") if self.train_dir: tmp_new = Path(tmp_dir, "new.tflite") new_part = train_in_dir( train_dir=self.train_dir, baseline_dir=None, output_filename=tmp_new, replace_fn=replace_fn, augmentations=augmentation_preset, steps=32, learning_rate=1e-3, batch_size=1, verbose=True, show_progress=True, ) eval_in_dir(self.train_dir, new_part[0]) join_in_dir(self.train_dir, new_part[0], str(tmp_output)) else: if not self.optimizer_configuration.layers_to_optimize: raise ConfigurationError( "Input and output tensor names need to be set for rewrite." ) train( source_model=tflite_model, unmodified_model=tflite_model if use_unmodified_model else None, output_model=str(tmp_output), input_tfrec=str(tfrecord), replace_fn=replace_fn, input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], augment=augmentation_preset, steps=32, learning_rate=1e-3, batch_size=1, verbose=True, show_progress=True, ) def get_model(self) -> TFLiteModel: """Return optimized model.""" return self.model def optimization_config(self) -> str: """Optimization configurations.""" return str(self.optimizer_configuration)