aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/generate')
-rw-r--r--reference_model/src/generate/generate_dot_product.cc88
-rw-r--r--reference_model/src/generate/generate_dot_product.h3
-rw-r--r--reference_model/src/generate/generate_dot_product_states.cc48
-rw-r--r--reference_model/src/generate/generate_utils.cc1
4 files changed, 131 insertions, 9 deletions
diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc
index fe829e3..046007e 100644
--- a/reference_model/src/generate/generate_dot_product.cc
+++ b/reference_model/src/generate/generate_dot_product.cc
@@ -736,6 +736,92 @@ bool generateTransposeConv2D(const TosaReference::GenerateConfig& cfg,
return false;
}
}
+//---------------------------------------------------------------------------//
+// FFT2D //
+//---------------------------------------------------------------------------//
+
+template <typename DataType>
+bool generateFFT2DReal(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ DataType* data,
+ size_t size)
+{
+ const int64_t T = TosaReference::numElementsFromShape(cfg.shape);
+ const uint32_t H = cfg.shape[1];
+ const uint32_t W = cfg.shape[2];
+
+ for (int64_t t = 0; t < T; ++t)
+ {
+ uint32_t x = t % W;
+ uint32_t y = (t / W) % H;
+ uint32_t k = y * W + x;
+
+ data[t] = static_cast<DataType>(generator(k));
+ }
+ return true;
+}
+
+template <typename DataType>
+bool generateFFT2DImag(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ DataType* data,
+ size_t size)
+{
+ const int64_t T = TosaReference::numElementsFromShape(cfg.shape);
+ const uint32_t H = cfg.shape[1];
+ const uint32_t W = cfg.shape[2];
+
+ // The index expression of ((1*N+n)*H+y)*W+x in the spec equates to
+ // using the values after those used for the Real tensor, but we need
+ // to iterate through all those values to get to the Imaginary data
+ for (int64_t n = 0; n < 2; ++n)
+ {
+ for (int64_t t = 0; t < T; ++t)
+ {
+ uint32_t x = t % W;
+ uint32_t y = (t / W) % H;
+ uint32_t k = y * W + x;
+
+ data[t] = static_cast<DataType>(generator(k));
+ }
+ }
+ return true;
+}
+
+bool generateFFT2D(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ if (cfg.shape.size() != 3)
+ {
+ WARNING("[Generator][DP][FFT2D] Tensor shape expected 3 dimensions.");
+ return false;
+ }
+
+ switch (cfg.dataType)
+ {
+ case DType::DType_FP32: {
+ float* outData = reinterpret_cast<float*>(data);
+ switch (cfg.inputPos)
+ {
+ case 0:
+ return generateFFT2DReal(cfg, generator, outData, size);
+ case 1:
+ return generateFFT2DImag(cfg, generator, outData, size);
+ default:
+ WARNING("[Generator][DP][FFT2D] Invalid input tensor slot position to operator.");
+ return false;
+ }
+ break;
+ }
+ default:
+ WARNING("[Generator][DP][FFT2D] Only supports FP32.");
+ return false;
+ }
+
+ return true;
+}
} // namespace
namespace TosaReference
@@ -772,6 +858,8 @@ bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size)
return generateDepthwiseConv2D(cfg, *generator, data, size);
case tosa::Op_TRANSPOSE_CONV2D:
return generateTransposeConv2D(cfg, *generator, data, size);
+ case tosa::Op_FFT2D:
+ return generateFFT2D(cfg, *generator, data, size);
default:
WARNING("[Generator][DP] Unsupported operator.");
return false;
diff --git a/reference_model/src/generate/generate_dot_product.h b/reference_model/src/generate/generate_dot_product.h
index cd9d4ba..bf1b1ff 100644
--- a/reference_model/src/generate/generate_dot_product.h
+++ b/reference_model/src/generate/generate_dot_product.h
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -28,6 +28,7 @@ class IDotProductGenerator
public:
virtual float operator()(uint32_t k) = 0;
virtual ~IDotProductGenerator() = default;
+ virtual uint32_t nextIndex() = 0;
};
/// \brief Dot-product stage generator selector
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc
index 9ce32ff..b78be71 100644
--- a/reference_model/src/generate/generate_dot_product_states.cc
+++ b/reference_model/src/generate/generate_dot_product_states.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -60,7 +60,7 @@ public:
return pseudo;
}
- uint32_t index()
+ uint32_t nextIndex()
{
return _index;
}
@@ -101,6 +101,11 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS0")
+ return _set_data0.nextIndex();
+ }
private:
uint32_t _p;
@@ -129,6 +134,10 @@ public:
else
return (_B * _B / (_KS + 1)) * v;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -158,6 +167,10 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -186,6 +199,10 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -229,6 +246,11 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS4")
+ return _set_data0.nextIndex();
+ }
private:
uint32_t _p;
@@ -258,6 +280,10 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -307,21 +333,27 @@ std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConf
float B = getBoundParameter(cfg.dataType, dpinfo.accType);
if (B > 0.f)
{
+ auto param = cfg.inputPos;
+ if (cfg.opType == Op_FFT2D)
+ {
+ // We only use param of zero for FFT2D tensors
+ param = 0;
+ }
// Create the generator
switch (dpinfo.s)
{
case 0:
- return std::make_unique<GeneratorS0>(cfg.inputPos);
+ return std::make_unique<GeneratorS0>(param);
case 1:
- return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B);
+ return std::make_unique<GeneratorS1>(param, dpinfo.ks, B);
case 2:
- return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks);
+ return std::make_unique<GeneratorS2>(param, dpinfo.ks);
case 3:
- return std::make_unique<GeneratorS3>(cfg.inputPos);
+ return std::make_unique<GeneratorS3>(param);
case 4:
- return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B);
+ return std::make_unique<GeneratorS4>(param, dpinfo.ks, B);
case 5:
- return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B);
+ return std::make_unique<GeneratorS5>(param, dpinfo.ks, B);
default:
WARNING("[Generator][DP] Unsupported dot product test series for generator.");
return nullptr;
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index a8b472a..2e40b04 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -54,6 +54,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op,
{ Op::Op_ERF, "ERF" },
{ Op::Op_EXP, "EXP" },
{ Op::Op_FLOOR, "FLOOR" },
+ { Op::Op_FFT2D, "FFT2D" },
{ Op::Op_FULLY_CONNECTED, "FULLY_CONNECTED" },
{ Op::Op_GATHER, "GATHER" },
{ Op::Op_GREATER, "GREATER" },