Facebook Pixel

Kth Smallest Element in a Sorted Matrix

Given an n x n matrix where each row and each column is sorted in ascending order, return the kth smallest value in the matrix.

The rank is based on sorted order of all n^2 entries, so duplicate numbers count multiple times.

Example:

Input:

matrix = [
  [ 1,  5,  9],
  [10, 11, 13],
  [12, 13, 15]
],
k = 8,

Output: 13

Note:

You may assume k is always valid, 1 ≤ k ≤ n^2, and 1 <= n <= 1000.

Try it yourself

Solution

Start with the direct method

The direct method is to flatten all n^2 values into one array, sort that array, and return index k - 1. This works, but sorting the whole matrix costs O(n^2 log n).

The matrix already gives us useful order: every row is sorted, and every column is sorted. We should use that structure instead of ignoring it.

Row-pointer min heap

Think of each row as a sorted list. If we place one pointer at the first column of every row, then the smallest unprocessed value in each row is known. The global next value must be the minimum among those pointers.

A min heap is exactly the right tool for this "next smallest among several candidates" task. Each heap entry stores (value, row, col), so after popping one value we know which row it came from and can push the next column from that same row.

For the sample matrix, the heap starts with 1, 10, and 12 (the first value of each row). Pop 1, then push 5 from row 0. Pop 5, then push 9. Continue this pop-then-push process until the kth pop; that popped value is the answer.

The following figures show this idea:

The visual sequence above shows the row-pointer approach. The remaining question is whether we can push fewer elements into the heap while still preserving sorted order.

Optimized min heap: diagonal wave expansion

Instead of pushing the entire first column (or first row) at the beginning, we can start from only (0, 0) and expand outward. When a cell (r, c) is popped, two neighbors become candidates: right (r, c + 1) and down (r + 1, c). We only push a neighbor when all cells that must come before it have already been processed. This keeps the heap focused on the current frontier of possible answers.

To enforce that rule, track frontier progress with two arrays. rowFirst[r] is the first column in row r that has not been processed yet. columnTop[c] is the first row in column c that has not been processed yet. A right neighbor (r, c + 1) is eligible exactly when columnTop[c + 1] == r, and a down neighbor (r + 1, c) is eligible exactly when rowFirst[r + 1] == c.

We repeat "pop the smallest, unlock eligible neighbors, push unlocked neighbors" for k - 1 iterations. The top of the heap is then the kth smallest value.

Here's a visual representation of this optimized process:

For the row-pointer heap, initialization costs O(n) because we push one entry per row. Then we perform k pop operations, each with at most one push, and each heap operation costs O(log n). Total time is O(n + k log n), and heap space is O(n).

For diagonal wave expansion, we process k values and keep only frontier cells in the heap. The heap never grows beyond min(k, n), so time is O(k log(min(k, n))). The frontier arrays take O(n) space, and the heap also stays within O(min(k, n)).

Both methods are correct and pass constraints. The row-pointer method is simpler to reason about, while the diagonal wave method reduces extra heap work when k is small compared with n.

