aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichiel Olieslagers <michiel.olieslagers@arm.com>2023-09-18 17:11:09 +0100
committerGergely Nagy <gergely.nagy@arm.com>2023-11-22 09:22:25 +0000
commit82a5fe34d7464bba70577c734dc446111adb4d93 (patch)
treee805148f20689ded267e433828cc2c0b56c07d4d
parentb09d9fa0028ebb7496327786810e5f0abbcdfd68 (diff)
downloadmlia-82a5fe34d7464bba70577c734dc446111adb4d93.tar.gz
MLIA-963: Capture and handle Vela warning
Signed-off-by: Gergely Nagy <gergely.nagy@arm.com> Change-Id: I6c8b0b74d6d35261eb0ff1a37b9577f9033be8f9
-rw-r--r--src/mlia/backend/vela/compiler.py12
-rw-r--r--tests/test_backend_vela_compiler.py20
2 files changed, 32 insertions, 0 deletions
diff --git a/src/mlia/backend/vela/compiler.py b/src/mlia/backend/vela/compiler.py
index 78f97b2..b591056 100644
--- a/src/mlia/backend/vela/compiler.py
+++ b/src/mlia/backend/vela/compiler.py
@@ -6,6 +6,7 @@ from __future__ import annotations
import logging
import sys
from dataclasses import dataclass
+from io import StringIO
from pathlib import Path
from typing import Any
from typing import Literal
@@ -141,6 +142,9 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
with redirect_output(
logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG
):
+ tmp = sys.stdout
+ output_message = StringIO()
+ sys.stdout = output_message
compiler_driver(
nng,
arch,
@@ -149,8 +153,16 @@ class VelaCompiler: # pylint: disable=too-many-instance-attributes
network_type,
output_basename,
)
+ sys.stdout = tmp
+ if (
+ "Warning: SRAM target for arena memory area exceeded."
+ in output_message.getvalue()
+ ):
+ raise MemoryError("Model is too large and uses too much RAM")
return OptimizedModel(nng, arch, compiler_options, scheduler_options)
+ except MemoryError as err:
+ raise err
except (SystemExit, Exception) as err:
raise RuntimeError(
"Model could not be optimized with Vela compiler."
diff --git a/tests/test_backend_vela_compiler.py b/tests/test_backend_vela_compiler.py
index 9b69ada..9f09efb 100644
--- a/tests/test_backend_vela_compiler.py
+++ b/tests/test_backend_vela_compiler.py
@@ -2,7 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for module vela/compiler."""
from pathlib import Path
+from typing import Any
+import pytest
from ethosu.vela.compiler_driver import TensorAllocator
from ethosu.vela.scheduler import OptimizationStrategy
@@ -154,6 +156,24 @@ def test_compile_model(test_tflite_model: Path) -> None:
assert isinstance(optimized_model, OptimizedModel)
+def test_compile_model_fail_sram_exceeded(
+ test_tflite_model: Path, monkeypatch: pytest.MonkeyPatch
+) -> None:
+ """Test model optimization."""
+ compiler = VelaCompiler(
+ EthosUConfiguration.load_profile("ethos-u55-256").compiler_options
+ )
+
+ def fake_compiler(*_: Any) -> None:
+ print("Warning: SRAM target for arena memory area exceeded.")
+
+ monkeypatch.setattr("mlia.backend.vela.compiler.compiler_driver", fake_compiler)
+ with pytest.raises(Exception) as exc_info:
+ compiler.compile_model(test_tflite_model)
+
+ assert str(exc_info.value) == "Model is too large and uses too much RAM"
+
+
def test_optimize_model(tmp_path: Path, test_tflite_model: Path) -> None:
"""Test model optimization and saving into file."""
tmp_file = tmp_path / "temp.tflite"