aboutsummaryrefslogtreecommitdiff
path: root/tests/test_cli_commands.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_cli_commands.py')
-rw-r--r--tests/test_cli_commands.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 6765a53..e4bbe91 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -73,8 +73,8 @@ def test_performance_unknown_target(
None,
True,
"fully_connected",
- "node_a",
- "node_b",
+ "sequential/flatten/Reshape",
+ "StatefulPartitionedCall:0",
does_not_raise(),
],
[
@@ -85,8 +85,8 @@ def test_performance_unknown_target(
None,
True,
"fully_connected",
- "node_a",
- "node_b",
+ "sequential/flatten/Reshape",
+ "StatefulPartitionedCall:0",
pytest.raises(
Exception,
match=(r"Only 'rewrite' is supported for TensorFlow Lite files."),
@@ -157,7 +157,7 @@ def test_performance_unknown_target(
],
],
)
-def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments
+def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments
target_profile: str,
sample_context: ExecutionContext,
pruning: bool,
@@ -171,12 +171,14 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments
expected_error: Any,
monkeypatch: pytest.MonkeyPatch,
test_keras_model: Path,
- test_tflite_model: Path,
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
) -> None:
"""Test that command should not fail with valid optimization targets."""
mock_performance_estimation(monkeypatch)
- model_type = test_tflite_model if rewrite else test_keras_model
+ model_type = test_tflite_model_fp32 if rewrite else test_keras_model
+ data = test_tfrecord_fp32 if rewrite else None
with expected_error:
optimize(
@@ -191,6 +193,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments
rewrite_target=rewrite_target,
rewrite_start=rewrite_start,
rewrite_end=rewrite_end,
+ dataset=data,
)