aboutsummaryrefslogtreecommitdiff
path: root/scripts/operator_api/generate_api.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/operator_api/generate_api.py')
-rw-r--r--scripts/operator_api/generate_api.py17
1 files changed, 6 insertions, 11 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 7f10568..e511f19 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -10,8 +10,6 @@ from xml.dom import minidom
from jinja2 import Environment
from jinja2 import FileSystemLoader
-# Note: main script designed to be run from the scripts/operator_api/ directory
-
def getBasePath():
return Path(__file__).resolve().parent.parent.parent
@@ -82,10 +80,7 @@ def getSerializeOpType(tosaOpName):
"cond_if": "CondIf",
"while_loop": "WhileLoop",
}
- if tosaOpName not in map.keys():
- return "None"
- else:
- return map[tosaOpName]
+ return map.get(tosaOpName, "None")
def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
@@ -268,7 +263,7 @@ def getOperators(tosaXml):
operator["serialLibAtts"] = serialLibAtts
serializationAttNames = [att["name"] for att in serialLibAtts]
operator["inputs"] = [
- arg["name"]
+ {"name": arg["name"], "type": arg["type"]}
for arg in tosaArgs
if arg["category"] == "input"
and arg["name"] not in serializationAttNames
@@ -308,7 +303,7 @@ def getTosaArgs(opXml):
if argType[-1:] == "*":
argType = argType[:-1]
if argCategory in ["input", "output"] and argType in tosaTensorTypes:
- argType = "tosa_tensor_t"
+ argType = f"tosa_{argType}"
argShape = ""
if argType in tosaTypeMap:
argType = tosaTypeMap[argType]
@@ -386,7 +381,7 @@ def getSerialLibAtts():
return serialLibAtts
-def renderTemplate(environment, dataTypes, operators, template, outfile):
+def renderTemplate(dataTypes, operators, template, outfile):
content = template.render(dataTypes=dataTypes, operators=operators)
with open(outfile, mode="w", encoding="utf-8") as output:
output.write(content)
@@ -399,12 +394,12 @@ def generate(environment, dataTypes, operators, base_path):
# Generate include/operators.h
template = environment.get_template("operators_h.j2")
outfile = base_path / "reference_model/include/operators.h"
- renderTemplate(environment, dataTypes, operators, template, outfile)
+ renderTemplate(dataTypes, operators, template, outfile)
# Generate src/operators.cc
template = environment.get_template("operators_cc.j2")
outfile = base_path / "reference_model/src/operators.cc"
- renderTemplate(environment, dataTypes, operators, template, outfile)
+ renderTemplate(dataTypes, operators, template, outfile)
if __name__ == "__main__":