aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorGrant Watson <grant.watson@arm.com>2023-10-31 19:02:14 +0000
committerEric Kunze <eric.kunze@arm.com>2023-10-31 20:58:15 +0000
commitce53cd103cc2ac09b43b4fdf586249e626bd5627 (patch)
treefc2b664693a3d11587f9e875b6e3496d38f62a21 /scripts
parent72dcab775c7a84037135bf365086ca976f3220ef (diff)
downloadreference_model-ce53cd103cc2ac09b43b4fdf586249e626bd5627.tar.gz
Fix Reshape in operator API
- The API incorrectly requires the new shape to be passed in twice. - This fix changes the name of the attribute from new_shape to shape in the generate_api.py script. - Adds a unit test to verify that the reshape operator works correctly. Signed-off-by: Grant Watson <grant.watson@arm.com> Change-Id: I07dd0ef786c747896b6e54f4eada0e7b97c6cef3
Diffstat (limited to 'scripts')
-rw-r--r--scripts/operator_api/generate_api.py30
1 files changed, 3 insertions, 27 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index afe12c1..99639f4 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -99,6 +99,9 @@ def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
serTosaTypeMap = {"ResizeMode": "tosa_mode"}
+ # For reshape operator, change 'new_shape' to 'shape' to match tosa.xml
+ if tosaOpName == "reshape":
+ serLibOpAtts[0]["name"] = "shape"
for att in serLibOpAtts:
attName = att["name"]
attType = att["dType"]
@@ -397,33 +400,6 @@ def generate(environment, dataTypes, operators, base_path):
renderTemplate(environment, dataTypes, operators, template, outfile)
-def getSerializeOpTypeMap():
- """
- Utility function for generating the map used in getSerializeOpType()
- """
- import re
-
- allSerialLibAtts = getSerialLibAtts()
- serAtts = [
- re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
- for name in allSerialLibAtts.keys()
- ]
- serAtts = sorted(serAtts, key=len, reverse=True)
- base_path = getBasePath()
- tosaXml = minidom.parse(base_path / "thirdparty/specification/tosa.xml")
- opsXml = tosaXml.getElementsByTagName("operator")
- opNames = [
- op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml
- ]
- map = {}
- for opName in opNames:
- for serAtt in serAtts:
- if serAtt in opName:
- components = serAtt.split("_")
- map[opName] = "".join(x.title() for x in components)
- return map
-
-
if __name__ == "__main__":
base_path = getBasePath()
environment = Environment(