diff options
Diffstat (limited to 'src/backends/gpuFsa/GpuFsaBackend.cpp')
-rw-r--r-- | src/backends/gpuFsa/GpuFsaBackend.cpp | 32 |
1 files changed, 18 insertions, 14 deletions
diff --git a/src/backends/gpuFsa/GpuFsaBackend.cpp b/src/backends/gpuFsa/GpuFsaBackend.cpp index 8ea9e8e7d3..9886a6e187 100644 --- a/src/backends/gpuFsa/GpuFsaBackend.cpp +++ b/src/backends/gpuFsa/GpuFsaBackend.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -20,10 +20,7 @@ #include <arm_compute/core/CL/CLKernelLibrary.h> #include <arm_compute/runtime/CL/CLBufferAllocator.h> -#include <arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadContext.h> -#include <arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h> - -#include "layerValidators/GpuFsaConvolution2dValidate.hpp" +#include "layers/GpuFsaConvolution2d.hpp" namespace armnn { @@ -218,9 +215,6 @@ OptimizationViews GpuFsaBackend::OptimizeSubgraphView(const SubgraphView& subgra OptimizationViews optimizationViews(modelOptions); using namespace arm_compute::experimental::dynamic_fusion; - // Create a new workload sketch, for validation purposes - auto compileCtx = arm_compute::CLKernelLibrary::get().get_compile_context(); - auto gpuCtx = GpuWorkloadContext(&compileCtx); auto it = subgraph.end(); std::map<LayerGuid, Layer*> untouched; @@ -233,32 +227,41 @@ OptimizationViews GpuFsaBackend::OptimizeSubgraphView(const SubgraphView& subgra GpuFsaLayerSupport supportChecker; it = subgraph.end(); + arm_compute::CLCompileContext* compileCtx = &(arm_compute::CLKernelLibrary::get().get_compile_context()); + + // Setup the GpuWokloadContext which will exist for the lifetime of the Graph. This contains the TensorInfos + std::shared_ptr<GpuWorkloadContext> workloadContext = std::make_shared<GpuWorkloadContext>(compileCtx); while (it != subgraph.begin()) { --it; Layer& base = *(PolymorphicDowncast<Layer*>(*it)); + // Create a GpuFsaPreCompiledBlob, this contains all of the information needed to execute an operator + GpuFsaPreCompiledBlob* preCompiledBlobPtr = new GpuFsaPreCompiledBlob(); + preCompiledBlobPtr->workloadContext = workloadContext; + preCompiledBlobPtr->sketch = std::make_unique<GpuWorkloadSketch>(workloadContext.get()); - std::unique_ptr<GpuWorkloadSketch> sketch = std::make_unique<GpuWorkloadSketch>(&gpuCtx); + // Configure and setup the sketch for each supported op. Their data will be wrapped into a PreCompiled layer switch (base.GetType()) { case (LayerType::Convolution2d): { auto input = base.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(); auto weights = base.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(); - //std::vector<TensorInfo> infos = {input, weights}; auto desc = PolymorphicDowncast<const Convolution2dDescriptor*>(&base.GetParameters()); if (desc->m_BiasEnabled) { auto bias = base.GetInputSlot(2).GetConnectedOutputSlot()->GetTensorInfo(); - GpuFsaConvolution2dCreateOp(input, + GpuFsaConvolution2dCreateOp(preCompiledBlobPtr, + input, *desc, weights, bias); } else { - GpuFsaConvolution2dCreateOp(input, + GpuFsaConvolution2dCreateOp(preCompiledBlobPtr, + input, *desc, weights, EmptyOptional()); @@ -270,7 +273,8 @@ OptimizationViews GpuFsaBackend::OptimizeSubgraphView(const SubgraphView& subgra continue; } - auto compiledBlob = std::make_unique<PreCompiledObjectPtr>(sketch.release(), DeleteAsType<GpuWorkloadSketch>); + auto compiledBlob = + std::make_unique<PreCompiledObjectPtr>(preCompiledBlobPtr, DeleteAsType<GpuFsaPreCompiledBlob>); IConnectableLayer* preCompiledLayer = optimizationViews.GetINetwork()->AddPrecompiledLayer( PreCompiledDescriptor(base.GetNumInputSlots(), base.GetNumOutputSlots()), @@ -289,7 +293,7 @@ OptimizationViews GpuFsaBackend::OptimizeSubgraphView(const SubgraphView& subgra CreateOutputsFrom(&base), {&base}); - optimizationViews.AddSubstitution({ *substituteSubgraph, SubgraphView(preCompiledLayer) }); + optimizationViews.AddSubstitution({ std::move(*substituteSubgraph), SubgraphView(preCompiledLayer) }); untouched.erase(base.GetGuid()); } |