private SparseRealMatrix updateRepresentationsForRight()

in src/main/java/com/twitter/sbf/core/SimClustersInMemory.java [471:573]


  private SparseRealMatrix updateRepresentationsForRight(SparseRealMatrix representationsForLeft) {
    //A nested HashMap to store columns of the nXk' matrix that is obtained by right-multiplying the
    //transpose of the m X n bi-adjacency matrix of the bipartite graph with the m X k' left
    //representations matrix V. Thus, the first key of the HashMap is the column index 0<=j<k', and
    //the keys for the inner HashMap are row-indices 0<=i<n.
    HashMap<Integer, HashMap<Integer, Double>> columnsOfProductMatrix = new HashMap<>();
    //Number of vertices on the left in the bipartite graph g
    Integer m = g.getNumLeftVertices();
    //Number of columns in the product matrix, i.e. k' = the number of columns in the
    //representationsForLeft matrix
    Integer kp = representationsForLeft.getNumCols();
    //Go over all left vertex indices
    for (int l = 0; l < m; l++) {
      //Get the support and non-zero values row l of the leftRepresentations matrix.
      int[] rowLRepLeftSupport = representationsForLeft.getColIdsForRow(l);
      double[] rowLRepLeftValues = representationsForLeft.getValuesForRow(l);
      //Get the (right) neighbors of the vertex L in the bipartite graph
      int[] neighborsOfVertexL = g.getNeighborsForLeftById(l);
      //Go over all the elements in the cartesian product: rowLRepLeftSupport X neighborsOfVertexL
      //For every (j,i) in this set, update the hash map columnsOfProductMatrix using (j,i) as key:
      //columnsOfProductMatrix[(j,i)] = 0, if (j,i) is not in the HashMap as a key yet,
      //columnsOfProductMatrix[(j,i)]+= representationsForLeft[l][j] if (j,i) is present as a key.
      for (int it = 0; it < rowLRepLeftSupport.length; it++) {
        //compute j as defined
        int j = rowLRepLeftSupport[it];
        //Get column j of the product matrix from the HashMap
        HashMap<Integer, Double> colJInProductMatrix
              = columnsOfProductMatrix.getOrDefault(j, new HashMap<Integer, Double>());
        for (int i : neighborsOfVertexL) {
          //update the hash table with key (j,i) as mentioned above
          double curValOfMatrixAtColJRowI = colJInProductMatrix.getOrDefault(i, 0.0);
          colJInProductMatrix.put(i, curValOfMatrixAtColJRowI + rowLRepLeftValues[it]);
        }
        //Update column j of the product matrix in the HashMap
        columnsOfProductMatrix.put(j, colJInProductMatrix);
      }
    }
    //We now update the value for each key (j,i) in the hash table columnsOfProductMatrix to:
    //columnOfProductMatrix[(j,i)]/Sqrt(NumNeighborsOfVertexI) X L2Norm(ColJofRepresentationForLeft
    //Additionally, if this value is less than thresholdForStepFour then we remove it from the table
    //In case of updating the existing columnsOfProductMatrix HashMap, we create a new HashMap.
    HashMap<Integer, HashMap<Integer, Double>> columnsOfProductMatrixUpdated =
        new HashMap<>();
    for (int j: columnsOfProductMatrix.keySet()) {
      HashMap<Integer, Double> columnJOfProductMatrix = columnsOfProductMatrix.get(j);
      HashMap<Integer, Double> newColumnJOfProductMatrix = new HashMap<>();
      double l2NormOfColumnJinLeftRepMatrix = representationsForLeft.getColumnNorm(j);
      for (int i: columnJOfProductMatrix.keySet()) {
        double sqrtOfDegreeOfVertexI = Math.sqrt(g.getNeighborsForRightById(i).length);
        double normalizationFactor =
            l2NormOfColumnJinLeftRepMatrix * sqrtOfDegreeOfVertexI;
        double newValueOfProductMatrixAtColJRowI =
            columnJOfProductMatrix.get(i) / normalizationFactor;
        //Remove value if it is less than the
        if (newValueOfProductMatrixAtColJRowI - thresholdForStepFour > 1e-5) {
          newColumnJOfProductMatrix.put(i, newValueOfProductMatrixAtColJRowI);
        }
      }
      //if newColumnJofProductMatrix has no keys, i.e. it's empty, then don't add it to new table
      if (newColumnJOfProductMatrix.size() > 0) {
        columnsOfProductMatrixUpdated.put(j, newColumnJOfProductMatrix);
      }
    }
    //Set columnsOfProductMatrix to columnsOfProductMatrixUpdated
    columnsOfProductMatrix = columnsOfProductMatrixUpdated;
    //We now use columnsOfProductMatrix to construct a SparseRealMatrix object to represent the
    //n X k' matrix that contains latent representation vectors for the n right vertices. We call it
    //returnMatrix. We need columns as IntSets to define the support of the matrix to be computed
    IntSet[] columnsOfReturnMatrix = new IntSet[kp];
    //Update support of columns of the return matrix using the columnsOfProductMatrix hash map.
    for (int j = 0; j < kp; j++) {
      if (!columnsOfProductMatrix.containsKey(j)) {
        columnsOfReturnMatrix[j] = new IntOpenHashSet(0);
      } else {
        Set<Integer> curCol = columnsOfProductMatrix.get(j).keySet();
        columnsOfReturnMatrix[j] = new IntOpenHashSet(curCol.size());
        for (int rowIndex : curCol) {
          columnsOfReturnMatrix[j].add(rowIndex);
        }
      }
    }
    //n = number of right vertices
    Integer n = this.g.getNumRightVertices();
    //We now use columnsOfReturnMatrix to define a SparseBinaryMatrix
    SparseBinaryMatrix supportOfReturnMatrix = new SparseBinaryMatrix(n, kp);
    supportOfReturnMatrix.initFromColSets(columnsOfReturnMatrix);
    //Need a 2 dimensional array to store the non-zero double values in the n rows of the return
    //matrix
    double[][] valuesOfReturnMatrix = new double[n][];
    //We populate this 2 dimensional array by traversing the rows of supportOfReturnMatrix.
    for (int row = 0; row < n; row++) {
      int[] nonZeroColsInRow = supportOfReturnMatrix.getRow(row);
      valuesOfReturnMatrix[row] = new double[nonZeroColsInRow.length];
      for (int index = 0; index < nonZeroColsInRow.length; index++) {
        int colId = nonZeroColsInRow[index];
        valuesOfReturnMatrix[row][index] = columnsOfProductMatrix.get(colId).get(row);
      }
    }
    //We can now initialize a SparseRealMatrix using supportOfReturnMatrix and valuesOfReturnMatrix
    SparseRealMatrix returnMatrix =
        new SparseRealMatrix(supportOfReturnMatrix, valuesOfReturnMatrix);
    return returnMatrix;
  }