summaryrefslogtreecommitdiff
path: root/base/src/main/java/bjc/utils/gen/WeightedRandom.java
blob: c9bdad83eadd0c4eb29f7d9db10434df0ff46d66 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package bjc.utils.gen;

import java.util.Random;

import bjc.utils.data.IHolder;
import bjc.utils.data.IPair;
import bjc.utils.data.Identity;
import bjc.utils.funcdata.FunctionalList;
import bjc.utils.funcdata.IList;

/**
 * Represents a random number generator where certain results are weighted more
 * heavily than others.
 *
 * @author ben
 *
 * @param <E>
 *        The type of values that are randomly selected.
 */
public class WeightedRandom<E> {
	/* The list of probabilities for each result */
	private final IList<Integer> probabilities;
	/* The list of possible results to pick from */
	private final IList<E> results;

	/* The source for any needed random numbers */
	private Random source;
	/* The total chance for all values. */
	private int totalChance;

	private final static Random BASE = new Random();

	/**
	 * Create a new weighted random generator with the specified source of
	 * randomness.
	 *
	 * @param src
	 *        The source of randomness to use.
	 */
	public WeightedRandom(Random src) {
		probabilities = new FunctionalList<>();
		results       = new FunctionalList<>();

		if(src == null) throw new NullPointerException("Source of randomness must not be null");

		source = src;
	}

	/**
	 * Create a new weighted random generator.
	 */
	public WeightedRandom() {
		this(BASE);
	}
	/**
	 * Add a probability for a specific result to be given.
	 *
	 * @param chance
	 *        The chance to get this result.
	 *
	 * @param result
	 *        The result to get when the chance comes up.
	 */
	public void addProbability(final int chance, final E result) {
		probabilities.add(chance);
		results.add(result);

		totalChance += chance;
	}

	/**
	 * Generate a weighted random value.
	 *
	 * @return A random value selected in a weighted fashion.
	 */
	public E generateValue() {
		return generateValue(source);
	}

	public E generateValue(Random rn) {
		int target = rn.nextInt(totalChance);
		int i = 0;

		for(int prob : probabilities) {
			if(target < prob) return results.getByIndex(i);

			target -= prob;
			i      += 1;
		}

		throw new NullPointerException("Fell off the end of the results list");
	}
	/**
	 * Return a list of values that can be generated by this generator
	 *
	 * @return A list of all the values that can be generated
	 */
	public IList<E> getResults() {
		return results;
	}

	/**
	 * Return a list containing values that can be generated paired with the
	 * probability of those values being generated
	 *
	 * @return A list of pairs of values and value probabilities
	 */
	public IList<IPair<Integer, E>> getValues() {
		return probabilities.pairWith(results);
	}

	public E getDescent(int factor) {
		return getDescent(factor, source);
	}

	public E getDescent(int factor, Random rn) {
		for(E res : results) {
			if(rn.nextInt(factor) == 0) continue;

			return res;
		}

		return results.getByIndex(results.getSize() - 1);
	}

	public E getBinomial(int target, int bound, int trials) {
		return getBinomial(target, bound, trials, source);
	}

	public E getBinomial(int target, int bound, int trials, Random rn) {
		int numSuc = 0;

		for(int i = 0; i < trials; i++) {
			/* 
			 * Adjust for zero, because it's easy to think of this
			 * as rolling a bound-sided dice and marking a success
			 * for every roll less than or equal to target.
			 */
			int num = rn.nextInt(bound) + 1;

			if(num <= target) {
				//System.err.printf("\t\tTRACE: mark binomial success (%d <= 1d%d, %d)\n", target, bound, num);
				numSuc += 1;
			}
		}

		//System.err.printf("\tTRACE: got %d success for binomial trials (%d <= 1d%d, %d times)\n", numSuc, target, bound, trials);
		return results.getByIndex(Math.min(numSuc, results.getSize() - 1));
	}
}