in src/main/java/com/twitter/sbf/core/SimClustersInMemory.java [471:573]
private SparseRealMatrix updateRepresentationsForRight(SparseRealMatrix representationsForLeft) {
//A nested HashMap to store columns of the nXk' matrix that is obtained by right-multiplying the
//transpose of the m X n bi-adjacency matrix of the bipartite graph with the m X k' left
//representations matrix V. Thus, the first key of the HashMap is the column index 0<=j<k', and
//the keys for the inner HashMap are row-indices 0<=i<n.
HashMap<Integer, HashMap<Integer, Double>> columnsOfProductMatrix = new HashMap<>();
//Number of vertices on the left in the bipartite graph g
Integer m = g.getNumLeftVertices();
//Number of columns in the product matrix, i.e. k' = the number of columns in the
//representationsForLeft matrix
Integer kp = representationsForLeft.getNumCols();
//Go over all left vertex indices
for (int l = 0; l < m; l++) {
//Get the support and non-zero values row l of the leftRepresentations matrix.
int[] rowLRepLeftSupport = representationsForLeft.getColIdsForRow(l);
double[] rowLRepLeftValues = representationsForLeft.getValuesForRow(l);
//Get the (right) neighbors of the vertex L in the bipartite graph
int[] neighborsOfVertexL = g.getNeighborsForLeftById(l);
//Go over all the elements in the cartesian product: rowLRepLeftSupport X neighborsOfVertexL
//For every (j,i) in this set, update the hash map columnsOfProductMatrix using (j,i) as key:
//columnsOfProductMatrix[(j,i)] = 0, if (j,i) is not in the HashMap as a key yet,
//columnsOfProductMatrix[(j,i)]+= representationsForLeft[l][j] if (j,i) is present as a key.
for (int it = 0; it < rowLRepLeftSupport.length; it++) {
//compute j as defined
int j = rowLRepLeftSupport[it];
//Get column j of the product matrix from the HashMap
HashMap<Integer, Double> colJInProductMatrix
= columnsOfProductMatrix.getOrDefault(j, new HashMap<Integer, Double>());
for (int i : neighborsOfVertexL) {
//update the hash table with key (j,i) as mentioned above
double curValOfMatrixAtColJRowI = colJInProductMatrix.getOrDefault(i, 0.0);
colJInProductMatrix.put(i, curValOfMatrixAtColJRowI + rowLRepLeftValues[it]);
}
//Update column j of the product matrix in the HashMap
columnsOfProductMatrix.put(j, colJInProductMatrix);
}
}
//We now update the value for each key (j,i) in the hash table columnsOfProductMatrix to:
//columnOfProductMatrix[(j,i)]/Sqrt(NumNeighborsOfVertexI) X L2Norm(ColJofRepresentationForLeft
//Additionally, if this value is less than thresholdForStepFour then we remove it from the table
//In case of updating the existing columnsOfProductMatrix HashMap, we create a new HashMap.
HashMap<Integer, HashMap<Integer, Double>> columnsOfProductMatrixUpdated =
new HashMap<>();
for (int j: columnsOfProductMatrix.keySet()) {
HashMap<Integer, Double> columnJOfProductMatrix = columnsOfProductMatrix.get(j);
HashMap<Integer, Double> newColumnJOfProductMatrix = new HashMap<>();
double l2NormOfColumnJinLeftRepMatrix = representationsForLeft.getColumnNorm(j);
for (int i: columnJOfProductMatrix.keySet()) {
double sqrtOfDegreeOfVertexI = Math.sqrt(g.getNeighborsForRightById(i).length);
double normalizationFactor =
l2NormOfColumnJinLeftRepMatrix * sqrtOfDegreeOfVertexI;
double newValueOfProductMatrixAtColJRowI =
columnJOfProductMatrix.get(i) / normalizationFactor;
//Remove value if it is less than the
if (newValueOfProductMatrixAtColJRowI - thresholdForStepFour > 1e-5) {
newColumnJOfProductMatrix.put(i, newValueOfProductMatrixAtColJRowI);
}
}
//if newColumnJofProductMatrix has no keys, i.e. it's empty, then don't add it to new table
if (newColumnJOfProductMatrix.size() > 0) {
columnsOfProductMatrixUpdated.put(j, newColumnJOfProductMatrix);
}
}
//Set columnsOfProductMatrix to columnsOfProductMatrixUpdated
columnsOfProductMatrix = columnsOfProductMatrixUpdated;
//We now use columnsOfProductMatrix to construct a SparseRealMatrix object to represent the
//n X k' matrix that contains latent representation vectors for the n right vertices. We call it
//returnMatrix. We need columns as IntSets to define the support of the matrix to be computed
IntSet[] columnsOfReturnMatrix = new IntSet[kp];
//Update support of columns of the return matrix using the columnsOfProductMatrix hash map.
for (int j = 0; j < kp; j++) {
if (!columnsOfProductMatrix.containsKey(j)) {
columnsOfReturnMatrix[j] = new IntOpenHashSet(0);
} else {
Set<Integer> curCol = columnsOfProductMatrix.get(j).keySet();
columnsOfReturnMatrix[j] = new IntOpenHashSet(curCol.size());
for (int rowIndex : curCol) {
columnsOfReturnMatrix[j].add(rowIndex);
}
}
}
//n = number of right vertices
Integer n = this.g.getNumRightVertices();
//We now use columnsOfReturnMatrix to define a SparseBinaryMatrix
SparseBinaryMatrix supportOfReturnMatrix = new SparseBinaryMatrix(n, kp);
supportOfReturnMatrix.initFromColSets(columnsOfReturnMatrix);
//Need a 2 dimensional array to store the non-zero double values in the n rows of the return
//matrix
double[][] valuesOfReturnMatrix = new double[n][];
//We populate this 2 dimensional array by traversing the rows of supportOfReturnMatrix.
for (int row = 0; row < n; row++) {
int[] nonZeroColsInRow = supportOfReturnMatrix.getRow(row);
valuesOfReturnMatrix[row] = new double[nonZeroColsInRow.length];
for (int index = 0; index < nonZeroColsInRow.length; index++) {
int colId = nonZeroColsInRow[index];
valuesOfReturnMatrix[row][index] = columnsOfProductMatrix.get(colId).get(row);
}
}
//We can now initialize a SparseRealMatrix using supportOfReturnMatrix and valuesOfReturnMatrix
SparseRealMatrix returnMatrix =
new SparseRealMatrix(supportOfReturnMatrix, valuesOfReturnMatrix);
return returnMatrix;
}