From 2bbd96457e3740fd9df5556607514b5e80a25720 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 4 Jul 2017 16:46:32 +0100 Subject: COMPMID-436, COMPMID-437 - Port NEConvolutionLayer & NEFullyConnectedLayer to support 16 bit fixed point Change-Id: I69edf2dac242f941bac95c8479d921e7be6abca7 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79725 Tested-by: Kaizen Reviewed-by: Pablo Tello --- tests/validation/TensorOperations.h | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'tests/validation/TensorOperations.h') diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h index 488ffa90d9..0502f53186 100644 --- a/tests/validation/TensorOperations.h +++ b/tests/validation/TensorOperations.h @@ -158,7 +158,7 @@ void convolution3d(const T *in, const T *weights, const T *bias, T *out, int xi, *out = res.raw(); } -template +template ::value, int>::type * = nullptr> void vector_matrix_multiply(const T *in, const T *weights, const T *bias, T *out, int cols_weights, int rows_weights, uint8_t fixed_point_position) { for(int x = 0; x < cols_weights; ++x) @@ -172,11 +172,12 @@ void vector_matrix_multiply(const T *in, const T *weights, const T *bias, T *out } } -template <> -void vector_matrix_multiply(const int8_t *in, const int8_t *weights, const int8_t *bias, int8_t *out, int cols_weights, int rows_weights, uint8_t fixed_point_position) +// Vector matrix multiply for fixed point type +template ::value, int>::type * = nullptr> +void vector_matrix_multiply(const T *in, const T *weights, const T *bias, T *out, int cols_weights, int rows_weights, uint8_t fixed_point_position) { using namespace fixed_point_arithmetic; - using promoted_type = typename fixed_point_arithmetic::traits::promote::type; + using promoted_type = typename fixed_point_arithmetic::traits::promote::type; for(int x = 0; x < cols_weights; ++x) { @@ -192,10 +193,10 @@ void vector_matrix_multiply(const int8_t *in, const int8_t *weights, const int8_ } // Get the bias - const fixed_point b(bias[x], fixed_point_position, true); + const fixed_point b(bias[x], fixed_point_position, true); // Convert back and accumulate the bias - fixed_point res(acc); + fixed_point res(acc); res = res + b; // Store the result -- cgit v1.2.1