aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ConcatenateLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ConcatenateLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ConcatenateLayerFixture.h24
1 files changed, 13 insertions, 11 deletions
diff --git a/tests/validation/fixtures/ConcatenateLayerFixture.h b/tests/validation/fixtures/ConcatenateLayerFixture.h
index d1eed63d41..3a021661ac 100644
--- a/tests/validation/fixtures/ConcatenateLayerFixture.h
+++ b/tests/validation/fixtures/ConcatenateLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -43,11 +43,13 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename ITensorType, typename AccessorType, typename FunctionType, typename T>
+template <typename TensorType, typename ITensorType, typename AccessorType, typename FunctionType, typename T, bool CI = true>
class ConcatenateLayerValidationFixture : public framework::Fixture
{
+private:
+ using SrcITensorType = typename std::conditional<CI, const ITensorType, ITensorType>::type;
+
public:
- template <typename...>
void setup(TensorShape shape, DataType data_type, unsigned int axis)
{
// Create input shapes
@@ -67,8 +69,8 @@ public:
{
qi = QuantizationInfo(1.f / 255.f, offset_dis(gen));
}
- std::bernoulli_distribution mutate_dis(0.5f);
- std::uniform_real_distribution<> change_dis(-0.25f, 0.f);
+ std::bernoulli_distribution mutate_dis(0.5f);
+ std::uniform_real_distribution<float> change_dis(-0.25f, 0.f);
// Generate more shapes based on the input
for(auto &s : shapes)
@@ -95,8 +97,8 @@ protected:
TensorType compute_target(const std::vector<TensorShape> &shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type, unsigned int axis)
{
- std::vector<TensorType> srcs;
- std::vector<ITensorType *> src_ptrs;
+ std::vector<TensorType> srcs;
+ std::vector<SrcITensorType *> src_ptrs;
// Create tensors
srcs.reserve(shapes.size());
@@ -116,20 +118,20 @@ protected:
for(auto &src : srcs)
{
- ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(src.info()->is_resizable());
}
- ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
// Allocate tensors
for(auto &src : srcs)
{
src.allocator()->allocate();
- ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
}
dst.allocator()->allocate();
- ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
// Fill tensors
int i = 0;