aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/TensorOperations.h
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2017-07-04 16:46:32 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:15:39 +0100
commit2bbd96457e3740fd9df5556607514b5e80a25720 (patch)
tree679935dd849bdac044769dfff67516962493dd51 /tests/validation/TensorOperations.h
parent8a383694445dfebb84732b19d5b3299961e8ffe3 (diff)
downloadComputeLibrary-2bbd96457e3740fd9df5556607514b5e80a25720.tar.gz
COMPMID-436, COMPMID-437 - Port NEConvolutionLayer & NEFullyConnectedLayer to support 16 bit fixed point
Change-Id: I69edf2dac242f941bac95c8479d921e7be6abca7 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79725 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Diffstat (limited to 'tests/validation/TensorOperations.h')
-rw-r--r--tests/validation/TensorOperations.h13
1 files changed, 7 insertions, 6 deletions
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h
index 488ffa90d9..0502f53186 100644
--- a/tests/validation/TensorOperations.h
+++ b/tests/validation/TensorOperations.h
@@ -158,7 +158,7 @@ void convolution3d(const T *in, const T *weights, const T *bias, T *out, int xi,
*out = res.raw();
}
-template <typename T>
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type * = nullptr>
void vector_matrix_multiply(const T *in, const T *weights, const T *bias, T *out, int cols_weights, int rows_weights, uint8_t fixed_point_position)
{
for(int x = 0; x < cols_weights; ++x)
@@ -172,11 +172,12 @@ void vector_matrix_multiply(const T *in, const T *weights, const T *bias, T *out
}
}
-template <>
-void vector_matrix_multiply(const int8_t *in, const int8_t *weights, const int8_t *bias, int8_t *out, int cols_weights, int rows_weights, uint8_t fixed_point_position)
+// Vector matrix multiply for fixed point type
+template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type * = nullptr>
+void vector_matrix_multiply(const T *in, const T *weights, const T *bias, T *out, int cols_weights, int rows_weights, uint8_t fixed_point_position)
{
using namespace fixed_point_arithmetic;
- using promoted_type = typename fixed_point_arithmetic::traits::promote<int8_t>::type;
+ using promoted_type = typename fixed_point_arithmetic::traits::promote<T>::type;
for(int x = 0; x < cols_weights; ++x)
{
@@ -192,10 +193,10 @@ void vector_matrix_multiply(const int8_t *in, const int8_t *weights, const int8_
}
// Get the bias
- const fixed_point<int8_t> b(bias[x], fixed_point_position, true);
+ const fixed_point<T> b(bias[x], fixed_point_position, true);
// Convert back and accumulate the bias
- fixed_point<int8_t> res(acc);
+ fixed_point<T> res(acc);
res = res + b;
// Store the result