aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/operator_api/generate_api.py17
-rw-r--r--scripts/operator_api/templates/operators_cc.j2128
2 files changed, 89 insertions, 56 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 7f10568..e511f19 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -10,8 +10,6 @@ from xml.dom import minidom
from jinja2 import Environment
from jinja2 import FileSystemLoader
-# Note: main script designed to be run from the scripts/operator_api/ directory
-
def getBasePath():
return Path(__file__).resolve().parent.parent.parent
@@ -82,10 +80,7 @@ def getSerializeOpType(tosaOpName):
"cond_if": "CondIf",
"while_loop": "WhileLoop",
}
- if tosaOpName not in map.keys():
- return "None"
- else:
- return map[tosaOpName]
+ return map.get(tosaOpName, "None")
def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
@@ -268,7 +263,7 @@ def getOperators(tosaXml):
operator["serialLibAtts"] = serialLibAtts
serializationAttNames = [att["name"] for att in serialLibAtts]
operator["inputs"] = [
- arg["name"]
+ {"name": arg["name"], "type": arg["type"]}
for arg in tosaArgs
if arg["category"] == "input"
and arg["name"] not in serializationAttNames
@@ -308,7 +303,7 @@ def getTosaArgs(opXml):
if argType[-1:] == "*":
argType = argType[:-1]
if argCategory in ["input", "output"] and argType in tosaTensorTypes:
- argType = "tosa_tensor_t"
+ argType = f"tosa_{argType}"
argShape = ""
if argType in tosaTypeMap:
argType = tosaTypeMap[argType]
@@ -386,7 +381,7 @@ def getSerialLibAtts():
return serialLibAtts
-def renderTemplate(environment, dataTypes, operators, template, outfile):
+def renderTemplate(dataTypes, operators, template, outfile):
content = template.render(dataTypes=dataTypes, operators=operators)
with open(outfile, mode="w", encoding="utf-8") as output:
output.write(content)
@@ -399,12 +394,12 @@ def generate(environment, dataTypes, operators, base_path):
# Generate include/operators.h
template = environment.get_template("operators_h.j2")
outfile = base_path / "reference_model/include/operators.h"
- renderTemplate(environment, dataTypes, operators, template, outfile)
+ renderTemplate(dataTypes, operators, template, outfile)
# Generate src/operators.cc
template = environment.get_template("operators_cc.j2")
outfile = base_path / "reference_model/src/operators.cc"
- renderTemplate(environment, dataTypes, operators, template, outfile)
+ renderTemplate(dataTypes, operators, template, outfile)
if __name__ == "__main__":
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
index 1de103a..0fc52ab 100644
--- a/scripts/operator_api/templates/operators_cc.j2
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -20,6 +20,17 @@
#include "model_runner_impl.h"
#include "ops/op_factory.h"
+
+#define TOSA_PROPAGATE_ERROR(status) \
+ do \
+ { \
+ if (status != 0) \
+ { \
+ return status; \
+ } \
+ } while (false)
+
+
#define TOSA_RETURN_ON_ERROR(status) \
do \
{ \
@@ -74,12 +85,60 @@ tosa::DType translate_client_datatype(tosa_datatype_t type)
}
};
+using TosaTensorInfo = std::pair<tosa::TosaSerializationTensor*, tosa_tensor_t*>;
+
tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
{
std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
}
+void addTensor(std::vector<TosaTensorInfo> &tensors, tosa_tensor_t& tensor, std::string tensorName) {
+ auto tensorDescr = translate_client_tensor(tensor, tensorName);
+ tensors.push_back(std::make_pair(tensorDescr, &tensor));
+}
+
+int setInputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& inputTensors)
+{
+ for (const auto& [tensorDescr, tensorData] : inputTensors)
+ {
+ auto status = runner.setInput(tensorDescr->GetName(), tensorData->data, tensorData->size);
+ TOSA_PROPAGATE_ERROR(status);
+ }
+
+ return 0;
+}
+
+int getOutputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& outputTensors)
+{
+ for (const auto& [tensorDescr, tensorData] : outputTensors)
+ {
+ auto status = runner.getOutput(tensorDescr->GetName(), tensorData->data, tensorData->size);
+ TOSA_PROPAGATE_ERROR(status);
+ }
+
+ return 0;
+}
+
+std::vector<std::string> getTensorNames(std::vector<TosaTensorInfo>& tensors)
+{
+ std::vector<std::string> tensorNames;
+ const auto mapping = [](const TosaTensorInfo &info){ return info.first->GetName(); };
+
+ std::transform(tensors.cbegin(), tensors.cend(), std::back_inserter(tensorNames), mapping);
+ return tensorNames;
+}
+
+std::vector<TosaSerializationTensor*> allTensors(std::vector<TosaTensorInfo> &inputTensors, std::vector<TosaTensorInfo> &outputTensors) {
+ std::vector<TosaSerializationTensor*> result;
+ const auto mapping = [](const TosaTensorInfo &info){ return info.first; };
+
+ std::transform(inputTensors.cbegin(), inputTensors.cend(), std::back_inserter(result), mapping);
+ std::transform(outputTensors.cbegin(), outputTensors.cend(), std::back_inserter(result), mapping);
+
+ return result;
+}
+
tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
switch(mode) {
case tosa_mode_nearest:
@@ -113,7 +172,7 @@ extern "C"
tosa_status_t tosa_run_{{ operator.name }} (
{%- for arg in operator.arguments: -%}
{% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
- {% if loop.index < operator.arguments|length %},{% endif %}
+ {% if not loop.last %},{% endif %}
{%- endfor -%},const func_ctx_t& func_ctx
)
{
@@ -131,17 +190,26 @@ extern "C"
{%- else -%}
{{att.name}}
{%- endif -%}
- {% if loop.index < operator.serialLibAtts|length %}, {% endif %}
+ {% if not loop.last %}, {% endif %}
{%- endfor -%}
)
{%- endif -%};
// Create tensors
+ std::vector<TosaTensorInfo> inputTensors;
{% for input in operator.inputs: -%}
- tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}");
- {%- endfor -%}
- {% for output in operator.outputs: %}
- tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}");
+ {%- if input.type == "tosa_tensor_list_t" -%}
+ for (int i = 0; i < client_{{input.name}}.size; i++) {
+ addTensor(inputTensors, client_{{input.name}}.tensors[i], "{{input.name}}-" + std::to_string(i));
+ }
+ {%- else -%}
+ addTensor(inputTensors, client_{{input.name}}, "{{input.name}}");
+ {%- endif -%}
+ {%- endfor %}
+
+ std::vector<TosaTensorInfo> outputTensors;
+ {% for output in operator.outputs: -%}
+ addTensor(outputTensors, client_{{output}}, "{{output}}");
{%- endfor %}
// Create operator
@@ -151,57 +219,27 @@ extern "C"
{%- else -%}
tosa::Attribute::Attribute_NONE
{%- endif -%},
- &attr, {
- {%- for input in operator.inputs: -%}
- {{input}}->GetName()
- {%- if loop.index < operator.inputs|length -%},{%- endif -%}
- {%- endfor -%}
- },
- {
- {%- for output in operator.outputs: -%}
- {{output}}->GetName()
- {%- if loop.index < operator.outputs|length -%},{%- endif -%}
- {%- endfor -%}
- });
+ &attr,
+ getTensorNames(inputTensors),
+ getTensorNames(outputTensors));
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op },
- {
- {%- for input in operator.inputs: -%}
- {{input}},
- {%- endfor -%}
- {%- for output in operator.outputs: -%}
- {{output}}
- {%- if loop.index < operator.outputs|length -%},{%- endif -%}
- {%- endfor -%}
- },
- {
- {%- for input in operator.inputs: -%}
- {{input}}->GetName()
- {%- if loop.index < operator.inputs|length -%},{%- endif -%}
- {%- endfor -%}
- },
- {
- {%- for output in operator.outputs: -%}
- {{output}}->GetName()
- {%- if loop.index < operator.outputs|length -%},{%- endif -%}
- {%- endfor -%}
- });
+ allTensors(inputTensors, outputTensors),
+ op->GetInputTensorNames(),
+ op->GetOutputTensorNames());
// Setup model
TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
- {% for input in operator.inputs: -%}
- TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
- {%- endfor %}
+
+ TOSA_RETURN_ON_ERROR(setInputTensors(runner, inputTensors));
// Execute
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
// Extract outputs
- {% for output in operator.outputs: -%}
- TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size));
- {%- endfor %}
+ TOSA_RETURN_ON_ERROR(getOutputTensors(runner, outputTensors));
return tosa_status_valid;
}