aboutsummaryrefslogtreecommitdiff
path: root/scripts/operator_api/generate_api.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/operator_api/generate_api.py')
-rw-r--r--scripts/operator_api/generate_api.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index c5c762d..31ee151 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -102,6 +102,9 @@ def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
if attType in serTosaTypeMap.keys():
init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});"
# Initialize Serialization library attributes to their matching function parameter
+ elif tosaOpName == "avg_pool2d" and attName == "accum_dtype":
+ init = f"const tosa::DType {attName} = translate_client_acc_size(client_acc_size);"
+ att["dType"] = "tosa::DType"
elif attName in tosaArgsDict:
if att["SV"] == "V":
if tosaArgsDict[attName]["type"] == "tosa_tensor_t":
@@ -270,10 +273,15 @@ def getTosaArgs(opXml):
argsXml = opXml.getElementsByTagName("argument")
tosaTensorTypes = getTosaArgTypes(tosaXml)
tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
+ tensorElemTypeMap = {
+ "resize_mode_t": "tosa_mode_t",
+ "acc_size_t": "tosa_acc_size_t",
+ }
for xmlArg in argsXml:
argName = xmlArg.getAttribute("name").lower()
- if xmlArg.getAttribute("tensor-element-type") == "resize_mode_t":
- argType = "tosa_mode_t"
+ tensorElemType = xmlArg.getAttribute("tensor-element-type")
+ if tensorElemType in tensorElemTypeMap:
+ argType = tensorElemTypeMap[tensorElemType]
else:
argType = xmlArg.getAttribute("type")
argShape = xmlArg.getAttribute("shape")