Skip to content

Lua函数

概述

函数是Lua编程的核心概念之一。在Lua中,函数是第一类值(first-class values),这意味着函数可以存储在变量中、作为参数传递、作为返回值返回,以及在运行时创建。

1. 函数定义

1.1 基本函数定义

lua
-- 基本函数定义
function greet(name)
    print("Hello, " .. name .. "!")
end

-- 调用函数
greet("Lua")  -- 输出: Hello, Lua!

-- 带返回值的函数
function add(a, b)
    return a + b
end

local result = add(5, 3)
print(result)  -- 输出: 8

1.2 函数表达式

lua
-- 将函数赋值给变量
local multiply = function(a, b)
    return a * b
end

print(multiply(4, 5))  -- 输出: 20

-- 匿名函数
local numbers = {1, 2, 3, 4, 5}
table.sort(numbers, function(a, b) return a > b end)
-- numbers现在是{5, 4, 3, 2, 1}

1.3 局部函数

lua
-- 局部函数定义
local function factorial(n)
    if n <= 1 then
        return 1
    else
        return n * factorial(n - 1)
    end
end

print(factorial(5))  -- 输出: 120

-- 注意:局部函数的递归定义
local function fibonacci(n)
    if n <= 2 then
        return 1
    else
        return fibonacci(n - 1) + fibonacci(n - 2)
    end
end

2. 函数参数

2.1 固定参数

lua
function calculate_area(length, width)
    return length * width
end

print(calculate_area(10, 5))  -- 输出: 50

2.2 可变参数

lua
-- 使用...接收可变参数
function sum(...)
    local args = {...}  -- 将参数打包成表
    local total = 0
    
    for i = 1, #args do
        total = total + args[i]
    end
    
    return total
end

print(sum(1, 2, 3, 4, 5))  -- 输出: 15

-- 使用select函数处理可变参数
function print_args(...)
    local n = select("#", ...)  -- 获取参数个数
    
    for i = 1, n do
        local arg = select(i, ...)
        print("参数" .. i .. ":", arg)
    end
end

print_args("hello", 42, true)

2.3 默认参数

lua
-- Lua没有内置的默认参数,但可以通过or运算符实现
function greet_with_default(name, greeting)
    name = name or "World"
    greeting = greeting or "Hello"
    print(greeting .. ", " .. name .. "!")
end

greet_with_default()              -- Hello, World!
greet_with_default("Alice")       -- Hello, Alice!
greet_with_default("Bob", "Hi")   -- Hi, Bob!

-- 更安全的默认参数处理
function safe_divide(a, b)
    if a == nil then a = 0 end
    if b == nil then b = 1 end
    if b == 0 then
        error("除数不能为零")
    end
    return a / b
end

2.4 命名参数

lua
-- 使用表模拟命名参数
function create_person(options)
    options = options or {}
    
    local person = {
        name = options.name or "Unknown",
        age = options.age or 0,
        city = options.city or "Unknown"
    }
    
    return person
end

-- 使用命名参数
local person1 = create_person({
    name = "Alice",
    age = 30,
    city = "Beijing"
})

local person2 = create_person({name = "Bob", age = 25})

3. 函数返回值

3.1 单个返回值

lua
function square(x)
    return x * x
end

local result = square(5)
print(result)  -- 输出: 25

3.2 多个返回值

lua
-- 返回多个值
function get_name_age()
    return "Alice", 30
end

local name, age = get_name_age()
print(name, age)  -- 输出: Alice	30

-- 只接收部分返回值
local name_only = get_name_age()  -- 只接收第一个返回值
print(name_only)  -- 输出: Alice

-- 数学运算的商和余数
function divmod(a, b)
    return math.floor(a / b), a % b
end

local quotient, remainder = divmod(17, 5)
print(quotient, remainder)  -- 输出: 3	2

3.3 返回表

lua
-- 返回表结构
function get_user_info(id)
    -- 模拟数据库查询
    local users = {
        [1] = {name = "Alice", email = "alice@example.com", age = 30},
        [2] = {name = "Bob", email = "bob@example.com", age = 25}
    }
    
    return users[id]
end

local user = get_user_info(1)
if user then
    print(user.name, user.email, user.age)
end

4. 高阶函数

4.1 函数作为参数

lua
-- 接受函数作为参数的高阶函数
function apply_operation(a, b, operation)
    return operation(a, b)
end

-- 定义操作函数
local function add(x, y) return x + y end
local function multiply(x, y) return x * y end

print(apply_operation(5, 3, add))      -- 输出: 8
print(apply_operation(5, 3, multiply)) -- 输出: 15

-- 数组映射函数
function map(array, func)
    local result = {}
    for i, v in ipairs(array) do
        result[i] = func(v)
    end
    return result
end

