18 const std::vector<int32_t>& indicesData)
26 Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
27 Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
28 Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
33 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
34 void GatherEndToEnd(
const std::vector<BackendId>& backends)
46 std::vector<T> paramsData{
47 1, 2, 3, 4, 5, 6, 7, 8
50 std::vector<int32_t> indicesData{
54 std::vector<T> expectedOutput{
59 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
61 BOOST_TEST_CHECKPOINT(
"create a network");
63 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
64 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
66 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
69 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
70 void GatherMultiDimEndToEnd(
const std::vector<BackendId>& backends)
82 std::vector<T> paramsData{
93 std::vector<int32_t> indicesData{
98 std::vector<T> expectedOutput{
115 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
117 BOOST_TEST_CHECKPOINT(
"create a network");
119 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
120 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
122 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
void SetQuantizationScale(float scale)
void SetQuantizationOffset(int32_t offset)
void Connect(armnn::IConnectableLayer *from, armnn::IConnectableLayer *to, const armnn::TensorInfo &tensorInfo, unsigned int fromIndex, unsigned int toIndex)
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
static INetworkPtr Create()