ArmNN
 21.11
MemCopyTestImpl.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <ResolveType.hpp>
8 
10 
15 
16 #include <test/TensorHelpers.hpp>
17 
18 namespace
19 {
20 
21 template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
22 LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
23  armnn::IWorkloadFactory& dstWorkloadFactory,
24  bool withSubtensors)
25 {
26  const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
27  const armnn::TensorShape tensorShape(4, shapeData.data());
28  const armnn::TensorInfo tensorInfo(tensorShape, dataType);
29  std::vector<T> inputData =
30  {
31  1, 2, 3, 4, 5,
32  6, 7, 8, 9, 10,
33  11, 12, 13, 14, 15,
34  16, 17, 18, 19, 20,
35  21, 22, 23, 24, 25,
36  26, 27, 28, 29, 30,
37  };
38 
39  LayerTestResult<T, 4> ret(tensorInfo);
40  ret.m_ExpectedData = inputData;
41 
42  std::vector<T> actualOutput(tensorInfo.GetNumElements());
43 
45  auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
46  auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
48 
49  AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
50  outputTensorHandle->Allocate();
51 
52  armnn::MemCopyQueueDescriptor memCopyQueueDesc;
53  armnn::WorkloadInfo workloadInfo;
54 
55  const unsigned int origin[4] = {};
56 
58  auto workloadInput = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
59  ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
60  : std::move(inputTensorHandle);
61  auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
62  ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
63  : std::move(outputTensorHandle);
65 
66  AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
67  AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
68 
69  dstWorkloadFactory.CreateMemCopy(memCopyQueueDesc, workloadInfo)->Execute();
70 
71  CopyDataFromITensorHandle(actualOutput.data(), workloadOutput.get());
72  ret.m_ActualData = actualOutput;
73 
74  return ret;
75 }
76 
77 template<typename SrcWorkloadFactory,
78  typename DstWorkloadFactory,
79  armnn::DataType dataType,
80  typename T = armnn::ResolveType<dataType>>
81 LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
82 {
84  WorkloadFactoryHelper<SrcWorkloadFactory>::GetMemoryManager();
85 
87  WorkloadFactoryHelper<DstWorkloadFactory>::GetMemoryManager();
88 
89  SrcWorkloadFactory srcWorkloadFactory = WorkloadFactoryHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
90  DstWorkloadFactory dstWorkloadFactory = WorkloadFactoryHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
91 
92  return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
93 }
94 
95 } // anonymous namespace
virtual std::unique_ptr< IWorkload > CreateMemCopy(const MemCopyQueueDescriptor &descriptor, const WorkloadInfo &info) const
#define ARMNN_NO_DEPRECATE_WARN_BEGIN
Definition: Deprecated.hpp:33
virtual std::unique_ptr< ITensorHandle > CreateSubTensorHandle(ITensorHandle &parent, TensorShape const &subTensorShape, unsigned int const *subTensorOrigin) const =0
typename ResolveTypeImpl< DT >::Type ResolveType
Definition: ResolveType.hpp:79
DataType
Definition: Types.hpp:35
#define ARMNN_NO_DEPRECATE_WARN_END
Definition: Deprecated.hpp:34
std::shared_ptr< IMemoryManager > IMemoryManagerSharedPtr
void AllocateAndCopyDataToITensorHandle(armnn::ITensorHandle *tensorHandle, const void *memory)
void CopyDataFromITensorHandle(void *memory, const armnn::ITensorHandle *tensorHandle)
virtual std::unique_ptr< ITensorHandle > CreateTensorHandle(const TensorInfo &tensorInfo, const bool IsMemoryManaged=true) const =0
Contains information about TensorInfos of a layer.
virtual bool SupportsSubTensors() const =0