in src/main/java/com/twitter/sbf/core/SimClustersInMemory.java [327:450]
private SparseRealMatrix getRepresentationsForLeft(SparseRealMatrix representationsForRight) {
//For the rest of the comments in this function k' will be denoted by kp.
int kp = representationsForRight.getNumCols();
System.out.println(
"Expected k: " + this.k + ", Actual k: " + kp
);
//We now right-multiply the adjacency matrix A (m X n) of the bipartite graph by the matrix
//representationsForRight (n X kp) to obtain the representation vectors for the left vertices.
//We do this without explicitly converting the BipartiteGraph object into a SparseRealMatrix or
//SparseBinaryMatrix object since the BipartiteGraph object maintains both row-wise and
//column-wise adjacency information as list of non-zero entries.
//For the sake of efficiency, instead of computing the dot product of every row of the adjacency
//matrix with every column of the representationsForRight matrix in order to compute the final
//product (of A and representationsForRight), we only compute the non-zero entries of the
//final matrix. This is done by using a hash table (called columnsOfProductMatrix below)
//which uses (l,i) as key, where l is a column index in the final matrix and i is a row index
//in the final matrix.
//To update the hash table, we go over every possible column index j of A, and then use the
//non-zero entries in the column j of A and the non-zero values in row j of
//representationsForRight. See below for more details.
//As mentioned above, we also incorporate a thresholding operation in this step. The idea is to
//use a Hash table with key (l,i), where l is a column index in the final matrix and i is
//is a row index in the final matrix, that keeps track of how many right vertices that have a
//non-zero value in index l of their latent representation (i.e., they "belong" to community l)
//does the left vertex indexed by i connect to. If this value is less than thresholdForStepThree
//then we will set the (i,l) entry in the final matrix to zero. The key (l,i) in this Hash table
//can be thought of as the number of edges vertex i on left has to community l.
//Hash table for non-zero intersections. This is represented as a nested hash table for the sake
//of simplicity.
HashMap<Integer, HashMap<Integer, Double>> columnsOfProductMatrix = new HashMap<>();
//Hash table for storing number of edges vertex i on left has to community l. To be used for
//thresholding later.
HashMap<Integer, HashMap<Integer, Integer>> numEdgesFromLeftToCommunities = new HashMap<>();
//n = number of columns in A/number of vertices on the right in g
Integer n = this.g.getNumRightVertices();
//Go over all column indices
for (int j = 0; j < n; j++) {
//Get the left-neighbors for vertex with id j on the right
//This is basically accessing column j in the matrix A
int[] neighborsOfVertexJ = this.g.getNeighborsForRightById(j);
//Get the support of row j in the representationsForRight matrix
//This is the same as the support of row j in representationsForRight matrix.
int[] nonZerosColsInRowJ = representationsForRight.getColIdsForRow(j);
//Get the non-zero values in row j (corresponding to the support)
double[] nonZeroValuesInRowJ = representationsForRight.getValuesForRow(j);
//Go over all the elements in the cartesian product: nonZeroColsInRowJ X neighborsOfVertexJ
//For every (l,i) in this set, update the hash map columnsOfProductMatrix using (l,i) as key:
//columnsOfProductMatrix[(l,i)] = 0, if (l,i) is not in the map
//columnsOfProductMatrix[(l,i)]+= representationsForRight[j][l] if (l,i) is in the map
//update the numEdgesFromLeftToCommunities hash table in a similar manner.
for (int i : neighborsOfVertexJ) {
for (int it = 0; it < nonZerosColsInRowJ.length; it++) {
//compute l as defined
int l = nonZerosColsInRowJ[it];
//update the hast table with key (l,i) as mentioned above
HashMap<Integer, Double> rowIinProductMatrix
= columnsOfProductMatrix.getOrDefault(l, new HashMap<Integer, Double>());
double curValOfMatrixAtIL = rowIinProductMatrix.getOrDefault(i, 0.0);
rowIinProductMatrix.put(i, curValOfMatrixAtIL + nonZeroValuesInRowJ[it]);
columnsOfProductMatrix.put(l, rowIinProductMatrix);
//update the number of edges from vertex i on left to community l in a similar manner
HashMap<Integer, Integer> numEdgesFromVerticesToCommunityL
= numEdgesFromLeftToCommunities.getOrDefault(l, new HashMap<Integer, Integer>());
int curNumEdgesFromVertexIToCommunityL
= numEdgesFromVerticesToCommunityL.getOrDefault(i, 0);
numEdgesFromVerticesToCommunityL.put(i, curNumEdgesFromVertexIToCommunityL + 1);
numEdgesFromLeftToCommunities.put(l, numEdgesFromVerticesToCommunityL);
}
}
}
//We now remove keys (l,i) in the hash table columnsOfProductMatrix for which the number
//of edges from vertex i (left) to community l is less than the threshold parameter using the
//information stored in numEdgesFromLeftToCommunities hash table.
for (int l: numEdgesFromLeftToCommunities.keySet()) {
HashMap<Integer, Integer> numEdgesFromVerticesToCommunityL
= numEdgesFromLeftToCommunities.get(l);
HashMap<Integer, Double> columnLOfProductMatrix = columnsOfProductMatrix.get(l);
for (int i: numEdgesFromVerticesToCommunityL.keySet()) {
int curNumEdgesFromVertexIToCommunityL
= numEdgesFromVerticesToCommunityL.get(i);
if (curNumEdgesFromVertexIToCommunityL < this.thresholdForStepThree) {
columnLOfProductMatrix.remove(i);
}
}
columnsOfProductMatrix.put(l, columnLOfProductMatrix);
}
//We now use columnsOfProductMatrix to construct a SparseRealMatrix object to represent the
//m X kp matrix that contains latent representation vectors for the m left vertices. We call it
//returnMatrix. We need columns as IntSets to define the support of the matrix to be computed
IntSet[] columnsOfReturnMatrix = new IntSet[kp];
//Update support of columns of the return matrix using the columnsOfProductMatrix hash map.
for (int l = 0; l < kp; l++) {
if (!columnsOfProductMatrix.containsKey(l)) {
columnsOfReturnMatrix[l] = new IntOpenHashSet(0);
} else {
Set<Integer> curCol = columnsOfProductMatrix.get(l).keySet();
columnsOfReturnMatrix[l] = new IntOpenHashSet(curCol.size());
for (int rowIndex : curCol) {
columnsOfReturnMatrix[l].add(rowIndex);
}
}
}
//m = number of left vertices
Integer m = this.g.getNumLeftVertices();
//We now use columnsOfReturnMatrix to define a SparseBinaryMatrix
SparseBinaryMatrix supportOfReturnMatrix = new SparseBinaryMatrix(m, kp);
supportOfReturnMatrix.initFromColSets(columnsOfReturnMatrix);
//Need a 2 dimensional array to store the non-zero double values in the m rows of the return
//matrix
double[][] valuesOfReturnMatrix = new double[m][];
//We populate this 2 dimensional array by traversing the rows of supportOfReturnMatrix.
for (int row = 0; row < m; row++) {
int[] nonZeroColsInRow = supportOfReturnMatrix.getRow(row);
valuesOfReturnMatrix[row] = new double[nonZeroColsInRow.length];
for (int index = 0; index < nonZeroColsInRow.length; index++) {
int colId = nonZeroColsInRow[index];
valuesOfReturnMatrix[row][index] = columnsOfProductMatrix.get(colId).get(row);
}
}
//We can now initialize a SparseRealMatrix using supportOfReturnMatrix and valuesOfReturnMatrix
SparseRealMatrix returnMatrix =
new SparseRealMatrix(supportOfReturnMatrix, valuesOfReturnMatrix);
return returnMatrix;
}