diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/operator_api/generate_api.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index 99639f4..d9077f0 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -64,7 +64,7 @@ def getSerializeOpType(tosaOpName): "fully_connected": "FullyConnected", "matmul": "MatMul", "max_pool2d": "Pool", - "transpose_conv2d": "Conv", + "transpose_conv2d": "TransposeConv", "clamp": "Clamp", "arithmetic_right_shift": "ArithmeticRightShift", "mul": "Mul", @@ -99,9 +99,16 @@ def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType]) tosaArgsDict = {arg["name"]: arg for arg in tosaArgs} serTosaTypeMap = {"ResizeMode": "tosa_mode"} - # For reshape operator, change 'new_shape' to 'shape' to match tosa.xml - if tosaOpName == "reshape": - serLibOpAtts[0]["name"] = "shape" + serAttsToFix = { + "reshape": {"new_shape": "shape"}, + "transpose_conv2d": {"output_shape": "out_shape"}, + } + if tosaOpName in serAttsToFix: + # Fix attributes names to match with tosa.xml + for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items(): + for opAtts in serLibOpAtts: + if opAtts["name"] == attDefName: + opAtts["name"] = tosaSpecName for att in serLibOpAtts: attName = att["name"] attType = att["dType"] |