diff --git a/typedlua/tlchecker.lua b/typedlua/tlchecker.lua index b34af306..3a32ebaa 100644 --- a/typedlua/tlchecker.lua +++ b/typedlua/tlchecker.lua @@ -4,13 +4,14 @@ This file implements Typed Lua type checker if not table.unpack then table.unpack = unpack end -local tlchecker = {} - local tlast = require "typedlua.tlast" local tlst = require "typedlua.tlst" local tltype = require "typedlua.tltype" local tlparser = require "typedlua.tlparser" local tldparser = require "typedlua.tldparser" +local tlvisitor = require "typedlua.tlvisitor" + +local tlchecker = setmetatable({}, { __index = tlvisitor }) local Value = tltype.Value() local Any = tltype.Any() @@ -23,8 +24,6 @@ local Number = tltype.Number() local String = tltype.String() local Integer = tltype.Integer(false) -local check_block, check_stm, check_exp, check_var - local function lineno (s, i) if i == 1 then return 1, 1 end local rest, num = s:sub(1,i):gsub("[^\n]*\n", "") @@ -76,7 +75,7 @@ local function check_self (env, torig, t, pos) return r elseif tltype.isTable(t) then local l = {} - for k, v in ipairs(t) do + for _, v in ipairs(t) do table.insert(l, tltype.Field(v.const, v[1], check_self_field(env, torig, v[2], pos))) end local r = tltype.Table(table.unpack(l)) @@ -89,7 +88,7 @@ local function check_self (env, torig, t, pos) end function check_self_field(env, torig, t, pos) - local msg = string.format("self type cannot appear in declaration of type '%s', replacing with 'any'", tltype.tostring(torig)) + --local msg = string.format("self type cannot appear in declaration of type '%s', replacing with 'any'", tltype.tostring(torig)) if tltype.isRecursive(t) then local r = tltype.Recursive(t[1], check_self_field(env, torig, t[2], pos)) r.name = t.name @@ -121,7 +120,7 @@ function check_self_field(env, torig, t, pos) end elseif tltype.isTable(t) then local l = {} - for k, v in ipairs(t) do + for _, v in ipairs(t) do table.insert(l, tltype.Field(v.const, v[1], check_self_field(env, torig, v[2], pos))) end local r = tltype.Table(table.unpack(l)) @@ -163,7 +162,7 @@ local function replace_names (env, t, pos, ignore) tltype.isUnionlist(t) or tltype.isTuple(t) then local r = { tag = t.tag, name = t.name } - for k, v in ipairs(t) do + for k, _ in ipairs(t) do r[k] = replace_names(env, t[k], pos, ignore) end return r @@ -172,7 +171,7 @@ local function replace_names (env, t, pos, ignore) t[2] = replace_names(env, t[2], pos, ignore) return t elseif tltype.isTable(t) then - for k, v in ipairs(t) do + for k, _ in ipairs(t) do t[k][2] = replace_names(env, t[k][2], pos, ignore) end return t @@ -196,7 +195,7 @@ local function close_type (t) if tltype.isUnion(t) or tltype.isUnionlist(t) or tltype.isTuple(t) then - for k, v in ipairs(t) do + for _, v in ipairs(t) do close_type(v) end else @@ -244,7 +243,7 @@ local function check_masking (env, local_name, pos) local masked_local = tlst.masking(env, local_name) if masked_local then local l, c = lineno(env.subject, masked_local.pos) - msg = "masking previous declaration of local %s on line %d" + local msg = "masking previous declaration of local %s on line %d" msg = string.format(msg, local_name, l) typeerror(env, "mask", msg, pos) end @@ -271,7 +270,7 @@ local function check_tl (env, name, path, pos) env.subject = subject env.filename = path tlst.begin_function(env) - check_block(env, ast) + tlchecker:Block(ast, env) local t1 = tltype.first(infer_return_type(env)) tlst.end_function(env) env.subject = s @@ -279,21 +278,6 @@ local function check_tl (env, name, path, pos) return t1 end -local function check_interface (env, stm) - local name, t, is_local = stm[1], stm[2], stm.is_local - if tlst.get_interface(env, name) then - local msg = "attempt to redeclare interface '%s'" - msg = string.format(msg, name) - typeerror(env, "alias", msg, stm.pos) - else - check_self(env, t, t, stm.pos) - local t = replace_names(env, t, stm.pos) - t.name = name - tlst.set_interface(env, name, t, is_local) - end - return false -end - local function check_userdata (env, stm) local name, t, is_local = stm[1], stm[2], stm.is_local if tlst.get_userdata(env, name) then @@ -312,12 +296,12 @@ local function check_tld (env, name, path, pos) return Any end local t = tltype.Table() - for k, v in ipairs(ast) do + for _, v in ipairs(ast) do local tag = v.tag if tag == "Id" then table.insert(t, tltype.Field(v.const, tltype.Literal(v[1]), v[2])) elseif tag == "Interface" then - check_interface(env, v) + tlchecker:Interface(v, env) elseif tag == "Userdata" then check_userdata(env, v) else @@ -341,7 +325,8 @@ local function check_require (env, name, pos, extra_path) end else path = string.gsub(package.path..";", "[.]lua;", ".tld;") - local filepath, msg2 = searchpath(extra_path .. name, path) + local msg2 + filepath, msg2 = searchpath(extra_path .. name, path) if filepath then env["loaded"][name] = check_tld(env, name, filepath, pos) else @@ -362,424 +347,505 @@ local function check_require (env, name, pos, extra_path) return env["loaded"][name] end -local function check_arith (env, exp, op) - local exp1, exp2 = exp[2], exp[3] - check_exp(env, exp1) - check_exp(env, exp2) - local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) - local msg = "attempt to perform arithmetic on a '%s'" - if tltype.subtype(t1, tltype.Integer(true)) and - tltype.subtype(t2, tltype.Integer(true)) then - if op == "div" or op == "pow" then - set_type(exp, Number) +local function check_parameters (env, parlist, pos) + local len = #parlist + if len == 0 then + if env.strict then + return tltype.Void() else - set_type(exp, Integer) - end - elseif tltype.subtype(t1, Number) and tltype.subtype(t2, Number) then - set_type(exp, Number) - if op == "idiv" then - local msg = "integer division on floats" - typeerror(env, "arith", msg, exp1.pos) + return tltype.Tuple({ Value }, true) end - elseif tltype.isAny(t1) then - set_type(exp, Any) - msg = string.format(msg, tltype.tostring(t1)) - typeerror(env, "any", msg, exp1.pos) - elseif tltype.isAny(t2) then - set_type(exp, Any) - msg = string.format(msg, tltype.tostring(t2)) - typeerror(env, "any", msg, exp2.pos) else - set_type(exp, Any) - local wrong_type, wrong_pos = tltype.general(t1), exp1.pos - if tltype.subtype(t1, Number) or tltype.isAny(t1) then - wrong_type, wrong_pos = tltype.general(t2), exp2.pos + local l = {} + if parlist[1][1] == "self" and not parlist[1][2] then + parlist[1][2] = Self + end + for i = 1, len do + if not parlist[i][2] then parlist[i][2] = Any end + l[i] = replace_names(env, parlist[i][2], pos) + end + if parlist[len].tag == "Dots" then + local t = parlist[len][1] or Any + l[len] = t + tlst.set_vararg(env, t) + return tltype.Tuple(l, true) + else + if env.strict then + return tltype.Tuple(l) + else + l[len + 1] = Value + return tltype.Tuple(l, true) + end end - msg = string.format(msg, tltype.tostring(wrong_type)) - typeerror(env, "arith", msg, wrong_pos) end end -local function check_bitwise (env, exp, op) - local exp1, exp2 = exp[2], exp[3] - check_exp(env, exp1) - check_exp(env, exp2) - local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) - local msg = "attempt to perform bitwise on a '%s'" - if tltype.subtype(t1, tltype.Integer(true)) and - tltype.subtype(t2, tltype.Integer(true)) then - set_type(exp, Integer) - elseif tltype.isAny(t1) then - set_type(exp, Any) - msg = string.format(msg, tltype.tostring(t1)) - typeerror(env, "any", msg, exp1.pos) - elseif tltype.isAny(t2) then - set_type(exp, Any) - msg = string.format(msg, tltype.tostring(t2)) - typeerror(env, "any", msg, exp2.pos) +local function check_return_type (env, inf_type, dec_type, pos) + local msg = "return type '%s' does not match '%s'" + if tltype.isUnionlist(dec_type) then + dec_type = tltype.unionlist2tuple(dec_type) + end + dec_type = tltype.unfold(dec_type) + if tltype.subtype(inf_type, dec_type) then + elseif tltype.consistent_subtype(inf_type, dec_type) then + msg = string.format(msg, tltype.tostring(inf_type), tltype.tostring(dec_type)) + typeerror(env, "any", msg, pos) else - set_type(exp, Any) - local wrong_type, wrong_pos = tltype.general(t1), exp1.pos - if tltype.subtype(t1, Number) or tltype.isAny(t1) then - wrong_type, wrong_pos = tltype.general(t2), exp2.pos - end - msg = string.format(msg, tltype.tostring(wrong_type)) - typeerror(env, "arith", msg, wrong_pos) + msg = string.format(msg, tltype.tostring(inf_type), tltype.tostring(dec_type)) + typeerror(env, "ret", msg, pos) end end -local function check_concat (env, exp) - local exp1, exp2 = exp[2], exp[3] - check_exp(env, exp1) - check_exp(env, exp2) - local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) - local msg = "attempt to concatenate a '%s'" - if tltype.subtype(t1, String) and tltype.subtype(t2, String) then - set_type(exp, String) - elseif tltype.isAny(t1) then - set_type(exp, Any) - msg = string.format(msg, tltype.tostring(t1)) - typeerror(env, "any", msg, exp1.pos) - elseif tltype.isAny(t2) then - set_type(exp, Any) - msg = string.format(msg, tltype.tostring(t2)) - typeerror(env, "any", msg, exp2.pos) - else - set_type(exp, Any) - local wrong_type, wrong_pos = tltype.general(t1), exp1.pos - if tltype.subtype(t1, String) or tltype.isAny(t1) then - wrong_type, wrong_pos = tltype.general(t2), exp2.pos +local function explist2typegen (explist) + local len = #explist + return function (i) + if i <= len then + local t = get_type(explist[i]) + return tltype.first(t) + else + local t = Nil + if len > 0 then t = get_type(explist[len]) end + if tltype.isTuple(t) then + if i <= #t then + t = t[i] + else + t = t[#t] + if not tltype.isVararg(t) then t = Nil end + end + else + t = Nil + end + if tltype.isVararg(t) then + return tltype.first(t) + else + return t + end end - msg = string.format(msg, tltype.tostring(wrong_type)) - typeerror(env, "concat", msg, wrong_pos) end end -local function check_equal (env, exp) - local exp1, exp2 = exp[2], exp[3] - check_exp(env, exp1) - check_exp(env, exp2) - set_type(exp, Boolean) +local function arglist2type (explist, strict) + local len = #explist + if len == 0 then + if strict then + return tltype.Void() + else + return tltype.Tuple({ Nil }, true) + end + else + local l = {} + for i = 1, len do + l[i] = tltype.first(get_type(explist[i])) + end + if strict then + return tltype.Tuple(l) + else + if not tltype.isVararg(explist[len]) then + l[len + 1] = Nil + end + return tltype.Tuple(l, true) + end + end end -local function check_order (env, exp) - local exp1, exp2 = exp[2], exp[3] - check_exp(env, exp1) - check_exp(env, exp2) - local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) - local msg = "attempt to compare '%s' with '%s'" - if tltype.subtype(t1, Number) and tltype.subtype(t2, Number) then - set_type(exp, Boolean) - elseif tltype.subtype(t1, String) and tltype.subtype(t2, String) then - set_type(exp, Boolean) - elseif tltype.isAny(t1) then - set_type(exp, Any) - msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) - typeerror(env, "any", msg, exp1.pos) - elseif tltype.isAny(t2) then - set_type(exp, Any) - msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) - typeerror(env, "any", msg, exp2.pos) +local function check_arguments (env, func_name, dec_type, infer_type, pos) + local msg = "attempt to pass '%s' to %s of input type '%s'" + if tltype.subtype(infer_type, dec_type) then + elseif tltype.consistent_subtype(infer_type, dec_type) then + msg = string.format(msg, tltype.tostring(infer_type), func_name, tltype.tostring(dec_type)) + typeerror(env, "any", msg, pos) else - set_type(exp, Any) - t1, t2 = tltype.general(t1), tltype.general(t2) - msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) - typeerror(env, "order", msg, exp.pos) + msg = string.format(msg, tltype.tostring(infer_type), func_name, tltype.tostring(dec_type)) + typeerror(env, "args", msg, pos) end end -local function check_and (env, exp) - local exp1, exp2 = exp[2], exp[3] - check_exp(env, exp1) - check_exp(env, exp2) - local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) - if tltype.isNil(t1) or tltype.isFalse(t1) then - set_type(exp, t1) - elseif tltype.isUnion(t1, Nil) then - set_type(exp, tltype.Union(t2, Nil)) - elseif tltype.isUnion(t1, False) then - set_type(exp, tltype.Union(t2, False)) - elseif tltype.isBoolean(t1) then - set_type(exp, tltype.Union(t2, False)) +local function replace_self (env, t, tself) + tself = tself or Nil + if tltype.isSelf(t) then + return tself + elseif tltype.isRecursive(t) then + local r = tltype.Recursive(t[1], replace_self(env, t[2], tself)) + r.name = t.name + return r + elseif tltype.isLiteral(t) or + tltype.isBase(t) or + tltype.isNil(t) or + tltype.isValue(t) or + tltype.isAny(t) or + tltype.isTable(t) or + tltype.isVariable(t) or + tltype.isVoid(t) then + return t + elseif tltype.isUnion(t) or + tltype.isUnionlist(t) or + tltype.isTuple(t) then + local r = { tag = t.tag, name = t.name } + for k, v in ipairs(t) do + r[k] = replace_self(env, v, tself) + end + return r + elseif tltype.isFunction(t) then + return tltype.Function(replace_self(env, t[1], tself), replace_self(env, t[2], tself)) + elseif tltype.isVararg(t) then + return tltype.Vararg(replace_self(env, t[1], tself)) else - set_type(exp, tltype.Union(t1, t2)) + return t end end -local function check_or (env, exp) - local exp1, exp2 = exp[2], exp[3] - check_exp(env, exp1) - check_exp(env, exp2) - local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) - if tltype.isNil(t1) or tltype.isFalse(t1) then - set_type(exp, t2) - elseif tltype.isUnion(t1, Nil) then - set_type(exp, tltype.Union(tltype.filterUnion(t1, Nil), t2)) - elseif tltype.isUnion(t1, False) then - set_type(exp, tltype.Union(tltype.filterUnion(t1, False), t2)) +local function check_local_var (env, id, inferred_type, close_local) + local local_name, local_type, pos = id[1], id[2], id.pos + if tltype.isMethod(inferred_type) then + local msg = "attempt to create a method reference" + typeerror(env, "local", msg, pos) + inferred_type = Nil + end + if not local_type then + if tltype.isNil(inferred_type) then + local_type = Any + else + local_type = tltype.general(inferred_type) + if not local_type.name then local_type.name = local_name end + if inferred_type.unique then + local_type.unique = nil + local_type.open = true + end + if close_local then local_type.open = nil end + end else - set_type(exp, tltype.Union(t1, t2)) + check_self(env, local_type, local_type, pos) + local_type = replace_names(env, local_type, pos) + local msg = "attempt to assign '%s' to '%s'" + local local_type = tltype.unfold(local_type) + msg = string.format(msg, tltype.tostring(inferred_type), tltype.tostring(local_type)) + if tltype.subtype(inferred_type, local_type) then + elseif tltype.consistent_subtype(inferred_type, local_type) then + typeerror(env, "any", msg, pos) + else + typeerror(env, "local", msg, pos) + end end + set_type(id, local_type) + check_masking(env, id[1], id.pos) + tlst.set_local(env, id) end -local function check_binary_op (env, exp) - local op = exp[1] - if op == "add" or op == "sub" or - op == "mul" or op == "idiv" or op == "div" or op == "mod" or - op == "pow" then - check_arith(env, exp, op) - elseif op == "concat" then - check_concat(env, exp) - elseif op == "eq" then - check_equal(env, exp) - elseif op == "lt" or op == "le" then - check_order(env, exp) - elseif op == "and" then - check_and(env, exp) - elseif op == "or" then - check_or(env, exp) - elseif op == "band" or op == "bor" or op == "bxor" or - op == "shl" or op == "shr" then - check_bitwise(env, exp) +local function explist2typelist (explist) + local len = #explist + if len == 0 then + return tltype.Tuple({ Nil }, true) else - error("cannot type check binary operator " .. op) + local l = {} + for i = 1, len - 1 do + table.insert(l, tltype.first(get_type(explist[i]))) + end + local last_type = get_type(explist[len]) + if tltype.isUnionlist(last_type) then + last_type = tltype.unionlist2tuple(last_type) + end + if tltype.isTuple(last_type) then + for _, v in ipairs(last_type) do + table.insert(l, tltype.first(v)) + end + else + table.insert(l, last_type) + end + if not tltype.isVararg(last_type) then + table.insert(l, tltype.Vararg(Nil)) + end + return tltype.Tuple(l) end end -local function check_not (env, exp) - local exp1 = exp[2] - check_exp(env, exp1) - set_type(exp, Boolean) +local function is_global_function_call (exp, fn_name) + return exp.tag == "Call" and exp[1].tag == "Index" and + exp[1][1].tag == "Id" and exp[1][1][1] == "_ENV" and + exp[1][2].tag == "String" and exp[1][2][1] == fn_name end -local function check_bnot (env, exp) - local exp1 = exp[2] - check_exp(env, exp1) - local t1 = tltype.first(get_type(exp1)) - local msg = "attempt to perform bitwise on a '%s'" - if tltype.subtype(t1, tltype.Integer(true)) then - set_type(exp, Integer) - elseif tltype.isAny(t1) then - set_type(exp, Any) - msg = string.format(msg, tltype.tostring(t1)) - typeerror(env, "any", msg, exp1.pos) +local function check_var (env, var, exp) + local tag = var.tag + if tag == "Id" then + local name = var[1] + local l = tlst.get_local(env, name) + local t = get_type(l) + if exp and exp.tag == "Id" and tltype.isTable(t) then t.open = nil end + set_type(var, t) + elseif tag == "Index" then + local exp1, exp2 = var[1], var[2] + tlchecker:Expression(exp1, env) + tlchecker:Expression(exp2, env) + local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) + local msg = "attempt to index '%s' with '%s'" + t1 = replace_self(env, t1, env.self) + if tltype.isTable(t1) then + local oself = env.self + -- another brittle hack for defining methods + if exp1.tag == "Id" and exp1[1] ~= "_ENV" then env.self = t1 end + local field_type = tltype.getField(t2, t1) + if not tltype.isNil(field_type) then + set_type(var, field_type) + else + if t1.open then + if exp then + local t3 = tltype.general(get_type(exp)) + local t = tltype.general(t1) + table.insert(t, tltype.Field(var.const, t2, t3)) + if tltype.subtype(t, t1) then + table.insert(t1, tltype.Field(var.const, t2, t3)) + else + msg = "could not include field '%s'" + msg = string.format(msg, tltype.tostring(t2)) + typeerror(env, "open", msg, var.pos) + end + if t3.open then t3.open = nil end + set_type(var, t3) + else + set_type(var, Nil) + end + else + if exp1.tag == "Id" and exp1[1] == "_ENV" and exp2.tag == "String" then + msg = "attempt to access undeclared global '%s'" + msg = string.format(msg, exp2[1]) + else + msg = "attempt to use '%s' to index closed table" + msg = string.format(msg, tltype.tostring(t2)) + end + typeerror(env, "open", msg, var.pos) + set_type(var, Nil) + end + end + env.self = oself + elseif tltype.isAny(t1) then + set_type(var, Any) + msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) + typeerror(env, "any", msg, var.pos) + else + set_type(var, Nil) + msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) + typeerror(env, "index", msg, var.pos) + end else - set_type(exp, Any) - msg = string.format(msg, tltype.tostring(t1)) - typeerror(env, "bitwise", msg, exp1.pos) + error("cannot type check variable " .. tag) end end -local function check_minus (env, exp) - local exp1 = exp[2] - check_exp(env, exp1) - local t1 = tltype.first(get_type(exp1)) + + + + + + +local function check_arith (env, exp, op) + local exp1, exp2 = exp[2], exp[3] + tlchecker:Expression(exp1, env) + tlchecker:Expression(exp2, env) + local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) local msg = "attempt to perform arithmetic on a '%s'" - if tltype.subtype(t1, Integer) then - set_type(exp, Integer) - elseif tltype.subtype(t1, Number) then + if tltype.subtype(t1, tltype.Integer(true)) and + tltype.subtype(t2, tltype.Integer(true)) then + if op == "div" or op == "pow" then + set_type(exp, Number) + else + set_type(exp, Integer) + end + elseif tltype.subtype(t1, Number) and tltype.subtype(t2, Number) then set_type(exp, Number) + if op == "idiv" then + local msg = "integer division on floats" + typeerror(env, "arith", msg, exp1.pos) + end elseif tltype.isAny(t1) then set_type(exp, Any) msg = string.format(msg, tltype.tostring(t1)) typeerror(env, "any", msg, exp1.pos) + elseif tltype.isAny(t2) then + set_type(exp, Any) + msg = string.format(msg, tltype.tostring(t2)) + typeerror(env, "any", msg, exp2.pos) else set_type(exp, Any) - t1 = tltype.general(t1) - msg = string.format(msg, tltype.tostring(t1)) - typeerror(env, "arith", msg, exp1.pos) + local wrong_type, wrong_pos = tltype.general(t1), exp1.pos + if tltype.subtype(t1, Number) or tltype.isAny(t1) then + wrong_type, wrong_pos = tltype.general(t2), exp2.pos + end + msg = string.format(msg, tltype.tostring(wrong_type)) + typeerror(env, "arith", msg, wrong_pos) end end -local function check_len (env, exp) - local exp1 = exp[2] - check_exp(env, exp1) - local t1 = tltype.first(get_type(exp1)) - local msg = "attempt to get length of a '%s'" - if tltype.subtype(t1, String) or - tltype.subtype(t1, tltype.Table()) then - set_type(exp, Integer) +local function check_concat (env, exp) + local exp1, exp2 = exp[2], exp[3] + tlchecker:Expression(exp1, env) + tlchecker:Expression(exp2, env) + local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) + local msg = "attempt to concatenate a '%s'" + if tltype.subtype(t1, String) and tltype.subtype(t2, String) then + set_type(exp, String) elseif tltype.isAny(t1) then set_type(exp, Any) msg = string.format(msg, tltype.tostring(t1)) typeerror(env, "any", msg, exp1.pos) + elseif tltype.isAny(t2) then + set_type(exp, Any) + msg = string.format(msg, tltype.tostring(t2)) + typeerror(env, "any", msg, exp2.pos) else set_type(exp, Any) - t1 = tltype.general(t1) - msg = string.format(msg, tltype.tostring(t1)) - typeerror(env, "len", msg, exp1.pos) + local wrong_type, wrong_pos = tltype.general(t1), exp1.pos + if tltype.subtype(t1, String) or tltype.isAny(t1) then + wrong_type, wrong_pos = tltype.general(t2), exp2.pos + end + msg = string.format(msg, tltype.tostring(wrong_type)) + typeerror(env, "concat", msg, wrong_pos) end end -local function check_unary_op (env, exp) - local op = exp[1] - if op == "not" then - check_not(env, exp) - elseif op == "bnot" then - check_bnot(env, exp) - elseif op == "unm" then - check_minus(env, exp) - elseif op == "len" then - check_len(env, exp) - else - error("cannot type check unary operator " .. op) - end +local function check_equal (env, exp) + local exp1, exp2 = exp[2], exp[3] + tlchecker:Expression(exp1, env) + tlchecker:Expression(exp2, env) + set_type(exp, Boolean) end -local function check_op (env, exp) - if exp[3] then - check_binary_op(env, exp) +local function check_order (env, exp) + local exp1, exp2 = exp[2], exp[3] + tlchecker:Expression(exp1, env) + tlchecker:Expression(exp2, env) + local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) + local msg = "attempt to compare '%s' with '%s'" + if tltype.subtype(t1, Number) and tltype.subtype(t2, Number) then + set_type(exp, Boolean) + elseif tltype.subtype(t1, String) and tltype.subtype(t2, String) then + set_type(exp, Boolean) + elseif tltype.isAny(t1) then + set_type(exp, Any) + msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) + typeerror(env, "any", msg, exp1.pos) + elseif tltype.isAny(t2) then + set_type(exp, Any) + msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) + typeerror(env, "any", msg, exp2.pos) else - check_unary_op(env, exp) - end + set_type(exp, Any) + t1, t2 = tltype.general(t1), tltype.general(t2) + msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) + typeerror(env, "order", msg, exp.pos) + end end -local function check_paren (env, exp) - local exp1 = exp[1] - check_exp(env, exp1) - local t1 = get_type(exp1) - set_type(exp, tltype.first(t1)) +local function check_and (env, exp) + local exp1, exp2 = exp[2], exp[3] + tlchecker:Expression(exp1, env) + tlchecker:Expression(exp2, env) + local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) + if tltype.isNil(t1) or tltype.isFalse(t1) then + set_type(exp, t1) + elseif tltype.isUnion(t1, Nil) then + set_type(exp, tltype.Union(t2, Nil)) + elseif tltype.isUnion(t1, False) then + set_type(exp, tltype.Union(t2, False)) + elseif tltype.isBoolean(t1) then + set_type(exp, tltype.Union(t2, False)) + else + set_type(exp, tltype.Union(t1, t2)) + end end -local function check_parameters (env, parlist, pos) - local len = #parlist - if len == 0 then - if env.strict then - return tltype.Void() - else - return tltype.Tuple({ Value }, true) - end +local function check_or (env, exp) + local exp1, exp2 = exp[2], exp[3] + tlchecker:Expression(exp1, env) + tlchecker:Expression(exp2, env) + local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) + if tltype.isNil(t1) or tltype.isFalse(t1) then + set_type(exp, t2) + elseif tltype.isUnion(t1, Nil) then + set_type(exp, tltype.Union(tltype.filterUnion(t1, Nil), t2)) + elseif tltype.isUnion(t1, False) then + set_type(exp, tltype.Union(tltype.filterUnion(t1, False), t2)) else - local l = {} - if parlist[1][1] == "self" and not parlist[1][2] then - parlist[1][2] = Self - end - for i = 1, len do - if not parlist[i][2] then parlist[i][2] = Any end - l[i] = replace_names(env, parlist[i][2], pos) - end - if parlist[len].tag == "Dots" then - local t = parlist[len][1] or Any - l[len] = t - tlst.set_vararg(env, t) - return tltype.Tuple(l, true) - else - if env.strict then - return tltype.Tuple(l) - else - l[len + 1] = Value - return tltype.Tuple(l, true) - end - end + set_type(exp, tltype.Union(t1, t2)) end end -local function check_explist (env, explist, lselfs) - lselfs = lselfs or {} - for k, v in ipairs(explist) do - check_exp(env, v, lselfs[k]) +local function check_bitwise (env, exp, op) + local exp1, exp2 = exp[2], exp[3] + tlchecker:Expression(exp1, env) + tlchecker:Expression(exp2, env) + local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) + local msg = "attempt to perform bitwise on a '%s'" + if tltype.subtype(t1, tltype.Integer(true)) and + tltype.subtype(t2, tltype.Integer(true)) then + set_type(exp, Integer) + elseif tltype.isAny(t1) then + set_type(exp, Any) + msg = string.format(msg, tltype.tostring(t1)) + typeerror(env, "any", msg, exp1.pos) + elseif tltype.isAny(t2) then + set_type(exp, Any) + msg = string.format(msg, tltype.tostring(t2)) + typeerror(env, "any", msg, exp2.pos) + else + set_type(exp, Any) + local wrong_type, wrong_pos = tltype.general(t1), exp1.pos + if tltype.subtype(t1, Number) or tltype.isAny(t1) then + wrong_type, wrong_pos = tltype.general(t2), exp2.pos + end + msg = string.format(msg, tltype.tostring(wrong_type)) + typeerror(env, "arith", msg, wrong_pos) end end -local function check_return_type (env, inf_type, dec_type, pos) - local msg = "return type '%s' does not match '%s'" - if tltype.isUnionlist(dec_type) then - dec_type = tltype.unionlist2tuple(dec_type) - end - local dec_type = tltype.unfold(dec_type) - if tltype.subtype(inf_type, dec_type) then - elseif tltype.consistent_subtype(inf_type, dec_type) then - msg = string.format(msg, tltype.tostring(inf_type), tltype.tostring(dec_type)) - typeerror(env, "any", msg, pos) +function tlchecker:BinaryOp (exp, env) + local op = exp[1] + if op == "add" or op == "sub" or + op == "mul" or op == "idiv" or op == "div" or op == "mod" or + op == "pow" then + check_arith(env, exp, op) + elseif op == "concat" then + check_concat(env, exp) + elseif op == "eq" then + check_equal(env, exp) + elseif op == "lt" or op == "le" then + check_order(env, exp) + elseif op == "and" then + check_and(env, exp) + elseif op == "or" then + check_or(env, exp) + elseif op == "band" or op == "bor" or op == "bxor" or + op == "shl" or op == "shr" then + check_bitwise(env, exp) else - msg = string.format(msg, tltype.tostring(inf_type), tltype.tostring(dec_type)) - typeerror(env, "ret", msg, pos) + error("cannot type check binary operator " .. op) end end -local function check_function (env, exp, tself) - local oself = env.self - env.self = tself - local idlist, ret_type, block = exp[1], replace_names(env, exp[2], exp.pos), exp[3] - local infer_return = false - if not block then - block = ret_type - ret_type = tltype.Tuple({ Nil }, true) - infer_return = true - end - tlst.begin_function(env) +local function is_exit_point (block) + if #block == 0 then return false end + local last = block[#block] + return last.tag == "Return" or is_global_function_call(last, "error") +end + +function tlchecker:Block (block, env) tlst.begin_scope(env) - local input_type = check_parameters(env, idlist, exp.pos) - local t = tltype.Function(input_type, ret_type) - local len = #idlist - if len > 0 and idlist[len].tag == "Dots" then len = len - 1 end - for k = 1, len do - local v = idlist[k] - v[2] = replace_names(env, v[2], exp.pos) - set_type(v, v[2]) - check_masking(env, v[1], v.pos) - tlst.set_local(env, v) + local r = false + local bkp = env.self + local endswithret = true + local didgoto, _ = false, nil + for _, v in ipairs(block) do + r, _, didgoto = self:Statement(v, env) + env.self = bkp + if didgoto then endswithret = false end end - local r = check_block(env, block) - if not r then tlst.set_return_type(env, tltype.Tuple({ Nil }, true)) end + endswithret = endswithret and is_exit_point(block) check_unused_locals(env) tlst.end_scope(env) - local inferred_type = infer_return_type(env) - if infer_return then - ret_type = inferred_type - t = tltype.Function(input_type, ret_type) - set_type(exp, t) - end - if env.self then - t = check_self_field(env, t, t, exp.pos) - else - t = check_self(env, t, t, exp.pos) - end - check_return_type(env, inferred_type, ret_type, exp.pos) - tlst.end_function(env) - set_type(exp, t) - env.self = oself + return r, endswithret, didgoto end -local function check_table (env, exp) - local l = {} - local i = 1 - local len = #exp - for k, v in ipairs(exp) do - local tag = v.tag - local t1, t2 - if tag == "Pair" then - local exp1, exp2 = v[1], v[2] - check_exp(env, exp1) - check_exp(env, exp2) - t1, t2 = get_type(exp1), tltype.general(get_type(exp2)) - if tltype.subtype(Nil, t1) then - t1 = Any - local msg = "table index can be nil" - typeerror(env, "table", msg, exp1.pos) - elseif not (tltype.subtype(t1, Boolean) or - tltype.subtype(t1, Number) or - tltype.subtype(t1, String)) then - t1 = Any - local msg = "table index is dynamic" - typeerror(env, "any", msg, exp1.pos) - end - else - local exp1 = v - check_exp(env, exp1) - t1, t2 = tltype.Literal(i), tltype.general(get_type(exp1)) - if k == len and tltype.isVararg(t2) then - t1 = Integer - end - i = i + 1 - end - if t2.open then t2.open = nil end - t2 = tltype.first(t2) - l[k] = tltype.Field(v.const, t1, t2) - end - local t = tltype.Table(table.unpack(l)) - t.unique = true - set_type(exp, t) +function tlchecker:Break (stm, env) + return false end local function var2name (var) @@ -797,130 +863,32 @@ local function var2name (var) end end -local function explist2typegen (explist) - local len = #explist - return function (i) - if i <= len then - local t = get_type(explist[i]) - return tltype.first(t) - else - local t = Nil - if len > 0 then t = get_type(explist[len]) end - if tltype.isTuple(t) then - if i <= #t then - t = t[i] - else - t = t[#t] - if not tltype.isVararg(t) then t = Nil end - end - else - t = Nil - end - if tltype.isVararg(t) then - return tltype.first(t) +function tlchecker:Call (exp, env) + local exp1 = exp[1] + local explist = {} + for i = 2, #exp do + explist[i - 1] = exp[i] + end + self:Expression(exp1, env) + self:ExpList(explist, env) + if exp1.tag == "Index" and + exp1[1].tag == "Id" and exp1[1][1] == "_ENV" and + exp1[2].tag == "String" and exp1[2][1] == "setmetatable" then + if explist[1] and explist[2] then + local t1, t2 = get_type(explist[1]), get_type(explist[2]) + local t3 = tltype.getField(tltype.Literal("__index"), t2) + if not tltype.isNil(t3) then + if tltype.isTable(t3) then t3.open = true end + set_type(exp, t3) else - return t + local msg = "second argument of setmetatable must be { __index = e }" + typeerror(env, "call", msg, exp.pos) + set_type(exp, Any) end - end - end -end - -local function arglist2type (explist, strict) - local len = #explist - if len == 0 then - if strict then - return tltype.Void() else - return tltype.Tuple({ Nil }, true) - end - else - local l = {} - for i = 1, len do - l[i] = tltype.first(get_type(explist[i])) - end - if strict then - return tltype.Tuple(l) - else - if not tltype.isVararg(explist[len]) then - l[len + 1] = Nil - end - return tltype.Tuple(l, true) - end - end -end - -local function check_arguments (env, func_name, dec_type, infer_type, pos) - local msg = "attempt to pass '%s' to %s of input type '%s'" - if tltype.subtype(infer_type, dec_type) then - elseif tltype.consistent_subtype(infer_type, dec_type) then - msg = string.format(msg, tltype.tostring(infer_type), func_name, tltype.tostring(dec_type)) - typeerror(env, "any", msg, pos) - else - msg = string.format(msg, tltype.tostring(infer_type), func_name, tltype.tostring(dec_type)) - typeerror(env, "args", msg, pos) - end -end - -local function replace_self (env, t, tself) - tself = tself or Nil - if tltype.isSelf(t) then - return tself - elseif tltype.isRecursive(t) then - local r = tltype.Recursive(t[1], replace_self(env, t[2], tself)) - r.name = t.name - return r - elseif tltype.isLiteral(t) or - tltype.isBase(t) or - tltype.isNil(t) or - tltype.isValue(t) or - tltype.isAny(t) or - tltype.isTable(t) or - tltype.isVariable(t) or - tltype.isVoid(t) then - return t - elseif tltype.isUnion(t) or - tltype.isUnionlist(t) or - tltype.isTuple(t) then - local r = { tag = t.tag, name = t.name } - for k, v in ipairs(t) do - r[k] = replace_self(env, v, tself) - end - return r - elseif tltype.isFunction(t) then - return tltype.Function(replace_self(env, t[1], tself), replace_self(env, t[2], tself)) - elseif tltype.isVararg(t) then - return tltype.Vararg(replace_self(env, t[1], tself)) - else - return t - end -end - -local function check_call (env, exp) - local exp1 = exp[1] - local explist = {} - for i = 2, #exp do - explist[i - 1] = exp[i] - end - check_exp(env, exp1) - check_explist(env, explist) - if exp1.tag == "Index" and - exp1[1].tag == "Id" and exp1[1][1] == "_ENV" and - exp1[2].tag == "String" and exp1[2][1] == "setmetatable" then - if explist[1] and explist[2] then - local t1, t2 = get_type(explist[1]), get_type(explist[2]) - local t3 = tltype.getField(tltype.Literal("__index"), t2) - if not tltype.isNil(t3) then - if tltype.isTable(t3) then t3.open = true end - set_type(exp, t3) - else - local msg = "second argument of setmetatable must be { __index = e }" - typeerror(env, "call", msg, exp.pos) - set_type(exp, Any) - end - else - local msg = "setmetatable must have two arguments" - typeerror(env, "call", msg, exp.pos) - set_type(exp, Any) + local msg = "setmetatable must have two arguments" + typeerror(env, "call", msg, exp.pos) + set_type(exp, Any) end elseif exp1.tag == "Index" and exp1[1].tag == "Id" and exp1[1][1] == "_ENV" and @@ -959,151 +927,120 @@ local function check_call (env, exp) return false end -local function check_invoke (env, exp) - local exp1, exp2 = exp[1], exp[2] - local explist = {} - for i = 3, #exp do - explist[i - 2] = exp[i] +function tlchecker:Dots (exp, env) + set_type(exp, tltype.Vararg(tlst.get_vararg(env))) +end + +function tlchecker:ExpList (explist, env, lselfs) + lselfs = lselfs or {} + for k, v in ipairs(explist) do + self:Expression(v, env, lselfs[k]) end - check_exp(env, exp1) - check_exp(env, exp2) - check_explist(env, explist) - local t1, t2 = get_type(exp1), get_type(exp2) - t1 = replace_self(env, t1, env.self) - table.insert(explist, 1, { type = t1 }) - if tltype.isTable(t1) or - tltype.isString(t1) or - tltype.isStr(t1) then - local inferred_type = replace_self(env, arglist2type(explist, env.strict), env.self) - local t3 - if tltype.isTable(t1) then - t3 = replace_self(env, tltype.getField(t2, t1), t1) - --local s = env.self or Nil - --if not tltype.subtype(s, t1) then env.self = t1 end - else - local string_userdata = env["loaded"]["string"] or tltype.Table() - t3 = replace_self(env, tltype.getField(t2, string_userdata), t1) - inferred_type[1] = String - end - local msg = "attempt to call method '%s' of type '%s'" - if tltype.isFunction(t3) then - check_arguments(env, "field", t3[1], inferred_type, exp.pos) - set_type(exp, t3[2]) - elseif tltype.isAny(t3) then - set_type(exp, Any) - msg = string.format(msg, exp2[1], tltype.tostring(t3)) - typeerror(env, "any", msg, exp.pos) - else - set_type(exp, Nil) - msg = string.format(msg, exp2[1], tltype.tostring(t3)) - typeerror(env, "invoke", msg, exp.pos) +end + +function tlchecker:False (exp, env) + set_type(exp, False) +end + +function tlchecker:Forin (stm, env) + local idlist, explist, block = stm[1], stm[2], stm[3] + tlst.begin_scope(env) + self:ExpList(explist, env) + local t = tltype.first(get_type(explist[1])) + local tuple = explist2typegen({}) + local msg = "attempt to iterate over %s" + if tltype.isFunction(t) then + local l = {} + for k, v in ipairs(t[2]) do + l[k] = {} + set_type(l[k], v) end - elseif tltype.isAny(t1) then - set_type(exp, Any) - local msg = "attempt to index '%s' with '%s'" - msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) - typeerror(env, "any", msg, exp.pos) + tuple = explist2typegen(l) + elseif tltype.isAny(t) then + msg = string.format(msg, tltype.tostring(t)) + typeerror(env, "any", msg, idlist.pos) else - set_type(exp, Nil) - local msg = "attempt to index '%s' with '%s'" - msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) - typeerror(env, "index", msg, exp.pos) + msg = string.format(msg, tltype.tostring(t)) + typeerror(env, "forin", msg, idlist.pos) end - return false + for k, v in ipairs(idlist) do + local t = tltype.filterUnion(tuple(k), Nil) + check_local_var(env, v, t, false) + end + local r, _, didgoto = self:Block(block, env) + check_unused_locals(env) + tlst.end_scope(env) + return r, _, didgoto end -local function check_local_var (env, id, inferred_type, close_local) - local local_name, local_type, pos = id[1], id[2], id.pos - if tltype.isMethod(inferred_type) then - local msg = "attempt to create a method reference" - typeerror(env, "local", msg, pos) - inferred_type = Nil +local function infer_int(t) + return tltype.isInt(t) or tltype.isInteger(t) +end + +function tlchecker:Fornum (stm, env) + local id, exp1, exp2, exp3, block = stm[1], stm[2], stm[3], stm[4], stm[5] + self:Expression(exp1, env) + local t1 = get_type(exp1) + local msg = "'for' initial value must be a number" + if tltype.subtype(t1, Number) then + elseif tltype.consistent_subtype(t1, Number) then + typeerror(env, "any", msg, exp1.pos) + else + typeerror(env, "fornum", msg, exp1.pos) end - if not local_type then - if tltype.isNil(inferred_type) then - local_type = Any - else - local_type = tltype.general(inferred_type) - if not local_type.name then local_type.name = local_name end - if inferred_type.unique then - local_type.unique = nil - local_type.open = true - end - if close_local then local_type.open = nil end - end + self:Expression(exp2, env) + local t2 = get_type(exp2) + msg = "'for' limit must be a number" + if tltype.subtype(t2, Number) then + elseif tltype.consistent_subtype(t2, Number) then + typeerror(env, "any", msg, exp2.pos) else - check_self(env, local_type, local_type, pos) - local_type = replace_names(env, local_type, pos) - local msg = "attempt to assign '%s' to '%s'" - local local_type = tltype.unfold(local_type) - msg = string.format(msg, tltype.tostring(inferred_type), tltype.tostring(local_type)) - if tltype.subtype(inferred_type, local_type) then - elseif tltype.consistent_subtype(inferred_type, local_type) then - typeerror(env, "any", msg, pos) + typeerror(env, "fornum", msg, exp2.pos) + end + local int_step = true + if block then + self:Expression(exp3, env) + local t3 = get_type(exp3) + msg = "'for' step must be a number" + if not infer_int(t3) then + int_step = false + end + if tltype.subtype(t3, Number) then + elseif tltype.consistent_subtype(t3, Number) then + typeerror(env, "any", msg, exp3.pos) else - typeerror(env, "local", msg, pos) + typeerror(env, "fornum", msg, exp3.pos) end + else + block = exp3 end - set_type(id, local_type) - check_masking(env, id[1], id.pos) + tlst.begin_scope(env) tlst.set_local(env, id) -end - -local function unannotated_idlist (idlist) - for k, v in ipairs(idlist) do - if v[2] then return false end + if infer_int(t1) and infer_int(t2) and int_step then + set_type(id, Integer) + else + set_type(id, Number) end - return true + local r, _, didgoto = self:Block(block, env) + check_unused_locals(env) + tlst.end_scope(env) + return r, _, didgoto end -local function sized_unionlist (t) - for i = 1, #t - 1 do - if #t[i] ~= #t[i + 1] then return false end - end - return true -end - -local function check_local (env, idlist, explist) - check_explist(env, explist) - if unannotated_idlist(idlist) and - #explist == 1 and - tltype.isUnionlist(get_type(explist[1])) and - sized_unionlist(get_type(explist[1])) and - #idlist == #get_type(explist[1])[1] - 1 then - local t = get_type(explist[1]) - for k, v in ipairs(idlist) do - set_type(v, t) - v.i = k - check_masking(env, v[1], v.pos) - tlst.set_local(env, v) - end - else - local tuple = explist2typegen(explist) - for k, v in ipairs(idlist) do - local t = tuple(k) - local close_local = explist[k] and explist[k].tag == "Id" and tltype.isTable(t) - check_local_var(env, v, t, close_local) - end - end - return false -end - -local function check_localrec (env, id, exp) - local idlist, ret_type, block = exp[1], replace_names(env, exp[2], exp.pos), exp[3] - local infer_return = false - if not block then - block = ret_type - ret_type = tltype.Tuple({ Nil }, true) - infer_return = true +function tlchecker:Function (exp, env, tself) + local oself = env.self + env.self = tself + local idlist, ret_type, block = exp[1], replace_names(env, exp[2], exp.pos), exp[3] + local infer_return = false + if not block then + block = ret_type + ret_type = tltype.Tuple({ Nil }, true) + infer_return = true end tlst.begin_function(env) + tlst.begin_scope(env) local input_type = check_parameters(env, idlist, exp.pos) local t = tltype.Function(input_type, ret_type) - id[2] = t - set_type(id, t) - check_masking(env, id[1], id.pos) - tlst.set_local(env, id) - tlst.begin_scope(env) local len = #idlist if len > 0 and idlist[len].tag == "Dots" then len = len - 1 end for k = 1, len do @@ -1113,7 +1050,7 @@ local function check_localrec (env, id, exp) check_masking(env, v[1], v.pos) tlst.set_local(env, v) end - local r = check_block(env, block) + local r = tlchecker:Block(block, env) if not r then tlst.set_return_type(env, tltype.Tuple({ Nil }, true)) end check_unused_locals(env) tlst.end_scope(env) @@ -1121,111 +1058,42 @@ local function check_localrec (env, id, exp) if infer_return then ret_type = inferred_type t = tltype.Function(input_type, ret_type) - id[2] = t - set_type(id, t) - tlst.set_local(env, id) set_type(exp, t) end - check_return_type(env, inferred_type, ret_type, exp.pos) - tlst.end_function(env) - return false -end - -local function explist2typelist (explist) - local len = #explist - if len == 0 then - return tltype.Tuple({ Nil }, true) + if env.self then + t = check_self_field(env, t, t, exp.pos) else - local l = {} - for i = 1, len - 1 do - table.insert(l, tltype.first(get_type(explist[i]))) - end - local last_type = get_type(explist[len]) - if tltype.isUnionlist(last_type) then - last_type = tltype.unionlist2tuple(last_type) - end - if tltype.isTuple(last_type) then - for k, v in ipairs(last_type) do - table.insert(l, tltype.first(v)) - end - else - table.insert(l, last_type) - end - if not tltype.isVararg(last_type) then - table.insert(l, tltype.Vararg(Nil)) - end - return tltype.Tuple(l) + t = check_self(env, t, t, exp.pos) end + check_return_type(env, inferred_type, ret_type, exp.pos) + tlst.end_function(env) + set_type(exp, t) + env.self = oself end -local function check_return (env, stm) - check_explist(env, stm) - local t = explist2typelist(stm) - tlst.set_return_type(env, tltype.general(t)) - return true +function tlchecker:Goto (stm, env) + return false, nil, true end -local function check_assignment (env, varlist, explist) - local lselfs = {} - for k, v in ipairs(varlist) do - if v.tag == "Index" and v[1].tag == "Id" and v[2].tag == "String" then - local l = tlst.get_local(env, v[1][1]) - local t = get_type(l) - -- a brittle hack to type a method definition in the right-hand side? - if tltype.isTable(t) then lselfs[k] = t end - end - end - check_explist(env, explist, lselfs) - local l = {} - for k, v in ipairs(varlist) do - check_var(env, v, explist[k]) - table.insert(l, get_type(v)) - end - table.insert(l, tltype.Vararg(Value)) - local var_type, exp_type = tltype.Tuple(l), explist2typelist(explist) - local msg = "attempt to assign '%s' to '%s'" - if tltype.subtype(exp_type, var_type) then - elseif tltype.consistent_subtype(exp_type, var_type) then - msg = string.format(msg, tltype.tostring(exp_type), tltype.tostring(var_type)) - typeerror(env, "any", msg, varlist[1].pos) +function tlchecker:Id (exp, env) + local name = exp[1] + local l = tlst.get_local(env, name) + local t = get_type(l) + if tltype.isUnionlist(t) and l.i then + set_type(exp, tltype.unionlist2union(t, l.i)) else - msg = string.format(msg, tltype.tostring(exp_type), tltype.tostring(var_type)) - typeerror(env, "set", msg, varlist[1].pos) + set_type(exp, t) end - for k, v in ipairs(varlist) do - local tag = v.tag - if tag == "Id" then - local name = v[1] - local l = tlst.get_local(env, name) - local exp = explist[k] - if exp and exp.tag == "Op" and exp[1] == "or" and - exp[2].tag == "Id" and exp[2][1] == name and not l.assigned then - local t1, t2 = get_type(exp), get_type(l) - if tltype.subtype(t1, t2) then - l.bkp = t2 - set_type(l, t1) - end +end + +local function get_index (u, t, i) + if tltype.isUnionlist(u) then + for k, v in ipairs(u) do + if tltype.subtype(v[i], t) and tltype.subtype(t, v[i]) then + return k end - l.assigned = true - elseif tag == "Index" then - local t1, t2 = get_type(v[1]), get_type(v[2]) end end - return false -end - -local function check_while (env, stm) - local exp1, stm1 = stm[1], stm[2] - check_exp(env, exp1) - local r, _, didgoto = check_block(env, stm1) - return r, _, didgoto -end - -local function check_repeat (env, stm) - local stm1, exp1 = stm[1], stm[2] - local r, _, didgoto = check_block(env, stm1) - check_exp(env, exp1) - return r, _, didgoto end local function tag2type (t) @@ -1249,30 +1117,14 @@ local function tag2type (t) end end -local function get_index (u, t, i) - if tltype.isUnionlist(u) then - for k, v in ipairs(u) do - if tltype.subtype(v[i], t) and tltype.subtype(t, v[i]) then - return k - end - end - end -end - -local function is_global_function_call (exp, fn_name) - return exp.tag == "Call" and exp[1].tag == "Index" and - exp[1][1].tag == "Id" and exp[1][1][1] == "_ENV" and - exp[1][2].tag == "String" and exp[1][2][1] == fn_name -end - -local function check_if (env, stm) +function tlchecker:If (stm, env) local l = {} local rl = {} local isallret = true for i = 1, #stm, 2 do local exp, block = stm[i], stm[i + 1] if block then - check_exp(env, exp) + self:Expression(exp, env) if exp.tag == "Id" then local name = exp[1] local var = tlst.get_local(env, name) @@ -1358,10 +1210,10 @@ local function check_if (env, stm) else block = exp end - local r, isret = check_block(env, block) + local r, isret = self:Block(block, env) table.insert(rl, r) isallret = isallret and isret - for k, v in pairs(l) do + for _, v in pairs(l) do if not tltype.isTuple(v.filter) then set_type(v, v.filter) else @@ -1373,7 +1225,7 @@ local function check_if (env, stm) end end if not isallret then - for k, v in pairs(l) do + for _, v in pairs(l) do if not tltype.isUnionlist(get_type(v)) then set_type(v, v.bkp) else @@ -1383,111 +1235,16 @@ local function check_if (env, stm) end if #stm % 2 == 0 then table.insert(rl, false) end local r = true - for k, v in ipairs(rl) do + for _, v in ipairs(rl) do r = r and v end return r end -local function infer_int(t) - return tltype.isInt(t) or tltype.isInteger(t) -end - -local function check_fornum (env, stm) - local id, exp1, exp2, exp3, block = stm[1], stm[2], stm[3], stm[4], stm[5] - check_exp(env, exp1) - local t1 = get_type(exp1) - local msg = "'for' initial value must be a number" - if tltype.subtype(t1, Number) then - elseif tltype.consistent_subtype(t1, Number) then - typeerror(env, "any", msg, exp1.pos) - else - typeerror(env, "fornum", msg, exp1.pos) - end - check_exp(env, exp2) - local t2 = get_type(exp2) - msg = "'for' limit must be a number" - if tltype.subtype(t2, Number) then - elseif tltype.consistent_subtype(t2, Number) then - typeerror(env, "any", msg, exp2.pos) - else - typeerror(env, "fornum", msg, exp2.pos) - end - local int_step = true - if block then - check_exp(env, exp3) - local t3 = get_type(exp3) - msg = "'for' step must be a number" - if not infer_int(t3) then - int_step = false - end - if tltype.subtype(t3, Number) then - elseif tltype.consistent_subtype(t3, Number) then - typeerror(env, "any", msg, exp3.pos) - else - typeerror(env, "fornum", msg, exp3.pos) - end - else - block = exp3 - end - tlst.begin_scope(env) - tlst.set_local(env, id) - if infer_int(t1) and infer_int(t2) and int_step then - set_type(id, Integer) - else - set_type(id, Number) - end - local r, _, didgoto = check_block(env, block) - check_unused_locals(env) - tlst.end_scope(env) - return r, _, didgoto -end - -local function check_forin (env, idlist, explist, block) - tlst.begin_scope(env) - check_explist(env, explist) - local t = tltype.first(get_type(explist[1])) - local tuple = explist2typegen({}) - local msg = "attempt to iterate over %s" - if tltype.isFunction(t) then - local l = {} - for k, v in ipairs(t[2]) do - l[k] = {} - set_type(l[k], v) - end - tuple = explist2typegen(l) - elseif tltype.isAny(t) then - msg = string.format(msg, tltype.tostring(t)) - typeerror(env, "any", msg, idlist.pos) - else - msg = string.format(msg, tltype.tostring(t)) - typeerror(env, "forin", msg, idlist.pos) - end - for k, v in ipairs(idlist) do - local t = tltype.filterUnion(tuple(k), Nil) - check_local_var(env, v, t, false) - end - local r, _, didgoto = check_block(env, block) - check_unused_locals(env) - tlst.end_scope(env) - return r, _, didgoto -end - -local function check_id (env, exp) - local name = exp[1] - local l = tlst.get_local(env, name) - local t = get_type(l) - if tltype.isUnionlist(t) and l.i then - set_type(exp, tltype.unionlist2union(t, l.i)) - else - set_type(exp, t) - end -end - -local function check_index (env, exp) +function tlchecker:Index (exp, env) local exp1, exp2 = exp[1], exp[2] - check_exp(env, exp1) - check_exp(env, exp2) + self:Expression(exp1, env) + self:Expression(exp2, env) local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) local msg = "attempt to index '%s' with '%s'" t1 = replace_self(env, t1, env.self) @@ -1517,168 +1274,375 @@ local function check_index (env, exp) end end -function check_var (env, var, exp) - local tag = var.tag - if tag == "Id" then - local name = var[1] - local l = tlst.get_local(env, name) - local t = get_type(l) - if exp and exp.tag == "Id" and tltype.isTable(t) then t.open = nil end - set_type(var, t) - elseif tag == "Index" then - local exp1, exp2 = var[1], var[2] - check_exp(env, exp1) - check_exp(env, exp2) - local t1, t2 = tltype.first(get_type(exp1)), tltype.first(get_type(exp2)) - local msg = "attempt to index '%s' with '%s'" - t1 = replace_self(env, t1, env.self) - if tltype.isTable(t1) then - local oself = env.self - -- another brittle hack for defining methods - if exp1.tag == "Id" and exp1[1] ~= "_ENV" then env.self = t1 end - local field_type = tltype.getField(t2, t1) - if not tltype.isNil(field_type) then - set_type(var, field_type) - else - if t1.open then - if exp then - local t3 = tltype.general(get_type(exp)) - local t = tltype.general(t1) - table.insert(t, tltype.Field(var.const, t2, t3)) - if tltype.subtype(t, t1) then - table.insert(t1, tltype.Field(var.const, t2, t3)) - else - msg = "could not include field '%s'" - msg = string.format(msg, tltype.tostring(t2)) - typeerror(env, "open", msg, var.pos) - end - if t3.open then t3.open = nil end - set_type(var, t3) - else - set_type(var, Nil) - end - else - if exp1.tag == "Id" and exp1[1] == "_ENV" and exp2.tag == "String" then - msg = "attempt to access undeclared global '%s'" - msg = string.format(msg, exp2[1]) - else - msg = "attempt to use '%s' to index closed table" - msg = string.format(msg, tltype.tostring(t2)) - end - typeerror(env, "open", msg, var.pos) - set_type(var, Nil) - end - end - env.self = oself - elseif tltype.isAny(t1) then - set_type(var, Any) - msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) - typeerror(env, "any", msg, var.pos) +function tlchecker:Interface (stm, env) + local name, t, is_local = stm[1], stm[2], stm.is_local + if tlst.get_interface(env, name) then + local msg = "attempt to redeclare interface '%s'" + msg = string.format(msg, name) + typeerror(env, "alias", msg, stm.pos) + else + check_self(env, t, t, stm.pos) + local t = replace_names(env, t, stm.pos) + t.name = name + tlst.set_interface(env, name, t, is_local) + end + return false +end + +function tlchecker:Invoke (exp, env) + local exp1, exp2 = exp[1], exp[2] + local explist = {} + for i = 3, #exp do + explist[i - 2] = exp[i] + end + self:Expression(exp1, env) + self:Expression(exp2, env) + self:ExpList(explist, env) + local t1, t2 = get_type(exp1), get_type(exp2) + t1 = replace_self(env, t1, env.self) + table.insert(explist, 1, { type = t1 }) + if tltype.isTable(t1) or + tltype.isString(t1) or + tltype.isStr(t1) then + local inferred_type = replace_self(env, arglist2type(explist, env.strict), env.self) + local t3 + if tltype.isTable(t1) then + t3 = replace_self(env, tltype.getField(t2, t1), t1) + --local s = env.self or Nil + --if not tltype.subtype(s, t1) then env.self = t1 end else - set_type(var, Nil) - msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) - typeerror(env, "index", msg, var.pos) + local string_userdata = env["loaded"]["string"] or tltype.Table() + t3 = replace_self(env, tltype.getField(t2, string_userdata), t1) + inferred_type[1] = String end + local msg = "attempt to call method '%s' of type '%s'" + if tltype.isFunction(t3) then + check_arguments(env, "field", t3[1], inferred_type, exp.pos) + set_type(exp, t3[2]) + elseif tltype.isAny(t3) then + set_type(exp, Any) + msg = string.format(msg, exp2[1], tltype.tostring(t3)) + typeerror(env, "any", msg, exp.pos) + else + set_type(exp, Nil) + msg = string.format(msg, exp2[1], tltype.tostring(t3)) + typeerror(env, "invoke", msg, exp.pos) + end + elseif tltype.isAny(t1) then + set_type(exp, Any) + local msg = "attempt to index '%s' with '%s'" + msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) + typeerror(env, "any", msg, exp.pos) else - error("cannot type check variable " .. tag) + set_type(exp, Nil) + local msg = "attempt to index '%s' with '%s'" + msg = string.format(msg, tltype.tostring(t1), tltype.tostring(t2)) + typeerror(env, "index", msg, exp.pos) end + return false end -function check_exp (env, exp, tself) - local tag = exp.tag - if tag == "Nil" then - set_type(exp, Nil) - elseif tag == "Dots" then - set_type(exp, tltype.Vararg(tlst.get_vararg(env))) - elseif tag == "True" then - set_type(exp, True) - elseif tag == "False" then - set_type(exp, False) - elseif tag == "Number" then - set_type(exp, tltype.Literal(exp[1])) - elseif tag == "String" then - set_type(exp, tltype.Literal(exp[1])) - elseif tag == "Function" then - check_function(env, exp, tself) - elseif tag == "Table" then - check_table(env, exp) - elseif tag == "Op" then - check_op(env, exp) - elseif tag == "Paren" then - check_paren(env, exp) - elseif tag == "Call" then - check_call(env, exp) - elseif tag == "Invoke" then - check_invoke(env, exp) - elseif tag == "Id" then - check_id(env, exp) - elseif tag == "Index" then - check_index(env, exp) - else - error("cannot type check expression " .. tag) - end -end - -function check_stm (env, stm) - local tag = stm.tag - if tag == "Do" then - return check_block(env, stm) - elseif tag == "Set" then - return check_assignment(env, stm[1], stm[2]) - elseif tag == "While" then - return check_while(env, stm) - elseif tag == "Repeat" then - return check_repeat(env, stm) - elseif tag == "If" then - return check_if(env, stm) - elseif tag == "Fornum" then - return check_fornum(env, stm) - elseif tag == "Forin" then - return check_forin(env, stm[1], stm[2], stm[3]) - elseif tag == "Local" then - return check_local(env, stm[1], stm[2]) - elseif tag == "Localrec" then - return check_localrec(env, stm[1][1], stm[2][1]) - elseif tag == "Goto" then - return false, nil, true - elseif tag == "Label" then - return false - elseif tag == "Return" then - return check_return(env, stm) - elseif tag == "Break" then - return false - elseif tag == "Call" then - return check_call(env, stm) - elseif tag == "Invoke" then - return check_invoke(env, stm) - elseif tag == "Interface" then - return check_interface(env, stm) - else - error("cannot type check statement " .. tag) +function tlchecker:Label (stm, env) + return false +end + +local function unannotated_idlist (idlist) + for _, v in ipairs(idlist) do + if v[2] then return false end end + return true end -local function is_exit_point (block) - if #block == 0 then return false end - local last = block[#block] - return last.tag == "Return" or is_global_function_call(last, "error") +local function sized_unionlist (t) + for i = 1, #t - 1 do + if #t[i] ~= #t[i + 1] then return false end + end + return true end -function check_block (env, block) +function tlchecker:Local (stm, env) + local idlist, explist = stm[1], stm[2] + self:ExpList(explist, env) + if unannotated_idlist(idlist) and + #explist == 1 and + tltype.isUnionlist(get_type(explist[1])) and + sized_unionlist(get_type(explist[1])) and + #idlist == #get_type(explist[1])[1] - 1 then + local t = get_type(explist[1]) + for k, v in ipairs(idlist) do + set_type(v, t) + v.i = k + check_masking(env, v[1], v.pos) + tlst.set_local(env, v) + end + else + local tuple = explist2typegen(explist) + for k, v in ipairs(idlist) do + local t = tuple(k) + local close_local = explist[k] and explist[k].tag == "Id" and tltype.isTable(t) + check_local_var(env, v, t, close_local) + end + end + return false +end + +function tlchecker:Localrec (stm, env) + local id, exp = stm[1][1], stm[2][1] + local idlist, ret_type, block = exp[1], replace_names(env, exp[2], exp.pos), exp[3] + local infer_return = false + if not block then + block = ret_type + ret_type = tltype.Tuple({ Nil }, true) + infer_return = true + end + tlst.begin_function(env) + local input_type = check_parameters(env, idlist, exp.pos) + local t = tltype.Function(input_type, ret_type) + id[2] = t + set_type(id, t) + check_masking(env, id[1], id.pos) + tlst.set_local(env, id) tlst.begin_scope(env) - local r = false - local bkp = env.self - local endswithret = true - local didgoto, _ = false, nil - for k, v in ipairs(block) do - r, _, didgoto = check_stm(env, v) - env.self = bkp - if didgoto then endswithret = false end + local len = #idlist + if len > 0 and idlist[len].tag == "Dots" then len = len - 1 end + for k = 1, len do + local v = idlist[k] + v[2] = replace_names(env, v[2], exp.pos) + set_type(v, v[2]) + check_masking(env, v[1], v.pos) + tlst.set_local(env, v) end - endswithret = endswithret and is_exit_point(block) + local r = self:Block(block, env) + if not r then tlst.set_return_type(env, tltype.Tuple({ Nil }, true)) end check_unused_locals(env) tlst.end_scope(env) - return r, endswithret, didgoto + local inferred_type = infer_return_type(env) + if infer_return then + ret_type = inferred_type + t = tltype.Function(input_type, ret_type) + id[2] = t + set_type(id, t) + tlst.set_local(env, id) + set_type(exp, t) + end + check_return_type(env, inferred_type, ret_type, exp.pos) + tlst.end_function(env) + return false +end + +function tlchecker:Nil (exp, env) + set_type(exp, Nil) +end + +function tlchecker:Number (exp, env) + set_type(exp, tltype.Literal(exp[1])) +end + +function tlchecker:Paren (exp, env) + local exp1 = exp[1] + self:Expression(exp1, env) + local t1 = get_type(exp1) + set_type(exp, tltype.first(t1)) +end + +function tlchecker:Repeat (stm, env) + local stm1, exp1 = stm[1], stm[2] + local r, _, didgoto = self:Block(stm1, env) + self:Expression(exp1, env) + return r, _, didgoto +end + +function tlchecker:Return (stm, env) + self:ExpList(stm, env) + local t = explist2typelist(stm) + tlst.set_return_type(env, tltype.general(t)) + return true +end + +function tlchecker:Set (stm, env) + local varlist, explist = stm[1], stm[2] + local lselfs = {} + for k, v in ipairs(varlist) do + if v.tag == "Index" and v[1].tag == "Id" and v[2].tag == "String" then + local l = tlst.get_local(env, v[1][1]) + local t = get_type(l) + -- a brittle hack to type a method definition in the right-hand side? + if tltype.isTable(t) then lselfs[k] = t end + end + end + self:ExpList(explist, env, lselfs) + local l = {} + for k, v in ipairs(varlist) do + check_var(env, v, explist[k]) + table.insert(l, get_type(v)) + end + table.insert(l, tltype.Vararg(Value)) + local var_type, exp_type = tltype.Tuple(l), explist2typelist(explist) + local msg = "attempt to assign '%s' to '%s'" + if tltype.subtype(exp_type, var_type) then + elseif tltype.consistent_subtype(exp_type, var_type) then + msg = string.format(msg, tltype.tostring(exp_type), tltype.tostring(var_type)) + typeerror(env, "any", msg, varlist[1].pos) + else + msg = string.format(msg, tltype.tostring(exp_type), tltype.tostring(var_type)) + typeerror(env, "set", msg, varlist[1].pos) + end + for k, v in ipairs(varlist) do + local tag = v.tag + if tag == "Id" then + local name = v[1] + local l = tlst.get_local(env, name) + local exp = explist[k] + if exp and exp.tag == "Op" and exp[1] == "or" and + exp[2].tag == "Id" and exp[2][1] == name and not l.assigned then + local t1, t2 = get_type(exp), get_type(l) + if tltype.subtype(t1, t2) then + l.bkp = t2 + set_type(l, t1) + end + end + l.assigned = true + elseif tag == "Index" then + local t1, t2 = get_type(v[1]), get_type(v[2]) + end + end + return false +end + +function tlchecker:String (exp, env) + set_type(exp, tltype.Literal(exp[1])) +end + +function tlchecker:Table (exp, env) + local l = {} + local i = 1 + local len = #exp + for k, v in ipairs(exp) do + local tag = v.tag + local t1, t2 + if tag == "Pair" then + local exp1, exp2 = v[1], v[2] + self:Expression(exp1, env) + self:Expression(exp2, env) + t1, t2 = get_type(exp1), tltype.general(get_type(exp2)) + if tltype.subtype(Nil, t1) then + t1 = Any + local msg = "table index can be nil" + typeerror(env, "table", msg, exp1.pos) + elseif not (tltype.subtype(t1, Boolean) or + tltype.subtype(t1, Number) or + tltype.subtype(t1, String)) then + t1 = Any + local msg = "table index is dynamic" + typeerror(env, "any", msg, exp1.pos) + end + else + local exp1 = v + self:Expression(exp1, env) + t1, t2 = tltype.Literal(i), tltype.general(get_type(exp1)) + if k == len and tltype.isVararg(t2) then + t1 = Integer + end + i = i + 1 + end + if t2.open then t2.open = nil end + t2 = tltype.first(t2) + l[k] = tltype.Field(v.const, t1, t2) + end + local t = tltype.Table(table.unpack(l)) + t.unique = true + set_type(exp, t) +end + +function tlchecker:True (exp, env) + set_type(exp, True) +end + +local function check_not (env, exp) + local exp1 = exp[2] + tlchecker:Expression(exp1, env) + set_type(exp, Boolean) +end + +local function check_bnot (env, exp) + local exp1 = exp[2] + tlchecker:Expression(exp1, env) + local t1 = tltype.first(get_type(exp1)) + local msg = "attempt to perform bitwise on a '%s'" + if tltype.subtype(t1, tltype.Integer(true)) then + set_type(exp, Integer) + elseif tltype.isAny(t1) then + set_type(exp, Any) + msg = string.format(msg, tltype.tostring(t1)) + typeerror(env, "any", msg, exp1.pos) + else + set_type(exp, Any) + msg = string.format(msg, tltype.tostring(t1)) + typeerror(env, "bitwise", msg, exp1.pos) + end +end + +local function check_minus (env, exp) + local exp1 = exp[2] + tlchecker:Expression(exp1, env) + local t1 = tltype.first(get_type(exp1)) + local msg = "attempt to perform arithmetic on a '%s'" + if tltype.subtype(t1, Integer) then + set_type(exp, Integer) + elseif tltype.subtype(t1, Number) then + set_type(exp, Number) + elseif tltype.isAny(t1) then + set_type(exp, Any) + msg = string.format(msg, tltype.tostring(t1)) + typeerror(env, "any", msg, exp1.pos) + else + set_type(exp, Any) + t1 = tltype.general(t1) + msg = string.format(msg, tltype.tostring(t1)) + typeerror(env, "arith", msg, exp1.pos) + end +end + +local function check_len (env, exp) + local exp1 = exp[2] + tlchecker:Expression(exp1, env) + local t1 = tltype.first(get_type(exp1)) + local msg = "attempt to get length of a '%s'" + if tltype.subtype(t1, String) or + tltype.subtype(t1, tltype.Table()) then + set_type(exp, Integer) + elseif tltype.isAny(t1) then + set_type(exp, Any) + msg = string.format(msg, tltype.tostring(t1)) + typeerror(env, "any", msg, exp1.pos) + else + set_type(exp, Any) + t1 = tltype.general(t1) + msg = string.format(msg, tltype.tostring(t1)) + typeerror(env, "len", msg, exp1.pos) + end +end + +function tlchecker:UnaryOp (exp, env) + local op = exp[1] + if op == "not" then + check_not(env, exp) + elseif op == "bnot" then + check_bnot(env, exp) + elseif op == "unm" then + check_minus(env, exp) + elseif op == "len" then + check_len(env, exp) + else + error("cannot type check unary operator " .. op) + end +end + +function tlchecker:While (stm, env) + local exp1, stm1 = stm[1], stm[2] + self:Expression(exp1, env) + local r, _, didgoto = self:Block(stm1, env) + return r, _, didgoto end local function load_lua_env (env) @@ -1697,7 +1661,7 @@ local function load_lua_env (env) error("Typed Lua does not support " .. _VERSION) end local t = check_require(env, "base", 0, path) - for k, v in ipairs(l) do + for _, v in ipairs(l) do local t1 = tltype.Literal(v) local t2 = check_require(env, v, 0, path) local f = tltype.Field(false, t1, t2) @@ -1725,8 +1689,8 @@ function tlchecker.typecheck (ast, subject, filename, strict, integer) tlst.begin_scope(env) tlst.set_vararg(env, String) load_lua_env(env) - for k, v in ipairs(ast) do - check_stm(env, v) + for _, v in ipairs(ast) do + tlchecker:Statement(v, env) end check_unused_locals(env) tlst.end_scope(env) @@ -1743,7 +1707,7 @@ function tlchecker.error_msgs (messages, warnings) mask = true, unused = true, } - for k, v in ipairs(messages) do + for _, v in ipairs(messages) do local tag = v.tag if skip_error[tag] then if warnings then diff --git a/typedlua/tlcode.lua b/typedlua/tlcode.lua index 58ded4a2..b854ec3e 100644 --- a/typedlua/tlcode.lua +++ b/typedlua/tlcode.lua @@ -1,10 +1,9 @@ --[[ This file implements the code generator for Typed Lua ]] -local tlcode = {} -local code_block, code_stm, code_exp, code_var -local code_explist, code_varlist, code_fieldlist, code_idlist +local tlvisitor = require "typedlua.tlvisitor" +local tlcode = setmetatable({}, { __index = tlvisitor }) local function spaces (fmt) return string.rep(" ", 2 * fmt.indent) @@ -26,7 +25,7 @@ end local function fix_str (str) local new_str = "" for i=1,string.len(str) do - char = string.byte(str, i) + local char = string.byte(str, i) if char == 34 then new_str = new_str .. string.format("\\\"") elseif char == 92 then new_str = new_str .. string.format("\\\\") elseif char == 7 then new_str = new_str .. string.format("\\a") @@ -47,6 +46,17 @@ local function fix_str (str) return new_str end +local function is_simple_key (key) + return key.tag == "String" and key[1]:match("^[a-zA-Z_][a-zA-Z0-9_]*$") +end + +local function resync_line(node, fmt, out) + while node.l > fmt.line do + table.insert(out, "\n") + fmt.line = fmt.line + 1 + end +end + local op = { add = " + ", sub = " - ", mul = " * ", @@ -70,26 +80,211 @@ local op = { add = " + ", bnot = "~", len = "#" } -local function code_call (call, fmt) + +function tlcode:BinaryOp (exp, fmt) + local str = "" + if _VERSION == "Lua 5.3" then + if exp[2].tag == "Call" and exp[2][1].tag == "Index" and + exp[2][1][1].tag == "Id" and exp[2][1][1][1] == "_ENV" and + exp[2][1][2].tag == "String" and exp[2][1][2][1] == "type" and + exp[3].tag == "String" and exp[3][1] == "integer" then + str = "math." + end + end + return str .. self:Expression(exp[2], fmt) .. op[exp[1]] .. self:Expression(exp[3], fmt) +end + +function tlcode:Block (block, fmt) + local l = {} + local firstline = fmt.line + local saveindent = fmt.indent + if block[1] and block[1].l and block[1].l > firstline then + fmt.indent = fmt.indent + 1 + else + fmt.indent = 0 + end + for _, v in ipairs(block) do + if v.l then + resync_line(v, fmt, l) + else + table.insert(l, "\n") + end + table.insert(l, self:Statement(v, fmt)) + end + if fmt.line ~= firstline then + table.insert(l, "\n") + fmt.line = fmt.line + 1 + else + table.insert(l, " ") + end + fmt.indent = saveindent + return table.concat(l) +end + +function tlcode:Break (stm, fmt) + return indent("break", fmt) +end + +function tlcode:Call (call, fmt) local l = {} for k = 2, #call do - l[k - 1] = code_exp(call[k], fmt) + l[k - 1] = self:Expression(call[k], fmt) end - return code_exp(call[1], fmt) .. "(" .. table.concat(l, ",") .. ")" + return self:Expression(call[1], fmt) .. "(" .. table.concat(l, ",") .. ")" +end + +function tlcode:CallStatement (call, fmt) + return indent(self:Call(call, fmt), fmt) end -local function code_invoke (invoke, fmt) +function tlcode:Do (stm, fmt) + return indent("do ", fmt) .. self:Block(stm, fmt) .. indent("end", fmt) +end + +function tlcode:Dots (exp, fmt) + return "..." +end + +function tlcode:ExpList (explist, fmt) + local l = tlvisitor.ExpList(self, explist, fmt) + return table.concat(l, ", ") +end + +function tlcode:False (exp, fmt) + return "false" +end + +function tlcode:Forin (stm, fmt) + local str = indent("for ", fmt) + str = str .. self:Varlist(stm[1], fmt) .. " in " + str = str .. self:ExpList(stm[2], fmt) .. " do " + str = str .. self:Block(stm[3], fmt) + str = str .. indent("end", fmt) + return str +end + +function tlcode:Fornum (stm, fmt) + local str = indent("for ", fmt) + str = str .. self:Variable(stm[1], fmt) .. " = " .. self:Expression(stm[2], fmt) + str = str .. ", " .. self:Expression(stm[3], fmt) + if stm[5] then + str = str .. ", " .. self:Expression(stm[4], fmt) .. " do " + str = str .. tlcode:Block(stm[5], fmt) + else + str = str .. " do " .. self:Block(stm[4], fmt) + end + str = str .. indent("end", fmt) + return str +end + +function tlcode:Function (exp, fmt) + local str = "function (" + str = str .. self:Parlist(exp[1], fmt) .. ") " + if not exp[3] then + str = str .. self:Block(exp[2], fmt) .. indent("end", fmt) + else + str = str .. self:Block(exp[3], fmt) .. indent("end", fmt) + end + return str +end + +function tlcode:Goto (stm, fmt) + return indent("goto ", fmt) .. stm[1] +end + +function tlcode:Id (var, fmt) + return var[1] +end + +function tlcode:If (stm, fmt) + local str = indent("if ", fmt) .. self:Expression(stm[1], 0) .. " then " + str = str .. self:Block(stm[2], fmt) + local len = #stm + if len % 2 == 0 then + for k=3, len, 2 do + str = str .. indent("elseif ", fmt) .. self:Expression(stm[k], 0) .. " then " + str = str .. self:Block(stm[k+1], fmt) + end + else + for k=3, len-1, 2 do + str = str .. indent("elseif ", fmt) .. self:Expression(stm[k], 0) .. " then " + str = str .. self:Block(stm[k+1], fmt) + end + str = str .. indent("else ", fmt) + str = str .. self:Block(stm[len], fmt) + end + str = str .. indent("end", fmt) + return str +end + +function tlcode:Index (var, fmt) + if var[1].tag == "Id" and var[1][1] == "_ENV" and var[2].tag == "String" then + local v = { tag = "Id", [1] = var[2][1] } + return self:Expression(v, fmt) + else + if is_simple_key(var[2]) then + return self:Expression(var[1], fmt) .. "." .. var[2][1] + else + return self:Expression(var[1], fmt) .. "[" .. self:Expression(var[2], fmt) .. "]" + end + end +end + +function tlcode:Interface (stm, fmt) + return "" +end + +function tlcode:Invoke (invoke, fmt) local l = {} for k = 3, #invoke do - l[k - 2] = code_exp(invoke[k], fmt) + l[k - 2] = self:Expression(invoke[k], fmt) end - local str = code_exp(invoke[1], fmt) + local str = self:Expression(invoke[1], fmt) str = str .. ":" .. invoke[2][1] str = str .. "(" .. table.concat(l, ",") .. ")" return str end -local function code_parlist (parlist, fmt) +function tlcode:InvokeStatement (stm, fmt) + return indent(self:Invoke(stm, fmt), fmt) +end + +function tlcode:Label (stm, fmt) + return indent("::", fmt) .. stm[1] .. "::" +end + +function tlcode:Local (stm, fmt) + local str = indent("local ", fmt) .. self:Varlist(stm[1], fmt) + if #stm[2] > 0 then + str = str .. " = " .. self:ExpList(stm[2], fmt) + end + return str +end + +function tlcode:Localrec (stm, fmt) + local str = indent("local function ", fmt) .. self:Variable(stm[1][1], fmt) + str = str .. " (" .. self:Parlist(stm[2][1][1], fmt) .. ") " + if not stm[2][1][3] then + str = str .. self:Block(stm[2][1][2], fmt) .. indent("end", fmt) + else + str = str .. self:Block(stm[2][1][3], fmt) .. indent("end", fmt) + end + return str +end + +function tlcode:Nil (exp, fmt) + return "nil" +end + +function tlcode:Number (exp, fmt) + return tostring(exp[1]) +end + +function tlcode:Paren (exp, fmt) + return "(" .. self:Expression(exp[1], fmt) .. ")" +end + +function tlcode:Parlist (parlist, fmt) local l = {} local len = #parlist local is_vararg = false @@ -97,9 +292,8 @@ local function code_parlist (parlist, fmt) is_vararg = true len = len - 1 end - local k = 1 for k=1, len do - l[k] = code_var(parlist[k], fmt) + l[k] = self:Variable(parlist[k], fmt) end if is_vararg then table.insert(l, "...") @@ -107,251 +301,66 @@ local function code_parlist (parlist, fmt) return table.concat(l, ", ") end -local function is_simple_key (key) - return key.tag == "String" and key[1]:match("^[a-zA-Z_][a-zA-Z0-9_]*$") +function tlcode:Repeat (stm, fmt) + local str = indent("repeat ", fmt) + str = str .. self:Block(stm[1], fmt) + str = str .. indent("until ", fmt) + str = str .. self:Expression(stm[2], fmt) + return str +end + +function tlcode:Return (stm, fmt) + return indent("return ", fmt) .. self:ExpList(stm, fmt) +end + +function tlcode:Set (stm, fmt) + return spaces(fmt) .. self:Varlist(stm[1], fmt) .. " = " .. self:ExpList(stm[2], fmt) end -local function code_fieldlist (fieldlist, fmt) +function tlcode:String (exp, fmt) + return '"' .. fix_str(exp[1]) .. '"' +end + +function tlcode:Table (fieldlist, fmt) local l = {} for k, v in ipairs(fieldlist) do if v.tag == "Pair" then if is_simple_key(v[1]) then - l[k] = v[1][1] .. " = " .. code_exp(v[2], fmt) + l[k] = v[1][1] .. " = " .. self:Expression(v[2], fmt) else - l[k] = "[" .. code_exp(v[1], fmt) .. "] = " .. code_exp(v[2], fmt) + l[k] = "[" .. self:Expression(v[1], fmt) .. "] = " .. self:Expression(v[2], fmt) end else - l[k] = code_exp(v, fmt) + l[k] = self:Expression(v, fmt) end end - return table.concat(l, ", ") -end - -function code_var (var, fmt) - local tag = var.tag - if tag == "Id" then - return var[1] - elseif tag == "Index" then - if var[1].tag == "Id" and var[1][1] == "_ENV" and var[2].tag == "String" then - local v = { tag = "Id", [1] = var[2][1] } - return code_exp(v, fmt) - else - if is_simple_key(var[2]) then - return code_exp(var[1], fmt) .. "." .. var[2][1] - else - return code_exp(var[1], fmt) .. "[" .. code_exp(var[2], fmt) .. "]" - end - end - else - error("trying to generate code for a variable, but got a " .. tag) - end + return "{" .. table.concat(l, ", ") .. "}" end -function code_varlist (varlist, fmt) - local l = {} - for k, v in ipairs(varlist) do - l[k] = code_var(v, fmt) - end - return table.concat(l, ", ") +function tlcode:True (exp, fmt) + return "true" end -function code_exp (exp, fmt) - local tag = exp.tag - if tag == "Nil" then - return "nil" - elseif tag == "Dots" then - return "..." - elseif tag == "True" then - return "true" - elseif tag == "False" then - return "false" - elseif tag == "Number" then - return tostring(exp[1]) - elseif tag == "String" then - return '"' .. fix_str(exp[1]) .. '"' - elseif tag == "Function" then - local str = "function (" - str = str .. code_parlist(exp[1], fmt) .. ") " - if not exp[3] then - str = str .. code_block(exp[2], fmt) .. indent("end", fmt) - else - str = str .. code_block(exp[3], fmt) .. indent("end", fmt) - end - return str - elseif tag == "Table" then - local str = "{" .. code_fieldlist(exp, fmt) .. "}" - return str - elseif tag == "Op" then - local str = "" - if exp[3] then - if _VERSION == "Lua 5.3" then - if exp[2].tag == "Call" and exp[2][1].tag == "Index" and - exp[2][1][1].tag == "Id" and exp[2][1][1][1] == "_ENV" and - exp[2][1][2].tag == "String" and exp[2][1][2][1] == "type" and - exp[3].tag == "String" and exp[3][1] == "integer" then - str = "math." - end - end - str = str .. code_exp(exp[2], fmt) .. op[exp[1]] .. code_exp(exp[3], fmt) - else - str = str .. op[exp[1]] .. "(" .. code_exp(exp[2], fmt) .. ")" - end - return str - elseif tag == "Paren" then - local str = "(" .. code_exp(exp[1], fmt) .. ")" - return str - elseif tag == "Call" then - return code_call(exp, fmt) - elseif tag == "Invoke" then - return code_invoke(exp, fmt) - elseif tag == "Id" or - tag == "Index" then - return code_var(exp, fmt) - else - error("trying to generate code for a expression, but got a " .. tag) - end +function tlcode:UnaryOp (exp, fmt) + return op[exp[1]] .. "(" .. self:Expression(exp[2], fmt) .. ")" end -function code_explist (explist, fmt) - local l = {} - for k, v in ipairs(explist) do - l[k] = code_exp(v, fmt) - end +function tlcode:Varlist (varlist, fmt) + local l = tlvisitor.Varlist(self, varlist, fmt) return table.concat(l, ", ") end -function code_stm (stm, fmt) - local tag = stm.tag - if tag == "Do" then - local str = indent("do ", fmt) .. code_block(stm, fmt) .. indent("end", fmt) - return str - elseif tag == "Set" then - local str = spaces(fmt) - str = str .. code_varlist(stm[1], fmt) .. " = " .. code_explist(stm[2], fmt) - return str - elseif tag == "While" then - local str = indent("while ", fmt) .. code_exp(stm[1], 0) .. " do " - str = str .. code_block(stm[2], fmt) .. indent("end", fmt) - return str - elseif tag == "Repeat" then - local str = indent("repeat ", fmt) - str = str .. code_block(stm[1], fmt) - str = str .. indent("until ", fmt) - str = str .. code_exp(stm[2], fmt) - return str - elseif tag == "If" then - local str = indent("if ", fmt) .. code_exp(stm[1], 0) .. " then " - str = str .. code_block(stm[2], fmt) - local len = #stm - if len % 2 == 0 then - for k=3, len, 2 do - str = str .. indent("elseif ", fmt) .. code_exp(stm[k], 0) .. " then " - str = str .. code_block(stm[k+1], fmt) - end - else - for k=3, len-1, 2 do - str = str .. indent("elseif ", fmt) .. code_exp(stm[k], 0) .. " then " - str = str .. code_block(stm[k+1], fmt) - end - str = str .. indent("else ", fmt) - str = str .. code_block(stm[len], fmt) - end - str = str .. indent("end", fmt) - return str - elseif tag == "Fornum" then - local str = indent("for ", fmt) - str = str .. code_var(stm[1], fmt) .. " = " .. code_exp(stm[2], fmt) - str = str .. ", " .. code_exp(stm[3], fmt) - if stm[5] then - str = str .. ", " .. code_exp(stm[4], fmt) .. " do " - str = str .. code_block(stm[5], fmt) - else - str = str .. " do " .. code_block(stm[4], fmt) - end - str = str .. indent("end", fmt) - return str - elseif tag == "Forin" then - local str = indent("for ", fmt) - str = str .. code_varlist(stm[1], fmt) .. " in " - str = str .. code_explist(stm[2], fmt) .. " do " - str = str .. code_block(stm[3], fmt) - str = str .. indent("end", fmt) - return str - elseif tag == "Local" then - local str = indent("local ", fmt) .. code_varlist(stm[1], fmt) - if #stm[2] > 0 then - str = str .. " = " .. code_explist(stm[2], fmt) - end - return str - elseif tag == "Localrec" then - local str = indent("local function ", fmt) .. code_var(stm[1][1], fmt) - str = str .. " (" .. code_parlist(stm[2][1][1], fmt) .. ") " - if not stm[2][1][3] then - str = str .. code_block(stm[2][1][2], fmt) .. indent("end", fmt) - else - str = str .. code_block(stm[2][1][3], fmt) .. indent("end", fmt) - end - return str - elseif tag == "Goto" then - local str = indent("goto ", fmt) .. stm[1] - return str - elseif tag == "Label" then - local str = indent("::", fmt) .. stm[1] .. "::" - return str - elseif tag == "Return" then - local str = indent("return ", fmt) .. code_explist(stm, fmt) - return str - elseif tag == "Break" then - return indent("break", fmt) - elseif tag == "Call" then - return indent(code_call(stm, fmt), fmt) - elseif tag == "Invoke" then - return indent(code_invoke(stm, fmt), fmt) - elseif tag == "Interface" then - return "" - else - error("tyring to generate code for a statement, but got " .. tag) - end -end - -local function resync_line(node, fmt, out) - while node.l > fmt.line do - table.insert(out, "\n") - fmt.line = fmt.line + 1 - end +function tlcode:While (stm, fmt) + local str = indent("while ", fmt) .. self:Expression(stm[1], 0) .. " do " + str = str .. self:Block(stm[2], fmt) .. indent("end", fmt) + return str end -function code_block (block, fmt) - local l = {} - local firstline = fmt.line - local saveindent = fmt.indent - if block[1] and block[1].l and block[1].l > firstline then - fmt.indent = fmt.indent + 1 - else - fmt.indent = 0 - end - for _, v in ipairs(block) do - if v.l then - resync_line(v, fmt, l) - else - table.insert(l, "\n") - end - table.insert(l, code_stm(v, fmt)) - end - if fmt.line ~= firstline then - table.insert(l, "\n") - fmt.line = fmt.line + 1 - else - table.insert(l, " ") - end - fmt.indent = saveindent - return table.concat(l) -end function tlcode.generate (ast) assert(type(ast) == "table") local fmt = { line = 1, indent = -1 } - return code_block(ast, fmt) .. "\n" + return tlcode:visit(ast, fmt) .. "\n" end return tlcode diff --git a/typedlua/tlvisitor.lua b/typedlua/tlvisitor.lua new file mode 100644 index 00000000..694f3ad9 --- /dev/null +++ b/typedlua/tlvisitor.lua @@ -0,0 +1,424 @@ +--[[ +This file implements a visitor for the Typed Lua AST +]] + +local default_visitor = {} + + +--------------------------------------------------------------------------------------------------- +-- Common +--------------------------------------------------------------------------------------------------- + +function default_visitor:visit (node, ...) + local tag = node.tag + if tag then + local method = self[tag] + if method then + return method(self, node, ...) + else + error("the visitor doesn't know how to visit " .. tag) + end + --[[ + else + -- Fallback: visit all subnodes + for _, child in pairs(node) do + if type(child) == "table" then + self:visit(child, ...) + end + end + --]] + end +end + + +--------------------------------------------------------------------------------------------------- +-- Nodes +--------------------------------------------------------------------------------------------------- + +function default_visitor:BinaryOp (exp, ...) + self:Expression(exp[2], ...) + self:Expression(exp[3], ...) +end + +-- block: { stat* } +function default_visitor:Block (block, ...) + for _, stm in ipairs(block) do + self:Statement(stm, ...) + end +end + +-- Break -- break +function default_visitor:Break () -- stm, ... +end + +-- Call{ expr expr* } +function default_visitor:Call (call, ...) + self:Expression(call[1], ...) + for k = 2, #call do + self:Expression(call[k], ...) + end +end + +function default_visitor:CallExpression (exp, ...) + return self:Call(exp, ...) +end + +function default_visitor:CallStatement (stm, ...) + return self:Call(stm, ...) +end + +-- Do{ stat* } +function default_visitor:Do (stm, ...) + return self:Block(stm, ...) +end + +-- Dots +function default_visitor:Dots () -- exp, ... +end + +function default_visitor:ExpList (explist, ...) + local l = {} + for k, exp in ipairs(explist) do + l[k] = self:Expression(exp, ...) + end + return l +end + +function default_visitor:Expression (exp, ...) + local tag = exp.tag + if tag == "Id" or tag == "Index" then + return self:Variable(exp, ...) + elseif tag == "Call" then + return self:CallExpression(exp, ...) + elseif tag == "Invoke" then + return self:InvokeExpression(exp, ...) + elseif tag == "Dots" + or tag == "False" + or tag == "Function" + or tag == "Nil" + or tag == "Number" + or tag == "Op" + or tag == "Paren" + or tag == "String" + or tag == "Table" + or tag == "True" then + + return self:visit(exp, ...) + else + error("tyring to visit an expression, but got " .. tag) + end +end + +-- False +function default_visitor:False () -- exp, ... +end + +-- Forin{ {ident+} {expr+} block } -- for i1, i2... in e1, e2... do b end +function default_visitor:Forin (stm, ...) + self:Varlist(stm[1], ...) + self:ExpList(stm[2], ...) + self:Block(stm[3], ...) +end + +-- Fornum{ ident expr expr expr? block } -- for ident = e, e[, e] do b end +function default_visitor:Fornum (stm, ...) + self:Variable(stm[1], ...) + self:Expression(stm[2], ...) + self:Expression(stm[3], ...) + if stm[5] then + self:Expression(stm[4], ...) + self:Block(stm[5], ...) + else + self:Block(stm[4], ...) + end +end + +-- Function{ { ident* { `Dots type? }? } typelist? block } +function default_visitor:Function (exp, ...) + self:Parlist(exp[1], ...) + if not exp[3] then + self:Block(exp[2], ...) + else + self:Block(exp[3], ...) + end +end + +-- Goto{ } -- goto str +function default_visitor:Goto () -- stm, ... +end + +-- Id{ type? } +function default_visitor:Id () -- var, ... +end + +-- If{ (expr block)+ block? } -- if e1 then b1 [elseif e2 then b2] ... [else bn] end +function default_visitor:If (stm, ...) + self:Expression(stm[1], ...) + self:Block(stm[2], ...) + + local len = #stm + if len % 2 == 0 then + for k=3, len, 2 do + self:Expression(stm[k], ...) + self:Block(stm[k+1], ...) + end + else + for k=3, len-1, 2 do + self:Expression(stm[k], ...) + self:Block(stm[k+1], ...) + end + self:Block(stm[len], ...) + end +end + +-- Index{ expr expr } +function default_visitor:Index (var, ...) + if var[1].tag == "Id" and var[1][1] == "_ENV" and var[2].tag == "String" then + local v = { tag = "Id", [1] = var[2][1] } + self:Expression(v, ...) + else + self:Expression(var[1], ...) -- obj + self:Expression(var[2], ...) -- field + end +end + +-- Interface{ type } +function default_visitor:Interface () -- stm, ... +end + +-- Invoke{ expr `String{ } expr* } +function default_visitor:Invoke (invoke, ...) + self:Expression(invoke[1], ...) + --invoke[2][1] -- method name + for k = 3, #invoke do + self:Expression(invoke[k], ...) + end +end + +function default_visitor:InvokeExpression (exp, ...) + return self:Invoke(exp, ...) +end + +function default_visitor:InvokeStatement (stm, ...) + return self:Invoke(stm, ...) +end + +-- Label{ } -- ::str:: +function default_visitor:Label () -- stm, ... +end + +-- Local{ {ident+} {expr+}? } -- local i1, i2... = e1, e2... +function default_visitor:Local (stm, ...) + self:Varlist(stm[1], ...) + if #stm[2] > 0 then + self:ExpList(stm[2], ...) + end +end + +-- Localrec{ ident expr } -- only used for 'local function' +function default_visitor:Localrec (stm, ...) + self:Variable(stm[1][1], ...) + self:Parlist(stm[2][1][1], ...) + if not stm[2][1][3] then + self:Block(stm[2][1][2], ...) + else + self:Block(stm[2][1][3], ...) + end +end + +function default_visitor:NameList () -- list, ... + -- TODO +end + +-- Nil +function default_visitor:Nil () -- exp, ... +end + +-- Number{ } +function default_visitor:Number () -- exp, ... +end + +-- Op{ opid expr expr? } +function default_visitor:Op (exp, ...) + if exp[3] then + return self:BinaryOp(exp, ...) + else + return self:UnaryOp(exp, ...) + end +end + +-- Paren{ expr } -- significant to cut multiple values returns +function default_visitor:Paren (exp, ...) + return self:Expression(exp[1], ...) +end + +function default_visitor:Parlist (parlist, ...) + local len = #parlist + for k=1, len do + self:Variable(parlist[k], ...) + end +end + +-- Repeat{ block expr } -- repeat b until e +function default_visitor:Repeat (stm, ...) + self:Block(stm[1], ...) + self:Expression(stm[2], ...) +end + +-- Return{ } -- return e1, e2... +function default_visitor:Return (stm, ...) + return self:ExpList(stm, ...) +end + +-- Set{ {lhs+} {expr+} } -- lhs1, lhs2... = e1, e2... +function default_visitor:Set (stm, ...) + self:Varlist(stm[1], ...) + self:ExpList(stm[2], ...) +end + +function default_visitor:Statement (stm, ...) + local tag = stm.tag + if tag == "Call" then + return self:CallStatement(stm, ...) + elseif tag == "Invoke" then + return self:InvokeStatement(stm, ...) + elseif tag == "Break" + or tag == "Do" + or tag == "Forin" + or tag == "Fornum" + or tag == "Goto" + or tag == "If" + or tag == "Interface" + or tag == "Label" + or tag == "Local" + or tag == "Localrec" + or tag == "Repeat" + or tag == "Return" + or tag == "Set" + or tag == "While" then + + return self:visit(stm, ...) + else + error("tyring to visit a statement, but got " .. tag) + end +end + +-- String{ } +function default_visitor:String () -- exp, ... +end + +-- Table{ ( `Pair{ expr expr } | expr )* } +function default_visitor:Table (fieldlist, ...) + for _, v in ipairs(fieldlist) do + if v.tag == "Pair" then + self:Expression(v[1], ...) -- field + self:Expression(v[2], ...) -- value + else + self:Expression(v, ...) + end + end +end + +-- True +function default_visitor:True () -- exp, ... +end + +function default_visitor:UnaryOp (exp, ...) + return self:Expression(exp[2], ...) +end + +function default_visitor:Variable (var, ...) + local tag = var.tag + if tag == "Dots" + or tag == "Id" + or tag == "Index" then + return self:visit(var, ...) + else + error("tyring to visit a variable, but got " .. tag) + end +end + +function default_visitor:Varlist (varlist, ...) + local l = {} + for k, var in ipairs(varlist) do + l[k] = self:Variable(var, ...) + end + return l +end + +-- While{ expr block } -- while e do b end +function default_visitor:While (stm, ...) + self:Expression(stm[1], ...) + self:Block(stm[2], ...) +end + + +--------------------------------------------------------------------------------------------------- +-- Types +--------------------------------------------------------------------------------------------------- + +-- TAny +function default_visitor:TAny () -- t, ... +end + +-- TBase{ 'boolean' | 'number' | 'string' } +function default_visitor:TBase () -- t, ... +end + +-- TODO +function default_visitor:TField () -- t, ... +end + +-- TFunction{ type type } +function default_visitor:TFunction () -- t, ... +end + +-- TLiteral{ false | true | | } +function default_visitor:TLiteral () -- t, ... +end + +-- TNil +function default_visitor:TNil () -- t, ... +end + +-- TODO +-- TRecursive{ type } + +-- TSelf +function default_visitor:TSelf () -- t, ... +end + +-- TTable{ type type* } +function default_visitor:TTable () -- t, ... +end + +-- TTuple{ type type* } +function default_visitor:TTuple () -- t, ... +end + +-- TUnion{ type type type* } +function default_visitor:TUnion () -- t, ... +end + +-- TUnionlist{ type type type* } +function default_visitor:TUnionlist () -- t, ... +end + +-- TValue +function default_visitor:TValue () -- t, ... +end + +-- TVararg{ type } +function default_visitor:TVararg () -- t, ... +end + +-- TVariable{ } +function default_visitor:TVariable () -- t, ... +end + +-- TVoid +function default_visitor:TVoid () -- t, ... +end + + +return default_visitor