diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2023-11-01 13:49:37 +0000 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2023-11-01 14:51:59 +0000 |
commit | b0b9e33c3500bd8dc9b12ef012d4234b1245247a (patch) | |
tree | 9d7579558126028f48374dac507fa8f145cfbf5a /scripts | |
parent | ce53cd103cc2ac09b43b4fdf586249e626bd5627 (diff) | |
download | reference_model-b0b9e33c3500bd8dc9b12ef012d4234b1245247a.tar.gz |
Fix TransposeConv2d in operator API
- Change name of the TransposeConv2d attribute output_shape to out_shape
in generate_api.py to match with TOSA specification
- Fix serialization attributes mapping for operator TransposeConv2d
- Add a unit test for TransposeConv2d operator
Signed-off-by: Dmitrii Agibov <dmitrii.agibov@arm.com>
Change-Id: I6613c0d093aeea0af30012bcc1c8e5d26dec746c
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/operator_api/generate_api.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index 99639f4..d9077f0 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -64,7 +64,7 @@ def getSerializeOpType(tosaOpName): "fully_connected": "FullyConnected", "matmul": "MatMul", "max_pool2d": "Pool", - "transpose_conv2d": "Conv", + "transpose_conv2d": "TransposeConv", "clamp": "Clamp", "arithmetic_right_shift": "ArithmeticRightShift", "mul": "Mul", @@ -99,9 +99,16 @@ def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType]) tosaArgsDict = {arg["name"]: arg for arg in tosaArgs} serTosaTypeMap = {"ResizeMode": "tosa_mode"} - # For reshape operator, change 'new_shape' to 'shape' to match tosa.xml - if tosaOpName == "reshape": - serLibOpAtts[0]["name"] = "shape" + serAttsToFix = { + "reshape": {"new_shape": "shape"}, + "transpose_conv2d": {"output_shape": "out_shape"}, + } + if tosaOpName in serAttsToFix: + # Fix attributes names to match with tosa.xml + for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items(): + for opAtts in serLibOpAtts: + if opAtts["name"] == attDefName: + opAtts["name"] = tosaSpecName for att in serLibOpAtts: attName = att["name"] attType = att["dType"] |