Random Pick with Weight
Random Pick with Weight:
You are given a 0-indexed array of positive integers w where w[i] describes the weight of the ith index.
You need to implement the function pickIndex(), which randomly picks an index in the range [0, w.length - 1] (inclusive) and returns it. The probability of picking an index i is w[i] / sum(w).
For example, if w = [1, 3], the probability of picking index 0 is 1 / (1 + 3) = 0.25 (i.e., 25%), and the probability of picking index 1 is 3 / (1 + 3) = 0.75 (i.e., 75%).
Example 1:
Input
["Solution","pickIndex"]
[[[1]],[]]
Output
[null,0]
Explanation
Solution solution = new Solution([1]);
solution.pickIndex(); // return 0. The only option is to
return 0 since there is only one element in w.
Example 2:
Input
["Solution","pickIndex","pickIndex","pickIndex","pickIndex","pickIndex"]
[[[1,3]],[],[],[],[],[]]
Output
[null,1,1,1,1,0]
Explanation
Solution solution = new Solution([1, 3]);
solution.pickIndex(); // return 1. It is returning the
second element (index = 1) that has a probability of 3/4.
solution.pickIndex(); // return 1
solution.pickIndex(); // return 1
solution.pickIndex(); // return 1
solution.pickIndex(); // return 0. It is returning the
first element (index = 0) that has a probability of 1/4.
Since this is a randomization problem, multiple answers are allowed.
All of the following outputs can be considered correct:
[null,1,1,1,1,0]
[null,1,1,1,1,1]
[null,1,1,1,0,0]
[null,1,1,1,0,1]
[null,1,0,1,0,0]
......
and so on.
Constraints:
1 <= w.length <= 10^41 <= w[i] <= 10^5pickIndexwill be called at most10^4times.
Try this Problem on your own or check similar problems:
Solution:
- Java
- JavaScript
- Python
- C++
class Solution {
Random random;
int[] sums;
public Solution(int[] w) {
random = new Random();
for(int i = 1; i < w.length; ++i){
w[i] += w[i - 1];
}
sums = w;
}
public int pickIndex() {
int idx = random.nextInt(sums[sums.length-1]) + 1;
int i = Arrays.binarySearch(sums, idx);
return i >= 0 ? i : ~i;
}
}
/**
* Your Solution object will be instantiated and called as such:
* Solution obj = new Solution(w);
* int param_1 = obj.pickIndex();
*/
/**
* @param {number[]} w
*/
var Solution = function (w) {
this.sums = [];
let sum = 0;
for (let v of w) {
sum += v;
this.sums.push(sum);
}
this.totalSum = sum;
};
/**
* @return {number}
*/
Solution.prototype.pickIndex = function () {
let target = Math.floor(Math.random() * this.totalSum);
let start = 0,
end = this.sums.length - 1;
while (start < end) {
let mid = Math.floor((start + end) / 2);
if (this.sums[mid] > target) end = mid;
else start = mid + 1;
}
return start;
};
/**
* Your Solution object will be instantiated and called as such:
* var obj = new Solution(w)
* var param_1 = obj.pickIndex()
*/
class Solution:
def __init__(self, w: List[int]):
self.sums = []
total_sum = 0
for weight in w:
total_sum += weight
self.sums.append(total_sum)
def pickIndex(self) -> int:
target = random.randint(1, self.sums[-1])
return bisect.bisect_left(self.sums, target)
# Your Solution object will be instantiated and called as such:
# obj = Solution(w)
# param_1 = obj.pickIndex()
class Solution {
public:
Solution(vector<int>& w) {
sums.push_back(w[0]);
for(int i=1; i<w.size() ; ++i){
sums.push_back(sums[i-1] + w[i]);
}
}
int pickIndex() {
int n = rand() % sums[sums.size()-1];
auto it = upper_bound(sums.begin(),sums.end(),n);
return it - sums.begin();
}
private:
vector<int> sums;
};
/**
* Your Solution object will be instantiated and called as such:
* Solution* obj = new Solution(w);
* int param_1 = obj->pickIndex();
*/
Time/Space Complexity:
- Time Complexity: O(n+Qlogn)
nfor initilization andlognforQqueriespickIndex - Space Complexity: O(n)
Explanation:
On initialization we build the accumulated sum, best seen on an example:
w = [3, 6, 4, 5] => sums = [3, 9, 13, 18] which takes n storage leading to linear space complexity. By doing this we enable every number to own the gap between itself and the previous element (e.g. 9 now owns the space between 4..9, in total 6 numbers which is the original element in the input array for which we have accumulated sum of 9).
On search we perform binary search for idx randomly chosen from range 1..18 (random.nextInt(sums[sums.length-1]) + 1). So if for example we get an index in range 4..9 the binary search will return either the exact index of element in array if idx == 9 or it will return insertion point (where the number would be inserted to maintain the sorted array, for this we're using built-in binary search) which again would be index 1 for range 4..9. To find to which weight the range belongs we just index the initial weight array with the index we got from the binary search (for the range 4..9 weight 6 will be returned). The same goes for other ranges/gaps since for each index we will belong to one of the gaps which is owned by some element in sums/w (initial array). Since we perform binary search, we have time complexity of O(logn). The built in binary search will return (-(insertionSort) - 1), to find the right in-boundary index we can use ~ which will do ~x = -x - 1. So for the insertion point -2 which matches the second position in the array, we would have ~(-2) = -(-2) - 1 = 2 - 1 = 1 correctly mapping to the second number using index 1.