diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils')
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/__init__.py | 2 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py | 3 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/parallel.py | 4 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/utils.py | 2 |
4 files changed, 4 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/__init__.py b/src/mlia/nn/rewrite/core/utils/__init__.py index 48b1622..8c1f750 100644 --- a/src/mlia/nn/rewrite/core/utils/__init__.py +++ b/src/mlia/nn/rewrite/core/utils/__init__.py @@ -1,2 +1,2 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0
\ No newline at end of file +# SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py index ac3e875..2141003 100644 --- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py +++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py @@ -63,9 +63,6 @@ class NumpyTFWriter: def __exit__(self, type, value, traceback): self.close() - def __del__(self): - self.close() - def write(self, array_dict): type_map = {n: str(a.dtype.name) for n, a in array_dict.items()} self.type_map.update(type_map) diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py index 5affc03..b1a2914 100644 --- a/src/mlia/nn/rewrite/core/utils/parallel.py +++ b/src/mlia/nn/rewrite/core/utils/parallel.py @@ -3,10 +3,10 @@ import math import os from collections import defaultdict +from multiprocessing import cpu_count from multiprocessing import Pool import numpy as np -from psutil import cpu_count os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf @@ -25,7 +25,7 @@ class ParallelTFLiteModel(TFLiteModel): self.pool = None self.filename = filename if not num_procs: - self.num_procs = cpu_count(logical=False) + self.num_procs = cpu_count() else: self.num_procs = int(num_procs) diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py index ed6c81d..d1ed322 100644 --- a/src/mlia/nn/rewrite/core/utils/utils.py +++ b/src/mlia/nn/rewrite/core/utils/utils.py @@ -8,7 +8,7 @@ from tensorflow.lite.python import schema_py_generated as schema_fb def load(input_tflite_file): if not os.path.exists(input_tflite_file): - raise RuntimeError("TFLite file not found at %r\n" % input_tflite_file) + raise FileNotFoundError("TFLite file not found at %r\n" % input_tflite_file) with open(input_tflite_file, "rb") as file_handle: file_data = bytearray(file_handle.read()) model_obj = schema_fb.Model.GetRootAsModel(file_data, 0) |