aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/TensorVisitors.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/TensorVisitors.h')
-rw-r--r--tests/validation/TensorVisitors.h29
1 files changed, 25 insertions, 4 deletions
diff --git a/tests/validation/TensorVisitors.h b/tests/validation/TensorVisitors.h
index fcc584dd46..c58b9a69c0 100644
--- a/tests/validation/TensorVisitors.h
+++ b/tests/validation/TensorVisitors.h
@@ -27,10 +27,12 @@
#include "Tensor.h"
#include "TensorOperations.h"
#include "arm_compute/core/Error.h"
+#include "arm_compute/runtime/Lut.h"
#include "boost_wrapper.h"
#include <algorithm>
+#include <map>
#include <memory>
#include <ostream>
#include <vector>
@@ -180,11 +182,30 @@ private:
ConvertPolicy _convert_policy;
RoundingPolicy _rounding_policy;
};
-// Threshold operation
-void threshold_operation(const Tensor<uint8_t> &in, Tensor<uint8_t> &out, uint8_t threshold, uint8_t false_value, uint8_t true_value, ThresholdType type, uint8_t upper)
+// Table lookup operation
+template <typename T1>
+struct table_lookup : public boost::static_visitor<>
{
- tensor_operations::threshold(in, out, threshold, false_value, true_value, type, upper);
-}
+public:
+ explicit table_lookup(const TensorVariant &in, std::map<T1, T1> &lut)
+ : _in(in), _lut(lut)
+ {
+ }
+
+ template <typename T>
+ void operator()(Tensor<T> &out) const
+ {
+ const auto &in = boost::get<Tensor<T>>(_in);
+ tensor_operations::table_lookup(in, out, _lut);
+ }
+
+private:
+ const TensorVariant &_in;
+ std::map<T1, T1> &_lut;
+};
+template struct arm_compute::test::validation::tensor_visitors::table_lookup<uint8_t>;
+template struct arm_compute::test::validation::tensor_visitors::table_lookup<int16_t>;
+
// Activation layer visitor
struct activation_layer_visitor : public boost::static_visitor<>
{