35 std::unique_ptr<Decoder<float>> params_decoderPtr = MakeDecoder<float>(inputInfo0, inputs[0]->
Map());
37 const int32_t* indicesDataPtr =
reinterpret_cast<int32_t*
>(inputs[1]->Map());
38 std::vector<int32_t> indices(indicesDataPtr, indicesDataPtr + inputInfo1.
GetNumElements());
40 std::unique_ptr<Encoder<float>> output_encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->
Map());
48 std::vector<unsigned int> flattenedCoeff(keyIndices[
"ND"], 1);
49 for (
unsigned int i = 1; i < keyIndices[
"ND"]; ++i)
51 flattenedCoeff[i-1] = paramsShape[i];
53 for (
unsigned int i = keyIndices[
"ND"]-1; i > 0; --i)
55 flattenedCoeff[i-1] *= flattenedCoeff[i];
61 flattenedIndices_Info.
SetShape({ keyIndices[
"W"] });
62 std::vector<int32_t> flattenedIndices(flattenedIndices_Info.
GetNumElements(), 0);
65 for (
unsigned int i = 0; i < keyIndices[
"W"]; ++i)
67 for (
unsigned int j = 0; j < keyIndices[
"ND"]; ++j)
69 flattenedIndices[i] += indices[i * keyIndices[
"ND"] + j] *
static_cast<int32_t
>(flattenedCoeff[j]);
76 params_K_C_Info.
SetShape({ keyIndices[
"K"], keyIndices[
"C"] });
80 indices_N_W_Info.
SetShape({ keyIndices[
"N"], keyIndices[
"W"] });
85 outputGather_Info.
SetShape({ keyIndices[
"N"], keyIndices[
"W"], keyIndices[
"C"] });
88 Gather(params_K_C_Info, indices_N_W_Info, outputGather_Info,
89 *params_decoderPtr, flattenedIndices.data(), *output_encoderPtr, 0);