Skip to content

Lua协程

概述

协程(Coroutine)是Lua的一个强大特性,它提供了一种协作式多任务处理的方式。与传统的抢占式多线程不同,协程是用户态的轻量级线程,可以在执行过程中主动让出控制权,然后在需要时恢复执行。

1. 协程基础概念

1.1 什么是协程

协程是一种可以暂停和恢复执行的函数。它具有以下特点:

  • 协作式:协程主动让出控制权,而不是被强制中断
  • 轻量级:创建和切换的开销很小
  • 状态保持:暂停时保持局部变量和执行位置
  • 单线程:在同一时刻只有一个协程在执行
lua
-- 创建一个简单的协程
local function simple_coroutine()
    print("协程开始执行")
    coroutine.yield()  -- 暂停执行
    print("协程恢复执行")
    coroutine.yield()  -- 再次暂停
    print("协程结束")
end

-- 创建协程
local co = coroutine.create(simple_coroutine)

-- 查看协程状态
print(coroutine.status(co))  -- suspended

-- 启动协程
coroutine.resume(co)  -- 输出: 协程开始执行
print(coroutine.status(co))  -- suspended

-- 恢复协程
coroutine.resume(co)  -- 输出: 协程恢复执行
print(coroutine.status(co))  -- suspended

-- 再次恢复协程
coroutine.resume(co)  -- 输出: 协程结束
print(coroutine.status(co))  -- dead

1.2 协程状态

协程有四种状态:

  • suspended:暂停状态,可以被恢复
  • running:正在运行
  • dead:执行完毕或出错
  • normal:恢复了另一个协程
lua
local function status_demo()
    print("当前状态:", coroutine.status(coroutine.running()))  -- running
    coroutine.yield("暂停中")
    print("恢复执行")
end

local co = coroutine.create(status_demo)
print("创建后状态:", coroutine.status(co))  -- suspended

local success, message = coroutine.resume(co)
print("恢复后状态:", coroutine.status(co))  -- suspended
print("返回值:", message)  -- 暂停中

coroutine.resume(co)
print("结束后状态:", coroutine.status(co))  -- dead

2. 协程的创建和控制

2.1 coroutine.create 和 coroutine.wrap

lua
-- 方法1: coroutine.create
local function task1()
    for i = 1, 3 do
        print("Task1:", i)
        coroutine.yield()
    end
end

local co1 = coroutine.create(task1)

-- 需要使用coroutine.resume来调用
coroutine.resume(co1)  -- Task1: 1
coroutine.resume(co1)  -- Task1: 2
coroutine.resume(co1)  -- Task1: 3

-- 方法2: coroutine.wrap (更简洁)
local function task2()
    for i = 1, 3 do
        print("Task2:", i)
        coroutine.yield()
    end
end

local co2 = coroutine.wrap(task2)

-- 可以直接调用
co2()  -- Task2: 1
co2()  -- Task2: 2
co2()  -- Task2: 3

2.2 参数传递和返回值

lua
-- 协程可以接收参数和返回值
local function calculator()
    local a, b = coroutine.yield("请输入两个数字")
    local result = a + b
    return result
end

local co = coroutine.create(calculator)

-- 启动协程
local success, message = coroutine.resume(co)
print(message)  -- 请输入两个数字

-- 传递参数并获取结果
local success, result = coroutine.resume(co, 10, 20)
print("计算结果:", result)  -- 计算结果: 30

-- 使用wrap方式
local calc_wrap = coroutine.wrap(function()
    local x = coroutine.yield("输入x:")
    local y = coroutine.yield("输入y:")
    return x * y
end)

print(calc_wrap())        -- 输入x:
print(calc_wrap(5))       -- 输入y:
print(calc_wrap(6))       -- 30

3. 协程的实际应用

3.1 生成器模式

lua
-- 斐波那契数列生成器
local function fibonacci_generator()
    local a, b = 0, 1
    while true do
        coroutine.yield(a)
        a, b = b, a + b
    end
end

local fib = coroutine.wrap(fibonacci_generator)

-- 生成前10个斐波那契数
for i = 1, 10 do
    print(string.format("fib(%d) = %d", i, fib()))
end

-- 素数生成器
local function prime_generator()
    local function is_prime(n)
        if n < 2 then return false end
        for i = 2, math.sqrt(n) do
            if n % i == 0 then return false end
        end
        return true
    end
    
    local n = 2
    while true do
        if is_prime(n) then
            coroutine.yield(n)
        end
        n = n + 1
    end
