aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/utils
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/utils')
-rw-r--r--src/mlia/nn/rewrite/core/utils/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py3
-rw-r--r--src/mlia/nn/rewrite/core/utils/parallel.py4
-rw-r--r--src/mlia/nn/rewrite/core/utils/utils.py2
4 files changed, 4 insertions, 7 deletions
diff --git a/src/mlia/nn/rewrite/core/utils/__init__.py b/src/mlia/nn/rewrite/core/utils/__init__.py
index 48b1622..8c1f750 100644
--- a/src/mlia/nn/rewrite/core/utils/__init__.py
+++ b/src/mlia/nn/rewrite/core/utils/__init__.py
@@ -1,2 +1,2 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
+# SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
index ac3e875..2141003 100644
--- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -63,9 +63,6 @@ class NumpyTFWriter:
def __exit__(self, type, value, traceback):
self.close()
- def __del__(self):
- self.close()
-
def write(self, array_dict):
type_map = {n: str(a.dtype.name) for n, a in array_dict.items()}
self.type_map.update(type_map)
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py
index 5affc03..b1a2914 100644
--- a/src/mlia/nn/rewrite/core/utils/parallel.py
+++ b/src/mlia/nn/rewrite/core/utils/parallel.py
@@ -3,10 +3,10 @@
import math
import os
from collections import defaultdict
+from multiprocessing import cpu_count
from multiprocessing import Pool
import numpy as np
-from psutil import cpu_count
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
@@ -25,7 +25,7 @@ class ParallelTFLiteModel(TFLiteModel):
self.pool = None
self.filename = filename
if not num_procs:
- self.num_procs = cpu_count(logical=False)
+ self.num_procs = cpu_count()
else:
self.num_procs = int(num_procs)
diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py
index ed6c81d..d1ed322 100644
--- a/src/mlia/nn/rewrite/core/utils/utils.py
+++ b/src/mlia/nn/rewrite/core/utils/utils.py
@@ -8,7 +8,7 @@ from tensorflow.lite.python import schema_py_generated as schema_fb
def load(input_tflite_file):
if not os.path.exists(input_tflite_file):
- raise RuntimeError("TFLite file not found at %r\n" % input_tflite_file)
+ raise FileNotFoundError("TFLite file not found at %r\n" % input_tflite_file)
with open(input_tflite_file, "rb") as file_handle:
file_data = bytearray(file_handle.read())
model_obj = schema_fb.Model.GetRootAsModel(file_data, 0)