diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/operator_api/generate_api.py | 12 | ||||
-rw-r--r-- | scripts/operator_api/templates/operators_cc.j2 | 13 |
2 files changed, 23 insertions, 2 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index fd33466..afe12c1 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -107,6 +107,9 @@ def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): if attType in serTosaTypeMap.keys(): init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});" # Initialize Serialization library attributes to their matching function parameter + elif tosaOpName == "avg_pool2d" and attName == "accum_dtype": + init = f"const tosa::DType {attName} = translate_client_acc_size(client_acc_size);" + att["dType"] = "tosa::DType" elif attName in tosaArgsDict: if att["SV"] == "V": if tosaArgsDict[attName]["type"] == "tosa_tensor_t": @@ -275,10 +278,15 @@ def getTosaArgs(opXml): argsXml = opXml.getElementsByTagName("argument") tosaTensorTypes = getTosaArgTypes(tosaXml) tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"} + tensorElemTypeMap = { + "resize_mode_t": "tosa_mode_t", + "acc_size_t": "tosa_acc_size_t", + } for xmlArg in argsXml: argName = xmlArg.getAttribute("name").lower() - if xmlArg.getAttribute("tensor-element-type") == "resize_mode_t": - argType = "tosa_mode_t" + tensorElemType = xmlArg.getAttribute("tensor-element-type") + if tensorElemType in tensorElemTypeMap: + argType = tensorElemTypeMap[tensorElemType] else: argType = xmlArg.getAttribute("type") argShape = xmlArg.getAttribute("shape") diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 index 6b6f864..1de103a 100644 --- a/scripts/operator_api/templates/operators_cc.j2 +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -92,6 +92,19 @@ tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) { } } +tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size) { + switch(acc_size) { + case tosa_acc_size_int32_t: + return tosa::DType::DType_INT32; + case tosa_acc_size_fp16_t: + return tosa::DType::DType_FP16; + case tosa_acc_size_fp32_t: + return tosa::DType::DType_FP32; + default: + return tosa::DType::DType_UNKNOWN; + } +} + } // namespace extern "C" |