aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2023-11-01 13:49:37 +0000
committerDmitrii Agibov <dmitrii.agibov@arm.com>2023-11-01 14:51:59 +0000
commitb0b9e33c3500bd8dc9b12ef012d4234b1245247a (patch)
tree9d7579558126028f48374dac507fa8f145cfbf5a /scripts
parentce53cd103cc2ac09b43b4fdf586249e626bd5627 (diff)
downloadreference_model-b0b9e33c3500bd8dc9b12ef012d4234b1245247a.tar.gz
Fix TransposeConv2d in operator API
- Change name of the TransposeConv2d attribute output_shape to out_shape in generate_api.py to match with TOSA specification - Fix serialization attributes mapping for operator TransposeConv2d - Add a unit test for TransposeConv2d operator Signed-off-by: Dmitrii Agibov <dmitrii.agibov@arm.com> Change-Id: I6613c0d093aeea0af30012bcc1c8e5d26dec746c
Diffstat (limited to 'scripts')
-rw-r--r--scripts/operator_api/generate_api.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 99639f4..d9077f0 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -64,7 +64,7 @@ def getSerializeOpType(tosaOpName):
"fully_connected": "FullyConnected",
"matmul": "MatMul",
"max_pool2d": "Pool",
- "transpose_conv2d": "Conv",
+ "transpose_conv2d": "TransposeConv",
"clamp": "Clamp",
"arithmetic_right_shift": "ArithmeticRightShift",
"mul": "Mul",
@@ -99,9 +99,16 @@ def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
serTosaTypeMap = {"ResizeMode": "tosa_mode"}
- # For reshape operator, change 'new_shape' to 'shape' to match tosa.xml
- if tosaOpName == "reshape":
- serLibOpAtts[0]["name"] = "shape"
+ serAttsToFix = {
+ "reshape": {"new_shape": "shape"},
+ "transpose_conv2d": {"output_shape": "out_shape"},
+ }
+ if tosaOpName in serAttsToFix:
+ # Fix attributes names to match with tosa.xml
+ for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items():
+ for opAtts in serLibOpAtts:
+ if opAtts["name"] == attDefName:
+ opAtts["name"] = tosaSpecName
for att in serLibOpAtts:
attName = att["name"]
attType = att["dType"]