aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGrant Watson <grant.watson@arm.com>2023-05-31 14:56:13 +0100
committerGrant Watson <grant.watson@arm.com>2023-06-12 17:15:54 +0100
commit6168047ef0354927cb175ad295722924dfc3053c (patch)
tree033b1568c35eea8c2dc65487c48c4530a71ab2b3
parent5dd5a55bc00d0eaf9aa38511cf553b0d78dfed51 (diff)
downloadreference_model-6168047ef0354927cb175ad295722924dfc3053c.tar.gz
Correctly identify "axis" attributes.
- Allows axis attributes to be treated differently to other arguments in attribute.def Signed-off-by: Grant Watson <grant.watson@arm.com> Change-Id: I1be2595c24bf22e5391a2911a5283391d310df37
-rw-r--r--reference_model/src/operators.cc64
-rw-r--r--scripts/operator_api/generate_api.py12
2 files changed, 49 insertions, 27 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index a070326..e9d6cad 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -85,15 +85,16 @@ extern "C"
tosa_status_t tosa_run_argmax(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output)
{
// Create operator attributes
- TosaNoneAttribute attr;
+ const int32_t axis = client_axis;
+ TosaAxisAttribute attr(axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_ARGMAX, tosa::Attribute::Attribute_NONE, &attr,
- { input->GetName() }, { output->GetName() });
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_ARGMAX, tosa::Attribute::Attribute_AxisAttribute,
+ &attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("argmax", "main", { op }, { input, output }, { input->GetName() },
@@ -1659,15 +1660,16 @@ extern "C"
tosa_run_reduce_all(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output)
{
// Create operator attributes
- TosaNoneAttribute attr;
+ const int32_t axis = client_axis;
+ TosaAxisAttribute attr(axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_ALL, tosa::Attribute::Attribute_NONE, &attr,
- { input->GetName() }, { output->GetName() });
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_ALL, tosa::Attribute::Attribute_AxisAttribute,
+ &attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("reduce_all", "main", { op }, { input, output }, { input->GetName() },
@@ -1691,15 +1693,16 @@ extern "C"
tosa_run_reduce_any(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output)
{
// Create operator attributes
- TosaNoneAttribute attr;
+ const int32_t axis = client_axis;
+ TosaAxisAttribute attr(axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_ANY, tosa::Attribute::Attribute_NONE, &attr,
- { input->GetName() }, { output->GetName() });
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_ANY, tosa::Attribute::Attribute_AxisAttribute,
+ &attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("reduce_any", "main", { op }, { input, output }, { input->GetName() },
@@ -1723,15 +1726,16 @@ extern "C"
tosa_run_reduce_max(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output)
{
// Create operator attributes
- TosaNoneAttribute attr;
+ const int32_t axis = client_axis;
+ TosaAxisAttribute attr(axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_MAX, tosa::Attribute::Attribute_NONE, &attr,
- { input->GetName() }, { output->GetName() });
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_MAX, tosa::Attribute::Attribute_AxisAttribute,
+ &attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("reduce_max", "main", { op }, { input, output }, { input->GetName() },
@@ -1755,15 +1759,16 @@ extern "C"
tosa_run_reduce_min(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output)
{
// Create operator attributes
- TosaNoneAttribute attr;
+ const int32_t axis = client_axis;
+ TosaAxisAttribute attr(axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_MIN, tosa::Attribute::Attribute_NONE, &attr,
- { input->GetName() }, { output->GetName() });
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_MIN, tosa::Attribute::Attribute_AxisAttribute,
+ &attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("reduce_min", "main", { op }, { input, output }, { input->GetName() },
@@ -1787,15 +1792,17 @@ extern "C"
tosa_run_reduce_product(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output)
{
// Create operator attributes
- TosaNoneAttribute attr;
+ const int32_t axis = client_axis;
+ TosaAxisAttribute attr(axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_PRODUCT, tosa::Attribute::Attribute_NONE,
- &attr, { input->GetName() }, { output->GetName() });
+ auto op =
+ new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_PRODUCT, tosa::Attribute::Attribute_AxisAttribute,
+ &attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("reduce_product", "main", { op }, { input, output },
@@ -1819,15 +1826,16 @@ extern "C"
tosa_run_reduce_sum(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output)
{
// Create operator attributes
- TosaNoneAttribute attr;
+ const int32_t axis = client_axis;
+ TosaAxisAttribute attr(axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_SUM, tosa::Attribute::Attribute_NONE, &attr,
- { input->GetName() }, { output->GetName() });
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REDUCE_SUM, tosa::Attribute::Attribute_AxisAttribute,
+ &attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("reduce_sum", "main", { op }, { input, output }, { input->GetName() },
@@ -1850,15 +1858,16 @@ extern "C"
tosa_status_t tosa_run_concat(tosa_tensor_t client_input1, const int32_t client_axis, tosa_tensor_t client_output)
{
// Create operator attributes
- TosaNoneAttribute attr;
+ const int32_t axis = client_axis;
+ TosaAxisAttribute attr(axis);
// Create tensors
tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_CONCAT, tosa::Attribute::Attribute_NONE, &attr,
- { input1->GetName() }, { output->GetName() });
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_CONCAT, tosa::Attribute::Attribute_AxisAttribute,
+ &attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("concat", "main", { op }, { input1, output }, { input1->GetName() },
@@ -1955,15 +1964,16 @@ extern "C"
tosa_status_t tosa_run_reverse(tosa_tensor_t client_input, const int32_t client_axis, tosa_tensor_t client_output)
{
// Create operator attributes
- TosaNoneAttribute attr;
+ const int32_t axis = client_axis;
+ TosaAxisAttribute attr(axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REVERSE, tosa::Attribute::Attribute_NONE, &attr,
- { input->GetName() }, { output->GetName() });
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_REVERSE, tosa::Attribute::Attribute_AxisAttribute,
+ &attr, { input->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
tosa::TosaSerializationBasicBlock block("reverse", "main", { op }, { input, output }, { input->GetName() },
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 671d902..5038973 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -192,6 +192,18 @@ def getOperators(tosaXml):
operator["serializeAttType"] = getSerializeOpType(opName)
tosaArgs = getTosaArgs(opXml)
serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs)
+ # Handle "axis" arguments
+ axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"]
+ if operator["serializeAttType"] == "None" and len(axisList) > 0:
+ operator["serializeAttType"] = "Axis"
+ serializeArgs = [
+ {
+ "name": "axis",
+ "dType": "int32_t",
+ "SV": "S",
+ "init": "= client_axis",
+ }
+ ]
updateTosaArgs(tosaArgs, serializeArgs, tosaXml)
operator["arguments"] = tosaArgs
operator["serializeArgs"] = serializeArgs