aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDhruv Chauhan <dhruv.chauhan@arm.com>2023-11-28 15:00:34 +0000
committerDhruv Chauhan <dhruv.chauhan@arm.com>2023-11-28 17:21:01 +0000
commit35a3aa994cf18f735193a05a7eb2c61d497233d2 (patch)
tree2add19eb59682b9d1c04c2ac30e67d2ab36a998e
parenta015001dfbd0ed48caf54fd66b0509ee344a229e (diff)
downloadreference_model-35a3aa994cf18f735193a05a7eb2c61d497233d2.tar.gz
Fix Fast Fourier Transforms in operator API
* Change ignore list in generate_api.py to generate operators information. * Fix serialization attributes mapping for operator FFT and RFFT * Add a unit test for Fft2d and Rfft2d operator Change-Id: I3ad7a77a3c46aa586834188bab42cbdcc423e834 Signed-off-by: Dhruv Chauhan <dhruv.chauhan@arm.com>
-rw-r--r--reference_model/include/operators.h14
-rw-r--r--reference_model/src/operators.cc86
-rw-r--r--reference_model/test/model_runner_tests.cpp127
-rw-r--r--scripts/operator_api/generate_api.py4
4 files changed, 229 insertions, 2 deletions
diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h
index d2bcf87..d037631 100644
--- a/reference_model/include/operators.h
+++ b/reference_model/include/operators.h
@@ -88,6 +88,14 @@ extern "C"
tosa_tensor_t client_output,
const func_ctx_t& func_ctx);
+ tosa_status_t tosa_run_fft2d(tosa_tensor_t client_input_real,
+ tosa_tensor_t client_input_imag,
+ const bool client_inverse,
+ tosa_tensor_t client_output_real,
+ const bool client_local_bound,
+ tosa_tensor_t client_output_imag,
+ const func_ctx_t& func_ctx);
+
tosa_status_t tosa_run_fully_connected(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
tosa_tensor_t client_bias,
@@ -112,6 +120,12 @@ extern "C"
tosa_tensor_t client_output,
const func_ctx_t& func_ctx);
+ tosa_status_t tosa_run_rfft2d(tosa_tensor_t client_input,
+ tosa_tensor_t client_output_real,
+ const bool client_local_bound,
+ tosa_tensor_t client_output_imag,
+ const func_ctx_t& func_ctx);
+
tosa_status_t tosa_run_transpose_conv2d(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
tosa_tensor_t client_bias,
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 13e8b12..015b28f 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -342,6 +342,51 @@ extern "C"
return tosa_status_valid;
}
+ tosa_status_t tosa_run_fft2d(tosa_tensor_t client_input_real,
+ tosa_tensor_t client_input_imag,
+ const bool client_inverse,
+ tosa_tensor_t client_output_real,
+ const bool client_local_bound,
+ tosa_tensor_t client_output_imag,
+ const func_ctx_t& func_ctx)
+ {
+ // Create operator attributes
+ TosaFFTAttribute attr(client_inverse, client_local_bound);
+
+ // Create tensors
+ tosa::TosaSerializationTensor* input_real = translate_client_tensor(client_input_real, "input_real");
+ tosa::TosaSerializationTensor* input_imag = translate_client_tensor(client_input_imag, "input_imag");
+ tosa::TosaSerializationTensor* output_real = translate_client_tensor(client_output_real, "output_real");
+ tosa::TosaSerializationTensor* output_imag = translate_client_tensor(client_output_imag, "output_imag");
+
+ // Create operator
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_FFT2D, tosa::Attribute::Attribute_FFTAttribute,
+ &attr, { input_real->GetName(), input_imag->GetName() },
+ { output_real->GetName(), output_imag->GetName() });
+
+ // Create a tosa single-op basic block
+ tosa::TosaSerializationBasicBlock block(
+ "fft2d", "main", { op }, { input_real, input_imag, output_real, output_imag },
+ { input_real->GetName(), input_imag->GetName() }, { output_real->GetName(), output_imag->GetName() });
+
+ // Setup model
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
+ TOSA_RETURN_ON_ERROR(runner.setInput(input_real->GetName(), client_input_real.data, client_input_real.size));
+ TOSA_RETURN_ON_ERROR(runner.setInput(input_imag->GetName(), client_input_imag.data, client_input_imag.size));
+
+ // Execute
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
+
+ // Extract outputs
+ TOSA_RETURN_ON_ERROR(
+ runner.getOutput(output_real->GetName(), client_output_real.data, client_output_real.size));
+ TOSA_RETURN_ON_ERROR(
+ runner.getOutput(output_imag->GetName(), client_output_imag.data, client_output_imag.size));
+
+ return tosa_status_valid;
+ }
+
tosa_status_t tosa_run_fully_connected(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
tosa_tensor_t client_bias,
@@ -465,6 +510,47 @@ extern "C"
return tosa_status_valid;
}
+ tosa_status_t tosa_run_rfft2d(tosa_tensor_t client_input,
+ tosa_tensor_t client_output_real,
+ const bool client_local_bound,
+ tosa_tensor_t client_output_imag,
+ const func_ctx_t& func_ctx)
+ {
+ // Create operator attributes
+ TosaRFFTAttribute attr(client_local_bound);
+
+ // Create tensors
+ tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
+ tosa::TosaSerializationTensor* output_real = translate_client_tensor(client_output_real, "output_real");
+ tosa::TosaSerializationTensor* output_imag = translate_client_tensor(client_output_imag, "output_imag");
+
+ // Create operator
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_RFFT2D, tosa::Attribute::Attribute_RFFTAttribute,
+ &attr, { input->GetName() },
+ { output_real->GetName(), output_imag->GetName() });
+
+ // Create a tosa single-op basic block
+ tosa::TosaSerializationBasicBlock block("rfft2d", "main", { op }, { input, output_real, output_imag },
+ { input->GetName() },
+ { output_real->GetName(), output_imag->GetName() });
+
+ // Setup model
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
+ TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
+
+ // Execute
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
+
+ // Extract outputs
+ TOSA_RETURN_ON_ERROR(
+ runner.getOutput(output_real->GetName(), client_output_real.data, client_output_real.size));
+ TOSA_RETURN_ON_ERROR(
+ runner.getOutput(output_imag->GetName(), client_output_imag.data, client_output_imag.size));
+
+ return tosa_status_valid;
+ }
+
tosa_status_t tosa_run_transpose_conv2d(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
tosa_tensor_t client_bias,
diff --git a/reference_model/test/model_runner_tests.cpp b/reference_model/test/model_runner_tests.cpp
index 2037b73..6580774 100644
--- a/reference_model/test/model_runner_tests.cpp
+++ b/reference_model/test/model_runner_tests.cpp
@@ -182,6 +182,133 @@ TEST_SUITE("model_runner")
compareOutput(dstData, expectedData, expectedData.size());
}
+ TEST_CASE("op_entry_fft2d")
+ {
+ // Fft2d parameters
+ const bool inverse = false;
+ const bool local_bound = false;
+
+ // Inputs/Outputs
+ tosa_datatype_t dt = tosa_datatype_fp32_t;
+ std::vector<int32_t> input_shape_real = { 1, 32, 32 };
+ std::vector<int32_t> output_shape_real = { 1, 32, 32 };
+ std::vector<int32_t> input_shape_imag = { 1, 32, 32 };
+ std::vector<int32_t> output_shape_imag = { 1, 32, 32 };
+ std::vector<float> srcData(32 * 32 * 1, 0.f);
+ std::vector<float> dstDataReal(32 * 32 * 1, 0.f);
+ std::vector<float> dstDataImag(32 * 32 * 1, 0.f);
+
+ tosa_tensor_t input_real;
+ input_real.shape = input_shape_real.data();
+ input_real.num_dims = input_shape_real.size();
+ input_real.data_type = dt;
+ input_real.data = reinterpret_cast<uint8_t*>(srcData.data());
+ input_real.size = srcData.size() * sizeof(float);
+
+ tosa_tensor_t input_imag;
+ input_imag.shape = input_shape_imag.data();
+ input_imag.num_dims = input_shape_imag.size();
+ input_imag.data_type = dt;
+ input_imag.data = reinterpret_cast<uint8_t*>(srcData.data());
+ input_imag.size = srcData.size() * sizeof(float);
+
+ tosa_tensor_t output_real;
+ output_real.shape = output_shape_real.data();
+ output_real.num_dims = output_shape_real.size();
+ output_real.data_type = dt;
+ output_real.data = reinterpret_cast<uint8_t*>(dstDataReal.data());
+ output_real.size = dstDataReal.size() * sizeof(float);
+
+ tosa_tensor_t output_imag;
+ output_imag.shape = output_shape_imag.data();
+ output_imag.num_dims = output_shape_imag.size();
+ output_imag.data_type = dt;
+ output_imag.data = reinterpret_cast<uint8_t*>(dstDataImag.data());
+ output_imag.size = dstDataImag.size() * sizeof(float);
+
+ // Execution
+ auto status = tosa_run_fft2d(input_real, input_imag, inverse, output_real, local_bound, output_imag, {});
+ CHECK((status == tosa_status_valid));
+
+ // Compare results
+ std::vector<float> expectedDataReal = {};
+ std::vector<float> expectedDataImag = {};
+ for (unsigned i = 0; i < dstDataReal.size(); ++i)
+ {
+ std::vector<float> sum_real = {};
+ std::vector<float> sum_imag = {};
+ for (unsigned j = 0; j < dstDataImag.size(); ++j)
+ {
+ float a = ((inverse) ? -1 : 1) * 402.123859659; /* 2 * pi * ((iY * oY) / H + (iX * oX) / W) */
+ sum_real.emplace_back(srcData[j] * std::cos(a) + srcData[j] * std::sin(a));
+ sum_imag.emplace_back((-1) * srcData[j] * std::sin(a) + srcData[j] * std::sin(a));
+ }
+ expectedDataReal.emplace_back(sum_real[i]);
+ expectedDataImag.emplace_back(sum_imag[i]);
+ }
+ compareOutput(dstDataReal, expectedDataReal, expectedDataReal.size());
+ compareOutput(dstDataImag, expectedDataImag, expectedDataImag.size());
+ }
+
+ TEST_CASE("op_entry_rfft2d")
+ {
+ // Rfft2d parameters
+ const bool local_bound = false;
+
+ // Inputs/Outputs
+ tosa_datatype_t dt = tosa_datatype_fp32_t;
+ std::vector<int32_t> input_shape = { 1, 32, 32 };
+ std::vector<int32_t> output_shape_real = { 1, 32, 17 };
+ std::vector<int32_t> output_shape_imag = { 1, 32, 17 };
+ std::vector<float> srcData(32 * 32 * 1, 0.f);
+ std::vector<float> dstDataReal(32 * 17 * 1, 0.f);
+ std::vector<float> dstDataImag(32 * 17 * 1, 0.f);
+
+ tosa_tensor_t input;
+ input.shape = input_shape.data();
+ input.num_dims = input_shape.size();
+ input.data_type = dt;
+ input.data = reinterpret_cast<uint8_t*>(srcData.data());
+ input.size = srcData.size() * sizeof(float);
+
+ tosa_tensor_t output_real;
+ output_real.shape = output_shape_real.data();
+ output_real.num_dims = output_shape_real.size();
+ output_real.data_type = dt;
+ output_real.data = reinterpret_cast<uint8_t*>(dstDataReal.data());
+ output_real.size = dstDataReal.size() * sizeof(float);
+
+ tosa_tensor_t output_imag;
+ output_imag.shape = output_shape_imag.data();
+ output_imag.num_dims = output_shape_imag.size();
+ output_imag.data_type = dt;
+ output_imag.data = reinterpret_cast<uint8_t*>(dstDataImag.data());
+ output_imag.size = dstDataImag.size() * sizeof(float);
+
+ // Execution
+ auto status = tosa_run_rfft2d(input, output_real, local_bound, output_imag, {});
+ CHECK((status == tosa_status_valid));
+
+ // Compare results
+ std::vector<float> expectedDataReal = {};
+ std::vector<float> expectedDataImag = {};
+ for (unsigned i = 0; i < dstDataReal.size(); ++i)
+ {
+ std::vector<float> sum_real = {};
+ std::vector<float> sum_imag = {};
+ for (unsigned j = 0; j < dstDataImag.size(); ++j)
+ {
+ float a = 307.876080052; /* 2 * pi * ((iY * oY) / H + (iX * oX) / W) */
+ sum_real.emplace_back(srcData[j] * std::cos(a));
+ sum_imag.emplace_back((-1) * srcData[j] * std::sin(a));
+ }
+ expectedDataReal.emplace_back(sum_real[i]);
+ expectedDataImag.emplace_back(sum_imag[i]);
+ }
+ compareOutput(dstDataReal, expectedDataReal, expectedDataReal.size());
+ compareOutput(dstDataImag, expectedDataImag, expectedDataImag.size());
+ }
+
TEST_CASE("op_entry_transpose_conv2d")
{
// Transpose Conv 2D parameters
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index d9077f0..7f10568 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -62,6 +62,8 @@ def getSerializeOpType(tosaOpName):
"conv3d": "Conv",
"depthwise_conv2d": "Conv",
"fully_connected": "FullyConnected",
+ "fft2d": "FFT",
+ "rfft2d": "RFFT",
"matmul": "MatMul",
"max_pool2d": "Pool",
"transpose_conv2d": "TransposeConv",
@@ -236,8 +238,6 @@ def getOperators(tosaXml):
"cond_if",
"const",
"custom",
- "fft2d",
- "rfft2d",
"variable",
"variable_read",
"variable_write",