aboutsummaryrefslogtreecommitdiff
path: root/scripts/operator_api/templates/operators_cc.j2
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/operator_api/templates/operators_cc.j2')
-rw-r--r--scripts/operator_api/templates/operators_cc.j2176
1 files changed, 176 insertions, 0 deletions
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
new file mode 100644
index 0000000..6b0ed6e
--- /dev/null
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -0,0 +1,176 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// THIS FILE IS GENERATED. DO NOT EDIT!
+// See scripts/operator_api/generate_api.py
+
+#include "operators.h"
+#include "model_runner_impl.h"
+#include "ops/op_factory.h"
+
+#define TOSA_RETURN_ON_ERROR(status) \
+ do \
+ { \
+ if (status != 0) \
+ { \
+ return tosa_status_error; \
+ } \
+ } while (false)
+
+#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \
+ do \
+ { \
+ if (status != GraphStatus::TOSA_VALID) \
+ { \
+ auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status); \
+ return static_cast<tosa_status_t>(ustatus); \
+ } \
+ } while (false)
+
+namespace {
+
+tosa::DType translate_client_datatype(tosa_datatype_t type)
+{
+ switch (type)
+ {
+ case tosa_datatype_fp16_t:
+ return tosa::DType::DType_FP16;
+ case tosa_datatype_fp32_t:
+ return tosa::DType::DType_FP32;
+ default:
+ return tosa::DType::DType_UNKNOWN;
+ }
+};
+
+tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
+{
+ std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
+ return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
+}
+
+tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
+ switch(mode) {
+ case tosa_mode_nearest:
+ return tosa::ResizeMode_NEAREST;
+ case tosa_mode_max:
+ case tosa_mode_bilinear:
+ return tosa::ResizeMode_BILINEAR;
+ default:
+ return tosa::ResizeMode_UNKNOWN;
+ }
+}
+
+} // namespace
+
+extern "C"
+{
+ {% for operator in operators: %}
+ tosa_status_t tosa_run_{{ operator.name }} (
+ {%- for arg in operator.arguments: -%}
+ {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
+ {% if loop.index < operator.arguments|length %},{% endif %}
+ {%- endfor -%}
+ )
+ {
+ // Create operator attributes
+ {% for arg in operator.serializeArgs: %}
+ {%- if arg.SV == "V": -%}
+ const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}};
+ {%- else: -%}
+ const {{arg.dType}} {{arg.name}}{{arg.init}};
+ {%- endif -%}
+ {%- endfor -%}
+
+ Tosa{{operator.serializeAttType}}Attribute attr
+ {%- if operator.serializeArgs|length > 0 -%}
+ (
+ {%- for arg in operator.serializeArgs: -%}
+ {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %}
+ {%- endfor -%}
+ )
+ {%- endif -%};
+
+ // Create tensors
+ {% for input in operator.inputs: -%}
+ tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}");
+ {%- endfor -%}
+ {% for output in operator.outputs: %}
+ tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}");
+ {%- endfor %}
+
+ // Create operator
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}},
+ {%- if operator.serializeAttType != "None" -%}
+ tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute
+ {%- else -%}
+ tosa::Attribute::Attribute_NONE
+ {%- endif -%},
+ &attr, {
+ {%- for input in operator.inputs: -%}
+ {{input}}->GetName()
+ {%- if loop.index < operator.inputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ },
+ {
+ {%- for output in operator.outputs: -%}
+ {{output}}->GetName()
+ {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ });
+
+ // Create a tosa single-op basic block
+ tosa::TosaSerializationBasicBlock block("{{operator.name}}", { op },
+ {
+ {%- for input in operator.inputs: -%}
+ {{input}},
+ {%- endfor -%}
+ {%- for output in operator.outputs: -%}
+ {{output}}
+ {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ },
+ {
+ {%- for input in operator.inputs: -%}
+ {{input}}->GetName()
+ {%- if loop.index < operator.inputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ },
+ {
+ {%- for output in operator.outputs: -%}
+ {{output}}->GetName()
+ {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ });
+
+ // Setup model
+ TosaReference::ModelRunnerImpl runner;
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
+ {% for input in operator.inputs: -%}
+ TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
+ {%- endfor %}
+
+ // Execute
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
+
+ // Extract outputs
+ {% for output in operator.outputs: -%}
+ TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size));
+ {%- endfor %}
+
+ return tosa_status_valid;
+ }
+ {% endfor %}
+
+} // extern "C" \ No newline at end of file