diff options
Diffstat (limited to 'tests/validation_old/TensorOperations.h')
-rw-r--r-- | tests/validation_old/TensorOperations.h | 74 |
1 files changed, 0 insertions, 74 deletions
diff --git a/tests/validation_old/TensorOperations.h b/tests/validation_old/TensorOperations.h index 2b326930f6..e03336505b 100644 --- a/tests/validation_old/TensorOperations.h +++ b/tests/validation_old/TensorOperations.h @@ -528,80 +528,6 @@ void non_linear_filter(const Tensor<T> &in, Tensor<T> &out, NonLinearFilterFunct } } -// Pixel-wise multiplication -template <typename T1, typename T2, typename T3> -void pixel_wise_multiplication(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) -{ - if(scale < 0) - { - ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative"); - } - using intermediate_type = typename common_promoted_signed_type<T1, T2, T3>::intermediate_type; - for(int i = 0; i < in1.num_elements(); ++i) - { - double val = static_cast<intermediate_type>(in1[i]) * static_cast<intermediate_type>(in2[i]) * static_cast<double>(scale); - if(is_floating_point<T3>::value) - { - out[i] = val; - } - else - { - double rounded_val = 0; - switch(rounding_policy) - { - case(RoundingPolicy::TO_ZERO): - rounded_val = support::cpp11::trunc(val); - break; - case(RoundingPolicy::TO_NEAREST_UP): - rounded_val = round_half_up(val); - break; - case(RoundingPolicy::TO_NEAREST_EVEN): - rounded_val = round_half_even(val); - break; - default: - ARM_COMPUTE_ERROR("Unsupported rounding policy"); - } - out[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(rounded_val) : static_cast<T3>(rounded_val); - } - } -} - -// Fixed-point Pixel-wise Multiplication -template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type> -void fixed_point_pixel_wise_multiplication(const Tensor<T> &in1, const Tensor<T> &in2, Tensor<T> &out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) -{ - using namespace fixed_point_arithmetic; - - const int fixed_point_position = in1.fixed_point_position(); - - ARM_COMPUTE_ERROR_ON_MSG(in1.data_type() != in2.data_type() || in1.data_type() != out.data_type(), - "Tensors must all have the same DataType"); - ARM_COMPUTE_ERROR_ON_MSG(fixed_point_position != in2.fixed_point_position() || fixed_point_position != out.fixed_point_position(), - "Fixed-point position must be the same for both inputs and outputs"); - - // Validate fixed_point_position - ARM_COMPUTE_ERROR_ON((in1.data_type() == DataType::QS8) && (fixed_point_position == 0 || fixed_point_position > 7)); - ARM_COMPUTE_ERROR_ON((in1.data_type() == DataType::QS16) && (fixed_point_position == 0 || fixed_point_position > 15)); - - const fixed_point<T> fp_scale(scale, fixed_point_position); - const bool is_sat = convert_policy == ConvertPolicy::SATURATE; - - for(int i = 0; i < in1.num_elements(); ++i) - { - const fixed_point<T> val1(in1[i], fixed_point_position, true); - fixed_point<T> res(in2[i], fixed_point_position, true); - if(is_sat) - { - res = mul(mul(res, val1), fp_scale); - } - else - { - res = mul<OverflowPolicy::WRAP>(mul<OverflowPolicy::WRAP>(res, val1), fp_scale); - } - out[i] = res.raw(); - } -} - // Threshold template <typename T> void threshold(const Tensor<T> &in, Tensor<T> &out, uint8_t threshold, uint8_t false_value, uint8_t true_value, ThresholdType type, uint8_t upper) |