container/heap包使用指南

本文基于官方文档介绍 golang 标准库中提供的堆/优先队列的使用方法.

1. 概述

“container/heap” 包提供了实现堆操作的接口,用户只需要定义满足 “heap.Interface” 接口的类型,就可以通过包提供的函数,像操作大根堆或小根堆一样,对实例数组变量进行 PushPop 操作。
堆通常是一个可以被看做一棵树的数组对象,堆总是满足下列性质:

  • 堆中某个结点的值总是不大于或不小于其父结点的值;
  • 堆总是一棵完全二叉树。

堆的定义如下:n 个元素的序列 {k1,k2,ki,…,kn} 当且仅当满足下关系时,称之为堆:

xxrlF0.png

下面将基于具体实例介绍"container/heap" 包的使用。

2. 整数堆

下面的代码实现了一个整数类型的最小堆:

package main

import (
    "container/heap"
    "log"
)

type IntHeap []int

// 下面几个函数必须实现,heap包会进行回调
func (h IntHeap) Len() int { return len(h) }

// Less函数的实现决定最终实现是最大堆还是最小堆
func (h IntHeap) Less(i, j int) bool { return h[i] < h[j] }
func (h IntHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (h *IntHeap) Push(x any)        { *h = append((*h), x.(int)) }
func (h *IntHeap) Pop() any {
    old := *h
    n := len(old)
    x := old[n-1]
    *h = old[0 : n-1]
    return x
}

func main() {
    h := &IntHeap{2, 1, 5}
    heap.Init(h)
    heap.Push(h, 3)
    log.Println(*h)
    for h.Len() > 0 {
        log.Println(heap.Pop(h))
    }
}

3. 标准库堆的实现

通过整数堆的使用方式,发现是通过定义新的整数数组类型,并为其实现 pointer receivers 的方法 Push、Pop,以及 value receivers 的方法 Len、Less、Swap 方法之后,借助 “container/heap” 包提供的方法对该类型定义的几个方法进行回调,从而实现堆的功能。下面是具体的实现代码,逻辑见注释:

package heap

import "sort"

// Interface 接口指明了想要使用这个包中的方法去实现堆,应该提供的接口方法
// 在 heap.Init 方法被调用、数据为空或者原始数据有序时,满足下列条件的情况下,小根堆会被建立
// !h.Less(j, i) for 0 <= i < h.Len() and 2*i+1 <= j <= 2*i+2 and j < h.Len()
// 注意:包中的方法会在添加和删除元素的时候调用Interface 接口实现的 Push 和 Pop方法,详见下面代码
type Interface interface {
    sort.Interface // 包括 Len()、Less(i, j int)、Swap(i, j int)
    Push(x any)    // add x as element Len()
    Pop() any      // remove and return element Len() - 1.
}

// 将传入的变量初始化为堆,时间复杂度为 O(n), 其中 n = h.Len().
func Init(h Interface) {
    // heapify
    n := h.Len()
    for i := n/2 - 1; i >= 0; i-- {
        down(h, i, n)
    }
}

// 向堆中加入新的元素,时间复杂度为 O(log n), 其中 n = h.Len().
func Push(h Interface, x any) {
    h.Push(x)
    up(h, h.Len()-1)
}

// Pop 移除并返回堆中的最小或最大元素,具体根据 h.Less 函数确定,时间复杂度为 O(log n).
// Pop 等价于 Remove(h, 0).
func Pop(h Interface) any {
    n := h.Len() - 1
    h.Swap(0, n)
    down(h, 0, n)
    // 由此看出实现 h.Pop 方法时只需要将数组最后元素取出并返回即可
    return h.Pop()
}

// Remove 移除并返回堆中索引为 i 的元素,时间复杂度为 O(log n).
func Remove(h Interface, i int) any {
    n := h.Len() - 1
    if n != i {
        h.Swap(i, n)
        if !down(h, i, n) {
            up(h, i)
        }
    }
    return h.Pop()
}

// 当堆数组中索引 i 处的元素的值或优先级发生变更的时候调用 Fix 调整元素 i 在堆中的位置
func Fix(h Interface, i int) {
    if !down(h, i, h.Len()) {
        up(h, i)
    }
}

// 向上进行堆调整,将新增元素上升到满足条件的位置
func up(h Interface, j int) {
    for {
        i := (j - 1) / 2 // parent
        if i == j || !h.Less(j, i) {
            break
        }
        h.Swap(i, j)
        j = i
    }
}

// 向下进行堆调整,确保i0节点是左右子树中的最小节点
func down(h Interface, i0, n int) bool {
    i := i0
    for {
        j1 := 2*i + 1
        if j1 >= n || j1 < 0 { // j1 < 0 after int overflow
            break
        }
        j := j1 // left child
        if j2 := j1 + 1; j2 < n && h.Less(j2, j1) {
            j = j2 // = 2*i + 2  // right child
        }
        if !h.Less(j, i) {
            break
        }
        h.Swap(i, j)
        i = j
    }
    return i > i0
}

4. 优先队列

举例实现针对结构实现的最小堆,也即一般意义上的优先队列:

package main

import (
    "container/heap"
    "log"
)

type Node struct {
    Val  int
    Next float32
}

type NodeHeap []Node

func (pq NodeHeap) Less(i, j int) bool { return pq[i].Val < pq[j].Val }
func (pq NodeHeap) Swap(i, j int)      { pq[i], pq[j] = pq[j], pq[i] }
func (pq NodeHeap) Len() int           { return len(pq) }
func (pq *NodeHeap) Push(x any)        { *pq = append(*pq, x.(Node)) }
func (pq *NodeHeap) Pop() any {
    old := *pq
    n := len(old)
    x := old[n-1]
    *pq = old[0 : n-1]
    return x
}

func main() {
    pq := &NodeHeap{}
    heap.Init(pq)
    heap.Push(pq, Node{Val: 10, Next: 1.0})
    heap.Push(pq, Node{Val: 11, Next: 2.0})
    heap.Push(pq, Node{Val: 1, Next: 3.0})
    for pq.Len() > 0 {
        log.Println(heap.Pop(pq).(Node))
    }
}

5. 优先队列的应用

力扣第 23 题「 合并 K 个升序链表」,要求合并 k 个有序链表为 1 个有序列表,如何快速得到 k 个节点中的最小节点,接到结果链表上?
此时就可以使用上述实现的优先队列了,不过需要稍微改动一下结构体,实现如下:

/**
 * Definition for singly-linked list.
 * type ListNode struct {
 *     Val int
 *     Next *ListNode
 * }
 */
type ItemHeap []*ListNode
func (pq ItemHeap) Less(i, j int) bool { return pq[i].Val < pq[j].Val}
func (pq ItemHeap) Swap(i, j int) {pq[i], pq[j] = pq[j], pq[i]}
func (pq ItemHeap) Len() int {return len(pq)}
func (pq *ItemHeap) Push(x interface{}) { *pq = append(*pq, x.(*ListNode)) }
func (pq *ItemHeap) Pop() interface{} {
    old := *pq
    n := len(old)
    x := old[n-1]
    *pq = old[0:n-1]
    return x
}


func mergeKLists(lists []*ListNode) *ListNode {
    pq := &ItemHeap{}
    heap.Init(pq)
    for _, v := range lists {
        if v != nil {
            heap.Push(pq, v)
        }
    }
    // dummy
    dummy := &ListNode{}
    p := dummy
    for pq.Len() > 0 {
        x := heap.Pop(pq).(*ListNode)
        p.Next = x
        p = p.Next
        if x.Next != nil {
            heap.Push(pq, x.Next)
        }
    }
    return dummy.Next
}

6. 参考资料