diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/cut.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/cut.py | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py index 13a5268..53d5389 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/cut.py +++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py @@ -1,9 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cut module.""" +from __future__ import annotations + import os from collections import defaultdict -from typing import Optional +from pathlib import Path import tensorflow as tf from tensorflow.lite.python.schema_py_generated import ModelT @@ -25,8 +27,8 @@ def tensors_by_name(subgraph: SubGraphT, names: list) -> list: def cut_subgraph( subgraph: SubGraphT, - input_tensor_names: Optional[list], - output_tensor_names: Optional[list], + input_tensor_names: list | None, + output_tensor_names: list | None, ) -> None: """Change the global inputs and outputs of a graph to the provided named tensors.""" if input_tensor_names is not None: @@ -131,11 +133,11 @@ def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple: def cut_model( - model_file: str, - input_names: Optional[list], - output_names: Optional[list], + model_file: str | Path, + input_names: list | None, + output_names: list | None, subgraph_index: int, - output_file: str, + output_file: str | Path, ) -> None: """Cut subgraphs and simplify a given model.""" model = load_fb(model_file) |