aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/operator_api/generate_api.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 99639f4..d9077f0 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -64,7 +64,7 @@ def getSerializeOpType(tosaOpName):
"fully_connected": "FullyConnected",
"matmul": "MatMul",
"max_pool2d": "Pool",
- "transpose_conv2d": "Conv",
+ "transpose_conv2d": "TransposeConv",
"clamp": "Clamp",
"arithmetic_right_shift": "ArithmeticRightShift",
"mul": "Mul",
@@ -99,9 +99,16 @@ def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
serTosaTypeMap = {"ResizeMode": "tosa_mode"}
- # For reshape operator, change 'new_shape' to 'shape' to match tosa.xml
- if tosaOpName == "reshape":
- serLibOpAtts[0]["name"] = "shape"
+ serAttsToFix = {
+ "reshape": {"new_shape": "shape"},
+ "transpose_conv2d": {"output_shape": "out_shape"},
+ }
+ if tosaOpName in serAttsToFix:
+ # Fix attributes names to match with tosa.xml
+ for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items():
+ for opAtts in serLibOpAtts:
+ if opAtts["name"] == attDefName:
+ opAtts["name"] = tosaSpecName
for att in serLibOpAtts:
attName = att["name"]
attType = att["dType"]