aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/operators.cc
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2023-08-10 10:33:01 +0000
committerWon Jeon <won.jeon@arm.com>2023-08-18 15:21:15 -0700
commita21b2e88d19d8cb11a9120d40bacbb594d600565 (patch)
tree3bc8a40db72a31c1e552a3bd6339627a1175686e /reference_model/src/operators.cc
parente0247481eb1f83f6eb7161d3f7ac2690b180952a (diff)
downloadreference_model-a21b2e88d19d8cb11a9120d40bacbb594d600565.tar.gz
Add DIM operator to reference model
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: Iea11ee5d3d98773e9c5e9b827593c05afb41ce3b
Diffstat (limited to 'reference_model/src/operators.cc')
-rw-r--r--reference_model/src/operators.cc30
1 files changed, 30 insertions, 0 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 1ae0683..ae5963d 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -68,6 +68,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type)
return tosa::DType::DType_UINT16;
case tosa_datatype_uint8_t:
return tosa::DType::DType_UINT8;
+ case tosa_datatype_shape_t:
+ return tosa::DType::DType_SHAPE;
default:
return tosa::DType::DType_UNKNOWN;
}
@@ -1978,6 +1980,34 @@ extern "C"
return tosa_status_valid;
}
+ tosa_status_t tosa_run_dim(tosa_tensor_t client_input1, const int32_t client_axis, tosa_tensor_t client_output)
+ {
+ // Create operator attributes
+ TosaAxisAttribute attr(client_axis);
+
+ // Create tensors
+ tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1");
+ tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
+
+ // Create operator
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_DIM, tosa::Attribute::Attribute_AxisAttribute, &attr,
+ { input1->GetName() }, { output->GetName() });
+
+ // Create a tosa single-op basic block
+ tosa::TosaSerializationBasicBlock block("dim", "main", { op }, { input1, output }, { input1->GetName() },
+ { output->GetName() });
+
+ // Setup model
+ TosaReference::ModelRunnerImpl runner;
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
+ TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
+
+ // Execute
+ TOSA_RETURN_ON_ERROR(runner.getOutput(output->GetName(), client_output.data, client_output.size));
+
+ return tosa_status_valid;
+ }
+
tosa_status_t tosa_run_reshape(tosa_tensor_t client_input1,
const int32_t client_new_shape_len,
const int32_t client_new_shape[],