From 73f19af80aaee8929553739894b8dd8fedb163c3 Mon Sep 17 00:00:00 2001 From: Ramy Elgammal Date: Sun, 23 Oct 2022 11:44:49 +0100 Subject: Add Dynamic Fusion GpuConv2d FP32/FP16 testcase Resolves: COMPMID-5511 Signed-off-by: Ramy Elgammal Change-Id: I0ac0acbf1de7da09f18f7b457307ec3cc99deb3b Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8546 Comments-Addressed: Arm Jenkins Reviewed-by: SiCong Li Tested-by: Arm Jenkins Benchmark: Arm Jenkins --- tests/SConscript | 1 + tests/SimpleTensorPrinter.h | 6 +- tests/Utils.h | 13 +- tests/validation/CL/DirectConvolutionLayer.cpp | 52 ++++- .../dynamic_fusion/gpu/cl/DirectConv2d.cpp | 100 ++++++++++ .../dynamic_fusion/gpu/cl/DirectConv2dFixture.h | 216 +++++++++++++++++++++ utils/TypePrinter.h | 57 ++++++ 7 files changed, 441 insertions(+), 4 deletions(-) create mode 100644 tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp create mode 100644 tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h diff --git a/tests/SConscript b/tests/SConscript index 95ecd27afa..87b654385a 100644 --- a/tests/SConscript +++ b/tests/SConscript @@ -122,6 +122,7 @@ if env['opencl']: if env['experimental_dynamic_fusion']: test_env.Append(CPPDEFINES = ['ENABLE_EXPERIMENTAL_DYNAMIC_FUSION']) files_validation += Glob('validation/dynamic_fusion/gpu/' + filter_pattern) + files_validation += Glob('validation/dynamic_fusion/gpu/cl/' + filter_pattern) filter_pattern = test_env['test_filter'] diff --git a/tests/SimpleTensorPrinter.h b/tests/SimpleTensorPrinter.h index 5d0299a696..e4ca66bb36 100644 --- a/tests/SimpleTensorPrinter.h +++ b/tests/SimpleTensorPrinter.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018, 2021 Arm Limited. + * Copyright (c) 2017-2018, 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -22,6 +22,9 @@ * SOFTWARE. */ +#ifndef ARM_COMPUTE_TEST_SIMPLE_TENSOR_PRINTER +#define ARM_COMPUTE_TEST_SIMPLE_TENSOR_PRINTER + #include "arm_compute/core/Error.h" #include "tests/RawTensor.h" @@ -152,3 +155,4 @@ void print_simpletensor(const SimpleTensor &tensor, const std::string &title, #endif // PRINT_TENSOR_LIMIT } // namespace test } // namespace arm_compute +#endif /* ARM_COMPUTE_TEST_SIMPLE_TENSOR_PRINTER */ diff --git a/tests/Utils.h b/tests/Utils.h index b62ad4a677..e58b8f7f86 100644 --- a/tests/Utils.h +++ b/tests/Utils.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -48,6 +48,7 @@ #include #include +#include "arm_compute/dynamic_fusion/sketch/OperatorAttributes.h" #include "arm_compute/runtime/CPP/CPPScheduler.h" #include "arm_compute/runtime/RuntimeContext.h" @@ -133,7 +134,7 @@ using make_unsigned_conditional_t = typename std::conditionalset_tensor_dims_state(construct_static_dims_state()); } + +inline experimental::dynamic_fusion::Conv2dAttributes convert_pad_stride_info_to_conv_attr(const PadStrideInfo &info, const Size2D &dialation) +{ + const Padding2D info_pad(info.pad_left(), info.pad_right(), info.pad_top(), info.pad_bottom()); + const Size2D info_stride(info.stride().first, info.stride().second); + return arm_compute::experimental::dynamic_fusion::Conv2dAttributes().pad(info_pad).stride(info_stride).dilation(dialation); +} + } // namespace test } // namespace arm_compute #endif /* ARM_COMPUTE_TEST_UTILS_H */ diff --git a/tests/validation/CL/DirectConvolutionLayer.cpp b/tests/validation/CL/DirectConvolutionLayer.cpp index 324b076482..f026bfe0b0 100644 --- a/tests/validation/CL/DirectConvolutionLayer.cpp +++ b/tests/validation/CL/DirectConvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -132,6 +132,56 @@ TEST_CASE(NoBias, framework::DatasetMode::PRECOMMIT) validate(CLAccessor(dst), ref_dst); } +/** Check whether the case of rectangle kernels i.e. when width and height of the weight_shape are not equal + * would lead to successful run + */ +TEST_CASE(NonSquareKernel, framework::DatasetMode::PRECOMMIT) +{ + auto src_shape = TensorShape(33U, 27U, 3U); + auto weights_shape = TensorShape(5U, 7U, 3U, 4U); // non-square kernel + const auto bias_shape = TensorShape(4U); + auto dst_shape = TensorShape(11U, 12U, 4U); + constexpr auto dt = DataType::F32; + + TensorShape src_shape_nhwc(src_shape); + TensorShape weights_shape_nhwc(weights_shape); + TensorShape dst_shape_nhwc(dst_shape); + + // Non-square shapes are only allowed for NHWC + permute(src_shape_nhwc, PermutationVector(2U, 0U, 1U)); + permute(weights_shape_nhwc, PermutationVector(2U, 0U, 1U)); + permute(dst_shape_nhwc, PermutationVector(2U, 0U, 1U)); + + auto src = create_tensor(src_shape_nhwc, dt, 1, QuantizationInfo(), DataLayout::NHWC); + auto weights = create_tensor(weights_shape_nhwc, dt, 1, QuantizationInfo(), DataLayout::NHWC); + auto dst = create_tensor(dst_shape_nhwc, dt, 1, QuantizationInfo(), DataLayout::NHWC); + const auto conv_info = PadStrideInfo(3, 2, 1, 1, 2, 0, DimensionRoundingType::FLOOR); + + // Create direct convolution function + CLDirectConvolutionLayer conv{}; + conv.configure(&src, &weights, nullptr, &dst, conv_info); + + src.allocator()->allocate(); + weights.allocator()->allocate(); + dst.allocator()->allocate(); + + library->fill_tensor_value(CLAccessor(src), 1.f); + library->fill_tensor_value(CLAccessor(weights), 1.f); + + conv.run(); + + // Compute reference to compare + SimpleTensor ref_src{ src_shape, dt }; + SimpleTensor ref_weights{ weights_shape, dt }; + SimpleTensor ref_bias{ bias_shape, dt }; + library->fill_tensor_value(ref_src, 1.f); + library->fill_tensor_value(ref_weights, 1.f); + // No bias + library->fill_tensor_value(ref_bias, 0.f); + auto ref_dst = reference::convolution_layer(ref_src, ref_weights, ref_bias, dst_shape, conv_info); + + validate(CLAccessor(dst), ref_dst); +} // *INDENT-OFF* // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip( diff --git a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp new file mode 100644 index 0000000000..1f9319b10f --- /dev/null +++ b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Types.h" + +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/runtime/CL/CLScheduler.h" +#include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h" +#include "arm_compute/dynamic_fusion/sketch/OperatorAttributes.h" +#include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h" +#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuConv2d.h" + +#include "tests/AssetsLibrary.h" +#include "tests/CL/CLAccessor.h" +#include "tests/Globals.h" +#include "tests/IAccessor.h" +#include "tests/framework/Asserts.h" +#include "tests/framework/Fixture.h" +#include "tests/framework/Macros.h" +#include "tests/framework/datasets/Datasets.h" +#include "tests/validation/Validation.h" +#include "tests/validation/reference/ConvolutionLayer.h" + +#include "tests/datasets/SmallConvolutionLayerDataset.h" +#include "tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h" + +#ifdef ARM_COMPUTE_ASSERTS_ENABLED +#include "tests/SimpleTensorPrinter.h" +#endif /* ARM_COMPUTE_ASSERTS_ENABLED */ +#include "tests/framework/Asserts.h" +#include "tests/framework/Macros.h" +#include "tests/validation/Validation.h" +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +TEST_SUITE(CL) +TEST_SUITE(DYNAMIC_FUSION) +TEST_SUITE(GPU_CONV2D) + +RelativeTolerance tolerance_f32(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ +RelativeTolerance tolerance_f16(half_float::half(0.1)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ +constexpr float tolerance_num = 0.02f; /**< Tolerance number */ + +template +using DynamicFusionGpuConv2dFixture = DynamicFusionGpuConv2dValidationFixture; +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuConv2dFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", QuantizationInfo()))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_f32); +} +TEST_SUITE_END() // FP32 + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuConv2dFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", QuantizationInfo()))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num); +} +TEST_SUITE_END() // FP16 +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +TEST_SUITE_END() // GPU_CONV2D +TEST_SUITE_END() // DYNAMIC_FUSION +TEST_SUITE_END() // CL +} // namespace validation +} // namespace test +} // namespace arm_compute diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h new file mode 100644 index 0000000000..b0522488b4 --- /dev/null +++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h @@ -0,0 +1,216 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_TEST_DYNAMIC_FUSION_FIXTURE +#define ARM_COMPUTE_TEST_DYNAMIC_FUSION_FIXTURE + +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Types.h" + +#include "arm_compute/runtime/CL/CLScheduler.h" + +#include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h" +#include "arm_compute/dynamic_fusion/sketch/OperatorAttributes.h" +#include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h" +#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuConv2d.h" + +#include "src/gpu/cl/operators/ClAdd.h" +#include "src/gpu/cl/operators/ClConv2d.h" + +#include "tests/CL/CLAccessor.h" + +#include "tests/framework/Asserts.h" +#include "tests/framework/Fixture.h" +#include "tests/framework/Macros.h" + +#include "tests/validation/Validation.h" +#include "tests/validation/reference/ConvolutionLayer.h" +#include "tests/validation/reference/ElementwiseOperations.h" +#include "tests/validation/reference/Permute.h" + +using namespace arm_compute::experimental::dynamic_fusion; + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +template +class DynamicFusionGpuConv2dValidationGenericFixture : public framework::Fixture +{ +public: + using TBias = typename std::conditional < std::is_same::type, uint8_t>::value + || std::is_same::type, int8_t>::value, + int32_t, T >::type; // If T: uint8_t or int8_t then TBias: int32_t, otherwise TBias: T + + template + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, const PadStrideInfo &info, const Size2D &dilation, DataType data_type, + DataLayout data_layout, QuantizationInfo quantization_info, QuantizationInfo weight_quantization_info) + { + ARM_COMPUTE_ERROR_ON(data_layout != DataLayout::NHWC); // Dynamic fusion conv2d only supports NHWC layout + const Conv2dAttributes conv2d_attr = convert_pad_stride_info_to_conv_attr(info, dilation); + _data_type = data_type; + _data_layout = data_layout; + _is_quantized = is_data_type_quantized_asymmetric(data_type); + _quantization_info = quantization_info; + _weight_quantization_info = weight_quantization_info; + _bias_data_type = _is_quantized ? DataType::S32 : data_type; + _target = compute_target(input_shape, weights_shape, bias_shape, conv2d_attr); + _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, conv2d_attr); + } + +protected: + template + void fill(U &&tensor, int i) + { + switch(tensor.data_type()) + { + case DataType::F16: + { + arm_compute::utils::uniform_real_distribution_16bit distribution{ -1.0f, 1.0f }; + library->fill(tensor, distribution, i); + break; + } + case DataType::F32: + { + std::uniform_real_distribution distribution(-1.0f, 1.0f); + library->fill(tensor, distribution, i); + break; + } + default: + library->fill_tensor_uniform(tensor, i); + } + } + + // Given input is in nchw format + TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, Conv2dAttributes conv2d_attr) + { + ARM_COMPUTE_ERROR_ON(_data_layout != DataLayout::NHWC); + permute(input_shape, PermutationVector(2U, 0U, 1U)); + permute(weights_shape, PermutationVector(2U, 0U, 1U)); + CLScheduler::get().default_reinit(); + + // Create a new workload sketch + auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); + auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx }; + GpuWorkloadSketch sketch{ &gpu_ctx }; + + // Create sketch tensors + auto input_info = sketch.create_tensor_info(TensorInfo(input_shape, 1, _data_type, _data_layout)); + auto weight_info = sketch.create_tensor_info(TensorInfo(weights_shape, 1, _data_type, _data_layout)); + auto bias_info = sketch.create_tensor_info(TensorInfo(bias_shape, 1, _data_type, _data_layout)); + auto dst_info = sketch.create_tensor_info(); + FunctionType::create_op(sketch, &input_info, &weight_info, &bias_info, &dst_info, conv2d_attr); + + // Configure runtime + ClWorkloadRuntime runtime; + runtime.configure(sketch); + // (Important) Allocate auxiliary tensor memory if there are any + for(auto &data : runtime.get_auxiliary_tensors()) + { + auto tensor = data.first; + const auto aux_mem_req = data.second; + tensor->allocator()->init(*data.first->info(), aux_mem_req.alignment); + tensor->allocator()->allocate(); // Use ACL allocated memory + } + // Construct user tensors + CLTensor t_input{}; + CLTensor t_weight{}; + CLTensor t_bias{}; + CLTensor t_dst{}; + + // Initialize user tensors + t_input.allocator()->init(input_info); + t_weight.allocator()->init(weight_info); + t_bias.allocator()->init(bias_info); + t_dst.allocator()->init(dst_info); + + // Allocate and fill user tensors + t_input.allocator()->allocate(); + t_weight.allocator()->allocate(); + t_bias.allocator()->allocate(); + t_dst.allocator()->allocate(); + fill(CLAccessor(t_input), 0); + fill(CLAccessor(t_weight), 1); + fill(CLAccessor(t_bias), 2); + + // Run runtime + runtime.run({ &t_input, &t_weight, &t_bias, &t_dst }); + return t_dst; + } + + SimpleTensor compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, + const TensorShape &output_shape, Conv2dAttributes conv2d_attr) + { + // Create reference + SimpleTensor src{ input_shape, _data_type, 1, _quantization_info }; + SimpleTensor weight{ weights_shape, _data_type, 1, _weight_quantization_info }; + SimpleTensor bias{ bias_shape, _data_type, 1, _quantization_info }; + + fill(src, 0); + fill(weight, 1); + fill(bias, 2); + + auto src_nchw = src; + auto weights_nchw = weight; + auto bias_nchw = bias; + auto output_shape_nchw = output_shape; + + PadStrideInfo legacy_pad_stride(conv2d_attr.stride().x(), conv2d_attr.stride().y(), conv2d_attr.pad().left, conv2d_attr.pad().right, conv2d_attr.pad().top, conv2d_attr.pad().bottom, + DimensionRoundingType{}); + auto dst_nchw = reference::convolution_layer(src_nchw, weights_nchw, bias_nchw, output_shape_nchw, legacy_pad_stride, conv2d_attr.dilation()); + return dst_nchw; + } + + TensorType _target{}; + SimpleTensor _reference{}; + DataType _data_type{}; + DataType _weights_data_type{}; + DataType _bias_data_type{}; + DataType _output_data_type{}; + DataLayout _data_layout{}; + QuantizationInfo _quantization_info{}; + QuantizationInfo _weight_quantization_info{}; + bool _is_quantized = false; + bool _is_bfloat16 = false; + bool _mixed_layout = false; +}; + +template +class DynamicFusionGpuConv2dValidationFixture : public DynamicFusionGpuConv2dValidationGenericFixture +{ +public: + template + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape output_shape, TensorShape bias_shape, + const PadStrideInfo &info, const Size2D &dialation, DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info) + { + DynamicFusionGpuConv2dValidationGenericFixture::setup(input_shape, weights_shape, output_shape, bias_shape, info, dialation, + data_type, data_layout, quantization_info, quantization_info); + } +}; +} // namespace validation +} // namespace test +} // namespace arm_compute +#endif /* ARM_COMPUTE_TEST_DYNAMIC_FUSION_FIXTURE */ diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h index f55f72a4b8..8b50e9d1ef 100644 --- a/utils/TypePrinter.h +++ b/utils/TypePrinter.h @@ -38,6 +38,7 @@ #include "arm_compute/core/Types.h" #include "arm_compute/core/experimental/IPostOp.h" #include "arm_compute/core/experimental/PostOps.h" +#include "arm_compute/dynamic_fusion/sketch/OperatorAttributes.h" #include "arm_compute/runtime/CL/CLTunerTypes.h" #include "arm_compute/runtime/CL/CLTypes.h" #include "arm_compute/runtime/FunctionDescriptors.h" @@ -3371,6 +3372,62 @@ inline std::string to_string(const std::tuple