From c577f2c6a3b4ddb6ba87a882723c53a248afbeba Mon Sep 17 00:00:00 2001 From: telsoa01 Date: Fri, 31 Aug 2018 09:22:23 +0100 Subject: Release 18.08 --- src/armnn/backends/WorkloadFactory.cpp | 418 +++++++++++++++++++++++++++++---- 1 file changed, 377 insertions(+), 41 deletions(-) (limited to 'src/armnn/backends/WorkloadFactory.cpp') diff --git a/src/armnn/backends/WorkloadFactory.cpp b/src/armnn/backends/WorkloadFactory.cpp index 4e94d7701c..1b3f29421a 100644 --- a/src/armnn/backends/WorkloadFactory.cpp +++ b/src/armnn/backends/WorkloadFactory.cpp @@ -20,7 +20,40 @@ namespace armnn { -bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, DataType dataType, +namespace +{ + const TensorInfo OverrideDataType(const TensorInfo& info, boost::optional type) + { + if (type == boost::none) + { + return info; + } + + return TensorInfo(info.GetShape(), type.get(), info.GetQuantizationScale(), info.GetQuantizationOffset()); + } + + boost::optional GetBiasTypeFromWeightsType(boost::optional weightsType) + { + if (weightsType == boost::none) + { + return weightsType; + } + + switch(weightsType.get()) + { + case DataType::Float16: + case DataType::Float32: + return weightsType; + case DataType::QuantisedAsymm8: + return DataType::Signed32; + default: + BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type."); + } + return boost::none; + } +} + +bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, boost::optional dataType, std::string& outReasonIfUnsupported) { constexpr size_t reasonCapacity = 1024; @@ -32,7 +65,13 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat { auto cLayer = boost::polymorphic_downcast(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsActivationSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + result = IsActivationSupported(compute, + OverrideDataType(input, dataType), + OverrideDataType(output, dataType), + cLayer->GetParameters(), + reason, + reasonCapacity); break; } case LayerType::Addition: @@ -40,30 +79,64 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); - result = IsAdditionSupported(compute, input0, input1, output, reason, reasonCapacity); + result = IsAdditionSupported(compute, + OverrideDataType(input0, dataType), + OverrideDataType(input1, dataType), + OverrideDataType(output, dataType), + reason, + reasonCapacity); break; } case LayerType::BatchNormalization: { auto cLayer = boost::polymorphic_downcast(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsBatchNormalizationSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo(); + const TensorInfo& var = cLayer->m_Variance->GetTensorInfo(); + const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo(); + const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo(); + result = IsBatchNormalizationSupported(compute, + OverrideDataType(input, dataType), + OverrideDataType(output, dataType), + OverrideDataType(mean, dataType), + OverrideDataType(var, dataType), + OverrideDataType(beta, dataType), + OverrideDataType(gamma, dataType), + cLayer->GetParameters(), + reason, reasonCapacity); break; } case LayerType::Constant: { const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); - result = IsConstantSupported(compute, output, reason, reasonCapacity); + result = IsConstantSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity); break; } - case LayerType::Convolution2d: + case LayerType::ConvertFp16ToFp32: { - auto cLayer = boost::polymorphic_downcast(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + result = IsConvertFp16ToFp32Supported(compute, input, output, reason, reasonCapacity); + break; + } + case LayerType::ConvertFp32ToFp16: + { + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + result = IsConvertFp32ToFp16Supported(compute, input, output, reason, reasonCapacity); + break; + } + case LayerType::Convolution2d: + { + auto cLayer = boost::polymorphic_downcast(&layer); + const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType); + const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); BOOST_ASSERT(cLayer->m_Weight.get() != nullptr); - const TensorInfo * biasInfo = nullptr; + TensorInfo biasInfo; + const TensorInfo * biasInfoPtr = nullptr; + static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16); static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32); static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32); @@ -72,21 +145,27 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat if (descriptor.m_BiasEnabled) { BOOST_ASSERT(cLayer->m_Bias.get() != nullptr); - biasInfo = &(cLayer->m_Bias->GetTensorInfo()); + biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType)); + biasInfoPtr = &biasInfo; } else { - // If biases are not enabled I pass a dummy tensorinfo for the validation + // If biases are not enabled pass a dummy tensorinfo for the validation. switch(input.GetDataType()) { + case DataType::Float16: + { + biasInfoPtr = &dummyFloat16Bias; + break; + } case DataType::Float32: { - biasInfo = &dummyFloat32Bias; + biasInfoPtr = &dummyFloat32Bias; break; } case DataType::QuantisedAsymm8: { - biasInfo = &dummyQA8Bias; + biasInfoPtr = &dummyQA8Bias; break; } default: @@ -100,16 +179,16 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat input, output, descriptor, - cLayer->m_Weight->GetTensorInfo(), - *biasInfo, + OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType), + *biasInfoPtr, reason, reasonCapacity); break; } case LayerType::MemCopy: { - // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends - // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests) + // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends, + // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests). result = compute == Compute::CpuRef || compute == Compute::Undefined || compute == Compute::CpuAcc || compute == Compute::GpuAcc; strcpy(reason, "Unsupported backend type"); @@ -118,66 +197,314 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat case LayerType::DepthwiseConvolution2d: { auto cLayer = boost::polymorphic_downcast(&layer); - const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsDepthwiseConvolutionSupported(compute, input, cLayer->GetParameters(), - cLayer->m_Weight->GetTensorInfo(), reason, reasonCapacity); + const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), + dataType); + const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); + BOOST_ASSERT(cLayer->m_Weight.get() != nullptr); + + TensorInfo biasInfo; + const TensorInfo * biasInfoPtr = nullptr; + static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16); + static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32); + static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32); + + const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters(); + if (descriptor.m_BiasEnabled) + { + BOOST_ASSERT(cLayer->m_Bias.get() != nullptr); + biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType)); + biasInfoPtr = &biasInfo; + } + else + { + // If biases are not enabled pass a dummy tensorinfo for the validation + switch(input.GetDataType()) + { + case DataType::Float16: + { + biasInfoPtr = &dummyFloat16Bias; + break; + } + case DataType::Float32: + { + biasInfoPtr = &dummyFloat32Bias; + break; + } + case DataType::QuantisedAsymm8: + { + biasInfoPtr = &dummyQA8Bias; + break; + } + default: + { + BOOST_ASSERT_MSG(false, "Unexpected bias type"); + } + } + } + + + result = IsDepthwiseConvolutionSupported(compute, + input, + output, + descriptor, + OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType), + *biasInfoPtr, + reason, + reasonCapacity); break; } case LayerType::FakeQuantization: { auto cLayer = boost::polymorphic_downcast(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsFakeQuantizationSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity); + result = IsFakeQuantizationSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(), + reason, reasonCapacity); break; } case LayerType::Floor: { const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); - result = IsFloorSupported(compute, input, output, reason, reasonCapacity); + result = IsFloorSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType), + reason, reasonCapacity); break; } case LayerType::FullyConnected: { auto cLayer = boost::polymorphic_downcast(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsFullyConnectedSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + BOOST_ASSERT(cLayer->m_Weight.get() != nullptr); + + TensorInfo biasInfo; + const TensorInfo * biasInfoPtr = nullptr; + static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16); + static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32); + static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32); + + const FullyConnectedDescriptor& descriptor = cLayer->GetParameters(); + if (descriptor.m_BiasEnabled) + { + BOOST_ASSERT(cLayer->m_Bias.get() != nullptr); + biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType)); + biasInfoPtr = &biasInfo; + } + else + { + // If biases are not enabled pass a dummy tensorinfo for the validation + switch(input.GetDataType()) + { + case DataType::Float16: + { + biasInfoPtr = &dummyFloat16Bias; + break; + } + case DataType::Float32: + { + biasInfoPtr = &dummyFloat32Bias; + break; + } + case DataType::QuantisedAsymm8: + { + biasInfoPtr = &dummyQA8Bias; + break; + } + default: + { + BOOST_ASSERT_MSG(false, "Unexpected bias type"); + } + } + } + + result = IsFullyConnectedSupported(compute, + OverrideDataType(input, dataType), + OverrideDataType(output, dataType), + OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType), + *biasInfoPtr, + descriptor, + reason, + reasonCapacity); break; } case LayerType::Input: { const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo(); - result = IsInputSupported(compute, input, reason, reasonCapacity); + result = IsInputSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity); break; } case LayerType::L2Normalization: { const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsL2NormalizationSupported(compute, input, reason, reasonCapacity); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + result = IsL2NormalizationSupported(compute, OverrideDataType(input, dataType), + OverrideDataType(output, dataType), reason, reasonCapacity); + break; + } + case LayerType::Lstm: + { + auto cLayer = boost::polymorphic_downcast(&layer); + const LstmDescriptor& descriptor = cLayer->GetParameters(); + + // All inputs. + const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), + dataType); + const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), + dataType); + const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), + dataType); + // All outputs + const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); + const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType); + const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType); + const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType); + + // Basic parameters + const TensorInfo& inputToForgetWeights + = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToCellWeights + = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToOutputWeights + = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToForgetWeights + = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToCellWeights + = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToOutputWeights + = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType); + const TensorInfo& forgetGateBias + = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType); + const TensorInfo& cellBias + = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType); + const TensorInfo& outputGateBias + = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType); + + // Optional parameters + const TensorInfo* inputToInputWeights = nullptr; + const TensorInfo* recurrentToInputWeights = nullptr; + const TensorInfo* cellToInputWeights = nullptr; + const TensorInfo* inputGateBias = nullptr; + const TensorInfo* projectionWeights = nullptr; + const TensorInfo* projectionBias = nullptr; + const TensorInfo* cellToForgetWeights = nullptr; + const TensorInfo* cellToOutputWeights = nullptr; + + TensorInfo optInputToInputWeights; + TensorInfo optRecurrentToInputWeights; + TensorInfo optCellToInputWeights; + TensorInfo optInputGateBias; + TensorInfo optProjectionWeights; + TensorInfo optProjectionBias; + TensorInfo optCellToForgetWeights; + TensorInfo optCellToOutputWeights; + + if(!descriptor.m_CifgEnabled) + { + optInputToInputWeights = + OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType); + inputToInputWeights = &optInputToInputWeights; + + optRecurrentToInputWeights = + OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); + recurrentToInputWeights = &optRecurrentToInputWeights; + if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr) + { + optCellToInputWeights = + OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType); + cellToInputWeights = &optCellToInputWeights; + } + optInputGateBias = + OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType); + inputGateBias = &optInputGateBias; + } + + if(descriptor.m_ProjectionEnabled) + { + optProjectionWeights = + OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType); + projectionWeights = &optProjectionWeights; + if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr) + { + optProjectionBias = + OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType); + projectionBias = &optProjectionBias; + } + } + + if(descriptor.m_PeepholeEnabled) + { + optCellToForgetWeights = + OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType); + cellToForgetWeights = &optCellToForgetWeights; + optCellToOutputWeights = + OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType); + cellToOutputWeights = &optCellToOutputWeights; + } + + result = IsLstmSupported(compute, + input, + outputStateIn, + cellStateIn, + scratchBuffer, + outputStateOut, + cellStateOut, + output, + descriptor, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + forgetGateBias, + cellBias, + outputGateBias, + inputToInputWeights, + recurrentToInputWeights, + cellToInputWeights, + inputGateBias, + projectionWeights, + projectionBias, + cellToForgetWeights, + cellToOutputWeights, + reason, + reasonCapacity); break; } case LayerType::Merger: { auto cLayer = boost::polymorphic_downcast(&layer); - // Get vector of all inputs - auto getTensorInfo = [](const InputSlot& slot) + // Get vector of all inputs. + auto getTensorInfo = [&dataType](const InputSlot& slot) { - return &slot.GetConnectedOutputSlot()->GetTensorInfo(); + return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType); }; - auto begin = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo); - auto end = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo); + auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo); + auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo); + std::vector inputs(beginI, endI); - std::vector inputs(begin, end); + auto getTensorInfoPtr = [](const TensorInfo& info) + { + return &info; + }; + auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr); + auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr); + std::vector inputPtrs(beginPtr, endPtr); - result = IsMergerSupported(compute, inputs, cLayer->GetParameters(), reason, reasonCapacity); + result = IsMergerSupported(compute, inputPtrs, cLayer->GetParameters(), reason, reasonCapacity); break; } case LayerType::Multiplication: { const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); - result = IsMultiplicationSupported(compute, input0, input1, reason, reasonCapacity); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + result = IsMultiplicationSupported(compute, + OverrideDataType(input0, dataType), + OverrideDataType(input1, dataType), + OverrideDataType(output, dataType), + reason, + reasonCapacity); break; } case LayerType::Normalization: @@ -185,13 +512,15 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat auto cLayer = boost::polymorphic_downcast(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); - result = IsNormalizationSupported(compute, input, output, cLayer->GetParameters(), reason, reasonCapacity); + result = IsNormalizationSupported(compute, OverrideDataType(input, dataType), + OverrideDataType(output, dataType), cLayer->GetParameters(), reason, + reasonCapacity); break; } case LayerType::Output: { const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsOutputSupported(compute, output, reason, reasonCapacity); + result = IsOutputSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity); break; } case LayerType::Permute: @@ -199,7 +528,8 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat auto cLayer = boost::polymorphic_downcast(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); - result = IsPermuteSupported(compute, input, output, cLayer->GetParameters(), reason, reasonCapacity); + result = IsPermuteSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType), + cLayer->GetParameters(), reason, reasonCapacity); break; } case LayerType::Pooling2d: @@ -207,33 +537,38 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat auto cLayer = boost::polymorphic_downcast(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); - result = IsPooling2dSupported(compute, input, output, cLayer->GetParameters(), reason, reasonCapacity); + result = IsPooling2dSupported(compute, OverrideDataType(input, dataType), + OverrideDataType(output, dataType), cLayer->GetParameters(), reason, + reasonCapacity); break; } case LayerType::Reshape: { const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsReshapeSupported(compute, input, reason, reasonCapacity); + result = IsReshapeSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity); break; } case LayerType::ResizeBilinear: { const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsResizeBilinearSupported(compute, input, reason, reasonCapacity); + result = IsResizeBilinearSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity); break; } case LayerType::Softmax: { auto cLayer = boost::polymorphic_downcast(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsSoftmaxSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + result = IsSoftmaxSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType), + cLayer->GetParameters(), reason, reasonCapacity); break; } case LayerType::Splitter: { auto cLayer = boost::polymorphic_downcast(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = IsSplitterSupported(compute, input, cLayer->GetParameters(), reason, reasonCapacity); + result = IsSplitterSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(), reason, + reasonCapacity); break; } default: @@ -248,7 +583,8 @@ bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, Dat return result; } -bool IWorkloadFactory::IsLayerSupported(const Layer& layer, DataType dataType, std::string& outReasonIfUnsupported) +bool IWorkloadFactory::IsLayerSupported(const Layer& layer, boost::optional dataType, + std::string& outReasonIfUnsupported) { return IsLayerSupported(layer.GetComputeDevice(), layer, dataType, outReasonIfUnsupported); } -- cgit v1.2.1