2023年政策修订增补工作正在进行中,欢迎参与!
Module:Sandbox/あめろ
跳转到导航
跳转到搜索
local fmt = string.format
local type = type
local ipairs = ipairs
local get_mt = getmetatable
local set_mt = setmetatable
-- Schema
local schema = {}
local function is_raw_table(val)
return type(val) == 'table' and not get_mt(val)
end
--- 过滤数组
local function ifilter(t, filter)
local new = {}
local filtered_num = 0
for i, v in ipairs(t) do
if filter(v) then
new[i - filtered_num] = v
else
filtered_num = filtered_num + 1
end
end
return new, filtered_num
end
---@param ty string 类型
---@param a string? 冠词
---@return function
local function type_checker(ty, a)
local fmt_str = "%s (type: %s) isn't "
if a then
fmt_str = fmt_str..a..' '
end
fmt_str = fmt_str..ty
return function(self, testee)
if type(testee) == ty then
return true
end
return false, fmt(fmt_str, testee, type(testee))
end
end
--- `val`是`x`或`{x}`时返回`{x}`,返回的`{x}`与前一个`{x}`是同一个对象
local function ensure_wrapped(val)
return is_raw_table(val) and val or {val}
end
local mts = set_mt({}, {__mode = 'k'}) ---@type {[metatable]: true}
---@param name string
---@param super_mt metatable?
---@param without_override boolean?
---@return metatable
local function reg_mt(name, super_mt, without_override)
local index = super_mt and super_mt.__index or {}
if not without_override then
index = set_mt({}, {__index = index})
end
local mt = {
__name = name,
__index = index,
}
mts[mt] = true
return mt
end
---@param v any
---@return string | nil
local function get_scm_type(v)
local mt = get_mt(v)
if not mts[mt] then return nil end
return mt.__name
end
---@param constraints table
---@return table | nil
local function get_validators_from_constraints(constraints)
local t = ifilter(
ensure_wrapped(constraints.validators or constraints.validator),
function(v)
return assert(type(v) == 'function', 'validator需要是函数或元素为函数的表')
end
)
return t[1] and t or nil
end
---@param validators function[]?
---@param val any
---@return boolean, string?
local function validate_all(validators, val)
if not validators then return true end
for _, validator in ipairs(validators) do
local valid, msg = validator(val)
if not valid then
return false, msg
end
end
return true
end
local Any_mt = reg_mt('Any', nil)
schema.Any = set_mt({
test = function()
return true
end,
}, Any_mt)
function Any_mt:__call(constraints)
return set_mt({
super = self,
validators = get_validators_from_constraints(constraints),
}, get_mt(self))
end
function Any_mt.__index:test(testee)
if self.super then
local valid, msg = self.super:test(testee)
if not valid then
return false, msg
end
end
return validate_all(self.validators, testee)
end
function Any_mt.__index:assert(testee)
local valid, msg = self:test(testee)
if valid then
return true
end
error(msg, 2)
end
local Nil_mt = reg_mt('Nil', Any_mt, true)
schema.Nil = set_mt({
test = type_checker('nil'),
}, Nil_mt)
local Boolean_mt = reg_mt('Boolean', Any_mt, true)
schema.Boolean = set_mt({
test = type_checker('boolean', 'a'),
}, Boolean_mt)
local Number_mt = reg_mt('Number', Any_mt)
schema.Number = set_mt({
test = type_checker('number', 'a'),
}, Number_mt)
function Number_mt:__call(constraints)
return set_mt({
super = self,
int = constraints.int,
lt = constraints.lt,
gt = constraints.gt,
le = constraints.le or constraints.max,
ge = constraints.ge or constraints.min,
ne = constraints.ne,
validators = get_validators_from_constraints(constraints),
}, Number_mt)
end
function Number_mt.__index:test(testee)
if self.super then
local valid, msg = self.super:test(testee)
if not valid then
return false, msg
end
end
if self.int and math.fmod(testee, 1) ~= 0 then
return false, fmt("%s isn't an integer", testee)
end
if self.lt and testee >= self.lt then
return false, fmt("%s isn't < %s", testee, self.lt)
end
if self.gt and testee <= self.gt then
return false, fmt("%s isn't > %s", testee, self.gt)
end
if self.le and testee > self.le then
return false, fmt("%s isn't <= %s", testee, self.le)
end
if self.ge and testee < self.ge then
return false, fmt("%s isn't >= %s", testee, self.ge)
end
if self.ne and testee == self.ne then
return false, fmt('testee equals %s', self.ne)
end
return validate_all(self.validators, testee)
end
local String_mt = reg_mt('String', Any_mt)
schema.String = set_mt({
test = type_checker('string', 'a')
}, String_mt)
function String_mt:__call(constraints)
return set_mt({
super = self,
max_len = constraints.max_len,
min_len = constraints.min_len,
pattern = constraints.pattern,
validators = get_validators_from_constraints(constraints),
}, String_mt)
end
function String_mt.__index:test(testee)
if self.super then
local valid, msg = self.super:test(testee)
if not valid then
return false, msg
end
end
if self.max_len and #testee > self.max_len then
return false, fmt("the length of %q (%d) exceeds %s", testee, #testee, self.max_len)
end
if self.min_len and #testee < self.min_len then
return false, fmt("the length of %q (%d) is under %s", testee, #testee, self.min_len)
end
if self.pattern and not testee:match(self.pattern) then
return false, fmt("%q doesn't match the pattern %q", testee, self.pattern)
end
return validate_all(self.validators, testee)
end
local Function_mt = reg_mt('Function', Any_mt)
schema.Function = set_mt({
test = type_checker('function', 'a'),
}, Function_mt)
Function_mt.__call = Any_mt.__call
function Function_mt.__index:test(testee)
if self.super then
local valid, msg = self.super:test(testee)
if not valid then
return false, msg
end
end
return validate_all(self.validators, testee)
end
local Table_mt = reg_mt('Table', Any_mt)
schema.Table = set_mt({
test = type_checker('table', 'a')
}, Table_mt)
function Table_mt:__call(constraints)
local specific = {}
local generic = {}
for k, v in pairs(constraints) do
local scm_type = get_scm_type(k)
if scm_type then
if scm_type == 'Literal' then
specific[k.val] = v
else
generic[k] = v
end
elseif k ~= 'validators' and k ~= 'validator' then
specific[k] = v
end
end
return set_mt({
super = self,
specific = specific,
generic = generic,
validators = get_validators_from_constraints(constraints)
}, Table_mt)
end
function Table_mt.__index:test(testee)
if self.super then
local valid, msg = self.super:test(testee)
if not valid then
return false, msg
end
end
for key_scm, val_scm in pairs(self.generic) do
for testee_key, testee_val in pairs(testee) do
if key_scm:test(testee_key) then
local valid, msg = val_scm:test(testee_val)
if not valid then
return false, msg
end
end
end
end
for key, val_scm in pairs(self.specific) do
local valid, msg = val_scm:test(testee[key])
if not valid then
return false, testee[key] == nil and fmt('`%s` misses field `%s`', testee, key) or msg
end
end
return validate_all(self.validators, testee)
end
local Union_mt = reg_mt('Union', Any_mt)
function schema.Union(...)
local union = {}
for i = 1, select('#', ...) do
local sub_scm = select(i, ...)
if sub_scm == nil then
union[schema.Nil] = true
elseif get_scm_type(sub_scm) == 'Union' then
for scm_in_union in next, sub_scm do
union[scm_in_union] = true
end
else
union[sub_scm] = true
end
end
return set_mt(union, Union_mt)
end
function Union_mt.__index:test(testee)
for allowed_val in next, self do
if get_scm_type(allowed_val) then
if allowed_val:test(testee) then
return true
end
elseif testee == allowed_val then
return true
end
end
return false, fmt('testee `%s` fails to match each value in the union: %s', testee, self)
end
for mt in next, mts do
mt.__bor = schema.Union
mt.__div = schema.Union
end
schema.Truthy = schema.Any{validator=function(v) return v end}
schema.Falsy = schema.Any{validator=function(v) return not v end}
return schema