aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGrant Watson <grant.watson@arm.com>2023-09-12 10:46:36 +0100
committerEric Kunze <eric.kunze@arm.com>2023-11-28 16:36:20 -0800
commit09ae449db8a45ab7c48af4541b43cb3dc80f9a30 (patch)
tree7c5766f0facee228bc2f801b749c9110f591329d
parent9c3754715368d84567db883bdbafc31860850141 (diff)
downloadreference_model-09ae449db8a45ab7c48af4541b43cb3dc80f9a30.tar.gz
Upgrade to latest version of TOSA specification
- Updates TOSA specification to the latest version - Updates generate_api.py to generate the operator API correctly for ops with additional tensor inputs. - Removes default arguments for func_debug and func_config to make the API C compliant again. - Updates model_runner_tests.cpp for operators that have changed. - Adds a unit test for the Tile operator to check that generated code for additional tensor inputs works correctly. Signed-off-by: Grant Watson <grant.watson@arm.com> Change-Id: I1e26065c6ed333b2ca4b3da39972d30f896fa6e5
-rw-r--r--reference_model/include/operators.h250
-rw-r--r--reference_model/src/operators.cc506
-rw-r--r--reference_model/test/model_runner_tests.cpp74
-rw-r--r--scripts/operator_api/generate_api.py171
-rw-r--r--scripts/operator_api/templates/operators_cc.j226
-rw-r--r--scripts/operator_api/templates/operators_h.j237
m---------thirdparty/specification0
7 files changed, 463 insertions, 601 deletions
diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h
index 1519d20..08da277 100644
--- a/reference_model/include/operators.h
+++ b/reference_model/include/operators.h
@@ -31,11 +31,16 @@ extern "C"
{
#endif /* __cplusplus */
+ struct func_ctx_t
+ {
+ func_config_t func_config = func_config_t{};
+ func_debug_t func_debug = func_debug_t{};
+ };
+
tosa_status_t tosa_run_argmax(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_avg_pool2d(tosa_tensor_t client_input,
const int32_t client_kernel[2],
@@ -44,8 +49,7 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_output_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_conv2d(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
@@ -56,8 +60,7 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_weight_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_conv3d(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
@@ -68,8 +71,7 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_weight_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_depthwise_conv2d(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
@@ -80,8 +82,7 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_weight_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_fully_connected(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
@@ -89,16 +90,14 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_weight_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_matmul(tosa_tensor_t client_a,
tosa_tensor_t client_b,
const int32_t client_a_zp,
const int32_t client_b_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_max_pool2d(tosa_tensor_t client_input,
const int32_t client_kernel[2],
@@ -107,8 +106,7 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_output_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_transpose_conv2d(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
@@ -121,8 +119,7 @@ extern "C"
const int32_t client_dilation_len,
const int32_t client_dilation[],
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_clamp(tosa_tensor_t client_input,
const int32_t client_min_int,
@@ -130,280 +127,210 @@ extern "C"
const float client_min_fp,
const float client_max_fp,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_erf(tosa_tensor_t client_input,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t tosa_run_erf(tosa_tensor_t client_input, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_sigmoid(tosa_tensor_t client_input,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t tosa_run_sigmoid(tosa_tensor_t client_input, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_tanh(tosa_tensor_t client_input,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t tosa_run_tanh(tosa_tensor_t client_input, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
tosa_status_t tosa_run_add(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_arithmetic_right_shift(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
const bool client_round,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_bitwise_and(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_bitwise_or(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_bitwise_xor(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_intdiv(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_logical_and(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_logical_left_shift(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_logical_right_shift(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_logical_or(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_logical_xor(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_maximum(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_minimum(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_mul(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
const int32_t client_shift,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_pow(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_sub(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_table(tosa_tensor_t client_input,
const int32_t client_table_len,
const int16_t client_table[],
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_abs(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t tosa_run_abs(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_bitwise_not(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t
+ tosa_run_bitwise_not(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_ceil(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t tosa_run_ceil(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_clz(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t tosa_run_clz(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_exp(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t tosa_run_exp(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_floor(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t tosa_run_floor(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_log(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t tosa_run_log(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_logical_not(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t
+ tosa_run_logical_not(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
tosa_status_t tosa_run_negate(tosa_tensor_t client_input1,
const int32_t client_input1_zp,
const int32_t client_output_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_reciprocal(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t
+ tosa_run_reciprocal(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_rsqrt(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t tosa_run_rsqrt(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
tosa_status_t tosa_run_select(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_input3,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_equal(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_greater(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_greater_equal(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_reduce_all(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_reduce_any(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_reduce_max(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_reduce_min(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_reduce_product(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_reduce_sum(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_concat(tosa_tensor_t client_input1,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_pad(tosa_tensor_t client_input1,
- const int32_t client_padding_len,
- const int32_t client_padding[],
+ tosa_tensor_t client_padding,
const int32_t client_pad_const_int,
const float client_pad_const_fp,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_dim(tosa_tensor_t client_input1,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_reshape(tosa_tensor_t client_input1,
+ tosa_tensor_t client_shape,
const int32_t client_new_shape_len,
const int32_t client_new_shape[],
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_reverse(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_slice(tosa_tensor_t client_input1,
const int32_t client_start_len,
@@ -411,49 +338,39 @@ extern "C"
const int32_t client_size_len,
const int32_t client_size[],
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_tile(tosa_tensor_t client_input1,
- const int32_t client_multiples_len,
- const int32_t client_multiples[],
+ tosa_tensor_t client_multiples,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_transpose(tosa_tensor_t client_input1,
const int32_t client_perms_len,
const int32_t client_perms[],
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_gather(tosa_tensor_t client_values,
tosa_tensor_t client_indices,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_scatter(tosa_tensor_t client_values_in,
tosa_tensor_t client_indices,
tosa_tensor_t client_input,
tosa_tensor_t client_values_out,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
tosa_status_t tosa_run_resize(tosa_tensor_t client_input,
- const int16_t client_scale[4],
- const int16_t client_offset[2],
- const int16_t client_border[2],
+ tosa_tensor_t client_scale,
+ tosa_tensor_t client_offset,
+ tosa_tensor_t client_border,
const tosa_mode_t client_mode,
tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_cast(tosa_tensor_t client_input,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t tosa_run_cast(tosa_tensor_t client_input, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
tosa_status_t tosa_run_rescale(tosa_tensor_t client_input,
tosa_tensor_t client_output,
@@ -468,16 +385,13 @@ extern "C"
const bool client_input_unsigned,
const bool client_output_unsigned,
const bool client_per_channel,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ const func_ctx_t& func_ctx);
- tosa_status_t tosa_run_identity(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config = func_config_t{},
- const func_debug_t& func_debug = func_debug_t{});
+ tosa_status_t
+ tosa_run_identity(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx);
#ifdef __cplusplus
}
#endif /* __cplusplus */
-#endif // OPERATORS_H_
+#endif // OPERATORS_H_ \ No newline at end of file
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 6c9b067..9b3721b 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -103,12 +103,10 @@ extern "C"
tosa_status_t tosa_run_argmax(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t axis = client_axis;
- TosaAxisAttribute attr(axis);
+ TosaAxisAttribute attr(client_axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -123,7 +121,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -143,17 +141,14 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_output_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> kernel(&client_kernel[0], &client_kernel[2]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
- const int32_t input_zp = client_input_zp;
- const int32_t output_zp = client_output_zp;
const tosa::DType accum_dtype = tosa::DType::DType_FP32;
- TosaPoolAttribute attr(pad, kernel, stride, input_zp, output_zp, accum_dtype);
+ TosaPoolAttribute attr(pad, kernel, stride, client_input_zp, client_output_zp, accum_dtype);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -168,7 +163,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -190,16 +185,13 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_weight_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[2]);
- const int32_t input_zp = client_input_zp;
- const int32_t weight_zp = client_weight_zp;
- TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp);
+ TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -218,7 +210,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
TOSA_RETURN_ON_ERROR(runner.setInput(weight->GetName(), client_weight.data, client_weight.size));
@@ -242,16 +234,13 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_weight_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
const std::vector<int32_t> pad(&client_pad[0], &client_pad[6]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[3]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[3]);
- const int32_t input_zp = client_input_zp;
- const int32_t weight_zp = client_weight_zp;
- TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp);
+ TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -270,7 +259,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
TOSA_RETURN_ON_ERROR(runner.setInput(weight->GetName(), client_weight.data, client_weight.size));
@@ -294,16 +283,13 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_weight_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[2]);
- const int32_t input_zp = client_input_zp;
- const int32_t weight_zp = client_weight_zp;
- TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp);
+ TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -322,7 +308,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
TOSA_RETURN_ON_ERROR(runner.setInput(weight->GetName(), client_weight.data, client_weight.size));
@@ -343,13 +329,10 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_weight_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t input_zp = client_input_zp;
- const int32_t weight_zp = client_weight_zp;
- TosaFullyConnectedAttribute attr(input_zp, weight_zp);
+ TosaFullyConnectedAttribute attr(client_input_zp, client_weight_zp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -368,7 +351,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
TOSA_RETURN_ON_ERROR(runner.setInput(weight->GetName(), client_weight.data, client_weight.size));
@@ -388,13 +371,10 @@ extern "C"
const int32_t client_a_zp,
const int32_t client_b_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t a_zp = client_a_zp;
- const int32_t b_zp = client_b_zp;
- TosaMatMulAttribute attr(a_zp, b_zp);
+ TosaMatMulAttribute attr(client_a_zp, client_b_zp);
// Create tensors
tosa::TosaSerializationTensor* a = translate_client_tensor(client_a, "a");
@@ -410,7 +390,7 @@ extern "C"
{ a->GetName(), b->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(a->GetName(), client_a.data, client_a.size));
TOSA_RETURN_ON_ERROR(runner.setInput(b->GetName(), client_b.data, client_b.size));
@@ -431,17 +411,14 @@ extern "C"
const int32_t client_input_zp,
const int32_t client_output_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> kernel(&client_kernel[0], &client_kernel[2]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
- const int32_t input_zp = client_input_zp;
- const int32_t output_zp = client_output_zp;
const tosa::DType accum_dtype = tosa::DType::DType_FP32;
- TosaPoolAttribute attr(pad, kernel, stride, input_zp, output_zp, accum_dtype);
+ TosaPoolAttribute attr(pad, kernel, stride, client_input_zp, client_output_zp, accum_dtype);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -456,7 +433,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -480,16 +457,13 @@ extern "C"
const int32_t client_dilation_len,
const int32_t client_dilation[],
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
const std::vector<int32_t> pad(&client_pad[0], &client_pad[0] + client_pad_len);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[0] + client_dilation_len);
- const int32_t input_zp = client_input_zp;
- const int32_t weight_zp = client_weight_zp;
- TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp);
+ TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -508,7 +482,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
TOSA_RETURN_ON_ERROR(runner.setInput(weight->GetName(), client_weight.data, client_weight.size));
@@ -529,15 +503,10 @@ extern "C"
const float client_min_fp,
const float client_max_fp,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t min_int = client_min_int;
- const int32_t max_int = client_max_int;
- const float min_fp = client_min_fp;
- const float max_fp = client_max_fp;
- TosaClampAttribute attr(min_int, max_int, min_fp, max_fp);
+ TosaClampAttribute attr(client_min_int, client_max_int, client_min_fp, client_max_fp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -552,7 +521,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -565,10 +534,7 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_erf(tosa_tensor_t client_input,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t tosa_run_erf(tosa_tensor_t client_input, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -586,7 +552,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -599,10 +565,7 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_sigmoid(tosa_tensor_t client_input,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t tosa_run_sigmoid(tosa_tensor_t client_input, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -620,7 +583,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -633,10 +596,7 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_tanh(tosa_tensor_t client_input,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t tosa_run_tanh(tosa_tensor_t client_input, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -654,7 +614,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -670,8 +630,7 @@ extern "C"
tosa_status_t tosa_run_add(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -690,7 +649,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -708,12 +667,10 @@ extern "C"
tosa_tensor_t client_input2,
const bool client_round,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const bool round = client_round;
- TosaArithmeticRightShiftAttribute attr(round);
+ TosaArithmeticRightShiftAttribute attr(client_round);
// Create tensors
tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1");
@@ -730,7 +687,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -747,8 +704,7 @@ extern "C"
tosa_status_t tosa_run_bitwise_and(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -767,7 +723,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -784,8 +740,7 @@ extern "C"
tosa_status_t tosa_run_bitwise_or(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -804,7 +759,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -821,8 +776,7 @@ extern "C"
tosa_status_t tosa_run_bitwise_xor(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -841,7 +795,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -858,8 +812,7 @@ extern "C"
tosa_status_t tosa_run_intdiv(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -878,7 +831,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -895,8 +848,7 @@ extern "C"
tosa_status_t tosa_run_logical_and(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -915,7 +867,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -932,8 +884,7 @@ extern "C"
tosa_status_t tosa_run_logical_left_shift(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -953,7 +904,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -970,8 +921,7 @@ extern "C"
tosa_status_t tosa_run_logical_right_shift(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -991,7 +941,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1008,8 +958,7 @@ extern "C"
tosa_status_t tosa_run_logical_or(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1028,7 +977,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1045,8 +994,7 @@ extern "C"
tosa_status_t tosa_run_logical_xor(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1065,7 +1013,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1082,8 +1030,7 @@ extern "C"
tosa_status_t tosa_run_maximum(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1102,7 +1049,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1119,8 +1066,7 @@ extern "C"
tosa_status_t tosa_run_minimum(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1139,7 +1085,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1157,12 +1103,10 @@ extern "C"
tosa_tensor_t client_input2,
const int32_t client_shift,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t shift = client_shift;
- TosaMulAttribute attr(shift);
+ TosaMulAttribute attr(client_shift);
// Create tensors
tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1");
@@ -1178,7 +1122,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1195,8 +1139,7 @@ extern "C"
tosa_status_t tosa_run_pow(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1215,7 +1158,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1232,8 +1175,7 @@ extern "C"
tosa_status_t tosa_run_sub(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1252,7 +1194,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1270,8 +1212,7 @@ extern "C"
const int32_t client_table_len,
const int16_t client_table[],
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
const std::vector<int16_t> table(&client_table[0], &client_table[0] + client_table_len);
@@ -1290,7 +1231,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -1303,10 +1244,7 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_abs(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t tosa_run_abs(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1324,7 +1262,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -1337,10 +1275,8 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_bitwise_not(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t
+ tosa_run_bitwise_not(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1358,7 +1294,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -1371,10 +1307,7 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_ceil(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t tosa_run_ceil(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1392,7 +1325,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -1405,10 +1338,7 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_clz(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t tosa_run_clz(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1426,7 +1356,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -1439,10 +1369,7 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_exp(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t tosa_run_exp(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1460,7 +1387,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -1473,10 +1400,7 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_floor(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t tosa_run_floor(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1494,7 +1418,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -1507,10 +1431,7 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_log(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t tosa_run_log(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1528,7 +1449,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -1541,10 +1462,8 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_logical_not(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t
+ tosa_run_logical_not(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1562,7 +1481,7 @@ extern "C"
{ input1->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -1579,13 +1498,10 @@ extern "C"
const int32_t client_input1_zp,
const int32_t client_output_zp,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t input1_zp = client_input1_zp;
- const int32_t output_zp = client_output_zp;
- TosaNegateAttribute attr(input1_zp, output_zp);
+ TosaNegateAttribute attr(client_input1_zp, client_output_zp);
// Create tensors
tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1");
@@ -1600,7 +1516,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -1613,10 +1529,8 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_reciprocal(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t
+ tosa_run_reciprocal(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1634,7 +1548,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -1647,10 +1561,7 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_rsqrt(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t tosa_run_rsqrt(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1668,7 +1579,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -1685,8 +1596,7 @@ extern "C"
tosa_tensor_t client_input2,
tosa_tensor_t client_input3,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1708,7 +1618,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1726,8 +1636,7 @@ extern "C"
tosa_status_t tosa_run_equal(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1746,7 +1655,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1763,8 +1672,7 @@ extern "C"
tosa_status_t tosa_run_greater(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1783,7 +1691,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1800,8 +1708,7 @@ extern "C"
tosa_status_t tosa_run_greater_equal(tosa_tensor_t client_input1,
tosa_tensor_t client_input2,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -1821,7 +1728,7 @@ extern "C"
{ input1->GetName(), input2->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
TOSA_RETURN_ON_ERROR(runner.setInput(input2->GetName(), client_input2.data, client_input2.size));
@@ -1838,12 +1745,10 @@ extern "C"
tosa_status_t tosa_run_reduce_all(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t axis = client_axis;
- TosaAxisAttribute attr(axis);
+ TosaAxisAttribute attr(client_axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -1858,7 +1763,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -1874,12 +1779,10 @@ extern "C"
tosa_status_t tosa_run_reduce_any(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t axis = client_axis;
- TosaAxisAttribute attr(axis);
+ TosaAxisAttribute attr(client_axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -1894,7 +1797,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -1910,12 +1813,10 @@ extern "C"
tosa_status_t tosa_run_reduce_max(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t axis = client_axis;
- TosaAxisAttribute attr(axis);
+ TosaAxisAttribute attr(client_axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -1930,7 +1831,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -1946,12 +1847,10 @@ extern "C"
tosa_status_t tosa_run_reduce_min(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t axis = client_axis;
- TosaAxisAttribute attr(axis);
+ TosaAxisAttribute attr(client_axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -1966,7 +1865,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -1982,12 +1881,10 @@ extern "C"
tosa_status_t tosa_run_reduce_product(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t axis = client_axis;
- TosaAxisAttribute attr(axis);
+ TosaAxisAttribute attr(client_axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -2003,7 +1900,7 @@ extern "C"
{ input->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -2019,12 +1916,10 @@ extern "C"
tosa_status_t tosa_run_reduce_sum(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t axis = client_axis;
- TosaAxisAttribute attr(axis);
+ TosaAxisAttribute attr(client_axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -2039,7 +1934,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -2055,12 +1950,10 @@ extern "C"
tosa_status_t tosa_run_concat(tosa_tensor_t client_input1,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t axis = client_axis;
- TosaAxisAttribute attr(axis);
+ TosaAxisAttribute attr(client_axis);
// Create tensors
tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1");
@@ -2075,7 +1968,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -2089,19 +1982,18 @@ extern "C"
}
tosa_status_t tosa_run_pad(tosa_tensor_t client_input1,
- const int32_t client_padding_len,
- const int32_t client_padding[],
+ tosa_tensor_t client_padding,
const int32_t client_pad_const_int,
const float client_pad_const_fp,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const std::vector<int32_t> padding(&client_padding[0], &client_padding[0] + client_padding_len);
- const int32_t pad_const_int = client_pad_const_int;
- const float pad_const_fp = client_pad_const_fp;
- TosaPadAttribute attr(padding, pad_const_int, pad_const_fp);
+ std::vector<int32_t> padding;
+ size_t padding_size = client_padding.size / sizeof(int32_t);
+ int32_t* padding_data = reinterpret_cast<int32_t*>(client_padding.data);
+ padding.assign(padding_data, padding_data + padding_size);
+ TosaPadAttribute attr(padding, client_pad_const_int, client_pad_const_fp);
// Create tensors
tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1");
@@ -2116,7 +2008,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -2132,8 +2024,7 @@ extern "C"
tosa_status_t tosa_run_dim(tosa_tensor_t client_input1,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaAxisAttribute attr(client_axis);
@@ -2151,22 +2042,25 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
// Execute
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
+
+ // Extract outputs
TOSA_RETURN_ON_ERROR(runner.getOutput(output->GetName(), client_output.data, client_output.size));
return tosa_status_valid;
}
tosa_status_t tosa_run_reshape(tosa_tensor_t client_input1,
+ tosa_tensor_t client_shape,
const int32_t client_new_shape_len,
const int32_t client_new_shape[],
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
const std::vector<int32_t> new_shape(&client_new_shape[0], &client_new_shape[0] + client_new_shape_len);
@@ -2174,20 +2068,23 @@ extern "C"
// Create tensors
tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1");
+ tosa::TosaSerializationTensor* shape = translate_client_tensor(client_shape, "shape");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_RESHAPE, tosa::Attribute::Attribute_ReshapeAttribute,
- &attr, { input1->GetName() }, { output->GetName() });
+ auto op =
+ new tosa::TosaSerializationOperator(tosa::Op::Op_RESHAPE, tosa::Attribute::Attribute_ReshapeAttribute,
+ &attr, { input1->GetName(), shape->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reshape", "main", { op }, { input1, output }, { input1->GetName() },
- { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("reshape", "main", { op }, { input1, shape, output },
+ { input1->GetName(), shape->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
+ TOSA_RETURN_ON_ERROR(runner.setInput(shape->GetName(), client_shape.data, client_shape.size));
// Execute
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
@@ -2201,12 +2098,10 @@ extern "C"
tosa_status_t tosa_run_reverse(tosa_tensor_t client_input,
const int32_t client_axis,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t axis = client_axis;
- TosaAxisAttribute attr(axis);
+ TosaAxisAttribute attr(client_axis);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -2221,7 +2116,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -2240,8 +2135,7 @@ extern "C"
const int32_t client_size_len,
const int32_t client_size[],
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
const std::vector<int32_t> start(&client_start[0], &client_start[0] + client_start_len);
@@ -2261,7 +2155,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -2275,14 +2169,15 @@ extern "C"
}
tosa_status_t tosa_run_tile(tosa_tensor_t client_input1,
- const int32_t client_multiples_len,
- const int32_t client_multiples[],
+ tosa_tensor_t client_multiples,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const std::vector<int32_t> multiples(&client_multiples[0], &client_multiples[0] + client_multiples_len);
+ std::vector<int32_t> multiples;
+ size_t multiples_size = client_multiples.size / sizeof(int32_t);
+ int32_t* multiples_data = reinterpret_cast<int32_t*>(client_multiples.data);
+ multiples.assign(multiples_data, multiples_data + multiples_size);
TosaTileAttribute attr(multiples);
// Create tensors
@@ -2298,7 +2193,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -2315,8 +2210,7 @@ extern "C"
const int32_t client_perms_len,
const int32_t client_perms[],
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
const std::vector<int32_t> perms(&client_perms[0], &client_perms[0] + client_perms_len);
@@ -2336,7 +2230,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
@@ -2352,8 +2246,7 @@ extern "C"
tosa_status_t tosa_run_gather(tosa_tensor_t client_values,
tosa_tensor_t client_indices,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -2372,7 +2265,7 @@ extern "C"
{ values->GetName(), indices->GetName() }, { output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(values->GetName(), client_values.data, client_values.size));
TOSA_RETURN_ON_ERROR(runner.setInput(indices->GetName(), client_indices.data, client_indices.size));
@@ -2390,8 +2283,7 @@ extern "C"
tosa_tensor_t client_indices,
tosa_tensor_t client_input,
tosa_tensor_t client_values_out,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -2413,7 +2305,7 @@ extern "C"
{ values_out->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(values_in->GetName(), client_values_in.data, client_values_in.size));
TOSA_RETURN_ON_ERROR(runner.setInput(indices->GetName(), client_indices.data, client_indices.size));
@@ -2429,18 +2321,26 @@ extern "C"
}
tosa_status_t tosa_run_resize(tosa_tensor_t client_input,
- const int16_t client_scale[4],
- const int16_t client_offset[2],
- const int16_t client_border[2],
+ tosa_tensor_t client_scale,
+ tosa_tensor_t client_offset,
+ tosa_tensor_t client_border,
const tosa_mode_t client_mode,
tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
- {
- // Create operator attributes
- const std::vector<int16_t> scale(&client_scale[0], &client_scale[4]);
- const std::vector<int16_t> offset(&client_offset[0], &client_offset[2]);
- const std::vector<int16_t> border(&client_border[0], &client_border[2]);
+ const func_ctx_t& func_ctx)
+ {
+ // Create operator attributes
+ std::vector<int16_t> scale;
+ size_t scale_size = client_scale.size / sizeof(int16_t);
+ int16_t* scale_data = reinterpret_cast<int16_t*>(client_scale.data);
+ scale.assign(scale_data, scale_data + scale_size);
+ std::vector<int16_t> offset;
+ size_t offset_size = client_offset.size / sizeof(int16_t);
+ int16_t* offset_data = reinterpret_cast<int16_t*>(client_offset.data);
+ offset.assign(offset_data, offset_data + offset_size);
+ std::vector<int16_t> border;
+ size_t border_size = client_border.size / sizeof(int16_t);
+ int16_t* border_data = reinterpret_cast<int16_t*>(client_border.data);
+ border.assign(border_data, border_data + border_size);
const ResizeMode mode = translate_client_tosa_mode(client_mode);
TosaResizeAttribute attr(scale, offset, border, mode);
@@ -2457,7 +2357,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -2470,10 +2370,7 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_cast(tosa_tensor_t client_input,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t tosa_run_cast(tosa_tensor_t client_input, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -2491,7 +2388,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -2517,21 +2414,14 @@ extern "C"
const bool client_input_unsigned,
const bool client_output_unsigned,
const bool client_per_channel,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ const func_ctx_t& func_ctx)
{
// Create operator attributes
- const int32_t input_zp = client_input_zp;
- const int32_t output_zp = client_output_zp;
const std::vector<int32_t> multiplier(&client_multiplier[0], &client_multiplier[0] + client_multiplier_len);
const std::vector<int32_t> shift(&client_shift[0], &client_shift[0] + client_shift_len);
- const bool scale32 = client_scale32;
- const bool double_round = client_double_round;
- const bool per_channel = client_per_channel;
- const bool input_unsigned = client_input_unsigned;
- const bool output_unsigned = client_output_unsigned;
- TosaRescaleAttribute attr(input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel,
- input_unsigned, output_unsigned);
+ TosaRescaleAttribute attr(client_input_zp, client_output_zp, multiplier, shift, client_scale32,
+ client_double_round, client_per_channel, client_input_unsigned,
+ client_output_unsigned);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -2546,7 +2436,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
@@ -2559,10 +2449,8 @@ extern "C"
return tosa_status_valid;
}
- tosa_status_t tosa_run_identity(tosa_tensor_t client_input1,
- tosa_tensor_t client_output,
- const func_config_t& func_config,
- const func_debug_t& func_debug)
+ tosa_status_t
+ tosa_run_identity(tosa_tensor_t client_input1, tosa_tensor_t client_output, const func_ctx_t& func_ctx)
{
// Create operator attributes
TosaNoneAttribute attr;
@@ -2580,7 +2468,7 @@ extern "C"
{ output->GetName() });
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
diff --git a/reference_model/test/model_runner_tests.cpp b/reference_model/test/model_runner_tests.cpp
index 0b73494..820ed63 100644
--- a/reference_model/test/model_runner_tests.cpp
+++ b/reference_model/test/model_runner_tests.cpp
@@ -75,7 +75,7 @@ TEST_SUITE("model_runner")
output.size = dstData.size() * sizeof(float);
// Execution
- auto status = tosa_run_add(input1, input2, output);
+ auto status = tosa_run_add(input1, input2, output, {});
CHECK((status == tosa_status_valid));
// Compare results
@@ -112,7 +112,7 @@ TEST_SUITE("model_runner")
output.size = dstData.size() * sizeof(float);
// Execution
- auto status = tosa_run_avg_pool2d(input, kernel, stride, pad, 0, 0, output);
+ auto status = tosa_run_avg_pool2d(input, kernel, stride, pad, 0, 0, output, {});
CHECK((status == tosa_status_valid));
// Compare results
@@ -170,7 +170,7 @@ TEST_SUITE("model_runner")
const int32_t weight_zp = 0;
// Execution
- auto status = tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, output);
+ auto status = tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, output, {});
CHECK((status == tosa_status_valid));
// Compare results
@@ -228,10 +228,10 @@ TEST_SUITE("model_runner")
const int32_t weight_zp = 0;
// Execution
- func_config_t func_config;
- func_config.abs_mode = true;
+ func_ctx_t func_ctx;
+ func_ctx.func_config.abs_mode = true;
auto status =
- tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, output, func_config);
+ tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, output, func_ctx);
CHECK((status == tosa_status_valid));
// Compare results
@@ -269,7 +269,7 @@ TEST_SUITE("model_runner")
output.size = dstData.size() * sizeof(float);
// Execution
- auto status = tosa_run_max_pool2d(input, kernel, stride, pad, 0, 0, output);
+ auto status = tosa_run_max_pool2d(input, kernel, stride, pad, 0, 0, output, {});
CHECK((status == tosa_status_valid));
// Compare results
@@ -280,10 +280,12 @@ TEST_SUITE("model_runner")
TEST_CASE("op_entry_pad")
{
// Inputs/Outputs
- tosa_datatype_t dt = tosa_datatype_fp32_t;
- std::vector<int32_t> input_shape = { 2, 2 };
- std::vector<int32_t> output_shape = { 4, 4 };
+ tosa_datatype_t dt = tosa_datatype_fp32_t;
+ std::vector<int32_t> input_shape = { 2, 2 };
+ std::vector<int32_t> padding_shape = { 1, 4 };
+ std::vector<int32_t> output_shape = { 4, 4 };
std::vector<float> srcData1(4, 4.0f);
+ std::vector<int32_t> padData(4, 1);
std::vector<float> dstData(16, 0.0f);
tosa_tensor_t input1;
@@ -293,6 +295,13 @@ TEST_SUITE("model_runner")
input1.data = reinterpret_cast<uint8_t*>(srcData1.data());
input1.size = srcData1.size() * sizeof(float);
+ tosa_tensor_t padding;
+ padding.shape = padding_shape.data();
+ padding.num_dims = padding_shape.size();
+ padding.data_type = tosa_datatype_int32_t;
+ padding.data = reinterpret_cast<uint8_t*>(padData.data());
+ padding.size = padData.size() * sizeof(int32_t);
+
tosa_tensor_t output;
output.shape = output_shape.data();
output.num_dims = output_shape.size();
@@ -301,11 +310,9 @@ TEST_SUITE("model_runner")
output.size = dstData.size() * sizeof(float);
// Execution
- int32_t padding[4] = { 1, 1, 1, 1 };
- int32_t padding_len = 4;
int32_t pad_const_int = 0;
float pad_const_fp = 5.0f;
- auto status = tosa_run_pad(input1, padding_len, padding, pad_const_int, pad_const_fp, output);
+ auto status = tosa_run_pad(input1, padding, pad_const_int, pad_const_fp, output, func_ctx_t{});
CHECK((status == tosa_status_valid));
// Compare results
@@ -318,6 +325,47 @@ TEST_SUITE("model_runner")
compareOutput(dstData, expectedData, expectedData.size());
}
+ TEST_CASE("op_entry_tile")
+ {
+ // Inputs/Outputs
+ tosa_datatype_t dt = tosa_datatype_fp32_t;
+ std::vector<int32_t> input_shape = { 2, 3 };
+ std::vector<int32_t> multiples_shape = { 1, 2 };
+ std::vector<int32_t> output_shape = { 2, 6 };
+ std::vector<float> srcData1 = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
+ std::vector<int32_t> multiples_data = { 1, 2 };
+ std::vector<float> dstData(12, 0.0f);
+
+ tosa_tensor_t input1;
+ input1.shape = input_shape.data();
+ input1.num_dims = input_shape.size();
+ input1.data_type = dt;
+ input1.data = reinterpret_cast<uint8_t*>(srcData1.data());
+ input1.size = srcData1.size() * sizeof(float);
+
+ tosa_tensor_t multiples;
+ multiples.shape = multiples_shape.data();
+ multiples.num_dims = multiples_shape.size();
+ multiples.data_type = tosa_datatype_int32_t;
+ multiples.data = reinterpret_cast<uint8_t*>(multiples_data.data());
+ multiples.size = multiples_data.size() * sizeof(int32_t);
+
+ tosa_tensor_t output;
+ output.shape = output_shape.data();
+ output.num_dims = output_shape.size();
+ output.data_type = dt;
+ output.data = reinterpret_cast<uint8_t*>(dstData.data());
+ output.size = dstData.size() * sizeof(float);
+
+ // Execution
+ auto status = tosa_run_tile(input1, multiples, output, {});
+ CHECK((status == tosa_status_valid));
+
+ // Compare results
+ std::vector<float> expectedData = { 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 4.0, 5.0, 6.0 };
+ compareOutput(dstData, expectedData, expectedData.size());
+ }
+
TEST_CASE("simple_add_f32_test")
{
std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_add_1x4x4x4_f32/");
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index f1cb6e0..c5c762d 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -81,64 +81,87 @@ def getSerializeOpType(tosaOpName):
return map[tosaOpName]
-def getSerializeArgsForOp(tosaOpName, allSerializeArgs, tosaArgs):
+def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
"""
- Returns the arguments required by the Serialization library for the TOSA operator specified.
- Generates code to initialize Serialization arguments. If a matching TOSA argument exists,
+ Returns the attributes required by the Serialization library for the TOSA operator specified.
+ Generates code to initialize Serialization library attributes. If a matching TOSA argument exists,
that value is used for initialization, otherwise a default value e.g. 0 is used.
"""
- serOpType = getSerializeOpType(tosaOpName)
- if serOpType not in allSerializeArgs.keys():
+ serLibOpType = getSerializeOpType(tosaOpName)
+ if serLibOpType not in allSerialLibAtts.keys():
return {}
else:
- serOpArgs = copy.deepcopy(allSerializeArgs[serOpType])
+ serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
serTosaTypeMap = {"ResizeMode": "tosa_mode"}
- for arg in serOpArgs:
- argName = arg["name"]
+ for att in serLibOpAtts:
+ attName = att["name"]
+ attType = att["dType"]
init = ""
- # Translate TOSA data types to Serialization data types for initialization
- if arg["dType"] in serTosaTypeMap.keys():
- init = f" = translate_client_{serTosaTypeMap[arg['dType']]}(client_{argName})"
- # Initialize Serialization arguments to their matching function parameter
- elif argName in tosaArgsDict:
- if arg["SV"] == "V":
- shape = tosaArgsDict[argName]["shape"]
- if shape == "[]":
- init = f"(&client_{argName}[0], &client_{argName}[0] + client_{argName}_len)"
+ # Translate TOSA data types to Serialization library data types for initialization
+ if attType in serTosaTypeMap.keys():
+ init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});"
+ # Initialize Serialization library attributes to their matching function parameter
+ elif attName in tosaArgsDict:
+ if att["SV"] == "V":
+ if tosaArgsDict[attName]["type"] == "tosa_tensor_t":
+ init = f"std::vector<{attType}> {attName};"
+ init = (
+ init
+ + f"size_t {attName}_size = client_{attName}.size / sizeof({attType});"
+ )
+ init = (
+ init
+ + f"{attType}* {attName}_data = reinterpret_cast<{attType}*>(client_{attName}.data);"
+ )
+ init = (
+ init
+ + f"{attName}.assign({attName}_data, {attName}_data + {attName}_size);"
+ )
else:
- init = f"(&client_{argName}[0], &client_{argName}{shape})"
+ init = f"const std::vector<{attType}> {attName}"
+ shape = tosaArgsDict[attName]["shape"]
+ if shape == "[]":
+ init = (
+ init
+ + f"(&client_{attName}[0], &client_{attName}[0] + client_{attName}_len);"
+ )
+ else:
+ init = (
+ init
+ + f"(&client_{attName}[0], &client_{attName}{shape});"
+ )
else:
- init = f" = client_{argName}"
- else:
- # Initialize Serialization arguments with no matching fuction parameter
- if arg["SV"] == "V":
init = ""
+ else:
+ # Initialize Serialization library attributes with no matching fuction parameter
+ if att["SV"] == "V":
+ init = f"std::vector<int32_t> {attName};"
else:
- if arg["dType"] == "DType":
- arg["dType"] = "tosa::DType"
- init = " = tosa::DType::DType_FP32"
+ if att["dType"] == "DType":
+ att["dType"] = "tosa::DType"
+ init = f"const tosa::DType {attName} = tosa::DType::DType_FP32;"
else:
- init = " = 0"
- arg["init"] = init
- return serOpArgs
+ init = f"const {attType} {attName} = 0;"
+ att["init"] = init
+ return serLibOpAtts
-def updateTosaArgs(tosaArgs, serializeArgs, tosaXml):
+def updateTosaArgs(tosaArgs, serialLibAtts, tosaXml):
"""
- Replace TOSA argument data types with their matching Serialization argument data types.
+ Replace TOSA argument data types with their matching Serialization attribute data types.
Delete TOSA arguments where the type couldn't be determined.
- Add Serialization arguments that have no matching TOSA argument.
+ Add Serialization attributes that have no matching TOSA argument.
"""
tosaArgTypes = getTosaArgTypes(tosaXml)
- serArgsDict = {arg["name"]: arg for arg in serializeArgs}
+ serAttsDict = {att["name"]: att for att in serialLibAtts}
tosaArgsNames = [arg["name"] for arg in tosaArgs]
delTosaArgs = []
- # Replace TOSA argument data types with their matching Serialization argument data types.
+ # Replace TOSA argument data types with their matching Serialization attribute data types.
for tosaArg in tosaArgs:
if tosaArg["type"] in tosaArgTypes:
- if tosaArg["name"] in serArgsDict:
- tosaArg["type"] = serArgsDict[tosaArg["name"]]["dType"]
+ if tosaArg["name"] in serAttsDict:
+ tosaArg["type"] = serAttsDict[tosaArg["name"]]["dType"]
else:
# Delete TOSA argument whose data type can't be determined
delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
@@ -149,36 +172,36 @@ def updateTosaArgs(tosaArgs, serializeArgs, tosaXml):
# Delete TOSA arguments where the type couldn't be determined
for index in sorted(delTosaArgs, key=int, reverse=True):
del tosaArgs[index]
- # Add Serialization arguments that have no matching TOSA argument
+ # Add Serialization attributes that have no matching TOSA argument
tosaArgNames = [arg["name"] for arg in tosaArgs]
- for serArg in serializeArgs:
- if (serArg["name"] not in tosaArgNames) and (
- not serArg["dType"] == "tosa::DType"
- ):
- serArgName = serArg["name"]
- if serArg["SV"] == "V":
+ for serAtt in serialLibAtts:
+ attName = serAtt["name"]
+ attType = serAtt["dType"]
+ if (attName not in tosaArgNames) and (not attType == "tosa::DType"):
+ serAttName = serAtt["name"]
+ if serAtt["SV"] == "V":
# For vector data types, insert a matching length argument
tosaArgs.insert(
len(tosaArgs) - 1,
{
- "name": f"{serArgName}_len",
+ "name": f"{serAttName}_len",
"type": "int32_t",
"shape": "",
"category": "",
},
)
- init = f"(&client_{serArgName}[0], &client_{serArgName}[0] + client_{serArgName}_len)"
+ init = f"const std::vector<{attType}> {attName}(&client_{serAttName}[0], &client_{serAttName}[0] + client_{serAttName}_len);"
shape = "[]"
else:
- init = f" = client_{serArg['name']}"
+ init = ""
shape = ""
- serArg["init"] = init
+ serAtt["init"] = init
# Insert new argument
tosaArgs.insert(
len(tosaArgs) - 1,
{
- "name": serArgName,
- "type": serArg["dType"],
+ "name": serAttName,
+ "type": serAtt["dType"],
"shape": shape,
"category": "",
},
@@ -190,33 +213,47 @@ def getOperators(tosaXml):
Return a list of TOSA operators as defined by tosa.xml.
"""
operators = []
- ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"]
+ ignoreOps = [
+ "while_loop",
+ "cond_if",
+ "const",
+ "custom",
+ "fft2d",
+ "rfft2d",
+ "variable",
+ "variable_read",
+ "variable_write",
+ ]
opsXml = tosaXml.getElementsByTagName("operator")
- allSerializeArgs = getSerializeArgs()
+ allSerialLibAtts = getSerialLibAtts()
for opXml in opsXml:
opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
if opName not in ignoreOps:
operator = {"name": opName}
operator["serializeAttType"] = getSerializeOpType(opName)
tosaArgs = getTosaArgs(opXml)
- serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs)
+ serialLibAtts = getSerialLibAttsForOp(opName, allSerialLibAtts, tosaArgs)
# Handle "axis" arguments
axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"]
if operator["serializeAttType"] == "None" and len(axisList) > 0:
operator["serializeAttType"] = "Axis"
- serializeArgs = [
+ serialLibAtts = [
{
"name": "axis",
"dType": "int32_t",
"SV": "S",
- "init": "= client_axis",
+ "init": "",
}
]
- updateTosaArgs(tosaArgs, serializeArgs, tosaXml)
+ updateTosaArgs(tosaArgs, serialLibAtts, tosaXml)
operator["arguments"] = tosaArgs
- operator["serializeArgs"] = serializeArgs
+ operator["serialLibAtts"] = serialLibAtts
+ serializationAttNames = [att["name"] for att in serialLibAtts]
operator["inputs"] = [
- arg["name"] for arg in tosaArgs if arg["category"] == "input"
+ arg["name"]
+ for arg in tosaArgs
+ if arg["category"] == "input"
+ and arg["name"] not in serializationAttNames
]
operator["outputs"] = [
arg["name"] for arg in tosaArgs if arg["category"] == "output"
@@ -283,12 +320,12 @@ def clangFormat(filename):
subprocess.check_call(cmd, stdout=devnull)
-def getSerializeArgs():
+def getSerialLibAtts():
"""
Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
The values are the arguments required by each Serialization library operator.
"""
- serializeArgs = {}
+ serialLibAtts = {}
with open("../../thirdparty/serialization_lib/include/attribute.def") as file:
preamble = True
inAtt = False
@@ -315,11 +352,11 @@ def getSerializeArgs():
}
args.append(arg)
if ")" in line:
- serializeArgs[opName] = args
+ serialLibAtts[opName] = args
opName = ""
args = []
inAtt = False
- return serializeArgs
+ return serialLibAtts
def renderTemplate(environment, dataTypes, operators, template, outfile):
@@ -349,12 +386,12 @@ def getSerializeOpTypeMap():
"""
import re
- allSerializeArgs = getSerializeArgs()
- serArgs = [
+ allSerialLibAtts = getSerialLibAtts()
+ serAtts = [
re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
- for name in allSerializeArgs.keys()
+ for name in allSerialLibAtts.keys()
]
- serArgs = sorted(serArgs, key=len, reverse=True)
+ serAtts = sorted(serAtts, key=len, reverse=True)
tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
opsXml = tosaXml.getElementsByTagName("operator")
opNames = [
@@ -362,9 +399,9 @@ def getSerializeOpTypeMap():
]
map = {}
for opName in opNames:
- for serArg in serArgs:
- if serArg in opName:
- components = serArg.split("_")
+ for serAtt in serAtts:
+ if serAtt in opName:
+ components = serAtt.split("_")
map[opName] = "".join(x.title() for x in components)
return map
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
index a8f1c24..6b6f864 100644
--- a/scripts/operator_api/templates/operators_cc.j2
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -67,6 +67,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type)
return tosa::DType::DType_UINT16;
case tosa_datatype_uint8_t:
return tosa::DType::DType_UINT8;
+ case tosa_datatype_shape_t:
+ return tosa::DType::DType_SHAPE;
default:
return tosa::DType::DType_UNKNOWN;
}
@@ -99,24 +101,24 @@ extern "C"
{%- for arg in operator.arguments: -%}
{% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
{% if loop.index < operator.arguments|length %},{% endif %}
- {%- endfor -%}, const func_config_t &func_config,
- const func_debug_t &func_debug
+ {%- endfor -%},const func_ctx_t& func_ctx
)
{
// Create operator attributes
- {% for arg in operator.serializeArgs: %}
- {%- if arg.SV == "V": -%}
- const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}};
- {%- else: -%}
- const {{arg.dType}} {{arg.name}}{{arg.init}};
- {%- endif -%}
+ {% for att in operator.serialLibAtts: -%}
+ {{att.init}}
{%- endfor -%}
Tosa{{operator.serializeAttType}}Attribute attr
- {%- if operator.serializeArgs|length > 0 -%}
+ {%- if operator.serialLibAtts|length > 0 -%}
(
- {%- for arg in operator.serializeArgs: -%}
- {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %}
+ {%- for att in operator.serialLibAtts: -%}
+ {%- if att.init == "" -%}
+ client_{{att.name}}
+ {%- else -%}
+ {{att.name}}
+ {%- endif -%}
+ {% if loop.index < operator.serialLibAtts|length %}, {% endif %}
{%- endfor -%}
)
{%- endif -%};
@@ -174,7 +176,7 @@ extern "C"
});
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
{% for input in operator.inputs: -%}
TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
diff --git a/scripts/operator_api/templates/operators_h.j2 b/scripts/operator_api/templates/operators_h.j2
index 042d7a5..0c98da8 100644
--- a/scripts/operator_api/templates/operators_h.j2
+++ b/scripts/operator_api/templates/operators_h.j2
@@ -21,6 +21,7 @@
#include "func_config.h"
#include "func_debug.h"
+#include "types.h"
#include <stddef.h>
#include <stdint.h>
@@ -29,37 +30,10 @@
extern "C" {
#endif /* __cplusplus */
- // Note status needs to be aligned with graph_status
- enum tosa_status_t
+ struct func_ctx_t
{
- tosa_status_valid = 0,
- tosa_status_unpredictable = 1,
- tosa_status_error = 2
- };
-
- enum tosa_mode_t
- {
- tosa_mode_unknown = 0,
- tosa_mode_nearest = 1,
- tosa_mode_bilinear = 2,
- tosa_mode_min = 3,
- tosa_mode_max = 4
- };
-
- enum tosa_datatype_t
- {
- {% for dataType in dataTypes: -%}
- {{dataType}} = {{loop.index-1}},
- {% endfor -%}
- };
-
- struct tosa_tensor_t
- {
- int32_t* shape;
- int32_t num_dims;
- tosa_datatype_t data_type;
- uint8_t* data;
- size_t size;
+ func_config_t func_config = func_config_t{};
+ func_debug_t func_debug = func_debug_t{};
};
{% for operator in operators: %}
@@ -67,8 +41,7 @@ extern "C" {
{%- for arg in operator.arguments: -%}
{% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
{% if loop.index < operator.arguments|length %},{% endif %}
- {%- endfor -%}, const func_config_t &func_config = func_config_t{},
- const func_debug_t &func_debug = func_debug_t{});
+ {%- endfor -%},const func_ctx_t& func_ctx);
{% endfor %}
#ifdef __cplusplus
diff --git a/thirdparty/specification b/thirdparty/specification
-Subproject 8e14dcd2f86e9a3b9c2283fb0f0325088565bbe
+Subproject b5b067819e5de11153b41cf3d26da4f3f9dd23e