aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-10-09 15:30:40 +0100
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-10-09 15:30:40 +0100
commita2a7380b43bd0961575b35704713fe914559e7b3 (patch)
tree459e574b3372a3423f5f52f6c3b394c3aef8959b
parent8f6429de4c278d79f076ddcea3fe1495e28fb75e (diff)
downloadandroid-nn-driver-a2a7380b43bd0961575b35704713fe914559e7b3.tar.gz
IVGCVSW-3891 Add support for INSTANCE_NORMALIZATION to the HAL1.2 Android driver
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: Ia4dfbda7aceb4cdfb8f83d49e4df21dedd415b7b
-rw-r--r--1.2/HalPolicy.cpp86
-rw-r--r--1.2/HalPolicy.hpp2
-rw-r--r--NnapiSupport.txt1
3 files changed, 89 insertions, 0 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index cd3d2a6c..019f5054 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -54,6 +54,8 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
return ConvertFullyConnected(operation, model, data);
case V1_2::OperationType::GROUPED_CONV_2D:
return ConvertGroupedConv2d(operation, model, data);
+ case V1_2::OperationType::INSTANCE_NORMALIZATION:
+ return ConvertInstanceNormalization(operation, model, data);
case V1_2::OperationType::L2_NORMALIZATION:
return ConvertL2Normalization(operation, model, data);
case V1_2::OperationType::L2_POOL_2D:
@@ -886,6 +888,90 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo
return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *endLayer, model, data);
}
+bool HalPolicy::ConvertInstanceNormalization(const Operation& operation, const Model& model, ConversionData& data)
+{
+ ALOGV("hal_1_2::HalPolicy::ConvertInstanceNormalization()");
+
+ LayerInputHandle input = ConvertToLayerInputHandle<hal_1_2::HalPolicy>(operation, 0, model, data);
+ if (!input.IsValid())
+ {
+ return Fail("%s: Operation has an invalid input 0", __func__);
+ }
+
+ const Operand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
+ if (!output)
+ {
+ return Fail("%s: Operation has an invalid output", __func__);
+ }
+
+ const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
+ if (IsDynamicTensor(outputInfo))
+ {
+ return Fail("%s: Dynamic output tensors are not supported", __func__);
+ }
+
+ // Determine data type of input tensor
+ OperandType inputType;
+ if (!GetOperandType<hal_1_2::HalPolicy>(operation, 0, model, inputType))
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ InstanceNormalizationDescriptor desc;
+
+ // Read gamma, beta & epsilon
+ if (inputType == OperandType::TENSOR_FLOAT16)
+ {
+ Half fp16Gamma;
+ Half fp16Beta;
+ Half fp16Epsilon;
+
+ if (!GetInputScalar<hal_1_2::HalPolicy>(operation, 1, OperandType::FLOAT16, fp16Gamma, model, data) ||
+ !GetInputScalar<hal_1_2::HalPolicy>(operation, 2, OperandType::FLOAT16, fp16Beta, model, data) ||
+ !GetInputScalar<hal_1_2::HalPolicy>(operation, 3, OperandType::FLOAT16, fp16Epsilon, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
+ }
+
+ desc.m_Gamma = static_cast<float>(fp16Gamma);
+ desc.m_Beta = static_cast<float>(fp16Beta);
+ desc.m_Eps = static_cast<float>(fp16Epsilon);
+ }
+ else if (inputType == OperandType::TENSOR_FLOAT32)
+ {
+ if (!GetInputScalar<hal_1_2::HalPolicy>(operation, 1, OperandType::FLOAT32, desc.m_Gamma, model, data) ||
+ !GetInputScalar<hal_1_2::HalPolicy>(operation, 2, OperandType::FLOAT32, desc.m_Beta, model, data) ||
+ !GetInputScalar<hal_1_2::HalPolicy>(operation, 3, OperandType::FLOAT32, desc.m_Eps, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
+ }
+ }
+ else
+ {
+ return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
+ }
+
+ desc.m_DataLayout = OptionalDataLayout<hal_1_2::HalPolicy>(operation, 4, model, data);
+
+ bool isSupported = false;
+ FORWARD_LAYER_SUPPORT_FUNC(__func__,
+ IsInstanceNormalizationSupported,
+ data.m_Backends,
+ isSupported,
+ input.GetTensorInfo(),
+ outputInfo,
+ desc);
+ if (!isSupported)
+ {
+ return false;
+ }
+
+ IConnectableLayer* layer = data.m_Network->AddInstanceNormalizationLayer(desc);
+ input.Connect(layer->GetInputSlot(0));
+
+ return SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *layer, model, data);
+}
+
bool HalPolicy::ConvertL2Normalization(const Operation& operation, const Model& model, ConversionData& data)
{
ALOGV("hal_1_2::HalPolicy::ConvertL2Normalization()");
diff --git a/1.2/HalPolicy.hpp b/1.2/HalPolicy.hpp
index 52f4a42e..aa69f127 100644
--- a/1.2/HalPolicy.hpp
+++ b/1.2/HalPolicy.hpp
@@ -59,6 +59,8 @@ private:
static bool ConvertGroupedConv2d(const Operation& operation, const Model& model, ConversionData& data);
+ static bool ConvertInstanceNormalization(const Operation& operation, const Model& model, ConversionData& data);
+
static bool ConvertL2Normalization(const Operation& operation, const Model& model, ConversionData& data);
static bool ConvertL2Pool2d(const Operation& operation, const Model& model, ConversionData& data);
diff --git a/NnapiSupport.txt b/NnapiSupport.txt
index be36514f..f159bc46 100644
--- a/NnapiSupport.txt
+++ b/NnapiSupport.txt
@@ -26,6 +26,7 @@ EXPAND_DIMS (FLOAT32,QUANT8_ASYMM)
FLOOR (FLOAT32)
FULLY_CONNECTED (FLOAT32,QUANT8_ASYMM)
GROUPED_CONV_2D (FLOAT32,QUANT8_ASYMM)
+INSTANCE_NORMALIZATION (FLOAT32,FLOAT16)
L2_NORMALIZATION (FLOAT32)
L2_POOL_2D (FLOAT32,QUANT8_ASYMM)
LOCAL_RESPONSE_NORMALIZATION (FLOAT32)