local numbers = {1, 2, 3, 4, 5}
local squares = map(numbers, function(x) return x * x end)
-- squares = {1, 4, 9, 16, 25}

4.2 函数作为返回值

lua
-- 返回函数的函数
function create_multiplier(factor)
    return function(x)
        return x * factor
    end
end

local double = create_multiplier(2)
local triple = create_multiplier(3)

print(double(5))  -- 输出: 10
print(triple(5))  -- 输出: 15

-- 创建计数器
function create_counter(initial)
    initial = initial or 0
    return function()
        initial = initial + 1
        return initial
    end
end

local counter1 = create_counter()
local counter2 = create_counter(10)

print(counter1())  -- 1
print(counter1())  -- 2
print(counter2())  -- 11
print(counter2())  -- 12

5. 闭包

5.1 闭包基础

lua
-- 闭包示例:函数访问外部变量
function create_bank_account(initial_balance)
    local balance = initial_balance or 0
    
    return {
        deposit = function(amount)
            balance = balance + amount
            return balance
        end,
        
        withdraw = function(amount)
            if amount <= balance then
                balance = balance - amount
                return balance
            else
                error("余额不足")
            end
        end,
        
        get_balance = function()
            return balance
        end
    }
end

local account = create_bank_account(100)
print(account.get_balance())  -- 100
account.deposit(50)
print(account.get_balance())  -- 150
account.withdraw(30)
print(account.get_balance())  -- 120

5.2 闭包的实际应用

lua
-- 事件处理器
function create_event_handler()
    local handlers = {}
    
    return {
        on = function(event, handler)
            if not handlers[event] then
                handlers[event] = {}
            end
            table.insert(handlers[event], handler)
        end,
        
        emit = function(event, ...)
            if handlers[event] then
                for _, handler in ipairs(handlers[event]) do
                    handler(...)
                end
            end
        end
    }
end

local emitter = create_event_handler()

emitter.on("user_login", function(username)
    print("用户登录: " .. username)
end)

emitter.on("user_login", function(username)
    print("记录日志: " .. username .. " 已登录")
end)

emitter.emit("user_login", "Alice")
-- 输出:
-- 用户登录: Alice
-- 记录日志: Alice 已登录

6. 递归函数

6.1 基本递归

lua
-- 阶乘函数
function factorial(n)
    if n <= 1 then
        return 1
    else
        return n * factorial(n - 1)
    end
end

print(factorial(5))  -- 输出: 120

-- 斐波那契数列
function fibonacci(n)
    if n <= 2 then
        return 1
    else
        return fibonacci(n - 1) + fibonacci(n - 2)
    end
end

print(fibonacci(10))  -- 输出: 55

6.2 尾递归优化

lua
-- 尾递归版本的阶乘
function factorial_tail(n, acc)
    acc = acc or 1
    if n <= 1 then
        return acc
    else
        return factorial_tail(n - 1, n * acc)
    end
end

print(factorial_tail(5))  -- 输出: 120

-- 尾递归版本的斐波那契
function fibonacci_tail(n, a, b)
    a = a or 1
    b = b or 1
    if n <= 2 then
        return b
    else
        return fibonacci_tail(n - 1, b, a + b)
    end
end

print(fibonacci_tail(10))  -- 输出: 55

7. 函数的高级特性

7.1 函数重载模拟

lua
-- 通过参数类型和数量模拟函数重载
function print_value(value, format)
    local t = type(value)
    
    if t == "number" then
        if format == "hex" then
            print(string.format("0x%x", value))
        elseif format == "binary" then
            -- 简单的二进制转换
            local binary = ""
            local n = value
            while n > 0 do
                binary = (n % 2) .. binary
                n = math.floor(n / 2)
            end
            print("0b" .. binary)
        else
            print(value)
        end
    elseif t == "string" then
        if format == "upper" then
            print(string.upper(value))
        elseif format == "lower" then
            print(string.lower(value))
        else
            print(value)
        end
    else
        print(tostring(value))
    end
end

print_value(255, "hex")     -- 0xff
print_value("hello", "upper") -- HELLO
print_value(42)             -- 42

7.2 函数缓存(记忆化)

lua
-- 带缓存的斐波那契函数
function create_memoized_fibonacci()
    local cache = {}
    
    local function fib(n)
        if cache[n] then
            return cache[n]
        end
        
        local result
        if n <= 2 then
            result = 1
        else
            result = fib(n - 1) + fib(n - 2)
        end
        
        cache[n] = result
        return result
    end
    
    return fib
end

local fast_fib = create_memoized_fibonacci()
print(fast_fib(40))  -- 很快就能计算出结果

7.3 函数装饰器

lua
-- 计时装饰器
function time_it(func)
    return function(...)
        local start_time = os.clock()
        local results = {func(...)}
        local end_time = os.clock()
        
        print(string.format("函数执行时间: %.4f 秒", end_time - start_time))
        return table.unpack(results)
    end
