Kth Smallest Element in a Sorted Matrix
Given an n x n matrix where each row and each column is sorted in ascending order, return the kth smallest value in the matrix.
The rank is based on sorted order of all n^2 entries, so duplicate numbers count multiple times.
Example:
Input:
matrix = [ [ 1, 5, 9], [10, 11, 13], [12, 13, 15] ], k = 8,
Output: 13
Note:
You may assume k is always valid, 1 ≤ k ≤ n^2, and 1 <= n <= 1000.
Try it yourself
Solution
Start with the direct method
The direct method is to flatten all n^2 values into one array, sort that array, and return index k - 1. This works, but sorting the whole matrix costs O(n^2 log n).
The matrix already gives us useful order: every row is sorted, and every column is sorted. We should use that structure instead of ignoring it.
Row-pointer min heap
Think of each row as a sorted list. If we place one pointer at the first column of every row, then the smallest unprocessed value in each row is known. The global next value must be the minimum among those pointers.
A min heap is exactly the right tool for this "next smallest among several candidates" task. Each heap entry stores (value, row, col), so after popping one value we know which row it came from and can push the next column from that same row.
For the sample matrix, the heap starts with 1, 10, and 12 (the first value of each row). Pop 1, then push 5 from row 0. Pop 5, then push 9. Continue this pop-then-push process until the kth pop; that popped value is the answer.
The following figures show this idea:
The visual sequence above shows the row-pointer approach. The remaining question is whether we can push fewer elements into the heap while still preserving sorted order.
Optimized min heap: diagonal wave expansion
Instead of pushing the entire first column (or first row) at the beginning, we can start from only (0, 0) and expand outward. When a cell (r, c) is popped, two neighbors become candidates: right (r, c + 1) and down (r + 1, c). We only push a neighbor when all cells that must come before it have already been processed. This keeps the heap focused on the current frontier of possible answers.
To enforce that rule, track frontier progress with two arrays. rowFirst[r] is the first column in row r that has not been processed yet. columnTop[c] is the first row in column c that has not been processed yet. A right neighbor (r, c + 1) is eligible exactly when columnTop[c + 1] == r, and a down neighbor (r + 1, c) is eligible exactly when rowFirst[r + 1] == c.
We repeat "pop the smallest, unlock eligible neighbors, push unlocked neighbors" for k - 1 iterations. The top of the heap is then the kth smallest value.
Here's a visual representation of this optimized process:
For the row-pointer heap, initialization costs O(n) because we push one entry per row. Then we perform k pop operations, each with at most one push, and each heap operation costs O(log n). Total time is O(n + k log n), and heap space is O(n).
For diagonal wave expansion, we process k values and keep only frontier cells in the heap. The heap never grows beyond min(k, n), so time is O(k log(min(k, n))). The frontier arrays take O(n) space, and the heap also stays within O(min(k, n)).
Both methods are correct and pass constraints. The row-pointer method is simpler to reason about, while the diagonal wave method reduces extra heap work when k is small compared with n.
1from heapq import heappop, heappush
2
3def kth_smallest(matrix: list[list[int]], k: int) -> int:
4 n = len(matrix)
5 # Keeps track of items in the heap, and their row and column numbers
6 heap = [(matrix[0][0], 0, 0)]
7 # Keeps track of the top of each row that is not processed
8 column_top = [0] * n
9 # Keeps track of the first number each row not processed
10 row_first = [0] * n
11 # Repeat the process k - 1 times.
12 while k > 1:
13 k -= 1
14 min_val, row, column = heappop(heap)
15 row_first[row] = column + 1
16 # Add the item on the right to the heap if everything above it is processed
17 if column + 1 < n and column_top[column + 1] == row:
18 heappush(heap, (matrix[row][column + 1], row, column + 1))
19 column_top[column] = row + 1
20 # Add the item below it to the heap if everything before it is processed
21 if row + 1 < n and row_first[row + 1] == column:
22 heappush(heap, (matrix[row + 1][column], row + 1, column))
23 return heap[0][0]
24
25if __name__ == "__main__":
26 matrix = [[int(x) for x in input().split()] for _ in range(int(input()))]
27 k = int(input())
28 res = kth_smallest(matrix, k)
29 print(res)
301import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.PriorityQueue;
5import java.util.Scanner;
6import java.util.stream.Collectors;
7
8class Solution {
9 public static int kthSmallest(List<List<Integer>> matrix, int k) {
10 int n = matrix.size();
11 // Keeps track of row and column numbers of items in the heap
12 // The smallest item represented by the row and column number is added to the top
13 PriorityQueue<int[]> heap = new PriorityQueue<>(
14 (a, b) -> Integer.compare(matrix.get(a[0]).get(a[1]), matrix.get(b[0]).get(b[1])));
15 heap.offer(new int[] {0, 0});
16 // Keeps track of the top of each row that is not processed
17 int[] columnTop = new int[n];
18 // Keeps track of the first number each row not processed
19 int[] rowFirst = new int[n];
20 // Repeat the process k - 1 times.
21 while (k > 1) {
22 k--;
23 int[] coords = heap.poll();
24 int row = coords[0], column = coords[1];
25 rowFirst[row] = column + 1;
26 // Add the item on the right to the heap if everything above it is processed
27 if (column + 1 < n && columnTop[column + 1] == row) {
28 heap.offer(new int[] {row, column + 1});
29 }
30 columnTop[column] = row + 1;
31 // Add the item below it to the heap if everything before it is processed
32 if (row + 1 < n && rowFirst[row + 1] == column) {
33 heap.offer(new int[] {row + 1, column});
34 }
35 }
36 int[] resCoords = heap.poll();
37 return matrix.get(resCoords[0]).get(resCoords[1]);
38 }
39
40 public static List<String> splitWords(String s) {
41 return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
42 }
43
44 public static void main(String[] args) {
45 Scanner scanner = new Scanner(System.in);
46 int matrixLength = Integer.parseInt(scanner.nextLine());
47 List<List<Integer>> matrix = new ArrayList<>();
48 for (int i = 0; i < matrixLength; i++) {
49 matrix.add(splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList()));
50 }
51 int k = Integer.parseInt(scanner.nextLine());
52 scanner.close();
53 int res = kthSmallest(matrix, k);
54 System.out.println(res);
55 }
56}
571"use strict";
2
3class HeapItem {
4 constructor(item, priority = item) {
5 this.item = item;
6 this.priority = priority;
7 }
8}
9
10class MinHeap {
11 constructor() {
12 this.heap = [];
13 }
14
15 push(node) {
16 // insert the new node at the end of the heap array
17 this.heap.push(node);
18 // find the correct position for the new node
19 this.bubbleUp();
20 }
21
22 bubbleUp() {
23 let index = this.heap.length - 1;
24
25 while (index > 0) {
26 const element = this.heap[index];
27 const parentIndex = Math.floor((index - 1) / 2);
28 const parent = this.heap[parentIndex];
29
30 if (parent.priority <= element.priority) break;
31 // if the parent is bigger than the child then swap the parent and child
32 this.heap[index] = parent;
33 this.heap[parentIndex] = element;
34 index = parentIndex;
35 }
36 }
37
38 pop() {
39 const min = this.heap[0];
40 this.heap[0] = this.heap[this.size() - 1];
41 this.heap.pop();
42 this.bubbleDown();
43 return min;
44 }
45
46 bubbleDown() {
47 let index = 0;
48 let min = index;
49 const n = this.heap.length;
50
51 while (index < n) {
52 const left = 2 * index + 1;
53 const right = left + 1;
54
55 if (left < n && this.heap[left].priority < this.heap[min].priority) {
56 min = left;
57 }
58 if (right < n && this.heap[right].priority < this.heap[min].priority) {
59 min = right;
60 }
61 if (min === index) break;
62 [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
63 index = min;
64 }
65 }
66
67 peek() {
68 return this.heap[0];
69 }
70
71 size() {
72 return this.heap.length;
73 }
74}
75
76function kthSmallest(matrix, k) {
77 const n = matrix.length;
78 const heap = new MinHeap();
79 heap.push(new HeapItem([0, 0], matrix[0][0]));
80 const columnTop = Array(n).fill(0);
81 const rowFirst = Array(n).fill(0);
82 while (k > 1) {
83 k -= 1;
84 const [row, col] = heap.pop().item;
85 rowFirst[row] = col + 1;
86 if (col + 1 < n && columnTop[col + 1] === row) {
87 heap.push(new HeapItem([row, col + 1], matrix[row][col + 1]));
88 }
89 columnTop[col] = row + 1;
90 if (row + 1 < n && rowFirst[row + 1] === col) {
91 heap.push(new HeapItem([row + 1, col], matrix[row + 1][col]));
92 }
93 }
94 const [resRow, resCol] = heap.pop().item;
95 return matrix[resRow][resCol];
96}
97
98function splitWords(s) {
99 return s === "" ? [] : s.split(" ");
100}
101
102function* main() {
103 const matrixLength = parseInt(yield);
104 const matrix = [];
105 for (let i = 0; i < matrixLength; i++) {
106 matrix.push(splitWords(yield).map((v) => parseInt(v)));
107 }
108 const k = parseInt(yield);
109 const res = kthSmallest(matrix, k);
110 console.log(res);
111}
112
113class EOFError extends Error {}
114{
115 const gen = main();
116 const next = (line) => gen.next(line).done && process.exit();
117 let buf = "";
118 next();
119 process.stdin.setEncoding("utf8");
120 process.stdin.on("data", (data) => {
121 const lines = (buf + data).split("\n");
122 buf = lines.pop();
123 lines.forEach(next);
124 });
125 process.stdin.on("end", () => {
126 buf && next(buf);
127 gen.throw(new EOFError());
128 });
129}
1301class HeapItem<T> {
2 item: T;
3 priority: number;
4
5 constructor(item: T, priority: number = Number(item)) {
6 this.item = item;
7 this.priority = priority;
8 }
9}
10
11class MinHeap<T> {
12 private heap: HeapItem<T>[];
13
14 constructor() {
15 this.heap = [];
16 }
17
18 push(node: HeapItem<T>): void {
19 this.heap.push(node);
20 this.bubbleUp();
21 }
22
23 private bubbleUp(): void {
24 let index = this.heap.length - 1;
25
26 while (index > 0) {
27 const element = this.heap[index];
28 const parentIndex = Math.floor((index - 1) / 2);
29 const parent = this.heap[parentIndex];
30
31 if (parent.priority <= element.priority) break;
32 this.heap[index] = parent;
33 this.heap[parentIndex] = element;
34 index = parentIndex;
35 }
36 }
37
38 pop(): HeapItem<T> | undefined {
39 if (this.heap.length === 0) return undefined;
40 const min = this.heap[0];
41 this.heap[0] = this.heap[this.size() - 1];
42 this.heap.pop();
43 this.bubbleDown();
44 return min;
45 }
46
47 private bubbleDown(): void {
48 let index = 0;
49 let min = index;
50 const n = this.heap.length;
51
52 while (index < n) {
53 const left = 2 * index + 1;
54 const right = left + 1;
55
56 if (left < n && this.heap[left].priority < this.heap[min].priority) {
57 min = left;
58 }
59 if (right < n && this.heap[right].priority < this.heap[min].priority) {
60 min = right;
61 }
62 if (min === index) break;
63 [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
64 index = min;
65 }
66 }
67
68 peek(): HeapItem<T> | undefined {
69 return this.heap[0];
70 }
71
72 size(): number {
73 return this.heap.length;
74 }
75}
76
77function kthSmallest(matrix: number[][], k: number): number {
78 const n = matrix.length;
79 const heap = new MinHeap<number[]>();
80 heap.push(new HeapItem([0, 0], matrix[0][0]));
81 const columnTop: number[] = Array(n).fill(0);
82 const rowFirst: number[] = Array(n).fill(0);
83 while (k > 1) {
84 k -= 1;
85 const [row, col] = heap.pop().item;
86 rowFirst[row] = col + 1;
87 if (col + 1 < n && columnTop[col + 1] === row) {
88 heap.push(new HeapItem([row, col + 1], matrix[row][col + 1]));
89 }
90 columnTop[col] = row + 1;
91 if (row + 1 < n && rowFirst[row + 1] === col) {
92 heap.push(new HeapItem([row + 1, col], matrix[row + 1][col]));
93 }
94 }
95 const [resRow, resCol] = heap.pop().item;
96 return matrix[resRow][resCol];
97}
98
99function splitWords(s: string): string[] {
100 return s === "" ? [] : s.split(" ");
101}
102
103function* main() {
104 const matrixLength = parseInt(yield);
105 const matrix: number[][] = [];
106 for (let i = 0; i < matrixLength; i++) {
107 matrix.push(splitWords(yield).map((v) => parseInt(v)));
108 }
109 const k = parseInt(yield);
110 const res = kthSmallest(matrix, k);
111 console.log(res);
112}
113
114class EOFError extends Error {}
115{
116 const gen = main();
117 const next = (line?: string) => gen.next(line ?? "").done && process.exit();
118 let buf = "";
119 next();
120 process.stdin.setEncoding("utf8");
121 process.stdin.on("data", (data) => {
122 const lines = (buf + data).split("\n");
123 buf = lines.pop() ?? "";
124 lines.forEach(next);
125 });
126 process.stdin.on("end", () => {
127 buf && next(buf);
128 gen.throw(new EOFError());
129 });
130}
1311#include <algorithm>
2#include <iostream>
3#include <iterator>
4#include <limits>
5#include <queue>
6#include <sstream>
7#include <string>
8#include <vector>
9
10int kth_smallest(std::vector<std::vector<int>>& matrix, int k) {
11 int n = matrix.size();
12 auto compare_pos = [&matrix](std::vector<int> pos1, std::vector<int> pos2) {
13 return matrix[pos1[0]][pos1[1]] > matrix[pos2[0]][pos2[1]];
14 };
15 // Keeps track of row and column numbers of items in the heap
16 // The smallest item represented by the row and column number is added to the top
17 std::priority_queue<std::vector<int>, std::vector<std::vector<int>>, decltype(compare_pos)> heap(compare_pos);
18 heap.push({0, 0});
19 // Keeps track of the top of each row that is not processed
20 std::vector<int> column_top(n);
21 // Keeps track of the first number of each row that is not processed
22 std::vector<int> row_first(n);
23 // Repeat the process k - 1 times
24 while (k > 1) {
25 k--;
26 std::vector<int> coords = heap.top();
27 heap.pop();
28 int row = coords[0];
29 int col = coords[1];
30 row_first[row] = col + 1;
31 // Add the item on the right to the heap if everything above it is processed
32 if (col + 1 < n && column_top[col + 1] == row) {
33 heap.push({row, col + 1});
34 }
35 column_top[col] = row + 1;
36 // Add the item below it to the heap if everything before it is processed
37 if (row + 1 < n && row_first[row + 1] == col) {
38 heap.push({row + 1, col});
39 }
40 }
41 std::vector<int> res = heap.top();
42 return matrix[res[0]][res[1]];
43}
44
45template<typename T>
46std::vector<T> get_words() {
47 std::string line;
48 std::getline(std::cin, line);
49 std::istringstream ss{line};
50 ss >> std::boolalpha;
51 std::vector<T> v;
52 std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
53 return v;
54}
55
56void ignore_line() {
57 std::cin.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
58}
59
60int main() {
61 int matrix_length;
62 std::cin >> matrix_length;
63 ignore_line();
64 std::vector<std::vector<int>> matrix;
65 for (int i = 0; i < matrix_length; i++) {
66 matrix.emplace_back(get_words<int>());
67 }
68 int k;
69 std::cin >> k;
70 ignore_line();
71 int res = kth_smallest(matrix, k);
72 std::cout << res << '\n';
73}
74
