end

local primes = coroutine.wrap(prime_generator)

-- 生成前10个素数
print("前10个素数:")
for i = 1, 10 do
    print(primes())
end

3.2 迭代器实现

lua
-- 使用协程实现树的遍历
local function create_tree_iterator(tree)
    return coroutine.wrap(function()
        local function traverse(node)
            if node then
                coroutine.yield(node.value)
                traverse(node.left)
                traverse(node.right)
            end
        end
        traverse(tree)
    end)
end

-- 创建一个简单的二叉树
local tree = {
    value = 1,
    left = {
        value = 2,
        left = {value = 4},
        right = {value = 5}
    },
    right = {
        value = 3,
        left = {value = 6},
        right = {value = 7}
    }
}

-- 遍历树
print("树的前序遍历:")
for value in create_tree_iterator(tree) do
    print(value)
end

-- 文件行迭代器
local function lines_iterator(filename)
    return coroutine.wrap(function()
        local file = io.open(filename, "r")
        if not file then
            error("无法打开文件: " .. filename)
        end
        
        for line in file:lines() do
            coroutine.yield(line)
        end
        
        file:close()
    end)
end

-- 使用示例(需要实际文件)
--[[
for line in lines_iterator("test.txt") do
    print("行内容:", line)
end
--]]

3.3 状态机实现

lua
-- 使用协程实现状态机
local function create_state_machine()
    local state = "idle"
    
    return coroutine.wrap(function()
        while true do
            local event = coroutine.yield(state)
            
            if state == "idle" then
                if event == "start" then
                    state = "running"
                    print("状态转换: idle -> running")
                end
            elseif state == "running" then
                if event == "pause" then
                    state = "paused"
                    print("状态转换: running -> paused")
                elseif event == "stop" then
                    state = "idle"
                    print("状态转换: running -> idle")
                end
            elseif state == "paused" then
                if event == "resume" then
                    state = "running"
                    print("状态转换: paused -> running")
                elseif event == "stop" then
                    state = "idle"
                    print("状态转换: paused -> idle")
                end
            end
        end
    end)
end

-- 使用状态机
local sm = create_state_machine()

print("初始状态:", sm())        -- idle
print("当前状态:", sm("start"))  -- running
print("当前状态:", sm("pause"))  -- paused
print("当前状态:", sm("resume")) -- running
print("当前状态:", sm("stop"))   -- idle

4. 协程的高级用法

4.1 协程池

lua
-- 协程池实现
local CoroutinePool = {}
CoroutinePool.__index = CoroutinePool

function CoroutinePool.new(size)
    local pool = {
        size = size,
        available = {},
        busy = {},
        tasks = {}
    }
    
    -- 预创建协程
    for i = 1, size do
        local co = coroutine.create(function()
            while true do
                local task = coroutine.yield()
                if task then
                    task()
                end
            end
        end)
        table.insert(pool.available, co)
    end
    
    return setmetatable(pool, CoroutinePool)
end

function CoroutinePool:execute(task)
    if #self.available > 0 then
        local co = table.remove(self.available)
        table.insert(self.busy, co)
        
        local success, error_msg = coroutine.resume(co, task)
        if not success then
            print("协程执行错误:", error_msg)
        end
        
        -- 任务完成后回收协程
        for i, busy_co in ipairs(self.busy) do
            if busy_co == co then
                table.remove(self.busy, i)
                table.insert(self.available, co)
                break
            end
        end
    else
        -- 没有可用协程,加入任务队列
        table.insert(self.tasks, task)
    end
end

function CoroutinePool:get_stats()
    return {
        available = #self.available,
        busy = #self.busy,
        queued = #self.tasks
    }
end

-- 使用协程池
local pool = CoroutinePool.new(3)

-- 提交任务
for i = 1, 5 do
    pool:execute(function()
        print("执行任务", i)
        -- 模拟耗时操作
        local start = os.clock()
        while os.clock() - start < 0.1 do end
        print("任务", i, "完成")
    end)
end

local stats = pool:get_stats()
print(string.format("可用: %d, 忙碌: %d, 排队: %d", 
    stats.available, stats.busy, stats.queued))

4.2 协程通信

