二叉堆/优先级队列代码实现
前文二叉堆的原理 介绍了二叉堆的基本性质、API 和常见应用。
我们先实现一个简化版的优先级队列,用来帮你理解二叉堆的核心操作 sink 和 swim。最后我再用给出一个比较完整的代码实现。
简化版优先级队列
我们实现的这个简化版优先级队列有如下限制:
1、不支持泛型,仅支持存储整数类型的元素。
2、不考虑扩容的问题,队列的容量在创建时固定,假设插入的元素数量不会超过这个容量。
3、底层仅实现一个小顶堆(即根节点是整个堆中的最小值),不支持自定义比较器。
基于上面这些限制,这个简化版优先级队列的 API 如下:
java:
class SimpleMinPQ {
// 创建一个容量为 capacity 的优先级队列
public SimpleMinPQ(int capacity);
// 返回队列中的元素个数
public int size();
// 向队列中插入一个元素
public void push(int x);
// 返回队列中的最小元素(堆顶元素)
public int peek();
// 删除并返回队列中的最小元素(堆顶元素)
public int pop();
}
// 使用方法
SimpleMinPQ pq = new SimpleMinPQ(10);
pq.push(3);
pq.push(4);
pq.push(1);
pq.push(2);
System.out.println(pq.pop()); // 1
System.out.println(pq.pop()); // 2
System.out.println(pq.pop()); // 3
System.out.println(pq.pop()); // 4
go:
type SimpleMinPQ struct {
// 创建一个容量为 capacity 的优先级队列
SimpleMinPQ(capacity int)
// 返回队列中的元素个数
Size() int
// 向队列中插入一个元素
Push(x int)
// 返回队列中的最小元素(堆顶元素)
Peek() int
// 删除并返回队列中的最小元素(堆顶元素)
Pop() int
}
// 使用方法
pq := SimpleMinPQ(10)
pq.Push(3)
pq.Push(4)
pq.Push(1)
pq.Push(2)
fmt.Println(pq.Pop()) // 1
fmt.Println(pq.Pop()) // 2
fmt.Println(pq.Pop()) // 3
fmt.Println(pq.Pop()) // 4
难点分析 在前文二叉堆的原理中你应该也感觉到了,二叉堆的难点在于 你在插入或删除元素时,还要保持堆的性质。
具体来说,看下面这个可视化面板,我在这个小顶堆中调用 push 方法插入元素 4,然后再调用 pop 方法删除堆顶元素 0。
请你先点击 let minHeap 这部分代码,让最小堆以及初始元素构造出来。注意看每个二叉树节点的值都比它的两个子树上的节点的值小,满足小顶堆的性质。
然后点击 push(4) 那行代码,可以看到这个新元素 4 被插入到了原先 6 的位置,而 6 被下沉为 4 的子节点,这样依然保持了小顶堆的性质。如果你直接把 4 放到树的最下层的话,比如作为 6 的子节点,就不满足小顶堆的性质了。
最后点击 pop() 那行代码,可以看到堆顶元素 0 被删除,元素 1 取代了 0 的位置作为新的堆顶元素,而 6 被从最左侧移动元素 1 原先的位置。这样依然保持了小顶堆的性质。
增:push/swim 方法插入元素
核心步骤
以小顶堆为例,向小顶堆中插入新元素遵循两个步骤:
1、先把新元素追加到二叉树底层的最右侧,保持完全二叉树的结构。此时该元素的父节点可能比它大,不满足小顶堆的性质。
2、为了恢复小顶堆的性质,需要将这个新元素不断上浮(`swim`),直到它的父节点比它小为止,或者到达根节点。此时整个二叉树就满足小顶堆的性质了。
删:pop/sink 方法删除元素
核心步骤
以小顶堆为例,删除小顶堆的堆顶元素遵循两个步骤:
1、先把堆顶元素删除,把二叉树底层的最右侧元素摘除并移动到堆顶,保持完全二叉树的结构。此时堆顶元素可能比它的子节点大,不满足小顶堆的性质。
2、为了恢复小顶堆的性质,需要将这个新的堆顶元素不断下沉(sink),直到它比它的子节点小为止,或者到达叶子节点。此时整个二叉树就满足小顶堆的性质了。
查:peek 方法查看堆顶元素
在数组上模拟二叉树
在之前的所有内容中,我都把二叉堆作为一种二叉树来讲解,而且可视化面板中也是通过操作 HeapNode 节点的方式来展示的。但实际上,我们在代码实现的时候,不会用类似 HeapNode 的节点类来实现,而是用数组来模拟二叉树结构。
用数组模拟二叉树的原因
第一个原因是前面介绍
数组 和
链表 时说到的,链表节点需要一个额外的指针存储相邻节点的地址,所以相对数组,链表的内存消耗会大一些。我们这里的 HeapNode 类也是链式存储的例子,和链表节点类似,需要额外的指针存储父节点和子节点的地址。
第二个原因,也是最主要的原因,是时间复杂度的问题。仔细想一下前面我给你展示的 push 和 pop 方法的操作过程,它们的第一步是什么?是不是要找到二叉树最底层的最右侧元素?
因为上面举的场景是我们自己构造的,可以直接用操作 left, right 指针的方式把目标节点拿到。但你想想,正常情况下你如何拿到二叉树的底层最右侧节点?你需要层序遍历或递归遍历二叉树,时间复杂度是 O(N),进而导致 push 和 pop 方法的时间复杂度退化到 O(N),这显然是不可接受的。
如果用数组来模拟二叉树,就可以完美解决这个问题,在 O(1) 时间内找到二叉树的底层最右侧节点。
完全二叉树是关键
想要用数组模拟二叉树,前提是这个二叉树必须是完全二叉树。
我在二叉树基础中介绍过完全二叉树,就是除了最后一层,其他层的节点都是满的,最后一层的节点都靠左排列。
由于完全二叉树上的元素都是紧凑排列的,我们可以用数组来存储。
直接在数组的末尾追加元素,就相当于在完全二叉树的最后一层从左到右依次填充元素;数组中最后一个元素,就是完全二叉树的底层最右侧的元素,完美契合我们实现二叉堆的场景。
看这幅图就明白了:

