aboutsummaryrefslogtreecommitdiff
path: root/scripts/operator_api/templates
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/operator_api/templates')
-rw-r--r--scripts/operator_api/templates/operators_cc.j2176
-rw-r--r--scripts/operator_api/templates/operators_h.j274
2 files changed, 250 insertions, 0 deletions
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
new file mode 100644
index 0000000..6b0ed6e
--- /dev/null
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -0,0 +1,176 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// THIS FILE IS GENERATED. DO NOT EDIT!
+// See scripts/operator_api/generate_api.py
+
+#include "operators.h"
+#include "model_runner_impl.h"
+#include "ops/op_factory.h"
+
+#define TOSA_RETURN_ON_ERROR(status) \
+ do \
+ { \
+ if (status != 0) \
+ { \
+ return tosa_status_error; \
+ } \
+ } while (false)
+
+#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \
+ do \
+ { \
+ if (status != GraphStatus::TOSA_VALID) \
+ { \
+ auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status); \
+ return static_cast<tosa_status_t>(ustatus); \
+ } \
+ } while (false)
+
+namespace {
+
+tosa::DType translate_client_datatype(tosa_datatype_t type)
+{
+ switch (type)
+ {
+ case tosa_datatype_fp16_t:
+ return tosa::DType::DType_FP16;
+ case tosa_datatype_fp32_t:
+ return tosa::DType::DType_FP32;
+ default:
+ return tosa::DType::DType_UNKNOWN;
+ }
+};
+
+tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
+{
+ std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
+ return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
+}
+
+tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
+ switch(mode) {
+ case tosa_mode_nearest:
+ return tosa::ResizeMode_NEAREST;
+ case tosa_mode_max:
+ case tosa_mode_bilinear:
+ return tosa::ResizeMode_BILINEAR;
+ default:
+ return tosa::ResizeMode_UNKNOWN;
+ }
+}
+
+} // namespace
+
+extern "C"
+{
+ {% for operator in operators: %}
+ tosa_status_t tosa_run_{{ operator.name }} (
+ {%- for arg in operator.arguments: -%}
+ {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
+ {% if loop.index < operator.arguments|length %},{% endif %}
+ {%- endfor -%}
+ )
+ {
+ // Create operator attributes
+ {% for arg in operator.serializeArgs: %}
+ {%- if arg.SV == "V": -%}
+ const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}};
+ {%- else: -%}
+ const {{arg.dType}} {{arg.name}}{{arg.init}};
+ {%- endif -%}
+ {%- endfor -%}
+
+ Tosa{{operator.serializeAttType}}Attribute attr
+ {%- if operator.serializeArgs|length > 0 -%}
+ (
+ {%- for arg in operator.serializeArgs: -%}
+ {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %}
+ {%- endfor -%}
+ )
+ {%- endif -%};
+
+ // Create tensors
+ {% for input in operator.inputs: -%}
+ tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}");
+ {%- endfor -%}
+ {% for output in operator.outputs: %}
+ tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}");
+ {%- endfor %}
+
+ // Create operator
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}},
+ {%- if operator.serializeAttType != "None" -%}
+ tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute
+ {%- else -%}
+ tosa::Attribute::Attribute_NONE
+ {%- endif -%},
+ &attr, {
+ {%- for input in operator.inputs: -%}
+ {{input}}->GetName()
+ {%- if loop.index < operator.inputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ },
+ {
+ {%- for output in operator.outputs: -%}
+ {{output}}->GetName()
+ {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ });
+
+ // Create a tosa single-op basic block
+ tosa::TosaSerializationBasicBlock block("{{operator.name}}", { op },
+ {
+ {%- for input in operator.inputs: -%}
+ {{input}},
+ {%- endfor -%}
+ {%- for output in operator.outputs: -%}
+ {{output}}
+ {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ },
+ {
+ {%- for input in operator.inputs: -%}
+ {{input}}->GetName()
+ {%- if loop.index < operator.inputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ },
+ {
+ {%- for output in operator.outputs: -%}
+ {{output}}->GetName()
+ {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ });
+
+ // Setup model
+ TosaReference::ModelRunnerImpl runner;
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
+ {% for input in operator.inputs: -%}
+ TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
+ {%- endfor %}
+
+ // Execute
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
+
+ // Extract outputs
+ {% for output in operator.outputs: -%}
+ TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size));
+ {%- endfor %}
+
+ return tosa_status_valid;
+ }
+ {% endfor %}
+
+} // extern "C" \ No newline at end of file
diff --git a/scripts/operator_api/templates/operators_h.j2 b/scripts/operator_api/templates/operators_h.j2
new file mode 100644
index 0000000..803b76a
--- /dev/null
+++ b/scripts/operator_api/templates/operators_h.j2
@@ -0,0 +1,74 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// THIS FILE IS GENERATED. DO NOT EDIT!
+// See scripts/operator_api/generate_api.py
+
+#ifndef OPERATORS_H_
+#define OPERATORS_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+ // Note status needs to be aligned with graph_status
+ enum tosa_status_t
+ {
+ tosa_status_valid = 0,
+ tosa_status_unpredictable = 1,
+ tosa_status_error = 2
+ };
+
+ enum tosa_mode_t
+ {
+ tosa_mode_unknown = 0,
+ tosa_mode_nearest = 1,
+ tosa_mode_bilinear = 2,
+ tosa_mode_min = 3,
+ tosa_mode_max = 4
+ };
+
+ enum tosa_datatype_t
+ {
+ {% for dataType in dataTypes: -%}
+ {{dataType}} = {{loop.index-1}},
+ {% endfor -%}
+ };
+
+ struct tosa_tensor_t
+ {
+ int32_t* shape;
+ int32_t num_dims;
+ tosa_datatype_t data_type;
+ uint8_t* data;
+ size_t size;
+ };
+
+ {% for operator in operators: %}
+ tosa_status_t tosa_run_{{ operator.name }} (
+ {%- for arg in operator.arguments: -%}
+ {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
+ {% if loop.index < operator.arguments|length %},{% endif %}
+ {%- endfor -%});
+ {% endfor %}
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+#endif // OPERATORS_H_ \ No newline at end of file