aboutsummaryrefslogtreecommitdiff
path: root/tests/test_backend_output_consumer.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-07-11 12:33:42 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-07-26 14:08:21 +0100
commit5d81f37de09efe10f90512e50252be9c36925fcf (patch)
treeb4d7cdfd051da0a6e882bdfcf280fd7ca7b39e57 /tests/test_backend_output_consumer.py
parent7899b908c1fe6d86b92a80f3827ddd0ac05b674b (diff)
downloadmlia-5d81f37de09efe10f90512e50252be9c36925fcf.tar.gz
MLIA-551 Rework remains of AIET architecture
Re-factoring the code base to further merge the old AIET code into MLIA. - Remove last traces of the backend type 'tool' - Controlled systems removed, including SSH protocol, controller, RunningCommand, locks etc. - Build command / build dir and deploy functionality removed from Applications and Systems - Moving working_dir() - Replace module 'output_parser' with new module 'output_consumer' and merge Base64 parsing into it - Change the output consumption to optionally remove (i.e. actually consume) lines - Use Base64 parsing in GenericInferenceOutputParser, replacing the regex-based parsing and remove the now unused regex parsing - Remove AIET reporting - Pre-install applications by moving them to src/mlia/resources/backends - Rename aiet-config.json to backend-config.json - Move tests from tests/mlia/ to tests/ - Adapt unit tests to code changes - Dependencies removed: paramiko, filelock, psutil - Fix bug in corstone.py: The wrong resource directory was used which broke the functionality to download backends. - Use f-string formatting. - Use logging instead of print. Change-Id: I768bc3bb6b2eda57d219ad01be4a8e0a74167d76
Diffstat (limited to 'tests/test_backend_output_consumer.py')
-rw-r--r--tests/test_backend_output_consumer.py99
1 files changed, 99 insertions, 0 deletions
diff --git a/tests/test_backend_output_consumer.py b/tests/test_backend_output_consumer.py
new file mode 100644
index 0000000..881112e
--- /dev/null
+++ b/tests/test_backend_output_consumer.py
@@ -0,0 +1,99 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the output parsing."""
+import base64
+import json
+from typing import Any
+from typing import Dict
+
+import pytest
+
+from mlia.backend.output_consumer import Base64OutputConsumer
+from mlia.backend.output_consumer import OutputConsumer
+
+
+OUTPUT_MATCH_ALL = bytearray(
+ """
+String1: My awesome string!
+String2: STRINGS_ARE_GREAT!!!
+Int: 12
+Float: 3.14
+""",
+ encoding="utf-8",
+)
+
+OUTPUT_NO_MATCH = bytearray(
+ """
+This contains no matches...
+Test1234567890!"£$%^&*()_+@~{}[]/.,<>?|
+""",
+ encoding="utf-8",
+)
+
+OUTPUT_PARTIAL_MATCH = bytearray(
+ "String1: My awesome string!",
+ encoding="utf-8",
+)
+
+REGEX_CONFIG = {
+ "FirstString": {"pattern": r"String1.*: (.*)", "type": "str"},
+ "SecondString": {"pattern": r"String2.*: (.*)!!!", "type": "str"},
+ "IntegerValue": {"pattern": r"Int.*: (.*)", "type": "int"},
+ "FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"},
+}
+
+EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {}
+
+EXPECTED_METRICS_ALL = {
+ "FirstString": "My awesome string!",
+ "SecondString": "STRINGS_ARE_GREAT",
+ "IntegerValue": 12,
+ "FloatValue": 3.14,
+}
+
+EXPECTED_METRICS_PARTIAL = {
+ "FirstString": "My awesome string!",
+}
+
+
+@pytest.mark.parametrize(
+ "expected_metrics",
+ [
+ EXPECTED_METRICS_ALL,
+ EXPECTED_METRICS_PARTIAL,
+ ],
+)
+def test_base64_output_consumer(expected_metrics: Dict) -> None:
+ """
+ Make sure the Base64OutputConsumer yields valid results.
+
+ I.e. return an empty dict if either the input or the config is empty and
+ return the parsed metrics otherwise.
+ """
+ parser = Base64OutputConsumer()
+ assert isinstance(parser, OutputConsumer)
+
+ def create_base64_output(expected_metrics: Dict) -> bytearray:
+ json_str = json.dumps(expected_metrics, indent=4)
+ json_b64 = base64.b64encode(json_str.encode("utf-8"))
+ return (
+ OUTPUT_MATCH_ALL # Should not be matched by the Base64OutputConsumer
+ + f"<{Base64OutputConsumer.TAG_NAME}>".encode("utf-8")
+ + bytearray(json_b64)
+ + f"</{Base64OutputConsumer.TAG_NAME}>".encode("utf-8")
+ + OUTPUT_NO_MATCH # Just to add some difficulty...
+ )
+
+ output = create_base64_output(expected_metrics)
+
+ consumed = False
+ for line in output.splitlines():
+ if parser.feed(line.decode("utf-8")):
+ consumed = True
+ assert consumed # we should have consumed at least one line
+
+ res = parser.parsed_output
+ assert len(res) == 1
+ assert isinstance(res, list)
+ for val in res:
+ assert val == expected_metrics