diff options
Diffstat (limited to 'tests/validation/CL')
-rw-r--r-- | tests/validation/CL/QLSTMLayerNormalization.cpp | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tests/validation/CL/QLSTMLayerNormalization.cpp b/tests/validation/CL/QLSTMLayerNormalization.cpp index ea5eca6261..17f431cbbf 100644 --- a/tests/validation/CL/QLSTMLayerNormalization.cpp +++ b/tests/validation/CL/QLSTMLayerNormalization.cpp @@ -23,6 +23,7 @@ */ #include "arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h" #include "tests/CL/CLAccessor.h" +#include "tests/CL/Helper.h" #include "tests/PaddingCalculator.h" #include "tests/datasets/ShapeDatasets.h" #include "tests/framework/Asserts.h" @@ -44,6 +45,7 @@ constexpr AbsoluteTolerance<int16_t> tolerance_s16(0); /**< Tolerance value for constexpr uint32_t vector_size_byte = 16; using test::datasets::ShapeDataset; +using CLQLSTMLayerNormalization = CLSynthetizeFunction<CLQLSTMLayerNormalizationKernel>; template <uint32_t num_elements_per_iter, uint32_t num_batches, uint32_t num_iteration> class QLSTMLayerNormShapeDataSet : public ShapeDataset { @@ -127,7 +129,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, ), input_info, weight_info, bias_info) { TensorInfo dummy_output{}; - const Status s = CLQLSTMLayerNormalizationKernel::validate(&input_info, &dummy_output, &weight_info, &bias_info); + const Status s = CLQLSTMLayerNormalization::validate(&input_info, &dummy_output, &weight_info, &bias_info); ARM_COMPUTE_EXPECT(!bool(s), framework::LogLevel::ERRORS); } @@ -135,7 +137,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, // *INDENT-ON* template <typename T> -using CLQLSTMLayerNormalizationFixture = CLQLSTMLayerNormalizationValidationFixture<CLTensor, CLAccessor, CLQLSTMLayerNormalizationKernel, T>; +using CLQLSTMLayerNormalizationFixture = QLSTMLayerNormalizationValidationFixture<CLTensor, CLAccessor, CLQLSTMLayerNormalization, T>; TEST_SUITE(Quantized) TEST_SUITE(QSYMM16) |