Skip to content

Commit a14eb56

Browse files
committed
Parallelised genetic algorithm
1 parent 333bf08 commit a14eb56

File tree

3 files changed

+201
-147
lines changed

3 files changed

+201
-147
lines changed

dna/src/main/java/dna/export/Polarisation.java

Lines changed: 150 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -823,130 +823,161 @@ public double[] getQ() {
823823
*
824824
* @return A PolarisationResultTimeSeries object containing the results of the genetic algorithm for each time step and iteration.
825825
*/
826-
public PolarisationResultTimeSeries geneticAlgorithm () {
827-
Random rng = (this.randomSeed == 0) ? new Random() : new Random(this.randomSeed); // Initialize random number generator
828-
ArrayList<PolarisationResult> polarisationResults = new ArrayList<>();
829-
try (ProgressBar pb = new ProgressBar("Genetic algorithm", this.congruence.size())) {
830-
for (int t = 0; t < this.congruence.size(); t++) {
831-
if (this.congruence.get(t).getMatrix().length > 0 && calculateMatrixNorm(this.congruence.get(t).getMatrix()) + calculateMatrixNorm(this.conflict.get(t).getMatrix()) != 0) { // if the network has no nodes or activity, skip this step and return 0 directly
832-
double[] qualityScores; // Quality scores for each time step
833-
double maxQ = -1;
834-
double avgQ, sdQ;
835-
int maxIndex = -1;
836-
boolean earlyConvergence = false;
837-
int lastIndex = -1;
838-
839-
double[] maxQArray = new double[numIterations];
840-
double[] avgQArray = new double[numIterations];
841-
double[] sdQArray = new double[numIterations];
842-
843-
// Create initially random cluster solutions; supply the number of nodes and clusters
844-
ArrayList<ClusterSolution> cs = new ArrayList<ClusterSolution>();
845-
for (int i = 0; i < numParents; i++) {
846-
cs.add(new ClusterSolution(this.congruence.get(t).getMatrix().length, numClusters, rng));
847-
}
848-
849-
// Run through iterations and do the breeding, then collect results and stats
850-
lastIndex = numIterations - 1; // choose last possible value here as a default if early convergence does not happen
851-
for (int i = 0; i < numIterations; i++) {
852-
GeneticIteration geneticIteration = new GeneticIteration(cs, this.congruence.get(t).getMatrix(), this.conflict.get(t).getMatrix(), this.normaliseScores, this.numClusters, rng);
853-
cs = geneticIteration.getChildren();
854-
855-
// compute summary statistics based on iteration step and retain them
856-
qualityScores = geneticIteration.getQ();
857-
maxQ = -1.0;
858-
avgQ = 0.0;
859-
sdQ = 0.0;
860-
maxIndex = -1;
861-
for (int j = 0; j < cs.size(); j++) {
862-
avgQ += qualityScores[j];
863-
if (qualityScores[j] > maxQ) {
864-
maxQ = qualityScores[j];
865-
maxIndex = j;
866-
}
867-
}
868-
avgQ = avgQ / numParents;
869-
for (int j = 0; j < numParents; j++) {
870-
sdQ = sdQ + Math.sqrt(((qualityScores[j] - avgQ) * (qualityScores[j] - avgQ)) / numParents);
871-
}
872-
maxQArray[i] = maxQ;
873-
avgQArray[i] = avgQ;
874-
sdQArray[i] = sdQ;
875-
876-
// check early convergence
877-
earlyConvergence = true;
878-
if (i >= 10 && (double) Math.round(sdQ * 100) / 100 == 0.00 && (double) Math.round(maxQ * 100) / 100 == (double) Math.round(avgQ * 100) / 100) {
879-
for (int j = i - 10; j < i; j++) {
880-
if ((double) Math.round(maxQArray[j] * 100) / 100 != (double) Math.round(maxQ * 100) / 100 ||
881-
(double) Math.round(avgQArray[j] * 100) / 100 != (double) Math.round(avgQ * 100) / 100 ||
882-
(double) Math.round(sdQArray[j] * 100) / 100 != 0.00) {
883-
earlyConvergence = false;
884-
}
885-
}
886-
} else {
887-
earlyConvergence = false;
888-
}
889-
if (earlyConvergence == true) {
890-
lastIndex = i;
891-
break;
892-
}
893-
}
894-
895-
// correct for early convergence in results vectors
896-
int finalIndex = lastIndex;
897-
for (int i = lastIndex; i >= 0; i--) {
898-
if (maxQArray[i] == maxQArray[lastIndex]) {
899-
finalIndex = i;
900-
} else {
901-
break;
902-
}
903-
}
904-
905-
double[] maxQArrayTemp = new double[finalIndex + 1];
906-
double[] avgQArrayTemp = new double[finalIndex + 1];
907-
double[] sdQArrayTemp = new double[finalIndex + 1];
908-
for (int i = 0; i < finalIndex + 1; i++) {
909-
maxQArrayTemp[i] = maxQArray[i];
910-
avgQArrayTemp[i] = avgQArray[i];
911-
sdQArrayTemp[i] = sdQArray[i];
826+
public PolarisationResultTimeSeries geneticAlgorithm() {
827+
Random r = (this.randomSeed == 0) ? new Random() : new Random(this.randomSeed); // Initialize RNG
828+
829+
ArrayList<PolarisationResult> polarisationResults = ProgressBar
830+
.wrap(IntStream.range(0, Polarisation.this.congruence.size()).parallel(), "Genetic algorithm")
831+
.map(t -> geneticTimeStep(t, r.nextLong()))
832+
.collect(Collectors.toCollection(ArrayList::new));
833+
834+
return new PolarisationResultTimeSeries(polarisationResults);
835+
}
836+
837+
/**
838+
* Runs the genetic algorithm for a single time step.
839+
*
840+
* @param t The time step index.
841+
* @param seed A random seed to ensure reproducibility.
842+
* @return The PolarisationResult for the given time step.
843+
*/
844+
private PolarisationResult geneticTimeStep(int t, long seed) {
845+
// Skip empty networks
846+
if (this.congruence.get(t).getMatrix().length == 0 ||
847+
(calculateMatrixNorm(this.congruence.get(t).getMatrix()) + calculateMatrixNorm(this.conflict.get(t).getMatrix())) == 0) {
848+
849+
return new PolarisationResult(
850+
new double[]{0}, new double[]{0}, new double[]{0}, 0.0,
851+
new int[0], new String[0], true,
852+
this.congruence.get(t).getStart(),
853+
this.congruence.get(t).getStop(),
854+
this.congruence.get(t).getDateTime()
855+
);
856+
}
857+
858+
// Genetic Algorithm Variables
859+
Random rng = new Random(seed);
860+
double maxQ = -1, avgQ, sdQ;
861+
int maxIndex = -1;
862+
boolean earlyConvergence = false;
863+
int lastIndex = numIterations - 1; // choose last possible value here as a default if early convergence does not happen
864+
865+
double[] maxQArray = new double[numIterations];
866+
double[] avgQArray = new double[numIterations];
867+
double[] sdQArray = new double[numIterations];
868+
869+
// Initialize random cluster solutions
870+
ArrayList<ClusterSolution> cs = new ArrayList<>();
871+
for (int i = 0; i < numParents; i++) {
872+
cs.add(new ClusterSolution(this.congruence.get(t).getMatrix().length, numClusters, rng));
873+
}
874+
875+
// Iterative breeding process
876+
for (int i = 0; i < numIterations; i++) {
877+
GeneticIteration geneticIteration = new GeneticIteration(
878+
cs, this.congruence.get(t).getMatrix(),
879+
this.conflict.get(t).getMatrix(),
880+
this.normaliseScores, this.numClusters, rng
881+
);
882+
cs = geneticIteration.getChildren();
883+
884+
// Compute quality metrics
885+
double[] qualityScores = geneticIteration.getQ();
886+
maxQ = -1.0;
887+
avgQ = 0.0;
888+
sdQ = 0.0;
889+
maxIndex = -1;
890+
891+
for (int j = 0; j < cs.size(); j++) {
892+
avgQ += qualityScores[j];
893+
if (qualityScores[j] > maxQ) {
894+
maxQ = qualityScores[j];
895+
maxIndex = j;
896+
}
897+
}
898+
avgQ /= numParents;
899+
900+
for (int j = 0; j < numParents; j++) {
901+
sdQ += Math.sqrt(((qualityScores[j] - avgQ) * (qualityScores[j] - avgQ)) / numParents);
902+
}
903+
904+
maxQArray[i] = maxQ;
905+
avgQArray[i] = avgQ;
906+
sdQArray[i] = sdQ;
907+
908+
// Early Convergence Check
909+
earlyConvergence = true;
910+
if (i >= 10 && (double) Math.round(sdQ * 100) / 100 == 0.00 &&
911+
(double) Math.round(maxQ * 100) / 100 == (double) Math.round(avgQ * 100) / 100) {
912+
913+
for (int j = i - 10; j < i; j++) {
914+
if ((double) Math.round(maxQArray[j] * 100) / 100 != (double) Math.round(maxQ * 100) / 100 ||
915+
(double) Math.round(avgQArray[j] * 100) / 100 != (double) Math.round(avgQ * 100) / 100 ||
916+
(double) Math.round(sdQArray[j] * 100) / 100 != 0.00) {
917+
918+
earlyConvergence = false;
912919
}
913-
maxQArray = maxQArrayTemp;
914-
avgQArray = avgQArrayTemp;
915-
sdQArray = sdQArrayTemp;
916-
917-
// save results in array as a complex object
918-
PolarisationResult pr = new PolarisationResult(
919-
maxQArray.clone(),
920-
avgQArray.clone(),
921-
sdQArray.clone(),
922-
maxQ,
923-
cs.get(maxIndex).getMemberships().clone(),
924-
this.congruence.get(t).getRowNames(),
925-
earlyConvergence,
926-
this.congruence.get(t).getStart(),
927-
this.congruence.get(t).getStop(),
928-
this.congruence.get(t).getDateTime());
929-
polarisationResults.add(pr);
930-
} else { // zero result because network is empty
931-
PolarisationResult pr = new PolarisationResult(
932-
new double[] { 0 },
933-
new double[] { 0 },
934-
new double[] { 0 },
935-
0.0,
936-
new int[0],
937-
new String[0],
938-
true,
939-
this.congruence.get(t).getStart(),
940-
this.congruence.get(t).getStop(),
941-
this.congruence.get(t).getDateTime());
942-
polarisationResults.add(pr);
943920
}
944-
pb.step();
921+
} else {
922+
earlyConvergence = false;
923+
}
924+
925+
if (earlyConvergence) {
926+
lastIndex = i;
927+
break;
945928
}
946929
}
947-
PolarisationResultTimeSeries polarisationResultTimeSeries = new PolarisationResultTimeSeries(polarisationResults);
948-
return polarisationResultTimeSeries;
930+
931+
// Adjust results for early convergence
932+
int finalIndex = lastIndex;
933+
for (int i = lastIndex; i >= 0; i--) {
934+
if (maxQArray[i] == maxQArray[lastIndex]) {
935+
finalIndex = i;
936+
} else {
937+
break;
938+
}
939+
}
940+
941+
double[] maxQArrayTemp = new double[finalIndex + 1];
942+
double[] avgQArrayTemp = new double[finalIndex + 1];
943+
double[] sdQArrayTemp = new double[finalIndex + 1];
944+
945+
for (int i = 0; i < finalIndex + 1; i++) {
946+
maxQArrayTemp[i] = maxQArray[i];
947+
avgQArrayTemp[i] = avgQArray[i];
948+
sdQArrayTemp[i] = sdQArray[i];
949+
}
950+
951+
// Store results
952+
return new PolarisationResult(
953+
maxQArrayTemp, avgQArrayTemp, sdQArrayTemp,
954+
maxQ, cs.get(maxIndex).getMemberships().clone(),
955+
this.congruence.get(t).getRowNames(), earlyConvergence,
956+
this.congruence.get(t).getStart(),
957+
this.congruence.get(t).getStop(),
958+
this.congruence.get(t).getDateTime()
959+
);
949960
}
961+
962+
963+
964+
965+
966+
967+
968+
969+
970+
971+
972+
973+
974+
975+
976+
977+
978+
979+
980+
950981

951982
/** Calculate the entrywise 1-norm (= the sum of absolute values) of a matrix. The
952983
* input matrix is represented by a two-dimensional double array.

0 commit comments

Comments
 (0)