diff options
Diffstat (limited to 'scripts/operator_api/generate_api.py')
-rw-r--r-- | scripts/operator_api/generate_api.py | 17 |
1 files changed, 6 insertions, 11 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index 7f10568..e511f19 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -10,8 +10,6 @@ from xml.dom import minidom from jinja2 import Environment from jinja2 import FileSystemLoader -# Note: main script designed to be run from the scripts/operator_api/ directory - def getBasePath(): return Path(__file__).resolve().parent.parent.parent @@ -82,10 +80,7 @@ def getSerializeOpType(tosaOpName): "cond_if": "CondIf", "while_loop": "WhileLoop", } - if tosaOpName not in map.keys(): - return "None" - else: - return map[tosaOpName] + return map.get(tosaOpName, "None") def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): @@ -268,7 +263,7 @@ def getOperators(tosaXml): operator["serialLibAtts"] = serialLibAtts serializationAttNames = [att["name"] for att in serialLibAtts] operator["inputs"] = [ - arg["name"] + {"name": arg["name"], "type": arg["type"]} for arg in tosaArgs if arg["category"] == "input" and arg["name"] not in serializationAttNames @@ -308,7 +303,7 @@ def getTosaArgs(opXml): if argType[-1:] == "*": argType = argType[:-1] if argCategory in ["input", "output"] and argType in tosaTensorTypes: - argType = "tosa_tensor_t" + argType = f"tosa_{argType}" argShape = "" if argType in tosaTypeMap: argType = tosaTypeMap[argType] @@ -386,7 +381,7 @@ def getSerialLibAtts(): return serialLibAtts -def renderTemplate(environment, dataTypes, operators, template, outfile): +def renderTemplate(dataTypes, operators, template, outfile): content = template.render(dataTypes=dataTypes, operators=operators) with open(outfile, mode="w", encoding="utf-8") as output: output.write(content) @@ -399,12 +394,12 @@ def generate(environment, dataTypes, operators, base_path): # Generate include/operators.h template = environment.get_template("operators_h.j2") outfile = base_path / "reference_model/include/operators.h" - renderTemplate(environment, dataTypes, operators, template, outfile) + renderTemplate(dataTypes, operators, template, outfile) # Generate src/operators.cc template = environment.get_template("operators_cc.j2") outfile = base_path / "reference_model/src/operators.cc" - renderTemplate(environment, dataTypes, operators, template, outfile) + renderTemplate(dataTypes, operators, template, outfile) if __name__ == "__main__": |