From f466d75f85938b96dd14675ec091193bdce12122 Mon Sep 17 00:00:00 2001 From: SiCongLi Date: Mon, 1 Mar 2021 15:26:18 +0000 Subject: Add QASYMM8_SIGNED support to graph examples via graph mutator Related to COMPMID-4279 Signed-off-by: SiCongLi Change-Id: I6c737536b4e614cc9975003acca766803f55bf0b Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5206 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- arm_compute/graph/Types.h | 4 +- .../graph/mutators/SyntheticDataTypeMutator.h | 7 ++- examples/graph_deepspeech_v0_4_1.cpp | 11 ++-- examples/graph_inception_v3.cpp | 16 +++--- examples/graph_inception_v4.cpp | 13 ++--- examples/graph_mobilenet.cpp | 1 - examples/graph_resnet50.cpp | 13 ++--- examples/graph_resnet_v2_50.cpp | 13 ++--- examples/graph_squeezenet.cpp | 13 ++--- examples/graph_squeezenet_v1_1.cpp | 13 ++--- examples/graph_srcnn955.cpp | 13 ++--- examples/graph_vgg16.cpp | 13 ++--- examples/graph_vgg19.cpp | 13 ++--- examples/graph_vgg_vdsr.cpp | 13 ++--- src/core/Utils.cpp | 3 +- src/graph/Utils.cpp | 19 +++++-- src/graph/mutators/SyntheticDataTypeMutator.cpp | 59 ++++++++++++++++++---- utils/CommonGraphOptions.cpp | 1 + 18 files changed, 155 insertions(+), 83 deletions(-) diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h index 77e91b205a..c4afba2dee 100644 --- a/arm_compute/graph/Types.h +++ b/arm_compute/graph/Types.h @@ -76,6 +76,7 @@ constexpr EdgeID EmptyEdgeID = std::numeric_limits::max(); // Forward declarations struct TensorDescriptor; + /** Graph configuration structure */ struct GraphConfig { @@ -83,7 +84,8 @@ struct GraphConfig bool use_function_weights_manager{ true }; /**< Use a weights manager to manage transformed weights */ bool use_transition_memory_manager{ true }; /**< Use a memory manager to manager transition buffer memory */ bool use_tuner{ false }; /**< Use a tuner in tunable backends */ - bool convert_to_uint8{ false }; /**< Convert graph to a synthetic uint8 graph */ + bool use_synthetic_type{ false }; /**< Convert graph to a synthetic graph for a data type */ + DataType synthetic_type{ DataType::QASYMM8 }; /**< The data type of the synthetic graph */ CLTunerMode tuner_mode{ CLTunerMode::EXHAUSTIVE }; /**< Tuner mode to be used by the CL tuner */ int num_threads{ -1 }; /**< Number of threads to use (thread capable backends), if 0 the backend will auto-initialize, if -1 the backend will stay as it is. */ std::string tuner_file{ "acl_tuner.csv" }; /**< File to load/store tuning values from */ diff --git a/arm_compute/graph/mutators/SyntheticDataTypeMutator.h b/arm_compute/graph/mutators/SyntheticDataTypeMutator.h index ed270f894f..2292e52086 100644 --- a/arm_compute/graph/mutators/SyntheticDataTypeMutator.h +++ b/arm_compute/graph/mutators/SyntheticDataTypeMutator.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Arm Limited. + * Copyright (c) 2019-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -35,10 +35,15 @@ namespace graph class SyntheticDataTypeMutator final : public IGraphMutator { public: + // Constructor + SyntheticDataTypeMutator(DataType mutate_type = DataType::QASYMM8); // Inherited methods overridden virtual void mutate(Graph &g) override; MutationType type() const override; const char *name() override; + +private: + DataType _mutate_type; }; } // namespace graph } // namespace arm_compute diff --git a/examples/graph_deepspeech_v0_4_1.cpp b/examples/graph_deepspeech_v0_4_1.cpp index a5658625c7..da163b6493 100644 --- a/examples/graph_deepspeech_v0_4_1.cpp +++ b/examples/graph_deepspeech_v0_4_1.cpp @@ -208,11 +208,12 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_file = common_params.tuner_file; - config.mlgo_file = common_params.mlgo_file; - config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.use_synthetic_type = arm_compute::is_data_type_quantized(common_params.data_type); + config.synthetic_type = common_params.data_type; graph.finalize(common_params.target, config); diff --git a/examples/graph_inception_v3.cpp b/examples/graph_inception_v3.cpp index 8ceeb5c68e..928efb9124 100644 --- a/examples/graph_inception_v3.cpp +++ b/examples/graph_inception_v3.cpp @@ -22,6 +22,8 @@ * SOFTWARE. */ #include "arm_compute/graph.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/core/Utils.h" #include "support/ToolchainSupport.h" #include "utils/CommonGraphOptions.h" #include "utils/GraphUtils.h" @@ -197,13 +199,13 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_mode = common_params.tuner_mode; - config.tuner_file = common_params.tuner_file; - config.mlgo_file = common_params.mlgo_file; - config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); - + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_mode = common_params.tuner_mode; + config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.use_synthetic_type = arm_compute::is_data_type_quantized(common_params.data_type); + config.synthetic_type = common_params.data_type; graph.finalize(common_params.target, config); return true; diff --git a/examples/graph_inception_v4.cpp b/examples/graph_inception_v4.cpp index cafa5c9f10..0c67215136 100644 --- a/examples/graph_inception_v4.cpp +++ b/examples/graph_inception_v4.cpp @@ -152,12 +152,13 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_mode = common_params.tuner_mode; - config.tuner_file = common_params.tuner_file; - config.mlgo_file = common_params.mlgo_file; - config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_mode = common_params.tuner_mode; + config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.use_synthetic_type = arm_compute::is_data_type_quantized(common_params.data_type); + config.synthetic_type = common_params.data_type; // Load the precompiled kernels from a file into the kernel library, in this way the next time they are needed // compilation won't be required. diff --git a/examples/graph_mobilenet.cpp b/examples/graph_mobilenet.cpp index 09b6e6e097..4d4e17715d 100644 --- a/examples/graph_mobilenet.cpp +++ b/examples/graph_mobilenet.cpp @@ -101,7 +101,6 @@ public: config.tuner_mode = common_params.tuner_mode; config.tuner_file = common_params.tuner_file; config.mlgo_file = common_params.mlgo_file; - config.mlgo_file = common_params.mlgo_file; graph.finalize(common_params.target, config); diff --git a/examples/graph_resnet50.cpp b/examples/graph_resnet50.cpp index b585284c60..5834d9be77 100644 --- a/examples/graph_resnet50.cpp +++ b/examples/graph_resnet50.cpp @@ -111,12 +111,13 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_mode = common_params.tuner_mode; - config.tuner_file = common_params.tuner_file; - config.mlgo_file = common_params.mlgo_file; - config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_mode = common_params.tuner_mode; + config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.use_synthetic_type = arm_compute::is_data_type_quantized(common_params.data_type); + config.synthetic_type = common_params.data_type; graph.finalize(common_params.target, config); diff --git a/examples/graph_resnet_v2_50.cpp b/examples/graph_resnet_v2_50.cpp index 472bf02b47..cd4e6fd6df 100644 --- a/examples/graph_resnet_v2_50.cpp +++ b/examples/graph_resnet_v2_50.cpp @@ -114,12 +114,13 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_mode = common_params.tuner_mode; - config.tuner_file = common_params.tuner_file; - config.mlgo_file = common_params.mlgo_file; - config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_mode = common_params.tuner_mode; + config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.use_synthetic_type = arm_compute::is_data_type_quantized(common_params.data_type); + config.synthetic_type = common_params.data_type; graph.finalize(common_params.target, config); diff --git a/examples/graph_squeezenet.cpp b/examples/graph_squeezenet.cpp index 3d32794e8d..82d95143be 100644 --- a/examples/graph_squeezenet.cpp +++ b/examples/graph_squeezenet.cpp @@ -164,12 +164,13 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_mode = common_params.tuner_mode; - config.tuner_file = common_params.tuner_file; - config.mlgo_file = common_params.mlgo_file; - config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_mode = common_params.tuner_mode; + config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.use_synthetic_type = arm_compute::is_data_type_quantized(common_params.data_type); + config.synthetic_type = common_params.data_type; graph.finalize(common_params.target, config); diff --git a/examples/graph_squeezenet_v1_1.cpp b/examples/graph_squeezenet_v1_1.cpp index 6d4ffee994..1a7d752d50 100644 --- a/examples/graph_squeezenet_v1_1.cpp +++ b/examples/graph_squeezenet_v1_1.cpp @@ -164,12 +164,13 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_mode = common_params.tuner_mode; - config.tuner_file = common_params.tuner_file; - config.mlgo_file = common_params.mlgo_file; - config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_mode = common_params.tuner_mode; + config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.use_synthetic_type = arm_compute::is_data_type_quantized(common_params.data_type); + config.synthetic_type = common_params.data_type; graph.finalize(common_params.target, config); diff --git a/examples/graph_srcnn955.cpp b/examples/graph_srcnn955.cpp index f4ffc02130..ccad0b65e4 100644 --- a/examples/graph_srcnn955.cpp +++ b/examples/graph_srcnn955.cpp @@ -115,12 +115,13 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_mode = common_params.tuner_mode; - config.tuner_file = common_params.tuner_file; - config.mlgo_file = common_params.mlgo_file; - config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_mode = common_params.tuner_mode; + config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.use_synthetic_type = arm_compute::is_data_type_quantized(common_params.data_type); + config.synthetic_type = common_params.data_type; graph.finalize(common_params.target, config); diff --git a/examples/graph_vgg16.cpp b/examples/graph_vgg16.cpp index 83e663798b..3e453b3626 100644 --- a/examples/graph_vgg16.cpp +++ b/examples/graph_vgg16.cpp @@ -212,12 +212,13 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_mode = common_params.tuner_mode; - config.tuner_file = common_params.tuner_file; - config.mlgo_file = common_params.mlgo_file; - config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_mode = common_params.tuner_mode; + config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.use_synthetic_type = arm_compute::is_data_type_quantized(common_params.data_type); + config.synthetic_type = common_params.data_type; graph.finalize(common_params.target, config); diff --git a/examples/graph_vgg19.cpp b/examples/graph_vgg19.cpp index 03f7e1606c..d79aa01326 100644 --- a/examples/graph_vgg19.cpp +++ b/examples/graph_vgg19.cpp @@ -223,12 +223,13 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_mode = common_params.tuner_mode; - config.tuner_file = common_params.tuner_file; - config.mlgo_file = common_params.mlgo_file; - config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_mode = common_params.tuner_mode; + config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.use_synthetic_type = arm_compute::is_data_type_quantized(common_params.data_type); + config.synthetic_type = common_params.data_type; graph.finalize(common_params.target, config); diff --git a/examples/graph_vgg_vdsr.cpp b/examples/graph_vgg_vdsr.cpp index bdb898081d..226edcd15b 100644 --- a/examples/graph_vgg_vdsr.cpp +++ b/examples/graph_vgg_vdsr.cpp @@ -136,12 +136,13 @@ public: // Finalize graph GraphConfig config; - config.num_threads = common_params.threads; - config.use_tuner = common_params.enable_tuner; - config.tuner_mode = common_params.tuner_mode; - config.tuner_file = common_params.tuner_file; - config.mlgo_file = common_params.mlgo_file; - config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8); + config.num_threads = common_params.threads; + config.use_tuner = common_params.enable_tuner; + config.tuner_mode = common_params.tuner_mode; + config.tuner_file = common_params.tuner_file; + config.mlgo_file = common_params.mlgo_file; + config.use_synthetic_type = arm_compute::is_data_type_quantized(common_params.data_type); + config.synthetic_type = common_params.data_type; graph.finalize(common_params.target, config); diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp index babf1c4b91..8deb5979ac 100644 --- a/src/core/Utils.cpp +++ b/src/core/Utils.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2020 Arm Limited. + * Copyright (c) 2016-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -330,6 +330,7 @@ DataType data_type_from_name(const std::string &name) { "f16", DataType::F16 }, { "f32", DataType::F32 }, { "qasymm8", DataType::QASYMM8 }, + { "qasymm8_signed", DataType::QASYMM8_SIGNED }, }; #ifndef ARM_COMPUTE_EXCEPTIONS_DISABLED diff --git a/src/graph/Utils.cpp b/src/graph/Utils.cpp index 2835af311a..8e12689fb9 100644 --- a/src/graph/Utils.cpp +++ b/src/graph/Utils.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -81,9 +81,22 @@ PassManager create_default_pass_manager(Target target, const GraphConfig &cfg) const bool is_target_gc = target == Target::GC; // Passes that mutate graph IR - if(cfg.convert_to_uint8) + if(cfg.use_synthetic_type) { - pm.append(std::make_unique(), !is_target_gc); + switch(cfg.synthetic_type) + { + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + { + pm.append(std::make_unique(cfg.synthetic_type), !is_target_gc); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported DataType for SyntheticDataTypeMutator"); + break; + } + } } pm.append(std::make_unique(), !is_target_gc); pm.append(std::make_unique()); diff --git a/src/graph/mutators/SyntheticDataTypeMutator.cpp b/src/graph/mutators/SyntheticDataTypeMutator.cpp index 21bafa61e1..74d040b81d 100644 --- a/src/graph/mutators/SyntheticDataTypeMutator.cpp +++ b/src/graph/mutators/SyntheticDataTypeMutator.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -120,15 +120,28 @@ void remove_optimized_nodes(Graph &g) * * @param[in,out] g Graph to convert tensors of. */ -void convert_tensors(Graph &g) +void convert_tensors(Graph &g, DataType data_type) { auto &tensors = g.tensors(); for(auto &tensor : tensors) { if(tensor != nullptr) { - tensor->desc().data_type = DataType::QASYMM8; - tensor->desc().quant_info = QuantizationInfo(0.125f, -10); + switch(data_type) + { + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + { + tensor->desc().quant_info = QuantizationInfo(0.125f, -10); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported mutation type"); + break; + } + } + tensor->desc().data_type = data_type; } } } @@ -164,20 +177,41 @@ void convert_special_tensors(Graph &g) auto softmax_func = [](INode * node, Tensor * tensor) { ARM_COMPUTE_UNUSED(node); - tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + if(tensor->desc().data_type == DataType::QASYMM8) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + } + else if(tensor->desc().data_type == DataType::QASYMM8_SIGNED) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, -128); + } return true; }; auto act_func = [](INode * node, Tensor * tensor) { auto *act_node = arm_compute::utils::cast::polymorphic_downcast(node); - if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH) + if(tensor->desc().data_type == DataType::QASYMM8) { - tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 128); + if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 128); + } + else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + } } - else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + else if(tensor->desc().data_type == DataType::QASYMM8_SIGNED) { - tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 0); + } + else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, -128); + } } return true; }; @@ -232,6 +266,11 @@ void handle_nodes_with_bias(Graph &g) } } // namespace +SyntheticDataTypeMutator::SyntheticDataTypeMutator(DataType mutate_type) + : _mutate_type{ mutate_type } +{ +} + const char *SyntheticDataTypeMutator::name() { return "SyntheticDataTypeMutator"; @@ -250,7 +289,7 @@ void SyntheticDataTypeMutator::mutate(Graph &g) remove_optimized_nodes(g); // Convert tensor - convert_tensors(g); + convert_tensors(g, _mutate_type); convert_special_tensors(g); // Handle special nodes diff --git a/utils/CommonGraphOptions.cpp b/utils/CommonGraphOptions.cpp index 44d66fa91b..b8808a476f 100644 --- a/utils/CommonGraphOptions.cpp +++ b/utils/CommonGraphOptions.cpp @@ -145,6 +145,7 @@ CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser) DataType::F16, DataType::F32, DataType::QASYMM8, + DataType::QASYMM8_SIGNED, }; std::set supported_data_layouts -- cgit v1.2.1