blob: d4f61c5f210fa1b79c5d7e6bc599a00eed720ca7 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
|
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Contains class Rewriter to replace a subgraph/layer of a model."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.tensorflow.config import TFLiteModel
@dataclass
class RewriteConfiguration(OptimizerConfiguration):
"""Rewrite configuration."""
optimization_target: str
layers_to_optimize: list[str] | None = None
dataset: Path | None = None
def __str__(self) -> str:
"""Return string representation of the configuration."""
return f"rewrite: {self.optimization_target}"
class Rewriter(Optimizer):
"""Rewriter class for basic rewrite flow."""
def __init__(
self, tflite_model_path: Path, optimizer_configuration: RewriteConfiguration
):
"""Init Rewriter instance."""
self.model = TFLiteModel(tflite_model_path)
self.optimizer_configuration = optimizer_configuration
def apply_optimization(self) -> None:
"""Apply the rewrite flow."""
def get_model(self) -> TFLiteModel:
"""Return optimized model."""
return self.model
def optimization_config(self) -> str:
"""Optimization configirations."""
|