Facebook Pixel

Count of Smaller Numbers after Self | Number of Swaps to Sort | Algorithm Swap

Given an integer array nums, return an array counts where counts[i] is the number of elements to the right of index i that are smaller than nums[i].

Example

Input: [5,2,6,1]

Output: [2,1,1,0]

At index 0, value 5 has two smaller values on its right (2 and 1), so counts[0] = 2. At index 1, value 2 has one smaller value on its right (1), so counts[1] = 1. At index 2, value 6 has one smaller value on its right (1), so counts[2] = 1. At index 3, value 1 has no elements on its right, so counts[3] = 0.

Number of swaps to sort

A closely related question asks: if we repeatedly fix inversions, how many swaps are needed to sort the array? Each inversion is a pair (i, j) with i < j and nums[i] > nums[j].

The total number of swaps equals the total number of inversions. In this problem, each counts[i] tells you how many inversions start at index i, so summing the counts array gives the total inversion count. For [5,2,6,1], that sum is 2 + 1 + 1 + 0 = 4.

Try it yourself

Explanation

Intuition

The brute force approach checks every pair (i, j) with i < j, which takes O(n^2) time. To do better, we need a way to count many "smaller on the right" relationships at once.

Merge sort gives exactly that. We split by index, recursively solve each half, and count cross-half relationships during merge. Splitting by index is important: every element in the left half originally appears before every element in the right half, so any left_value > right_value pair is a valid "smaller on the right" match for the left element.

During merge, both halves are already sorted. When a right value is smaller than the current left value, it will be placed first in the merged array. We keep a counter of how many right values have moved ahead so far. When we finally place a left value, we add that counter to its answer because those moved right values are exactly the smaller elements that were to its right in the original array.

This counting happens at every recursion level and each level processes all n elements once in merge. That gives the standard recurrence:

T(N) = 2T(N/2) + O(N)

The total time complexity is O(n log n).

Space Complexity: O(n)

Implementation

