diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/operator_api/generate_api.py | 17 | ||||
-rw-r--r-- | scripts/operator_api/templates/operators_cc.j2 | 128 |
2 files changed, 89 insertions, 56 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index 7f10568..e511f19 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -10,8 +10,6 @@ from xml.dom import minidom from jinja2 import Environment from jinja2 import FileSystemLoader -# Note: main script designed to be run from the scripts/operator_api/ directory - def getBasePath(): return Path(__file__).resolve().parent.parent.parent @@ -82,10 +80,7 @@ def getSerializeOpType(tosaOpName): "cond_if": "CondIf", "while_loop": "WhileLoop", } - if tosaOpName not in map.keys(): - return "None" - else: - return map[tosaOpName] + return map.get(tosaOpName, "None") def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): @@ -268,7 +263,7 @@ def getOperators(tosaXml): operator["serialLibAtts"] = serialLibAtts serializationAttNames = [att["name"] for att in serialLibAtts] operator["inputs"] = [ - arg["name"] + {"name": arg["name"], "type": arg["type"]} for arg in tosaArgs if arg["category"] == "input" and arg["name"] not in serializationAttNames @@ -308,7 +303,7 @@ def getTosaArgs(opXml): if argType[-1:] == "*": argType = argType[:-1] if argCategory in ["input", "output"] and argType in tosaTensorTypes: - argType = "tosa_tensor_t" + argType = f"tosa_{argType}" argShape = "" if argType in tosaTypeMap: argType = tosaTypeMap[argType] @@ -386,7 +381,7 @@ def getSerialLibAtts(): return serialLibAtts -def renderTemplate(environment, dataTypes, operators, template, outfile): +def renderTemplate(dataTypes, operators, template, outfile): content = template.render(dataTypes=dataTypes, operators=operators) with open(outfile, mode="w", encoding="utf-8") as output: output.write(content) @@ -399,12 +394,12 @@ def generate(environment, dataTypes, operators, base_path): # Generate include/operators.h template = environment.get_template("operators_h.j2") outfile = base_path / "reference_model/include/operators.h" - renderTemplate(environment, dataTypes, operators, template, outfile) + renderTemplate(dataTypes, operators, template, outfile) # Generate src/operators.cc template = environment.get_template("operators_cc.j2") outfile = base_path / "reference_model/src/operators.cc" - renderTemplate(environment, dataTypes, operators, template, outfile) + renderTemplate(dataTypes, operators, template, outfile) if __name__ == "__main__": diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 index 1de103a..0fc52ab 100644 --- a/scripts/operator_api/templates/operators_cc.j2 +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -20,6 +20,17 @@ #include "model_runner_impl.h" #include "ops/op_factory.h" + +#define TOSA_PROPAGATE_ERROR(status) \ + do \ + { \ + if (status != 0) \ + { \ + return status; \ + } \ + } while (false) + + #define TOSA_RETURN_ON_ERROR(status) \ do \ { \ @@ -74,12 +85,60 @@ tosa::DType translate_client_datatype(tosa_datatype_t type) } }; +using TosaTensorInfo = std::pair<tosa::TosaSerializationTensor*, tosa_tensor_t*>; + tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name) { std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims); return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {}); } +void addTensor(std::vector<TosaTensorInfo> &tensors, tosa_tensor_t& tensor, std::string tensorName) { + auto tensorDescr = translate_client_tensor(tensor, tensorName); + tensors.push_back(std::make_pair(tensorDescr, &tensor)); +} + +int setInputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& inputTensors) +{ + for (const auto& [tensorDescr, tensorData] : inputTensors) + { + auto status = runner.setInput(tensorDescr->GetName(), tensorData->data, tensorData->size); + TOSA_PROPAGATE_ERROR(status); + } + + return 0; +} + +int getOutputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& outputTensors) +{ + for (const auto& [tensorDescr, tensorData] : outputTensors) + { + auto status = runner.getOutput(tensorDescr->GetName(), tensorData->data, tensorData->size); + TOSA_PROPAGATE_ERROR(status); + } + + return 0; +} + +std::vector<std::string> getTensorNames(std::vector<TosaTensorInfo>& tensors) +{ + std::vector<std::string> tensorNames; + const auto mapping = [](const TosaTensorInfo &info){ return info.first->GetName(); }; + + std::transform(tensors.cbegin(), tensors.cend(), std::back_inserter(tensorNames), mapping); + return tensorNames; +} + +std::vector<TosaSerializationTensor*> allTensors(std::vector<TosaTensorInfo> &inputTensors, std::vector<TosaTensorInfo> &outputTensors) { + std::vector<TosaSerializationTensor*> result; + const auto mapping = [](const TosaTensorInfo &info){ return info.first; }; + + std::transform(inputTensors.cbegin(), inputTensors.cend(), std::back_inserter(result), mapping); + std::transform(outputTensors.cbegin(), outputTensors.cend(), std::back_inserter(result), mapping); + + return result; +} + tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) { switch(mode) { case tosa_mode_nearest: @@ -113,7 +172,7 @@ extern "C" tosa_status_t tosa_run_{{ operator.name }} ( {%- for arg in operator.arguments: -%} {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} - {% if loop.index < operator.arguments|length %},{% endif %} + {% if not loop.last %},{% endif %} {%- endfor -%},const func_ctx_t& func_ctx ) { @@ -131,17 +190,26 @@ extern "C" {%- else -%} {{att.name}} {%- endif -%} - {% if loop.index < operator.serialLibAtts|length %}, {% endif %} + {% if not loop.last %}, {% endif %} {%- endfor -%} ) {%- endif -%}; // Create tensors + std::vector<TosaTensorInfo> inputTensors; {% for input in operator.inputs: -%} - tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}"); - {%- endfor -%} - {% for output in operator.outputs: %} - tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}"); + {%- if input.type == "tosa_tensor_list_t" -%} + for (int i = 0; i < client_{{input.name}}.size; i++) { + addTensor(inputTensors, client_{{input.name}}.tensors[i], "{{input.name}}-" + std::to_string(i)); + } + {%- else -%} + addTensor(inputTensors, client_{{input.name}}, "{{input.name}}"); + {%- endif -%} + {%- endfor %} + + std::vector<TosaTensorInfo> outputTensors; + {% for output in operator.outputs: -%} + addTensor(outputTensors, client_{{output}}, "{{output}}"); {%- endfor %} // Create operator @@ -151,57 +219,27 @@ extern "C" {%- else -%} tosa::Attribute::Attribute_NONE {%- endif -%}, - &attr, { - {%- for input in operator.inputs: -%} - {{input}}->GetName() - {%- if loop.index < operator.inputs|length -%},{%- endif -%} - {%- endfor -%} - }, - { - {%- for output in operator.outputs: -%} - {{output}}->GetName() - {%- if loop.index < operator.outputs|length -%},{%- endif -%} - {%- endfor -%} - }); + &attr, + getTensorNames(inputTensors), + getTensorNames(outputTensors)); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op }, - { - {%- for input in operator.inputs: -%} - {{input}}, - {%- endfor -%} - {%- for output in operator.outputs: -%} - {{output}} - {%- if loop.index < operator.outputs|length -%},{%- endif -%} - {%- endfor -%} - }, - { - {%- for input in operator.inputs: -%} - {{input}}->GetName() - {%- if loop.index < operator.inputs|length -%},{%- endif -%} - {%- endfor -%} - }, - { - {%- for output in operator.outputs: -%} - {{output}}->GetName() - {%- if loop.index < operator.outputs|length -%},{%- endif -%} - {%- endfor -%} - }); + allTensors(inputTensors, outputTensors), + op->GetInputTensorNames(), + op->GetOutputTensorNames()); // Setup model TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug); TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); - {% for input in operator.inputs: -%} - TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size)); - {%- endfor %} + + TOSA_RETURN_ON_ERROR(setInputTensors(runner, inputTensors)); // Execute TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); // Extract outputs - {% for output in operator.outputs: -%} - TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size)); - {%- endfor %} + TOSA_RETURN_ON_ERROR(getOutputTensors(runner, outputTensors)); return tosa_status_valid; } |