From 91780021e25575086c6c31d014d34b6513649a9d Mon Sep 17 00:00:00 2001 From: Ramy Elgammal Date: Wed, 20 Jul 2022 14:57:37 +0100 Subject: Fix for inclusion of "arm_gemm" from src into "Types.h" from core - Added arm_compute::WeightFormat and converted to/from arm_gemm::WeightFormat when needed through two map function. - Moved to_string(WeightFormat) to TypePrinter.h Resolves: COMPMID-5415 Signed-off-by: Ramy Elgammal Change-Id: I65f7942100bcd4dbf2c5cf6c07f26c8e1e3bf86e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/438511 Tested-by: bsgcomp Reviewed-by: Pablo Tello Reviewed-by: Sicong Li Comments-Addressed: bsgcomp Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7985 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Michalis Spyrou Benchmark: Arm Jenkins --- tests/framework/Asserts.h | 6 +- tests/validation/NEON/ConvolutionLayer.cpp | 120 ++++++++++----------- .../validation/fixtures/ConvolutionLayerFixture.h | 30 +++--- 3 files changed, 79 insertions(+), 77 deletions(-) (limited to 'tests') diff --git a/tests/framework/Asserts.h b/tests/framework/Asserts.h index 5f462773d0..7adfa8f2f3 100644 --- a/tests/framework/Asserts.h +++ b/tests/framework/Asserts.h @@ -30,6 +30,8 @@ #include #include +#include "utils/TypePrinter.h" + namespace arm_compute { namespace test @@ -42,9 +44,9 @@ inline int make_printable(int8_t value) return value; } -inline std::string make_printable(arm_gemm::WeightFormat wf) +inline std::string make_printable(const arm_compute::WeightFormat wf) { - return arm_gemm::to_string(wf); + return arm_compute::to_string(wf); } inline unsigned int make_printable(uint8_t value) diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index 940983f42b..0194220e1a 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -511,13 +511,13 @@ TEST_SUITE(VariableWeightUtils) FIXTURE_DATA_TEST_CASE(UC2_1_CpuGemmConv2d, HasOptImplFixture, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), - framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo2 }))) + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo2 }))) { ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); } FIXTURE_DATA_TEST_CASE(UC2_1_NEGEMMConvolutionLayer, HasOptImplFixture, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), - framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo2 }))) + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo2 }))) { ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); } @@ -527,18 +527,18 @@ FIXTURE_DATA_TEST_CASE(UC2_1_NEGEMMConvolutionLayer, HasOptImplFixture, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), - framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo4 }))) + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo4 }))) { ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format == arm_gemm::WeightFormat::OHWIo4, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo4, framework::LogLevel::ERRORS); } FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixture, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), - framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo4 }))) + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo4 }))) { ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format == arm_gemm::WeightFormat::OHWIo4, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo4, framework::LogLevel::ERRORS); } // UC3_1_* tests: the user queries for ANY fixed format, but there is @@ -548,14 +548,14 @@ FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixture, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::S32 }), - framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY }))) + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) { ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); } FIXTURE_DATA_TEST_CASE(UC3_1_NEGEMMConvolutionLayer, HasOptImplFixture, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::S32 }), - framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY }))) + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) { ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); } @@ -572,24 +572,24 @@ FIXTURE_DATA_TEST_CASE(UC3_1_NEGEMMConvolutionLayer, HasOptImplFixture, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), - framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY }))) + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) { ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::ANY, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); } FIXTURE_DATA_TEST_CASE(UC3_2_NEGEMMConvolutionLayer, HasOptImplFixture, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), - framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY }))) + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) { - ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::ANY, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); } namespace { -using TestCaseType = std::tuple; +using TestCaseType = std::tuple; auto prepare_weights_shapes = framework::dataset::make("TensorShape", { // OHWIoi @@ -601,51 +601,51 @@ auto prepare_weights_shapes = framework::dataset::make("TensorShape", // // Change N for OHWIo4 - TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 12U }, arm_gemm::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 12U }, arm_compute::WeightFormat::OHWIo4 }), // // Change N for OHWIo8 - TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }), - TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }), - TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }), - TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }), - TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }), - TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }), - TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }), - TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }), - TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 16U }, arm_gemm::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 16U }, arm_compute::WeightFormat::OHWIo8 }), // // Change N for OHWIo4 when H, W and C are not 1 - TestCaseType({ { 3U, 4U, 2U, 1U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 3U, 4U, 2U, 2U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 3U, 4U, 2U, 3U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 3U, 4U, 2U, 4U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 3U, 4U, 2U, 6U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 3U, 4U, 2U, 7U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 3U, 4U, 2U, 8U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 12 }, arm_gemm::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 1U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 2U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 3U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 4U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 6U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 7U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 8U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 12 }, arm_compute::WeightFormat::OHWIo4 }), // // Fix N and move HWI around, with different data layouts and formats - TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 2U, 4U, 3U, 9U }, { 2, 4, 3, 16 }, arm_gemm::WeightFormat::OHWIo8 }), - TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 16 }, arm_gemm::WeightFormat::OHWIo8 }), - TestCaseType({ { 1024U, 1U, 1U, 1001U }, { 1024, 1, 1, 1008 }, arm_gemm::WeightFormat::OHWIo8 }), + TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 2U, 4U, 3U, 9U }, { 2, 4, 3, 16 }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 16 }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1024U, 1U, 1U, 1001U }, { 1024, 1, 1, 1008 }, arm_compute::WeightFormat::OHWIo8 }), // // Adding on I (=C) - TestCaseType({ { 1U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }), - TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }), - TestCaseType({ { 3U, 4U, 3U, 5U }, { 4, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }), + TestCaseType({ { 1U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }), + TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }), + TestCaseType({ { 3U, 4U, 3U, 5U }, { 4, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }), // --------- - TestCaseType({ { 2, 2, 1, 5 }, { 2, 2, 1, 8 }, arm_gemm::WeightFormat::OHWIo4 }), - TestCaseType({ { 1, 2, 2, 5 }, { 1, 2, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }), + TestCaseType({ { 2, 2, 1, 5 }, { 2, 2, 1, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1, 2, 2, 5 }, { 1, 2, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), }); } // unnamed namespace @@ -653,14 +653,14 @@ auto prepare_weights_shapes = framework::dataset::make("TensorShape", DATA_TEST_CASE(PrepareWeightShape, framework::DatasetMode::ALL, prepare_weights_shapes, shapes) { - const TensorShape input_shape = std::get<0>(shapes); - const TensorShape expected_shape = std::get<1>(shapes); - const arm_gemm::WeightFormat wf = std::get<2>(shapes); - const DataType DT = DataType::F32; - const DataLayout DL = DataLayout::NHWC; - const auto TI = TensorInfo(input_shape, 1 /*num_channels, deprecated*/, DT, DL); - const TensorInfo computed = ::arm_compute::test::validation::prepare_weights(TI, wf); - const TensorInfo expected = TensorInfo(expected_shape, 1 /*num_channels, deprecated*/, DT, DL); + const TensorShape input_shape = std::get<0>(shapes); + const TensorShape expected_shape = std::get<1>(shapes); + const arm_compute::WeightFormat wf = std::get<2>(shapes); + const DataType DT = DataType::F32; + const DataLayout DL = DataLayout::NHWC; + const auto TI = TensorInfo(input_shape, 1 /*num_channels, deprecated*/, DT, DL); + const TensorInfo computed = ::arm_compute::test::validation::prepare_weights(TI, wf); + const TensorInfo expected = TensorInfo(expected_shape, 1 /*num_channels, deprecated*/, DT, DL); ARM_COMPUTE_EXPECT_EQUAL(computed, expected, framework::LogLevel::ERRORS); } diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index d3804ee371..c58a0a2c91 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -122,14 +122,14 @@ protected: { case DataType::QASYMM8: { - std::pair bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); + std::pair bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); std::uniform_int_distribution distribution(bounds.first, bounds.second); library->fill(tensor, distribution, i); break; } case DataType::QASYMM8_SIGNED: { - std::pair bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f); + std::pair bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f); std::uniform_int_distribution distribution(bounds.first, bounds.second); library->fill(tensor, distribution, i); break; @@ -400,7 +400,7 @@ public: }; #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS -inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_gemm::WeightFormat weight_format) +inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_compute::WeightFormat weight_format) { const DataLayout data_layout = tensor_info.data_layout(); ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS); @@ -411,8 +411,8 @@ inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_gemm:: const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; const int C = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I - const int interleave_by = arm_gemm::interleave_by(weight_format); - const int block_by = arm_gemm::block_by(weight_format); + const int interleave_by = arm_compute::interleave_by(weight_format); + const int block_by = arm_compute::block_by(weight_format); const int Ip = arm_gemm::roundup(C, block_by); // C'=I' const int Op = arm_gemm::roundup(N, interleave_by); // O'=N' @@ -421,12 +421,12 @@ inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_gemm:: } template -inline void rearrange_data(const AccessorType src, AccessorType dst, const arm_gemm::WeightFormat weight_format) +inline void rearrange_data(const AccessorType src, AccessorType dst, const arm_compute::WeightFormat weight_format) { - ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(weight_format), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format(weight_format), framework::LogLevel::ERRORS); // Data Layout: OHWIoi - const int interleave_by = arm_gemm::interleave_by(weight_format); - const int block_by = arm_gemm::block_by(weight_format); + const int interleave_by = arm_compute::interleave_by(weight_format); + const int block_by = arm_compute::block_by(weight_format); const TensorShape src_tensor_shape = src.shape(); const DataLayout data_layout = src.data_layout(); ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS); @@ -545,12 +545,12 @@ private: const int kernel_width = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH)]; const int num_kernels = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::BATCHES)]; - const WeightsInfo query_weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, arm_gemm::WeightFormat::ANY); + const WeightsInfo query_weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, arm_compute::WeightFormat::ANY); const bool kernel_found = bool(ConvolutionFunction::has_opt_impl(_computed_weight_format, &src_tensor_info, &weight_tensor_info, &bias_tensor_info, &dst_tensor_info, conv_info, query_weights_info)); // Make surethat the setup founds a fixed-format kernel as requested by the test case. ARM_COMPUTE_EXPECT(kernel_found, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(_computed_weight_format), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format(_computed_weight_format), framework::LogLevel::ERRORS); const WeightsInfo weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, _computed_weight_format); configure_and_execute_kernel(src_tensor_info, weight_tensor_info, bias_tensor_info, dst_tensor_info, weights_info, conv_info, @@ -576,7 +576,7 @@ private: protected: std::unique_ptr conv{}; - arm_gemm::WeightFormat _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED }; + arm_compute::WeightFormat _computed_weight_format{ arm_compute::WeightFormat::UNSPECIFIED }; TensorClass _target{}; SimpleTensor _reference{}; }; @@ -669,7 +669,7 @@ class HasOptImplFixture : public framework::Fixture { public: template - void setup(DataType data_type, arm_gemm::WeightFormat query_weight_format) + void setup(DataType data_type, arm_compute::WeightFormat query_weight_format) { auto conv = std::make_unique(); const auto src_info = TensorInfo(TensorShape(1U, 5U, 2U), 1, data_type, DataLayout::NHWC); @@ -683,8 +683,8 @@ public: } protected: - bool _kernel_found{ false }; - arm_gemm::WeightFormat _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED }; + bool _kernel_found{ false }; + arm_compute::WeightFormat _computed_weight_format{ arm_compute::WeightFormat::UNSPECIFIED }; }; #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS -- cgit v1.2.1