diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-03-20 18:07:54 +0000 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 15:42:55 +0100 |
commit | 62768232c5fe4ed6b87136c336b65e13d030e9d4 (patch) | |
tree | 847c36a2f7e092982bc1d7a66d0bf601447c8d20 /src/mlia/nn/rewrite/core/graph_edit/record.py | |
parent | 446c379c92e15ad8f24ed0db853dd0fc9c271151 (diff) | |
download | mlia-62768232c5fe4ed6b87136c336b65e13d030e9d4.tar.gz |
MLIA-843 Add unit tests for module mlia.nn.rewrite
Note: The unit tests mostly call the main functions from the respective
modules only.
Change-Id: Ib2ce5c53d0c3eb222b8b8be42fba33ac8e007574
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/record.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/record.py | 33 |
1 files changed, 19 insertions, 14 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/record.py b/src/mlia/nn/rewrite/core/graph_edit/record.py index 03cd3f9..ae13313 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/record.py +++ b/src/mlia/nn/rewrite/core/graph_edit/record.py @@ -37,7 +37,6 @@ def record_model( total = numpytf_count(input_filename) dataset = NumpyTFReader(input_filename) - writer = NumpyTFWriter(output_filename) if batch_size > 1: # Collapse batch-size 1 items into batch-size n. I regret using batch-size 1 items in tfrecs now. @@ -47,16 +46,22 @@ def record_model( dataset = dataset.batch(batch_size, drop_remainder=False) total = int(math.ceil(total / batch_size)) - for j, named_x in enumerate( - tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress) - ): - named_y = model(named_x) - if batch_size > 1: - for i in range(batch_size): - # Expand the batches and recreate each dict as a batch-size 1 item for the tfrec output - d = {k: v[i : i + 1] for k, v in named_y.items() if i < v.shape[0]} - if d: - writer.write(d) - else: - writer.write(named_y) - model.close() + with NumpyTFWriter(output_filename) as writer: + for _, named_x in enumerate( + tqdm(dataset.as_numpy_iterator(), total=total, disable=not show_progress) + ): + named_y = model(named_x) + if batch_size > 1: + for i in range(batch_size): + # Expand the batches and recreate each dict as a + # batch-size 1 item for the tfrec output + recreated_dict = { + k: v[i : i + 1] # noqa: E203 + for k, v in named_y.items() + if i < v.shape[0] + } + if recreated_dict: + writer.write(recreated_dict) + else: + writer.write(named_y) + model.close() |