1def count_smaller(nums: list[int]) -> list[int]:
2    smaller_arr = [0] * len(nums)
3
4    def merge_sort(nums):
5        if len(nums) <= 1:
6            return nums
7        mid = len(nums) // 2
8        left = merge_sort(nums[:mid])
9        right = merge_sort(nums[mid:])
10        return merge(left, right)
11
12    def merge(left, right):
13        result = []
14        l, r = 0, 0
15        while l < len(left) or r < len(right):
16            if r >= len(right) or (l < len(left) and left[l][1] <= right[r][1]):
17                result.append(left[l])
18                smaller_arr[left[l][0]] += r
19                l += 1
20            else:
21                result.append(right[r])
22                r += 1
23        return result
24
25    merge_sort(list(enumerate(nums)))
26    return smaller_arr
27
28if __name__ == "__main__":
29    nums = [int(x) for x in input().split()]
30    res = count_smaller(nums)
31    print(" ".join(map(str, res)))
32
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.Scanner;
5import java.util.stream.Collectors;
6
7class Solution {
8    public static class Element {
9        int val;
10        int ind;
11
12        public Element(int val, int ind) {
13            this.val = val;
14            this.ind = ind;
15        }
16    }
17
18    public static List<Integer> smallerArr = new ArrayList<Integer>();
19
20    public static List<Element> mergeSort(List<Element> nums) {
21        if (nums.size() <= 1) {
22            return nums;
23        }
24        int mid = nums.size() / 2;
25        List<Element> splitLeft = new ArrayList<Element>();
26        List<Element> splitRight = new ArrayList<Element>();
27        for (int i = 0; i < nums.size(); i++) {
28            if (i < nums.size() / 2) {
29                splitLeft.add(nums.get(i));
30            } else {
31                splitRight.add(nums.get(i));
32            }
33        }
34        List<Element> left = mergeSort(splitLeft);
35        List<Element> right = mergeSort(splitRight);
36        return merge(left, right);
37    }
38
39    public static List<Element> merge(List<Element> left, List<Element> right) {
40        List<Element> result = new ArrayList<Element>();
41        int l = 0;
42        int r = 0;
43        while (l < left.size() || r < right.size()) {
44            if (r >= right.size() || (l < left.size() && left.get(l).val <= right.get(r).val)) {
45                result.add(left.get(l));
46                smallerArr.set(left.get(l).ind, smallerArr.get(left.get(l).ind) + r);
47                l += 1;
48            } else {
49                result.add(right.get(r));
50                r += 1;
51            }
52        }
53        return result;
54    }
55
56    public static List<Integer> countSmaller(List<Integer> nums) {
57        for (int i = 0; i < nums.size(); i++) {
58            smallerArr.add(0);
59        }
60        List<Element> temp = new ArrayList<Element>();
61        for (int i = 0; i < nums.size(); i++) {
62            temp.add(new Element(nums.get(i), i));
63        }
64        mergeSort(temp);
65        return smallerArr;
66    }
67
68    public static List<String> splitWords(String s) {
69        return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
70    }
71
72    public static void main(String[] args) {
73        Scanner scanner = new Scanner(System.in);
74        List<Integer> nums = splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList());
75        scanner.close();
76        List<Integer> res = countSmaller(nums);
77        System.out.println(res.stream().map(String::valueOf).collect(Collectors.joining(" ")));
78    }
79}
80
1"use strict";
2
3function countSmaller(nums) {
4    const smallerArr = Array(nums.length).fill(0);
5
6    function merge(left, right) {
7        const result = [];
8        let l = 0;
9        let r = 0;
10        while (l < left.length || r < right.length) {
11            if (r >= right.length || (l < left.length && left[l][1] <= right[r][1])) {
12                result.push(left[l]);
13                smallerArr[left[l][0]] += r;
14                l += 1;
15            } else {
16                result.push(right[r]);
17                r += 1;
18            }
19        }
20        return result;
21    }
22
23    function mergeSort(nums) {
24        if (nums.length <= 1) return nums;
25        const mid = Math.floor(nums.length / 2);
26        const left = mergeSort(nums.slice(0, mid));
27        const right = mergeSort(nums.slice(mid));
28        return merge(left, right);
29    }
30
31    const temp = [];
32    nums.map((e, i) => temp.push([i, e]));
33    mergeSort(temp);
34    return smallerArr;
35}
36
37function splitWords(s) {
38    return s === "" ? [] : s.split(" ");
39}
40
41function* main() {
42    const nums = splitWords(yield).map((v) => parseInt(v));
43    const res = countSmaller(nums);
44    console.log(res.join(" "));
45}
46
47class EOFError extends Error {}
48{
49    const gen = main();
50    const next = (line) => gen.next(line).done && process.exit();
51    let buf = "";
52    next();
53    process.stdin.setEncoding("utf8");
54    process.stdin.on("data", (data) => {
55        const lines = (buf + data).split("\n");
56        buf = lines.pop();
57        lines.forEach(next);
58    });
59    process.stdin.on("end", () => {
60        buf && next(buf);
61        gen.throw(new EOFError());
62    });
63}
64
1function countSmaller(nums: number[]): number[] {
2    const smallerArr: number[] = Array(nums.length).fill(0);
3
4    function merge(left: [number, number][], right: [number, number][]): [number, number][] {
5        const result: [number, number][] = [];
6        let l = 0;
7        let r = 0;
8        while (l < left.length || r < right.length) {
9            if (r >= right.length || (l < left.length && left[l][1] <= right[r][1])) {
10                result.push(left[l]);
11                smallerArr[left[l][0]] += r;
12                l += 1;
13            } else {
14                result.push(right[r]);
15                r += 1;
16            }
17        }
18        return result;
19    }
20
21    function mergeSort(nums: [number, number][]): [number, number][] {
22        if (nums.length <= 1) return nums;
23        const mid = Math.floor(nums.length / 2);
24        const left = mergeSort(nums.slice(0, mid));
25        const right = mergeSort(nums.slice(mid));
26        return merge(left, right);
27    }
28
29    const temp: [number, number][] = nums.map((e, i) => [i, e]);
30    mergeSort(temp);
31    return smallerArr;
32}
33
34function splitWords(s: string): string[] {
35    return s === "" ? [] : s.split(" ");
36}
37
38function* main() {
39    const nums = splitWords(yield).map((v) => parseInt(v));
40    const res = countSmaller(nums);
41    console.log(res.join(" "));
42}
43
44class EOFError extends Error {}
45{
46    const gen = main();
47    const next = (line?: string) => gen.next(line ?? "").done && process.exit();
48    let buf = "";
49    next();
50    process.stdin.setEncoding("utf8");
51    process.stdin.on("data", (data) => {
52        const lines = (buf + data).split("\n");
53        buf = lines.pop() ?? "";
54        lines.forEach(next);
55    });
56    process.stdin.on("end", () => {
57        buf && next(buf);
58        gen.throw(new EOFError());
59    });
60}
61
1#include <algorithm>
2#include <iostream>
3#include <iterator>
4#include <sstream>
5#include <string>
6#include <vector>
7
8std::vector<std::vector<int>> merge(std::vector<std::vector<int>>& left, std::vector<std::vector<int>>& right, std::vector<int>& counts) {
9    std::vector<std::vector<int>> res;
10    int l = 0;
11    int r = 0;
12    while (l < left.size() || r < right.size()) {
13        if (r >= right.size() || (l < left.size() && left[l][1] <= right[r][1])) {
14            res.emplace_back(left[l]);
15            counts[left[l][0]] = counts[left[l][0]] + r;
16            l++;
17        } else {
18            res.emplace_back(right[r]);
19            r++;
20        }
21    }
22    return res;
23}
24
25std::vector<std::vector<int>> merge_sort(std::vector<std::vector<int>>& nums, std::vector<int>& counts) {
26    if (nums.size() <= 1) return nums;
27    int mid = nums.size() / 2;
28    std::vector<std::vector<int>> split_left(nums.begin(), nums.begin() + mid);
29    std::vector<std::vector<int>> split_right(nums.begin() + mid, nums.end());
30    std::vector<std::vector<int>> left = merge_sort(split_left, counts);
31    std::vector<std::vector<int>> right = merge_sort(split_right, counts);
32    return merge(left, right, counts);
33}
34
35std::vector<int> count_smaller(std::vector<int>& nums) {
36    std::vector<int> counts(nums.size(), 0);
37    std::vector<std::vector<int>> idx_num_mapping;
38    for (int i = 0; i < nums.size(); i++) {
39        std::vector<int> idx_num_pair{i, nums[i]};
40        idx_num_mapping.emplace_back(idx_num_pair);
41    }
42    merge_sort(idx_num_mapping, counts);
43    return counts;
44}
45
46template<typename T>
47std::vector<T> get_words() {
48    std::string line;
49    std::getline(std::cin, line);
50    std::istringstream ss{line};
51    ss >> std::boolalpha;
52    std::vector<T> v;
53    std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
54    return v;
55}
56
57template<typename T>
58void put_words(const std::vector<T>& v) {
59    if (!v.empty()) {
60        std::copy(v.begin(), std::prev(v.end()), std::ostream_iterator<T>{std::cout, " "});
61        std::cout << v.back();
62    }
63    std::cout << '\n';
64}
65
66int main() {
67    std::vector<int> nums = get_words<int>();
68    std::vector<int> res = count_smaller(nums);
69    put_words(res);
70}
71

