diff options
Diffstat (limited to 'reference_model/src/generate')
-rw-r--r-- | reference_model/src/generate/generate_entry.cc | 7 | ||||
-rw-r--r-- | reference_model/src/generate/generate_fixed_data.cc | 56 | ||||
-rw-r--r-- | reference_model/src/generate/generate_fixed_data.h | 34 | ||||
-rw-r--r-- | reference_model/src/generate/generate_utils.cc | 15 | ||||
-rw-r--r-- | reference_model/src/generate/generate_utils.h | 12 |
5 files changed, 122 insertions, 2 deletions
diff --git a/reference_model/src/generate/generate_entry.cc b/reference_model/src/generate/generate_entry.cc index 741cd79..91b2fc7 100644 --- a/reference_model/src/generate/generate_entry.cc +++ b/reference_model/src/generate/generate_entry.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ #include "generate.h" #include "generate_dot_product.h" +#include "generate_fixed_data.h" #include "generate_pseudo_random.h" #include "generate_utils.h" @@ -36,6 +37,10 @@ bool generate(const GenerateConfig& cfg, void* data, size_t size) return generatePseudoRandom(cfg, data, size); break; } + case GeneratorType::FixedData: { + return generateFixedData(cfg, data, size); + break; + } default: { WARNING("[Generator] Unsupported generation mode."); break; diff --git a/reference_model/src/generate/generate_fixed_data.cc b/reference_model/src/generate/generate_fixed_data.cc new file mode 100644 index 0000000..d83ee58 --- /dev/null +++ b/reference_model/src/generate/generate_fixed_data.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "generate.h" +#include "generate_utils.h" + +#include <algorithm> +#include <array> +#include <iterator> +#include <type_traits> +#include <vector> + +namespace TosaReference +{ +bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size) +{ + // Check we support the operator + if (cfg.opType == Op::Op_UNKNOWN) + { + WARNING("[Generator][FD] Unknown operator."); + return false; + } + + switch (cfg.dataType) + { + case DType::DType_SHAPE: { + int32_t* outData = reinterpret_cast<int32_t*>(data); + std::vector<int32_t> inData = cfg.fixedDataInfo.data; + const auto T = TosaReference::numElementsFromShape(cfg.shape); + if (T != inData.size()) + { + WARNING("[Generator][FD] Size does not match."); + return false; + } + for (auto t = 0; t < T; t++) + { + outData[t] = inData[t]; + } + return true; + } + default: + WARNING("[Generator][FD] Unsupported type."); + return false; + } +} +} // namespace TosaReference diff --git a/reference_model/src/generate/generate_fixed_data.h b/reference_model/src/generate/generate_fixed_data.h new file mode 100644 index 0000000..50371c8 --- /dev/null +++ b/reference_model/src/generate/generate_fixed_data.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GENERATE_FIXED_DATA_H_ +#define GENERATE_FIXED_DATA_H_ + +#include "generate_utils.h" + +namespace TosaReference +{ + +/// \brief Perform fixed data generation +/// +/// \param cfg Generator related meta-data +/// \param data Buffer to generate the data to +/// \param size Size of the buffer +/// +/// \return True on successful generation +bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size); + +}; // namespace TosaReference + +#endif // GENERATE_FIXED_DATA_H_ diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index c16d1c6..9eda0b6 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -33,6 +33,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(DType, { DType::DType_FP16, "FP16" }, { DType::DType_BF16, "BF16" }, { DType::DType_FP32, "FP32" }, + { DType::DType_SHAPE, "SHAPE" }, }) NLOHMANN_JSON_SERIALIZE_ENUM(Op, @@ -93,6 +94,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType, { GeneratorType::OpFullRange, "OP_FULL_RANGE" }, { GeneratorType::OpBoundary, "OP_BOUNDARY" }, { GeneratorType::OpSpecial, "OP_SPECIAL" }, + { GeneratorType::FixedData, "FIXED_DATA" }, }) // NOTE: This assumes it's VARIABLE if the InputType is not recognized @@ -130,6 +132,11 @@ void from_json(const nlohmann::json& j, PseudoRandomInfo& pseudoRandomInfo) } } +void from_json(const nlohmann::json& j, FixedDataInfo& fixedDataInfo) +{ + j.at("data").get_to(fixedDataInfo.data); +} + void from_json(const nlohmann::json& j, GenerateConfig& cfg) { j.at("data_type").get_to(cfg.dataType); @@ -158,6 +165,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg) { j.at("pseudo_random_info").get_to(cfg.pseudoRandomInfo); } + + // Set up defaults for fixedDataInfo + cfg.fixedDataInfo.data = std::vector<int32_t>(); + if (j.contains("fixed_data_info")) + { + j.at("fixed_data_info").get_to(cfg.fixedDataInfo); + } } std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName) @@ -209,6 +223,7 @@ size_t elementSizeFromType(DType type) return 2; case DType::DType_INT32: case DType::DType_FP32: + case DType::DType_SHAPE: return 4; default: return 0; diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h index f9ec713..697b404 100644 --- a/reference_model/src/generate/generate_utils.h +++ b/reference_model/src/generate/generate_utils.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ enum class GeneratorType OpFullRange, OpBoundary, OpSpecial, + FixedData, }; /// \brief Supported input types @@ -65,6 +66,14 @@ struct PseudoRandomInfo bool round; }; +/// \brief Fixed data generator meta-data +struct FixedDataInfo +{ + FixedDataInfo() = default; + + std::vector<int32_t> data; +}; + /// \brief Generator configuration struct GenerateConfig { @@ -76,6 +85,7 @@ struct GenerateConfig tosa::Op opType; DotProductInfo dotProductInfo; PseudoRandomInfo pseudoRandomInfo; + FixedDataInfo fixedDataInfo; }; /// \brief Parse the generator config when given in JSON form |