end

-- 使用装饰器
local function slow_function(n)
    local sum = 0
    for i = 1, n do
        sum = sum + i
    end
    return sum
end

local timed_slow_function = time_it(slow_function)
local result = timed_slow_function(1000000)
print("结果:", result)

8. 错误处理

8.1 使用error函数

lua
function safe_divide(a, b)
    if type(a) ~= "number" or type(b) ~= "number" then
        error("参数必须是数字")
    end
    
    if b == 0 then
        error("除数不能为零")
    end
    
    return a / b
end

-- 使用pcall捕获错误
local success, result = pcall(safe_divide, 10, 2)
if success then
    print("结果:", result)
else
    print("错误:", result)
end

8.2 返回错误信息

lua
-- 返回结果和错误信息
function safe_sqrt(x)
    if type(x) ~= "number" then
        return nil, "参数必须是数字"
    end
    
    if x < 0 then
        return nil, "不能计算负数的平方根"
    end
    
    return math.sqrt(x), nil
end

local result, err = safe_sqrt(16)
if result then
    print("平方根:", result)
else
    print("错误:", err)
end

9. 实际应用示例

9.1 配置管理

lua
-- 配置管理函数
function create_config_manager(default_config)
    local config = {}
    
    -- 复制默认配置
    for k, v in pairs(default_config or {}) do
        config[k] = v
    end
    
    return {
        get = function(key, default)
            return config[key] or default
        end,
        
        set = function(key, value)
            config[key] = value
        end,
        
        load_from_table = function(new_config)
            for k, v in pairs(new_config) do
                config[k] = v
            end
        end,
        
        get_all = function()
            local copy = {}
            for k, v in pairs(config) do
                copy[k] = v
            end
            return copy
        end
    }
end

-- 使用配置管理器
local config = create_config_manager({
    host = "localhost",
    port = 8080,
    debug = false
})

print(config.get("host"))        -- localhost
config.set("debug", true)
print(config.get("debug"))       -- true

9.2 数据验证

lua
-- 数据验证函数
function create_validator()
    local rules = {}
    
    return {
        add_rule = function(field, validator, message)
            if not rules[field] then
                rules[field] = {}
            end
            table.insert(rules[field], {validator = validator, message = message})
        end,
        
        validate = function(data)
            local errors = {}
            
            for field, field_rules in pairs(rules) do
                local value = data[field]
                
                for _, rule in ipairs(field_rules) do
                    if not rule.validator(value) then
                        if not errors[field] then
                            errors[field] = {}
                        end
                        table.insert(errors[field], rule.message)
                    end
                end
            end
            
            return next(errors) == nil, errors
        end
    }
end

-- 使用验证器
local validator = create_validator()

validator.add_rule("name", function(v) return type(v) == "string" and #v > 0 end, "姓名不能为空")
validator.add_rule("age", function(v) return type(v) == "number" and v >= 0 and v <= 150 end, "年龄必须在0-150之间")
validator.add_rule("email", function(v) return type(v) == "string" and string.match(v, "@") end, "邮箱格式不正确")

local user_data = {
    name = "Alice",
    age = 30,
    email = "alice@example.com"
}

local is_valid, errors = validator.validate(user_data)
if is_valid then
    print("数据验证通过")
else
    for field, field_errors in pairs(errors) do
        print(field .. ":", table.concat(field_errors, ", "))
    end
end

10. 性能优化

10.1 避免全局查找

lua
-- 缓存全局函数到局部变量
local math_sin = math.sin
local math_cos = math.cos

function calculate_circle_points(radius, num_points)
    local points = {}
    local angle_step = 2 * math.pi / num_points
    
    for i = 1, num_points do
        local angle = (i - 1) * angle_step
        points[i] = {
            x = radius * math_cos(angle),
            y = radius * math_sin(angle)
        }
    end
    
    return points
end

10.2 尾调用优化

lua
-- 利用尾调用优化避免栈溢出
function sum_tail(n, acc)
    acc = acc or 0
    if n <= 0 then
        return acc
    else
        return sum_tail(n - 1, acc + n)  -- 尾调用
    end
end

print(sum_tail(100000))  -- 不会栈溢出

总结

Lua函数具有以下特点:

  1. 第一类值 - 可以存储、传递和返回
  2. 支持闭包 - 可以访问外部作用域的变量
  3. 多返回值 - 一个函数可以返回多个值
  4. 可变参数 - 支持不定数量的参数
  5. 尾调用优化 - 避免深度递归的栈溢出
  6. 高阶函数 - 支持函数式编程范式

掌握函数的使用是编写高质量Lua代码的关键,合理运用这些特性可以写出简洁、高效、可维护的程序。

基于 MIT 许可发布