aboutsummaryrefslogtreecommitdiff
path: root/scripts/operator_api/templates/operators_cc.j2
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/operator_api/templates/operators_cc.j2')
-rw-r--r--scripts/operator_api/templates/operators_cc.j2128
1 files changed, 83 insertions, 45 deletions
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
index 1de103a..0fc52ab 100644
--- a/scripts/operator_api/templates/operators_cc.j2
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -20,6 +20,17 @@
#include "model_runner_impl.h"
#include "ops/op_factory.h"
+
+#define TOSA_PROPAGATE_ERROR(status) \
+ do \
+ { \
+ if (status != 0) \
+ { \
+ return status; \
+ } \
+ } while (false)
+
+
#define TOSA_RETURN_ON_ERROR(status) \
do \
{ \
@@ -74,12 +85,60 @@ tosa::DType translate_client_datatype(tosa_datatype_t type)
}
};
+using TosaTensorInfo = std::pair<tosa::TosaSerializationTensor*, tosa_tensor_t*>;
+
tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
{
std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
}
+void addTensor(std::vector<TosaTensorInfo> &tensors, tosa_tensor_t& tensor, std::string tensorName) {
+ auto tensorDescr = translate_client_tensor(tensor, tensorName);
+ tensors.push_back(std::make_pair(tensorDescr, &tensor));
+}
+
+int setInputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& inputTensors)
+{
+ for (const auto& [tensorDescr, tensorData] : inputTensors)
+ {
+ auto status = runner.setInput(tensorDescr->GetName(), tensorData->data, tensorData->size);
+ TOSA_PROPAGATE_ERROR(status);
+ }
+
+ return 0;
+}
+
+int getOutputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& outputTensors)
+{
+ for (const auto& [tensorDescr, tensorData] : outputTensors)
+ {
+ auto status = runner.getOutput(tensorDescr->GetName(), tensorData->data, tensorData->size);
+ TOSA_PROPAGATE_ERROR(status);
+ }
+
+ return 0;
+}
+
+std::vector<std::string> getTensorNames(std::vector<TosaTensorInfo>& tensors)
+{
+ std::vector<std::string> tensorNames;
+ const auto mapping = [](const TosaTensorInfo &info){ return info.first->GetName(); };
+
+ std::transform(tensors.cbegin(), tensors.cend(), std::back_inserter(tensorNames), mapping);
+ return tensorNames;
+}
+
+std::vector<TosaSerializationTensor*> allTensors(std::vector<TosaTensorInfo> &inputTensors, std::vector<TosaTensorInfo> &outputTensors) {
+ std::vector<TosaSerializationTensor*> result;
+ const auto mapping = [](const TosaTensorInfo &info){ return info.first; };
+
+ std::transform(inputTensors.cbegin(), inputTensors.cend(), std::back_inserter(result), mapping);
+ std::transform(outputTensors.cbegin(), outputTensors.cend(), std::back_inserter(result), mapping);
+
+ return result;
+}
+
tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
switch(mode) {
case tosa_mode_nearest:
@@ -113,7 +172,7 @@ extern "C"
tosa_status_t tosa_run_{{ operator.name }} (
{%- for arg in operator.arguments: -%}
{% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
- {% if loop.index < operator.arguments|length %},{% endif %}
+ {% if not loop.last %},{% endif %}
{%- endfor -%},const func_ctx_t& func_ctx
)
{
@@ -131,17 +190,26 @@ extern "C"
{%- else -%}
{{att.name}}
{%- endif -%}
- {% if loop.index < operator.serialLibAtts|length %}, {% endif %}
+ {% if not loop.last %}, {% endif %}
{%- endfor -%}
)
{%- endif -%};
// Create tensors
+ std::vector<TosaTensorInfo> inputTensors;
{% for input in operator.inputs: -%}
- tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}");
- {%- endfor -%}
- {% for output in operator.outputs: %}
- tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}");
+ {%- if input.type == "tosa_tensor_list_t" -%}
+ for (int i = 0; i < client_{{input.name}}.size; i++) {
+ addTensor(inputTensors, client_{{input.name}}.tensors[i], "{{input.name}}-" + std::to_string(i));
+ }
+ {%- else -%}
+ addTensor(inputTensors, client_{{input.name}}, "{{input.name}}");
+ {%- endif -%}
+ {%- endfor %}
+
+ std::vector<TosaTensorInfo> outputTensors;
+ {% for output in operator.outputs: -%}
+ addTensor(outputTensors, client_{{output}}, "{{output}}");
{%- endfor %}
// Create operator
@@ -151,57 +219,27 @@ extern "C"
{%- else -%}
tosa::Attribute::Attribute_NONE
{%- endif -%},
- &attr, {
- {%- for input in operator.inputs: -%}
- {{input}}->GetName()
- {%- if loop.index < operator.inputs|length -%},{%- endif -%}
- {%- endfor -%}
- },
- {
- {%- for output in operator.outputs: -%}
- {{output}}->GetName()
- {%- if loop.index < operator.outputs|length -%},{%- endif -%}
- {%- endfor -%}
- });
+ &attr,
+ getTensorNames(inputTensors),
+ getTensorNames(outputTensors));
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op },
- {
- {%- for input in operator.inputs: -%}
- {{input}},
- {%- endfor -%}
- {%- for output in operator.outputs: -%}
- {{output}}
- {%- if loop.index < operator.outputs|length -%},{%- endif -%}
- {%- endfor -%}
- },
- {
- {%- for input in operator.inputs: -%}
- {{input}}->GetName()
- {%- if loop.index < operator.inputs|length -%},{%- endif -%}
- {%- endfor -%}
- },
- {
- {%- for output in operator.outputs: -%}
- {{output}}->GetName()
- {%- if loop.index < operator.outputs|length -%},{%- endif -%}
- {%- endfor -%}
- });
+ allTensors(inputTensors, outputTensors),
+ op->GetInputTensorNames(),
+ op->GetOutputTensorNames());
// Setup model
TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
- {% for input in operator.inputs: -%}
- TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
- {%- endfor %}
+
+ TOSA_RETURN_ON_ERROR(setInputTensors(runner, inputTensors));
// Execute
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
// Extract outputs
- {% for output in operator.outputs: -%}
- TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size));
- {%- endfor %}
+ TOSA_RETURN_ON_ERROR(getOutputTensors(runner, outputTensors));
return tosa_status_valid;
}