diff options
Diffstat (limited to 'scripts/operator_api/templates/operators_cc.j2')
-rw-r--r-- | scripts/operator_api/templates/operators_cc.j2 | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 index a8f1c24..6b6f864 100644 --- a/scripts/operator_api/templates/operators_cc.j2 +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -67,6 +67,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type) return tosa::DType::DType_UINT16; case tosa_datatype_uint8_t: return tosa::DType::DType_UINT8; + case tosa_datatype_shape_t: + return tosa::DType::DType_SHAPE; default: return tosa::DType::DType_UNKNOWN; } @@ -99,24 +101,24 @@ extern "C" {%- for arg in operator.arguments: -%} {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} {% if loop.index < operator.arguments|length %},{% endif %} - {%- endfor -%}, const func_config_t &func_config, - const func_debug_t &func_debug + {%- endfor -%},const func_ctx_t& func_ctx ) { // Create operator attributes - {% for arg in operator.serializeArgs: %} - {%- if arg.SV == "V": -%} - const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}}; - {%- else: -%} - const {{arg.dType}} {{arg.name}}{{arg.init}}; - {%- endif -%} + {% for att in operator.serialLibAtts: -%} + {{att.init}} {%- endfor -%} Tosa{{operator.serializeAttType}}Attribute attr - {%- if operator.serializeArgs|length > 0 -%} + {%- if operator.serialLibAtts|length > 0 -%} ( - {%- for arg in operator.serializeArgs: -%} - {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %} + {%- for att in operator.serialLibAtts: -%} + {%- if att.init == "" -%} + client_{{att.name}} + {%- else -%} + {{att.name}} + {%- endif -%} + {% if loop.index < operator.serialLibAtts|length %}, {% endif %} {%- endfor -%} ) {%- endif -%}; @@ -174,7 +176,7 @@ extern "C" }); // Setup model - TosaReference::ModelRunnerImpl runner(func_config, func_debug); + TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug); TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); {% for input in operator.inputs: -%} TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size)); |