diff options
Diffstat (limited to 'scripts/operator_api/templates/operators_cc.j2')
-rw-r--r-- | scripts/operator_api/templates/operators_cc.j2 | 128 |
1 files changed, 83 insertions, 45 deletions
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 index 1de103a..0fc52ab 100644 --- a/scripts/operator_api/templates/operators_cc.j2 +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -20,6 +20,17 @@ #include "model_runner_impl.h" #include "ops/op_factory.h" + +#define TOSA_PROPAGATE_ERROR(status) \ + do \ + { \ + if (status != 0) \ + { \ + return status; \ + } \ + } while (false) + + #define TOSA_RETURN_ON_ERROR(status) \ do \ { \ @@ -74,12 +85,60 @@ tosa::DType translate_client_datatype(tosa_datatype_t type) } }; +using TosaTensorInfo = std::pair<tosa::TosaSerializationTensor*, tosa_tensor_t*>; + tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name) { std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims); return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {}); } +void addTensor(std::vector<TosaTensorInfo> &tensors, tosa_tensor_t& tensor, std::string tensorName) { + auto tensorDescr = translate_client_tensor(tensor, tensorName); + tensors.push_back(std::make_pair(tensorDescr, &tensor)); +} + +int setInputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& inputTensors) +{ + for (const auto& [tensorDescr, tensorData] : inputTensors) + { + auto status = runner.setInput(tensorDescr->GetName(), tensorData->data, tensorData->size); + TOSA_PROPAGATE_ERROR(status); + } + + return 0; +} + +int getOutputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& outputTensors) +{ + for (const auto& [tensorDescr, tensorData] : outputTensors) + { + auto status = runner.getOutput(tensorDescr->GetName(), tensorData->data, tensorData->size); + TOSA_PROPAGATE_ERROR(status); + } + + return 0; +} + +std::vector<std::string> getTensorNames(std::vector<TosaTensorInfo>& tensors) +{ + std::vector<std::string> tensorNames; + const auto mapping = [](const TosaTensorInfo &info){ return info.first->GetName(); }; + + std::transform(tensors.cbegin(), tensors.cend(), std::back_inserter(tensorNames), mapping); + return tensorNames; +} + +std::vector<TosaSerializationTensor*> allTensors(std::vector<TosaTensorInfo> &inputTensors, std::vector<TosaTensorInfo> &outputTensors) { + std::vector<TosaSerializationTensor*> result; + const auto mapping = [](const TosaTensorInfo &info){ return info.first; }; + + std::transform(inputTensors.cbegin(), inputTensors.cend(), std::back_inserter(result), mapping); + std::transform(outputTensors.cbegin(), outputTensors.cend(), std::back_inserter(result), mapping); + + return result; +} + tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) { switch(mode) { case tosa_mode_nearest: @@ -113,7 +172,7 @@ extern "C" tosa_status_t tosa_run_{{ operator.name }} ( {%- for arg in operator.arguments: -%} {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} - {% if loop.index < operator.arguments|length %},{% endif %} + {% if not loop.last %},{% endif %} {%- endfor -%},const func_ctx_t& func_ctx ) { @@ -131,17 +190,26 @@ extern "C" {%- else -%} {{att.name}} {%- endif -%} - {% if loop.index < operator.serialLibAtts|length %}, {% endif %} + {% if not loop.last %}, {% endif %} {%- endfor -%} ) {%- endif -%}; // Create tensors + std::vector<TosaTensorInfo> inputTensors; {% for input in operator.inputs: -%} - tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}"); - {%- endfor -%} - {% for output in operator.outputs: %} - tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}"); + {%- if input.type == "tosa_tensor_list_t" -%} + for (int i = 0; i < client_{{input.name}}.size; i++) { + addTensor(inputTensors, client_{{input.name}}.tensors[i], "{{input.name}}-" + std::to_string(i)); + } + {%- else -%} + addTensor(inputTensors, client_{{input.name}}, "{{input.name}}"); + {%- endif -%} + {%- endfor %} + + std::vector<TosaTensorInfo> outputTensors; + {% for output in operator.outputs: -%} + addTensor(outputTensors, client_{{output}}, "{{output}}"); {%- endfor %} // Create operator @@ -151,57 +219,27 @@ extern "C" {%- else -%} tosa::Attribute::Attribute_NONE {%- endif -%}, - &attr, { - {%- for input in operator.inputs: -%} - {{input}}->GetName() - {%- if loop.index < operator.inputs|length -%},{%- endif -%} - {%- endfor -%} - }, - { - {%- for output in operator.outputs: -%} - {{output}}->GetName() - {%- if loop.index < operator.outputs|length -%},{%- endif -%} - {%- endfor -%} - }); + &attr, + getTensorNames(inputTensors), + getTensorNames(outputTensors)); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op }, - { - {%- for input in operator.inputs: -%} - {{input}}, - {%- endfor -%} - {%- for output in operator.outputs: -%} - {{output}} - {%- if loop.index < operator.outputs|length -%},{%- endif -%} - {%- endfor -%} - }, - { - {%- for input in operator.inputs: -%} - {{input}}->GetName() - {%- if loop.index < operator.inputs|length -%},{%- endif -%} - {%- endfor -%} - }, - { - {%- for output in operator.outputs: -%} - {{output}}->GetName() - {%- if loop.index < operator.outputs|length -%},{%- endif -%} - {%- endfor -%} - }); + allTensors(inputTensors, outputTensors), + op->GetInputTensorNames(), + op->GetOutputTensorNames()); // Setup model TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug); TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); - {% for input in operator.inputs: -%} - TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size)); - {%- endfor %} + + TOSA_RETURN_ON_ERROR(setInputTensors(runner, inputTensors)); // Execute TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); // Extract outputs - {% for output in operator.outputs: -%} - TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size)); - {%- endfor %} + TOSA_RETURN_ON_ERROR(getOutputTensors(runner, outputTensors)); return tosa_status_valid; } |