1from heapq import heappop, heappush
2
3def kth_smallest(matrix: list[list[int]], k: int) -> int:
4    n = len(matrix)
5    # Keeps track of items in the heap, and their row and column numbers
6    heap = [(matrix[0][0], 0, 0)]
7    # Keeps track of the top of each row that is not processed
8    column_top = [0] * n
9    # Keeps track of the first number each row not processed
10    row_first = [0] * n
11    # Repeat the process k - 1 times.
12    while k > 1:
13        k -= 1
14        min_val, row, column = heappop(heap)
15        row_first[row] = column + 1
16        # Add the item on the right to the heap if everything above it is processed
17        if column + 1 < n and column_top[column + 1] == row:
18            heappush(heap, (matrix[row][column + 1], row, column + 1))
19        column_top[column] = row + 1
20        # Add the item below it to the heap if everything before it is processed
21        if row + 1 < n and row_first[row + 1] == column:
22            heappush(heap, (matrix[row + 1][column], row + 1, column))
23    return heap[0][0]
24
25if __name__ == "__main__":
26    matrix = [[int(x) for x in input().split()] for _ in range(int(input()))]
27    k = int(input())
28    res = kth_smallest(matrix, k)
29    print(res)
30
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.PriorityQueue;
5import java.util.Scanner;
6import java.util.stream.Collectors;
7
8class Solution {
9    public static int kthSmallest(List<List<Integer>> matrix, int k) {
10        int n = matrix.size();
11        // Keeps track of row and column numbers of items in the heap
12        // The smallest item represented by the row and column number is added to the top
13        PriorityQueue<int[]> heap = new PriorityQueue<>(
14            (a, b) -> Integer.compare(matrix.get(a[0]).get(a[1]), matrix.get(b[0]).get(b[1])));
15        heap.offer(new int[] {0, 0});
16        // Keeps track of the top of each row that is not processed
17        int[] columnTop = new int[n];
18        // Keeps track of the first number each row not processed
19        int[] rowFirst = new int[n];
20        // Repeat the process k - 1 times.
21        while (k > 1) {
22            k--;
23            int[] coords = heap.poll();
24            int row = coords[0], column = coords[1];
25            rowFirst[row] = column + 1;
26            // Add the item on the right to the heap if everything above it is processed
27            if (column + 1 < n && columnTop[column + 1] == row) {
28                heap.offer(new int[] {row, column + 1});
29            }
30            columnTop[column] = row + 1;
31            // Add the item below it to the heap if everything before it is processed
32            if (row + 1 < n && rowFirst[row + 1] == column) {
33                heap.offer(new int[] {row + 1, column});
34            }
35        }
36        int[] resCoords = heap.poll();
37        return matrix.get(resCoords[0]).get(resCoords[1]);
38    }
39
40    public static List<String> splitWords(String s) {
41        return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
42    }
43
44    public static void main(String[] args) {
45        Scanner scanner = new Scanner(System.in);
46        int matrixLength = Integer.parseInt(scanner.nextLine());
47        List<List<Integer>> matrix = new ArrayList<>();
48        for (int i = 0; i < matrixLength; i++) {
49            matrix.add(splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList()));
50        }
51        int k = Integer.parseInt(scanner.nextLine());
52        scanner.close();
53        int res = kthSmallest(matrix, k);
54        System.out.println(res);
55    }
56}
57
1"use strict";
2
3class HeapItem {
4    constructor(item, priority = item) {
5        this.item = item;
6        this.priority = priority;
7    }
8}
9
10class MinHeap {
11    constructor() {
12        this.heap = [];
13    }
14
15    push(node) {
16        // insert the new node at the end of the heap array
17        this.heap.push(node);
18        // find the correct position for the new node
19        this.bubbleUp();
20    }
21
22    bubbleUp() {
23        let index = this.heap.length - 1;
24
25        while (index > 0) {
26            const element = this.heap[index];
27            const parentIndex = Math.floor((index - 1) / 2);
28            const parent = this.heap[parentIndex];
29
30            if (parent.priority <= element.priority) break;
31            // if the parent is bigger than the child then swap the parent and child
32            this.heap[index] = parent;
33            this.heap[parentIndex] = element;
34            index = parentIndex;
35        }
36    }
37
38    pop() {
39        const min = this.heap[0];
40        this.heap[0] = this.heap[this.size() - 1];
41        this.heap.pop();
42        this.bubbleDown();
43        return min;
44    }
45
46    bubbleDown() {
47        let index = 0;
48        let min = index;
49        const n = this.heap.length;
50
51        while (index < n) {
52            const left = 2 * index + 1;
53            const right = left + 1;
54
55            if (left < n && this.heap[left].priority < this.heap[min].priority) {
56                min = left;
57            }
58            if (right < n && this.heap[right].priority < this.heap[min].priority) {
59                min = right;
60            }
61            if (min === index) break;
62            [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
63            index = min;
64        }
65    }
66
67    peek() {
68        return this.heap[0];
69    }
70
71    size() {
72        return this.heap.length;
73    }
74}
75
76function kthSmallest(matrix, k) {
77    const n = matrix.length;
78    const heap = new MinHeap();
79    heap.push(new HeapItem([0, 0], matrix[0][0]));
80    const columnTop = Array(n).fill(0);
81    const rowFirst = Array(n).fill(0);
82    while (k > 1) {
83        k -= 1;
84        const [row, col] = heap.pop().item;
85        rowFirst[row] = col + 1;
86        if (col + 1 < n && columnTop[col + 1] === row) {
87            heap.push(new HeapItem([row, col + 1], matrix[row][col + 1]));
88        }
89        columnTop[col] = row + 1;
90        if (row + 1 < n && rowFirst[row + 1] === col) {
91            heap.push(new HeapItem([row + 1, col], matrix[row + 1][col]));
92        }
93    }
94    const [resRow, resCol] = heap.pop().item;
95    return matrix[resRow][resCol];
96}
97
98function splitWords(s) {
99    return s === "" ? [] : s.split(" ");
100}
101
102function* main() {
103    const matrixLength = parseInt(yield);
104    const matrix = [];
105    for (let i = 0; i < matrixLength; i++) {
106        matrix.push(splitWords(yield).map((v) => parseInt(v)));
107    }
108    const k = parseInt(yield);
109    const res = kthSmallest(matrix, k);
110    console.log(res);
111}
112
113class EOFError extends Error {}
114{
115    const gen = main();
116    const next = (line) => gen.next(line).done && process.exit();
117    let buf = "";
118    next();
119    process.stdin.setEncoding("utf8");
120    process.stdin.on("data", (data) => {
121        const lines = (buf + data).split("\n");
122        buf = lines.pop();
123        lines.forEach(next);
124    });
125    process.stdin.on("end", () => {
126        buf && next(buf);
127        gen.throw(new EOFError());
128    });
129}
130
1class HeapItem<T> {
2    item: T;
3    priority: number;
4
5    constructor(item: T, priority: number = Number(item)) {
6        this.item = item;
7        this.priority = priority;
8    }
9}
10
11class MinHeap<T> {
12    private heap: HeapItem<T>[];
13
14    constructor() {
15        this.heap = [];
16    }
17
18    push(node: HeapItem<T>): void {
19        this.heap.push(node);
20        this.bubbleUp();
21    }
22
23    private bubbleUp(): void {
24        let index = this.heap.length - 1;
25
26        while (index > 0) {
27            const element = this.heap[index];
28            const parentIndex = Math.floor((index - 1) / 2);
29            const parent = this.heap[parentIndex];
30
31            if (parent.priority <= element.priority) break;
32            this.heap[index] = parent;
33            this.heap[parentIndex] = element;
34            index = parentIndex;
35        }
36    }
37
38    pop(): HeapItem<T> | undefined {
39        if (this.heap.length === 0) return undefined;
40        const min = this.heap[0];
41        this.heap[0] = this.heap[this.size() - 1];
42        this.heap.pop();
43        this.bubbleDown();
44        return min;
45    }
46
47    private bubbleDown(): void {
48        let index = 0;
49        let min = index;
50        const n = this.heap.length;
51
52        while (index < n) {
53            const left = 2 * index + 1;
54            const right = left + 1;
55
56            if (left < n && this.heap[left].priority < this.heap[min].priority) {
57                min = left;
58            }
59            if (right < n && this.heap[right].priority < this.heap[min].priority) {
60                min = right;
61            }
62            if (min === index) break;
63            [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
64            index = min;
65        }
66    }
67
68    peek(): HeapItem<T> | undefined {
69        return this.heap[0];
70    }
71
72    size(): number {
73        return this.heap.length;
74    }
75}
76
77function kthSmallest(matrix: number[][], k: number): number {
78    const n = matrix.length;
79    const heap = new MinHeap<number[]>();
80    heap.push(new HeapItem([0, 0], matrix[0][0]));
81    const columnTop: number[] = Array(n).fill(0);
82    const rowFirst: number[] = Array(n).fill(0);
83    while (k > 1) {
84        k -= 1;
85        const [row, col] = heap.pop().item;
86        rowFirst[row] = col + 1;
87        if (col + 1 < n && columnTop[col + 1] === row) {
88            heap.push(new HeapItem([row, col + 1], matrix[row][col + 1]));
89        }
90        columnTop[col] = row + 1;
91        if (row + 1 < n && rowFirst[row + 1] === col) {
92            heap.push(new HeapItem([row + 1, col], matrix[row + 1][col]));
93        }
94    }
95    const [resRow, resCol] = heap.pop().item;
96    return matrix[resRow][resCol];
97}
98
99function splitWords(s: string): string[] {
100    return s === "" ? [] : s.split(" ");
101}
102
103function* main() {
104    const matrixLength = parseInt(yield);
105    const matrix: number[][] = [];
106    for (let i = 0; i < matrixLength; i++) {
107        matrix.push(splitWords(yield).map((v) => parseInt(v)));
108    }
109    const k = parseInt(yield);
110    const res = kthSmallest(matrix, k);
111    console.log(res);
112}
113
114class EOFError extends Error {}
115{
116    const gen = main();
117    const next = (line?: string) => gen.next(line ?? "").done && process.exit();
118    let buf = "";
119    next();
120    process.stdin.setEncoding("utf8");
121    process.stdin.on("data", (data) => {
122        const lines = (buf + data).split("\n");
123        buf = lines.pop() ?? "";
124        lines.forEach(next);
125    });
126    process.stdin.on("end", () => {
127        buf && next(buf);
128        gen.throw(new EOFError());
129    });
130}
131
1#include <algorithm>
2#include <iostream>
3#include <iterator>
4#include <limits>
5#include <queue>
6#include <sstream>
7#include <string>
8#include <vector>
9
10int kth_smallest(std::vector<std::vector<int>>& matrix, int k) {
11    int n = matrix.size();
12    auto compare_pos = [&matrix](std::vector<int> pos1, std::vector<int> pos2) {
13        return matrix[pos1[0]][pos1[1]] > matrix[pos2[0]][pos2[1]];
14    };
15    // Keeps track of row and column numbers of items in the heap
16    // The smallest item represented by the row and column number is added to the top
17    std::priority_queue<std::vector<int>, std::vector<std::vector<int>>, decltype(compare_pos)> heap(compare_pos);
18    heap.push({0, 0});
19    // Keeps track of the top of each row that is not processed
20    std::vector<int> column_top(n);
21    // Keeps track of the first number of each row that is not processed
22    std::vector<int> row_first(n);
23    // Repeat the process k - 1 times
24    while (k > 1) {
25        k--;
26        std::vector<int> coords = heap.top();
27        heap.pop();
28        int row = coords[0];
29        int col = coords[1];
30        row_first[row] = col + 1;
31        // Add the item on the right to the heap if everything above it is processed
32        if (col + 1 < n && column_top[col + 1] == row) {
33            heap.push({row, col + 1});
34        }
35        column_top[col] = row + 1;
36        // Add the item below it to the heap if everything before it is processed
37        if (row + 1 < n && row_first[row + 1] == col) {
38            heap.push({row + 1, col});
39        }
40    }
41    std::vector<int> res = heap.top();
42    return matrix[res[0]][res[1]];
43}
44
45template<typename T>
46std::vector<T> get_words() {
47    std::string line;
48    std::getline(std::cin, line);
49    std::istringstream ss{line};
50    ss >> std::boolalpha;
51    std::vector<T> v;
52    std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
53    return v;
54}
55
56void ignore_line() {
57    std::cin.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
58}
59
60int main() {
61    int matrix_length;
62    std::cin >> matrix_length;
63    ignore_line();
64    std::vector<std::vector<int>> matrix;
65    for (int i = 0; i < matrix_length; i++) {
66        matrix.emplace_back(get_words<int>());
67    }
68    int k;
69    std::cin >> k;
70    ignore_line();
71    int res = kth_smallest(matrix, k);
72    std::cout << res << '\n';
73}
74
Invest in Yourself
Your new job is waiting. 83% of people that complete the program get a job offer. Unlock unlimited access to all content and features.
Go Pro