diff options
Diffstat (limited to 'src/mlia/backend')
-rw-r--r-- | src/mlia/backend/vela/compiler.py | 552 | ||||
-rw-r--r-- | src/mlia/backend/vela/performance.py | 85 |
2 files changed, 431 insertions, 206 deletions
diff --git a/src/mlia/backend/vela/compiler.py b/src/mlia/backend/vela/compiler.py index fe9e365..211721a 100644 --- a/src/mlia/backend/vela/compiler.py +++ b/src/mlia/backend/vela/compiler.py @@ -3,36 +3,156 @@ """Vela compiler wrapper module.""" from __future__ import annotations +import csv import logging +import re import sys from dataclasses import dataclass +from dataclasses import fields from io import StringIO from pathlib import Path -from typing import Any from typing import Literal -from ethosu.vela.architecture_features import ArchitectureFeatures -from ethosu.vela.compiler_driver import compiler_driver -from ethosu.vela.compiler_driver import CompilerOptions -from ethosu.vela.compiler_driver import TensorAllocator from ethosu.vela.model_reader import ModelReaderOptions from ethosu.vela.model_reader import read_model from ethosu.vela.nn_graph import Graph from ethosu.vela.nn_graph import NetworkType from ethosu.vela.operation import CustomType -from ethosu.vela.scheduler import OptimizationStrategy -from ethosu.vela.scheduler import SchedulerOptions -from ethosu.vela.tensor import BandwidthDirection -from ethosu.vela.tensor import MemArea -from ethosu.vela.tensor import Tensor -from ethosu.vela.tflite_writer import write_tflite +from ethosu.vela.vela import main +from mlia.utils.filesystem import get_vela_config from mlia.utils.logging import redirect_output +from mlia.utils.logging import redirect_raw_output logger = logging.getLogger(__name__) @dataclass +class VelaInitMemoryData: + """Memory Data from vela.ini.""" + + clock_scale: float | None + burst_length: int | None + read_latency: int | None + write_latency: int | None + + +@dataclass +class VelaInitData: # pylint: disable=too-many-instance-attributes + """Data gathered from the vela.ini file we provide to vela.""" + + system_config: str + core_clock: float + axi0_port: str + axi1_port: str + sram_memory_data: VelaInitMemoryData + dram_memory_data: VelaInitMemoryData + off_chip_flash_memory_data: VelaInitMemoryData + on_chip_flash_memory_data: VelaInitMemoryData + memory_mode: str + const_mem_area: str + arena_mem_area: str + cache_mem_area: str + arena_cache_size: int | None + + +@dataclass +class VelaSummary: # pylint: disable=too-many-instance-attributes + """Data gathered from the summary CSV file that Vela produces.""" + + cycles_total: float + cycles_npu: float + cycles_sram_access: float + cycles_dram_access: float + cycles_on_chip_flash_access: float + cycles_off_chip_flash_access: float + core_clock: float + dram_memory_used: float + sram_memory_used: float + on_chip_flash_memory_used: float + off_chip_flash_memory_used: float + batch_size: int + memory_mode: str + system_config: str + accelerator_configuration: str + arena_cache_size: float + + def __repr__(self) -> str: + """Return String Representation of VelaSummary object.""" + header_values = dict(summary_metrics) + string_to_check = "" + for field in fields(self): + string_to_check += ( + f"{header_values[field.name]}: {getattr(self, field.name)}, " + ) + return string_to_check + + +complete_summary_metrics = [ + ("experiment", "experiment"), + ("network", "network"), + ("accelerator_configuration", "accelerator_configuration"), + ("system_config", "system_config"), + ("memory_mode", "memory_mode"), + ("core_clock", "core_clock"), + ("arena_cache_size", "arena_cache_size"), + ("sram_bandwidth", "sram_bandwidth"), + ("dram_bandwidth", "dram_bandwidth"), + ("on_chip_flash_bandwidth", "on_chip_flash_bandwidth"), + ("off_chip_flash_bandwidth", "off_chip_flash_bandwidth"), + ("weights_storage_area", "weights_storage_area"), + ("feature_map_storage_area", "feature_map_storage_area"), + ("inferences_per_second", "inferences_per_second"), + ("batch_size", "batch_size"), + ("inference_time", "inference_time"), + ("passes_before_fusing", "passes_before_fusing"), + ("sram_memory_used", "sram_memory_used"), + ("dram_memory_used", "dram_memory_used"), + ( + "on_chip_flash_memory_used", + "on_chip_flash_memory_used", + ), + ("off_chip_flash_memory_used", "off_chip_flash_memory_used"), + ("total_original_weights", "total_original_weights"), + ("total_npu_encoded_weights", "total_npu_encoded_weights"), + ("dram_total_bytes", "dram_total_bytes"), + ( + "on_chip_flash_feature_map_read_bytes", + "on_chip_flash_feature_map_read_bytes", + ), + ("on_chip_flash_feature_map_write_bytes", "on_chip_flash_feature_map_write_bytes"), + ("on_chip_flash_weight_read_bytes", "on_chip_flash_weight_read_bytes"), + ("on_chip_flash_weight_write_bytes", "on_chip_flash_weight_write_bytes"), + ("on_chip_flash_total_bytes", "on_chip_flash_total_bytes"), + ("off_chip_flash_feature_map_read_bytes", "off_chip_flash_feature_map_read_bytes"), + ( + "off_chip_flash_feature_map_write_bytes", + "off_chip_flash_feature_map_write_bytes", + ), + ("off_chip_flash_weight_read_bytes", "off_chip_flash_weight_read_bytes"), + ("off_chip_flash_weight_write_bytes", "off_chip_flash_weight_write_bytes"), + ("off_chip_flash_total_bytes", "off_chip_flash_total_bytes"), + ("nn_macs", "nn_macs"), + ("nn_tops", "nn_tops"), + ("cycles_npu", "cycles_npu"), + ("cycles_sram_access", "cycles_sram_access"), + ("cycles_dram_access", "cycles_dram_access"), + ("cycles_on_chip_flash_access", "cycles_on_chip_flash_access"), + ("cycles_off_chip_flash_access", "cycles_off_chip_flash_access"), + ("cycles_total", "cycles_total"), +] + +OUTPUT_METRICS = [field.name for field in fields(VelaSummary)] + +summary_metrics = [ + summary_metric + for summary_metric in complete_summary_metrics + if summary_metric[0] in OUTPUT_METRICS +] +summary_metrics.sort(key=lambda e: OUTPUT_METRICS.index(e[0])) + + +@dataclass class Model: """Model metadata.""" @@ -49,20 +169,6 @@ class Model: ) -@dataclass -class OptimizedModel: - """Instance of the Vela optimized model.""" - - nng: Graph - arch: ArchitectureFeatures - compiler_options: CompilerOptions - scheduler_options: SchedulerOptions - - def save(self, output_filename: str | Path) -> None: - """Save instance of the optimized model to the file.""" - write_tflite(self.nng, output_filename) - - AcceleratorConfigType = Literal[ "ethos-u55-32", "ethos-u55-64", @@ -82,16 +188,17 @@ class VelaCompilerOptions: # pylint: disable=too-many-instance-attributes """Vela compiler options.""" config_files: str | list[str] | None = None - system_config: str = ArchitectureFeatures.DEFAULT_CONFIG - memory_mode: str = ArchitectureFeatures.DEFAULT_CONFIG + system_config: str = "internal-default" + memory_mode: str = "internal-default" accelerator_config: AcceleratorConfigType | None = None - max_block_dependency: int = ArchitectureFeatures.MAX_BLOCKDEP + max_block_dependency: int = 3 arena_cache_size: int | None = None tensor_allocator: TensorAllocatorType = "HillClimb" - cpu_tensor_alignment: int = Tensor.AllocationQuantum + cpu_tensor_alignment: int = 16 optimization_strategy: OptimizationStrategyType = "Performance" output_dir: Path = Path("output") recursion_limit: int = 1000 + verbose_performance: bool = True class VelaCompiler: # pylint: disable=too-many-instance-attributes @@ -105,13 +212,12 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes self.accelerator_config = compiler_options.accelerator_config self.max_block_dependency = compiler_options.max_block_dependency self.arena_cache_size = compiler_options.arena_cache_size - self.tensor_allocator = TensorAllocator[compiler_options.tensor_allocator] + self.tensor_allocator = compiler_options.tensor_allocator self.cpu_tensor_alignment = compiler_options.cpu_tensor_alignment - self.optimization_strategy = OptimizationStrategy[ - compiler_options.optimization_strategy - ] - self.output_dir = compiler_options.output_dir + self.optimization_strategy = compiler_options.optimization_strategy + self.output_dir = Path(compiler_options.output_dir) self.recursion_limit = compiler_options.recursion_limit + self.verbose_performance = compiler_options.verbose_performance sys.setrecursionlimit(self.recursion_limit) @@ -122,36 +228,48 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes nng, network_type = self._read_model(model) return Model(nng, network_type) - def compile_model(self, model: str | Path | Model) -> OptimizedModel: + def compile_model( + self, model_path: Path, already_compiled: bool = False + ) -> tuple[VelaSummary, Path]: """Compile the model.""" - if isinstance(model, (str, Path)): - nng, network_type = self._read_model(model) - else: - nng, network_type = model.nng, NetworkType.TFLite - - if not nng: - raise ValueError("Unable to read model: model.nng is not available") - - output_basename = f"{self.output_dir}/{nng.name}" - try: - arch = self._architecture_features() - compiler_options = self._compiler_options() - scheduler_options = self._scheduler_options() - - with redirect_output( + with redirect_raw_output( logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG ): tmp = sys.stdout output_message = StringIO() sys.stdout = output_message - compiler_driver( - nng, - arch, - compiler_options, - scheduler_options, - network_type, - output_basename, + main_args = [ + "--output-dir", + str(self.output_dir.as_posix()), + "--tensor-allocator", + str(self.tensor_allocator), + "--cpu-tensor-alignment", + str(self.cpu_tensor_alignment), + "--accelerator-config", + str(self.accelerator_config), + "--system-config", + str(self.system_config), + "--memory-mode", + str(self.memory_mode), + "--max-block-dependency", + str(self.max_block_dependency), + "--optimise", + str(self.optimization_strategy), + model_path.as_posix(), + "--config", + str(self.config_files), + ] + if self.verbose_performance: + main_args.append("--verbose-performance") + if not already_compiled: + main(main_args) + optimized_model_path = Path( + self.output_dir.as_posix() + + "/" + + model_path.stem + + "_vela" + + model_path.suffix ) sys.stdout = tmp if ( @@ -159,51 +277,29 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes in output_message.getvalue() ): raise MemoryError("Model is too large and uses too much RAM") - - return OptimizedModel(nng, arch, compiler_options, scheduler_options) + summary_data = parse_summary_csv_file( + Path( + self.output_dir.as_posix() + + "/" + + model_path.stem + + "_summary_" + + self.system_config + + ".csv" + ) + ) + return summary_data, optimized_model_path except MemoryError as err: raise err except (SystemExit, Exception) as err: + if ( + "Error: Invalid tflite file." in output_message.getvalue() + and isinstance(err, SystemExit) + ): + raise RuntimeError(f"Unable to read model {model_path}") from err raise RuntimeError( "Model could not be optimized with Vela compiler." ) from err - def get_config(self) -> dict[str, Any]: - """Get compiler configuration.""" - arch = self._architecture_features() - - memory_area = { - mem.name: { - "clock_scales": arch.memory_clock_scales[mem], - "burst_length": arch.memory_burst_length[mem], - "read_latency": arch.memory_latency[mem][BandwidthDirection.Read], - "write_latency": arch.memory_latency[mem][BandwidthDirection.Write], - } - for mem in ( - MemArea.Sram, - MemArea.Dram, - MemArea.OnChipFlash, - MemArea.OffChipFlash, - ) - } - - return { - "accelerator_config": arch.accelerator_config.value, - "system_config": arch.system_config, - "core_clock": arch.core_clock, - "axi0_port": arch.axi0_port.name, - "axi1_port": arch.axi1_port.name, - "memory_mode": arch.memory_mode, - "const_mem_area": arch.const_mem_area.name, - "arena_mem_area": arch.arena_mem_area.name, - "cache_mem_area": arch.cache_mem_area.name, - "arena_cache_size": arch.arena_cache_size, - "permanent_storage_mem_area": arch.permanent_storage_mem_area.name, - "feature_map_storage_mem_area": arch.feature_map_storage_mem_area.name, - "fast_storage_mem_area": arch.fast_storage_mem_area.name, - "memory_area": memory_area, - } - @staticmethod def _read_model(model: str | Path) -> tuple[Graph, NetworkType]: """Read TensorFlow Lite model.""" @@ -216,57 +312,10 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes except (SystemExit, Exception) as err: raise RuntimeError(f"Unable to read model {model_path}.") from err - def _architecture_features(self) -> ArchitectureFeatures: - """Return ArchitectureFeatures instance.""" - return ArchitectureFeatures( - vela_config_files=self.config_files, - accelerator_config=self.accelerator_config, - system_config=self.system_config, - memory_mode=self.memory_mode, - max_blockdep=self.max_block_dependency, - verbose_config=False, - arena_cache_size=self.arena_cache_size, - ) - - def _scheduler_options(self) -> SchedulerOptions: - """Return SchedulerOptions instance.""" - arch = self._architecture_features() - - return SchedulerOptions( - optimization_strategy=self.optimization_strategy, - sram_target=arch.arena_cache_size, - verbose_schedule=False, - ) - - def _compiler_options(self) -> CompilerOptions: - """Return CompilerOptions instance.""" - return CompilerOptions( - verbose_graph=False, - verbose_quantization=False, - verbose_packing=False, - verbose_tensor_purpose=False, - verbose_tensor_format=False, - verbose_allocation=False, - verbose_high_level_command_stream=False, - verbose_register_command_stream=False, - verbose_operators=False, - verbose_weights=False, - verbose_performance=True, - show_cpu_operations=False, - tensor_allocator=self.tensor_allocator, - timing=False, - output_dir=self.output_dir, - cpu_tensor_alignment=self.cpu_tensor_alignment, - ) - - def return_compiler_options(self) -> CompilerOptions: - """Return CompilerOptions instance for test purposes.""" - return self._compiler_options() - def resolve_compiler_config( vela_compiler_options: VelaCompilerOptions, -) -> dict[str, Any]: +) -> VelaInitData: """Resolve passed compiler options. Vela has number of configuration parameters that being @@ -278,22 +327,209 @@ def resolve_compiler_config( In order to get this information we need to create instance of the Vela compiler first. """ - vela_compiler = VelaCompiler(vela_compiler_options) - return vela_compiler.get_config() - - -def optimize_model( - model_path: Path, compiler_options: VelaCompilerOptions, output_model_path: Path -) -> None: - """Optimize model and return it's path after optimization.""" - logger.debug( - "Optimize model %s for target %s", - model_path, - compiler_options.accelerator_config, + return parse_vela_initialisation_file( + get_vela_config(), + vela_compiler_options.system_config, + vela_compiler_options.memory_mode, ) + +def compile_model(model_path: Path, compiler_options: VelaCompilerOptions) -> Path: + """Compile model.""" vela_compiler = VelaCompiler(compiler_options) - optimized_model = vela_compiler.compile_model(model_path) + # output dir could be a path or str, cast to Path object + output_dir = Path(compiler_options.output_dir) + if Path( + output_dir.as_posix() + + "/" + + model_path.stem + + "_summary_" + + compiler_options.system_config + + ".csv" + ).is_file(): + _, optimized_model_path = vela_compiler.compile_model(model_path, True) + else: + _, optimized_model_path = vela_compiler.compile_model(model_path) + return optimized_model_path + + +def parse_summary_csv_file(vela_summary_csv_file: Path) -> VelaSummary: + """Parse the summary csv file from Vela.""" + if not vela_summary_csv_file.is_file(): + raise FileNotFoundError(f"CSV File not found at {vela_summary_csv_file}") + + with open(vela_summary_csv_file, encoding="UTF-8") as csv_file: + summary_reader = csv.DictReader(csv_file, delimiter=",") + try: + row = next(summary_reader) + except StopIteration as err: + raise RuntimeError("Generated Vela Summary CSV is empty") from err + try: + # pylint: disable=eval-used + key_types = { + field.name: eval(field.type) # type: ignore # nosec + for field in fields(VelaSummary) + } + # pylint: enable=eval-used + summary_data = VelaSummary( + **{key: key_types[key](row[title]) for key, title in summary_metrics} + ) + except KeyError as err: + raise KeyError( + f"Generated Vela Summary CSV missing expected header: {err.args[0]}." + ) from err + return summary_data + + +def parse_vela_initialisation_file( # pylint: disable=too-many-locals + vela_init_file: Path, system_config: str, memory_mode: str +) -> VelaInitData: + """Parse the vela.ini to retrieve data for the target information table.""" + if not vela_init_file.is_file(): + raise FileNotFoundError( + f"Vela Initialisation File not found at {vela_init_file}" + ) + + lines = [] + with open(vela_init_file, encoding="UTF-8") as init_file: + lines = init_file.readlines() + + if len(lines) == 0: + raise OSError("vela.ini File Is Empty") + + lines = [line.strip("\n][ ") for line in lines] + + idxs_memory_mode = [ + idx for idx, item in enumerate(lines) if re.search("^Memory_Mode.*", item) + ] + + if len(idxs_memory_mode) == 0: + raise IndexError("No memory modes are present in vela.ini file.") + + idxs_system_config = [ + idx for idx, item in enumerate(lines) if re.search("^System_Config.*", item) + ] + [idxs_memory_mode[0]] + + if len(idxs_system_config) <= 1: + raise IndexError("No system configs are present in vela.ini file.") + + try: + idx_config = lines.index("System_Config." + system_config) + except ValueError as err: + raise ValueError( + f"System Config: {system_config} not present in vela.ini file." + ) from err + + lines_to_probe = lines[ + idx_config : idxs_system_config[ # noqa: E203 + idxs_system_config.index(idx_config) + 1 + ] + ] + + def collect_memory_mode_lines(memory_mode: str) -> list[str]: + try: + idx_memory_mode = lines.index("Memory_Mode." + memory_mode) + except ValueError as err: + raise ValueError( + f"Memory Mode: {memory_mode} not present in vela.ini file." + ) from err + if idxs_memory_mode.index(idx_memory_mode) == len(idxs_memory_mode) - 1: + lines_to_probe = lines[idx_memory_mode:] + else: + lines_to_probe = lines[ + idx_memory_mode : idxs_memory_mode[ # noqa: E203 + idxs_memory_mode.index(idx_memory_mode) + 1 + ] + ] + return lines_to_probe + + lines_to_probe_memory_mode = collect_memory_mode_lines(memory_mode) + extra_memory_mode_lines = [] + for line in lines_to_probe_memory_mode: + if "inherit=Memory_Mode." in line: + extra_memory_mode = line[line.rindex(".") + 1 :] # noqa: E203 + extra_memory_mode_lines = collect_memory_mode_lines(extra_memory_mode) + + lines_to_probe += extra_memory_mode_lines + lines_to_probe_memory_mode + + init_dict = {} + for line in lines_to_probe: + if "=" in line: + init_dict[line[: line.index("=")]] = line[ + line.index("=") + 1 : # noqa: E203 + ] + try: + init_data = VelaInitData( + system_config=system_config, + core_clock=float(init_dict["core_clock"]), + axi0_port=str(init_dict["axi0_port"]), + axi1_port=str(init_dict["axi1_port"]), + memory_mode=memory_mode, + sram_memory_data=VelaInitMemoryData( + clock_scale=float(init_dict["Sram_clock_scale"]) + if "Sram_clock_scale" in init_dict + else None, + burst_length=int(init_dict["Sram_burst_length"]) + if "Sram_burst_length" in init_dict + else None, + read_latency=int(init_dict["Sram_read_latency"]) + if "Sram_read_latency" in init_dict + else None, + write_latency=int(init_dict["Sram_write_latency"]) + if "Sram_write_latency" in init_dict + else None, + ), + dram_memory_data=VelaInitMemoryData( + clock_scale=float(init_dict["Dram_clock_scale"]) + if "Dram_clock_scale" in init_dict + else None, + burst_length=int(init_dict["Dram_burst_length"]) + if "Dram_burst_length" in init_dict + else None, + read_latency=int(init_dict["Dram_read_latency"]) + if "Dram_read_latency" in init_dict + else None, + write_latency=int(init_dict["Dram_write_latency"]) + if "Dram_write_latency" in init_dict + else None, + ), + off_chip_flash_memory_data=VelaInitMemoryData( + clock_scale=float(init_dict["OffChipFlash_clock_scale"]) + if "OffChipFlash_clock_scale" in init_dict + else None, + burst_length=int(init_dict["OffChipFlash_burst_length"]) + if "OffChipFlash_burst_length" in init_dict + else None, + read_latency=int(init_dict["OffChipFlash_read_latency"]) + if "OffChipFlash_read_latency" in init_dict + else None, + write_latency=int(init_dict["OffChipFlash_write_latency"]) + if "OffChipFlash_write_latency" in init_dict + else None, + ), + on_chip_flash_memory_data=VelaInitMemoryData( + clock_scale=float(init_dict["OnChipFlash_clock_scale"]) + if "OnChipFlash_clock_scale" in init_dict + else None, + burst_length=int(init_dict["OnChipFlash_burst_length"]) + if "OnChipFlash_burst_length" in init_dict + else None, + read_latency=int(init_dict["OnChipFlash_read_latency"]) + if "OnChipFlash_read_latency" in init_dict + else None, + write_latency=int(init_dict["OnChipFlash_write_latency"]) + if "OnChipFlash_write_latency" in init_dict + else None, + ), + const_mem_area=str(init_dict["const_mem_area"]), + arena_mem_area=str(init_dict["arena_mem_area"]), + cache_mem_area=str(init_dict["cache_mem_area"]), + arena_cache_size=int(init_dict["arena_cache_size"]) + if "arena_cache_size" in init_dict + else None, + ) + + except KeyError as err: + raise KeyError(f"Vela.ini file missing expected header: {err.args[0]}") from err - logger.debug("Save optimized model into %s", output_model_path) - optimized_model.save(output_model_path) + return init_data diff --git a/src/mlia/backend/vela/performance.py b/src/mlia/backend/vela/performance.py index 72a8ceb..2cf945d 100644 --- a/src/mlia/backend/vela/performance.py +++ b/src/mlia/backend/vela/performance.py @@ -10,15 +10,12 @@ from collections import Counter from dataclasses import dataclass from dataclasses import fields from pathlib import Path -from pydoc import locate import numpy as np -from ethosu.vela.npu_performance import PassCycles -from ethosu.vela.tensor import MemArea -from mlia.backend.vela.compiler import OptimizedModel from mlia.backend.vela.compiler import VelaCompiler from mlia.backend.vela.compiler import VelaCompilerOptions +from mlia.backend.vela.compiler import VelaSummary logger = logging.getLogger(__name__) @@ -37,11 +34,10 @@ class PerformanceMetrics: # pylint: disable=too-many-instance-attributes batch_inference_time: float inferences_per_second: float batch_size: int - unknown_memory_area_size: int - sram_memory_area_size: int - dram_memory_area_size: int - on_chip_flash_memory_area_size: int - off_chip_flash_memory_area_size: int + sram_memory_area_size: float + dram_memory_area_size: float + on_chip_flash_memory_area_size: float + off_chip_flash_memory_area_size: float layerwise_performance_info: LayerwisePerfInfo @@ -145,19 +141,19 @@ def parse_layerwise_perf_csv( # pylint: disable=too-many-locals if row == headers_to_check_cpu_ops: continue try: + # pylint: disable=eval-used key_types = { - field.name: locate(str(field.type)) + field.name: eval(field.type) # type: ignore # nosec for field in fields(LayerPerfInfo) } + # pylint: enable=eval-used ids_to_metrics = {} for key, title, _ in metrics: try: - ids_to_metrics[key] = key_types[key]( # type: ignore - row_as_dict[title] - ) + ids_to_metrics[key] = key_types[key](row_as_dict[title]) except ValueError as err: if "invalid literal for int() with base 10" in str(err): - ids_to_metrics[key] = key_types[key]( # type: ignore + ids_to_metrics[key] = key_types[key]( float(row_as_dict[title]) ) else: @@ -180,17 +176,20 @@ def estimate_performance( model_path, compiler_options.accelerator_config, ) - vela_compiler = VelaCompiler(compiler_options) - - initial_model = vela_compiler.read_model(model_path) - if initial_model.optimized: - raise ValueError( - "Unable to estimate performance for the given optimized model." - ) - - optimized_model = vela_compiler.compile_model(initial_model) - output_dir = optimized_model.compiler_options.output_dir + if Path( + Path(compiler_options.output_dir).as_posix() + + "/" + + model_path.stem + + "_summary_" + + compiler_options.system_config + + ".csv" + ).is_file(): + summary_data, _ = vela_compiler.compile_model(model_path, True) + else: + summary_data, _ = vela_compiler.compile_model(model_path) + + output_dir = compiler_options.output_dir csv_paths = [entry for entry in os.listdir(output_dir) if "per-layer.csv" in entry] model_name = str(model_path.stem) csv_file_found = None @@ -204,41 +203,31 @@ def estimate_performance( vela_csv_file=csv_path, metrics=layer_metrics ) - return _performance_metrics(layerwise_performance_info, optimized_model) + return _performance_metrics(layerwise_performance_info, summary_data) def _performance_metrics( - layerwise_performance_info: LayerwisePerfInfo, optimized_model: OptimizedModel + layerwise_performance_info: LayerwisePerfInfo, summary_data: VelaSummary ) -> PerformanceMetrics: """Return performance metrics for optimized model.""" - cycles = optimized_model.nng.cycles - - def memory_usage(mem_area: MemArea) -> int: - """Get memory usage for the proviced memory area type.""" - memory_used: dict[MemArea, int] = optimized_model.nng.memory_used - bandwidths = optimized_model.nng.bandwidths - - return memory_used.get(mem_area, 0) if np.sum(bandwidths[mem_area]) > 0 else 0 - midpoint_fps = np.nan - midpoint_inference_time = cycles[PassCycles.Total] / optimized_model.arch.core_clock + midpoint_inference_time = summary_data.cycles_total / summary_data.core_clock if midpoint_inference_time > 0: midpoint_fps = 1 / midpoint_inference_time return PerformanceMetrics( - npu_cycles=int(cycles[PassCycles.Npu]), - sram_access_cycles=int(cycles[PassCycles.SramAccess]), - dram_access_cycles=int(cycles[PassCycles.DramAccess]), - on_chip_flash_access_cycles=int(cycles[PassCycles.OnChipFlashAccess]), - off_chip_flash_access_cycles=int(cycles[PassCycles.OffChipFlashAccess]), - total_cycles=int(cycles[PassCycles.Total]), + npu_cycles=int(summary_data.cycles_npu), + sram_access_cycles=int(summary_data.cycles_sram_access), + dram_access_cycles=int(summary_data.cycles_dram_access), + on_chip_flash_access_cycles=int(summary_data.cycles_on_chip_flash_access), + off_chip_flash_access_cycles=int(summary_data.cycles_off_chip_flash_access), + total_cycles=int(summary_data.cycles_total), batch_inference_time=midpoint_inference_time * 1000, inferences_per_second=midpoint_fps, - batch_size=optimized_model.nng.batch_size, - unknown_memory_area_size=memory_usage(MemArea.Unknown), - sram_memory_area_size=memory_usage(MemArea.Sram), - dram_memory_area_size=memory_usage(MemArea.Dram), - on_chip_flash_memory_area_size=memory_usage(MemArea.OnChipFlash), - off_chip_flash_memory_area_size=memory_usage(MemArea.OffChipFlash), + batch_size=summary_data.batch_size, + sram_memory_area_size=float(summary_data.sram_memory_used), + dram_memory_area_size=float(summary_data.dram_memory_used), + on_chip_flash_memory_area_size=float(summary_data.on_chip_flash_memory_used), + off_chip_flash_memory_area_size=float(summary_data.off_chip_flash_memory_used), layerwise_performance_info=layerwise_performance_info, ) |