aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/WorkloadFactory.cpp')
-rw-r--r--src/armnn/backends/WorkloadFactory.cpp418
1 files changed, 377 insertions, 41 deletions
diff --git a/src/armnn/backends/WorkloadFactory.cpp b/src/armnn/backends/WorkloadFactory.cpp
index 4e94d7701c..1b3f29421a 100644
--- a/src/armnn/backends/WorkloadFactory.cpp
+++ b/src/armnn/backends/WorkloadFactory.cpp
@@ -20,7 +20,40 @@
namespace armnn
{
-bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, DataType dataType,
+namespace
+{
+ const TensorInfo OverrideDataType(const TensorInfo& info, boost::optional<DataType> type)
+ {
+ if (type == boost::none)
+ {
+ return info;
+ }
+
+ return TensorInfo(info.GetShape(), type.get(), info.GetQuantizationScale(), info.GetQuantizationOffset());
+ }
+
+ boost::optional<DataType> GetBiasTypeFromWeightsType(boost::optional<DataType> weightsType)
+ {
+ if (weightsType == boost::none)
+ {
+ return weightsType;
+ }
+
+ switch(weightsType.get())
+ {
+ case DataType::Float16:
+ case DataType::Float32:
+ return weightsType;
+ case DataType::QuantisedAsymm8:
+ return DataType::Signed32;
+ default:
+ BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
+ }
+ return boost::none;
+ }
+}
+
+bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, boost::optional<DataType> dataType,
std::string& outReasonIfUnsupported)
{
constexpr size_t reasonCapacity = 1024;
@@ -32,7 +65,13 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat
{
auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsActivationSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity);
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ result = IsActivationSupported(compute,
+ OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType),
+ cLayer->GetParameters(),
+ reason,
+ reasonCapacity);
break;
}
case LayerType::Addition:
@@ -40,30 +79,64 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
- result = IsAdditionSupported(compute, input0, input1, output, reason, reasonCapacity);
+ result = IsAdditionSupported(compute,
+ OverrideDataType(input0, dataType),
+ OverrideDataType(input1, dataType),
+ OverrideDataType(output, dataType),
+ reason,
+ reasonCapacity);
break;
}
case LayerType::BatchNormalization:
{
auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsBatchNormalizationSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity);
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
+ const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
+ const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
+ const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
+ result = IsBatchNormalizationSupported(compute,
+ OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType),
+ OverrideDataType(mean, dataType),
+ OverrideDataType(var, dataType),
+ OverrideDataType(beta, dataType),
+ OverrideDataType(gamma, dataType),
+ cLayer->GetParameters(),
+ reason, reasonCapacity);
break;
}
case LayerType::Constant:
{
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
- result = IsConstantSupported(compute, output, reason, reasonCapacity);
+ result = IsConstantSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity);
break;
}
- case LayerType::Convolution2d:
+ case LayerType::ConvertFp16ToFp32:
{
- auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ result = IsConvertFp16ToFp32Supported(compute, input, output, reason, reasonCapacity);
+ break;
+ }
+ case LayerType::ConvertFp32ToFp16:
+ {
+ const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ result = IsConvertFp32ToFp16Supported(compute, input, output, reason, reasonCapacity);
+ break;
+ }
+ case LayerType::Convolution2d:
+ {
+ auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
+ const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType);
+ const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
- const TensorInfo * biasInfo = nullptr;
+ TensorInfo biasInfo;
+ const TensorInfo * biasInfoPtr = nullptr;
+ static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
@@ -72,21 +145,27 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat
if (descriptor.m_BiasEnabled)
{
BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
- biasInfo = &(cLayer->m_Bias->GetTensorInfo());
+ biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
+ biasInfoPtr = &biasInfo;
}
else
{
- // If biases are not enabled I pass a dummy tensorinfo for the validation
+ // If biases are not enabled pass a dummy tensorinfo for the validation.
switch(input.GetDataType())
{
+ case DataType::Float16:
+ {
+ biasInfoPtr = &dummyFloat16Bias;
+ break;
+ }
case DataType::Float32:
{
- biasInfo = &dummyFloat32Bias;
+ biasInfoPtr = &dummyFloat32Bias;
break;
}
case DataType::QuantisedAsymm8:
{
- biasInfo = &dummyQA8Bias;
+ biasInfoPtr = &dummyQA8Bias;
break;
}
default:
@@ -100,16 +179,16 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat
input,
output,
descriptor,
- cLayer->m_Weight->GetTensorInfo(),
- *biasInfo,
+ OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
+ *biasInfoPtr,
reason,
reasonCapacity);
break;
}
case LayerType::MemCopy:
{
- // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends
- // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests)
+ // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
+ // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
result = compute == Compute::CpuRef || compute == Compute::Undefined
|| compute == Compute::CpuAcc || compute == Compute::GpuAcc;
strcpy(reason, "Unsupported backend type");
@@ -118,66 +197,314 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat
case LayerType::DepthwiseConvolution2d:
{
auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
- const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsDepthwiseConvolutionSupported(compute, input, cLayer->GetParameters(),
- cLayer->m_Weight->GetTensorInfo(), reason, reasonCapacity);
+ const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
+ dataType);
+ const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
+ BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
+
+ TensorInfo biasInfo;
+ const TensorInfo * biasInfoPtr = nullptr;
+ static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
+ static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
+ static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
+
+ const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
+ if (descriptor.m_BiasEnabled)
+ {
+ BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
+ biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
+ biasInfoPtr = &biasInfo;
+ }
+ else
+ {
+ // If biases are not enabled pass a dummy tensorinfo for the validation
+ switch(input.GetDataType())
+ {
+ case DataType::Float16:
+ {
+ biasInfoPtr = &dummyFloat16Bias;
+ break;
+ }
+ case DataType::Float32:
+ {
+ biasInfoPtr = &dummyFloat32Bias;
+ break;
+ }
+ case DataType::QuantisedAsymm8:
+ {
+ biasInfoPtr = &dummyQA8Bias;
+ break;
+ }
+ default:
+ {
+ BOOST_ASSERT_MSG(false, "Unexpected bias type");
+ }
+ }
+ }
+
+
+ result = IsDepthwiseConvolutionSupported(compute,
+ input,
+ output,
+ descriptor,
+ OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
+ *biasInfoPtr,
+ reason,
+ reasonCapacity);
break;
}
case LayerType::FakeQuantization:
{
auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsFakeQuantizationSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity);
+ result = IsFakeQuantizationSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(),
+ reason, reasonCapacity);
break;
}
case LayerType::Floor:
{
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
- result = IsFloorSupported(compute, input, output, reason, reasonCapacity);
+ result = IsFloorSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
+ reason, reasonCapacity);
break;
}
case LayerType::FullyConnected:
{
auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsFullyConnectedSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity);
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
+
+ TensorInfo biasInfo;
+ const TensorInfo * biasInfoPtr = nullptr;
+ static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
+ static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
+ static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
+
+ const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
+ if (descriptor.m_BiasEnabled)
+ {
+ BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
+ biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
+ biasInfoPtr = &biasInfo;
+ }
+ else
+ {
+ // If biases are not enabled pass a dummy tensorinfo for the validation
+ switch(input.GetDataType())
+ {
+ case DataType::Float16:
+ {
+ biasInfoPtr = &dummyFloat16Bias;
+ break;
+ }
+ case DataType::Float32:
+ {
+ biasInfoPtr = &dummyFloat32Bias;
+ break;
+ }
+ case DataType::QuantisedAsymm8:
+ {
+ biasInfoPtr = &dummyQA8Bias;
+ break;
+ }
+ default:
+ {
+ BOOST_ASSERT_MSG(false, "Unexpected bias type");
+ }
+ }
+ }
+
+ result = IsFullyConnectedSupported(compute,
+ OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType),
+ OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
+ *biasInfoPtr,
+ descriptor,
+ reason,
+ reasonCapacity);
break;
}
case LayerType::Input:
{
const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
- result = IsInputSupported(compute, input, reason, reasonCapacity);
+ result = IsInputSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
break;
}
case LayerType::L2Normalization:
{
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsL2NormalizationSupported(compute, input, reason, reasonCapacity);
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ result = IsL2NormalizationSupported(compute, OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType), reason, reasonCapacity);
+ break;
+ }
+ case LayerType::Lstm:
+ {
+ auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
+ const LstmDescriptor& descriptor = cLayer->GetParameters();
+
+ // All inputs.
+ const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
+ dataType);
+ const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
+ dataType);
+ const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
+ dataType);
+ // All outputs
+ const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
+ const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
+ const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
+ const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
+
+ // Basic parameters
+ const TensorInfo& inputToForgetWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
+ const TensorInfo& inputToCellWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
+ const TensorInfo& inputToOutputWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
+ const TensorInfo& recurrentToForgetWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
+ const TensorInfo& recurrentToCellWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
+ const TensorInfo& recurrentToOutputWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
+ const TensorInfo& forgetGateBias
+ = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
+ const TensorInfo& cellBias
+ = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
+ const TensorInfo& outputGateBias
+ = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
+
+ // Optional parameters
+ const TensorInfo* inputToInputWeights = nullptr;
+ const TensorInfo* recurrentToInputWeights = nullptr;
+ const TensorInfo* cellToInputWeights = nullptr;
+ const TensorInfo* inputGateBias = nullptr;
+ const TensorInfo* projectionWeights = nullptr;
+ const TensorInfo* projectionBias = nullptr;
+ const TensorInfo* cellToForgetWeights = nullptr;
+ const TensorInfo* cellToOutputWeights = nullptr;
+
+ TensorInfo optInputToInputWeights;
+ TensorInfo optRecurrentToInputWeights;
+ TensorInfo optCellToInputWeights;
+ TensorInfo optInputGateBias;
+ TensorInfo optProjectionWeights;
+ TensorInfo optProjectionBias;
+ TensorInfo optCellToForgetWeights;
+ TensorInfo optCellToOutputWeights;
+
+ if(!descriptor.m_CifgEnabled)
+ {
+ optInputToInputWeights =
+ OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
+ inputToInputWeights = &optInputToInputWeights;
+
+ optRecurrentToInputWeights =
+ OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
+ recurrentToInputWeights = &optRecurrentToInputWeights;
+ if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
+ {
+ optCellToInputWeights =
+ OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
+ cellToInputWeights = &optCellToInputWeights;
+ }
+ optInputGateBias =
+ OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
+ inputGateBias = &optInputGateBias;
+ }
+
+ if(descriptor.m_ProjectionEnabled)
+ {
+ optProjectionWeights =
+ OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
+ projectionWeights = &optProjectionWeights;
+ if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
+ {
+ optProjectionBias =
+ OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
+ projectionBias = &optProjectionBias;
+ }
+ }
+
+ if(descriptor.m_PeepholeEnabled)
+ {
+ optCellToForgetWeights =
+ OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
+ cellToForgetWeights = &optCellToForgetWeights;
+ optCellToOutputWeights =
+ OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
+ cellToOutputWeights = &optCellToOutputWeights;
+ }
+
+ result = IsLstmSupported(compute,
+ input,
+ outputStateIn,
+ cellStateIn,
+ scratchBuffer,
+ outputStateOut,
+ cellStateOut,
+ output,
+ descriptor,
+ inputToForgetWeights,
+ inputToCellWeights,
+ inputToOutputWeights,
+ recurrentToForgetWeights,
+ recurrentToCellWeights,
+ recurrentToOutputWeights,
+ forgetGateBias,
+ cellBias,
+ outputGateBias,
+ inputToInputWeights,
+ recurrentToInputWeights,
+ cellToInputWeights,
+ inputGateBias,
+ projectionWeights,
+ projectionBias,
+ cellToForgetWeights,
+ cellToOutputWeights,
+ reason,
+ reasonCapacity);
break;
}
case LayerType::Merger:
{
auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
- // Get vector of all inputs
- auto getTensorInfo = [](const InputSlot& slot)
+ // Get vector of all inputs.
+ auto getTensorInfo = [&dataType](const InputSlot& slot)
{
- return &slot.GetConnectedOutputSlot()->GetTensorInfo();
+ return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
};
- auto begin = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
- auto end = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
+ auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
+ auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
+ std::vector<TensorInfo> inputs(beginI, endI);
- std::vector<const TensorInfo*> inputs(begin, end);
+ auto getTensorInfoPtr = [](const TensorInfo& info)
+ {
+ return &info;
+ };
+ auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
+ auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
+ std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
- result = IsMergerSupported(compute, inputs, cLayer->GetParameters(), reason, reasonCapacity);
+ result = IsMergerSupported(compute, inputPtrs, cLayer->GetParameters(), reason, reasonCapacity);
break;
}
case LayerType::Multiplication:
{
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
- result = IsMultiplicationSupported(compute, input0, input1, reason, reasonCapacity);
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ result = IsMultiplicationSupported(compute,
+ OverrideDataType(input0, dataType),
+ OverrideDataType(input1, dataType),
+ OverrideDataType(output, dataType),
+ reason,
+ reasonCapacity);
break;
}
case LayerType::Normalization:
@@ -185,13 +512,15 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat
auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
- result = IsNormalizationSupported(compute, input, output, cLayer->GetParameters(), reason, reasonCapacity);
+ result = IsNormalizationSupported(compute, OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType), cLayer->GetParameters(), reason,
+ reasonCapacity);
break;
}
case LayerType::Output:
{
const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsOutputSupported(compute, output, reason, reasonCapacity);
+ result = IsOutputSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity);
break;
}
case LayerType::Permute:
@@ -199,7 +528,8 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat
auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
- result = IsPermuteSupported(compute, input, output, cLayer->GetParameters(), reason, reasonCapacity);
+ result = IsPermuteSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
+ cLayer->GetParameters(), reason, reasonCapacity);
break;
}
case LayerType::Pooling2d:
@@ -207,33 +537,38 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat
auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
- result = IsPooling2dSupported(compute, input, output, cLayer->GetParameters(), reason, reasonCapacity);
+ result = IsPooling2dSupported(compute, OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType), cLayer->GetParameters(), reason,
+ reasonCapacity);
break;
}
case LayerType::Reshape:
{
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsReshapeSupported(compute, input, reason, reasonCapacity);
+ result = IsReshapeSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
break;
}
case LayerType::ResizeBilinear:
{
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsResizeBilinearSupported(compute, input, reason, reasonCapacity);
+ result = IsResizeBilinearSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
break;
}
case LayerType::Softmax:
{
auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsSoftmaxSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity);
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ result = IsSoftmaxSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
+ cLayer->GetParameters(), reason, reasonCapacity);
break;
}
case LayerType::Splitter:
{
auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- result = IsSplitterSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity);
+ result = IsSplitterSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(), reason,
+ reasonCapacity);
break;
}
default:
@@ -248,7 +583,8 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat
return result;
}
-bool IWorkloadFactory::IsLayerSupported(const Layer& layer, DataType dataType, std::string& outReasonIfUnsupported)
+bool IWorkloadFactory::IsLayerSupported(const Layer& layer, boost::optional<DataType> dataType,
+ std::string& outReasonIfUnsupported)
{
return IsLayerSupported(layer.GetComputeDevice(), layer, dataType, outReasonIfUnsupported);
}