aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_entry.cc
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-10-04 14:17:55 +0100
committerEric Kunze <eric.kunze@arm.com>2023-10-04 18:45:40 +0000
commitb20b0c9cb4c85bb9a3c901d5acaf421d84656850 (patch)
tree8af9d6338b62bc65e7e4292427f06a4ef0346312 /reference_model/src/generate/generate_entry.cc
parent12ee1a79374b451602784fd6dc8f63886bf2a997 (diff)
downloadreference_model-b20b0c9cb4c85bb9a3c901d5acaf421d84656850.tar.gz
Add initial TOSA MI generator support
Add support for dot-product MatMul - test set 0 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: Ifd15b42570014b634f59c94a1fd1cd56bac79ea4 Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Diffstat (limited to 'reference_model/src/generate/generate_entry.cc')
-rw-r--r--reference_model/src/generate/generate_entry.cc75
1 files changed, 75 insertions, 0 deletions
diff --git a/reference_model/src/generate/generate_entry.cc b/reference_model/src/generate/generate_entry.cc
new file mode 100644
index 0000000..95dbe8f
--- /dev/null
+++ b/reference_model/src/generate/generate_entry.cc
@@ -0,0 +1,75 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "generate.h"
+
+#include "generate_dot_product.h"
+#include "generate_utils.h"
+
+#include "func_debug.h"
+#include "model_common.h"
+
+namespace TosaReference
+{
+
+bool generate(const GenerateConfig& cfg, void* data, size_t size)
+{
+ switch (cfg.generatorType)
+ {
+ case GeneratorType::DotProduct: {
+ return generateDotProduct(cfg, data, size);
+ break;
+ }
+ default: {
+ WARNING("[Generator] Unsupported generation mode.");
+ break;
+ }
+ }
+ return false;
+}
+
+} // namespace TosaReference
+
+extern "C"
+{
+ bool tgd_generate_data(const char* config_json, const char* tensor_name, void* data, size_t size)
+ {
+ // Check inputs for nullptr
+ if (!config_json || !tensor_name || !data)
+ {
+ WARNING("[Generator] One of the inputs is missing.");
+ return false;
+ }
+
+ // Check JSON config validity
+ auto cfg = TosaReference::parseGenerateConfig(config_json, tensor_name);
+ if (!cfg)
+ {
+ WARNING("[Generator] Invalid json config.");
+ return false;
+ }
+
+ // Check size
+ const size_t totalBytesNeeded =
+ TosaReference::numElementsFromShape(cfg->shape) * TosaReference::elementSizeFromType(cfg->dataType);
+ if (totalBytesNeeded > size)
+ {
+ WARNING("[Generator] Not enough space in provided buffer.");
+ return false;
+ }
+
+ // Run generator
+ return generate(cfg.value(), data, size);
+ }
+} // extern "C"