aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/TensorVisitors.h
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2017-06-23 15:02:11 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:16:08 +0100
commitb797fa235f714440ffa7a2ad4eef7ae14ee45da4 (patch)
treeefdefae2963d612c1bb1f84b8b74823c64804702 /tests/validation/TensorVisitors.h
parentac4e873dad6aa6291fc36aff62047a896db04f6a (diff)
downloadComputeLibrary-b797fa235f714440ffa7a2ad4eef7ae14ee45da4.tar.gz
COMPMID-424 - Implemented reference implementation and tests (NEON and CL) for TableLookup
Change-Id: I53098ee750ab07fe64e9e2af8df91954d64017f5 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79411 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Steven Niu <steven.niu@arm.com>
Diffstat (limited to 'tests/validation/TensorVisitors.h')
-rw-r--r--tests/validation/TensorVisitors.h29
1 files changed, 25 insertions, 4 deletions
diff --git a/tests/validation/TensorVisitors.h b/tests/validation/TensorVisitors.h
index fcc584dd46..c58b9a69c0 100644
--- a/tests/validation/TensorVisitors.h
+++ b/tests/validation/TensorVisitors.h
@@ -27,10 +27,12 @@
#include "Tensor.h"
#include "TensorOperations.h"
#include "arm_compute/core/Error.h"
+#include "arm_compute/runtime/Lut.h"
#include "boost_wrapper.h"
#include <algorithm>
+#include <map>
#include <memory>
#include <ostream>
#include <vector>
@@ -180,11 +182,30 @@ private:
ConvertPolicy _convert_policy;
RoundingPolicy _rounding_policy;
};
-// Threshold operation
-void threshold_operation(const Tensor<uint8_t> &in, Tensor<uint8_t> &out, uint8_t threshold, uint8_t false_value, uint8_t true_value, ThresholdType type, uint8_t upper)
+// Table lookup operation
+template <typename T1>
+struct table_lookup : public boost::static_visitor<>
{
- tensor_operations::threshold(in, out, threshold, false_value, true_value, type, upper);
-}
+public:
+ explicit table_lookup(const TensorVariant &in, std::map<T1, T1> &lut)
+ : _in(in), _lut(lut)
+ {
+ }
+
+ template <typename T>
+ void operator()(Tensor<T> &out) const
+ {
+ const auto &in = boost::get<Tensor<T>>(_in);
+ tensor_operations::table_lookup(in, out, _lut);
+ }
+
+private:
+ const TensorVariant &_in;
+ std::map<T1, T1> &_lut;
+};
+template struct arm_compute::test::validation::tensor_visitors::table_lookup<uint8_t>;
+template struct arm_compute::test::validation::tensor_visitors::table_lookup<int16_t>;
+
// Activation layer visitor
struct activation_layer_visitor : public boost::static_visitor<>
{