aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/select.py
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2023-03-15 11:27:08 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:43:14 +0100
commit867f37d643e66c0223457c28f5345f2f21db97f2 (patch)
tree4e3c55896760e24a8b5eadc5176ce7f5586552e1 /src/mlia/nn/select.py
parent62768232c5fe4ed6b87136c336b65e13d030e9d4 (diff)
downloadmlia-867f37d643e66c0223457c28f5345f2f21db97f2.tar.gz
Adapt rewrite module to MLIA coding standards
- Fix imports - Update variable names - Refactor helper functions - Add licence headers - Add docstrings - Use f-strings rather than % notation - Create type annotations in rewrite module - Migrate from tqdm to rich progress bar - Use logging module in rewrite module: All print statements are replaced with logging module Resolves: MLIA-831, MLIA-842, MLIA-844, MLIA-846 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: Idee37538d72b9f01128a894281a8d10155f7c17c
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r--src/mlia/nn/select.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 7a25e47..5e223fa 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -6,6 +6,8 @@ from __future__ import annotations
import math
from pathlib import Path
from typing import Any
+from typing import cast
+from typing import List
from typing import NamedTuple
import tensorflow as tf
@@ -129,12 +131,12 @@ def get_optimizer(
return Clusterer(model, config)
if isinstance(config, RewriteConfiguration):
- return Rewriter(model, config) # type: ignore
+ return Rewriter(model, config)
- if isinstance(config, OptimizationSettings) or is_list_of(
- config, OptimizationSettings
- ):
- return _get_optimizer(model, config) # type: ignore
+ if isinstance(config, OptimizationSettings):
+ return _get_optimizer(model, cast(OptimizationSettings, config))
+ if is_list_of(config, OptimizationSettings):
+ return _get_optimizer(model, cast(List[OptimizationSettings], config))
raise ConfigurationError(f"Unknown optimization configuration {config}")
@@ -186,7 +188,7 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
- return RewriteConfiguration( # type: ignore
+ return RewriteConfiguration(
str(optimization_target), layers_to_optimize, dataset
)