in graphjet-core/src/main/java/com/twitter/graphjet/algorithms/randommultigraphneighbors/RandomMultiGraphNeighbors.java [88:153]
public RandomMultiGraphNeighborsResponse getRandomMultiGraphNeighbors(
RandomMultiGraphNeighborsRequest request,
Random random,
RelatedTweetFilterChain filterChain) {
Long2DoubleMap leftSeedNodeWithWeights = request.getLeftSeedNodesWithWeight();
int maxNumSamples = request.getMaxNumSamples();
int maxNumResults = Math.min(
request.getMaxNumResults(),
RecommendationRequest.MAX_RECOMMENDATION_RESULTS
);
// construct a IndexArray and AliasTableArray to sample from LHS seed nodes
long[] indexArray = new long[leftSeedNodeWithWeights.size()];
int[] aliasTableArray = IntArrayAliasTable.generateAliasTableArray(
leftSeedNodeWithWeights.size());
constructAliasTableArray(leftSeedNodeWithWeights, indexArray, aliasTableArray);
// first, get number of samples for each of the seed nodes
Long2IntMap nodeToNumSamples = new Long2IntOpenHashMap(leftSeedNodeWithWeights.size());
nodeToNumSamples.defaultReturnValue(0);
for (int i = 0; i < maxNumSamples; i++) {
int index = AliasTableUtil.getRandomSampleFromAliasTable(aliasTableArray, random);
long sampledLHSNode = indexArray[index];
nodeToNumSamples.put(sampledLHSNode, nodeToNumSamples.get(sampledLHSNode) + 1);
}
// now actually retrieve neighbors of sampled seed nodes
Long2IntMap seedNodeRightNeighbors = new Long2IntOpenHashMap(1024);
seedNodeRightNeighbors.defaultReturnValue(0);
for (Long2IntMap.Entry entry : nodeToNumSamples.long2IntEntrySet()) {
long node = entry.getLongKey();
int numSamples = entry.getIntValue();
EdgeIterator currentIterator = bipartiteGraph.getRandomLeftNodeEdges(
node, numSamples, random);
if (currentIterator != null) {
while (currentIterator.hasNext()) {
long neighbor = currentIterator.nextLong();
seedNodeRightNeighbors.put(
neighbor, seedNodeRightNeighbors.get(neighbor) + 1);
}
}
}
// normalize and select top neighbors
PriorityQueue<NeighborInfo> topResults = new PriorityQueue<NeighborInfo>(maxNumResults);
numOfUniqueNeighborsCounter.incr(seedNodeRightNeighbors.size());
for (Long2IntMap.Entry entry : seedNodeRightNeighbors.long2IntEntrySet()) {
long neighborNode = entry.getLongKey();
int occurrence = entry.getIntValue();
numOfNeighborsCounter.incr(occurrence);
if (filterChain.filter(neighborNode)) {
continue;
}
int neighborNodeDegree = bipartiteGraph.getRightNodeDegree(neighborNode);
NeighborInfo neighborInfo = new NeighborInfo(
neighborNode, (double) occurrence / (double) maxNumSamples, neighborNodeDegree);
addResultToPriorityQueue(topResults, neighborInfo, maxNumResults);
}
List<NeighborInfo> outputResults = Lists.newArrayListWithCapacity(topResults.size());
while (!topResults.isEmpty()) {
outputResults.add(topResults.poll());
}
Collections.reverse(outputResults);
return new RandomMultiGraphNeighborsResponse(outputResults);
}