lua
-- 协程间通信示例
local function create_producer_consumer()
    local buffer = {}
    local max_size = 5
    
    local producer = coroutine.create(function()
        for i = 1, 10 do
            -- 等待缓冲区有空间
            while #buffer >= max_size do
                coroutine.yield("buffer_full")
            end
            
            table.insert(buffer, "item_" .. i)
            print("生产:", "item_" .. i, "缓冲区大小:", #buffer)
            coroutine.yield("produced")
        end
        coroutine.yield("producer_done")
    end)
    
    local consumer = coroutine.create(function()
        while true do
            -- 等待缓冲区有数据
            while #buffer == 0 do
                coroutine.yield("buffer_empty")
            end
            
            local item = table.remove(buffer, 1)
            print("消费:", item, "缓冲区大小:", #buffer)
            coroutine.yield("consumed")
        end
    end)
    
    return producer, consumer, buffer
end

-- 协调生产者和消费者
local function run_producer_consumer()
    local producer, consumer, buffer = create_producer_consumer()
    local producer_done = false
    
    while not producer_done or #buffer > 0 do
        -- 运行生产者
        if not producer_done then
            local success, status = coroutine.resume(producer)
            if status == "producer_done" then
                producer_done = true
                print("生产者完成")
            end
        end
        
        -- 运行消费者
        if #buffer > 0 then
            coroutine.resume(consumer)
        end
        
        -- 简单的调度延迟
        local start = os.clock()
        while os.clock() - start < 0.01 do end
    end
    
    print("生产消费完成")
end

run_producer_consumer()

4.3 异步操作模拟

lua
-- 模拟异步文件操作
local function async_file_operations()
    local function async_read_file(filename)
        return coroutine.wrap(function()
            print("开始读取文件:", filename)
            -- 模拟异步IO操作
            for i = 1, 3 do
                coroutine.yield("reading...")
            end
            return "文件内容: " .. filename
        end)
    end
    
    local function async_write_file(filename, content)
        return coroutine.wrap(function()
            print("开始写入文件:", filename)
            -- 模拟异步IO操作
            for i = 1, 2 do
                coroutine.yield("writing...")
            end
            return "写入完成: " .. filename
        end)
    end
    
    -- 创建异步操作
    local read_op = async_read_file("input.txt")
    local write_op = async_write_file("output.txt", "some content")
    
    -- 并发执行异步操作
    local read_done, write_done = false, false
    local read_result, write_result
    
    while not read_done or not write_done do
        if not read_done then
            local result = read_op()
            if result and not string.match(result, "%.%.%.") then
                read_result = result
                read_done = true
                print("读取完成:", read_result)
            else
                print("读取状态:", result)
            end
        end
        
        if not write_done then
            local result = write_op()
            if result and not string.match(result, "%.%.%.") then
                write_result = result
                write_done = true
                print("写入完成:", write_result)
            else
                print("写入状态:", result)
            end
        end
        
        -- 模拟事件循环延迟
        local start = os.clock()
        while os.clock() - start < 0.1 do end
    end
end

async_file_operations()

5. 协程测试和调试

5.1 协程测试框架

lua
-- 简单的协程测试框架
local CoroutineTest = {}

function CoroutineTest.run_test(name, test_func)
    print("运行测试:", name)
    
    local success, error_msg = pcall(test_func)
    
    if success then
        print("✓ 测试通过:", name)
    else
        print("✗ 测试失败:", name, error_msg)
    end
end

function CoroutineTest.assert_coroutine_status(co, expected_status, message)
    local actual_status = coroutine.status(co)
    if actual_status ~= expected_status then
        error(message or string.format("期望状态 %s,实际状态 %s", expected_status, actual_status))
    end
end

-- 测试用例
CoroutineTest.run_test("协程创建测试", function()
    local co = coroutine.create(function()
        coroutine.yield()
    end)
    
    CoroutineTest.assert_coroutine_status(co, "suspended", "新创建的协程应该是suspended状态")
end)

CoroutineTest.run_test("协程恢复测试", function()
    local co = coroutine.create(function()
        coroutine.yield("test_value")
        return "finished"
    end)
    
    local success, value = coroutine.resume(co)
    assert(success, "协程恢复应该成功")
    assert(value == "test_value", "应该返回正确的yield值")
    
    CoroutineTest.assert_coroutine_status(co, "suspended", "yield后应该是suspended状态")
    
    local success, result = coroutine.resume(co)
    assert(success, "协程第二次恢复应该成功")
    assert(result == "finished", "应该返回正确的返回值")
    
    CoroutineTest.assert_coroutine_status(co, "dead", "完成后应该是dead状态")
end)

CoroutineTest.run_test("协程错误处理测试", function()
    local co = coroutine.create(function()
        error("测试错误")
    end)
    
    local success, error_msg = coroutine.resume(co)
    assert(not success, "有错误的协程恢复应该失败")
    assert(string.find(error_msg, "测试错误"), "应该包含错误信息")
    
    CoroutineTest.assert_coroutine_status(co, "dead", "出错后应该是dead状态")
end)

5.2 协程调试工具

lua
-- 协程调试工具
local CoroutineDebugger = {}

function CoroutineDebugger.trace_coroutine(co, name)
    name = name or "unnamed"
    
    local original_resume = coroutine.resume
    local original_yield = coroutine.yield
    
    -- 包装resume函数
    local function traced_resume(...)
        print(string.format("[DEBUG] 恢复协程 %s", name))
        local results = {original_resume(...)}
        print(string.format("[DEBUG] 协程 %s 状态: %s", name, coroutine.status(co)))
        if results[1] then
            print(string.format("[DEBUG] 协程 %s 返回: %s", name, table.concat({select(2, table.unpack(results))}, ", ")))
        else
            print(string.format("[DEBUG] 协程 %s 错误: %s", name, results[2]))
        end
        return table.unpack(results)
    end
    
    return traced_resume, co
end

function CoroutineDebugger.monitor_yields()
    local original_yield = coroutine.yield
    
    coroutine.yield = function(...)
        local co = coroutine.running()
        print(string.format("[DEBUG] 协程 yield,参数: %s", table.concat({...}, ", ")))
        return original_yield(...)
    end
end

-- 使用调试工具
CoroutineDebugger.monitor_yields()

local function debug_test()
    print("协程开始")
    coroutine.yield("第一次yield")
    print("协程中间")
    coroutine.yield("第二次yield")
    print("协程结束")
    return "完成"
end

local co = coroutine.create(debug_test)
local traced_resume, traced_co = CoroutineDebugger.trace_coroutine(co, "测试协程")

traced_resume(traced_co)
traced_resume(traced_co)
traced_resume(traced_co)

6. 性能考虑和最佳实践

6.1 协程性能测试

lua
-- 协程vs函数调用性能对比
local function performance_test()
    local iterations = 1000000
    
    -- 普通函数调用
    local function normal_function()
        return 42
    end
    
    local start_time = os.clock()
    for i = 1, iterations do
        normal_function()
    end
    local normal_time = os.clock() - start_time
    
    -- 协程调用
    local function coroutine_function()
        coroutine.yield(42)
    end
    
    start_time = os.clock()
    for i = 1, iterations do
        local co = coroutine.create(coroutine_function)
        coroutine.resume(co)
    end
    local coroutine_time = os.clock() - start_time
    
    print(string.format("普通函数调用: %.4f 秒", normal_time))
    print(string.format("协程调用: %.4f 秒", coroutine_time))
    print(string.format("协程开销: %.2f 倍", coroutine_time / normal_time))
end

performance_test()

6.2 最佳实践

lua
-- 最佳实践示例

-- 1. 合理使用协程池避免频繁创建
local function create_reusable_coroutine_pool()
    local pool = {}
    
    local function get_coroutine(func)
        local co = table.remove(pool)
        if not co then
            co = coroutine.create(function()
                while true do
                    local task = coroutine.yield()
                    if task then
                        task()
                    end
                end
            end)
        end
        return co
    end
    
    local function return_coroutine(co)
        if coroutine.status(co) ~= "dead" then
            table.insert(pool, co)
        end
    end
    
    return get_coroutine, return_coroutine
end

-- 2. 错误处理
local function safe_coroutine_execution(func, ...)
    local co = coroutine.create(func)
    local success, result = coroutine.resume(co, ...)
    
    if not success then
        print("协程执行错误:", result)
        return nil, result
    end
    
    return result
end

-- 3. 避免深度嵌套的协程调用
local function avoid_deep_nesting()
    -- 不好的做法:深度嵌套
    local function bad_nested_coroutines(depth)
        if depth > 0 then
            local co = coroutine.create(function()
                bad_nested_coroutines(depth - 1)
            end)
            coroutine.resume(co)
        end
        coroutine.yield()
    end
    
    -- 好的做法:使用队列管理
    local function good_queue_based(tasks)
        local queue = {}
        
        for _, task in ipairs(tasks) do
            table.insert(queue, coroutine.create(task))
        end
        
        while #queue > 0 do
            local co = table.remove(queue, 1)
            local success = coroutine.resume(co)
            
            if success and coroutine.status(co) ~= "dead" then
                table.insert(queue, co)  -- 重新加入队列
            end
        end
    end
end

7. 实际项目应用

7.1 Web服务器协程调度

lua
-- 简化的Web服务器协程调度示例
local WebServer = {}

function WebServer.new()
    return {
        connections = {},
        request_handlers = {}
    }
end

function WebServer:handle_request(connection_id, request)
    local handler_co = coroutine.create(function()
        print("处理请求:", connection_id, request.path)
        
        -- 模拟异步数据库查询
        coroutine.yield("db_query")
        local db_result = "数据库结果"
        
        -- 模拟异步文件读取
        coroutine.yield("file_read")
        local file_content = "文件内容"
        
        -- 生成响应
        local response = {
            status = 200,
            body = db_result .. " + " .. file_content
        }
        
        return response
    end)
    
    self.request_handlers[connection_id] = handler_co
    return self:process_handler(connection_id)
end

function WebServer:process_handler(connection_id)
    local handler = self.request_handlers[connection_id]
    if not handler then return nil end
    
    local success, result = coroutine.resume(handler)
    
    if not success then
        print("处理器错误:", result)
        self.request_handlers[connection_id] = nil
        return {status = 500, body = "Internal Server Error"}
    end
    
    if coroutine.status(handler) == "dead" then
        self.request_handlers[connection_id] = nil
        return result  -- 最终响应
    end
    
    -- 需要继续处理
    return nil
end

function WebServer:tick()
    -- 处理所有待处理的请求
    for connection_id, handler in pairs(self.request_handlers) do
        local response = self:process_handler(connection_id)
        if response then
            print("发送响应:", connection_id, response.status)
        end
    end
end

-- 使用示例
local server = WebServer.new()

-- 模拟接收请求
server:handle_request("conn1", {path = "/api/users"})
server:handle_request("conn2", {path = "/api/posts"})

-- 模拟服务器事件循环
for i = 1, 5 do
    print("Tick", i)
    server:tick()
end

7.2 游戏AI行为树

lua
-- 使用协程实现游戏AI行为树
local BehaviorTree = {}

function BehaviorTree.sequence(...)
    local children = {...}
    return coroutine.wrap(function()
        for _, child in ipairs(children) do
            local result = child()
            if result ~= "success" then
                return result
            end
        end
        return "success"
    end)
end

function BehaviorTree.selector(...)
    local children = {...}
    return coroutine.wrap(function()
        for _, child in ipairs(children) do
            local result = child()
            if result == "success" then
                return "success"
            end
        end
        return "failure"
    end)
end

function BehaviorTree.condition(check_func)
    return coroutine.wrap(function()
        if check_func() then
            return "success"
        else
            return "failure"
        end
    end)
end

function BehaviorTree.action(action_func)
    return coroutine.wrap(function()
        return action_func()
    end)
end

-- 游戏AI示例
local function create_guard_ai(guard)
    local function can_see_player()
        return guard.player_distance < 10
    end
    
    local function is_player_in_range()
        return guard.player_distance < 3
    end
    
    local function patrol()
        print("守卫巡逻")
        return "success"
    end
    
    local function chase_player()
        print("守卫追击玩家")
        return "success"
    end
    
    local function attack_player()
        print("守卫攻击玩家")
        return "success"
    end
    
    -- 构建行为树
    return BehaviorTree.selector(
        BehaviorTree.sequence(
            BehaviorTree.condition(can_see_player),
            BehaviorTree.selector(
                BehaviorTree.sequence(
                    BehaviorTree.condition(is_player_in_range),
                    BehaviorTree.action(attack_player)
                ),
                BehaviorTree.action(chase_player)
            )
        ),
        BehaviorTree.action(patrol)
    )
end

-- 使用AI
local guard = {player_distance = 15}
local ai = create_guard_ai(guard)

-- 模拟游戏循环
for frame = 1, 5 do
    print("帧", frame)
    guard.player_distance = guard.player_distance - 3
    local result = ai()
    print("AI结果:", result)
    print()
end

总结

Lua协程是一个强大的特性,具有以下优势:

  1. 轻量级 - 创建和切换开销很小
  2. 协作式 - 程序员完全控制执行流程
  3. 状态保持 - 自动保存局部变量和执行位置
  4. 简单易用 - API简洁,概念清晰
  5. 应用广泛 - 适用于生成器、迭代器、状态机、异步编程等场景

掌握协程的使用可以让你写出更加优雅和高效的Lua程序,特别是在需要处理复杂控制流程的场景中。

基于 MIT 许可发布