diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/mlia/backend/vela/compiler.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/src/mlia/backend/vela/compiler.py b/src/mlia/backend/vela/compiler.py index 3d3847a..b62df24 100644 --- a/src/mlia/backend/vela/compiler.py +++ b/src/mlia/backend/vela/compiler.py @@ -89,7 +89,7 @@ class VelaCompilerOptions: # pylint: disable=too-many-instance-attributes tensor_allocator: TensorAllocatorType = "HillClimb" cpu_tensor_alignment: int = Tensor.AllocationQuantum optimization_strategy: OptimizationStrategyType = "Performance" - output_dir: str | None = None + output_dir: str = "output" recursion_limit: int = 1000 @@ -131,6 +131,8 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes if not nng: raise Exception("Unable to read model") + output_basename = f"{self.output_dir}/{nng.name}" + try: arch = self._architecture_features() compiler_options = self._compiler_options() @@ -140,7 +142,12 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG ): compiler_driver( - nng, arch, compiler_options, scheduler_options, network_type + nng, + arch, + compiler_options, + scheduler_options, + network_type, + output_basename, ) return OptimizedModel(nng, arch, compiler_options, scheduler_options) @@ -186,9 +193,8 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes @staticmethod def _read_model(model: str | Path) -> tuple[Graph, NetworkType]: """Read TensorFlow Lite model.""" + model_path = str(model) if isinstance(model, Path) else model try: - model_path = str(model) if isinstance(model, Path) else model - with redirect_output( logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG ): |