diff options
Diffstat (limited to 'reference_model/src/generate/generate_pseudo_random.cc')
-rw-r--r-- | reference_model/src/generate/generate_pseudo_random.cc | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/reference_model/src/generate/generate_pseudo_random.cc b/reference_model/src/generate/generate_pseudo_random.cc index b51424d..b62c38f 100644 --- a/reference_model/src/generate/generate_pseudo_random.cc +++ b/reference_model/src/generate/generate_pseudo_random.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "generate.h" #include "generate_utils.h" +#include "half.hpp" #include <array> #include <iterator> @@ -88,7 +89,8 @@ private: bool _useUniform; }; -bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t size) +template <typename DataType> +bool generateFP(const TosaReference::GenerateConfig& cfg, DataType* data, size_t size) { const TosaReference::PseudoRandomInfo& prinfo = cfg.pseudoRandomInfo; @@ -106,21 +108,20 @@ bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t s generator = new PseudoRandomGeneratorFloat<float>(prinfo.rngSeed); } - float* a = reinterpret_cast<float*>(data); const auto T = TosaReference::numElementsFromShape(cfg.shape); const bool comparisonOp = (cfg.opType == Op::Op_EQUAL) || (cfg.opType == Op::Op_GREATER_EQUAL) || (cfg.opType == Op::Op_GREATER); for (auto t = 0; t < T; ++t) { - a[t] = generator->getRandomFloat(); + data[t] = static_cast<DataType>(generator->getRandomFloat()); if (comparisonOp && (t % 4 == 0)) { // Set every 4th value to 0 to enable better comparison testing - a[t] = 0.f; + data[t] = static_cast<DataType>(0.f); } else if (roundMode) { - a[t] = std::roundf(a[t]); + data[t] = static_cast<DataType>(std::roundf(data[t])); } } return true; @@ -146,8 +147,14 @@ bool generatePseudoRandom(const GenerateConfig& cfg, void* data, size_t size) switch (cfg.dataType) { - case DType::DType_FP32: - return generateFP32(cfg, data, size); + case DType::DType_FP32: { + float* outData = reinterpret_cast<float*>(data); + return generateFP(cfg, outData, size); + } + case DType::DType_FP16: { + half_float::half* outData = reinterpret_cast<half_float::half*>(data); + return generateFP(cfg, outData, size); + } default: WARNING("[Generator][PR] Unsupported type."); return false; |