diff options
Diffstat (limited to 'src/mlia/utils/console.py')
-rw-r--r-- | src/mlia/utils/console.py | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/src/mlia/utils/console.py b/src/mlia/utils/console.py new file mode 100644 index 0000000..7cb3d83 --- /dev/null +++ b/src/mlia/utils/console.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Console output utility functions.""" +from typing import Iterable +from typing import List +from typing import Optional + +from rich.console import Console +from rich.console import RenderableType +from rich.table import box +from rich.table import Table +from rich.text import Text + + +def create_section_header( + section_name: Optional[str] = None, length: int = 80, sep: str = "-" +) -> str: + """Return section header.""" + if not section_name: + content = sep * length + else: + before = 3 + spaces = 2 + after = length - (len(section_name) + before + spaces) + if after < 0: + raise ValueError("Section name too long") + content = f"{sep * before} {section_name} {sep * after}" + + return f"\n{content}\n" + + +def apply_style(value: str, style: str) -> str: + """Apply style to the value.""" + return f"[{style}]{value}" + + +def style_improvement(result: bool) -> str: + """Return different text style based on result.""" + return "green" if result else "yellow" + + +def produce_table( + rows: Iterable, + headers: Optional[List[str]] = None, + table_style: str = "default", +) -> str: + """Represent data in tabular form.""" + table = _get_table(table_style) + + if headers: + table.show_header = True + for header in headers: + table.add_column(header) + + for row in rows: + table.add_row(*row) + + return _convert_to_text(table) + + +def _get_table(table_style: str) -> Table: + """Get Table instance for the provided style.""" + if table_style == "default": + return Table( + show_header=False, + show_lines=True, + box=box.SQUARE_DOUBLE_HEAD, + ) + + if table_style == "nested": + return Table( + show_header=False, + box=None, + padding=(0, 1, 1, 0), + ) + + if table_style == "no_borders": + return Table(show_header=False, box=None) + + raise Exception(f"Unsupported table style {table_style}") + + +def _convert_to_text(*renderables: RenderableType) -> str: + """Convert renderable object to text.""" + console = Console() + with console.capture() as capture: + for item in renderables: + console.print(item) + + text = capture.get() + return text.rstrip() + + +def remove_ascii_codes(value: str) -> str: + """Decode and remove ASCII codes.""" + text = Text.from_ansi(value) + return text.plain |