aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core')
-rw-r--r--src/mlia/nn/rewrite/core/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/diff.py30
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/record.py33
-rw-r--r--src/mlia/nn/rewrite/core/train.py88
-rw-r--r--src/mlia/nn/rewrite/core/utils/__init__.py2
-rw-r--r--src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py3
-rw-r--r--src/mlia/nn/rewrite/core/utils/parallel.py4
-rw-r--r--src/mlia/nn/rewrite/core/utils/utils.py2
9 files changed, 31 insertions, 135 deletions
diff --git a/src/mlia/nn/rewrite/core/__init__.py b/src/mlia/nn/rewrite/core/__init__.py
index 48b1622..8c1f750 100644
--- a/src/mlia/nn/rewrite/core/__init__.py
+++ b/src/mlia/nn/rewrite/core/__init__.py
@@ -1,2 +1,2 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
+# SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/nn/rewrite/core/graph_edit/__init__.py b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
index 48b1622..8c1f750 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/__init__.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/__init__.py
@@ -1,2 +1,2 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
+# SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py
index b6c9616..0829f0a 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/diff.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py
@@ -13,36 +13,6 @@ from tensorflow.lite.python import interpreter as interpreter_wrapper
from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader, NumpyTFWriter
-def diff(file1, file2):
- results = []
-
- dataset1 = NumpyTFReader(file1)
- dataset2 = NumpyTFReader(file2)
-
- for i, (x1, x2) in enumerate(zip(dataset1, dataset2)):
- assert x1.keys() == x2.keys(), (
- "At input %d the files have different sets of tensors.\n%s: %s\n%s: %s\n"
- % (
- i,
- file1,
- ", ".join(x1.keys()),
- file2,
- ", ".join(x2.keys()),
- )
- )
- results.append({})
- for k in x1.keys():
- v1 = x1[k].numpy().astype(np.double)
- v2 = x2[k].numpy().astype(np.double)
- mae = abs(v1 - v2).mean()
- results[-1][k] = mae
-
- total = sum(sum(x.values()) for x in results)
- count = sum(len(x.values()) for x in results)
- mean = total / count
- return results, mean
-
-
def diff_stats(file1, file2, per_tensor_and_channel=False):
dataset1 = NumpyTFReader(file1)
dataset2 = NumpyTFReader(file2)
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py
index 03cd3f9..ae13313 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/record.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/record.py
@@ -37,7 +37,6 @@ def record_model(
total = numpytf_count(input_filename)
dataset = NumpyTFReader(input_filename)
- writer = NumpyTFWriter(output_filename)
if batch_size > 1:
# Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now.
@@ -47,16 +46,22 @@ def record_model(
dataset = dataset.batch(batch_size, drop_remainder=False)
total = int(math.ceil(total / batch_size))
- for j, named_x in enumerate(
- tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
- ):
- named_y = model(named_x)
- if batch_size > 1:
- for i in range(batch_size):
- # Expand the batches and recreate each dict as a batch-size 1 item for the tfrec output
- d = {k: v[i : i + 1] for k, v in named_y.items() if i < v.shape[0]}
- if d:
- writer.write(d)
- else:
- writer.write(named_y)
- model.close()
+ with NumpyTFWriter(output_filename) as writer:
+ for _, named_x in enumerate(
+ tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
+ ):
+ named_y = model(named_x)
+ if batch_size > 1:
+ for i in range(batch_size):
+ # Expand the batches and recreate each dict as a
+ # batch-size 1 item for the tfrec output
+ recreated_dict = {
+ k: v[i : i + 1] # noqa: E203
+ for k, v in named_y.items()
+ if i < v.shape[0]
+ }
+ if recreated_dict:
+ writer.write(recreated_dict)
+ else:
+ writer.write(named_y)
+ model.close()
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index a929b14..096daf4 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -40,85 +40,7 @@ augmentation_presets = {
"mix_gaussian_small": (1.6, 0.3),
}
-
-class SequentialTrainer:
- def __init__(
- self,
- source_model,
- output_model,
- input_tfrec,
- augment="gaussian",
- steps=6000,
- lr=1e-3,
- batch_size=32,
- show_progress=True,
- eval_fn=None,
- num_procs=1,
- num_threads=0,
- ):
- self.source_model = source_model
- self.output_model = output_model
- self.input_tfrec = input_tfrec
- self.default_augment = augment
- self.default_steps = steps
- self.default_lr = lr
- self.default_batch_size = batch_size
- self.show_progress = show_progress
- self.num_procs = num_procs
- self.num_threads = num_threads
- self.first_replace = True
- self.eval_fn = eval_fn
-
- def replace(
- self,
- model_fn,
- input_tensors,
- output_tensors,
- augment=None,
- steps=None,
- lr=None,
- batch_size=None,
- ):
- augment = self.default_augment if augment is None else augment
- steps = self.default_steps if steps is None else steps
- lr = self.default_lr if lr is None else lr
- batch_size = self.default_batch_size if batch_size is None else batch_size
-
- if isinstance(augment, str):
- augment = augmentation_presets[augment]
-
- if self.first_replace:
- source_model = self.source_model
- unmodified_model = None
- else:
- source_model = self.output_model
- unmodified_model = self.source_model
-
- mae, nrmse = train(
- source_model,
- unmodified_model,
- self.output_model,
- self.input_tfrec,
- model_fn,
- input_tensors,
- output_tensors,
- augment,
- steps,
- lr,
- batch_size,
- False,
- self.show_progress,
- None,
- 0,
- self.num_procs,
- self.num_threads,
- )
-
- self.first_replace = False
- if self.eval_fn:
- return self.eval_fn(mae, nrmse, self.output_model)
- else:
- return mae, nrmse
+learning_rate_schedules = {"cosine", "late", "constant"}
def train(
@@ -135,6 +57,7 @@ def train(
batch_size,
verbose,
show_progress,
+ learning_rate_schedule="cosine",
checkpoint_at=None,
checkpoint_decay_steps=0,
num_procs=1,
@@ -183,6 +106,7 @@ def train(
show_progress=show_progress,
num_procs=num_procs,
num_threads=num_threads,
+ schedule=learning_rate_schedule,
)
for i, filename in enumerate(tflite_filenames):
@@ -363,9 +287,9 @@ def train_in_dir(
elif schedule == "constant":
callbacks = []
else:
- assert False, (
- 'LR schedule "%s" not implemented - expected "cosine", "constant" or "late"'
- % schedule
+ assert schedule not in learning_rate_schedules
+ raise ValueError(
+ f'LR schedule "{schedule}" not implemented - expected one of {learning_rate_schedules}.'
)
output_filenames = []
diff --git a/src/mlia/nn/rewrite/core/utils/__init__.py b/src/mlia/nn/rewrite/core/utils/__init__.py
index 48b1622..8c1f750 100644
--- a/src/mlia/nn/rewrite/core/utils/__init__.py
+++ b/src/mlia/nn/rewrite/core/utils/__init__.py
@@ -1,2 +1,2 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file
+# SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
index ac3e875..2141003 100644
--- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
+++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py
@@ -63,9 +63,6 @@ class NumpyTFWriter:
def __exit__(self, type, value, traceback):
self.close()
- def __del__(self):
- self.close()
-
def write(self, array_dict):
type_map = {n: str(a.dtype.name) for n, a in array_dict.items()}
self.type_map.update(type_map)
diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py
index 5affc03..b1a2914 100644
--- a/src/mlia/nn/rewrite/core/utils/parallel.py
+++ b/src/mlia/nn/rewrite/core/utils/parallel.py
@@ -3,10 +3,10 @@
import math
import os
from collections import defaultdict
+from multiprocessing import cpu_count
from multiprocessing import Pool
import numpy as np
-from psutil import cpu_count
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
@@ -25,7 +25,7 @@ class ParallelTFLiteModel(TFLiteModel):
self.pool = None
self.filename = filename
if not num_procs:
- self.num_procs = cpu_count(logical=False)
+ self.num_procs = cpu_count()
else:
self.num_procs = int(num_procs)
diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py
index ed6c81d..d1ed322 100644
--- a/src/mlia/nn/rewrite/core/utils/utils.py
+++ b/src/mlia/nn/rewrite/core/utils/utils.py
@@ -8,7 +8,7 @@ from tensorflow.lite.python import schema_py_generated as schema_fb
def load(input_tflite_file):
if not os.path.exists(input_tflite_file):
- raise RuntimeError("TFLite file not found at %r\n" % input_tflite_file)
+ raise FileNotFoundError("TFLite file not found at %r\n" % input_tflite_file)
with open(input_tflite_file, "rb") as file_handle:
file_data = bytearray(file_handle.read())
model_obj = schema_fb.Model.GetRootAsModel(file_data, 0)