diff options
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/src/operators.cc | 64 |
1 files changed, 37 insertions, 27 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index a070326..e9d6cad 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -85,15 +85,16 @@ extern "C" tosa_status_t tosa_run_argmax(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output) { // Create operator attributes - TosaNoneAttribute attr; + const int32_t axis = client_axis; + TosaAxisAttribute attr(axis); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_ARGMAX, tosa::Attribute::Attribute_NONE, &attr, - { input->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_ARGMAX, tosa::Attribute::Attribute_AxisAttribute, + &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("argmax", "main", { op }, { input, output }, { input->GetName() }, @@ -1659,15 +1660,16 @@ extern "C" tosa_run_reduce_all(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output) { // Create operator attributes - TosaNoneAttribute attr; + const int32_t axis = client_axis; + TosaAxisAttribute attr(axis); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_ALL, tosa::Attribute::Attribute_NONE, &attr, - { input->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_ALL, tosa::Attribute::Attribute_AxisAttribute, + &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("reduce_all", "main", { op }, { input, output }, { input->GetName() }, @@ -1691,15 +1693,16 @@ extern "C" tosa_run_reduce_any(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output) { // Create operator attributes - TosaNoneAttribute attr; + const int32_t axis = client_axis; + TosaAxisAttribute attr(axis); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_ANY, tosa::Attribute::Attribute_NONE, &attr, - { input->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_ANY, tosa::Attribute::Attribute_AxisAttribute, + &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("reduce_any", "main", { op }, { input, output }, { input->GetName() }, @@ -1723,15 +1726,16 @@ extern "C" tosa_run_reduce_max(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output) { // Create operator attributes - TosaNoneAttribute attr; + const int32_t axis = client_axis; + TosaAxisAttribute attr(axis); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_MAX, tosa::Attribute::Attribute_NONE, &attr, - { input->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_MAX, tosa::Attribute::Attribute_AxisAttribute, + &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("reduce_max", "main", { op }, { input, output }, { input->GetName() }, @@ -1755,15 +1759,16 @@ extern "C" tosa_run_reduce_min(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output) { // Create operator attributes - TosaNoneAttribute attr; + const int32_t axis = client_axis; + TosaAxisAttribute attr(axis); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_MIN, tosa::Attribute::Attribute_NONE, &attr, - { input->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_MIN, tosa::Attribute::Attribute_AxisAttribute, + &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("reduce_min", "main", { op }, { input, output }, { input->GetName() }, @@ -1787,15 +1792,17 @@ extern "C" tosa_run_reduce_product(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output) { // Create operator attributes - TosaNoneAttribute attr; + const int32_t axis = client_axis; + TosaAxisAttribute attr(axis); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_PRODUCT, tosa::Attribute::Attribute_NONE, - &attr, { input->GetName() }, { output->GetName() }); + auto op = + new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_PRODUCT, tosa::Attribute::Attribute_AxisAttribute, + &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("reduce_product", "main", { op }, { input, output }, @@ -1819,15 +1826,16 @@ extern "C" tosa_run_reduce_sum(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output) { // Create operator attributes - TosaNoneAttribute attr; + const int32_t axis = client_axis; + TosaAxisAttribute attr(axis); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_SUM, tosa::Attribute::Attribute_NONE, &attr, - { input->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_SUM, tosa::Attribute::Attribute_AxisAttribute, + &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("reduce_sum", "main", { op }, { input, output }, { input->GetName() }, @@ -1850,15 +1858,16 @@ extern "C" tosa_status_t tosa_run_concat(tosa_tensor_t client_input1, const int32_t client_axis, tosa_tensor_t client_output) { // Create operator attributes - TosaNoneAttribute attr; + const int32_t axis = client_axis; + TosaAxisAttribute attr(axis); // Create tensors tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_CONCAT, tosa::Attribute::Attribute_NONE, &attr, - { input1->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_CONCAT, tosa::Attribute::Attribute_AxisAttribute, + &attr, { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("concat", "main", { op }, { input1, output }, { input1->GetName() }, @@ -1955,15 +1964,16 @@ extern "C" tosa_status_t tosa_run_reverse(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output) { // Create operator attributes - TosaNoneAttribute attr; + const int32_t axis = client_axis; + TosaAxisAttribute attr(axis); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REVERSE, tosa::Attribute::Attribute_NONE, &attr, - { input->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REVERSE, tosa::Attribute::Attribute_AxisAttribute, + &attr, { input->GetName() }, { output->GetName() }); // Create a tosa single-op basic block tosa::TosaSerializationBasicBlock block("reverse", "main", { op }, { input, output }, { input->GetName() }, |