diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-10-04 14:17:55 +0100 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-10-04 18:45:40 +0000 |
commit | b20b0c9cb4c85bb9a3c901d5acaf421d84656850 (patch) | |
tree | 8af9d6338b62bc65e7e4292427f06a4ef0346312 /reference_model/src/generate/generate_entry.cc | |
parent | 12ee1a79374b451602784fd6dc8f63886bf2a997 (diff) | |
download | reference_model-b20b0c9cb4c85bb9a3c901d5acaf421d84656850.tar.gz |
Add initial TOSA MI generator support
Add support for dot-product MatMul - test set 0
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: Ifd15b42570014b634f59c94a1fd1cd56bac79ea4
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Diffstat (limited to 'reference_model/src/generate/generate_entry.cc')
-rw-r--r-- | reference_model/src/generate/generate_entry.cc | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/reference_model/src/generate/generate_entry.cc b/reference_model/src/generate/generate_entry.cc new file mode 100644 index 0000000..95dbe8f --- /dev/null +++ b/reference_model/src/generate/generate_entry.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "generate.h" + +#include "generate_dot_product.h" +#include "generate_utils.h" + +#include "func_debug.h" +#include "model_common.h" + +namespace TosaReference +{ + +bool generate(const GenerateConfig& cfg, void* data, size_t size) +{ + switch (cfg.generatorType) + { + case GeneratorType::DotProduct: { + return generateDotProduct(cfg, data, size); + break; + } + default: { + WARNING("[Generator] Unsupported generation mode."); + break; + } + } + return false; +} + +} // namespace TosaReference + +extern "C" +{ + bool tgd_generate_data(const char* config_json, const char* tensor_name, void* data, size_t size) + { + // Check inputs for nullptr + if (!config_json || !tensor_name || !data) + { + WARNING("[Generator] One of the inputs is missing."); + return false; + } + + // Check JSON config validity + auto cfg = TosaReference::parseGenerateConfig(config_json, tensor_name); + if (!cfg) + { + WARNING("[Generator] Invalid json config."); + return false; + } + + // Check size + const size_t totalBytesNeeded = + TosaReference::numElementsFromShape(cfg->shape) * TosaReference::elementSizeFromType(cfg->dataType); + if (totalBytesNeeded > size) + { + WARNING("[Generator] Not enough space in provided buffer."); + return false; + } + + // Run generator + return generate(cfg.value(), data, size); + } +} // extern "C" |