diff options
Diffstat (limited to 'tests/validation_new/fixtures/FullyConnectedLayerFixture.h')
-rw-r--r-- | tests/validation_new/fixtures/FullyConnectedLayerFixture.h | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tests/validation_new/fixtures/FullyConnectedLayerFixture.h b/tests/validation_new/fixtures/FullyConnectedLayerFixture.h index eb4aad8952..0953b0b67e 100644 --- a/tests/validation_new/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation_new/fixtures/FullyConnectedLayerFixture.h @@ -76,7 +76,7 @@ RawTensor transpose(const RawTensor &src, int interleave = 1) } } // namespace -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave> class FullyConnectedLayerValidationFixedPointFixture : public framework::Fixture { public: @@ -131,7 +131,7 @@ protected: // Weights have to be passed reshaped // Transpose 1xW for batched version - if(!reshape_weights && output_shape.y() > 1) + if(!reshape_weights && output_shape.y() > 1 && run_interleave) { const int transpose_width = 16 / data_size_from_type(data_type); const float shape_x = reshaped_weights_shape.x(); @@ -182,7 +182,7 @@ protected: tmp = transpose(tmp); // Reshape weights for batched runs - if(!reshape_weights && output_shape.y() > 1) + if(!reshape_weights && output_shape.y() > 1 && run_interleave) { // Transpose with interleave const int interleave_size = 16 / tmp.element_size(); @@ -232,15 +232,16 @@ protected: DataType _data_type{}; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave> +class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T, run_interleave> { public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type) { - FullyConnectedLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type, - 0); + FullyConnectedLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T, run_interleave>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, + reshape_weights, data_type, + 0); } }; } // namespace validation |