ArmNN
 21.02
MemCopyTestImpl.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <ResolveType.hpp>
8 
10 
15 
16 #include <test/TensorHelpers.hpp>
17 
18 #include <boost/multi_array.hpp>
19 
20 namespace
21 {
22 
23 template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
24 LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
25  armnn::IWorkloadFactory& dstWorkloadFactory,
26  bool withSubtensors)
27 {
28  const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
29  const armnn::TensorShape tensorShape(4, shapeData.data());
30  const armnn::TensorInfo tensorInfo(tensorShape, dataType);
31  boost::multi_array<T, 4> inputData = MakeTensor<T, 4>(tensorInfo, std::vector<T>(
32  {
33  1, 2, 3, 4, 5,
34  6, 7, 8, 9, 10,
35  11, 12, 13, 14, 15,
36  16, 17, 18, 19, 20,
37  21, 22, 23, 24, 25,
38  26, 27, 28, 29, 30,
39  })
40  );
41 
42  LayerTestResult<T, 4> ret(tensorInfo);
43  ret.outputExpected = inputData;
44 
45  boost::multi_array<T, 4> outputData(shapeData);
46 
48  auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
49  auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
51 
52  AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
53  outputTensorHandle->Allocate();
54 
55  armnn::MemCopyQueueDescriptor memCopyQueueDesc;
56  armnn::WorkloadInfo workloadInfo;
57 
58  const unsigned int origin[4] = {};
59 
61  auto workloadInput = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
62  ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
63  : std::move(inputTensorHandle);
64  auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
65  ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
66  : std::move(outputTensorHandle);
68 
69  AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
70  AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
71 
72  dstWorkloadFactory.CreateMemCopy(memCopyQueueDesc, workloadInfo)->Execute();
73 
74  CopyDataFromITensorHandle(outputData.data(), workloadOutput.get());
75  ret.output = outputData;
76 
77  return ret;
78 }
79 
80 template<typename SrcWorkloadFactory,
81  typename DstWorkloadFactory,
82  armnn::DataType dataType,
83  typename T = armnn::ResolveType<dataType>>
84 LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
85 {
87  WorkloadFactoryHelper<SrcWorkloadFactory>::GetMemoryManager();
88 
90  WorkloadFactoryHelper<DstWorkloadFactory>::GetMemoryManager();
91 
92  SrcWorkloadFactory srcWorkloadFactory = WorkloadFactoryHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
93  DstWorkloadFactory dstWorkloadFactory = WorkloadFactoryHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
94 
95  return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
96 }
97 
98 } // anonymous namespace
virtual std::unique_ptr< IWorkload > CreateMemCopy(const MemCopyQueueDescriptor &descriptor, const WorkloadInfo &info) const
#define ARMNN_NO_DEPRECATE_WARN_BEGIN
Definition: Deprecated.hpp:33
virtual std::unique_ptr< ITensorHandle > CreateSubTensorHandle(ITensorHandle &parent, TensorShape const &subTensorShape, unsigned int const *subTensorOrigin) const =0
typename ResolveTypeImpl< DT >::Type ResolveType
Definition: ResolveType.hpp:73
DataType
Definition: Types.hpp:32
#define ARMNN_NO_DEPRECATE_WARN_END
Definition: Deprecated.hpp:34
std::shared_ptr< IMemoryManager > IMemoryManagerSharedPtr
void AllocateAndCopyDataToITensorHandle(armnn::ITensorHandle *tensorHandle, const void *memory)
void CopyDataFromITensorHandle(void *memory, const armnn::ITensorHandle *tensorHandle)
virtual std::unique_ptr< ITensorHandle > CreateTensorHandle(const TensorInfo &tensorInfo, const bool IsMemoryManaged=true) const =0
Contains information about inputs and outputs to a layer.
virtual bool SupportsSubTensors() const =0