From a21b2e88d19d8cb11a9120d40bacbb594d600565 Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Thu, 10 Aug 2023 10:33:01 +0000 Subject: Add DIM operator to reference model Signed-off-by: Won Jeon Change-Id: Iea11ee5d3d98773e9c5e9b827593c05afb41ce3b --- reference_model/src/operators.cc | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) (limited to 'reference_model/src/operators.cc') diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index 1ae0683..ae5963d 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -68,6 +68,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type) return tosa::DType::DType_UINT16; case tosa_datatype_uint8_t: return tosa::DType::DType_UINT8; + case tosa_datatype_shape_t: + return tosa::DType::DType_SHAPE; default: return tosa::DType::DType_UNKNOWN; } @@ -1978,6 +1980,34 @@ extern "C" return tosa_status_valid; } + tosa_status_t tosa_run_dim(tosa_tensor_t client_input1, const int32_t client_axis, tosa_tensor_t client_output) + { + // Create operator attributes + TosaAxisAttribute attr(client_axis); + + // Create tensors + tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1"); + tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); + + // Create operator + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_DIM, tosa::Attribute::Attribute_AxisAttribute, &attr, + { input1->GetName() }, { output->GetName() }); + + // Create a tosa single-op basic block + tosa::TosaSerializationBasicBlock block("dim", "main", { op }, { input1, output }, { input1->GetName() }, + { output->GetName() }); + + // Setup model + TosaReference::ModelRunnerImpl runner; + TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); + TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size)); + + // Execute + TOSA_RETURN_ON_ERROR(runner.getOutput(output->GetName(), client_output.data, client_output.size)); + + return tosa_status_valid; + } + tosa_status_t tosa_run_reshape(tosa_tensor_t client_input1, const int32_t client_new_shape_len, const int32_t client_new_shape[], -- cgit v1.2.1