aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h')
-rw-r--r--tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h36
1 files changed, 12 insertions, 24 deletions
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h
index b0522488b4..e437c440d0 100644
--- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h
@@ -21,32 +21,23 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_DYNAMIC_FUSION_FIXTURE
-#define ARM_COMPUTE_TEST_DYNAMIC_FUSION_FIXTURE
+#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DIRECTCONV2DFIXTURE
+#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DIRECTCONV2DFIXTURE
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
-#include "arm_compute/runtime/CL/CLScheduler.h"
-
#include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h"
#include "arm_compute/dynamic_fusion/sketch/OperatorAttributes.h"
#include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuConv2d.h"
-#include "src/gpu/cl/operators/ClAdd.h"
-#include "src/gpu/cl/operators/ClConv2d.h"
-
#include "tests/CL/CLAccessor.h"
-
-#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
#include "tests/framework/Macros.h"
-
#include "tests/validation/Validation.h"
#include "tests/validation/reference/ConvolutionLayer.h"
-#include "tests/validation/reference/ElementwiseOperations.h"
#include "tests/validation/reference/Permute.h"
using namespace arm_compute::experimental::dynamic_fusion;
@@ -136,10 +127,10 @@ protected:
tensor->allocator()->allocate(); // Use ACL allocated memory
}
// Construct user tensors
- CLTensor t_input{};
- CLTensor t_weight{};
- CLTensor t_bias{};
- CLTensor t_dst{};
+ TensorType t_input{};
+ TensorType t_weight{};
+ TensorType t_bias{};
+ TensorType t_dst{};
// Initialize user tensors
t_input.allocator()->init(input_info);
@@ -152,9 +143,10 @@ protected:
t_weight.allocator()->allocate();
t_bias.allocator()->allocate();
t_dst.allocator()->allocate();
- fill(CLAccessor(t_input), 0);
- fill(CLAccessor(t_weight), 1);
- fill(CLAccessor(t_bias), 2);
+
+ fill(AccessorType(t_input), 0);
+ fill(AccessorType(t_weight), 1);
+ fill(AccessorType(t_bias), 2);
// Run runtime
runtime.run({ &t_input, &t_weight, &t_bias, &t_dst });
@@ -187,15 +179,11 @@ protected:
TensorType _target{};
SimpleTensor<T> _reference{};
DataType _data_type{};
- DataType _weights_data_type{};
DataType _bias_data_type{};
- DataType _output_data_type{};
DataLayout _data_layout{};
QuantizationInfo _quantization_info{};
QuantizationInfo _weight_quantization_info{};
bool _is_quantized = false;
- bool _is_bfloat16 = false;
- bool _mixed_layout = false;
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
@@ -207,10 +195,10 @@ public:
const PadStrideInfo &info, const Size2D &dialation, DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info)
{
DynamicFusionGpuConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, output_shape, bias_shape, info, dialation,
- data_type, data_layout, quantization_info, quantization_info);
+ data_type, data_layout, quantization_info, quantization_info);
}
};
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_DYNAMIC_FUSION_FIXTURE */
+#endif /* TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DIRECTCONV2DFIXTURE */