diff options
Diffstat (limited to 'scripts/operator_api/generate_api.py')
-rw-r--r-- | scripts/operator_api/generate_api.py | 413 |
1 files changed, 0 insertions, 413 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py deleted file mode 100644 index e511f19..0000000 --- a/scripts/operator_api/generate_api.py +++ /dev/null @@ -1,413 +0,0 @@ -"""Generate extended reference model API with eager operator execution entrypoints""" -# Copyright (c) 2021-2023, ARM Limited. -# SPDX-License-Identifier: Apache-2.0 -import copy -import os -import subprocess -from pathlib import Path -from xml.dom import minidom - -from jinja2 import Environment -from jinja2 import FileSystemLoader - - -def getBasePath(): - return Path(__file__).resolve().parent.parent.parent - - -def getTosaArgTypes(tosaXml): - """ - Returns a list of the TOSA argument types from tosa.xml. - """ - argTypes = { - "tensor_t", - "in_t", - "out_t", - "mul_t", - "weight_t", - "in_out_t", - "tensor_list_t", - } - argTypesXml = tosaXml.getElementsByTagName("type") - for argTypeXml in argTypesXml: - argTypes.add(argTypeXml.getAttribute("name")) - argTypes.remove("TABLE_SIZE") - return argTypes - - -def getTosaDataTypes(tosaXml): - """ - Returns a list of the TOSA data types from tosa.xml. - """ - argTypes = getTosaArgTypes(tosaXml) - dataTypes = set() - dataTypesXml = tosaXml.getElementsByTagName("typesupport") - for dataTypeXml in dataTypesXml: - for argType in argTypes: - dataType = dataTypeXml.getAttribute(argType) - if dataType != "": - dataTypes.add(f"tosa_datatype_{dataType}") - return sorted(dataTypes) - - -def getSerializeOpType(tosaOpName): - """ - Returns the Serialization library operator that matches the TOSA operator specified. - """ - map = { - "avg_pool2d": "Pool", - "conv2d": "Conv", - "conv3d": "Conv", - "depthwise_conv2d": "Conv", - "fully_connected": "FullyConnected", - "fft2d": "FFT", - "rfft2d": "RFFT", - "matmul": "MatMul", - "max_pool2d": "Pool", - "transpose_conv2d": "TransposeConv", - "clamp": "Clamp", - "arithmetic_right_shift": "ArithmeticRightShift", - "mul": "Mul", - "table": "Table", - "negate": "Negate", - "pad": "Pad", - "reshape": "Reshape", - "slice": "Slice", - "tile": "Tile", - "transpose": "Transpose", - "resize": "Resize", - "rescale": "Rescale", - "cond_if": "CondIf", - "while_loop": "WhileLoop", - } - return map.get(tosaOpName, "None") - - -def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): - """ - Returns the attributes required by the Serialization library for the TOSA operator specified. - Generates code to initialize Serialization library attributes. If a matching TOSA argument exists, - that value is used for initialization, otherwise a default value e.g. 0 is used. - """ - serLibOpType = getSerializeOpType(tosaOpName) - if serLibOpType not in allSerialLibAtts.keys(): - return {} - else: - serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType]) - tosaArgsDict = {arg["name"]: arg for arg in tosaArgs} - serTosaTypeMap = {"ResizeMode": "tosa_mode"} - serAttsToFix = { - "reshape": {"new_shape": "shape"}, - "transpose_conv2d": {"output_shape": "out_shape"}, - } - if tosaOpName in serAttsToFix: - # Fix attributes names to match with tosa.xml - for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items(): - for opAtts in serLibOpAtts: - if opAtts["name"] == attDefName: - opAtts["name"] = tosaSpecName - for att in serLibOpAtts: - attName = att["name"] - attType = att["dType"] - init = "" - # Translate TOSA data types to Serialization library data types for initialization - if attType in serTosaTypeMap.keys(): - init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});" - # Initialize Serialization library attributes to their matching function parameter - elif tosaOpName == "avg_pool2d" and attName == "accum_dtype": - init = f"const tosa::DType {attName} = translate_client_acc_size(client_acc_size);" - att["dType"] = "tosa::DType" - elif attName in tosaArgsDict: - if att["SV"] == "V": - if tosaArgsDict[attName]["type"] == "tosa_tensor_t": - init = f"std::vector<{attType}> {attName};" - init = ( - init - + f"size_t {attName}_size = client_{attName}.size / sizeof({attType});" - ) - init = ( - init - + f"{attType}* {attName}_data = reinterpret_cast<{attType}*>(client_{attName}.data);" - ) - init = ( - init - + f"{attName}.assign({attName}_data, {attName}_data + {attName}_size);" - ) - else: - init = f"const std::vector<{attType}> {attName}" - shape = tosaArgsDict[attName]["shape"] - if shape == "[]": - init = ( - init - + f"(&client_{attName}[0], &client_{attName}[0] + client_{attName}_len);" - ) - else: - init = ( - init - + f"(&client_{attName}[0], &client_{attName}{shape});" - ) - else: - init = "" - else: - # Initialize Serialization library attributes with no matching fuction parameter - if att["SV"] == "V": - init = f"std::vector<int32_t> {attName};" - else: - if att["dType"] == "DType": - att["dType"] = "tosa::DType" - init = f"const tosa::DType {attName} = tosa::DType::DType_FP32;" - else: - init = f"const {attType} {attName} = 0;" - att["init"] = init - return serLibOpAtts - - -def updateTosaArgs(tosaArgs, serialLibAtts, tosaXml): - """ - Replace TOSA argument data types with their matching Serialization attribute data types. - Delete TOSA arguments where the type couldn't be determined. - Add Serialization attributes that have no matching TOSA argument. - """ - tosaArgTypes = getTosaArgTypes(tosaXml) - serAttsDict = {att["name"]: att for att in serialLibAtts} - tosaArgsNames = [arg["name"] for arg in tosaArgs] - delTosaArgs = [] - # Replace TOSA argument data types with their matching Serialization attribute data types. - for tosaArg in tosaArgs: - if tosaArg["type"] in tosaArgTypes: - if tosaArg["name"] in serAttsDict: - tosaArg["type"] = serAttsDict[tosaArg["name"]]["dType"] - else: - # Delete TOSA argument whose data type can't be determined - delTosaArgs.append(tosaArgsNames.index(tosaArg["name"])) - # Delete corresponding length argument if one exists - lenArgName = f"{tosaArg['name']}_len" - if lenArgName in tosaArgsNames: - delTosaArgs.append(tosaArgsNames.index(lenArgName)) - # Delete TOSA arguments where the type couldn't be determined - for index in sorted(delTosaArgs, key=int, reverse=True): - del tosaArgs[index] - # Add Serialization attributes that have no matching TOSA argument - tosaArgNames = [arg["name"] for arg in tosaArgs] - for serAtt in serialLibAtts: - attName = serAtt["name"] - attType = serAtt["dType"] - if (attName not in tosaArgNames) and (not attType == "tosa::DType"): - serAttName = serAtt["name"] - if serAtt["SV"] == "V": - # For vector data types, insert a matching length argument - tosaArgs.insert( - len(tosaArgs) - 1, - { - "name": f"{serAttName}_len", - "type": "int32_t", - "shape": "", - "category": "", - }, - ) - init = f"const std::vector<{attType}> {attName}(&client_{serAttName}[0], &client_{serAttName}[0] + client_{serAttName}_len);" - shape = "[]" - else: - init = "" - shape = "" - serAtt["init"] = init - # Insert new argument - tosaArgs.insert( - len(tosaArgs) - 1, - { - "name": serAttName, - "type": serAtt["dType"], - "shape": shape, - "category": "", - }, - ) - - -def getOperators(tosaXml): - """ - Return a list of TOSA operators as defined by tosa.xml. - """ - operators = [] - ignoreOps = [ - "while_loop", - "cond_if", - "const", - "custom", - "variable", - "variable_read", - "variable_write", - ] - opsXml = tosaXml.getElementsByTagName("operator") - allSerialLibAtts = getSerialLibAtts() - for opXml in opsXml: - opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower() - if opName not in ignoreOps: - operator = {"name": opName} - operator["serializeAttType"] = getSerializeOpType(opName) - tosaArgs = getTosaArgs(opXml) - serialLibAtts = getSerialLibAttsForOp(opName, allSerialLibAtts, tosaArgs) - # Handle "axis" arguments - axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"] - if operator["serializeAttType"] == "None" and len(axisList) > 0: - operator["serializeAttType"] = "Axis" - serialLibAtts = [ - { - "name": "axis", - "dType": "int32_t", - "SV": "S", - "init": "", - } - ] - updateTosaArgs(tosaArgs, serialLibAtts, tosaXml) - operator["arguments"] = tosaArgs - operator["serialLibAtts"] = serialLibAtts - serializationAttNames = [att["name"] for att in serialLibAtts] - operator["inputs"] = [ - {"name": arg["name"], "type": arg["type"]} - for arg in tosaArgs - if arg["category"] == "input" - and arg["name"] not in serializationAttNames - ] - operator["outputs"] = [ - arg["name"] for arg in tosaArgs if arg["category"] == "output" - ] - operators.append(operator) - return operators - - -def getTosaArgs(opXml): - """ - Return the arguments required for the TOSA operator specified. - """ - arguments = [] - argsXml = opXml.getElementsByTagName("argument") - tosaTensorTypes = getTosaArgTypes(tosaXml) - tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"} - tensorElemTypeMap = { - "resize_mode_t": "tosa_mode_t", - "acc_size_t": "tosa_acc_size_t", - } - for xmlArg in argsXml: - argName = xmlArg.getAttribute("name").lower() - tensorElemType = xmlArg.getAttribute("tensor-element-type") - if tensorElemType in tensorElemTypeMap: - argType = tensorElemTypeMap[tensorElemType] - else: - argType = xmlArg.getAttribute("type") - argShape = xmlArg.getAttribute("shape") - argCategory = xmlArg.getAttribute("category") - # FullyConnected workaround - if (argName == "weight" or argName == "bias") and (argCategory == "attribute"): - argCategory = "input" - # Update argument type - if argType[-1:] == "*": - argType = argType[:-1] - if argCategory in ["input", "output"] and argType in tosaTensorTypes: - argType = f"tosa_{argType}" - argShape = "" - if argType in tosaTypeMap: - argType = tosaTypeMap[argType] - # Add a length argument for arrays with unknown compile-time size - if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric(): - argShape = "[]" - arguments.append( - { - "name": f"{argName}_len", - "type": "int32_t", - "shape": "", - "category": "", - } - ) - elif argShape == "" or not argShape[0] == "[": - argShape = "" - # Append argument - arguments.append( - { - "name": argName, - "type": argType, - "shape": argShape, - "category": argCategory, - } - ) - return arguments - - -def clangFormat(filename): - cmd = ["clang-format", "-i", filename] - with open(os.devnull, "w") as devnull: - subprocess.check_call(cmd, stdout=devnull) - - -def getSerialLibAtts(): - """ - Parse attribute.def file and return a dictionary where the keys are Serialization library operator names. - The values are the arguments required by each Serialization library operator. - """ - serialLibAtts = {} - base_path = getBasePath() - attr_def = ( - base_path / "thirdparty" / "serialization_lib" / "include" / "attribute.def" - ) - with open(attr_def) as file: - preamble = True - inAtt = False - opName = "" - args = [] - for line in file: - if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(": - continue - else: - preamble = False - line = line.lstrip().rstrip() - if not inAtt and "DEF_ATTRIBUTE(" in line: - opName = line[len("DEF_ATTRIBUTE(") : line.find(",")] - inAtt = True - elif inAtt: - vals = line.split(",") - argName = vals[2].lstrip().strip() - if ")" in argName: - argName = argName[:-1] - arg = { - "name": argName, - "dType": vals[0].lstrip().strip(), - "SV": vals[1].lstrip().strip(), - } - args.append(arg) - if ")" in line: - serialLibAtts[opName] = args - opName = "" - args = [] - inAtt = False - return serialLibAtts - - -def renderTemplate(dataTypes, operators, template, outfile): - content = template.render(dataTypes=dataTypes, operators=operators) - with open(outfile, mode="w", encoding="utf-8") as output: - output.write(content) - print(f"Created {outfile}") - - clangFormat(outfile) - - -def generate(environment, dataTypes, operators, base_path): - # Generate include/operators.h - template = environment.get_template("operators_h.j2") - outfile = base_path / "reference_model/include/operators.h" - renderTemplate(dataTypes, operators, template, outfile) - - # Generate src/operators.cc - template = environment.get_template("operators_cc.j2") - outfile = base_path / "reference_model/src/operators.cc" - renderTemplate(dataTypes, operators, template, outfile) - - -if __name__ == "__main__": - base_path = getBasePath() - environment = Environment( - loader=FileSystemLoader(Path(__file__).resolve().parent / "templates") - ) - tosaXml = minidom.parse(str(base_path / "thirdparty/specification/tosa.xml")) - dataTypes = getTosaDataTypes(tosaXml) - operators = getOperators(tosaXml) - generate(environment, dataTypes, operators, base_path) |