aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/select.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r--src/mlia/nn/select.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 7a25e47..5e223fa 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -6,6 +6,8 @@ from __future__ import annotations
import math
from pathlib import Path
from typing import Any
+from typing import cast
+from typing import List
from typing import NamedTuple
import tensorflow as tf
@@ -129,12 +131,12 @@ def get_optimizer(
return Clusterer(model, config)
if isinstance(config, RewriteConfiguration):
- return Rewriter(model, config) # type: ignore
+ return Rewriter(model, config)
- if isinstance(config, OptimizationSettings) or is_list_of(
- config, OptimizationSettings
- ):
- return _get_optimizer(model, config) # type: ignore
+ if isinstance(config, OptimizationSettings):
+ return _get_optimizer(model, cast(OptimizationSettings, config))
+ if is_list_of(config, OptimizationSettings):
+ return _get_optimizer(model, cast(List[OptimizationSettings], config))
raise ConfigurationError(f"Unknown optimization configuration {config}")
@@ -186,7 +188,7 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
- return RewriteConfiguration( # type: ignore
+ return RewriteConfiguration(
str(optimization_target), layers_to_optimize, dataset
)