aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_graph_edit_cut.py
blob: 7d267edd3c36429408d2239119ba79d377e93464 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.graph_edit.cut."""
from pathlib import Path

import numpy as np
import tensorflow as tf

from mlia.nn.rewrite.core.graph_edit.cut import cut_model


def test_cut_model(test_tflite_model: Path, tmp_path: Path) -> None:
    """Test the function cut_model()."""
    output_file = tmp_path / "out.tflite"
    cut_model(
        model_file=str(test_tflite_model),
        input_names=["serving_default_input:0"],
        output_names=["sequential/flatten/Reshape"],
        subgraph_index=0,
        output_file=str(output_file),
    )
    assert output_file.is_file()

    interpreter = tf.lite.Interpreter(model_path=str(output_file))
    output_details = interpreter.get_output_details()
    assert len(output_details) == 1
    out = output_details[0]
    assert "Reshape" in out["name"]
    assert np.prod(out["shape"]) == 1728