// Copyright (c) 2022-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // THIS FILE IS GENERATED. DO NOT EDIT! // See scripts/operator_api/generate_api.py #include "operators.h" #include "model_runner_impl.h" #include "ops/op_factory.h" #define TOSA_RETURN_ON_ERROR(status) \ do \ { \ if (status != 0) \ { \ return tosa_status_error; \ } \ } while (false) #define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \ do \ { \ if (status != GraphStatus::TOSA_VALID) \ { \ auto ustatus = static_cast>(status); \ return static_cast(ustatus); \ } \ } while (false) namespace { tosa::DType translate_client_datatype(tosa_datatype_t type) { switch (type) { case tosa_datatype_fp16_t: return tosa::DType::DType_FP16; case tosa_datatype_fp32_t: return tosa::DType::DType_FP32; case tosa_datatype_bool_t: return tosa::DType::DType_BOOL; default: return tosa::DType::DType_UNKNOWN; } }; tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name) { std::vector shape(tensor.shape, tensor.shape + tensor.num_dims); return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {}); } tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) { switch(mode) { case tosa_mode_nearest: return tosa::ResizeMode_NEAREST; case tosa_mode_max: case tosa_mode_bilinear: return tosa::ResizeMode_BILINEAR; default: return tosa::ResizeMode_UNKNOWN; } } } // namespace extern "C" { {% for operator in operators: %} tosa_status_t tosa_run_{{ operator.name }} ( {%- for arg in operator.arguments: -%} {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} {% if loop.index < operator.arguments|length %},{% endif %} {%- endfor -%} ) { // Create operator attributes {% for arg in operator.serializeArgs: %} {%- if arg.SV == "V": -%} const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}}; {%- else: -%} const {{arg.dType}} {{arg.name}}{{arg.init}}; {%- endif -%} {%- endfor -%} Tosa{{operator.serializeAttType}}Attribute attr {%- if operator.serializeArgs|length > 0 -%} ( {%- for arg in operator.serializeArgs: -%} {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %} {%- endfor -%} ) {%- endif -%}; // Create tensors {% for input in operator.inputs: -%} tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}"); {%- endfor -%} {% for output in operator.outputs: %} tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}"); {%- endfor %} // Create operator auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}}, {%- if operator.serializeAttType != "None" -%} tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute {%- else -%} tosa::Attribute::Attribute_NONE {%- endif -%}, &attr, { {%- for input in operator.inputs: -%} {{input}}->GetName() {%- if loop.index < operator.inputs|length -%},{%- endif -%} {%- endfor -%} }, { {%- for output in operator.outputs: -%} {{output}}->GetName() {%- if loop.index < operator.outputs|length -%},{%- endif -%} {%- endfor -%} }); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op }, { {%- for input in operator.inputs: -%} {{input}}, {%- endfor -%} {%- for output in operator.outputs: -%} {{output}} {%- if loop.index < operator.outputs|length -%},{%- endif -%} {%- endfor -%} }, { {%- for input in operator.inputs: -%} {{input}}->GetName() {%- if loop.index < operator.inputs|length -%},{%- endif -%} {%- endfor -%} }, { {%- for output in operator.outputs: -%} {{output}}->GetName() {%- if loop.index < operator.outputs|length -%},{%- endif -%} {%- endfor -%} }); // Setup model TosaReference::ModelRunnerImpl runner; TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); {% for input in operator.inputs: -%} TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size)); {%- endfor %} // Execute TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); // Extract outputs {% for output in operator.outputs: -%} TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size)); {%- endfor %} return tosa_status_valid; } {% endfor %} } // extern "C"