diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2017-08-31 18:12:42 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:35:24 +0000 |
commit | 583137cc60580023abfd9d05abf933e7e117e29f (patch) | |
tree | b29ec55c11b65e2882e60c0cf8b592bf25e78b1b /tests/validation/CPP | |
parent | 3021edfb5e72ef4cd91dbc754ce6ac55388ebc4e (diff) | |
download | ComputeLibrary-583137cc60580023abfd9d05abf933e7e117e29f.tar.gz |
COMPMID-417: Add support for floats in scale.
Change-Id: I7d714ba13861509080a89817f54e9d32da83e970
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/86026
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'tests/validation/CPP')
-rw-r--r-- | tests/validation/CPP/ActivationLayer.cpp | 4 | ||||
-rw-r--r-- | tests/validation/CPP/ArithmeticAddition.cpp | 6 | ||||
-rw-r--r-- | tests/validation/CPP/ArithmeticSubtraction.cpp | 4 | ||||
-rw-r--r-- | tests/validation/CPP/ConvolutionLayer.cpp | 5 | ||||
-rw-r--r-- | tests/validation/CPP/DepthConcatenateLayer.cpp | 3 | ||||
-rw-r--r-- | tests/validation/CPP/DepthConvert.cpp | 1 | ||||
-rw-r--r-- | tests/validation/CPP/DepthwiseConvolution.cpp | 1 | ||||
-rw-r--r-- | tests/validation/CPP/DepthwiseSeparableConvolutionLayer.cpp | 1 | ||||
-rw-r--r-- | tests/validation/CPP/FullyConnectedLayer.cpp | 5 | ||||
-rw-r--r-- | tests/validation/CPP/GEMM.cpp | 4 | ||||
-rw-r--r-- | tests/validation/CPP/NormalizationLayer.cpp | 4 | ||||
-rw-r--r-- | tests/validation/CPP/PoolingLayer.cpp | 4 | ||||
-rw-r--r-- | tests/validation/CPP/ReshapeLayer.cpp | 4 | ||||
-rw-r--r-- | tests/validation/CPP/Scale.cpp | 5 | ||||
-rw-r--r-- | tests/validation/CPP/Scale.h | 2 | ||||
-rw-r--r-- | tests/validation/CPP/SoftmaxLayer.cpp | 4 | ||||
-rw-r--r-- | tests/validation/CPP/Utils.cpp | 26 | ||||
-rw-r--r-- | tests/validation/CPP/Utils.h | 2 |
18 files changed, 44 insertions, 41 deletions
diff --git a/tests/validation/CPP/ActivationLayer.cpp b/tests/validation/CPP/ActivationLayer.cpp index 8fcacca1e2..2243e6ff59 100644 --- a/tests/validation/CPP/ActivationLayer.cpp +++ b/tests/validation/CPP/ActivationLayer.cpp @@ -23,9 +23,9 @@ */ #include "ActivationLayer.h" +#include "arm_compute/core/Types.h" #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" -#include "tests/validation/half.h" namespace arm_compute { @@ -155,7 +155,7 @@ SimpleTensor<T> activation_layer(const SimpleTensor<T> &src, ActivationLayerInfo } template SimpleTensor<float> activation_layer(const SimpleTensor<float> &src, ActivationLayerInfo info); -template SimpleTensor<half_float::half> activation_layer(const SimpleTensor<half_float::half> &src, ActivationLayerInfo info); +template SimpleTensor<half> activation_layer(const SimpleTensor<half> &src, ActivationLayerInfo info); template SimpleTensor<qint8_t> activation_layer(const SimpleTensor<qint8_t> &src, ActivationLayerInfo info); template SimpleTensor<qint16_t> activation_layer(const SimpleTensor<qint16_t> &src, ActivationLayerInfo info); } // namespace reference diff --git a/tests/validation/CPP/ArithmeticAddition.cpp b/tests/validation/CPP/ArithmeticAddition.cpp index 41052c469b..82dd1437cd 100644 --- a/tests/validation/CPP/ArithmeticAddition.cpp +++ b/tests/validation/CPP/ArithmeticAddition.cpp @@ -23,9 +23,9 @@ */ #include "ArithmeticAddition.h" +#include "arm_compute/core/Types.h" #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" -#include "tests/validation/half.h" namespace arm_compute { @@ -54,8 +54,8 @@ SimpleTensor<T> arithmetic_addition(const SimpleTensor<T> &src1, const SimpleTen template SimpleTensor<uint8_t> arithmetic_addition(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<int16_t> arithmetic_addition(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<int8_t> arithmetic_addition(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); -template SimpleTensor<half_float::half> arithmetic_addition(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, DataType dst_data_type, - ConvertPolicy convert_policy); +template SimpleTensor<half> arithmetic_addition(const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type, + ConvertPolicy convert_policy); template SimpleTensor<float> arithmetic_addition(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, DataType dst_data_type, ConvertPolicy convert_policy); } // namespace reference } // namespace validation diff --git a/tests/validation/CPP/ArithmeticSubtraction.cpp b/tests/validation/CPP/ArithmeticSubtraction.cpp index fa7fec9d1b..80bdb15a49 100644 --- a/tests/validation/CPP/ArithmeticSubtraction.cpp +++ b/tests/validation/CPP/ArithmeticSubtraction.cpp @@ -25,7 +25,6 @@ #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" -#include "tests/validation/half.h" namespace arm_compute { @@ -54,8 +53,7 @@ SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const Simple template SimpleTensor<uint8_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<int8_t> arithmetic_subtraction(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); -template SimpleTensor<half_float::half> arithmetic_subtraction(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, DataType dst_data_type, - ConvertPolicy convert_policy); +template SimpleTensor<half> arithmetic_subtraction(const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<float> arithmetic_subtraction(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, DataType dst_data_type, ConvertPolicy convert_policy); } // namespace reference } // namespace validation diff --git a/tests/validation/CPP/ConvolutionLayer.cpp b/tests/validation/CPP/ConvolutionLayer.cpp index 1824ada791..656cd2ee26 100644 --- a/tests/validation/CPP/ConvolutionLayer.cpp +++ b/tests/validation/CPP/ConvolutionLayer.cpp @@ -25,7 +25,6 @@ #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" -#include "tests/validation/half.h" namespace arm_compute { @@ -193,8 +192,8 @@ SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor template SimpleTensor<float> convolution_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, const TensorShape &output_shape, const PadStrideInfo &info); -template SimpleTensor<half_float::half> convolution_layer(const SimpleTensor<half_float::half> &src, const SimpleTensor<half_float::half> &weights, const SimpleTensor<half_float::half> &bias, - const TensorShape &output_shape, const PadStrideInfo &info); +template SimpleTensor<half> convolution_layer(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &bias, const TensorShape &output_shape, + const PadStrideInfo &info); template SimpleTensor<qint8_t> convolution_layer(const SimpleTensor<qint8_t> &src, const SimpleTensor<qint8_t> &weights, const SimpleTensor<qint8_t> &bias, const TensorShape &output_shape, const PadStrideInfo &info); template SimpleTensor<qint16_t> convolution_layer(const SimpleTensor<qint16_t> &src, const SimpleTensor<qint16_t> &weights, const SimpleTensor<qint16_t> &bias, const TensorShape &output_shape, diff --git a/tests/validation/CPP/DepthConcatenateLayer.cpp b/tests/validation/CPP/DepthConcatenateLayer.cpp index 139d26f2b6..9a7248493d 100644 --- a/tests/validation/CPP/DepthConcatenateLayer.cpp +++ b/tests/validation/CPP/DepthConcatenateLayer.cpp @@ -25,7 +25,6 @@ #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" -#include "tests/validation/half.h" namespace arm_compute { @@ -95,7 +94,7 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs) } template SimpleTensor<float> depthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs); -template SimpleTensor<half_float::half> depthconcatenate_layer(const std::vector<SimpleTensor<half_float::half>> &srcs); +template SimpleTensor<half> depthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs); template SimpleTensor<qint8_t> depthconcatenate_layer(const std::vector<SimpleTensor<qint8_t>> &srcs); template SimpleTensor<qint16_t> depthconcatenate_layer(const std::vector<SimpleTensor<qint16_t>> &srcs); } // namespace reference diff --git a/tests/validation/CPP/DepthConvert.cpp b/tests/validation/CPP/DepthConvert.cpp index bb34f67086..110174a73f 100644 --- a/tests/validation/CPP/DepthConvert.cpp +++ b/tests/validation/CPP/DepthConvert.cpp @@ -25,7 +25,6 @@ #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" -#include "tests/validation/half.h" #include "tests/Types.h" diff --git a/tests/validation/CPP/DepthwiseConvolution.cpp b/tests/validation/CPP/DepthwiseConvolution.cpp index ebca333715..ce30bed640 100644 --- a/tests/validation/CPP/DepthwiseConvolution.cpp +++ b/tests/validation/CPP/DepthwiseConvolution.cpp @@ -27,7 +27,6 @@ #include "Utils.h" #include "tests/validation/Helpers.h" -#include "tests/validation/half.h" namespace arm_compute { diff --git a/tests/validation/CPP/DepthwiseSeparableConvolutionLayer.cpp b/tests/validation/CPP/DepthwiseSeparableConvolutionLayer.cpp index 7020a854cf..3942ecf02a 100644 --- a/tests/validation/CPP/DepthwiseSeparableConvolutionLayer.cpp +++ b/tests/validation/CPP/DepthwiseSeparableConvolutionLayer.cpp @@ -29,7 +29,6 @@ #include "Utils.h" #include "tests/validation/Helpers.h" -#include "tests/validation/half.h" namespace arm_compute { diff --git a/tests/validation/CPP/FullyConnectedLayer.cpp b/tests/validation/CPP/FullyConnectedLayer.cpp index a146535bd9..2b32c4b161 100644 --- a/tests/validation/CPP/FullyConnectedLayer.cpp +++ b/tests/validation/CPP/FullyConnectedLayer.cpp @@ -23,8 +23,8 @@ */ #include "FullyConnectedLayer.h" +#include "arm_compute/core/Types.h" #include "tests/validation/FixedPoint.h" -#include "tests/validation/half.h" #include <numeric> @@ -123,8 +123,7 @@ SimpleTensor<T> fully_connected_layer(const SimpleTensor<T> &src, const SimpleTe } template SimpleTensor<float> fully_connected_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, const TensorShape &dst_shape); -template SimpleTensor<half_float::half> fully_connected_layer(const SimpleTensor<half_float::half> &src, const SimpleTensor<half_float::half> &weights, const SimpleTensor<half_float::half> &bias, - const TensorShape &dst_shape); +template SimpleTensor<half> fully_connected_layer(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &bias, const TensorShape &dst_shape); template SimpleTensor<qint8_t> fully_connected_layer(const SimpleTensor<qint8_t> &src, const SimpleTensor<qint8_t> &weights, const SimpleTensor<qint8_t> &bias, const TensorShape &dst_shape); template SimpleTensor<qint16_t> fully_connected_layer(const SimpleTensor<qint16_t> &src, const SimpleTensor<qint16_t> &weights, const SimpleTensor<qint16_t> &bias, const TensorShape &dst_shape); } // namespace reference diff --git a/tests/validation/CPP/GEMM.cpp b/tests/validation/CPP/GEMM.cpp index 9b66597eb8..77d025ec8e 100644 --- a/tests/validation/CPP/GEMM.cpp +++ b/tests/validation/CPP/GEMM.cpp @@ -23,8 +23,8 @@ */ #include "GEMM.h" +#include "arm_compute/core/Types.h" #include "tests/validation/FixedPoint.h" -#include "tests/validation/half.h" namespace arm_compute { @@ -113,7 +113,7 @@ SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const S } template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta); -template SimpleTensor<half_float::half> gemm(const SimpleTensor<half_float::half> &a, const SimpleTensor<half_float::half> &b, const SimpleTensor<half_float::half> &c, float alpha, float beta); +template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta); template SimpleTensor<qint8_t> gemm(const SimpleTensor<qint8_t> &a, const SimpleTensor<qint8_t> &b, const SimpleTensor<qint8_t> &c, float alpha, float beta); template SimpleTensor<qint16_t> gemm(const SimpleTensor<qint16_t> &a, const SimpleTensor<qint16_t> &b, const SimpleTensor<qint16_t> &c, float alpha, float beta); } // namespace reference diff --git a/tests/validation/CPP/NormalizationLayer.cpp b/tests/validation/CPP/NormalizationLayer.cpp index 3c6f5e1a54..226af96fe3 100644 --- a/tests/validation/CPP/NormalizationLayer.cpp +++ b/tests/validation/CPP/NormalizationLayer.cpp @@ -23,8 +23,8 @@ */ #include "NormalizationLayer.h" +#include "arm_compute/core/Types.h" #include "tests/validation/FixedPoint.h" -#include "tests/validation/half.h" namespace arm_compute { @@ -266,7 +266,7 @@ SimpleTensor<T> normalization_layer(const SimpleTensor<T> &src, NormalizationLay } template SimpleTensor<float> normalization_layer(const SimpleTensor<float> &src, NormalizationLayerInfo info); -template SimpleTensor<half_float::half> normalization_layer(const SimpleTensor<half_float::half> &src, NormalizationLayerInfo info); +template SimpleTensor<half> normalization_layer(const SimpleTensor<half> &src, NormalizationLayerInfo info); template SimpleTensor<qint8_t> normalization_layer(const SimpleTensor<qint8_t> &src, NormalizationLayerInfo info); template SimpleTensor<qint16_t> normalization_layer(const SimpleTensor<qint16_t> &src, NormalizationLayerInfo info); } // namespace reference diff --git a/tests/validation/CPP/PoolingLayer.cpp b/tests/validation/CPP/PoolingLayer.cpp index c4425ca9a1..f7273f073f 100644 --- a/tests/validation/CPP/PoolingLayer.cpp +++ b/tests/validation/CPP/PoolingLayer.cpp @@ -23,8 +23,8 @@ */ #include "PoolingLayer.h" +#include "arm_compute/core/Types.h" #include "tests/validation/FixedPoint.h" -#include "tests/validation/half.h" namespace arm_compute { @@ -234,7 +234,7 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info) } template SimpleTensor<float> pooling_layer(const SimpleTensor<float> &src, PoolingLayerInfo info); -template SimpleTensor<half_float::half> pooling_layer(const SimpleTensor<half_float::half> &src, PoolingLayerInfo info); +template SimpleTensor<half> pooling_layer(const SimpleTensor<half> &src, PoolingLayerInfo info); template SimpleTensor<qint8_t> pooling_layer(const SimpleTensor<qint8_t> &src, PoolingLayerInfo info); template SimpleTensor<qint16_t> pooling_layer(const SimpleTensor<qint16_t> &src, PoolingLayerInfo info); } // namespace reference diff --git a/tests/validation/CPP/ReshapeLayer.cpp b/tests/validation/CPP/ReshapeLayer.cpp index cc7f15e4b1..42f06e4f5a 100644 --- a/tests/validation/CPP/ReshapeLayer.cpp +++ b/tests/validation/CPP/ReshapeLayer.cpp @@ -23,7 +23,7 @@ */ #include "ReshapeLayer.h" -#include "tests/validation/half.h" +#include "arm_compute/core/Types.h" namespace arm_compute { @@ -49,7 +49,7 @@ template SimpleTensor<uint16_t> reshape_layer(const SimpleTensor<uint16_t> &src, template SimpleTensor<int16_t> reshape_layer(const SimpleTensor<int16_t> &src, const TensorShape &output_shape); template SimpleTensor<uint32_t> reshape_layer(const SimpleTensor<uint32_t> &src, const TensorShape &output_shape); template SimpleTensor<int32_t> reshape_layer(const SimpleTensor<int32_t> &src, const TensorShape &output_shape); -template SimpleTensor<half_float::half> reshape_layer(const SimpleTensor<half_float::half> &src, const TensorShape &output_shape); +template SimpleTensor<half> reshape_layer(const SimpleTensor<half> &src, const TensorShape &output_shape); template SimpleTensor<float> reshape_layer(const SimpleTensor<float> &src, const TensorShape &output_shape); } // namespace reference } // namespace validation diff --git a/tests/validation/CPP/Scale.cpp b/tests/validation/CPP/Scale.cpp index a1119f33b9..ba34553a99 100644 --- a/tests/validation/CPP/Scale.cpp +++ b/tests/validation/CPP/Scale.cpp @@ -36,7 +36,7 @@ namespace validation namespace reference { template <typename T> -SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value) +SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, T constant_border_value) { TensorShape shape_scaled(in.shape()); shape_scaled.set(0, in.shape()[0] * scale_x); @@ -160,6 +160,9 @@ SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, I } template SimpleTensor<uint8_t> scale(const SimpleTensor<uint8_t> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value); +template SimpleTensor<int16_t> scale(const SimpleTensor<int16_t> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, int16_t constant_border_value); +template SimpleTensor<half> scale(const SimpleTensor<half> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, half constant_border_value); +template SimpleTensor<float> scale(const SimpleTensor<float> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, float constant_border_value); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/CPP/Scale.h b/tests/validation/CPP/Scale.h index b882915946..53183ae742 100644 --- a/tests/validation/CPP/Scale.h +++ b/tests/validation/CPP/Scale.h @@ -35,7 +35,7 @@ namespace validation namespace reference { template <typename T> -SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value = 0); +SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, T constant_border_value = 0); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/CPP/SoftmaxLayer.cpp b/tests/validation/CPP/SoftmaxLayer.cpp index 4fe87d07dc..eb7655078c 100644 --- a/tests/validation/CPP/SoftmaxLayer.cpp +++ b/tests/validation/CPP/SoftmaxLayer.cpp @@ -23,8 +23,8 @@ */ #include "SoftmaxLayer.h" +#include "arm_compute/core/Types.h" #include "tests/validation/FixedPoint.h" -#include "tests/validation/half.h" namespace arm_compute { @@ -113,7 +113,7 @@ SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src) } template SimpleTensor<float> softmax_layer(const SimpleTensor<float> &src); -template SimpleTensor<half_float::half> softmax_layer(const SimpleTensor<half_float::half> &src); +template SimpleTensor<half> softmax_layer(const SimpleTensor<half> &src); template SimpleTensor<qint8_t> softmax_layer(const SimpleTensor<qint8_t> &src); template SimpleTensor<qint16_t> softmax_layer(const SimpleTensor<qint16_t> &src); } // namespace reference diff --git a/tests/validation/CPP/Utils.cpp b/tests/validation/CPP/Utils.cpp index 15e9fc3138..2f54879818 100644 --- a/tests/validation/CPP/Utils.cpp +++ b/tests/validation/CPP/Utils.cpp @@ -24,7 +24,6 @@ #include "Utils.h" #include "tests/validation/Helpers.h" -#include "tests/validation/half.h" namespace arm_compute { @@ -51,17 +50,20 @@ T tensor_elem_at(const SimpleTensor<T> &in, Coordinates coord, BorderMode border } else { - return constant_border_value; + return static_cast<T>(constant_border_value); } } return in[coord2index(in.shape(), coord)]; } -template float tensor_elem_at(const SimpleTensor<float> &in, Coordinates coord, BorderMode border_mode, float constant_border_value); + template uint8_t tensor_elem_at(const SimpleTensor<uint8_t> &in, Coordinates coord, BorderMode border_mode, uint8_t constant_border_value); +template int16_t tensor_elem_at(const SimpleTensor<int16_t> &in, Coordinates coord, BorderMode border_mode, int16_t constant_border_value); +template half tensor_elem_at(const SimpleTensor<half> &in, Coordinates coord, BorderMode border_mode, half constant_border_value); +template float tensor_elem_at(const SimpleTensor<float> &in, Coordinates coord, BorderMode border_mode, float constant_border_value); // Return the bilinear value at a specified coordinate with different border modes template <typename T> -T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, uint8_t constant_border_value) +T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, T constant_border_value) { int idx = std::floor(xn); int idy = std::floor(yn); @@ -71,22 +73,28 @@ T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, const float dx_1 = 1.0f - dx; const float dy_1 = 1.0f - dy; + const T border_value = constant_border_value; + id.set(0, idx); id.set(1, idy); - const T tl = tensor_elem_at(in, id, border_mode, constant_border_value); + const float tl = tensor_elem_at(in, id, border_mode, border_value); id.set(0, idx + 1); id.set(1, idy); - const T tr = tensor_elem_at(in, id, border_mode, constant_border_value); + const float tr = tensor_elem_at(in, id, border_mode, border_value); id.set(0, idx); id.set(1, idy + 1); - const T bl = tensor_elem_at(in, id, border_mode, constant_border_value); + const float bl = tensor_elem_at(in, id, border_mode, border_value); id.set(0, idx + 1); id.set(1, idy + 1); - const T br = tensor_elem_at(in, id, border_mode, constant_border_value); + const float br = tensor_elem_at(in, id, border_mode, border_value); - return tl * (dx_1 * dy_1) + tr * (dx * dy_1) + bl * (dx_1 * dy) + br * (dx * dy); + return static_cast<T>(tl * (dx_1 * dy_1) + tr * (dx * dy_1) + bl * (dx_1 * dy) + br * (dx * dy)); } + template uint8_t bilinear_policy(const SimpleTensor<uint8_t> &in, Coordinates id, float xn, float yn, BorderMode border_mode, uint8_t constant_border_value); +template int16_t bilinear_policy(const SimpleTensor<int16_t> &in, Coordinates id, float xn, float yn, BorderMode border_mode, int16_t constant_border_value); +template half bilinear_policy(const SimpleTensor<half> &in, Coordinates id, float xn, float yn, BorderMode border_mode, half constant_border_value); +template float bilinear_policy(const SimpleTensor<float> &in, Coordinates id, float xn, float yn, BorderMode border_mode, float constant_border_value); /* Apply 2D spatial filter on a single element of @p in at coordinates @p coord * diff --git a/tests/validation/CPP/Utils.h b/tests/validation/CPP/Utils.h index 34ba60bed6..557d85f204 100644 --- a/tests/validation/CPP/Utils.h +++ b/tests/validation/CPP/Utils.h @@ -45,7 +45,7 @@ template <typename T> T tensor_elem_at(const SimpleTensor<T> &in, Coordinates coord, BorderMode border_mode, T constant_border_value); template <typename T> -T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, uint8_t constant_border_value); +T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, T constant_border_value); template <typename T1, typename T2, typename T3> void apply_2d_spatial_filter(Coordinates coord, const SimpleTensor<T1> &in, SimpleTensor<T3> &out, const TensorShape &filter_shape, const T2 *filter_itr, float scale, BorderMode border_mode, |