diff options
Diffstat (limited to 'verif/generator/datagenerator.py')
-rw-r--r-- | verif/generator/datagenerator.py | 59 |
1 files changed, 44 insertions, 15 deletions
diff --git a/verif/generator/datagenerator.py b/verif/generator/datagenerator.py index 408c83e..0d59084 100644 --- a/verif/generator/datagenerator.py +++ b/verif/generator/datagenerator.py @@ -6,7 +6,7 @@ import json from pathlib import Path import numpy as np -from schemavalidation import schemavalidation +import schemavalidation.schemavalidation as sch class GenerateError(Exception): @@ -14,7 +14,15 @@ class GenerateError(Exception): class GenerateLibrary: - """Python interface to the C generate library.""" + """Python interface to the C generate library. + + Simple usage to write out all input files: + set_config(test_desc) + write_numpy_files(test_path) + + To get data buffers (for const data): + get_tensor_data(tensor_name) + """ def __init__(self, generate_lib_path): """Find the library and set up the interface.""" @@ -22,6 +30,8 @@ class GenerateLibrary: if not self.lib_path.is_file(): raise GenerateError(f"Could not find generate library - {self.lib_path}") + self.schema_validator = sch.TestDescSchemaValidator() + self.test_desc = None self.json_config = None self.lib = ct.cdll.LoadLibrary(self.lib_path) @@ -51,8 +61,7 @@ class GenerateLibrary: raise GenerateError("No meta/data_gen section found in desc.json") # Validate the config versus the schema - tdsv = schemavalidation.TestDescSchemaValidator() - tdsv.validate_config(test_desc) + self.schema_validator.validate_config(test_desc) self.test_desc = test_desc self.json_config = test_desc["meta"]["data_gen"] @@ -72,25 +81,25 @@ class GenerateLibrary: return buffer, size_bytes - def _data_gen_write( - self, test_path: Path, json_bytes: bytes, ifm_name: str, ifm_file: str - ): - """Generate the named tensor data and save it in numpy format.""" + def _data_gen_array(self, json_config: str, tensor_name: str): + """Generate the named tensor data and return a numpy array.""" try: - tensor = self.json_config["tensors"][ifm_name] + tensor = json_config["tensors"][tensor_name] dtype = tensor["data_type"] shape = tuple(tensor["shape"]) except KeyError as e: raise GenerateError( - f"Missing data in desc.json for input {ifm_name} - {repr(e)}" + f"Missing data in json config for input {tensor_name} - {repr(e)}" ) buffer, size_bytes = self._create_buffer(dtype, shape) buffer_ptr = ct.cast(buffer, ct.c_void_p) + json_bytes = bytes(json.dumps(json_config), "utf8") + result = self.tgd_generate_data( ct.c_char_p(json_bytes), - ct.c_char_p(bytes(ifm_name, "utf8")), + ct.c_char_p(bytes(tensor_name, "utf8")), buffer_ptr, ct.c_size_t(size_bytes), ) @@ -100,11 +109,19 @@ class GenerateLibrary: arr = np.ctypeslib.as_array(buffer) arr = np.reshape(arr, shape) + return arr + + def _data_gen_write( + self, test_path: Path, json_config: str, ifm_name: str, ifm_file: str + ): + """Generate the named tensor data and save it in numpy format.""" + arr = self._data_gen_array(json_config, ifm_name) + file_name = test_path / ifm_file np.save(file_name, arr) def write_numpy_files(self, test_path: Path): - """Write out all the specified tensors to numpy data files.""" + """Write out all the desc.json input tensors to numpy data files.""" if self.test_desc is None or self.json_config is None: raise GenerateError("Cannot write numpy files as no config set up") @@ -114,12 +131,10 @@ class GenerateLibrary: except KeyError as e: raise GenerateError(f"Missing data in desc.json - {repr(e)}") - json_bytes = bytes(json.dumps(self.json_config), "utf8") - failures = [] for iname, ifile in zip(ifm_names, ifm_files): try: - self._data_gen_write(test_path, json_bytes, iname, ifile) + self._data_gen_write(test_path, self.json_config, iname, ifile) except GenerateError as e: failures.append( f"ERROR: Failed to create data for tensor {iname} - {repr(e)}" @@ -128,6 +143,20 @@ class GenerateLibrary: if len(failures) > 0: raise GenerateError("\n".join(failures)) + def get_tensor_data(self, tensor_name: str, json_config=None): + """Get a numpy array for a named tensor in the data_gen meta data.""" + if json_config is None: + if self.json_config is None: + raise GenerateError("Cannot get tensor data as no config set up") + json_config = self.json_config + else: + # Validate the given config + self.schema_validator.validate_config( + json_config, schema_type=sch.TD_SCHEMA_DATA_GEN + ) + + return self._data_gen_array(json_config, tensor_name) + def main(argv=None): """Simple command line interface for the data generator.""" |