in src/main/java/com/twitter/sbf/generator/GraphGenerator.java [174:353]
public GraphAndGroundTruth generateWithOverlappingClusters(
int maxClustersForAVertex, double averageNumberOfClustersPerVertex) {
Random r = this.rng;
assert averageNumberOfClustersPerVertex > 1.0;
int sumOfClusterSizes = 0;
int[][] clusterToVertices = new int[numClusters][];
for (int clusterId = 0; clusterId < numClusters; clusterId++) {
int sampledClusterSize = minClusterSize + r.nextInt(maxClusterSize - minClusterSize);
clusterToVertices[clusterId] = new int[sampledClusterSize];
sumOfClusterSizes += sampledClusterSize;
}
int numNodes = (int) Math.ceil(sumOfClusterSizes / averageNumberOfClustersPerVertex);
IntSet[] vertexToClusterIds = new IntOpenHashSet[numNodes];
for (int vertexId = 0; vertexId < numNodes; vertexId++) {
vertexToClusterIds[vertexId] = new IntOpenHashSet(maxClustersForAVertex);
}
for (int clusterId = 0; clusterId < numClusters; clusterId++) {
for (int j = 0; j < clusterToVertices[clusterId].length; j++) {
int sampledNode;
IntSet clustersForSampledNode;
while (true) {
sampledNode = r.nextInt(numNodes);
clustersForSampledNode = vertexToClusterIds[sampledNode];
if (!clustersForSampledNode.contains(clusterId)
&& clustersForSampledNode.size() < maxClustersForAVertex) {
break;
}
}
clusterToVertices[clusterId][j] = sampledNode;
clustersForSampledNode.add(clusterId);
}
}
ArrayList<HashSet<Integer>> adjLists = new ArrayList<>(numNodes);
ArrayList<HashMap<Integer, Integer>> intersectionSizesForEachEdge = new ArrayList<>(numNodes);
for (int i = 0; i < numNodes; i++) {
adjLists.add(new HashSet<>());
intersectionSizesForEachEdge.add(new HashMap<>());
}
int numEdges = 0;
int numIntraClusterEdges = 0;
int numRepeatedIntraClusterEdges = 0;
for (int clusterId = 0; clusterId < numClusters; clusterId++) {
int sizeOfThisCluster = clusterToVertices[clusterId].length;
double probForThisCluster = getProbabilityInsideCluster(clusterToVertices[clusterId].length);
GeometricDistribution gd = new GeometricDistribution(rng, probForThisCluster);
int numEdgesInsideThisCluster = 0;
for (int nodeIdInsideCluster = 0; nodeIdInsideCluster < sizeOfThisCluster;
nodeIdInsideCluster++) {
// Starting from nodeId after current node, repeat
for (int neighborNode = nodeIdInsideCluster + 1; neighborNode < sizeOfThisCluster;) {
int numNodesToSkip = gd.sample();
neighborNode += numNodesToSkip;
if (neighborNode < sizeOfThisCluster) {
int nId1 = clusterToVertices[clusterId][nodeIdInsideCluster];
int nId2 = clusterToVertices[clusterId][neighborNode];
neighborNode++;
numEdgesInsideThisCluster++;
numIntraClusterEdges++;
if (!adjLists.get(nId1).contains(nId2)) {
adjLists.get(nId1).add(nId2);
adjLists.get(nId2).add(nId1);
int smaller = Math.min(nId1, nId2);
int bigger = Math.max(nId1, nId2);
intersectionSizesForEachEdge.get(smaller).put(bigger, 1);
numEdges++;
} else {
numRepeatedIntraClusterEdges++;
int smaller = Math.min(nId1, nId2);
int bigger = Math.max(nId1, nId2);
int current = intersectionSizesForEachEdge.get(smaller).get(bigger);
intersectionSizesForEachEdge.get(smaller).put(bigger, current + 1);
}
}
}
}
float actualProbInsideCluster =
numEdgesInsideThisCluster * 1.0f / ((sizeOfThisCluster * (sizeOfThisCluster - 1)) / 2);
assert Math.abs(actualProbInsideCluster - probForThisCluster) < 0.1
: String.format("cluster %d: size %d, assigned prob. %f, actual prob. %f\n",
clusterId + 1, sizeOfThisCluster, probForThisCluster, actualProbInsideCluster);
}
int numGlobalEdges = 0;
GeometricDistribution gd = new GeometricDistribution(rng, getGlobalProbability());
for (int nodeId = 0; nodeId < numNodes; nodeId++) {
for (int neighborNodePosition = nodeId + 1; neighborNodePosition < numNodes;) {
int numNodesToSkip = gd.sample();
neighborNodePosition += numNodesToSkip;
if (neighborNodePosition < numNodes) {
if (!adjLists.get(nodeId).contains(neighborNodePosition)) {
adjLists.get(nodeId).add(neighborNodePosition);
adjLists.get(neighborNodePosition).add(nodeId);
neighborNodePosition++;
numEdges++;
numGlobalEdges++;
}
}
}
}
int[][] nbrs = new int[numNodes][];
int numEdges2 = 0;
for (int i = 0; i < numNodes; i++) {
nbrs[i] = new int[adjLists.get(i).size()];
int j = 0;
HashSet<Integer> s = adjLists.get(i);
for (int neighbor : s) {
nbrs[i][j] = neighbor;
j++;
numEdges2++;
}
Arrays.sort(nbrs[i]);
}
float[][] wts = null;
int[] histogramOfIntersectionSizes = new int[maxClustersForAVertex + 1];
if (this.isWeighted) {
wts = new float[numNodes][];
for (int i = 0; i < numNodes; i++) {
wts[i] = new float[nbrs[i].length];
for (int j = 0; j < wts[i].length; j++) {
int intersectionCount = 0;
if (i < nbrs[i][j]) {
intersectionCount = intersectionSizesForEachEdge.get(i)
.getOrDefault(nbrs[i][j], 0);
} else {
intersectionCount = intersectionSizesForEachEdge.get(nbrs[i][j])
.getOrDefault(i, 0);
}
histogramOfIntersectionSizes[intersectionCount]++;
// Consider only cells in the lower-triangular part of the adjacency matrix, and then
// replicate the weight to the corresponding cell in the upper-triangular part.
if (nbrs[i][j] < i) {
float newWt = 0;
if (intersectionCount == 0) {
newWt = (float) (lowerWeightMode - 1.5 * lowerWeightDist.getStandardDeviation());
} else {
newWt = 0;
for (int sampleCount = 0; sampleCount < intersectionCount; sampleCount++) {
newWt += (float) sampleEdgeWeight();
}
}
wts[i][j] = newWt;
int indexOfIInJ = Arrays.binarySearch(nbrs[nbrs[i][j]], i);
wts[nbrs[i][j]][indexOfIInJ] = newWt;
}
}
}
}
System.err.print("Histogram of intersection sizes: ");
for (int i = 0; i <= maxClustersForAVertex; i++) {
System.err.format("%d -> %d, ", i, histogramOfIntersectionSizes[i]);
}
System.err.println();
System.err.format("Num repeated intra-cluster edges: %d\n", numRepeatedIntraClusterEdges);
System.err.format(
"Num intra-cluster edges: %d, num global edges: %d, fraction global edges: %f\n",
numIntraClusterEdges, numGlobalEdges, numGlobalEdges * 1.0f / numEdges);
assert numEdges * 2 == numEdges2
: String.format("numEdges %d, numEdges2 %d", numEdges, numEdges2);
IntSet[] clustersAsSets = new IntSet[numClusters];
for (int i = 0; i < numClusters; i++) {
clustersAsSets[i] = new IntOpenHashSet(clusterToVertices[i]);
}
return new GraphAndGroundTruth(new Graph(numNodes, numEdges, nbrs, wts), clustersAsSets);
}