aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/cut.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/cut.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/cut.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py
index 13a5268..53d5389 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/cut.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py
@@ -1,9 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cut module."""
+from __future__ import annotations
+
import os
from collections import defaultdict
-from typing import Optional
+from pathlib import Path
import tensorflow as tf
from tensorflow.lite.python.schema_py_generated import ModelT
@@ -25,8 +27,8 @@ def tensors_by_name(subgraph: SubGraphT, names: list) -> list:
def cut_subgraph(
subgraph: SubGraphT,
- input_tensor_names: Optional[list],
- output_tensor_names: Optional[list],
+ input_tensor_names: list | None,
+ output_tensor_names: list | None,
) -> None:
"""Change the global inputs and outputs of a graph to the provided named tensors."""
if input_tensor_names is not None:
@@ -131,11 +133,11 @@ def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple:
def cut_model(
- model_file: str,
- input_names: Optional[list],
- output_names: Optional[list],
+ model_file: str | Path,
+ input_names: list | None,
+ output_names: list | None,
subgraph_index: int,
- output_file: str,
+ output_file: str | Path,
) -> None:
"""Cut subgraphs and simplify a given model."""
model = load_fb(model_file)