aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src')
-rw-r--r--reference_model/src/generate/generate_entry.cc7
-rw-r--r--reference_model/src/generate/generate_fixed_data.cc56
-rw-r--r--reference_model/src/generate/generate_fixed_data.h34
-rw-r--r--reference_model/src/generate/generate_utils.cc15
-rw-r--r--reference_model/src/generate/generate_utils.h12
5 files changed, 122 insertions, 2 deletions
diff --git a/reference_model/src/generate/generate_entry.cc b/reference_model/src/generate/generate_entry.cc
index 741cd79..91b2fc7 100644
--- a/reference_model/src/generate/generate_entry.cc
+++ b/reference_model/src/generate/generate_entry.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
#include "generate.h"
#include "generate_dot_product.h"
+#include "generate_fixed_data.h"
#include "generate_pseudo_random.h"
#include "generate_utils.h"
@@ -36,6 +37,10 @@ bool generate(const GenerateConfig& cfg, void* data, size_t size)
return generatePseudoRandom(cfg, data, size);
break;
}
+ case GeneratorType::FixedData: {
+ return generateFixedData(cfg, data, size);
+ break;
+ }
default: {
WARNING("[Generator] Unsupported generation mode.");
break;
diff --git a/reference_model/src/generate/generate_fixed_data.cc b/reference_model/src/generate/generate_fixed_data.cc
new file mode 100644
index 0000000..d83ee58
--- /dev/null
+++ b/reference_model/src/generate/generate_fixed_data.cc
@@ -0,0 +1,56 @@
+// Copyright (c) 2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "generate.h"
+#include "generate_utils.h"
+
+#include <algorithm>
+#include <array>
+#include <iterator>
+#include <type_traits>
+#include <vector>
+
+namespace TosaReference
+{
+bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size)
+{
+ // Check we support the operator
+ if (cfg.opType == Op::Op_UNKNOWN)
+ {
+ WARNING("[Generator][FD] Unknown operator.");
+ return false;
+ }
+
+ switch (cfg.dataType)
+ {
+ case DType::DType_SHAPE: {
+ int32_t* outData = reinterpret_cast<int32_t*>(data);
+ std::vector<int32_t> inData = cfg.fixedDataInfo.data;
+ const auto T = TosaReference::numElementsFromShape(cfg.shape);
+ if (T != inData.size())
+ {
+ WARNING("[Generator][FD] Size does not match.");
+ return false;
+ }
+ for (auto t = 0; t < T; t++)
+ {
+ outData[t] = inData[t];
+ }
+ return true;
+ }
+ default:
+ WARNING("[Generator][FD] Unsupported type.");
+ return false;
+ }
+}
+} // namespace TosaReference
diff --git a/reference_model/src/generate/generate_fixed_data.h b/reference_model/src/generate/generate_fixed_data.h
new file mode 100644
index 0000000..50371c8
--- /dev/null
+++ b/reference_model/src/generate/generate_fixed_data.h
@@ -0,0 +1,34 @@
+// Copyright (c) 2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GENERATE_FIXED_DATA_H_
+#define GENERATE_FIXED_DATA_H_
+
+#include "generate_utils.h"
+
+namespace TosaReference
+{
+
+/// \brief Perform fixed data generation
+///
+/// \param cfg Generator related meta-data
+/// \param data Buffer to generate the data to
+/// \param size Size of the buffer
+///
+/// \return True on successful generation
+bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size);
+
+}; // namespace TosaReference
+
+#endif // GENERATE_FIXED_DATA_H_
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index c16d1c6..9eda0b6 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -33,6 +33,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(DType,
{ DType::DType_FP16, "FP16" },
{ DType::DType_BF16, "BF16" },
{ DType::DType_FP32, "FP32" },
+ { DType::DType_SHAPE, "SHAPE" },
})
NLOHMANN_JSON_SERIALIZE_ENUM(Op,
@@ -93,6 +94,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType,
{ GeneratorType::OpFullRange, "OP_FULL_RANGE" },
{ GeneratorType::OpBoundary, "OP_BOUNDARY" },
{ GeneratorType::OpSpecial, "OP_SPECIAL" },
+ { GeneratorType::FixedData, "FIXED_DATA" },
})
// NOTE: This assumes it's VARIABLE if the InputType is not recognized
@@ -130,6 +132,11 @@ void from_json(const nlohmann::json& j, PseudoRandomInfo& pseudoRandomInfo)
}
}
+void from_json(const nlohmann::json& j, FixedDataInfo& fixedDataInfo)
+{
+ j.at("data").get_to(fixedDataInfo.data);
+}
+
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("data_type").get_to(cfg.dataType);
@@ -158,6 +165,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("pseudo_random_info").get_to(cfg.pseudoRandomInfo);
}
+
+ // Set up defaults for fixedDataInfo
+ cfg.fixedDataInfo.data = std::vector<int32_t>();
+ if (j.contains("fixed_data_info"))
+ {
+ j.at("fixed_data_info").get_to(cfg.fixedDataInfo);
+ }
}
std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName)
@@ -209,6 +223,7 @@ size_t elementSizeFromType(DType type)
return 2;
case DType::DType_INT32:
case DType::DType_FP32:
+ case DType::DType_SHAPE:
return 4;
default:
return 0;
diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h
index f9ec713..697b404 100644
--- a/reference_model/src/generate/generate_utils.h
+++ b/reference_model/src/generate/generate_utils.h
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@ enum class GeneratorType
OpFullRange,
OpBoundary,
OpSpecial,
+ FixedData,
};
/// \brief Supported input types
@@ -65,6 +66,14 @@ struct PseudoRandomInfo
bool round;
};
+/// \brief Fixed data generator meta-data
+struct FixedDataInfo
+{
+ FixedDataInfo() = default;
+
+ std::vector<int32_t> data;
+};
+
/// \brief Generator configuration
struct GenerateConfig
{
@@ -76,6 +85,7 @@ struct GenerateConfig
tosa::Op opType;
DotProductInfo dotProductInfo;
PseudoRandomInfo pseudoRandomInfo;
+ FixedDataInfo fixedDataInfo;
};
/// \brief Parse the generator config when given in JSON form