LeetCode 528. Random Pick with Weight

Question

You are given a 0-indexed array of positive integers w where w[i] describes the weight of the ith index.

You need to implement the function pickIndex(), which randomly picks an index in the range [0, w.length - 1] (inclusive) and returns it. The probability of picking an index i is w[i] / sum(w).

  • For example, if w = [1, 3], the probability of picking index 0 is 1 / (1 + 3) = 0.25 (i.e., 25%), and the probability of picking index 1 is 3 / (1 + 3) = 0.75 (i.e., 75%).

Constraints:

  • 1 <= w.length <= 104
  • 1 <= w[i] <= 105
  • pickIndex will be called at most 104 times.

Source: https://leetcode.com/problems/random-pick-with-weight/

Solution

In theory, we create several intervals from the weight array. Then generate a random value. In which interval the random value falls, the corresponding index is returned.

The key point is how to determine which interval the random value falls. We use binary search so that the time complexity is O(log(n)). In Java, TreeMap is implemented based on Red-Black tree, a kind of self-balancing binary search tree. Thus, we can also use TreeMap to get the same time complexity.

class RandomPickWithWeight

Use TreeMap to store lower bounds.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// lower bound -> index
private TreeMap<Integer, Integer> treeMap;
private Random rand;
private int scope;

public RandomPickWithWeight(int[] w) {
treeMap = new TreeMap<>();
rand = new Random(System.currentTimeMillis());
int lowerBound = 0;
for (int i = 0; i < w.length; i++) {
treeMap.put(lowerBound, i);
lowerBound += w[i];
}
scope = lowerBound;
}

public int pickIndex() {
// avoid negative random number
int r = Math.abs(rand.nextInt() % scope);
return treeMap.floorEntry(r).getValue();
}

Use prefix sum and binary search. Prefix sum array represents upper bounds.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
private Random rand;
private int scope;
// upper bounds, (,]
private int[] prefixSums;

public RandomPickWithWeightAlt1(int[] w) {
rand = new Random(System.currentTimeMillis());
prefixSums = new int[w.length];
prefixSums[0] = w[0];
for (int i = 1; i < w.length; i++) {
prefixSums[i] = prefixSums[i - 1] + w[i];
}
scope = prefixSums[w.length - 1];
}

public int pickIndex() {
// nextInt(n) returns a random number in [0,n)
int r = rand.nextInt(scope) + 1;
int len = prefixSums.length;
// both ends of the search window are inclusive
int left = 0, right = len - 1;

// mid is rounded to the floor, so left must make progress
while (left < right) {
int mid = left + (right - left) / 2; // prevent overflow
if (r == prefixSums[mid]) {
return mid;
} else if (r < prefixSums[mid]) {
right = mid;
} else {
left = mid + 1;
}
}
// left and right should be the same here
// but using right is easier to understand
return right;
}

public int pickIndex2() {
// nextInt(n) returns a random number in [0,n)
int r = rand.nextInt(scope) + 1;
int len = prefixSums.length;
// both ends of the search window are inclusive
int left = 0, right = len - 1;

// another implementation of binary search
// avoid the case that left == mid
while (left < right - 1) {
int mid = left + (right - left) / 2; // prevent overflow
if (r == prefixSums[mid]) {
return mid;
} else if (r < prefixSums[mid]) {
right = mid;
} else {
left = mid + 1;
}
}

// out of the loop, left == right or left are adjacent with right
if (r <= prefixSums[left] && r <= prefixSums[right]) {
return left;
} else {
return right;
}
}
Author

Weihao Ye

Posted on

2022-03-09

Updated on

2022-03-09

Licensed under