aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/SyntheticDataTypeMutator.cpp
diff options
context:
space:
mode:
authorSiCongLi <sicong.li@arm.com>2021-03-01 15:26:18 +0000
committerSiCong Li <sicong.li@arm.com>2021-03-03 10:45:00 +0000
commitf466d75f85938b96dd14675ec091193bdce12122 (patch)
treeb4dcb0f1c5e2bf15ba129525d0271e48a9b0f2a6 /src/graph/mutators/SyntheticDataTypeMutator.cpp
parent793daa138aee628bf80289191061ec8c81421bd2 (diff)
downloadComputeLibrary-f466d75f85938b96dd14675ec091193bdce12122.tar.gz
Add QASYMM8_SIGNED support to graph examples via graph mutator
Related to COMPMID-4279 Signed-off-by: SiCongLi <sicong.li@arm.com> Change-Id: I6c737536b4e614cc9975003acca766803f55bf0b Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5206 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/graph/mutators/SyntheticDataTypeMutator.cpp')
-rw-r--r--src/graph/mutators/SyntheticDataTypeMutator.cpp59
1 files changed, 49 insertions, 10 deletions
diff --git a/src/graph/mutators/SyntheticDataTypeMutator.cpp b/src/graph/mutators/SyntheticDataTypeMutator.cpp
index 21bafa61e1..74d040b81d 100644
--- a/src/graph/mutators/SyntheticDataTypeMutator.cpp
+++ b/src/graph/mutators/SyntheticDataTypeMutator.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -120,15 +120,28 @@ void remove_optimized_nodes(Graph &g)
*
* @param[in,out] g Graph to convert tensors of.
*/
-void convert_tensors(Graph &g)
+void convert_tensors(Graph &g, DataType data_type)
{
auto &tensors = g.tensors();
for(auto &tensor : tensors)
{
if(tensor != nullptr)
{
- tensor->desc().data_type = DataType::QASYMM8;
- tensor->desc().quant_info = QuantizationInfo(0.125f, -10);
+ switch(data_type)
+ {
+ case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
+ {
+ tensor->desc().quant_info = QuantizationInfo(0.125f, -10);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported mutation type");
+ break;
+ }
+ }
+ tensor->desc().data_type = data_type;
}
}
}
@@ -164,20 +177,41 @@ void convert_special_tensors(Graph &g)
auto softmax_func = [](INode * node, Tensor * tensor)
{
ARM_COMPUTE_UNUSED(node);
- tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0);
+ if(tensor->desc().data_type == DataType::QASYMM8)
+ {
+ tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0);
+ }
+ else if(tensor->desc().data_type == DataType::QASYMM8_SIGNED)
+ {
+ tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, -128);
+ }
return true;
};
auto act_func = [](INode * node, Tensor * tensor)
{
auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(node);
- if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH)
+ if(tensor->desc().data_type == DataType::QASYMM8)
{
- tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 128);
+ if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH)
+ {
+ tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 128);
+ }
+ else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC)
+ {
+ tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0);
+ }
}
- else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC)
+ else if(tensor->desc().data_type == DataType::QASYMM8_SIGNED)
{
- tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0);
+ if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH)
+ {
+ tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 0);
+ }
+ else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC)
+ {
+ tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, -128);
+ }
}
return true;
};
@@ -232,6 +266,11 @@ void handle_nodes_with_bias(Graph &g)
}
} // namespace
+SyntheticDataTypeMutator::SyntheticDataTypeMutator(DataType mutate_type)
+ : _mutate_type{ mutate_type }
+{
+}
+
const char *SyntheticDataTypeMutator::name()
{
return "SyntheticDataTypeMutator";
@@ -250,7 +289,7 @@ void SyntheticDataTypeMutator::mutate(Graph &g)
remove_optimized_nodes(g);
// Convert tensor
- convert_tensors(g);
+ convert_tensors(g, _mutate_type);
convert_special_tensors(g);
// Handle special nodes