-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGeneticLearner.java
373 lines (306 loc) · 9.36 KB
/
GeneticLearner.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import java.util.*;
import java.util.concurrent.CountDownLatch;
/**
* This is the code for our Genetic Algorithm. This file is not necessary to run PlayerSkeleton.java.
*/
public class GeneticLearner {
public static void main(String[] args) {
int initialSize = 1000;
double cutoff = 0.3;
int numGenerations = 0;
int cutoffGenerations = Integer.MAX_VALUE;
StopWatch sw = new StopWatch();
try {
sw.start();
Population p = new Population(initialSize);
System.out.println("------------------------------------------------------");
System.out.println("PROFILE: Population created in " + sw.getTime() + "ms");
System.out.println("------------------------------------------------------");
sw.start();
// Run this for a fixed number (cutoffGenerations) of generations
while(numGenerations < cutoffGenerations) {
// A single generation producing offspring
while(p.offspringProduced < initialSize*cutoff) {
p.crossover();
}
System.out.println("---------------GENERATION PROFILE------------------");
System.out.print("Generation " + (numGenerations+1) + ": ");
p.getFittest();
// Once this generation produces a certain percentage of offspring, purge the population
p.purge();
p.profile();
System.out.println("Total time elapsed: " + sw.getTime());
System.out.println("---------------------------------------------------");
numGenerations++;
}
} catch (Exception e) {
System.out.println("error: " + e);
}
}
}
class Population {
public static final int HEURISTICS = 5;
public int originalSize;
public int offspringProduced;
public PriorityQueue<WeightVector> vectors;
// Used to profile each generation
public StopWatch purgeClock;
public StopWatch sampleClock;
public StopWatch crossClock;
Comparator<WeightVector> comparator = new Comparator<WeightVector>() {
public int compare(WeightVector a, WeightVector b) {
return Double.compare(b.fitness, a.fitness);
}
};
/**
* Constructs the population with given population size
* - Creates the given number of vectors and calculates their fitness
* - Adds the vectors into the max heap of vectors (by fitness)
*/
public Population(int populationSize) {
originalSize = populationSize;
int checkpointSize = 10;
int checkpoint = populationSize/checkpointSize;
vectors = new PriorityQueue<WeightVector>(populationSize, comparator);
System.out.println("Initializing population...");
for (int i = 0; i < populationSize; i++) {
if ((i+1) % checkpoint == 0) {
System.out.println("..." + ((i+1)*100/populationSize) + "%");
System.out.print("Current best: ");
getFittest();
}
WeightVector v = new WeightVector();
vectors.add(v);
}
offspringProduced = 0;
System.out.println("\nInitial population created.");
purgeClock = new StopWatch();
sampleClock = new StopWatch();
crossClock = new StopWatch();
}
/**
* Randomly samples 10% of the population
* Chooses the two vectors with the highest fitness
* Crosses the two most fit vectors by the formula:
* v1*fitness(v1) + v2*fitness(v2)
* Adds this new vector into the population
*/
public void crossover() {
PriorityQueue<WeightVector> sample = samplePopulation();
crossClock.start();
// Takes the 2 fittest vectors
WeightVector a = sample.poll();
WeightVector b = sample.poll();
if (a == null || b == null) {
return;
}
double[] newWeights = new double[HEURISTICS];
for (int i = 0; i < HEURISTICS; i++) {
newWeights[i] = a.weights[i]*a.fitness + b.weights[i]*b.fitness;
}
WeightVector v = new WeightVector(newWeights);
crossClock.clock();
addOffspring(v);
}
/**
* Adds a new vector to the population
* Increments the number of offspring produced
*/
public void addOffspring(WeightVector v) {
vectors.add(v);
offspringProduced += 1;
}
public static final double sampleProp = 0.1;
/**
* Return a max heap with a sample of 10% of the current population
*/
public PriorityQueue<WeightVector> samplePopulation() {
sampleClock.start();
int sampleSize = new Double(originalSize*sampleProp).intValue();
PriorityQueue<WeightVector> sample = new PriorityQueue<WeightVector>(sampleSize, comparator);
WeightVector[] p = new WeightVector[sampleSize];
p = vectors.toArray(new WeightVector[sampleSize]);
List<WeightVector> population = Arrays.asList(p);
// Shuffles the population to simulate randomness in selection
Collections.shuffle(population);
for (int i = 0; i < sampleSize; i++) {
WeightVector v = population.get(i);
sample.add(v);
}
sampleClock.clock();
return sample;
}
/**
* Gets the size of the population
*/
public int size() {
return vectors.size();
}
/**
* Creates a new max heap of vectors and pushes the top (originalSize) vectors
* Sets this heap to be the new heap of the population
* Essentially removes the least fit vectors until we get back the original population size
* Resets the number of offspring produced to 0.
*/
public void purge() {
purgeClock.start();
PriorityQueue<WeightVector> q = new PriorityQueue<WeightVector>(originalSize, comparator);
while(q.size() < originalSize) {
q.add(vectors.poll());
}
vectors = q;
offspringProduced = 0;
purgeClock.clock();
}
/**
* Returns the fitness of the best vector in this current population
*/
public double getFittest() {
WeightVector v = vectors.peek();
System.out.print(Arrays.toString(v.weights) + ", fitness: " + v.fitness + "\n");
return v.fitness;
}
/**
* Prints the accumulated time of purge(), samplePopulation() and (crossing and creating a new vector)
*/
public void profile() {
System.out.println("Sample total elapsed: " + sampleClock.getElapsedTime() + "ms");
System.out.println("Purge total elapsed: " + purgeClock.getElapsedTime() + "ms");
System.out.println("Crossing total elapsed: " + crossClock.getElapsedTime() + "ms");
}
}
class WeightVector {
public static final int HEURISTICS = 5;
public double[] weights;
public double fitness;
public static final double randomMin = -1;
public static final double randomMax = 1;
/**
* Generates a vector with random weights
*/
public WeightVector() {
weights = new double[HEURISTICS];
for (int i = 0; i < HEURISTICS; i++) {
weights[i] = randomMin + (randomMax-randomMin)*Math.random();
}
normalize();
calculateFitness();
}
/**
* Generates a vector with the specified weights
*/
public WeightVector(double[] w) {
weights = new double[HEURISTICS];
for (int i = 0; i < HEURISTICS; i++) {
weights[i] = w[i];
}
normalize();
mutate();
calculateFitness();
}
/**
* Plays the game and sets the score to the fitness
*/
public void calculateFitness() {
int numGames = 10;
PlayerSkeleton p = new PlayerSkeleton(weights);
Vector<Integer> scores = new Vector<>();
CountDownLatch completionSignal = new CountDownLatch(numGames);
try {
for (int i = 0; i < numGames; i++) {
Thread t = new Thread(new GameRunner(completionSignal, scores, p));
t.start();
}
// Wait for all threads to complete.
completionSignal.await();
} catch (Exception e) {
System.out.println("Thread Interrupted");
}
Integer[] s = new Integer[numGames];
double total = 0;
scores.toArray(s);
for (int i = 0; i < numGames; i++) {
total += s[i];
}
this.fitness = total/numGames;
}
public static final double mutationThreshold = 0.2;
public static final double mutationChance = 0.05;
/**
* This is called when the vector is initialized with weights
* (after being crossed by two parents)
* It has a 5% chance of mutating a random weight to up to +/- 0.2
* The vector is then normalized
*/
public void mutate() {
if (Math.random() <= mutationChance) {
Random rand = new Random();
int randomIndex = rand.nextInt(HEURISTICS);
double mutationAmount = Math.random()*(mutationThreshold*2) - mutationThreshold;
weights[randomIndex] += mutationAmount;
}
normalize();
}
/**
* Normalizes the vector based on its magnitude
*/
public void normalize() {
double magnitude = 0;
for (int i = 0; i < HEURISTICS; i++) {
magnitude += Math.pow(weights[i], 2);
}
magnitude = Math.sqrt(magnitude);
for (int i = 0; i < HEURISTICS; i++) {
weights[i] /= magnitude;
}
}
}
class GameRunner implements Runnable {
// Require thread-safe list
private Vector<Integer> scores;
private PlayerSkeleton player;
private CountDownLatch doneSignal;
public GameRunner(CountDownLatch doneSignal, Vector<Integer> scores, PlayerSkeleton player) {
this.doneSignal = doneSignal;
this.scores = scores;
this.player = player;
}
// Plays the game and adds the score to scores.
public void run() {
int score = player.playGame();
scores.add(score);
doneSignal.countDown();
}
}
/**
* Simple stopwatch class
* @author tanzh
*
*/
class StopWatch {
long startTime;
long elapsedTime;
public StopWatch() {
startTime = 0;
elapsedTime = 0;
}
public void start() {
startTime = System.nanoTime();
}
public void reset() {
elapsedTime = 0;
}
// Used for simple start-stop functionality.
// Returns the time elapsed in ms.
public long getTime() {
return (System.nanoTime() - startTime)/1000000;
}
// Keeps track of the accumulated time.
public void clock() {
elapsedTime += System.nanoTime() - startTime;
}
// Gets the time elapsed
public long getElapsedTime() {
return elapsedTime/1000000;
}
}