aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_pseudo_random.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/generate/generate_pseudo_random.cc')
-rw-r--r--reference_model/src/generate/generate_pseudo_random.cc21
1 files changed, 14 insertions, 7 deletions
diff --git a/reference_model/src/generate/generate_pseudo_random.cc b/reference_model/src/generate/generate_pseudo_random.cc
index b51424d..b62c38f 100644
--- a/reference_model/src/generate/generate_pseudo_random.cc
+++ b/reference_model/src/generate/generate_pseudo_random.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "generate.h"
#include "generate_utils.h"
+#include "half.hpp"
#include <array>
#include <iterator>
@@ -88,7 +89,8 @@ private:
bool _useUniform;
};
-bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t size)
+template <typename DataType>
+bool generateFP(const TosaReference::GenerateConfig& cfg, DataType* data, size_t size)
{
const TosaReference::PseudoRandomInfo& prinfo = cfg.pseudoRandomInfo;
@@ -106,21 +108,20 @@ bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t s
generator = new PseudoRandomGeneratorFloat<float>(prinfo.rngSeed);
}
- float* a = reinterpret_cast<float*>(data);
const auto T = TosaReference::numElementsFromShape(cfg.shape);
const bool comparisonOp =
(cfg.opType == Op::Op_EQUAL) || (cfg.opType == Op::Op_GREATER_EQUAL) || (cfg.opType == Op::Op_GREATER);
for (auto t = 0; t < T; ++t)
{
- a[t] = generator->getRandomFloat();
+ data[t] = static_cast<DataType>(generator->getRandomFloat());
if (comparisonOp && (t % 4 == 0))
{
// Set every 4th value to 0 to enable better comparison testing
- a[t] = 0.f;
+ data[t] = static_cast<DataType>(0.f);
}
else if (roundMode)
{
- a[t] = std::roundf(a[t]);
+ data[t] = static_cast<DataType>(std::roundf(data[t]));
}
}
return true;
@@ -146,8 +147,14 @@ bool generatePseudoRandom(const GenerateConfig& cfg, void* data, size_t size)
switch (cfg.dataType)
{
- case DType::DType_FP32:
- return generateFP32(cfg, data, size);
+ case DType::DType_FP32: {
+ float* outData = reinterpret_cast<float*>(data);
+ return generateFP(cfg, outData, size);
+ }
+ case DType::DType_FP16: {
+ half_float::half* outData = reinterpret_cast<half_float::half*>(data);
+ return generateFP(cfg, outData, size);
+ }
default:
WARNING("[Generator][PR] Unsupported type.");
return false;