aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/RsqrtLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/RsqrtLayer.cpp')
-rw-r--r--tests/validation/CL/RsqrtLayer.cpp35
1 files changed, 17 insertions, 18 deletions
diff --git a/tests/validation/CL/RsqrtLayer.cpp b/tests/validation/CL/RsqrtLayer.cpp
index ee9e9363b3..82fbed3b5d 100644
--- a/tests/validation/CL/RsqrtLayer.cpp
+++ b/tests/validation/CL/RsqrtLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -48,25 +48,24 @@ RelativeTolerance<float> tolerance_fp16(0.001f);
TEST_SUITE(CL)
TEST_SUITE(RsqrtLayer)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("DataType", DataType::F32)), shape, data_type)
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
+ framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Mismatching data types
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Valid
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
+ }),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F16),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(30U, 11U, 2U), 1, DataType::F32),
+ })),
+ framework::dataset::make("Expected", { false, true, false })),
+ input_info, output_info, expected)
{
- // Create tensors
- CLTensor src = create_tensor<CLTensor>(shape, data_type);
- CLTensor dst = create_tensor<CLTensor>(shape, data_type);
-
- ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
-
- // Create and configure function
- CLRsqrtLayer exp_layer;
- exp_layer.configure(&src, &dst);
-
- // Validate valid region
- const ValidRegion valid_region = shape_to_valid_region(shape);
- validate(src.info()->valid_region(), valid_region);
- validate(dst.info()->valid_region(), valid_region);
+ ARM_COMPUTE_EXPECT(bool(CLRsqrtLayer::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS);
}
-
+// clang-format on
+// *INDENT-ON*
template <typename T>
using CLRsqrtLayerFixture = RsqrtValidationFixture<CLTensor, CLAccessor, CLRsqrtLayer, T>;