# 308 Range Sum Query 2D - Mutable

給一個2D array,還有左上角、右下角的index,問範圍內的總和。 另有update function,可以指定更新某index的值。

Given matrix = [
  [3, 0, 1, 4, 2],
  [5, 6, 3, 2, 1],
  [1, 2, 0, 1, 5],
  [4, 1, 0, 1, 7],
  [1, 0, 3, 0, 5]
]

sumRegion(2, 1, 4, 3) -> 8
update(3, 2, 2)
sumRegion(2, 1, 4, 3) -> 10

Concept: 和307一樣用Binary Index Tree。 複習一下BIT的存法: 當tree[i][j]有用到nums[i][j],它的ancestor就也有用到nums[i][j],所以必須找出這些tree[i][j]所有的ancestor並更新。找到ancestor的方法為:idx = i, idx = i+(i&-i), ... 直到idx超過nums.length,根據binary indexed tree的定義,tree[i]的parent即為tree[i+(i&-i)]。 實作上先宣告變數int diff = val - nums[i][j]。再用上面的方法找到有用到nums[i][j]的tree nodes。 range sum的算法以此圖為例,如果要求(2,3)和(4,5)之間的range sum(紅色範圍),取(4,5)以內的range sum(藍色大框),減掉(4,2)和(1,5)的range sum(綠色範圍),再加回(1,2)的range sum(藍色小框),即可得到。

Implementation: int[][] tree 儲存binary indexed tree中每個node的值 int[][] nums 儲存input array 程式分成三個部分:

  1. Constructor: new出tree和nums的空間。再用update()初始化tree和nums。
  2. update(): int diff = val - nums[i][j],nums[i][j]在初始化之前是0,所以diff就會是val。用兩層for迴圈,每個迴圈i+(i&-i)還有j+(j&-j),把所有有用到nums[i][j]的node都加上diff。還要把nums[i][j]+=diff。
  3. rangeSum(): 寫另一個function sum(row,col),取(row,col)以內的range sum。return sum(藍色大框)+sum(藍色小框)-sum(綠色框1)-sum(綠色框2)

Code:

public class NumMatrix {
    int m;
    int n;
    int[][] tree;
    int[][] nums;
    public NumMatrix(int[][] matrix) {
        if (matrix.length == 0 || matrix[0].length == 0) return;
        m = matrix.length;
        n = matrix[0].length;
        tree = new int[m+1][n+1];
        nums = new int[m][n];
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                update(i, j, matrix[i][j]);
            }
        }
    }

    public void update(int row, int col, int val) {
        if(m==0 || n==0)    return;
        int diff = val - nums[row][col];
        nums[row][col] = val;
        for(int i= row+1; i<=m; i+=(i&-i)){
            for(int j= col+1; j<=n; j+=(j&-j)){
                tree[i][j] += diff;
            }
        }
    }

    public int sumRegion(int row1, int col1, int row2, int col2) {
        if (m == 0 || n == 0) return 0;
        return sum(row2+1,col2+1)+sum(row1,col1)-sum(row1,col2+1)-sum(row2+1,col1);
    }

    public int sum(int row, int col){
        int sum = 0;
        for (int i = row; i > 0; i -= i & (-i)) {
            for (int j = col; j > 0; j -= j & (-j)) {
                sum += tree[i][j];
            }
        }
        return sum;
    }
}

/**
 * Your NumMatrix object will be instantiated and called as such:
 * NumMatrix obj = new NumMatrix(matrix);
 * obj.update(row,col,val);
 * int param_2 = obj.sumRegion(row1,col1,row2,col2);
 */

results matching ""

    No results matching ""