aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/FFTFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/FFTFixture.h')
-rw-r--r--tests/validation/fixtures/FFTFixture.h37
1 files changed, 30 insertions, 7 deletions
diff --git a/tests/validation/fixtures/FFTFixture.h b/tests/validation/fixtures/FFTFixture.h
index 86a97272a0..199730d5d0 100644
--- a/tests/validation/fixtures/FFTFixture.h
+++ b/tests/validation/fixtures/FFTFixture.h
@@ -134,8 +134,9 @@ class FFTConvolutionValidationGenericFixture : public framework::Fixture
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation,
- DataType data_type, DataLayout data_layout, ActivationLayerInfo act_info)
+ DataType data_type, DataLayout data_layout, ActivationLayerInfo act_info, bool mixed_layout = false)
{
+ _mixed_layout = mixed_layout;
_data_type = data_type;
_data_layout = data_layout;
@@ -144,6 +145,21 @@ public:
}
protected:
+
+ void mix_layout(FunctionType &layer, TensorType &src, TensorType &dst)
+ {
+ // Test Multi DataLayout graph cases, when the data layout changes after configure
+ src.info()->set_data_layout(_data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW);
+ dst.info()->set_data_layout(_data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW);
+
+ // Compute Convolution function
+ layer.run();
+
+ // Reinstating original data layout for the test suite to properly check the values
+ src.info()->set_data_layout(_data_layout);
+ dst.info()->set_data_layout(_data_layout);
+ }
+
template <typename U>
void fill(U &&tensor, int i)
{
@@ -209,10 +225,16 @@ protected:
fill(AccessorType(src), 0);
fill(AccessorType(weights), 1);
fill(AccessorType(bias), 2);
-
- // Compute convolution function
- conv.run();
-
+
+ if(_mixed_layout)
+ {
+ mix_layout(conv, src, dst);
+ }
+ else
+ {
+ // Compute Convolution function
+ conv.run();
+ }
return dst;
}
@@ -239,9 +261,10 @@ protected:
SimpleTensor<T> _reference{};
DataType _data_type{};
DataLayout _data_layout{};
+ bool _mixed_layout{false};
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
class FFTConvolutionValidationFixture : public FFTConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
{
public:
@@ -250,7 +273,7 @@ public:
DataType data_type, DataLayout data_layout, ActivationLayerInfo act_info)
{
FFTConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation,
- data_type, data_layout, act_info);
+ data_type, data_layout, act_info, mixed_layout);
}
};
} // namespace validation