aboutsummaryrefslogtreecommitdiff
path: root/tests/validation_new/fixtures/FullyConnectedLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation_new/fixtures/FullyConnectedLayerFixture.h')
-rw-r--r--tests/validation_new/fixtures/FullyConnectedLayerFixture.h15
1 files changed, 8 insertions, 7 deletions
diff --git a/tests/validation_new/fixtures/FullyConnectedLayerFixture.h b/tests/validation_new/fixtures/FullyConnectedLayerFixture.h
index eb4aad8952..0953b0b67e 100644
--- a/tests/validation_new/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation_new/fixtures/FullyConnectedLayerFixture.h
@@ -76,7 +76,7 @@ RawTensor transpose(const RawTensor &src, int interleave = 1)
}
} // namespace
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave>
class FullyConnectedLayerValidationFixedPointFixture : public framework::Fixture
{
public:
@@ -131,7 +131,7 @@ protected:
// Weights have to be passed reshaped
// Transpose 1xW for batched version
- if(!reshape_weights && output_shape.y() > 1)
+ if(!reshape_weights && output_shape.y() > 1 && run_interleave)
{
const int transpose_width = 16 / data_size_from_type(data_type);
const float shape_x = reshaped_weights_shape.x();
@@ -182,7 +182,7 @@ protected:
tmp = transpose(tmp);
// Reshape weights for batched runs
- if(!reshape_weights && output_shape.y() > 1)
+ if(!reshape_weights && output_shape.y() > 1 && run_interleave)
{
// Transpose with interleave
const int interleave_size = 16 / tmp.element_size();
@@ -232,15 +232,16 @@ protected:
DataType _data_type{};
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave>
+class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T, run_interleave>
{
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type)
{
- FullyConnectedLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type,
- 0);
+ FullyConnectedLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T, run_interleave>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights,
+ reshape_weights, data_type,
+ 0);
}
};
} // namespace validation