aboutsummaryrefslogtreecommitdiff
path: root/scripts/operator_api/templates/operators_cc.j2
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/operator_api/templates/operators_cc.j2')
-rw-r--r--scripts/operator_api/templates/operators_cc.j225
1 files changed, 21 insertions, 4 deletions
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
index 37a0af6..a8f1c24 100644
--- a/scripts/operator_api/templates/operators_cc.j2
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -45,12 +45,28 @@ tosa::DType translate_client_datatype(tosa_datatype_t type)
{
switch (type)
{
+ case tosa_datatype_bf16_t:
+ return tosa::DType::DType_BF16;
+ case tosa_datatype_bool_t:
+ return tosa::DType::DType_BOOL;
case tosa_datatype_fp16_t:
return tosa::DType::DType_FP16;
case tosa_datatype_fp32_t:
return tosa::DType::DType_FP32;
- case tosa_datatype_bool_t:
- return tosa::DType::DType_BOOL;
+ case tosa_datatype_int16_t:
+ return tosa::DType::DType_INT16;
+ case tosa_datatype_int32_t:
+ return tosa::DType::DType_INT32;
+ case tosa_datatype_int48_t:
+ return tosa::DType::DType_INT48;
+ case tosa_datatype_int4_t:
+ return tosa::DType::DType_INT4;
+ case tosa_datatype_int8_t:
+ return tosa::DType::DType_INT8;
+ case tosa_datatype_uint16_t:
+ return tosa::DType::DType_UINT16;
+ case tosa_datatype_uint8_t:
+ return tosa::DType::DType_UINT8;
default:
return tosa::DType::DType_UNKNOWN;
}
@@ -83,7 +99,8 @@ extern "C"
{%- for arg in operator.arguments: -%}
{% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
{% if loop.index < operator.arguments|length %},{% endif %}
- {%- endfor -%}
+ {%- endfor -%}, const func_config_t &func_config,
+ const func_debug_t &func_debug
)
{
// Create operator attributes
@@ -157,7 +174,7 @@ extern "C"
});
// Setup model
- TosaReference::ModelRunnerImpl runner;
+ TosaReference::ModelRunnerImpl runner(func_config, func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
{% for input in operator.inputs: -%}
TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));