aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/NEON/QLSTMLayerNormalization.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/NEON/QLSTMLayerNormalization.cpp')
-rw-r--r--tests/validation/NEON/QLSTMLayerNormalization.cpp7
1 files changed, 5 insertions, 2 deletions
diff --git a/tests/validation/NEON/QLSTMLayerNormalization.cpp b/tests/validation/NEON/QLSTMLayerNormalization.cpp
index 248bf5cf78..3d71175a6f 100644
--- a/tests/validation/NEON/QLSTMLayerNormalization.cpp
+++ b/tests/validation/NEON/QLSTMLayerNormalization.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/runtime/Tensor.h"
#include "arm_compute/runtime/TensorAllocator.h"
#include "tests/NEON/Accessor.h"
+#include "tests/NEON/Helper.h"
#include "tests/PaddingCalculator.h"
#include "tests/datasets/ShapeDatasets.h"
#include "tests/framework/Asserts.h"
@@ -46,6 +47,8 @@ namespace
constexpr uint32_t vector_size_byte = 16;
using test::datasets::ShapeDataset;
+using NEQLSTMLayerNormalization = NESynthetizeFunction<NEQLSTMLayerNormalizationKernel>;
+
template <uint32_t num_elements_per_iter, uint32_t num_batches, uint32_t num_iteration>
class QLSTMLayerNormShapeDataSet : public ShapeDataset
{
@@ -150,7 +153,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL,
),
input_info, weight_info, bias_info, output_info)
{
- const Status s = NEQLSTMLayerNormalizationKernel::validate(&input_info, &output_info, &weight_info, &bias_info);
+ const Status s = NEQLSTMLayerNormalization::validate(&input_info, &output_info, &weight_info, &bias_info);
ARM_COMPUTE_EXPECT(!bool(s), framework::LogLevel::ERRORS);
}
@@ -158,7 +161,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL,
// *INDENT-ON*
template <typename T>
-using NEQLSTMLayerNormalizationFixture = NEQLSTMLayerNormalizationValidationFixture<Tensor, Accessor, NEQLSTMLayerNormalizationKernel, T>;
+using NEQLSTMLayerNormalizationFixture = QLSTMLayerNormalizationValidationFixture<Tensor, Accessor, NEQLSTMLayerNormalization, T>;
TEST_SUITE(Quantized)
TEST_SUITE(QSYMM16)