From b0b9e33c3500bd8dc9b12ef012d4234b1245247a Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Wed, 1 Nov 2023 13:49:37 +0000 Subject: Fix TransposeConv2d in operator API - Change name of the TransposeConv2d attribute output_shape to out_shape in generate_api.py to match with TOSA specification - Fix serialization attributes mapping for operator TransposeConv2d - Add a unit test for TransposeConv2d operator Signed-off-by: Dmitrii Agibov Change-Id: I6613c0d093aeea0af30012bcc1c8e5d26dec746c --- scripts/operator_api/generate_api.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'scripts') diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index 99639f4..d9077f0 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -64,7 +64,7 @@ def getSerializeOpType(tosaOpName): "fully_connected": "FullyConnected", "matmul": "MatMul", "max_pool2d": "Pool", - "transpose_conv2d": "Conv", + "transpose_conv2d": "TransposeConv", "clamp": "Clamp", "arithmetic_right_shift": "ArithmeticRightShift", "mul": "Mul", @@ -99,9 +99,16 @@ def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType]) tosaArgsDict = {arg["name"]: arg for arg in tosaArgs} serTosaTypeMap = {"ResizeMode": "tosa_mode"} - # For reshape operator, change 'new_shape' to 'shape' to match tosa.xml - if tosaOpName == "reshape": - serLibOpAtts[0]["name"] = "shape" + serAttsToFix = { + "reshape": {"new_shape": "shape"}, + "transpose_conv2d": {"output_shape": "out_shape"}, + } + if tosaOpName in serAttsToFix: + # Fix attributes names to match with tosa.xml + for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items(): + for opAtts in serLibOpAtts: + if opAtts["name"] == attDefName: + opAtts["name"] = tosaSpecName for att in serLibOpAtts: attName = att["name"] attType = att["dType"] -- cgit v1.2.1