If the problem asks for the number of swaps, we can simply keep a counter each time we swap and don't have to keep the array.

1def number_of_swaps_to_sort(nums: list[int]) -> int:
2    count = 0
3
4    def merge(left, right):
5        nonlocal count
6        result = []
7        l, r = 0, 0
8        while l < len(left) or r < len(right):
9            if r >= len(right) or (l < len(left) and left[l][1] <= right[r][1]):
10                result.append(left[l])
11                count += r
12                l += 1
13            else:
14                result.append(right[r])
15                r += 1
16        return result
17
18    def merge_sort(nums):
19        if len(nums) <= 1:
20            return nums
21        mid = len(nums) // 2
22        left = merge_sort(nums[:mid])
23        right = merge_sort(nums[mid:])
24        return merge(left, right)
25
26    merge_sort(list(enumerate(nums)))
27    return count
28
29if __name__ == "__main__":
30    nums = [int(x) for x in input().split()]
31    res = number_of_swaps_to_sort(nums)
32    print(res)
33
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.Scanner;
5import java.util.stream.Collectors;
6
7class Solution {
8    public static class Number {
9        int index;
10        int val;
11        public Number(int i, int v) {
12            index = i;
13            val = v;
14        }
15    };
16
17    private static List<Number> merge(List<Number> left, List<Number> right, int[] countRef) {
18        List<Number> result = new ArrayList<>();
19        int l = 0;
20        int r = 0;
21        while (l < left.size() || r < right.size()) {
22            if (r >= right.size() || (l < left.size() && left.get(l).val <= right.get(r).val)) {
23                result.add(left.get(l));
24                countRef[0] += r;
25                l++;
26            } else {
27                result.add(right.get(r));
28                r++;
29            }
30        }
31        return result;
32    }
33
34    private static List<Number> mergeSort(List<Number> nums, int[] countRef) {
35        if (nums.size() <= 1) {
36            return nums;
37        }
38        int mid = nums.size() / 2;
39        List<Number> left = mergeSort(new ArrayList<>(nums.subList(0, mid)), countRef);
40        List<Number> right = mergeSort(new ArrayList<>(nums.subList(mid, nums.size())), countRef);
41        return merge(left, right, countRef);
42    }
43
44    public static int numberOfSwapsToSort(List<Integer> nums) {
45        List<Number> numbers = new ArrayList<>();
46        for (int i = 0; i < nums.size(); i++) {
47            numbers.add(new Number(i, nums.get(i)));
48        }
49        int[] count = {0};
50        mergeSort(numbers, count);
51        return count[0];
52    }
53
54    public static List<String> splitWords(String s) {
55        return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
56    }
57
58    public static void main(String[] args) {
59        Scanner scanner = new Scanner(System.in);
60        List<Integer> nums = splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList());
61        scanner.close();
62        int res = numberOfSwapsToSort(nums);
63        System.out.println(res);
64    }
65}
66
1"use strict";
2
3function numberOfSwapsToSort(nums) {
4    let count = 0;
5
6    function merge(left, right) {
7        const result = [];
8        let l = 0;
9        let r = 0;
10        while (l < left.length || r < right.length) {
11            if (r >= right.length || (l < left.length && left[l][1] <= right[r][1])) {
12                result.push(left[l]);
13                count += r;
14                l += 1;
15            } else {
16                result.push(right[r]);
17                r += 1;
18            }
19        }
20        return result;
21    }
22
23    function mergeSort(nums) {
24        if (nums.length <= 1) return nums;
25        const mid = Math.floor(nums.length / 2);
26        const left = mergeSort(nums.slice(0, mid));
27        const right = mergeSort(nums.slice(mid));
28        return merge(left, right);
29    }
30
31    const temp = [];
32    nums.map((e, i) => temp.push([i, e]));
33    mergeSort(temp);
34    return count;
35}
36
37function splitWords(s) {
38    return s === "" ? [] : s.split(" ");
39}
40
41function* main() {
42    const nums = splitWords(yield).map((v) => parseInt(v));
43    const res = numberOfSwapsToSort(nums);
44    console.log(res);
45}
46
47class EOFError extends Error {}
48{
49    const gen = main();
50    const next = (line) => gen.next(line).done && process.exit();
51    let buf = "";
52    next();
53    process.stdin.setEncoding("utf8");
54    process.stdin.on("data", (data) => {
55        const lines = (buf + data).split("\n");
56        buf = lines.pop();
57        lines.forEach(next);
58    });
59    process.stdin.on("end", () => {
60        buf && next(buf);
61        gen.throw(new EOFError());
62    });
63}
64
1function numberOfSwapsToSort(nums: number[]): number {
2    let count = 0;
3
4    function merge(left: [number, number][], right: [number, number][]): [number, number][] {
5        const result: [number, number][] = [];
6        let l = 0;
7        let r = 0;
8        while (l < left.length || r < right.length) {
9            if (r >= right.length || (l < left.length && left[l][1] <= right[r][1])) {
10                result.push(left[l]);
11                count += r;
12                l += 1;
13            } else {
14                result.push(right[r]);
15                r += 1;
16            }
17        }
18        return result;
19    }
20
21    function mergeSort(nums: [number, number][]): [number, number][] {
22        if (nums.length <= 1) return nums;
23        const mid = Math.floor(nums.length / 2);
24        const left = mergeSort(nums.slice(0, mid));
25        const right = mergeSort(nums.slice(mid));
26        return merge(left, right);
27    }
28
29    const temp: [number, number][] = [];
30    nums.map((e, i) => temp.push([i, e]));
31    mergeSort(temp);
32    return count;
33}
34
35function splitWords(s: string): string[] {
36    return s === "" ? [] : s.split(" ");
37}
38
39function* main() {
40    const nums = splitWords(yield).map((v) => parseInt(v));
41    const res = numberOfSwapsToSort(nums);
42    console.log(res);
43}
44
45class EOFError extends Error {}
46{
47    const gen = main();
48    const next = (line?: string) => gen.next(line ?? "").done && process.exit();
49    let buf = "";
50    next();
51    process.stdin.setEncoding("utf8");
52    process.stdin.on("data", (data) => {
53        const lines = (buf + data).split("\n");
54        buf = lines.pop() ?? "";
55        lines.forEach(next);
56    });
57    process.stdin.on("end", () => {
58        buf && next(buf);
59        gen.throw(new EOFError());
60    });
61}
62
1#include <algorithm>
2#include <iostream>
3#include <iterator>
4#include <sstream>
5#include <string>
6#include <utility>
7#include <vector>
8
9std::vector<std::vector<int>> merge(std::vector<std::vector<int>>& left, std::vector<std::vector<int>>& right, int& count) {
10    std::vector<std::vector<int>> res;
11    int l = 0;
12    int r = 0;
13    while (l < left.size() || r < right.size()) {
14        if (r >= right.size() || (l < left.size() && left[l][1] <= right[r][1])) {
15            res.emplace_back(left[l]);
16            count += r;
17            l++;
18        } else {
19            res.emplace_back(right[r]);
20            r++;
21        }
22    }
23    return res;
24}
25
26std::vector<std::vector<int>> merge_sort(std::vector<std::vector<int>>& nums, int& count) {
27    if (nums.size() <= 1) return nums;
28    int mid = nums.size() / 2;
29    std::vector<std::vector<int>> split_left(nums.begin(), nums.begin() + mid);
30    std::vector<std::vector<int>> split_right(nums.begin() + mid, nums.end());
31    std::vector<std::vector<int>> left = merge_sort(split_left, count);
32    std::vector<std::vector<int>> right = merge_sort(split_right, count);
33    return merge(left, right, count);
34}
35
36int number_of_swaps_to_sort(std::vector<int>& nums) {
37    int count = 0;
38    std::vector<std::vector<int>> idx_num_mapping;
39    for (int i = 0; i < nums.size(); i++) {
40        std::vector<int> idx_num_pair{i, nums[i]};
41        idx_num_mapping.emplace_back(std::move(idx_num_pair));
42    }
43    merge_sort(idx_num_mapping, count);
44    return count;
45}
46
47template<typename T>
48std::vector<T> get_words() {
49    std::string line;
50    std::getline(std::cin, line);
51    std::istringstream ss{line};
52    ss >> std::boolalpha;
53    std::vector<T> v;
54    std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
55    return v;
56}
57
58int main() {
59    std::vector<int> nums = get_words<int>();
60    int res = number_of_swaps_to_sort(nums);
61    std::cout << res << '\n';
62}
63
Invest in Yourself
Your new job is waiting. 83% of people that complete the program get a job offer. Unlock unlimited access to all content and features.
Go Pro