diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-07-11 12:33:42 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-07-26 14:08:21 +0100 |
commit | 5d81f37de09efe10f90512e50252be9c36925fcf (patch) | |
tree | b4d7cdfd051da0a6e882bdfcf280fd7ca7b39e57 /tests/test_backend_output_consumer.py | |
parent | 7899b908c1fe6d86b92a80f3827ddd0ac05b674b (diff) | |
download | mlia-5d81f37de09efe10f90512e50252be9c36925fcf.tar.gz |
MLIA-551 Rework remains of AIET architecture
Re-factoring the code base to further merge the old AIET code into MLIA.
- Remove last traces of the backend type 'tool'
- Controlled systems removed, including SSH protocol, controller,
RunningCommand, locks etc.
- Build command / build dir and deploy functionality removed from
Applications and Systems
- Moving working_dir()
- Replace module 'output_parser' with new module 'output_consumer' and
merge Base64 parsing into it
- Change the output consumption to optionally remove (i.e. actually
consume) lines
- Use Base64 parsing in GenericInferenceOutputParser, replacing the
regex-based parsing and remove the now unused regex parsing
- Remove AIET reporting
- Pre-install applications by moving them to src/mlia/resources/backends
- Rename aiet-config.json to backend-config.json
- Move tests from tests/mlia/ to tests/
- Adapt unit tests to code changes
- Dependencies removed: paramiko, filelock, psutil
- Fix bug in corstone.py: The wrong resource directory was used which
broke the functionality to download backends.
- Use f-string formatting.
- Use logging instead of print.
Change-Id: I768bc3bb6b2eda57d219ad01be4a8e0a74167d76
Diffstat (limited to 'tests/test_backend_output_consumer.py')
-rw-r--r-- | tests/test_backend_output_consumer.py | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/tests/test_backend_output_consumer.py b/tests/test_backend_output_consumer.py new file mode 100644 index 0000000..881112e --- /dev/null +++ b/tests/test_backend_output_consumer.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the output parsing.""" +import base64 +import json +from typing import Any +from typing import Dict + +import pytest + +from mlia.backend.output_consumer import Base64OutputConsumer +from mlia.backend.output_consumer import OutputConsumer + + +OUTPUT_MATCH_ALL = bytearray( + """ +String1: My awesome string! +String2: STRINGS_ARE_GREAT!!! +Int: 12 +Float: 3.14 +""", + encoding="utf-8", +) + +OUTPUT_NO_MATCH = bytearray( + """ +This contains no matches... +Test1234567890!"£$%^&*()_+@~{}[]/.,<>?| +""", + encoding="utf-8", +) + +OUTPUT_PARTIAL_MATCH = bytearray( + "String1: My awesome string!", + encoding="utf-8", +) + +REGEX_CONFIG = { + "FirstString": {"pattern": r"String1.*: (.*)", "type": "str"}, + "SecondString": {"pattern": r"String2.*: (.*)!!!", "type": "str"}, + "IntegerValue": {"pattern": r"Int.*: (.*)", "type": "int"}, + "FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"}, +} + +EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {} + +EXPECTED_METRICS_ALL = { + "FirstString": "My awesome string!", + "SecondString": "STRINGS_ARE_GREAT", + "IntegerValue": 12, + "FloatValue": 3.14, +} + +EXPECTED_METRICS_PARTIAL = { + "FirstString": "My awesome string!", +} + + +@pytest.mark.parametrize( + "expected_metrics", + [ + EXPECTED_METRICS_ALL, + EXPECTED_METRICS_PARTIAL, + ], +) +def test_base64_output_consumer(expected_metrics: Dict) -> None: + """ + Make sure the Base64OutputConsumer yields valid results. + + I.e. return an empty dict if either the input or the config is empty and + return the parsed metrics otherwise. + """ + parser = Base64OutputConsumer() + assert isinstance(parser, OutputConsumer) + + def create_base64_output(expected_metrics: Dict) -> bytearray: + json_str = json.dumps(expected_metrics, indent=4) + json_b64 = base64.b64encode(json_str.encode("utf-8")) + return ( + OUTPUT_MATCH_ALL # Should not be matched by the Base64OutputConsumer + + f"<{Base64OutputConsumer.TAG_NAME}>".encode("utf-8") + + bytearray(json_b64) + + f"</{Base64OutputConsumer.TAG_NAME}>".encode("utf-8") + + OUTPUT_NO_MATCH # Just to add some difficulty... + ) + + output = create_base64_output(expected_metrics) + + consumed = False + for line in output.splitlines(): + if parser.feed(line.decode("utf-8")): + consumed = True + assert consumed # we should have consumed at least one line + + res = parser.parsed_output + assert len(res) == 1 + assert isinstance(res, list) + for val in res: + assert val == expected_metrics |