aboutsummaryrefslogtreecommitdiff
path: root/reference_model/include
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-10-30 10:18:45 -0700
committerJerry Ge <jerry.ge@arm.com>2023-11-16 21:56:43 +0000
commit5637a8606bc3caeec3c590350de770c7fcec8dd7 (patch)
treeb83a0d33d8a76c77cf560026e6cc8e8db22ad712 /reference_model/include
parentd3797f014811ca1ea876989b4839a8297eb1731e (diff)
downloadreference_model-5637a8606bc3caeec3c590350de770c7fcec8dd7.tar.gz
Support loading shared libraries for custom operators
- Add a new command line option to allow users to specify a custom defined dll library - Add a custom registry to store all registered libraries - Add a dummy example (custom_op_example.cpp) for demonstrating this new feature Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I7c360835933f77e33fcbd772cabfe01d82282d47
Diffstat (limited to 'reference_model/include')
-rw-r--r--reference_model/include/custom_op_interface.h38
-rw-r--r--reference_model/include/custom_registry.h88
-rw-r--r--reference_model/include/func_config.h1
3 files changed, 127 insertions, 0 deletions
diff --git a/reference_model/include/custom_op_interface.h b/reference_model/include/custom_op_interface.h
new file mode 100644
index 0000000..aea9086
--- /dev/null
+++ b/reference_model/include/custom_op_interface.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef CUSTOMOPINTERFACE_H
+#define CUSTOMOPINTERFACE_H
+
+#include "tensor.h"
+#include <vector>
+
+using namespace tosa;
+
+namespace TosaReference
+{
+class CustomOpInterface
+{
+public:
+ CustomOpInterface() = default;
+ virtual std::string getDomainName() const = 0;
+ virtual std::string getOperatorName() const = 0;
+ virtual int eval(std::vector<TosaReference::Tensor*>& input_tensors,
+ std::vector<TosaReference::Tensor*>& output_tensors,
+ const std::string& implementation_attrs) = 0;
+ virtual std::string getVersion() const = 0;
+};
+} // namespace TosaReference
+
+#endif
diff --git a/reference_model/include/custom_registry.h b/reference_model/include/custom_registry.h
new file mode 100644
index 0000000..f1a9b8c
--- /dev/null
+++ b/reference_model/include/custom_registry.h
@@ -0,0 +1,88 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef CUSTOMREGISTRY_H
+#define CUSTOMREGISTRY_H
+
+#include "custom_op_interface.h"
+#include <dlfcn.h>
+#include <unordered_map>
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+typedef CustomOpInterface* (*op_creation_function_t)();
+typedef int (*registration_callback_t)(const std::string& domain_name,
+ const std::string& operator_name,
+ const op_creation_function_t& op_creation_function);
+
+class MasterRegistry
+{
+public:
+ static int register_function(const std::string& domain_name,
+ const std::string& operator_name,
+ const op_creation_function_t& op_creation_function)
+ {
+ std::string unique_id = domain_name + "::" + operator_name;
+ MasterRegistry& instance = get_instance();
+ if (instance.op_creation_map.find(unique_id) != instance.op_creation_map.end())
+ {
+ std::cout << std::endl;
+ printf("domain_name: %s and operator_name: %s pair has already been registered", domain_name.c_str(),
+ operator_name.c_str());
+ return 1;
+ }
+ instance.op_creation_map[unique_id] = op_creation_function;
+ return 0;
+ }
+
+ static MasterRegistry& get_instance()
+ {
+ static MasterRegistry instance;
+ return instance;
+ }
+
+ MasterRegistry(const MasterRegistry&) = delete;
+ void operator=(const MasterRegistry&) = delete;
+
+ std::unordered_map<std::string, op_creation_function_t> get_ops() const
+ {
+ return op_creation_map;
+ }
+
+ static op_creation_function_t get_op(const std::string& domain_name, const std::string& operator_name)
+ {
+ std::string unique_id = domain_name + "::" + operator_name;
+ MasterRegistry& instance = get_instance();
+ auto all_ops_map = instance.get_ops();
+ if (all_ops_map.find(unique_id) == all_ops_map.end())
+ {
+ return nullptr;
+ }
+ else
+ {
+ op_creation_function_t& op_creation_function = all_ops_map[unique_id];
+ return op_creation_function;
+ }
+ }
+
+private:
+ MasterRegistry() = default;
+ std::unordered_map<std::string, op_creation_function_t> op_creation_map;
+};
+} // namespace TosaReference
+
+#endif
diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h
index 22e7e2c..97afa82 100644
--- a/reference_model/include/func_config.h
+++ b/reference_model/include/func_config.h
@@ -54,6 +54,7 @@ struct func_config_t
uint32_t dump_intermediates = 0;
uint32_t initialize_variable_tensor_from_numpy = 0;
std::string fp_format = "0.5";
+ std::string custom_op_lib_path = "";
uint32_t precise_mode = 0;
bool abs_mode = 0; // set in main as second run of precise_mode
bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian()