diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2023-12-01 12:18:15 +0000 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2023-12-07 17:19:33 +0000 |
commit | 0de60f38e828e13359acbcfac51b6c179a34d042 (patch) | |
tree | 532cda49e0c9101440715b17a3e18ddaebc75858 /scripts | |
parent | e9059775c0486de4a96d42b41104496f4aefe8e8 (diff) | |
download | reference_model-0de60f38e828e13359acbcfac51b6c179a34d042.tar.gz |
Add support for list of tensors as input parameter
Some operators (e.g. Concat) expect list of tensor as an input
parameter. Currently operators API does not support passing
such parameters from the client code.
In order to enable it:
- Add new type tensor_list_t
- Update operators API generation script to support new type
- Add unit test for operator Concat
Signed-off-by: Dmitrii Agibov <dmitrii.agibov@arm.com>
Change-Id: Ib2f61bcea5e5ecabf56ce031d905cb46a4cc68ea
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/operator_api/generate_api.py | 17 | ||||
-rw-r--r-- | scripts/operator_api/templates/operators_cc.j2 | 128 |
2 files changed, 89 insertions, 56 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index 7f10568..e511f19 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -10,8 +10,6 @@ from xml.dom import minidom from jinja2 import Environment from jinja2 import FileSystemLoader -# Note: main script designed to be run from the scripts/operator_api/ directory - def getBasePath(): return Path(__file__).resolve().parent.parent.parent @@ -82,10 +80,7 @@ def getSerializeOpType(tosaOpName): "cond_if": "CondIf", "while_loop": "WhileLoop", } - if tosaOpName not in map.keys(): - return "None" - else: - return map[tosaOpName] + return map.get(tosaOpName, "None") def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): @@ -268,7 +263,7 @@ def getOperators(tosaXml): operator["serialLibAtts"] = serialLibAtts serializationAttNames = [att["name"] for att in serialLibAtts] operator["inputs"] = [ - arg["name"] + {"name": arg["name"], "type": arg["type"]} for arg in tosaArgs if arg["category"] == "input" and arg["name"] not in serializationAttNames @@ -308,7 +303,7 @@ def getTosaArgs(opXml): if argType[-1:] == "*": argType = argType[:-1] if argCategory in ["input", "output"] and argType in tosaTensorTypes: - argType = "tosa_tensor_t" + argType = f"tosa_{argType}" argShape = "" if argType in tosaTypeMap: argType = tosaTypeMap[argType] @@ -386,7 +381,7 @@ def getSerialLibAtts(): return serialLibAtts -def renderTemplate(environment, dataTypes, operators, template, outfile): +def renderTemplate(dataTypes, operators, template, outfile): content = template.render(dataTypes=dataTypes, operators=operators) with open(outfile, mode="w", encoding="utf-8") as output: output.write(content) @@ -399,12 +394,12 @@ def generate(environment, dataTypes, operators, base_path): # Generate include/operators.h template = environment.get_template("operators_h.j2") outfile = base_path / "reference_model/include/operators.h" - renderTemplate(environment, dataTypes, operators, template, outfile) + renderTemplate(dataTypes, operators, template, outfile) # Generate src/operators.cc template = environment.get_template("operators_cc.j2") outfile = base_path / "reference_model/src/operators.cc" - renderTemplate(environment, dataTypes, operators, template, outfile) + renderTemplate(dataTypes, operators, template, outfile) if __name__ == "__main__": diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 index 1de103a..0fc52ab 100644 --- a/scripts/operator_api/templates/operators_cc.j2 +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -20,6 +20,17 @@ #include "model_runner_impl.h" #include "ops/op_factory.h" + +#define TOSA_PROPAGATE_ERROR(status) \ + do \ + { \ + if (status != 0) \ + { \ + return status; \ + } \ + } while (false) + + #define TOSA_RETURN_ON_ERROR(status) \ do \ { \ @@ -74,12 +85,60 @@ tosa::DType translate_client_datatype(tosa_datatype_t type) } }; +using TosaTensorInfo = std::pair<tosa::TosaSerializationTensor*, tosa_tensor_t*>; + tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name) { std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims); return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {}); } +void addTensor(std::vector<TosaTensorInfo> &tensors, tosa_tensor_t& tensor, std::string tensorName) { + auto tensorDescr = translate_client_tensor(tensor, tensorName); + tensors.push_back(std::make_pair(tensorDescr, &tensor)); +} + +int setInputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& inputTensors) +{ + for (const auto& [tensorDescr, tensorData] : inputTensors) + { + auto status = runner.setInput(tensorDescr->GetName(), tensorData->data, tensorData->size); + TOSA_PROPAGATE_ERROR(status); + } + + return 0; +} + +int getOutputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& outputTensors) +{ + for (const auto& [tensorDescr, tensorData] : outputTensors) + { + auto status = runner.getOutput(tensorDescr->GetName(), tensorData->data, tensorData->size); + TOSA_PROPAGATE_ERROR(status); + } + + return 0; +} + +std::vector<std::string> getTensorNames(std::vector<TosaTensorInfo>& tensors) +{ + std::vector<std::string> tensorNames; + const auto mapping = [](const TosaTensorInfo &info){ return info.first->GetName(); }; + + std::transform(tensors.cbegin(), tensors.cend(), std::back_inserter(tensorNames), mapping); + return tensorNames; +} + +std::vector<TosaSerializationTensor*> allTensors(std::vector<TosaTensorInfo> &inputTensors, std::vector<TosaTensorInfo> &outputTensors) { + std::vector<TosaSerializationTensor*> result; + const auto mapping = [](const TosaTensorInfo &info){ return info.first; }; + + std::transform(inputTensors.cbegin(), inputTensors.cend(), std::back_inserter(result), mapping); + std::transform(outputTensors.cbegin(), outputTensors.cend(), std::back_inserter(result), mapping); + + return result; +} + tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) { switch(mode) { case tosa_mode_nearest: @@ -113,7 +172,7 @@ extern "C" tosa_status_t tosa_run_{{ operator.name }} ( {%- for arg in operator.arguments: -%} {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} - {% if loop.index < operator.arguments|length %},{% endif %} + {% if not loop.last %},{% endif %} {%- endfor -%},const func_ctx_t& func_ctx ) { @@ -131,17 +190,26 @@ extern "C" {%- else -%} {{att.name}} {%- endif -%} - {% if loop.index < operator.serialLibAtts|length %}, {% endif %} + {% if not loop.last %}, {% endif %} {%- endfor -%} ) {%- endif -%}; // Create tensors + std::vector<TosaTensorInfo> inputTensors; {% for input in operator.inputs: -%} - tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}"); - {%- endfor -%} - {% for output in operator.outputs: %} - tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}"); + {%- if input.type == "tosa_tensor_list_t" -%} + for (int i = 0; i < client_{{input.name}}.size; i++) { + addTensor(inputTensors, client_{{input.name}}.tensors[i], "{{input.name}}-" + std::to_string(i)); + } + {%- else -%} + addTensor(inputTensors, client_{{input.name}}, "{{input.name}}"); + {%- endif -%} + {%- endfor %} + + std::vector<TosaTensorInfo> outputTensors; + {% for output in operator.outputs: -%} + addTensor(outputTensors, client_{{output}}, "{{output}}"); {%- endfor %} // Create operator @@ -151,57 +219,27 @@ extern "C" {%- else -%} tosa::Attribute::Attribute_NONE {%- endif -%}, - &attr, { - {%- for input in operator.inputs: -%} - {{input}}->GetName() - {%- if loop.index < operator.inputs|length -%},{%- endif -%} - {%- endfor -%} - }, - { - {%- for output in operator.outputs: -%} - {{output}}->GetName() - {%- if loop.index < operator.outputs|length -%},{%- endif -%} - {%- endfor -%} - }); + &attr, + getTensorNames(inputTensors), + getTensorNames(outputTensors)); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op }, - { - {%- for input in operator.inputs: -%} - {{input}}, - {%- endfor -%} - {%- for output in operator.outputs: -%} - {{output}} - {%- if loop.index < operator.outputs|length -%},{%- endif -%} - {%- endfor -%} - }, - { - {%- for input in operator.inputs: -%} - {{input}}->GetName() - {%- if loop.index < operator.inputs|length -%},{%- endif -%} - {%- endfor -%} - }, - { - {%- for output in operator.outputs: -%} - {{output}}->GetName() - {%- if loop.index < operator.outputs|length -%},{%- endif -%} - {%- endfor -%} - }); + allTensors(inputTensors, outputTensors), + op->GetInputTensorNames(), + op->GetOutputTensorNames()); // Setup model TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug); TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); - {% for input in operator.inputs: -%} - TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size)); - {%- endfor %} + + TOSA_RETURN_ON_ERROR(setInputTensors(runner, inputTensors)); // Execute TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); // Extract outputs - {% for output in operator.outputs: -%} - TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size)); - {%- endfor %} + TOSA_RETURN_ON_ERROR(getOutputTensors(runner, outputTensors)); return tosa_status_valid; } |