aboutsummaryrefslogtreecommitdiff
path: root/scripts/operator_api/templates/operators_cc.j2
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/operator_api/templates/operators_cc.j2')
-rw-r--r--scripts/operator_api/templates/operators_cc.j226
1 files changed, 14 insertions, 12 deletions
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
index a8f1c24..6b6f864 100644
--- a/scripts/operator_api/templates/operators_cc.j2
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -67,6 +67,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type)
return tosa::DType::DType_UINT16;
case tosa_datatype_uint8_t:
return tosa::DType::DType_UINT8;
+ case tosa_datatype_shape_t:
+ return tosa::DType::DType_SHAPE;
default:
return tosa::DType::DType_UNKNOWN;
}
@@ -99,24 +101,24 @@ extern "C"
{%- for arg in operator.arguments: -%}
{% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
{% if loop.index < operator.arguments|length %},{% endif %}
- {%- endfor -%}, const func_config_t &func_config,
- const func_debug_t &func_debug
+ {%- endfor -%},const func_ctx_t& func_ctx
)
{
// Create operator attributes
- {% for arg in operator.serializeArgs: %}
- {%- if arg.SV == "V": -%}
- const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}};
- {%- else: -%}
- const {{arg.dType}} {{arg.name}}{{arg.init}};
- {%- endif -%}
+ {% for att in operator.serialLibAtts: -%}
+ {{att.init}}
{%- endfor -%}
Tosa{{operator.serializeAttType}}Attribute attr
- {%- if operator.serializeArgs|length > 0 -%}
+ {%- if operator.serialLibAtts|length > 0 -%}
(
- {%- for arg in operator.serializeArgs: -%}
- {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %}
+ {%- for att in operator.serialLibAtts: -%}
+ {%- if att.init == "" -%}
+ client_{{att.name}}
+ {%- else -%}
+ {{att.name}}
+ {%- endif -%}
+ {% if loop.index < operator.serialLibAtts|length %}, {% endif %}
{%- endfor -%}
)
{%- endif -%};
@@ -174,7 +176,7 @@ extern "C"
});
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
{% for input in operator.inputs: -%}
TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));