Skip to content

Lua数据结构

概述

虽然Lua只有表(table)这一种内置的数据结构,但我们可以利用表的灵活性来实现各种常用的数据结构。本文档将介绍如何在Lua中实现栈、队列、链表、集合等常用数据结构。

1. 栈(Stack)

栈是一种后进先出(LIFO)的数据结构。

1.1 基本栈实现

lua
-- 栈的基本实现
local Stack = {}
Stack.__index = Stack

function Stack.new()
    return setmetatable({}, Stack)
end

function Stack:push(item)
    table.insert(self, item)
    return self
end

function Stack:pop()
    if #self == 0 then
        error("栈为空")
    end
    return table.remove(self)
end

function Stack:peek()
    if #self == 0 then
        return nil
    end
    return self[#self]
end

function Stack:is_empty()
    return #self == 0
end

function Stack:size()
    return #self
end

function Stack:clear()
    for i = #self, 1, -1 do
        self[i] = nil
    end
    return self
end

-- 栈的使用示例
local function stack_example()
    print("=== 栈示例 ===")
    
    local stack = Stack.new()
    
    -- 压栈
    stack:push(1)
    stack:push(2)
    stack:push(3)
    
    print("栈大小:", stack:size())
    print("栈顶元素:", stack:peek())
    
    -- 弹栈
    while not stack:is_empty() do
        print("弹出:", stack:pop())
    end
    
    print("栈是否为空:", stack:is_empty())
end

stack_example()

1.2 增强栈实现

lua
-- 增强的栈实现,支持容量限制和迭代
local EnhancedStack = {}
EnhancedStack.__index = EnhancedStack

function EnhancedStack.new(max_capacity)
    return setmetatable({
        max_capacity = max_capacity
    }, EnhancedStack)
end

function EnhancedStack:push(item)
    if self.max_capacity and #self >= self.max_capacity then
        error("栈已满,容量: " .. self.max_capacity)
    end
    table.insert(self, item)
    return self
end

function EnhancedStack:pop()
    if #self == 0 then
        return nil  -- 返回nil而不是抛出错误
    end
    return table.remove(self)
end

function EnhancedStack:peek()
    return self[#self]
end

function EnhancedStack:is_empty()
    return #self == 0
end

function EnhancedStack:is_full()
    return self.max_capacity and #self >= self.max_capacity
end

function EnhancedStack:size()
    return #self
end

function EnhancedStack:capacity()
    return self.max_capacity or math.huge
end

-- 迭代器(从栈顶到栈底)
function EnhancedStack:iterate()
    local index = #self + 1
    return function()
        index = index - 1
        if index > 0 then
            return index, self[index]
        end
    end
end

-- 转换为数组
function EnhancedStack:to_array()
    local result = {}
    for i = 1, #self do
        result[i] = self[i]
    end
    return result
end

-- 增强栈示例
local function enhanced_stack_example()
    print("\n=== 增强栈示例 ===")
    
    local stack = EnhancedStack.new(5)  -- 最大容量5
    
    -- 填充栈
    for i = 1, 5 do
        stack:push("item" .. i)
    end
    
    print("栈容量:", stack:capacity())
    print("栈大小:", stack:size())
    print("栈是否满:", stack:is_full())
    
    -- 尝试添加超出容量的元素
    local success, err = pcall(function()
        stack:push("overflow")
    end)
    print("添加超出容量的元素:", success and "成功" or err)
    
    -- 迭代栈
    print("栈内容(从顶到底):")
    for index, value in stack:iterate() do
        print("  " .. index .. ": " .. value)
    end
    
    -- 转换为数组
    local array = stack:to_array()
    print("转换为数组:", table.concat(array, ", "))
end

enhanced_stack_example()

2. 队列(Queue)

队列是一种先进先出(FIFO)的数据结构。

2.1 基本队列实现

lua
-- 队列的基本实现
local Queue = {}
Queue.__index = Queue

function Queue.new()
    return setmetatable({
        first = 1,
        last = 0
    }, Queue)
end

function Queue:enqueue(item)
    self.last = self.last + 1
    self[self.last] = item
    return self
end

function Queue:dequeue()
    if self.first > self.last then
        error("队列为空")
    end
    
    local item = self[self.first]
    self[self.first] = nil
    self.first = self.first + 1
    
    -- 重置索引以避免内存泄漏
    if self.first > self.last then
        self.first = 1
        self.last = 0
    end
    
    return item
end

function Queue:front()
    if self.first > self.last then
        return nil
    end
    return self[self.first]
end

function Queue:is_empty()
    return self.first > self.last
end

function Queue:size()
    return self.last - self.first + 1
end

function Queue:clear()
    for i = self.first, self.last do
        self[i] = nil
    end
    self.first = 1
    self.last = 0
    return self
end

-- 队列使用示例
local function queue_example()
    print("\n=== 队列示例 ===")
    
    local queue = Queue.new()
    
    -- 入队
    for i = 1, 5 do
        queue:enqueue("task" .. i)
    end
    
    print("队列大小:", queue:size())
    print("队首元素:", queue:front())
    
    -- 出队
    while not queue:is_empty() do
        print("处理任务:", queue:dequeue())
    end
    
    print("队列是否为空:", queue:is_empty())
end

queue_example()

2.2 循环队列实现

lua
-- 循环队列实现
local CircularQueue = {}
CircularQueue.__index = CircularQueue

function CircularQueue.new(capacity)
    return setmetatable({
        capacity = capacity,
        data = {},
        front = 1,
        rear = 1,
        count = 0
    }, CircularQueue)
end

function CircularQueue:enqueue(item)
    if self.count >= self.capacity then
        error("队列已满")
    end
    
    self.data[self.rear] = item
    self.rear = (self.rear % self.capacity) + 1
    self.count = self.count + 1
    return self
end

function CircularQueue:dequeue()
    if self.count == 0 then
        error("队列为空")
    end
    
    local item = self.data[self.front]
    self.data[self.front] = nil
    self.front = (self.front % self.capacity) + 1
    self.count = self.count - 1
    
    return item
end

function CircularQueue:peek()
    if self.count == 0 then
        return nil
    end
    return self.data[self.front]
end

function CircularQueue:is_empty()
    return self.count == 0
end

function CircularQueue:is_full()
    return self.count >= self.capacity
end

function CircularQueue:size()
    return self.count
end

function CircularQueue:get_capacity()
    return self.capacity
end

-- 循环队列示例
local function circular_queue_example()
    print("\n=== 循环队列示例 ===")
    
    local queue = CircularQueue.new(3)
    
    -- 填满队列
    queue:enqueue("A")
    queue:enqueue("B")
    queue:enqueue("C")
    
    print("队列已满:", queue:is_full())
    
    -- 出队一个元素
    print("出队:", queue:dequeue())
    
    -- 再入队一个元素
    queue:enqueue("D")
    
    print("当前队列大小:", queue:size())
    
    -- 清空队列
    while not queue:is_empty() do
        print("出队:", queue:dequeue())
    end
end

circular_queue_example()

3. 链表(Linked List)

链表是一种动态数据结构,元素通过指针连接。

3.1 单向链表

lua
-- 链表节点
local ListNode = {}
ListNode.__index = ListNode

function ListNode.new(data)
    return setmetatable({
        data = data,
        next = nil
    }, ListNode)
end

-- 单向链表
local LinkedList = {}
LinkedList.__index = LinkedList

function LinkedList.new()
    return setmetatable({
        head = nil,
        tail = nil,
        size = 0
    }, LinkedList)
end

function LinkedList:append(data)
    local new_node = ListNode.new(data)
    
    if not self.head then
        self.head = new_node
        self.tail = new_node
    else
        self.tail.next = new_node
        self.tail = new_node
    end
    
    self.size = self.size + 1
    return self
end

function LinkedList:prepend(data)
    local new_node = ListNode.new(data)
    
    if not self.head then
        self.head = new_node
        self.tail = new_node
    else
        new_node.next = self.head
        self.head = new_node
    end
    
    self.size = self.size + 1
    return self
end

function LinkedList:insert(index, data)
    if index < 1 or index > self.size + 1 then
        error("索引超出范围")
    end
    
    if index == 1 then
        return self:prepend(data)
    elseif index == self.size + 1 then
        return self:append(data)
    end
    
    local new_node = ListNode.new(data)
    local current = self.head
    
    for i = 1, index - 2 do
        current = current.next
    end
    
    new_node.next = current.next
    current.next = new_node
    self.size = self.size + 1
    
    return self
end

function LinkedList:remove(index)
    if index < 1 or index > self.size then
        error("索引超出范围")
    end
    
    if index == 1 then
        local data = self.head.data
        self.head = self.head.next
        if not self.head then
            self.tail = nil
        end
        self.size = self.size - 1
        return data
    end
    
    local current = self.head
    for i = 1, index - 2 do
        current = current.next
    end
    
    local data = current.next.data
    current.next = current.next.next
    
    if index == self.size then
        self.tail = current
    end
    
    self.size = self.size - 1
    return data
end

function LinkedList:get(index)
    if index < 1 or index > self.size then
        error("索引超出范围")
    end
    
    local current = self.head
    for i = 1, index - 1 do
        current = current.next
    end
    
    return current.data
end

function LinkedList:find(data)
    local current = self.head
    local index = 1
    
    while current do
        if current.data == data then
            return index
        end
        current = current.next
        index = index + 1
    end
    
    return nil
end

function LinkedList:get_size()
    return self.size
end

function LinkedList:is_empty()
    return self.size == 0
end

function LinkedList:to_array()
    local result = {}
    local current = self.head
    
    while current do
        table.insert(result, current.data)
        current = current.next
    end
    
    return result
end

-- 迭代器
function LinkedList:iterate()
    local current = self.head
    return function()
        if current then
            local data = current.data
            current = current.next
            return data
        end
    end
end

-- 链表示例
local function linked_list_example()
    print("\n=== 链表示例 ===")
    
    local list = LinkedList.new()
    
    -- 添加元素
    list:append("A")
    list:append("B")
    list:append("C")
    list:prepend("0")
    list:insert(3, "1.5")
    
    print("链表大小:", list:get_size())
    print("链表内容:", table.concat(list:to_array(), " -> "))
    
    -- 查找元素
    print("查找 'B' 的位置:", list:find("B"))
    print("获取第3个元素:", list:get(3))
    
    -- 删除元素
    print("删除第2个元素:", list:remove(2))
    print("删除后链表:", table.concat(list:to_array(), " -> "))
    
    -- 迭代链表
    print("迭代链表:")
    for data in list:iterate() do
        print("  " .. data)
    end
end

linked_list_example()

3.2 双向链表

lua
-- 双向链表节点
local DoublyListNode = {}
DoublyListNode.__index = DoublyListNode

function DoublyListNode.new(data)
    return setmetatable({
        data = data,
        next = nil,
        prev = nil
    }, DoublyListNode)
end

-- 双向链表
local DoublyLinkedList = {}
DoublyLinkedList.__index = DoublyLinkedList

function DoublyLinkedList.new()
    return setmetatable({
        head = nil,
        tail = nil,
        size = 0
    }, DoublyLinkedList)
end

function DoublyLinkedList:append(data)
    local new_node = DoublyListNode.new(data)
    
    if not self.head then
        self.head = new_node
        self.tail = new_node
    else
        new_node.prev = self.tail
        self.tail.next = new_node
        self.tail = new_node
    end
    
    self.size = self.size + 1
    return self
end

function DoublyLinkedList:prepend(data)
    local new_node = DoublyListNode.new(data)
    
    if not self.head then
        self.head = new_node
        self.tail = new_node
    else
        new_node.next = self.head
        self.head.prev = new_node
        self.head = new_node
    end
    
    self.size = self.size + 1
    return self
end

function DoublyLinkedList:remove_first()
    if not self.head then
        return nil
    end
    
    local data = self.head.data
    
    if self.head == self.tail then
        self.head = nil
        self.tail = nil
    else
        self.head = self.head.next
        self.head.prev = nil
    end
    
    self.size = self.size - 1
    return data
end

function DoublyLinkedList:remove_last()
    if not self.tail then
        return nil
    end
    
    local data = self.tail.data
    
    if self.head == self.tail then
        self.head = nil
        self.tail = nil
    else
        self.tail = self.tail.prev
        self.tail.next = nil
    end
    
    self.size = self.size - 1
    return data
end

function DoublyLinkedList:get_size()
    return self.size
end

function DoublyLinkedList:is_empty()
    return self.size == 0
end

function DoublyLinkedList:to_array()
    local result = {}
    local current = self.head
    
    while current do
        table.insert(result, current.data)
        current = current.next
    end
    
    return result
end

function DoublyLinkedList:to_array_reverse()
    local result = {}
    local current = self.tail
    
    while current do
        table.insert(result, current.data)
        current = current.prev
    end
    
    return result
end

-- 双向链表示例
local function doubly_linked_list_example()
    print("\n=== 双向链表示例 ===")
    
    local list = DoublyLinkedList.new()
    
    -- 添加元素
    list:append("A")
    list:append("B")
    list:append("C")
    list:prepend("0")
    
    print("链表大小:", list:get_size())
    print("正向遍历:", table.concat(list:to_array(), " -> "))
    print("反向遍历:", table.concat(list:to_array_reverse(), " -> "))
    
    -- 删除元素
    print("删除第一个元素:", list:remove_first())
    print("删除最后一个元素:", list:remove_last())
    print("删除后链表:", table.concat(list:to_array(), " -> "))
end

doubly_linked_list_example()

4. 集合(Set)

集合是不包含重复元素的数据结构。

4.1 基本集合实现

lua
-- 集合实现
local Set = {}
Set.__index = Set

function Set.new(items)
    local set = setmetatable({}, Set)
    if items then
        for _, item in ipairs(items) do
            set:add(item)
        end
    end
    return set
end

function Set:add(item)
    self[item] = true
    return self
end

function Set:remove(item)
    self[item] = nil
    return self
end

function Set:contains(item)
    return self[item] == true
end

function Set:size()
    local count = 0
    for _ in pairs(self) do
        count = count + 1
    end
    return count
end

function Set:is_empty()
    return next(self) == nil
end

function Set:clear()
    for key in pairs(self) do
        self[key] = nil
    end
    return self
end

function Set:to_array()
    local result = {}
    for item in pairs(self) do
        table.insert(result, item)
    end
    return result
end

function Set:iterate()
    return pairs(self)
end

-- 集合运算
function Set:union(other)
    local result = Set.new()
    
    for item in pairs(self) do
        result:add(item)
    end
    
    for item in pairs(other) do
        result:add(item)
    end
    
    return result
end

function Set:intersection(other)
    local result = Set.new()
    
    for item in pairs(self) do
        if other:contains(item) then
            result:add(item)
        end
    end
    
    return result
end

function Set:difference(other)
    local result = Set.new()
    
    for item in pairs(self) do
        if not other:contains(item) then
            result:add(item)
        end
    end
    
    return result
end

function Set:is_subset(other)
    for item in pairs(self) do
        if not other:contains(item) then
            return false
        end
    end
    return true
end

function Set:is_superset(other)
    return other:is_subset(self)
end

-- 集合示例
local function set_example()
    print("\n=== 集合示例 ===")
    
    local set1 = Set.new({1, 2, 3, 4, 5})
    local set2 = Set.new({4, 5, 6, 7, 8})
    
    print("集合1:", table.concat(set1:to_array(), ", "))
    print("集合2:", table.concat(set2:to_array(), ", "))
    
    -- 集合运算
    local union = set1:union(set2)
    local intersection = set1:intersection(set2)
    local difference = set1:difference(set2)
    
    print("并集:", table.concat(union:to_array(), ", "))
    print("交集:", table.concat(intersection:to_array(), ", "))
    print("差集:", table.concat(difference:to_array(), ", "))
    
    -- 子集检查
    local subset = Set.new({1, 2})
    print("子集检查:", subset:is_subset(set1))
    
    -- 元素操作
    print("包含3:", set1:contains(3))
    set1:remove(3)
    print("删除3后包含3:", set1:contains(3))
end

set_example()

5. 二叉树(Binary Tree)

二叉树是每个节点最多有两个子节点的树结构。

5.1 二叉搜索树

lua
-- 二叉树节点
local TreeNode = {}
TreeNode.__index = TreeNode

function TreeNode.new(data)
    return setmetatable({
        data = data,
        left = nil,
        right = nil
    }, TreeNode)
end

-- 二叉搜索树
local BinarySearchTree = {}
BinarySearchTree.__index = BinarySearchTree

function BinarySearchTree.new()
    return setmetatable({
        root = nil,
        size = 0
    }, BinarySearchTree)
end

function BinarySearchTree:insert(data)
    local function insert_node(node, data)
        if not node then
            return TreeNode.new(data)
        end
        
        if data < node.data then
            node.left = insert_node(node.left, data)
        elseif data > node.data then
            node.right = insert_node(node.right, data)
        end
        
        return node
    end
    
    self.root = insert_node(self.root, data)
    self.size = self.size + 1
    return self
end

function BinarySearchTree:search(data)
    local function search_node(node, data)
        if not node then
            return false
        end
        
        if data == node.data then
            return true
        elseif data < node.data then
            return search_node(node.left, data)
        else
            return search_node(node.right, data)
        end
    end
    
    return search_node(self.root, data)
end

function BinarySearchTree:remove(data)
    local function find_min(node)
        while node.left do
            node = node.left
        end
        return node
    end
    
    local function remove_node(node, data)
        if not node then
            return nil
        end
        
        if data < node.data then
            node.left = remove_node(node.left, data)
        elseif data > node.data then
            node.right = remove_node(node.right, data)
        else
            -- 找到要删除的节点
            if not node.left then
                return node.right
            elseif not node.right then
                return node.left
            else
                -- 有两个子节点
                local min_node = find_min(node.right)
                node.data = min_node.data
                node.right = remove_node(node.right, min_node.data)
            end
        end
        
        return node
    end
    
    self.root = remove_node(self.root, data)
    self.size = self.size - 1
    return self
end

-- 遍历方法
function BinarySearchTree:inorder()
    local result = {}
    
    local function inorder_traverse(node)
        if node then
            inorder_traverse(node.left)
            table.insert(result, node.data)
            inorder_traverse(node.right)
        end
    end
    
    inorder_traverse(self.root)
    return result
end

function BinarySearchTree:preorder()
    local result = {}
    
    local function preorder_traverse(node)
        if node then
            table.insert(result, node.data)
            preorder_traverse(node.left)
            preorder_traverse(node.right)
        end
    end
    
    preorder_traverse(self.root)
    return result
end

function BinarySearchTree:postorder()
    local result = {}
    
    local function postorder_traverse(node)
        if node then
            postorder_traverse(node.left)
            postorder_traverse(node.right)
            table.insert(result, node.data)
        end
    end
    
    postorder_traverse(self.root)
    return result
end

function BinarySearchTree:get_size()
    return self.size
end

function BinarySearchTree:is_empty()
    return self.root == nil
end

-- 二叉搜索树示例
local function binary_search_tree_example()
    print("\n=== 二叉搜索树示例 ===")
    
    local bst = BinarySearchTree.new()
    
    -- 插入数据
    local data = {50, 30, 70, 20, 40, 60, 80}
    for _, value in ipairs(data) do
        bst:insert(value)
    end
    
    print("树的大小:", bst:get_size())
    
    -- 遍历
    print("中序遍历:", table.concat(bst:inorder(), ", "))
    print("前序遍历:", table.concat(bst:preorder(), ", "))
    print("后序遍历:", table.concat(bst:postorder(), ", "))
    
    -- 搜索
    print("搜索40:", bst:search(40))
    print("搜索25:", bst:search(25))
    
    -- 删除
    bst:remove(30)
    print("删除30后中序遍历:", table.concat(bst:inorder(), ", "))
end

binary_search_tree_example()

6. 哈希表(Hash Table)

虽然Lua的表本身就是哈希表,但我们可以实现一个更明确的哈希表结构。

6.1 自定义哈希表

lua
-- 哈希表实现
local HashTable = {}
HashTable.__index = HashTable

function HashTable.new(initial_capacity)
    local capacity = initial_capacity or 16
    return setmetatable({
        buckets = {},
        capacity = capacity,
        size = 0,
        load_factor_threshold = 0.75
    }, HashTable)
end

-- 简单的哈希函数
function HashTable:hash(key)
    if type(key) == "string" then
        local hash = 0
        for i = 1, #key do
            hash = (hash * 31 + string.byte(key, i)) % self.capacity
        end
        return hash + 1  -- Lua数组从1开始
    elseif type(key) == "number" then
        return (key % self.capacity) + 1
    else
        return (tostring(key):len() % self.capacity) + 1
    end
end

function HashTable:put(key, value)
    local index = self:hash(key)
    
    if not self.buckets[index] then
        self.buckets[index] = {}
    end
    
    local bucket = self.buckets[index]
    
    -- 查找是否已存在
    for i, pair in ipairs(bucket) do
        if pair.key == key then
            pair.value = value
            return self
        end
    end
    
    -- 添加新键值对
    table.insert(bucket, {key = key, value = value})
    self.size = self.size + 1
    
    -- 检查是否需要扩容
    if self.size / self.capacity > self.load_factor_threshold then
        self:resize()
    end
    
    return self
end

function HashTable:get(key)
    local index = self:hash(key)
    local bucket = self.buckets[index]
    
    if bucket then
        for _, pair in ipairs(bucket) do
            if pair.key == key then
                return pair.value
            end
        end
    end
    
    return nil
end

function HashTable:remove(key)
    local index = self:hash(key)
    local bucket = self.buckets[index]
    
    if bucket then
        for i, pair in ipairs(bucket) do
            if pair.key == key then
                local value = pair.value
                table.remove(bucket, i)
                self.size = self.size - 1
                return value
            end
        end
    end
    
    return nil
end

function HashTable:contains(key)
    return self:get(key) ~= nil
end

function HashTable:get_size()
    return self.size
end

function HashTable:is_empty()
    return self.size == 0
end

function HashTable:resize()
    local old_buckets = self.buckets
    self.capacity = self.capacity * 2
    self.buckets = {}
    local old_size = self.size
    self.size = 0
    
    -- 重新插入所有元素
    for _, bucket in pairs(old_buckets) do
        if bucket then
            for _, pair in ipairs(bucket) do
                self:put(pair.key, pair.value)
            end
        end
    end
    
    print(string.format("哈希表扩容: %d -> %d, 元素数量: %d", 
        self.capacity / 2, self.capacity, old_size))
end

function HashTable:get_load_factor()
    return self.size / self.capacity
end

function HashTable:get_keys()
    local keys = {}
    for _, bucket in pairs(self.buckets) do
        if bucket then
            for _, pair in ipairs(bucket) do
                table.insert(keys, pair.key)
            end
        end
    end
    return keys
end

function HashTable:get_values()
    local values = {}
    for _, bucket in pairs(self.buckets) do
        if bucket then
            for _, pair in ipairs(bucket) do
                table.insert(values, pair.value)
            end
        end
    end
    return values
end

-- 哈希表示例
local function hash_table_example()
    print("\n=== 哈希表示例 ===")
    
    local ht = HashTable.new(4)  -- 小容量以演示扩容
    
    -- 添加键值对
    ht:put("name", "Alice")
    ht:put("age", 30)
    ht:put("city", "Beijing")
    ht:put("job", "Engineer")
    ht:put("hobby", "Reading")  -- 这会触发扩容
    
    print("哈希表大小:", ht:get_size())
    print("负载因子:", string.format("%.2f", ht:get_load_factor()))
    
    -- 获取值
    print("姓名:", ht:get("name"))
    print("年龄:", ht:get("age"))
    print("不存在的键:", ht:get("nonexistent"))
    
    -- 检查键是否存在
    print("包含'city':", ht:contains("city"))
    
    -- 删除键
    print("删除'job':", ht:remove("job"))
    print("删除后大小:", ht:get_size())
    
    -- 获取所有键和值
    print("所有键:", table.concat(ht:get_keys(), ", "))
    print("所有值:", table.concat(ht:get_values(), ", "))
end

hash_table_example()

7. 图(Graph)

图是由顶点和边组成的数据结构。

7.1 邻接表表示的图

lua
-- 图的实现(邻接表)
local Graph = {}
Graph.__index = Graph

function Graph.new(directed)
    return setmetatable({
        vertices = {},
        directed = directed or false
    }, Graph)
end

function Graph:add_vertex(vertex)
    if not self.vertices[vertex] then
        self.vertices[vertex] = {}
    end
    return self
end

function Graph:add_edge(from, to, weight)
    self:add_vertex(from)
    self:add_vertex(to)
    
    table.insert(self.vertices[from], {vertex = to, weight = weight or 1})
    
    if not self.directed then
        table.insert(self.vertices[to], {vertex = from, weight = weight or 1})
    end
    
    return self
end

function Graph:remove_vertex(vertex)
    if not self.vertices[vertex] then
        return self
    end
    
    -- 删除所有指向该顶点的边
    for v, edges in pairs(self.vertices) do
        for i = #edges, 1, -1 do
            if edges[i].vertex == vertex then
                table.remove(edges, i)
            end
        end
    end
    
    -- 删除顶点
    self.vertices[vertex] = nil
    return self
end

function Graph:remove_edge(from, to)
    if self.vertices[from] then
        for i = #self.vertices[from], 1, -1 do
            if self.vertices[from][i].vertex == to then
                table.remove(self.vertices[from], i)
            end
        end
    end
    
    if not self.directed and self.vertices[to] then
        for i = #self.vertices[to], 1, -1 do
            if self.vertices[to][i].vertex == from then
                table.remove(self.vertices[to], i)
            end
        end
    end
    
    return self
end

function Graph:get_neighbors(vertex)
    return self.vertices[vertex] or {}
end

function Graph:has_vertex(vertex)
    return self.vertices[vertex] ~= nil
end

function Graph:has_edge(from, to)
    if not self.vertices[from] then
        return false
    end
    
    for _, edge in ipairs(self.vertices[from]) do
        if edge.vertex == to then
            return true
        end
    end
    
    return false
end

function Graph:get_vertices()
    local vertices = {}
    for vertex in pairs(self.vertices) do
        table.insert(vertices, vertex)
    end
    return vertices
end

function Graph:vertex_count()
    local count = 0
    for _ in pairs(self.vertices) do
        count = count + 1
    end
    return count
end

function Graph:edge_count()
    local count = 0
    for _, edges in pairs(self.vertices) do
        count = count + #edges
    end
    
    if not self.directed then
        count = count / 2
    end
    
    return count
end

-- 深度优先搜索
function Graph:dfs(start_vertex, visit_func)
    local visited = {}
    
    local function dfs_recursive(vertex)
        visited[vertex] = true
        if visit_func then
            visit_func(vertex)
        end
        
        for _, edge in ipairs(self.vertices[vertex] or {}) do
            if not visited[edge.vertex] then
                dfs_recursive(edge.vertex)
            end
        end
    end
    
    dfs_recursive(start_vertex)
    return visited
end

-- 广度优先搜索
function Graph:bfs(start_vertex, visit_func)
    local visited = {}
    local queue = {start_vertex}
    local front = 1
    
    visited[start_vertex] = true
    
    while front <= #queue do
        local vertex = queue[front]
        front = front + 1
        
        if visit_func then
            visit_func(vertex)
        end
        
        for _, edge in ipairs(self.vertices[vertex] or {}) do
            if not visited[edge.vertex] then
                visited[edge.vertex] = true
                table.insert(queue, edge.vertex)
            end
        end
    end
    
    return visited
end

-- 图示例
local function graph_example()
    print("\n=== 图示例 ===")
    
    local graph = Graph.new(false)  -- 无向图
    
    -- 添加边(会自动添加顶点)
    graph:add_edge("A", "B")
    graph:add_edge("A", "C")
    graph:add_edge("B", "D")
    graph:add_edge("C", "D")
    graph:add_edge("D", "E")
    
    print("顶点数量:", graph:vertex_count())
    print("边数量:", graph:edge_count())
    print("所有顶点:", table.concat(graph:get_vertices(), ", "))
    
    -- 检查连接
    print("A和B是否相连:", graph:has_edge("A", "B"))
    print("A和E是否相连:", graph:has_edge("A", "E"))
    
    -- 获取邻居
    local neighbors = graph:get_neighbors("A")
    local neighbor_names = {}
    for _, edge in ipairs(neighbors) do
        table.insert(neighbor_names, edge.vertex)
    end
    print("A的邻居:", table.concat(neighbor_names, ", "))
    
    -- 深度优先搜索
    print("\n从A开始的DFS:")
    graph:dfs("A", function(vertex)
        io.write(vertex .. " ")
    end)
    print()
    
    -- 广度优先搜索
    print("从A开始的BFS:")
    graph:bfs("A", function(vertex)
        io.write(vertex .. " ")
    end)
    print()
end

graph_example()

总结

本文档介绍了如何在Lua中实现各种常用的数据结构:

  1. 栈(Stack) - 后进先出的数据结构,适用于函数调用、表达式求值等
  2. 队列(Queue) - 先进先出的数据结构,适用于任务调度、广度优先搜索等
  3. 链表(Linked List) - 动态数据结构,支持高效的插入和删除操作
  4. 集合(Set) - 不包含重复元素的数据结构,支持集合运算
  5. 二叉搜索树(BST) - 有序的树结构,支持高效的搜索、插入和删除
  6. 哈希表(Hash Table) - 基于哈希函数的键值存储结构
  7. 图(Graph) - 由顶点和边组成的复杂数据结构

这些数据结构的实现展示了Lua表的强大灵活性,以及如何利用元表和面向对象编程来创建复杂的数据结构。在实际应用中,可以根据具体需求选择合适的数据结构来提高程序的效率和可维护性。

基于 MIT 许可发布