diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/operator_api/generate_api.py | 171 | ||||
-rw-r--r-- | scripts/operator_api/templates/operators_cc.j2 | 26 | ||||
-rw-r--r-- | scripts/operator_api/templates/operators_h.j2 | 37 |
3 files changed, 123 insertions, 111 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index f1cb6e0..c5c762d 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -81,64 +81,87 @@ def getSerializeOpType(tosaOpName): return map[tosaOpName] -def getSerializeArgsForOp(tosaOpName, allSerializeArgs, tosaArgs): +def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): """ - Returns the arguments required by the Serialization library for the TOSA operator specified. - Generates code to initialize Serialization arguments. If a matching TOSA argument exists, + Returns the attributes required by the Serialization library for the TOSA operator specified. + Generates code to initialize Serialization library attributes. If a matching TOSA argument exists, that value is used for initialization, otherwise a default value e.g. 0 is used. """ - serOpType = getSerializeOpType(tosaOpName) - if serOpType not in allSerializeArgs.keys(): + serLibOpType = getSerializeOpType(tosaOpName) + if serLibOpType not in allSerialLibAtts.keys(): return {} else: - serOpArgs = copy.deepcopy(allSerializeArgs[serOpType]) + serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType]) tosaArgsDict = {arg["name"]: arg for arg in tosaArgs} serTosaTypeMap = {"ResizeMode": "tosa_mode"} - for arg in serOpArgs: - argName = arg["name"] + for att in serLibOpAtts: + attName = att["name"] + attType = att["dType"] init = "" - # Translate TOSA data types to Serialization data types for initialization - if arg["dType"] in serTosaTypeMap.keys(): - init = f" = translate_client_{serTosaTypeMap[arg['dType']]}(client_{argName})" - # Initialize Serialization arguments to their matching function parameter - elif argName in tosaArgsDict: - if arg["SV"] == "V": - shape = tosaArgsDict[argName]["shape"] - if shape == "[]": - init = f"(&client_{argName}[0], &client_{argName}[0] + client_{argName}_len)" + # Translate TOSA data types to Serialization library data types for initialization + if attType in serTosaTypeMap.keys(): + init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});" + # Initialize Serialization library attributes to their matching function parameter + elif attName in tosaArgsDict: + if att["SV"] == "V": + if tosaArgsDict[attName]["type"] == "tosa_tensor_t": + init = f"std::vector<{attType}> {attName};" + init = ( + init + + f"size_t {attName}_size = client_{attName}.size / sizeof({attType});" + ) + init = ( + init + + f"{attType}* {attName}_data = reinterpret_cast<{attType}*>(client_{attName}.data);" + ) + init = ( + init + + f"{attName}.assign({attName}_data, {attName}_data + {attName}_size);" + ) else: - init = f"(&client_{argName}[0], &client_{argName}{shape})" + init = f"const std::vector<{attType}> {attName}" + shape = tosaArgsDict[attName]["shape"] + if shape == "[]": + init = ( + init + + f"(&client_{attName}[0], &client_{attName}[0] + client_{attName}_len);" + ) + else: + init = ( + init + + f"(&client_{attName}[0], &client_{attName}{shape});" + ) else: - init = f" = client_{argName}" - else: - # Initialize Serialization arguments with no matching fuction parameter - if arg["SV"] == "V": init = "" + else: + # Initialize Serialization library attributes with no matching fuction parameter + if att["SV"] == "V": + init = f"std::vector<int32_t> {attName};" else: - if arg["dType"] == "DType": - arg["dType"] = "tosa::DType" - init = " = tosa::DType::DType_FP32" + if att["dType"] == "DType": + att["dType"] = "tosa::DType" + init = f"const tosa::DType {attName} = tosa::DType::DType_FP32;" else: - init = " = 0" - arg["init"] = init - return serOpArgs + init = f"const {attType} {attName} = 0;" + att["init"] = init + return serLibOpAtts -def updateTosaArgs(tosaArgs, serializeArgs, tosaXml): +def updateTosaArgs(tosaArgs, serialLibAtts, tosaXml): """ - Replace TOSA argument data types with their matching Serialization argument data types. + Replace TOSA argument data types with their matching Serialization attribute data types. Delete TOSA arguments where the type couldn't be determined. - Add Serialization arguments that have no matching TOSA argument. + Add Serialization attributes that have no matching TOSA argument. """ tosaArgTypes = getTosaArgTypes(tosaXml) - serArgsDict = {arg["name"]: arg for arg in serializeArgs} + serAttsDict = {att["name"]: att for att in serialLibAtts} tosaArgsNames = [arg["name"] for arg in tosaArgs] delTosaArgs = [] - # Replace TOSA argument data types with their matching Serialization argument data types. + # Replace TOSA argument data types with their matching Serialization attribute data types. for tosaArg in tosaArgs: if tosaArg["type"] in tosaArgTypes: - if tosaArg["name"] in serArgsDict: - tosaArg["type"] = serArgsDict[tosaArg["name"]]["dType"] + if tosaArg["name"] in serAttsDict: + tosaArg["type"] = serAttsDict[tosaArg["name"]]["dType"] else: # Delete TOSA argument whose data type can't be determined delTosaArgs.append(tosaArgsNames.index(tosaArg["name"])) @@ -149,36 +172,36 @@ def updateTosaArgs(tosaArgs, serializeArgs, tosaXml): # Delete TOSA arguments where the type couldn't be determined for index in sorted(delTosaArgs, key=int, reverse=True): del tosaArgs[index] - # Add Serialization arguments that have no matching TOSA argument + # Add Serialization attributes that have no matching TOSA argument tosaArgNames = [arg["name"] for arg in tosaArgs] - for serArg in serializeArgs: - if (serArg["name"] not in tosaArgNames) and ( - not serArg["dType"] == "tosa::DType" - ): - serArgName = serArg["name"] - if serArg["SV"] == "V": + for serAtt in serialLibAtts: + attName = serAtt["name"] + attType = serAtt["dType"] + if (attName not in tosaArgNames) and (not attType == "tosa::DType"): + serAttName = serAtt["name"] + if serAtt["SV"] == "V": # For vector data types, insert a matching length argument tosaArgs.insert( len(tosaArgs) - 1, { - "name": f"{serArgName}_len", + "name": f"{serAttName}_len", "type": "int32_t", "shape": "", "category": "", }, ) - init = f"(&client_{serArgName}[0], &client_{serArgName}[0] + client_{serArgName}_len)" + init = f"const std::vector<{attType}> {attName}(&client_{serAttName}[0], &client_{serAttName}[0] + client_{serAttName}_len);" shape = "[]" else: - init = f" = client_{serArg['name']}" + init = "" shape = "" - serArg["init"] = init + serAtt["init"] = init # Insert new argument tosaArgs.insert( len(tosaArgs) - 1, { - "name": serArgName, - "type": serArg["dType"], + "name": serAttName, + "type": serAtt["dType"], "shape": shape, "category": "", }, @@ -190,33 +213,47 @@ def getOperators(tosaXml): Return a list of TOSA operators as defined by tosa.xml. """ operators = [] - ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"] + ignoreOps = [ + "while_loop", + "cond_if", + "const", + "custom", + "fft2d", + "rfft2d", + "variable", + "variable_read", + "variable_write", + ] opsXml = tosaXml.getElementsByTagName("operator") - allSerializeArgs = getSerializeArgs() + allSerialLibAtts = getSerialLibAtts() for opXml in opsXml: opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower() if opName not in ignoreOps: operator = {"name": opName} operator["serializeAttType"] = getSerializeOpType(opName) tosaArgs = getTosaArgs(opXml) - serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs) + serialLibAtts = getSerialLibAttsForOp(opName, allSerialLibAtts, tosaArgs) # Handle "axis" arguments axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"] if operator["serializeAttType"] == "None" and len(axisList) > 0: operator["serializeAttType"] = "Axis" - serializeArgs = [ + serialLibAtts = [ { "name": "axis", "dType": "int32_t", "SV": "S", - "init": "= client_axis", + "init": "", } ] - updateTosaArgs(tosaArgs, serializeArgs, tosaXml) + updateTosaArgs(tosaArgs, serialLibAtts, tosaXml) operator["arguments"] = tosaArgs - operator["serializeArgs"] = serializeArgs + operator["serialLibAtts"] = serialLibAtts + serializationAttNames = [att["name"] for att in serialLibAtts] operator["inputs"] = [ - arg["name"] for arg in tosaArgs if arg["category"] == "input" + arg["name"] + for arg in tosaArgs + if arg["category"] == "input" + and arg["name"] not in serializationAttNames ] operator["outputs"] = [ arg["name"] for arg in tosaArgs if arg["category"] == "output" @@ -283,12 +320,12 @@ def clangFormat(filename): subprocess.check_call(cmd, stdout=devnull) -def getSerializeArgs(): +def getSerialLibAtts(): """ Parse attribute.def file and return a dictionary where the keys are Serialization library operator names. The values are the arguments required by each Serialization library operator. """ - serializeArgs = {} + serialLibAtts = {} with open("../../thirdparty/serialization_lib/include/attribute.def") as file: preamble = True inAtt = False @@ -315,11 +352,11 @@ def getSerializeArgs(): } args.append(arg) if ")" in line: - serializeArgs[opName] = args + serialLibAtts[opName] = args opName = "" args = [] inAtt = False - return serializeArgs + return serialLibAtts def renderTemplate(environment, dataTypes, operators, template, outfile): @@ -349,12 +386,12 @@ def getSerializeOpTypeMap(): """ import re - allSerializeArgs = getSerializeArgs() - serArgs = [ + allSerialLibAtts = getSerialLibAtts() + serAtts = [ re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower() - for name in allSerializeArgs.keys() + for name in allSerialLibAtts.keys() ] - serArgs = sorted(serArgs, key=len, reverse=True) + serAtts = sorted(serAtts, key=len, reverse=True) tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml") opsXml = tosaXml.getElementsByTagName("operator") opNames = [ @@ -362,9 +399,9 @@ def getSerializeOpTypeMap(): ] map = {} for opName in opNames: - for serArg in serArgs: - if serArg in opName: - components = serArg.split("_") + for serAtt in serAtts: + if serAtt in opName: + components = serAtt.split("_") map[opName] = "".join(x.title() for x in components) return map diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 index a8f1c24..6b6f864 100644 --- a/scripts/operator_api/templates/operators_cc.j2 +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -67,6 +67,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type) return tosa::DType::DType_UINT16; case tosa_datatype_uint8_t: return tosa::DType::DType_UINT8; + case tosa_datatype_shape_t: + return tosa::DType::DType_SHAPE; default: return tosa::DType::DType_UNKNOWN; } @@ -99,24 +101,24 @@ extern "C" {%- for arg in operator.arguments: -%} {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} {% if loop.index < operator.arguments|length %},{% endif %} - {%- endfor -%}, const func_config_t &func_config, - const func_debug_t &func_debug + {%- endfor -%},const func_ctx_t& func_ctx ) { // Create operator attributes - {% for arg in operator.serializeArgs: %} - {%- if arg.SV == "V": -%} - const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}}; - {%- else: -%} - const {{arg.dType}} {{arg.name}}{{arg.init}}; - {%- endif -%} + {% for att in operator.serialLibAtts: -%} + {{att.init}} {%- endfor -%} Tosa{{operator.serializeAttType}}Attribute attr - {%- if operator.serializeArgs|length > 0 -%} + {%- if operator.serialLibAtts|length > 0 -%} ( - {%- for arg in operator.serializeArgs: -%} - {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %} + {%- for att in operator.serialLibAtts: -%} + {%- if att.init == "" -%} + client_{{att.name}} + {%- else -%} + {{att.name}} + {%- endif -%} + {% if loop.index < operator.serialLibAtts|length %}, {% endif %} {%- endfor -%} ) {%- endif -%}; @@ -174,7 +176,7 @@ extern "C" }); // Setup model - TosaReference::ModelRunnerImpl runner(func_config, func_debug); + TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug); TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); {% for input in operator.inputs: -%} TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size)); diff --git a/scripts/operator_api/templates/operators_h.j2 b/scripts/operator_api/templates/operators_h.j2 index 042d7a5..0c98da8 100644 --- a/scripts/operator_api/templates/operators_h.j2 +++ b/scripts/operator_api/templates/operators_h.j2 @@ -21,6 +21,7 @@ #include "func_config.h" #include "func_debug.h" +#include "types.h" #include <stddef.h> #include <stdint.h> @@ -29,37 +30,10 @@ extern "C" { #endif /* __cplusplus */ - // Note status needs to be aligned with graph_status - enum tosa_status_t + struct func_ctx_t { - tosa_status_valid = 0, - tosa_status_unpredictable = 1, - tosa_status_error = 2 - }; - - enum tosa_mode_t - { - tosa_mode_unknown = 0, - tosa_mode_nearest = 1, - tosa_mode_bilinear = 2, - tosa_mode_min = 3, - tosa_mode_max = 4 - }; - - enum tosa_datatype_t - { - {% for dataType in dataTypes: -%} - {{dataType}} = {{loop.index-1}}, - {% endfor -%} - }; - - struct tosa_tensor_t - { - int32_t* shape; - int32_t num_dims; - tosa_datatype_t data_type; - uint8_t* data; - size_t size; + func_config_t func_config = func_config_t{}; + func_debug_t func_debug = func_debug_t{}; }; {% for operator in operators: %} @@ -67,8 +41,7 @@ extern "C" { {%- for arg in operator.arguments: -%} {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} {% if loop.index < operator.arguments|length %},{% endif %} - {%- endfor -%}, const func_config_t &func_config = func_config_t{}, - const func_debug_t &func_debug = func_debug_t{}); + {%- endfor -%},const func_ctx_t& func_ctx); {% endfor %} #ifdef __cplusplus |