# 308 Range Sum Query 2D - Mutable
給一個2D array,還有左上角、右下角的index,問範圍內的總和。 另有update function,可以指定更新某index的值。
Given matrix = [
[3, 0, 1, 4, 2],
[5, 6, 3, 2, 1],
[1, 2, 0, 1, 5],
[4, 1, 0, 1, 7],
[1, 0, 3, 0, 5]
]
sumRegion(2, 1, 4, 3) -> 8
update(3, 2, 2)
sumRegion(2, 1, 4, 3) -> 10
Concept: 和307一樣用Binary Index Tree。 複習一下BIT的存法: 當tree[i][j]有用到nums[i][j],它的ancestor就也有用到nums[i][j],所以必須找出這些tree[i][j]所有的ancestor並更新。找到ancestor的方法為:idx = i, idx = i+(i&-i), ... 直到idx超過nums.length,根據binary indexed tree的定義,tree[i]的parent即為tree[i+(i&-i)]。 實作上先宣告變數int diff = val - nums[i][j]。再用上面的方法找到有用到nums[i][j]的tree nodes。 range sum的算法以此圖為例,如果要求(2,3)和(4,5)之間的range sum(紅色範圍),取(4,5)以內的range sum(藍色大框),減掉(4,2)和(1,5)的range sum(綠色範圍),再加回(1,2)的range sum(藍色小框),即可得到。
Implementation: int[][] tree 儲存binary indexed tree中每個node的值 int[][] nums 儲存input array 程式分成三個部分:
- Constructor: new出tree和nums的空間。再用update()初始化tree和nums。
- update(): int diff = val - nums[i][j],nums[i][j]在初始化之前是0,所以diff就會是val。用兩層for迴圈,每個迴圈i+(i&-i)還有j+(j&-j),把所有有用到nums[i][j]的node都加上diff。還要把nums[i][j]+=diff。
- rangeSum(): 寫另一個function sum(row,col),取(row,col)以內的range sum。return sum(藍色大框)+sum(藍色小框)-sum(綠色框1)-sum(綠色框2)
Code:
public class NumMatrix {
int m;
int n;
int[][] tree;
int[][] nums;
public NumMatrix(int[][] matrix) {
if (matrix.length == 0 || matrix[0].length == 0) return;
m = matrix.length;
n = matrix[0].length;
tree = new int[m+1][n+1];
nums = new int[m][n];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
update(i, j, matrix[i][j]);
}
}
}
public void update(int row, int col, int val) {
if(m==0 || n==0) return;
int diff = val - nums[row][col];
nums[row][col] = val;
for(int i= row+1; i<=m; i+=(i&-i)){
for(int j= col+1; j<=n; j+=(j&-j)){
tree[i][j] += diff;
}
}
}
public int sumRegion(int row1, int col1, int row2, int col2) {
if (m == 0 || n == 0) return 0;
return sum(row2+1,col2+1)+sum(row1,col1)-sum(row1,col2+1)-sum(row2+1,col1);
}
public int sum(int row, int col){
int sum = 0;
for (int i = row; i > 0; i -= i & (-i)) {
for (int j = col; j > 0; j -= j & (-j)) {
sum += tree[i][j];
}
}
return sum;
}
}
/**
* Your NumMatrix object will be instantiated and called as such:
* NumMatrix obj = new NumMatrix(matrix);
* obj.update(row,col,val);
* int param_2 = obj.sumRegion(row1,col1,row2,col2);
*/