aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/datagenerator.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/datagenerator.py')
-rw-r--r--verif/generator/datagenerator.py59
1 files changed, 44 insertions, 15 deletions
diff --git a/verif/generator/datagenerator.py b/verif/generator/datagenerator.py
index 408c83e..0d59084 100644
--- a/verif/generator/datagenerator.py
+++ b/verif/generator/datagenerator.py
@@ -6,7 +6,7 @@ import json
from pathlib import Path
import numpy as np
-from schemavalidation import schemavalidation
+import schemavalidation.schemavalidation as sch
class GenerateError(Exception):
@@ -14,7 +14,15 @@ class GenerateError(Exception):
class GenerateLibrary:
- """Python interface to the C generate library."""
+ """Python interface to the C generate library.
+
+ Simple usage to write out all input files:
+ set_config(test_desc)
+ write_numpy_files(test_path)
+
+ To get data buffers (for const data):
+ get_tensor_data(tensor_name)
+ """
def __init__(self, generate_lib_path):
"""Find the library and set up the interface."""
@@ -22,6 +30,8 @@ class GenerateLibrary:
if not self.lib_path.is_file():
raise GenerateError(f"Could not find generate library - {self.lib_path}")
+ self.schema_validator = sch.TestDescSchemaValidator()
+
self.test_desc = None
self.json_config = None
self.lib = ct.cdll.LoadLibrary(self.lib_path)
@@ -51,8 +61,7 @@ class GenerateLibrary:
raise GenerateError("No meta/data_gen section found in desc.json")
# Validate the config versus the schema
- tdsv = schemavalidation.TestDescSchemaValidator()
- tdsv.validate_config(test_desc)
+ self.schema_validator.validate_config(test_desc)
self.test_desc = test_desc
self.json_config = test_desc["meta"]["data_gen"]
@@ -72,25 +81,25 @@ class GenerateLibrary:
return buffer, size_bytes
- def _data_gen_write(
- self, test_path: Path, json_bytes: bytes, ifm_name: str, ifm_file: str
- ):
- """Generate the named tensor data and save it in numpy format."""
+ def _data_gen_array(self, json_config: str, tensor_name: str):
+ """Generate the named tensor data and return a numpy array."""
try:
- tensor = self.json_config["tensors"][ifm_name]
+ tensor = json_config["tensors"][tensor_name]
dtype = tensor["data_type"]
shape = tuple(tensor["shape"])
except KeyError as e:
raise GenerateError(
- f"Missing data in desc.json for input {ifm_name} - {repr(e)}"
+ f"Missing data in json config for input {tensor_name} - {repr(e)}"
)
buffer, size_bytes = self._create_buffer(dtype, shape)
buffer_ptr = ct.cast(buffer, ct.c_void_p)
+ json_bytes = bytes(json.dumps(json_config), "utf8")
+
result = self.tgd_generate_data(
ct.c_char_p(json_bytes),
- ct.c_char_p(bytes(ifm_name, "utf8")),
+ ct.c_char_p(bytes(tensor_name, "utf8")),
buffer_ptr,
ct.c_size_t(size_bytes),
)
@@ -100,11 +109,19 @@ class GenerateLibrary:
arr = np.ctypeslib.as_array(buffer)
arr = np.reshape(arr, shape)
+ return arr
+
+ def _data_gen_write(
+ self, test_path: Path, json_config: str, ifm_name: str, ifm_file: str
+ ):
+ """Generate the named tensor data and save it in numpy format."""
+ arr = self._data_gen_array(json_config, ifm_name)
+
file_name = test_path / ifm_file
np.save(file_name, arr)
def write_numpy_files(self, test_path: Path):
- """Write out all the specified tensors to numpy data files."""
+ """Write out all the desc.json input tensors to numpy data files."""
if self.test_desc is None or self.json_config is None:
raise GenerateError("Cannot write numpy files as no config set up")
@@ -114,12 +131,10 @@ class GenerateLibrary:
except KeyError as e:
raise GenerateError(f"Missing data in desc.json - {repr(e)}")
- json_bytes = bytes(json.dumps(self.json_config), "utf8")
-
failures = []
for iname, ifile in zip(ifm_names, ifm_files):
try:
- self._data_gen_write(test_path, json_bytes, iname, ifile)
+ self._data_gen_write(test_path, self.json_config, iname, ifile)
except GenerateError as e:
failures.append(
f"ERROR: Failed to create data for tensor {iname} - {repr(e)}"
@@ -128,6 +143,20 @@ class GenerateLibrary:
if len(failures) > 0:
raise GenerateError("\n".join(failures))
+ def get_tensor_data(self, tensor_name: str, json_config=None):
+ """Get a numpy array for a named tensor in the data_gen meta data."""
+ if json_config is None:
+ if self.json_config is None:
+ raise GenerateError("Cannot get tensor data as no config set up")
+ json_config = self.json_config
+ else:
+ # Validate the given config
+ self.schema_validator.validate_config(
+ json_config, schema_type=sch.TD_SCHEMA_DATA_GEN
+ )
+
+ return self._data_gen_array(json_config, tensor_name)
+
def main(argv=None):
"""Simple command line interface for the data generator."""