aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/TensorOperations.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/TensorOperations.h')
-rw-r--r--tests/validation/TensorOperations.h61
1 files changed, 61 insertions, 0 deletions
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h
index 0430d59d33..843c52fec4 100644
--- a/tests/validation/TensorOperations.h
+++ b/tests/validation/TensorOperations.h
@@ -711,6 +711,67 @@ void gemm(const Tensor<T> &in1, const Tensor<T> &in2, const Tensor<T> &in3, Tens
}
}
+// Non linear filter
+template <typename T>
+void non_linear_filter(const Tensor<T> &in, Tensor<T> &out, NonLinearFilterFunction function, unsigned int mask_size,
+ MatrixPattern pattern, const uint8_t *mask, BorderMode border_mode, uint8_t constant_border_value)
+{
+ ARM_COMPUTE_ERROR_ON(MatrixPattern::OTHER == pattern && nullptr == mask);
+
+ using intermediate_type = typename common_promoted_signed_type<T>::intermediate_type;
+
+ const int sq_mask_size = mask_size * mask_size;
+ const int half_mask_size = mask_size / 2;
+ std::vector<intermediate_type> vals(sq_mask_size);
+ intermediate_type current_value = 0;
+
+ ValidRegion valid_region = shape_to_valid_region(in.shape());
+ if(border_mode == BorderMode::UNDEFINED)
+ {
+ valid_region = shape_to_valid_region_undefined_border(in.shape(), BorderSize(half_mask_size));
+ }
+
+ for(int element_idx = 0, count = 0, index = 0; element_idx < in.num_elements(); ++element_idx, count = 0, index = 0)
+ {
+ Coordinates id = index2coord(in.shape(), element_idx);
+ if(is_in_valid_region(valid_region, id))
+ {
+ int idx = id.x();
+ int idy = id.y();
+ for(int y = idy - half_mask_size; y <= idy + half_mask_size; ++y)
+ {
+ for(int x = idx - half_mask_size; x <= idx + half_mask_size; ++x, ++index)
+ {
+ id.set(0, x);
+ id.set(1, y);
+ current_value = tensor_elem_at(in, id, border_mode, constant_border_value);
+
+ if(mask[index] == 255)
+ {
+ vals[count] = static_cast<intermediate_type>(current_value);
+ ++count;
+ }
+ }
+ }
+ std::sort(vals.begin(), vals.begin() + count);
+ switch(function)
+ {
+ case NonLinearFilterFunction::MIN:
+ out[element_idx] = saturate_cast<T>(vals[0]);
+ break;
+ case NonLinearFilterFunction::MAX:
+ out[element_idx] = saturate_cast<T>(vals[count - 1]);
+ break;
+ case NonLinearFilterFunction::MEDIAN:
+ out[element_idx] = saturate_cast<T>(vals[count / 2]);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported NonLinearFilter function.");
+ }
+ }
+ }
+}
+
// Pixel-wise multiplication
template <typename T1, typename T2, typename T3>
void pixel_wise_multiplication(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3> &out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)