aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/backend/vela/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/backend/vela/compiler.py')
-rw-r--r--src/mlia/backend/vela/compiler.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/src/mlia/backend/vela/compiler.py b/src/mlia/backend/vela/compiler.py
index 3d3847a..b62df24 100644
--- a/src/mlia/backend/vela/compiler.py
+++ b/src/mlia/backend/vela/compiler.py
@@ -89,7 +89,7 @@ class VelaCompilerOptions: # pylint: disable=too-many-instance-attributes
tensor_allocator: TensorAllocatorType = "HillClimb"
cpu_tensor_alignment: int = Tensor.AllocationQuantum
optimization_strategy: OptimizationStrategyType = "Performance"
- output_dir: str | None = None
+ output_dir: str = "output"
recursion_limit: int = 1000
@@ -131,6 +131,8 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
if not nng:
raise Exception("Unable to read model")
+ output_basename = f"{self.output_dir}/{nng.name}"
+
try:
arch = self._architecture_features()
compiler_options = self._compiler_options()
@@ -140,7 +142,12 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG
):
compiler_driver(
- nng, arch, compiler_options, scheduler_options, network_type
+ nng,
+ arch,
+ compiler_options,
+ scheduler_options,
+ network_type,
+ output_basename,
)
return OptimizedModel(nng, arch, compiler_options, scheduler_options)
@@ -186,9 +193,8 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
@staticmethod
def _read_model(model: str | Path) -> tuple[Graph, NetworkType]:
"""Read TensorFlow Lite model."""
+ model_path = str(model) if isinstance(model, Path) else model
try:
- model_path = str(model) if isinstance(model, Path) else model
-
with redirect_output(
logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG
):