在这个数组中,索引 0 空着不用,就可以根据任意节点的索引计算出父节点或左右子节点的索引: java:
// 父节点的索引
int parent(int node) {
return node / 2;
}
// 左子节点的索引
int left(int node) {
return node * 2;
}
// 右子节点的索引
int right(int node) {
return node * 2 + 1;
}
go:
// 父节点的索引
func parent(node int) int {
return node / 2
}
// 左子节点的索引
func left(node int) int {
return node * 2
}
// 右子节点的索引
func right(node int) int {
return node * 2 + 1
}
有读者会问,为啥数组中索引 0 要空着不用,从 1 开始存储元素呢?
其实从 0 开始也是可以的,稍微改一改计算公式就行了:
java:
// 父节点的索引
int parent(int node) {
return (node - 1) / 2;
}
// 左子节点的索引
int left(int node) {
return node * 2 + 1;
}
// 右子节点的索引
int right(int node) {
return node * 2 + 2;
}
go:
// 父节点的索引
func parent(node int) int {
return (node - 1) / 2
}
// 左子节点的索引
func left(node int) int {
return node * 2 + 1
}
// 右子节点的索引
func right(node int) int {
return node * 2 + 2
}
代码实现
下面是一个简化版的小顶堆优先级队列核心逻辑的实现,没有特别处理边界情况,供你参考: java:
public class SimpleMinPQ {
// 底层使用数组实现二叉堆
private final int[] heap;
// 堆中元素的数量
private int size;
public SimpleMinPQ(int capacity) {
heap = new int[capacity];
size = 0;
}
public int size() {
return size;
}
// 父节点的索引
private int parent(int node) {
return (node - 1) / 2;
}
// 左子节点的索引
private int left(int node) {
return node * 2 + 1;
}
// 右子节点的索引
private int right(int node) {
return node * 2 + 2;
}
// 交换数组的两个元素
private void swap(int i, int j) {
int temp = heap[i];
heap[i] = heap[j];
heap[j] = temp;
}
// 查,返回堆顶元素,时间复杂度 O(1)
public int peek() {
return heap[0];
}
// 增,向堆中插入一个元素,时间复杂度 O(logN)
public void push(int x) {
// 把新元素追加到最后
heap[size] = x;
// 然后上浮到正确位置
swim(size);
size++;
}
// 删,删除堆顶元素,时间复杂度 O(logN)
public int pop() {
int res = heap[0];
// 把堆底元素放到堆顶
heap[0] = heap[size - 1];
size--;
// 然后下沉到正确位置
sink(0);
return res;
}
// 上浮操作,时间复杂度是树高 O(logN)
private void swim(int node) {
while (node > 0 && heap[parent(node)] > heap[node]) {
swap(parent(node), node);
node = parent(node);
}
}
// 下沉操作,时间复杂度是树高 O(logN)
private void sink(int node) {
while (left(node) < size || right(node) < size) {
// 比较自己和左右子节点,看看谁最小
int min = node;
if (left(node) < size && heap[left(node)] < heap[min]) {
min = left(node);
}
if (right(node) < size && heap[right(node)] < heap[min]) {
min = right(node);
}
if (min == node) {
break;
}
// 如果左右子节点中有比自己小的,就交换
swap(node, min);
node = min;
}
}
public static void main(String[] args) {
SimpleMinPQ pq = new SimpleMinPQ(5);
pq.push(3);
pq.push(2);
pq.push(1);
pq.push(5);
pq.push(4);
System.out.println(pq.pop()); // 1
System.out.println(pq.pop()); // 2
System.out.println(pq.pop()); // 3
System.out.println(pq.pop()); // 4
System.out.println(pq.pop()); // 5
}
}
go:
package main
import "fmt"
type SimpleMinPQ struct {
// 底层使用数组实现二叉堆
heap []int
// 堆中元素的数量
size int
}
// 父节点的索引
func (pq *SimpleMinPQ) parent(node int) int {
return (node - 1) / 2
}
// 左子节点的索引
func (pq *SimpleMinPQ) left(node int) int {
return node*2 + 1
}
// 右子节点的索引
func (pq *SimpleMinPQ) right(node int) int {
return node*2 + 2
}
// 交换数组的两个元素
func (pq *SimpleMinPQ) swap(i, j int) {
pq.heap[i], pq.heap[j] = pq.heap[j], pq.heap[i]
}
// 查,返回堆顶元素,时间复杂度 O(1)
func (pq *SimpleMinPQ) peek() int {
return pq.heap[0]
}
// 增,向堆中插入一个元素,时间复杂度 O(logN)
func (pq *SimpleMinPQ) push(x int) {
// 把新元素追加到最后
pq.heap[pq.size] = x
// 然后上浮到正确位置
pq.swim(pq.size)
pq.size++
}
// 删,删除堆顶元素,时间复杂度 O(logN)
func (pq *SimpleMinPQ) pop() int {
res := pq.heap[0]
// 把堆底元素放到堆顶
pq.heap[0] = pq.heap[pq.size-1]
pq.size--
// 然后下沉到正确位置
pq.sink(0)
return res
}
// 上浮操作,时间复杂度是树高 O(logN)
func (pq *SimpleMinPQ) swim(node int) {
for node > 0 && pq.heap[pq.parent(node)] > pq.heap[node] {
pq.swap(pq.parent(node), node)
node = pq.parent(node)
}
}
// 下沉操作,时间复杂度是树高 O(logN)
func (pq *SimpleMinPQ) sink(node int) {
for pq.left(node) < pq.size || pq.right(node) < pq.size {
// 比较自己和左右子节点,看看谁最小
min := node
if pq.left(node) < pq.size && pq.heap[pq.left(node)] < pq.heap[min] {
min = pq.left(node)
}
if pq.right(node) < pq.size && pq.heap[pq.right(node)] < pq.heap[min] {
min = pq.right(node)
}
if min == node {
break
}
// 如果左右子节点中有比自己小的,就交换
pq.swap(node, min)
node = min
}
}
func main() {
pq := SimpleMinPQ{
heap: make([]int, 5),
size: 0,
}
pq.push(3)
pq.push(2)
pq.push(1)
pq.push(5)
pq.push(4)
fmt.Println(pq.pop()) // 1
fmt.Println(pq.pop()) // 2
fmt.Println(pq.pop()) // 3
fmt.Println(pq.pop()) // 4
fmt.Println(pq.pop()) // 5
}
明白了这个 SimpleMinPQ 类的实现,如果你想实现一个大顶堆 SimpleMaxPQ,只需要把 swim 和 sink 方法中元素大小比较的逻辑反过来即可,这里就不再赘述了。
完善版优先级队列
基于上面的简化版优先级队列,只要加上泛型、自定义比较器、扩容等功能,就可以实现一个比较完善的优先级队列了: java:
import java.util.Comparator;
import java.util.NoSuchElementException;
public class MyPriorityQueue<T> {
private T[] heap;
private int size;
private final Comparator<? super T> comparator;
@SuppressWarnings("unchecked")
public MyPriorityQueue(int capacity, Comparator<? super T> comparator) {
heap = (T[]) new Object[capacity];
size = 0;
this.comparator = comparator;
}
public int size() {
return size;
}
public boolean isEmpty() {
return size == 0;
}
// 父节点的索引
private int parent(int node) {
return (node - 1) / 2;
}
// 左子节点的索引
private int left(int node) {
return node * 2 + 1;
}
// 右子节点的索引
private int right(int node) {
return node * 2 + 2;
}
// 交换数组的两个元素
private void swap(int i, int j) {
T temp = heap[i];
heap[i] = heap[j];
heap[j] = temp;
}
// 查,返回堆顶元素,时间复杂度 O(1)
public T peek() {
if (isEmpty()) {
throw new NoSuchElementException("Priority queue underflow");
}
return heap[0];
}
// 增,向堆中插入一个元素,时间复杂度 O(logN)
public void push(T x) {
// 扩容
if (size == heap.length) {
resize(2 * heap.length);
}
// 把新元素追加到最后
heap[size] = x;
// 然后上浮到正确位置
swim(size);
size++;
}
// 删,删除堆顶元素,时间复杂度 O(logN)
public T pop() {
if (isEmpty()) {
throw new NoSuchElementException("Priority queue underflow");
}
T res = heap[0];
// 把堆底元素放到堆顶
swap(0, size - 1);
// 避免对象游离
heap[size - 1] = null;
size--;
// 然后下沉到正确位置
sink(0);
// 缩容
if ((size > 0) && (size == heap.length / 4)) {
resize(heap.length / 2);
}
return res;
}
// 上浮操作,时间复杂度是树高 O(logN)
private void swim(int node) {
while (node > 0 && comparator.compare(heap[parent(node)], heap[node]) > 0) {
swap(parent(node), node);
node = parent(node);
}
}
// 下沉操作,时间复杂度是树高 O(logN)
private void sink(int node) {
while (left(node) < size || right(node) < size) {
// 比较自己和左右子节点,看看谁最小
int min = node;
if (left(node) < size && comparator.compare(heap[left(node)], heap[min]) < 0) {
min = left(node);
}
if (right(node) < size && comparator.compare(heap[right(node)], heap[min]) < 0) {
min = right(node);
}
if (min == node) {
break;
}
// 如果左右子节点中有比自己小的,就交换
swap(node, min);
node = min;
}
}
// 调整堆的大小
@SuppressWarnings("unchecked")
private void resize(int capacity) {
assert capacity > size;
T[] temp = (T[]) new Object[capacity];
for (int i = 0; i < size; i++) {
temp[i] = heap[i];
}
heap = temp;
}
public static void main(String[] args) {
MyPriorityQueue<Integer> pq = new MyPriorityQueue<>(3, Comparator.naturalOrder());
pq.push(3);
pq.push(1);
pq.push(4);
pq.push(1);
pq.push(5);
pq.push(9);
// 1 1 3 4 5 9
while (!pq.isEmpty()) {
System.out.println(pq.pop());
}
}
}
go:
package main
import (
"fmt"
"errors"
)
type MyPriorityQueue struct {
// 堆数组
heap []interface{}
// 堆中元素的数量
size int
// 元素比较器
comparator func(x, y interface{}) int
}
// 构造函数
func NewMyPriorityQueue(capacity int, comparator func(x, y interface{}) int) *MyPriorityQueue {
return &MyPriorityQueue{
heap: make([]interface{}, capacity),
size: 0,
comparator: comparator,
}
}
// 返回堆的大小
func (pq *MyPriorityQueue) Size() int {
return pq.size
}
// 判断堆是否为空
func (pq *MyPriorityQueue) IsEmpty() bool {
return pq.size == 0
}
// 父节点的索引
func (pq *MyPriorityQueue) Parent(node int) int {
return (node - 1) / 2
}
// 左子节点的索引
func (pq *MyPriorityQueue) Left(node int) int {
return node*2 + 1
}
// 右子节点的索引
func (pq *MyPriorityQueue) Right(node int) int {
return node*2 + 2
}
// 交换数组的两个元素
func (pq *MyPriorityQueue) Swap(i, j int) {
pq.heap[i], pq.heap[j] = pq.heap[j], pq.heap[i]
}
// 查,返回堆顶元素,时间复杂度 O(1)
func (pq *MyPriorityQueue) Peek() (interface{}, error) {
if pq.IsEmpty() {
return nil, errors.New("priority queue underflow")
}
return pq.heap[0], nil
}
// 增,向堆中插入一个元素,时间复杂度 O(logN)
func (pq *MyPriorityQueue) Push(x interface{}) {
// 扩容
if pq.size == len(pq.heap) {
pq.resize(2 * len(pq.heap))
}
// 把新元素追加到最后
pq.heap[pq.size] = x
// 然后上浮到正确位置
pq.swim(pq.size)
pq.size++
}
// 删,删除堆顶元素,时间复杂度 O(logN)
func (pq *MyPriorityQueue) Pop() (interface{}, error) {
if pq.IsEmpty() {
return nil, errors.New("priority queue underflow")
}
res := pq.heap[0]
// 把堆底元素放到堆顶
pq.Swap(0, pq.size-1)
// 避免对象游离
pq.heap[pq.size-1] = nil
pq.size--
// 然后下沉到正确位置
pq.sink(0)
// 缩容
if pq.size > 0 && pq.size == len(pq.heap)/4 {
pq.resize(len(pq.heap) / 2)
}
return res, nil
}
// 上浮操作,时间复杂度是树高 O(logN)
func (pq *MyPriorityQueue) swim(node int) {
for node > 0 && pq.comparator(pq.heap[pq.Parent(node)], pq.heap[node]) > 0 {
pq.Swap(pq.Parent(node), node)
node = pq.Parent(node)
}
}
// 下沉操作,时间复杂度是树高 O(logN)
func (pq *MyPriorityQueue) sink(node int) {
for pq.Left(node) < pq.size {
// 比较自己和左右子节点,看看谁最小
minNode := node
if pq.Left(node) < pq.size && pq.comparator(pq.heap[pq.Left(node)], pq.heap[minNode]) < 0 {
minNode = pq.Left(node)
}
if pq.Right(node) < pq.size && pq.comparator(pq.heap[pq.Right(node)], pq.heap[minNode]) < 0 {
minNode = pq.Right(node)
}
if minNode == node {
break
}
// 如果左右子节点中有比自己小的,就交换
pq.Swap(node, minNode)
node = minNode
}
}
// 调整堆的大小
func (pq *MyPriorityQueue) resize(capacity int) {
newHeap := make([]interface{}, capacity)
for i := 0; i < pq.size; i++ {
newHeap[i] = pq.heap[i]
}
pq.heap = newHeap
}
func main() {
pq := NewMyPriorityQueue(3, func(x, y interface{}) int {
a := x.(int)
b := y.(int)
if a < b {
return -1
} else if a > b {
return 1
}
return 0
})
pq.Push(3)
pq.Push(1)
pq.Push(4)
pq.Push(1)
pq.Push(5)
pq.Push(9)
// 1 1 3 4 5 9
for !pq.IsEmpty() {
item, _ := pq.Pop()
fmt.Println(item)
}
}