aboutsummaryrefslogtreecommitdiff
path: root/scripts/operator_api/generate_api.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/operator_api/generate_api.py')
-rw-r--r--scripts/operator_api/generate_api.py30
1 files changed, 3 insertions, 27 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index afe12c1..99639f4 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -99,6 +99,9 @@ def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
serTosaTypeMap = {"ResizeMode": "tosa_mode"}
+ # For reshape operator, change 'new_shape' to 'shape' to match tosa.xml
+ if tosaOpName == "reshape":
+ serLibOpAtts[0]["name"] = "shape"
for att in serLibOpAtts:
attName = att["name"]
attType = att["dType"]
@@ -397,33 +400,6 @@ def generate(environment, dataTypes, operators, base_path):
renderTemplate(environment, dataTypes, operators, template, outfile)
-def getSerializeOpTypeMap():
- """
- Utility function for generating the map used in getSerializeOpType()
- """
- import re
-
- allSerialLibAtts = getSerialLibAtts()
- serAtts = [
- re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
- for name in allSerialLibAtts.keys()
- ]
- serAtts = sorted(serAtts, key=len, reverse=True)
- base_path = getBasePath()
- tosaXml = minidom.parse(base_path / "thirdparty/specification/tosa.xml")
- opsXml = tosaXml.getElementsByTagName("operator")
- opNames = [
- op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml
- ]
- map = {}
- for opName in opNames:
- for serAtt in serAtts:
- if serAtt in opName:
- components = serAtt.split("_")
- map[opName] = "".join(x.title() for x in components)
- return map
-
-
if __name__ == "__main__":
base_path = getBasePath()
environment = Environment(