diff options
Diffstat (limited to 'scripts/operator_api/templates/operators_cc.j2')
-rw-r--r-- | scripts/operator_api/templates/operators_cc.j2 | 176 |
1 files changed, 176 insertions, 0 deletions
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 new file mode 100644 index 0000000..6b0ed6e --- /dev/null +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -0,0 +1,176 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// THIS FILE IS GENERATED. DO NOT EDIT! +// See scripts/operator_api/generate_api.py + +#include "operators.h" +#include "model_runner_impl.h" +#include "ops/op_factory.h" + +#define TOSA_RETURN_ON_ERROR(status) \ + do \ + { \ + if (status != 0) \ + { \ + return tosa_status_error; \ + } \ + } while (false) + +#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \ + do \ + { \ + if (status != GraphStatus::TOSA_VALID) \ + { \ + auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status); \ + return static_cast<tosa_status_t>(ustatus); \ + } \ + } while (false) + +namespace { + +tosa::DType translate_client_datatype(tosa_datatype_t type) +{ + switch (type) + { + case tosa_datatype_fp16_t: + return tosa::DType::DType_FP16; + case tosa_datatype_fp32_t: + return tosa::DType::DType_FP32; + default: + return tosa::DType::DType_UNKNOWN; + } +}; + +tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name) +{ + std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims); + return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {}); +} + +tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) { + switch(mode) { + case tosa_mode_nearest: + return tosa::ResizeMode_NEAREST; + case tosa_mode_max: + case tosa_mode_bilinear: + return tosa::ResizeMode_BILINEAR; + default: + return tosa::ResizeMode_UNKNOWN; + } +} + +} // namespace + +extern "C" +{ + {% for operator in operators: %} + tosa_status_t tosa_run_{{ operator.name }} ( + {%- for arg in operator.arguments: -%} + {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} + {% if loop.index < operator.arguments|length %},{% endif %} + {%- endfor -%} + ) + { + // Create operator attributes + {% for arg in operator.serializeArgs: %} + {%- if arg.SV == "V": -%} + const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}}; + {%- else: -%} + const {{arg.dType}} {{arg.name}}{{arg.init}}; + {%- endif -%} + {%- endfor -%} + + Tosa{{operator.serializeAttType}}Attribute attr + {%- if operator.serializeArgs|length > 0 -%} + ( + {%- for arg in operator.serializeArgs: -%} + {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %} + {%- endfor -%} + ) + {%- endif -%}; + + // Create tensors + {% for input in operator.inputs: -%} + tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}"); + {%- endfor -%} + {% for output in operator.outputs: %} + tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}"); + {%- endfor %} + + // Create operator + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}}, + {%- if operator.serializeAttType != "None" -%} + tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute + {%- else -%} + tosa::Attribute::Attribute_NONE + {%- endif -%}, + &attr, { + {%- for input in operator.inputs: -%} + {{input}}->GetName() + {%- if loop.index < operator.inputs|length -%},{%- endif -%} + {%- endfor -%} + }, + { + {%- for output in operator.outputs: -%} + {{output}}->GetName() + {%- if loop.index < operator.outputs|length -%},{%- endif -%} + {%- endfor -%} + }); + + // Create a tosa single-op basic block + tosa::TosaSerializationBasicBlock block("{{operator.name}}", { op }, + { + {%- for input in operator.inputs: -%} + {{input}}, + {%- endfor -%} + {%- for output in operator.outputs: -%} + {{output}} + {%- if loop.index < operator.outputs|length -%},{%- endif -%} + {%- endfor -%} + }, + { + {%- for input in operator.inputs: -%} + {{input}}->GetName() + {%- if loop.index < operator.inputs|length -%},{%- endif -%} + {%- endfor -%} + }, + { + {%- for output in operator.outputs: -%} + {{output}}->GetName() + {%- if loop.index < operator.outputs|length -%},{%- endif -%} + {%- endfor -%} + }); + + // Setup model + TosaReference::ModelRunnerImpl runner; + TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); + {% for input in operator.inputs: -%} + TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size)); + {%- endfor %} + + // Execute + TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); + + // Extract outputs + {% for output in operator.outputs: -%} + TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size)); + {%- endfor %} + + return tosa_status_valid; + } + {% endfor %} + +} // extern "C"
\ No newline at end of file |