diff options
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/FullyConnectedLayerFixture.h | 54 |
1 files changed, 18 insertions, 36 deletions
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index 18321480f8..49c3be0c2e 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -45,7 +45,7 @@ namespace test { namespace validation { -template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class FullyConnectedLayerValidationGenericFixture : public framework::Fixture { public: @@ -103,8 +103,8 @@ protected: // -----------+-----------+--------------------------- // transpose | | *** // -----------+-----------+--------------------------- - // !transpose | transpose | transpose & - // | | transpose1xW (if required) + // !transpose | transpose | transpose + // | | // // ***: That combination is invalid. But we can ignore the transpose flag and handle all !reshape the same if(!reshape_weights || !transpose_weights) @@ -112,16 +112,6 @@ protected: const size_t shape_x = reshaped_weights_shape.x(); reshaped_weights_shape.set(0, reshaped_weights_shape.y()); reshaped_weights_shape.set(1, shape_x); - - // Weights have to be passed reshaped - // Transpose 1xW for batched version - if(!reshape_weights && output_shape.y() > 1 && run_interleave) - { - const int transpose_width = 16 / data_size_from_type(_data_type); - const float shape_x = reshaped_weights_shape.x(); - reshaped_weights_shape.set(0, reshaped_weights_shape.y() * transpose_width); - reshaped_weights_shape.set(1, static_cast<unsigned int>(std::ceil(shape_x / transpose_width))); - } } // Create tensors @@ -170,14 +160,6 @@ protected: // Transpose elementwise tmp = transpose(tmp); - // Reshape weights for batched runs - if(!reshape_weights && output_shape.y() > 1 && run_interleave) - { - // Transpose with interleave - const int interleave_size = 16 / tmp.element_size(); - tmp = transpose(tmp, interleave_size); - } - AccessorType weights_accessor(weights); for(int i = 0; i < tmp.num_elements(); ++i) @@ -222,43 +204,43 @@ protected: QuantizationInfo _quantization_info{}; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave> -class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type) { - FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, - reshape_weights, data_type, - QuantizationInfo()); + FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, + reshape_weights, data_type, + QuantizationInfo()); } }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave> -class FullyConnectedLayerValidationFixedPointFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class FullyConnectedLayerValidationFixedPointFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type) { - FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, - reshape_weights, data_type, - QuantizationInfo()); + FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, + reshape_weights, data_type, + QuantizationInfo()); } }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave> -class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, QuantizationInfo quantization_info) { - FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, - reshape_weights, data_type, - quantization_info); + FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, + reshape_weights, data_type, + quantization_info); } }; } // namespace validation |