aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src')
-rw-r--r--reference_model/src/operators.cc86
1 files changed, 86 insertions, 0 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 13e8b12..015b28f 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -342,6 +342,51 @@ extern "C"
return tosa_status_valid;
}
+ tosa_status_t tosa_run_fft2d(tosa_tensor_t client_input_real,
+ tosa_tensor_t client_input_imag,
+ const bool client_inverse,
+ tosa_tensor_t client_output_real,
+ const bool client_local_bound,
+ tosa_tensor_t client_output_imag,
+ const func_ctx_t& func_ctx)
+ {
+ // Create operator attributes
+ TosaFFTAttribute attr(client_inverse, client_local_bound);
+
+ // Create tensors
+ tosa::TosaSerializationTensor* input_real = translate_client_tensor(client_input_real, "input_real");
+ tosa::TosaSerializationTensor* input_imag = translate_client_tensor(client_input_imag, "input_imag");
+ tosa::TosaSerializationTensor* output_real = translate_client_tensor(client_output_real, "output_real");
+ tosa::TosaSerializationTensor* output_imag = translate_client_tensor(client_output_imag, "output_imag");
+
+ // Create operator
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_FFT2D, tosa::Attribute::Attribute_FFTAttribute,
+ &attr, { input_real->GetName(), input_imag->GetName() },
+ { output_real->GetName(), output_imag->GetName() });
+
+ // Create a tosa single-op basic block
+ tosa::TosaSerializationBasicBlock block(
+ "fft2d", "main", { op }, { input_real, input_imag, output_real, output_imag },
+ { input_real->GetName(), input_imag->GetName() }, { output_real->GetName(), output_imag->GetName() });
+
+ // Setup model
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
+ TOSA_RETURN_ON_ERROR(runner.setInput(input_real->GetName(), client_input_real.data, client_input_real.size));
+ TOSA_RETURN_ON_ERROR(runner.setInput(input_imag->GetName(), client_input_imag.data, client_input_imag.size));
+
+ // Execute
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
+
+ // Extract outputs
+ TOSA_RETURN_ON_ERROR(
+ runner.getOutput(output_real->GetName(), client_output_real.data, client_output_real.size));
+ TOSA_RETURN_ON_ERROR(
+ runner.getOutput(output_imag->GetName(), client_output_imag.data, client_output_imag.size));
+
+ return tosa_status_valid;
+ }
+
tosa_status_t tosa_run_fully_connected(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
tosa_tensor_t client_bias,
@@ -465,6 +510,47 @@ extern "C"
return tosa_status_valid;
}
+ tosa_status_t tosa_run_rfft2d(tosa_tensor_t client_input,
+ tosa_tensor_t client_output_real,
+ const bool client_local_bound,
+ tosa_tensor_t client_output_imag,
+ const func_ctx_t& func_ctx)
+ {
+ // Create operator attributes
+ TosaRFFTAttribute attr(client_local_bound);
+
+ // Create tensors
+ tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
+ tosa::TosaSerializationTensor* output_real = translate_client_tensor(client_output_real, "output_real");
+ tosa::TosaSerializationTensor* output_imag = translate_client_tensor(client_output_imag, "output_imag");
+
+ // Create operator
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_RFFT2D, tosa::Attribute::Attribute_RFFTAttribute,
+ &attr, { input->GetName() },
+ { output_real->GetName(), output_imag->GetName() });
+
+ // Create a tosa single-op basic block
+ tosa::TosaSerializationBasicBlock block("rfft2d", "main", { op }, { input, output_real, output_imag },
+ { input->GetName() },
+ { output_real->GetName(), output_imag->GetName() });
+
+ // Setup model
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
+ TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
+
+ // Execute
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
+
+ // Extract outputs
+ TOSA_RETURN_ON_ERROR(
+ runner.getOutput(output_real->GetName(), client_output_real.data, client_output_real.size));
+ TOSA_RETURN_ON_ERROR(
+ runner.getOutput(output_imag->GetName(), client_output_imag.data, client_output_imag.size));
+
+ return tosa_status_valid;
+ }
+
tosa_status_t tosa_run_transpose_conv2d(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
tosa_tensor_t client_bias,