Skip to content

Commit 0f41fd2

Browse files
committed
Parallelised the greedy polarisation algorithm
1 parent f0bda2f commit 0f41fd2

File tree

1 file changed

+108
-94
lines changed

1 file changed

+108
-94
lines changed

dna/src/main/java/dna/export/Polarisation.java

Lines changed: 108 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,10 @@ public PolarisationResultTimeSeries getResults() {
184184
* @param congruenceNetwork A 2D array representing the congruence network.
185185
* @param conflictNetwork A 2D array representing the conflict network.
186186
* @param normaliseScores Should the result be divided by its theoretical maximum (the sum of the two matrix norms)?
187+
* @param numClusters The number of clusters.
187188
* @return The quality of polarization as a double value.
188189
*/
189-
private double qualityAbsdiff(int[] memberships, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise) {
190+
private double qualityAbsdiff(int[] memberships, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, int numClusters) {
190191
double congruenceNorm = calculateMatrixNorm(congruenceNetwork);
191192
double conflictNorm = calculateMatrixNorm(conflictNetwork);
192193

@@ -586,10 +587,11 @@ private class GeneticIteration {
586587
* @param congruenceNetwork The congruence matrix.
587588
* @param conflictNetwork The conflict matrix.
588589
* @param normalise Should the quality/fitness scores be normalised?
590+
* @param numClusters The number of clusters.
589591
* @param rng The random number generator to use.
590592
* @return A list of children cluster solutions.
591593
*/
592-
GeneticIteration(ArrayList<ClusterSolution> clusterSolutions, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, Random rng) {
594+
GeneticIteration(ArrayList<ClusterSolution> clusterSolutions, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, int numClusters, Random rng) {
593595
this.clusterSolutions = new ArrayList<>(clusterSolutions);
594596
this.normalise = normalise;
595597
this.congruenceNetwork = congruenceNetwork.clone();
@@ -608,7 +610,7 @@ private class GeneticIteration {
608610
"Number of mutations based on the mutation percentage.");
609611
Dna.logger.log(log);
610612

611-
this.q = evaluateQuality(this.congruenceNetwork, this.conflictNetwork, this.normalise);
613+
this.q = evaluateQuality(this.congruenceNetwork, this.conflictNetwork, normalise, numClusters);
612614
this.children = eliteRetentionStep(this.clusterSolutions, this.q, this.numElites);
613615
this.children = crossoverStep(this.clusterSolutions, this.q, this.children, rng);
614616
this.children = mutationStep(this.children, this.numMutations, this.n, rng);
@@ -622,13 +624,14 @@ private class GeneticIteration {
622624
* @param congruenceNetwork The congruence network matrix.
623625
* @param conflictNetwork The conflict network matrix.
624626
* @param normalise Normalise the results?
627+
* @param numClusters The number of clusters.
625628
* @return An array of quality scores for each cluster solution.
626629
*/
627-
private double[] evaluateQuality(double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise) {
630+
private double[] evaluateQuality(double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, int numClusters) {
628631
double[] q = new double[clusterSolutions.size()];
629632
for (int i = 0; i < clusterSolutions.size(); i++) {
630633
int[] mem = clusterSolutions.get(i).getMemberships();
631-
q[i] = qualityAbsdiff(mem, congruenceNetwork, conflictNetwork, normalise);
634+
q[i] = qualityAbsdiff(mem, congruenceNetwork, conflictNetwork, normalise, numClusters);
632635
}
633636
return q;
634637
}
@@ -846,7 +849,7 @@ public PolarisationResultTimeSeries geneticAlgorithm () {
846849
// Run through iterations and do the breeding, then collect results and stats
847850
lastIndex = numIterations - 1; // choose last possible value here as a default if early convergence does not happen
848851
for (int i = 0; i < numIterations; i++) {
849-
GeneticIteration geneticIteration = new GeneticIteration(cs, this.congruence.get(t).getMatrix(), this.conflict.get(t).getMatrix(), this.normaliseScores, rng);
852+
GeneticIteration geneticIteration = new GeneticIteration(cs, this.congruence.get(t).getMatrix(), this.conflict.get(t).getMatrix(), this.normaliseScores, this.numClusters, rng);
850853
cs = geneticIteration.getChildren();
851854

852855
// compute summary statistics based on iteration step and retain them
@@ -1279,105 +1282,116 @@ private ArrayList<ExportStatement>[][][] create3dArray(String[] var1Values, Stri
12791282
/**
12801283
* Prepare the greedy membership swapping algorithm and run all the iterations.
12811284
* Take out the maximum quality measure at the last step and create an object
1282-
* that stores the polarisation results.
1285+
* that stores the polarisation results. Run the algorithm in parallel for all
1286+
* time windows.
12831287
*/
12841288
private PolarisationResultTimeSeries greedyAlgorithm () {
12851289
Random rng = (this.randomSeed == 0) ? new Random() : new Random(this.randomSeed); // Initialize random number generator
1286-
ArrayList<PolarisationResult> polarisationResults = new ArrayList<PolarisationResult>();
1290+
1291+
ArrayList<PolarisationResult> polarisationResults = ProgressBar
1292+
.wrap(IntStream.range(0, Polarisation.this.congruence.size()).parallel(), "Greedy algorithm")
1293+
.map(t -> greedyTimeStep(Polarisation.this.congruence.get(t),
1294+
Polarisation.this.conflict.get(t),
1295+
Polarisation.this.normaliseScores,
1296+
Polarisation.this.numClusters,
1297+
rng.nextLong()))
1298+
.collect(Collectors.toCollection(ArrayList::new));
1299+
1300+
PolarisationResultTimeSeries polarisationResultTimeSeries = new PolarisationResultTimeSeries(polarisationResults);
1301+
return polarisationResultTimeSeries;
1302+
}
1303+
/**
1304+
* A single run of the greedy algorithm, for one pair of congruence and conflict
1305+
* network, i.e., for one time slice.
1306+
*
1307+
* @param congruence A Matrix object containing the 2D congruence array.
1308+
* @param conflict A Matrix object containing the 2D conflict array.
1309+
* @param normaliseScores Normalise the absdiff quality/fitness scores to 1.0?
1310+
* @param numClusters The number of clusters.
1311+
* @param seed A random seed, which is used to create a new random number generator for this algorithm run. The seed should have been itself generated by a random number generator to ensure variability across time steps and reproducibility.
1312+
* @return a PolarisationResult object
1313+
*/
1314+
private PolarisationResult greedyTimeStep(Matrix congruence, Matrix conflict, boolean normaliseScores, int numClusters, long seed) {
12871315

12881316
// for each time step, run the algorithm over the cluster solutions; retain quality and memberships
1289-
double[][] congruenceMatrix, conflictMatrix;
1290-
int t, oldI, oldJ;
1317+
double[][] congruenceMatrix = congruence.getMatrix();
1318+
double[][] conflictMatrix = conflict.getMatrix();
12911319
ArrayList<Double> maxQArray = new ArrayList<Double>();
1292-
int[] bestMemberships, mem, mem2;
1293-
double maxQ, q1, q2;
1294-
boolean noChanges;
1295-
1296-
try (ProgressBar pb = new ProgressBar("Greedy algorithm", this.congruence.size())) {
1297-
for (t = 0; t < congruence.size(); t++) { // go through all time steps of the time window networks
1298-
maxQArray.clear();
1299-
congruenceMatrix = congruence.get(t).getMatrix();
1300-
conflictMatrix = conflict.get(t).getMatrix();
1301-
double combinedNorm = calculateMatrixNorm(congruenceMatrix) + calculateMatrixNorm(congruenceMatrix);
1302-
1303-
if (congruenceMatrix.length > 0 || combinedNorm == 0.0) { // if the network has no nodes or edges, skip this step and return 0 directly
1304-
1305-
// Create initially random cluster solution to update
1306-
ClusterSolution cs = new ClusterSolution(congruence.get(t).getMatrix().length, numClusters, rng);
1307-
mem = cs.getMemberships();
1308-
1309-
// evaluate quality of initial solution
1310-
maxQArray.add(qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, this.normaliseScores));
1311-
bestMemberships = mem.clone();
1312-
maxQ = maxQArray.get(0);
1313-
1314-
boolean convergence = false;
1315-
while (!convergence) { // run the two nested for-loops repeatedly until there are no more swaps
1316-
noChanges = true;
1317-
for (int i = 0; i < mem.length; i++) {
1318-
for (int j = 1; j < mem.length; j++) { // swap positions i and j in the membership vector and see if leads to higher fitness
1319-
if (i < j && mem[i] != mem[j]) {
1320-
mem2 = mem.clone();
1321-
oldI = mem2[i];
1322-
oldJ = mem2[j];
1323-
mem2[i] = oldJ;
1324-
mem2[j] = oldI;
1325-
q1 = qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, this.normaliseScores);
1326-
q2 = qualityAbsdiff(mem2, congruenceMatrix, conflictMatrix, this.normaliseScores);
1327-
if (q2 > q1) { // candidate solution has higher fitness -> keep it
1328-
mem = mem2.clone(); // accept the new solution if it was better than the previous
1329-
maxQArray.add(q2);
1330-
maxQ = q2;
1331-
bestMemberships = mem.clone();
1332-
noChanges = false;
1333-
}
1334-
}
1320+
double combinedNorm = calculateMatrixNorm(congruenceMatrix) + calculateMatrixNorm(congruenceMatrix);
1321+
1322+
if (congruenceMatrix.length > 0 || combinedNorm == 0.0) { // if the network has no nodes or edges, skip this step and return 0 directly
1323+
1324+
// Create initially random cluster solution to update
1325+
Random random = new Random(seed);
1326+
ClusterSolution cs = new ClusterSolution(congruenceMatrix.length, numClusters, random);
1327+
int[] mem = cs.getMemberships();
1328+
1329+
// evaluate quality of initial solution
1330+
maxQArray.add(qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, normaliseScores, numClusters));
1331+
int[] bestMemberships = mem.clone();
1332+
double maxQ = maxQArray.get(0);
1333+
1334+
boolean convergence = false;
1335+
while (!convergence) { // run the two nested for-loops repeatedly until there are no more swaps
1336+
boolean noChanges = true;
1337+
for (int i = 0; i < mem.length; i++) {
1338+
for (int j = 1; j < mem.length; j++) { // swap positions i and j in the membership vector and see if leads to higher fitness
1339+
if (i < j && mem[i] != mem[j]) {
1340+
int[] mem2 = mem.clone();
1341+
int oldI = mem2[i];
1342+
int oldJ = mem2[j];
1343+
mem2[i] = oldJ;
1344+
mem2[j] = oldI;
1345+
double q1 = qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, normaliseScores, numClusters);
1346+
double q2 = qualityAbsdiff(mem2, congruenceMatrix, conflictMatrix, normaliseScores, numClusters);
1347+
if (q2 > q1) { // candidate solution has higher fitness -> keep it
1348+
mem = mem2.clone(); // accept the new solution if it was better than the previous
1349+
maxQArray.add(q2);
1350+
maxQ = q2;
1351+
bestMemberships = mem.clone();
1352+
noChanges = false;
13351353
}
13361354
}
1337-
if (noChanges) {
1338-
convergence = true;
1339-
}
1340-
}
1341-
1342-
double[] maxQArray2 = new double[maxQArray.size()];
1343-
for (int i = 0; i < maxQArray.size(); i++) {
1344-
maxQArray2[i] = maxQArray.get(i);
13451355
}
1346-
1347-
// save results in array as a complex object
1348-
double[] avgQArray = maxQArray2;
1349-
double[] sdQArray = new double[maxQArray.size()];
1350-
PolarisationResult pr = new PolarisationResult(
1351-
maxQArray2,
1352-
avgQArray,
1353-
sdQArray,
1354-
maxQ,
1355-
bestMemberships,
1356-
congruence.get(t).getRowNames(),
1357-
true,
1358-
congruence.get(t).getStart(),
1359-
congruence.get(t).getStop(),
1360-
congruence.get(t).getDateTime());
1361-
polarisationResults.add(pr);
1362-
} else { // zero result because network is empty
1363-
PolarisationResult pr = new PolarisationResult(
1364-
new double[] { 0 },
1365-
new double[] { 0 },
1366-
new double[] { 0 },
1367-
0.0,
1368-
new int[0],
1369-
new String[0],
1370-
true,
1371-
congruence.get(t).getStart(),
1372-
congruence.get(t).getStop(),
1373-
congruence.get(t).getDateTime());
1374-
polarisationResults.add(pr);
13751356
}
1376-
pb.step();
1357+
if (noChanges) {
1358+
convergence = true;
1359+
}
13771360
}
1378-
}
13791361

1380-
PolarisationResultTimeSeries polarisationResultTimeSeries = new PolarisationResultTimeSeries(polarisationResults);
1381-
return polarisationResultTimeSeries;
1362+
double[] maxQArray2 = new double[maxQArray.size()];
1363+
for (int i = 0; i < maxQArray.size(); i++) {
1364+
maxQArray2[i] = maxQArray.get(i);
1365+
}
1366+
1367+
// save results in array as a complex object
1368+
double[] avgQArray = maxQArray2;
1369+
double[] sdQArray = new double[maxQArray.size()];
1370+
PolarisationResult pr = new PolarisationResult(
1371+
maxQArray2,
1372+
avgQArray,
1373+
sdQArray,
1374+
maxQ,
1375+
bestMemberships,
1376+
congruence.getRowNames(),
1377+
true,
1378+
congruence.getStart(),
1379+
congruence.getStop(),
1380+
congruence.getDateTime());
1381+
return pr;
1382+
} else { // zero result because network is empty
1383+
PolarisationResult pr = new PolarisationResult(
1384+
new double[] { 0 },
1385+
new double[] { 0 },
1386+
new double[] { 0 },
1387+
0.0,
1388+
new int[0],
1389+
new String[0],
1390+
true,
1391+
congruence.getStart(),
1392+
congruence.getStop(),
1393+
congruence.getDateTime());
1394+
return pr;
1395+
}
13821396
}
13831397
}

0 commit comments

Comments
 (0)