diff options
author | Won Jeon <won.jeon@arm.com> | 2023-08-10 10:33:01 +0000 |
---|---|---|
committer | Won Jeon <won.jeon@arm.com> | 2023-08-18 15:21:15 -0700 |
commit | a21b2e88d19d8cb11a9120d40bacbb594d600565 (patch) | |
tree | 3bc8a40db72a31c1e552a3bd6339627a1175686e /reference_model/src/operators.cc | |
parent | e0247481eb1f83f6eb7161d3f7ac2690b180952a (diff) | |
download | reference_model-a21b2e88d19d8cb11a9120d40bacbb594d600565.tar.gz |
Add DIM operator to reference model
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: Iea11ee5d3d98773e9c5e9b827593c05afb41ce3b
Diffstat (limited to 'reference_model/src/operators.cc')
-rw-r--r-- | reference_model/src/operators.cc | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index 1ae0683..ae5963d 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -68,6 +68,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type) return tosa::DType::DType_UINT16; case tosa_datatype_uint8_t: return tosa::DType::DType_UINT8; + case tosa_datatype_shape_t: + return tosa::DType::DType_SHAPE; default: return tosa::DType::DType_UNKNOWN; } @@ -1978,6 +1980,34 @@ extern "C" return tosa_status_valid; } + tosa_status_t tosa_run_dim(tosa_tensor_t client_input1, const int32_t client_axis, tosa_tensor_t client_output) + { + // Create operator attributes + TosaAxisAttribute attr(client_axis); + + // Create tensors + tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1"); + tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); + + // Create operator + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_DIM, tosa::Attribute::Attribute_AxisAttribute, &attr, + { input1->GetName() }, { output->GetName() }); + + // Create a tosa single-op basic block + tosa::TosaSerializationBasicBlock block("dim", "main", { op }, { input1, output }, { input1->GetName() }, + { output->GetName() }); + + // Setup model + TosaReference::ModelRunnerImpl runner; + TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); + TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size)); + + // Execute + TOSA_RETURN_ON_ERROR(runner.getOutput(output->GetName(), client_output.data, client_output.size)); + + return tosa_status_valid; + } + tosa_status_t tosa_run_reshape(tosa_tensor_t client_input1, const int32_t client_new_shape_len, const int32_t client_new_shape[], |