diff options
Diffstat (limited to 'shim/sl/canonical/ConversionUtils.hpp')
-rw-r--r-- | shim/sl/canonical/ConversionUtils.hpp | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/shim/sl/canonical/ConversionUtils.hpp b/shim/sl/canonical/ConversionUtils.hpp index beee00d11a..91a8e3080c 100644 --- a/shim/sl/canonical/ConversionUtils.hpp +++ b/shim/sl/canonical/ConversionUtils.hpp @@ -150,7 +150,7 @@ static bool Fail(const char* formatStr, Args&&... args) // Convenience macro to call an Is*Supported function and log caller name together with reason for lack of support. // Called as: FORWARD_LAYER_SUPPORT_FUNC(__func__, Is*Supported, backends, a, b, c, d, e) -#define FORWARD_LAYER_SUPPORT_FUNC(funcName, func, backends, supported, ...) \ +#define FORWARD_LAYER_SUPPORT_FUNC(funcName, func, backends, supported, setBackend, ...) \ try \ { \ for (auto&& backendId : backends) \ @@ -163,6 +163,7 @@ try \ layerSupportObject.func(__VA_ARGS__, armnn::Optional<std::string&>(reasonIfUnsupported)); \ if (supported) \ { \ + setBackend = backendId; \ break; \ } \ else \ @@ -322,10 +323,12 @@ bool BroadcastTensor(LayerInputHandle& input0, armnn::ReshapeDescriptor reshapeDescriptor; bool isSupported = false; + armnn::BackendId setBackend; FORWARD_LAYER_SUPPORT_FUNC(__func__, IsReshapeSupported, data.m_Backends, isSupported, + setBackend, smallInfo, reshapedInfo, reshapeDescriptor); @@ -336,6 +339,7 @@ bool BroadcastTensor(LayerInputHandle& input0, ARMNN_ASSERT(data.m_Network != nullptr); armnn::IConnectableLayer& reshapeLayer = AddReshapeLayer(*data.m_Network, smallInputHandle, reshapedInfo); + reshapeLayer.SetBackendId(setBackend); if (input0IsSmaller) { @@ -527,7 +531,8 @@ inline bool RequiresReshape(armnn::TensorShape & inputShape) inline void SwizzleInputs(armnn::INetwork& network, std::vector<LayerInputHandle>& inputs, std::vector<armnn::TensorShape>& inputShapes, - const armnn::PermutationVector& mapping) + const armnn::PermutationVector& mapping, + std::vector<armnn::BackendId>& setBackends) { if (!mapping.IsEqual(IdentityPermutation4D)) { @@ -536,6 +541,7 @@ inline void SwizzleInputs(armnn::INetwork& network, { // add swizzle layer armnn::IConnectableLayer& swizzleLayer = AddTransposeLayer(network, inputs[i], mapping); + swizzleLayer.SetBackendId(setBackends[i]); auto& outputSlot = swizzleLayer.GetOutputSlot(0); auto& outputInfo = outputSlot.GetTensorInfo(); // replace inputs with the swizzled ones @@ -553,6 +559,7 @@ bool TransposeInputTensors(ConversionData& data, // If we have a IdentityPermutation4D or IdentityPermutation3D then we are not permuting if (!mapping.IsEqual(IdentityPermutation4D) && !mapping.IsEqual(IdentityPermutation3D)) { + std::vector<armnn::BackendId> setBackendsVec; armnn::TensorInfo outputTransposeInfo; size_t nInputs = inputs.size(); for (size_t i=0; i<nInputs; ++i) @@ -563,20 +570,23 @@ bool TransposeInputTensors(ConversionData& data, outputTransposeInfo = armnnUtils::TransposeTensorShape(inputs[i].GetTensorInfo(), mapping); bool isSupported = false; + armnn::BackendId setBackend; FORWARD_LAYER_SUPPORT_FUNC(__func__, IsTransposeSupported, data.m_Backends, isSupported, + setBackend, inputs[i].GetTensorInfo(), outputTransposeInfo, transposeDesc); + setBackendsVec.push_back(setBackend); if (!isSupported) { return false; } } - SwizzleInputs(*data.m_Network, inputs, inputShapes, mapping); + SwizzleInputs(*data.m_Network, inputs, inputShapes, mapping, setBackendsVec); } return true; } |