13 #include <doctest/doctest.h> 20 const std::vector<int32_t>& indicesData)
29 Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
30 Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
31 Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
36 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
37 void GatherEndToEnd(
const std::vector<BackendId>& backends)
51 std::vector<T> paramsData{
52 1, 2, 3, 4, 5, 6, 7, 8
55 std::vector<int32_t> indicesData{
59 std::vector<T> expectedOutput{
64 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
68 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
69 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
71 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
74 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
75 void GatherMultiDimEndToEnd(
const std::vector<BackendId>& backends)
89 std::vector<T> paramsData{
100 std::vector<int32_t> indicesData{
105 std::vector<T> expectedOutput{
122 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
124 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
125 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
127 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
A GatherDescriptor for the GatherLayer.
void SetQuantizationScale(float scale)
void SetConstant(const bool IsConstant=true)
Marks the data corresponding to this tensor info as constant.
void SetQuantizationOffset(int32_t offset)
void Connect(armnn::IConnectableLayer *from, armnn::IConnectableLayer *to, const armnn::TensorInfo &tensorInfo, unsigned int fromIndex, unsigned int toIndex)
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
static INetworkPtr Create(NetworkOptions networkOptions={})