private SparseRealMatrix getRepresentationsForLeft()

in src/main/java/com/twitter/sbf/core/SimClustersInMemory.java [327:450]


  private SparseRealMatrix getRepresentationsForLeft(SparseRealMatrix representationsForRight) {
    //For the rest of the comments in this function k' will be denoted by kp.
    int kp = representationsForRight.getNumCols();
    System.out.println(
        "Expected k: " + this.k + ", Actual k: " + kp
    );
    //We now right-multiply the adjacency matrix A (m X n) of the bipartite graph by the matrix
    //representationsForRight (n X kp) to obtain the representation vectors for the left vertices.
    //We do this without explicitly converting the BipartiteGraph object into a SparseRealMatrix or
    //SparseBinaryMatrix object since the BipartiteGraph object maintains both row-wise and
    //column-wise adjacency information as list of non-zero entries.
    //For the sake of efficiency, instead of computing the dot product of every row of the adjacency
    //matrix with every column of the representationsForRight matrix in order to compute the final
    //product (of A and representationsForRight), we only compute the non-zero entries of the
    //final matrix. This is done by using a hash table (called columnsOfProductMatrix below)
    //which uses (l,i) as key, where l is a column index in the final matrix and i is a row index
    //in the final matrix.
    //To update the hash table, we go over every possible column index j of A, and then use the
    //non-zero entries in the column j of A and the non-zero values in row j of
    //representationsForRight. See below for more details.
    //As mentioned above, we also incorporate a thresholding operation in this step. The idea is to
    //use a Hash table with key (l,i), where l is a column index in the final matrix and i is
    //is a row index in the final matrix, that keeps track of how many right vertices that have a
    //non-zero value in index l of their latent representation (i.e., they "belong" to community l)
    //does the left vertex indexed by i connect to. If this value is less than thresholdForStepThree
    //then we will set the (i,l) entry in the final matrix to zero. The key (l,i) in this Hash table
    //can be thought of as the number of edges vertex i on left has to community l.
    //Hash table for non-zero intersections. This is represented as a nested hash table for the sake
    //of simplicity.
    HashMap<Integer, HashMap<Integer, Double>> columnsOfProductMatrix = new HashMap<>();
    //Hash table for storing number of edges vertex i on left has to community l. To be used for
    //thresholding later.
    HashMap<Integer, HashMap<Integer, Integer>> numEdgesFromLeftToCommunities = new HashMap<>();
    //n = number of columns in A/number of vertices on the right in g
    Integer n = this.g.getNumRightVertices();
    //Go over all column indices
    for (int j = 0; j < n; j++) {
      //Get the left-neighbors for vertex with id j on the right
      //This is basically accessing column j in the matrix A
      int[] neighborsOfVertexJ = this.g.getNeighborsForRightById(j);
      //Get the support of row j in the representationsForRight matrix
      //This is the same as the support of row j in representationsForRight matrix.
      int[] nonZerosColsInRowJ = representationsForRight.getColIdsForRow(j);
      //Get the non-zero values in row j (corresponding to the support)
      double[] nonZeroValuesInRowJ = representationsForRight.getValuesForRow(j);
      //Go over all the elements in the cartesian product: nonZeroColsInRowJ X neighborsOfVertexJ
      //For every (l,i) in this set, update the hash map columnsOfProductMatrix using (l,i) as key:
      //columnsOfProductMatrix[(l,i)] = 0, if (l,i) is not in the map
      //columnsOfProductMatrix[(l,i)]+= representationsForRight[j][l] if (l,i) is in the map
      //update the numEdgesFromLeftToCommunities hash table in a similar manner.
      for (int i : neighborsOfVertexJ) {
        for (int it = 0; it < nonZerosColsInRowJ.length; it++) {
          //compute l as defined
          int l = nonZerosColsInRowJ[it];
          //update the hast table with key (l,i) as mentioned above
          HashMap<Integer, Double> rowIinProductMatrix
              = columnsOfProductMatrix.getOrDefault(l, new HashMap<Integer, Double>());
          double curValOfMatrixAtIL = rowIinProductMatrix.getOrDefault(i, 0.0);
          rowIinProductMatrix.put(i, curValOfMatrixAtIL + nonZeroValuesInRowJ[it]);
          columnsOfProductMatrix.put(l, rowIinProductMatrix);
          //update the number of edges from vertex i on left to community l in a similar manner
          HashMap<Integer, Integer> numEdgesFromVerticesToCommunityL
              = numEdgesFromLeftToCommunities.getOrDefault(l, new HashMap<Integer, Integer>());
          int curNumEdgesFromVertexIToCommunityL
              = numEdgesFromVerticesToCommunityL.getOrDefault(i, 0);
          numEdgesFromVerticesToCommunityL.put(i, curNumEdgesFromVertexIToCommunityL + 1);
          numEdgesFromLeftToCommunities.put(l, numEdgesFromVerticesToCommunityL);
        }
      }
    }
    //We now remove keys (l,i) in the hash table columnsOfProductMatrix for which the number
    //of edges from vertex i (left) to community l is less than the threshold parameter using the
    //information stored in numEdgesFromLeftToCommunities hash table.
    for (int l: numEdgesFromLeftToCommunities.keySet()) {
      HashMap<Integer, Integer> numEdgesFromVerticesToCommunityL
          = numEdgesFromLeftToCommunities.get(l);
      HashMap<Integer, Double> columnLOfProductMatrix = columnsOfProductMatrix.get(l);
      for (int i: numEdgesFromVerticesToCommunityL.keySet()) {
        int curNumEdgesFromVertexIToCommunityL
            = numEdgesFromVerticesToCommunityL.get(i);
        if (curNumEdgesFromVertexIToCommunityL < this.thresholdForStepThree) {
          columnLOfProductMatrix.remove(i);
        }
      }
      columnsOfProductMatrix.put(l, columnLOfProductMatrix);
    }
    //We now use columnsOfProductMatrix to construct a SparseRealMatrix object to represent the
    //m X kp matrix that contains latent representation vectors for the m left vertices. We call it
    //returnMatrix. We need columns as IntSets to define the support of the matrix to be computed
    IntSet[] columnsOfReturnMatrix = new IntSet[kp];
    //Update support of columns of the return matrix using the columnsOfProductMatrix hash map.
    for (int l = 0; l < kp; l++) {
      if (!columnsOfProductMatrix.containsKey(l)) {
        columnsOfReturnMatrix[l] = new IntOpenHashSet(0);
      } else {
        Set<Integer> curCol = columnsOfProductMatrix.get(l).keySet();
        columnsOfReturnMatrix[l] = new IntOpenHashSet(curCol.size());
        for (int rowIndex : curCol) {
          columnsOfReturnMatrix[l].add(rowIndex);
        }
      }
    }
    //m = number of left vertices
    Integer m = this.g.getNumLeftVertices();
    //We now use columnsOfReturnMatrix to define a SparseBinaryMatrix
    SparseBinaryMatrix supportOfReturnMatrix = new SparseBinaryMatrix(m, kp);
    supportOfReturnMatrix.initFromColSets(columnsOfReturnMatrix);
    //Need a 2 dimensional array to store the non-zero double values in the m rows of the return
    //matrix
    double[][] valuesOfReturnMatrix = new double[m][];
    //We populate this 2 dimensional array by traversing the rows of supportOfReturnMatrix.
    for (int row = 0; row < m; row++) {
      int[] nonZeroColsInRow = supportOfReturnMatrix.getRow(row);
      valuesOfReturnMatrix[row] = new double[nonZeroColsInRow.length];
      for (int index = 0; index < nonZeroColsInRow.length; index++) {
        int colId = nonZeroColsInRow[index];
        valuesOfReturnMatrix[row][index] = columnsOfProductMatrix.get(colId).get(row);
      }
    }
    //We can now initialize a SparseRealMatrix using supportOfReturnMatrix and valuesOfReturnMatrix
    SparseRealMatrix returnMatrix =
        new SparseRealMatrix(supportOfReturnMatrix, valuesOfReturnMatrix);
    return returnMatrix;
  }