diff options
Diffstat (limited to 'base/src/main/java/bjc/utils/gen')
| -rw-r--r-- | base/src/main/java/bjc/utils/gen/WeightedRandom.java | 82 |
1 files changed, 65 insertions, 17 deletions
diff --git a/base/src/main/java/bjc/utils/gen/WeightedRandom.java b/base/src/main/java/bjc/utils/gen/WeightedRandom.java index c9bdad8..405d685 100644 --- a/base/src/main/java/bjc/utils/gen/WeightedRandom.java +++ b/base/src/main/java/bjc/utils/gen/WeightedRandom.java @@ -5,6 +5,7 @@ import java.util.Random; import bjc.utils.data.IHolder; import bjc.utils.data.IPair; import bjc.utils.data.Identity; +import bjc.utils.data.Pair; import bjc.utils.funcdata.FunctionalList; import bjc.utils.funcdata.IList; @@ -18,10 +19,7 @@ import bjc.utils.funcdata.IList; * The type of values that are randomly selected. */ public class WeightedRandom<E> { - /* The list of probabilities for each result */ - private final IList<Integer> probabilities; - /* The list of possible results to pick from */ - private final IList<E> results; + private final IList<IPair<Integer, E>> values; /* The source for any needed random numbers */ private Random source; @@ -30,6 +28,7 @@ public class WeightedRandom<E> { private final static Random BASE = new Random(); + private boolean exhaust; /** * Create a new weighted random generator with the specified source of * randomness. @@ -38,8 +37,7 @@ public class WeightedRandom<E> { * The source of randomness to use. */ public WeightedRandom(Random src) { - probabilities = new FunctionalList<>(); - results = new FunctionalList<>(); + values = new FunctionalList<>(); if(src == null) throw new NullPointerException("Source of randomness must not be null"); @@ -52,6 +50,15 @@ public class WeightedRandom<E> { public WeightedRandom() { this(BASE); } + + private WeightedRandom(Random src, IList<IPair<Integer, E>> vals, int chance) { + source = src; + + values = vals; + + totalChance = chance; + } + /** * Add a probability for a specific result to be given. * @@ -62,8 +69,7 @@ public class WeightedRandom<E> { * The result to get when the chance comes up. */ public void addProbability(final int chance, final E result) { - probabilities.add(chance); - results.add(result); + values.add(new Pair<>(chance, result)); totalChance += chance; } @@ -81,14 +87,23 @@ public class WeightedRandom<E> { int target = rn.nextInt(totalChance); int i = 0; - for(int prob : probabilities) { - if(target < prob) return results.getByIndex(i); + for(IPair<Integer, E> val : values) { + int prob = val.getLeft(); + + if(target < prob) { + if(exhaust) { + totalChance -= val.getLeft(); + values.removeMatching(val); + } + + return val.getRight(); + } target -= prob; i += 1; } - throw new NullPointerException("Fell off the end of the results list"); + return null; } /** * Return a list of values that can be generated by this generator @@ -96,7 +111,7 @@ public class WeightedRandom<E> { * @return A list of all the values that can be generated */ public IList<E> getResults() { - return results; + return values.map(IPair::getRight); } /** @@ -106,7 +121,7 @@ public class WeightedRandom<E> { * @return A list of pairs of values and value probabilities */ public IList<IPair<Integer, E>> getValues() { - return probabilities.pairWith(results); + return values; } public E getDescent(int factor) { @@ -114,13 +129,24 @@ public class WeightedRandom<E> { } public E getDescent(int factor, Random rn) { - for(E res : results) { + if(values.getSize() == 0) return null; + + for(IPair<Integer, E> val : values) { if(rn.nextInt(factor) == 0) continue; - return res; + if(exhaust) { + totalChance -= val.getLeft(); + + values.removeMatching(val); + } + + return val.getRight(); } - return results.getByIndex(results.getSize() - 1); + IPair<Integer, E> val = values.getByIndex(values.getSize() - 1); + if(exhaust) values.removeMatching(val); + + return val.getRight(); } public E getBinomial(int target, int bound, int trials) { @@ -128,6 +154,8 @@ public class WeightedRandom<E> { } public E getBinomial(int target, int bound, int trials, Random rn) { + if(values.getSize() == 0) return null; + int numSuc = 0; for(int i = 0; i < trials; i++) { @@ -145,6 +173,26 @@ public class WeightedRandom<E> { } //System.err.printf("\tTRACE: got %d success for binomial trials (%d <= 1d%d, %d times)\n", numSuc, target, bound, trials); - return results.getByIndex(Math.min(numSuc, results.getSize() - 1)); + IPair<Integer, E> val = values.getByIndex(Math.min(numSuc, values.getSize() - 1)); + if(exhaust) { + totalChance -= val.getLeft(); + + values.removeMatching(val); + } + + return val.getRight(); + } + + public WeightedRandom<E> exhaustible() { + IList<IPair<Integer, E>> lst = new FunctionalList<>(); + for(IPair<Integer, E> val : values) { + lst.add(val); + } + + WeightedRandom<E> res = new WeightedRandom<>(source, lst, totalChance); + + res.exhaust = true; + + return res; } } |
