From 45222662ef9f6c36987c45b0b9b0c74e4aa9d555 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?T=C3=A9rence=20Clastres?= Date: Thu, 24 Aug 2023 12:22:56 +0200 Subject: [PATCH 1/7] function monomorphization [skip ci] --- src/codegen/codegenEnv.ml | 2 +- .../monomorphization/monomorphization.ml | 280 +++++++++++++++++- .../monomorphization/monomorphizationMonad.ml | 45 +++ .../monomorphization/monomorphizationUtils.ml | 112 +++++++ src/passes/process/process.ml | 6 +- src/passes/process/processUtils.ml | 2 +- 6 files changed, 428 insertions(+), 19 deletions(-) create mode 100644 src/passes/monomorphization/monomorphizationMonad.ml create mode 100644 src/passes/monomorphization/monomorphizationUtils.ml diff --git a/src/codegen/codegenEnv.ml b/src/codegen/codegenEnv.ml index b01e25f..93ae0fe 100644 --- a/src/codegen/codegenEnv.ml +++ b/src/codegen/codegenEnv.ml @@ -43,7 +43,7 @@ let getLLVMBasicType f t llc llm : lltype = | String -> i8_type llc |> pointer_type | ArrayType (t,s) -> array_type (aux t) s | Box t | RefType (t,_) -> aux t |> pointer_type - | GenericType _ -> failwith "there should be no generic type, was degenerifyType used ? " + | GenericType _ -> failwith "no generic type in codegen" | CompoundType {name=(_,name); _} when name = "_value" -> i64_type llc (* for extern functions *) | CompoundType {origin=None;_} | CompoundType {decl_ty=None;_} -> failwith "compound type with no origin or decl_ty" | CompoundType {origin=Some (_,mname); name=(_,name); decl_ty=Some d;_} -> diff --git a/src/passes/monomorphization/monomorphization.ml b/src/passes/monomorphization/monomorphization.ml index 219e178..039e6d8 100644 --- a/src/passes/monomorphization/monomorphization.ml +++ b/src/passes/monomorphization/monomorphization.ml @@ -1,26 +1,276 @@ open Common +open Monad open TypesCommon -module E = Common.Error.Logger -open Monad.UseMonad(E) -open IrMir -open IrHir -open SailParser - -type mono_body = { - monomorphics : AstMir.mir_function method_defn list; - polymorphics : AstMir.mir_function method_defn list; - processes : (HirUtils.statement,HirUtils.expression) AstParser.process_body process_defn list -} +module E = Common.Error +open Monad.MonadSyntax (E.Logger) +open IrMir.AstMir +open MonomorphizationMonad +module M = MonoMonad +open MonomorphizationUtils +open MonadSyntax(M) +open MonadOperator(M) +open MonadFunctions(M) + module Pass = Pass.Make (struct let name = "Monomorphization" - type in_body = (AstMir.mir_function,(HirUtils.statement,HirUtils.expression) AstParser.process_body) SailModule.methods_processes - type out_body = mono_body + type in_body = (MonomorphizationUtils.in_body,(IrHir.HirUtils.statement,IrHir.HirUtils.expression) SailParser.AstParser.process_body) SailModule.methods_processes + type out_body = MonomorphizationUtils.out_body module Env = SailModule.DeclEnv + let mono_fun (f : sailor_function) (sm : in_body SailModule.t) : unit M.t = + + let mono_exp (e : expression) : sailtype M.t = + let rec aux (e : expression) : sailtype M.t = + match e.exp with + | Variable s -> M.get_var s >>| fun v -> (v |> Option.get |> snd).ty + + | Literal l -> return (sailtype_of_literal l) + + | ArrayRead (e, idx) -> + begin + let* t = aux e in + match t with + | ArrayType (t, _) -> + let+ idx_t = aux idx in + let _ = resolveType idx_t (Int 32) [] [] in + t + | _ -> failwith "cannot happen" + end + | UnOp (_, e) -> aux e + + | BinOp (_, e1, e2) -> + let* t1 = aux e1 in + let+ t2 = aux e2 in + let _ = resolveType t1 t2 [] [] in + t1 + + | Ref (m, e) -> + let+ t = aux e in + RefType (t, m) + + | Deref e -> ( + let+ t = aux e in + match t with + | RefType _ -> t + | _ -> failwith "cannot happen" + ) + + | ArrayStatic (e :: h) -> + let* t = aux e in + let+ t = + ListM.fold_left (fun last_t e -> + let+ next_t = aux e in + let _ = resolveType next_t last_t [] [] in + next_t + ) t h + in + ArrayType (t, List.length (e :: h)) + + | ArrayStatic [] -> failwith "error : empty array" + | StructAlloc (_, _, _) -> failwith "todo: struct alloc" + | EnumAlloc (_, _) -> failwith "todo: enum alloc" + | StructRead (_, _, _) -> failwith "todo: struct read" + | MethodCall _ -> failwith "no method call at this stage" + in + aux e + in + + let construct_call (calle : string) (el : expression list) : (string * sailtype option) M.t = + (* we construct the types of the args (and collect extra new calls) *) + Logs.debug (fun m -> m "contructing call to %s from %s" calle f.m_proto.name); + let* monos = M.get_monos and* funs = M.get_funs in + Logs.debug (fun m -> m "current monos : %s" (String.concat ";" (List.map ( fun (g,(t:sailor_args)) -> g ^ " -> " ^ (List.map (fun (id,t) -> "(" ^ id ^ "," ^ string_of_sailtype (Some t) ^ ")") t |> String.concat "," )) monos))); + Logs.debug (fun m -> m "current funs : %s" (FieldMap.fold (fun name _ acc -> Fmt.str "%s;%s" name acc) funs "")); + + + let* call_args = + ListM.fold_left + (fun l e -> + Logs.debug (fun m -> m "analyze param expression"); + let* t = mono_exp e in + Logs.debug (fun m -> m "param is %s " @@ string_of_sailtype @@ Some t); + return (t :: l) + ) + [] el + in + + (*don't do anything if the function is already added *) + let mname = mangle_method_name calle call_args in + let* funs = M.get_funs in + match FieldMap.find_opt mname funs with + | Some f -> + Logs.debug (fun m -> m "function %s already discovered, skipping" calle); + return (mname,f.methd.m_proto.rtype) + | None -> + begin + let* f = find_callable calle sm |> M.lift in + match f with + | None -> (*import *) return (mname,Some (Int 32) (*fixme*)) + | Some f -> + begin + Logs.debug (fun m -> m "found call to %s, variadic : %b" f.m_proto.name f.m_proto.variadic ); + match f.m_body with + | Right _ -> + (* process and method + + we make sure they correspond to what the callable wants + if the callable is generic we check all the generic types are present at least once + + we build a (string*sailtype) list of generic to type correspondance + if the generic is not found in the list, we add it with the corresponding type + if the generic already exists with the same type as the new one, we are good else we fail + *) + let* resolved_generics = check_args call_args f |> M.lift in + List.iter (fun (n, t) -> Logs.debug (fun m -> m "resolved %s to %s " n (string_of_sailtype (Some t)))) resolved_generics; + + let* () = M.push_monos calle resolved_generics in + + let* rtype = + match f.m_proto.rtype with + | Some t -> + (* Logs.warn (fun m -> m "TYPE BEFORE : %s" (string_of_sailtype (Some t))); *) + let+ t = (degenerifyType t resolved_generics|> M.lift) in + (* Logs.warn (fun m -> m "TYPE AFTER : %s" (string_of_sailtype (Some t))); *) + Some t + | None -> return None + in + + let params = List.map2 (fun (p:param) ty -> {p with ty}) f.m_proto.params call_args in + let name = mname in + let methd = { f with m_proto = { f.m_proto with rtype ; params } } in + let+ () = + let* f = M.get_decl name (Self Method) in + if Option.is_none f then + M.add_decl name ((dummy_pos,name),(defn_to_proto (Method methd))) Method + else return () + in + mname,rtype + | Left _ -> (* external method *) return (calle,f.m_proto.rtype) + end + end + in + + let rec mono_body (lbl: label) (treated: LabelSet.t) (blocks : (VE.t,unit) basicBlock BlockMap.t): (LabelSet.t * (_,_) basicBlock BlockMap.t) MonoMonad.t = + (* collect calls and name correctly *) + if LabelSet.mem lbl treated then return (treated,blocks) + else + begin + let treated = LabelSet.add lbl treated in + + let bb = BlockMap.find lbl blocks in + let* () = M.set_ve bb.forward_info in + let* () = ListM.iter (fun assign -> mono_exp assign.target >>= fun _ty -> mono_exp assign.expression >>| fun _ty -> ()) bb.assignments + in + + match bb.terminator |> Option.get with + | Return e -> + let+ _ = + begin + match e with + | Some e -> let+ t = mono_exp e in Some t + | None -> return None + end + in treated,blocks + + | Invoke new_f -> + let* (id,_) = construct_call new_f.id new_f.params in + mono_body new_f.next treated BlockMap.(update lbl (fun _ -> Some {bb with terminator=Some (Invoke {new_f with id})}) blocks) + + | Goto lbl -> mono_body lbl treated blocks + + | SwitchInt si -> + let* _ = mono_exp si.choice in + let* treated,blocks = mono_body si.default treated blocks in + ListM.fold_left ( fun (treated,blocks) (_,lbl) -> + mono_body lbl treated blocks + ) (treated,blocks) si.paths + + | Break -> failwith "no break should be there" + end + in + + match f.m_body with + | Right (decls,cfg) -> mono_body cfg.input LabelSet.empty cfg.blocks >>= fun (_,blocks) -> + let params = List.map (fun (p:param) -> p.ty) f.m_proto.params in + let name = mangle_method_name f.m_proto.name params in + let methd = {m_proto = f.m_proto; m_body=Right (decls,{cfg with blocks})} in + M.add_fun name {methd; generics=[]} + + | Left _ -> (* external *) return () + + + let analyse_functions (sm : in_body SailModule.t) : unit M.t = + + (* find the function, apply generic substitutions to its signature and monomorphize *) + let find_fun_and_mono (name, (g : sailor_args)) : unit M.t = + let* f = find_callable name sm |> M.lift in + match f with + | None -> (* fixme imports *) return () + | Some f -> + (* monomorphize signature with resolved generics (if any) *) + let* params = ListM.map (fun (p : param) -> let+ ty = degenerifyType p.ty g |> M.lift in {p with ty}) f.m_proto.params in + let* rtype = + match f.m_proto.rtype with + | Some t -> let+ t = degenerifyType t g |> M.lift in Some t + | None -> return None + in + (* update function signature *) + let f = { f with m_proto = { f.m_proto with params; rtype } } in + (* monomorphize, updating env with any new function calls found *) + mono_fun f sm + in + + let rec aux () : unit M.t = + let* empty = M.get_monos >>| (=) [] in + if not empty then (* runs until no more new monomorphic function is found *) + begin + let* name,args = M.pop_monos in + Logs.debug (fun m -> m "looking at function %s with args %s " name (List.map (fun (_,t) -> string_of_sailtype @@ Some t) args |> String.concat " ")); + + let mname = mangle_method_name name (List.split args |> snd) in + + (* we only look at untreated functions *) + let* funs = M.get_funs in + match FieldMap.find_opt mname funs with + | Some _ -> + Logs.debug (fun m -> m "%s already checked" mname); + return () + | None -> + Logs.debug (fun m -> m "analyzing monomorphic function %s" mname); + find_fun_and_mono (name, args) >>= aux + end + else return () + in + let* empty = M.get_monos >>| (=) [] in M.throw_if Error.(make dummy_pos "no monomorphic callable (no main?)") empty >>= aux + + let transform (smdl : in_body SailModule.t) : out_body SailModule.t E.t = - let polymorphics,monomorphics = List.partition (fun m -> m.m_proto.generics <> []) smdl.body.methods in - return {smdl with body={monomorphics;polymorphics;processes=smdl.body.processes}} + let add_if_mono name args gens else_ret = + let args = List.map (fun (p:param) -> p.id,p.ty) args in + if gens <> [] then M.pure [else_ret] else M.push_monos name args >>| fun () -> [] + in + + + let mono_poly = (* add monomorphics to the env and collect generic methods *) + M.pure [] + (* our entry points are non generic methods and processes *) + >>= fun l -> ListM.fold_left (fun acc m -> add_if_mono m.m_proto.name m.m_proto.params m.m_proto.generics m >>| Fun.flip List.append acc) l smdl.body.methods + (* + analyze them, find and resolve calls to generic functions + IMPORTANT : we must keep the generic functions : if one of them + is called from an other module and we don't have a monomorphic versio, we must generate one using the generic version + *) + >>= fun l -> analyse_functions smdl >>| fun () -> l + in + + let open MonadSyntax(E) in + let+ polymorphics,mono_env = M.run smdl.declEnv mono_poly in + Logs.info (fun m -> m "generated %i monomorphic functions : " (List.length (FieldMap.bindings mono_env.functions))); + FieldMap.iter print_method_proto mono_env.functions; + let monomorphics = List.filter (fun m -> Either.is_left m.m_body) smdl.body.methods |> FieldMap.fold (fun name f acc -> {f.methd with m_proto={f.methd.m_proto with name}}::acc) mono_env.functions in + + {smdl with body={monomorphics;polymorphics;processes=smdl.body.processes}} end) \ No newline at end of file diff --git a/src/passes/monomorphization/monomorphizationMonad.ml b/src/passes/monomorphization/monomorphizationMonad.ml new file mode 100644 index 0000000..19b49a6 --- /dev/null +++ b/src/passes/monomorphization/monomorphizationMonad.ml @@ -0,0 +1,45 @@ +open Common +open Monad +open TypesCommon +open MonomorphizationUtils + +type env = {monos: monomorphics; functions : sailor_functions; env: varTypesMap} + +module MonoMonad = struct + module S = MonadState.T(Error.Logger)(struct type t = env end) + open MonadSyntax(S) + open MonadOperator(S) + include S + (* error *) + let throw e = E.throw e |> lift + let throw_if e c = E.throw_if e c |> lift + + let get_decl id ty = get >>| fun e -> Env.get_decl id ty e.env + let add_decl id decl ty = update (fun e -> E.bind (Env.add_decl id decl ty e.env) (fun env -> E.pure {e with env})) + let get_var id = get >>| fun e -> Env.get_var id e.env + let set_ve ve = update (fun e -> E.pure {e with env=(ve,snd e.env)}) + + + let add_fun mname (f: 'a sailor_method) = S.update (fun e -> E.pure {e with functions=FieldMap.add mname f e.functions}) + let get_funs = let+ e = S.get in e.functions + + let push_monos name generics = S.update (fun e -> E.pure {e with monos=(name,generics)::e.monos}) + let get_monos = let+ e = S.get in e.monos + let pop_monos = let* e = S.get in + match e.monos with + | [] -> throw Error.(make dummy_pos "empty_monos") + | h::monos -> S.set {e with monos} >>| fun () -> h + + let run (decls:Env.D.t) (x: 'a t) : ('a * env) E.t = x {monos=[];functions=FieldMap.empty;env=Env.empty decls} + +end + + +let mangle_method_name (name : string) (args : sailtype list) : string = + let back = + List.fold_left (fun s t -> s ^ string_of_sailtype (Some t) ^ "_") "" args + in + let front = "_" ^ name ^ "_" in + let res = front ^ back in + Logs.debug (fun m -> m "renamed %s to %s" name res); + res \ No newline at end of file diff --git a/src/passes/monomorphization/monomorphizationUtils.ml b/src/passes/monomorphization/monomorphizationUtils.ml new file mode 100644 index 0000000..868b8f6 --- /dev/null +++ b/src/passes/monomorphization/monomorphizationUtils.ml @@ -0,0 +1,112 @@ +open Common +open TypesCommon +open Monad +open IrHir +module E = Error.Logger +module Env = SailModule.SailEnv(IrMir.AstMir.V) +open UseMonad(E) + +type in_body = IrMir.AstMir.mir_function +type out_body = { + monomorphics : in_body method_defn list; + polymorphics : in_body method_defn list; + processes : (HirUtils.statement,HirUtils.expression) SailParser.AstParser.process_body process_defn list +} + + +type sailor_args = sailtype dict +type varTypesMap = Env.t +type monomorphics = sailor_args dict +type sailor_function = in_body method_defn +type 'a sailor_method = { methd : 'a method_defn; generics : sailor_args } +type sailor_functions = in_body sailor_method FieldMap.t + +let print_method_proto (name : string) (methd : in_body sailor_method) = + let args_type = + List.map (fun (p : param) -> p.ty) methd.methd.m_proto.params + in + let args = + String.concat "," + (List.map (fun t -> string_of_sailtype (Some t)) args_type) + in + let methd_string = Printf.sprintf "method %s (%s)" name args in + Logs.debug (fun m -> m "%s" methd_string) + + + +let resolveType (arg : sailtype) (m_param : sailtype) (generics : string list) (resolved_generics : sailor_args) : (sailtype * sailor_args) E.t = + let rec aux (a : sailtype) (m : sailtype) (g : sailor_args) = + match (a, m) with + | Bool, Bool -> return (Bool, g) + | Int x, Int y when x = y -> return (Int x, g) + | Float, Float -> return (Float, g) + | Char, Char -> return (Char, g) + | String, String -> return (String, g) + | ArrayType (at, s), ArrayType (mt, _) -> let+ t,g = aux at mt g in ArrayType (t, s), g + | GenericType _g1, GenericType _g2 -> return (Int 32,g) + (* E.throw Error.(make dummy_pos @@ Fmt.str "resolveType between generic %s and %s" g1 g2) *) + | at, GenericType gt -> + let* () = E.throw_if Error.(make dummy_pos @@ Fmt.str "generic type %s not declared" gt) (not @@ List.mem gt generics) in + begin + match List.assoc_opt gt g with + | None -> return (at, (gt, at) :: g) + | Some t -> + E.throw_if + Error.(make dummy_pos @@ Fmt.str "generic type mismatch : %s -> %s vs %s" gt (string_of_sailtype (Some t)) (string_of_sailtype (Some at))) + (t <> at) + >>| fun () -> at, g + end + | RefType (at, _), RefType (mt, _) -> aux at mt g + + | CompoundType _, CompoundType _ -> failwith "todocompoundtype" + | Box _at, Box _mt -> failwith "todobox" + | _ -> E.throw Error.(make dummy_pos @@ Fmt.str "cannot happen : %s vs %s" (string_of_sailtype (Some a)) (string_of_sailtype (Some m))) + in + aux arg m_param resolved_generics + +let degenerifyType (t : sailtype) (generics : sailor_args) : sailtype E.t = + let rec aux = function + | Bool -> return Bool + | Int n -> return (Int n) + | Float -> return Float + | Char -> return Char + | String -> return String + | ArrayType (t, s) -> let+ t = aux t in ArrayType (t, s) + | Box t -> let+ t = aux t in Box t + | RefType (t, m) -> let+ t = aux t in RefType (t, m) + | GenericType _t when generics = [] -> + (* E.throw Error.(make dummy_pos @@ Fmt.str "generic type %s present but empty generics list" t) *) + return (Int 32) + + | GenericType _n -> + (* E.throw_if_none Error.(make dummy_pos @@ Fmt.str "generic type %s not present in the generics list" n) (List.assoc_opt n generics) *) + return (Int 32) + | CompoundType _ -> failwith "todo compoundtype" + in + aux t + +let check_args (caller_args : sailtype list) (f:sailor_function) : sailor_args E.t = + let margs = List.map (fun (p:param) -> p.ty) f.m_proto.params in + Logs.debug (fun m -> m "caller args : %s" + (List.fold_left (fun acc t ->Printf.sprintf "%s %s," acc (string_of_sailtype (Some t))) "" caller_args)); + Logs.debug (fun m -> + m "method args : %s" + (List.fold_left (fun acc t -> Printf.sprintf "%s %s," acc (string_of_sailtype (Some t))) "" margs)); + + let args = if f.m_proto.variadic then List.filteri (fun i _ -> i < (List.length margs)) caller_args else caller_args in + +let+ resolved_generics = ListM.fold_right2 (fun ca a g -> resolveType ca a f.m_proto.generics g >>| snd) args margs [] in + List.rev resolved_generics + +let find_callable (name : string) (sm : _ SailModule.methods_processes SailModule.t) : sailor_function option E.t = + (* fixme imports *) + Logs.debug (fun m -> m "looking for function %s" name); + Logs.debug (fun m -> m "name is %s" name); + Logs.debug (fun m -> m "%s" @@ SailModule.DeclEnv.string_of_env sm.declEnv); + match SailModule.DeclEnv.find_decl name (All (Method)) sm.declEnv with + + | [_,_] -> + return @@ List.find_opt (fun m -> print_string m.m_proto.name; print_newline (); m.m_proto.name = name) sm.body.methods + + | [] -> E.throw Error.(make dummy_pos @@ Fmt.str "mono : %s not found" name) + | l -> E.throw Error.(make dummy_pos @@ Fmt.str "multiple symbols for %s : %s" name (List.map (fun (i,_) -> i.mname) l |> String.concat " ")) \ No newline at end of file diff --git a/src/passes/process/process.ml b/src/passes/process/process.ml index 0c12548..d5b7187 100644 --- a/src/passes/process/process.ml +++ b/src/passes/process/process.ml @@ -92,7 +92,9 @@ module Pass = Pass.Make(struct in let open Monad.UseMonad(E) in - let+ main = lower_processes sm.body.processes in + let* main = lower_processes sm.body.processes in let body : in_body = { methods = main::sm.body.methods; processes = sm.body.processes} in - {sm with body} + let decl = SailModule.method_decl_of_defn main in + let+ declEnv = SailModule.DeclEnv.add_decl "main" decl Method sm.declEnv in + {sm with body ; declEnv} end) \ No newline at end of file diff --git a/src/passes/process/processUtils.ml b/src/passes/process/processUtils.ml index af26c49..ea5e83d 100644 --- a/src/passes/process/processUtils.ml +++ b/src/passes/process/processUtils.ml @@ -44,7 +44,7 @@ let find_process_source (name: l_str) (import : l_str option) procs : 'a process else let find_import = List.find_opt (fun i -> i.mname = origin) (HirUtils.D.get_imports env) in let+ i = M.throw_if_none Error.(make dummy_pos "can't happen") find_import in - let sm = In_channel.with_open_bin (i.dir ^ i.mname ^ Constants.mir_file_ext) @@ fun c -> (Marshal.from_channel c : Mono.Monomorphization.mono_body SailModule.t) + let sm = In_channel.with_open_bin (i.dir ^ i.mname ^ Constants.mir_file_ext) @@ fun c -> (Marshal.from_channel c : Mono.MonomorphizationUtils.out_body SailModule.t) in sm.body.processes in List.find_opt (fun (p:_ process_defn) -> p.p_name = snd name) procs From eacb0f1748f339240e2eb018a73c79f8876c687e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?T=C3=A9rence=20Clastres?= Date: Sat, 26 Aug 2023 23:06:46 +0200 Subject: [PATCH 2/7] global typing context --- bin/sailor.ml | 39 +- examples/imperative/arrays/minArray.sl | 2 +- examples/imperative/arrays/sumArray.sl | 2 +- examples/imperative/complex/list.sl | 2 +- examples/imperative/complex/testLVal.sl | 2 +- examples/imperative/genericity/generics2.sl | 2 +- examples/imperative/genericity/min2.sl | 2 +- .../imperative/genericity/minArrayGeneric.sl | 2 +- .../genericity/testInnerGenericity.sl | 2 +- .../genericity/testInnerGenericity2.sl | 2 +- examples/imperative/loops/sum.sl | 2 +- examples/imperative/loops/while1.sl | 2 +- .../imperative/pointers/bettercallsaul.sl | 2 +- examples/imperative/pointers/drop1.sl | 2 +- examples/imperative/pointers/drop2.sl | 2 +- examples/imperative/pointers/drop3.sl | 2 +- examples/imperative/pointers/drop4.sl | 2 +- examples/imperative/pointers/dropassign1.sl | 2 +- examples/imperative/pointers/dropblock1.sl | 2 +- examples/imperative/pointers/free1.sl | 2 +- examples/imperative/pointers/free2.sl | 2 +- examples/imperative/pointers/testBox.sl | 2 +- examples/imperative/pointers/testBox1.sl | 2 +- examples/imperative/pointers/testMove1.sl | 2 +- examples/imperative/pointers/testMove2.sl | 2 +- examples/imperative/pointers/testMove3.sl | 2 +- examples/imperative/simple/arithmetic.mir | Bin 993 -> 1986 bytes examples/imperative/simple/arithmetic.sl | 2 +- examples/imperative/simple/decl1.sl | 2 +- examples/imperative/simple/decl2.sl | 2 +- examples/imperative/simple/factorial.sl | 2 +- examples/imperative/simple/helloworld.sl | 2 +- examples/imperative/simple/min.sl | 2 +- examples/imperative/simple/mutual_rec.sl | 2 +- examples/imperative/simple/testReturnVal.sl | 2 +- .../structuresAndEnums/lr_values.sl | 2 +- .../imperative/structuresAndEnums/point.sl | 2 +- .../structuresAndEnums/testFieldAssign.sl | 2 +- src/codegen/codegenEnv.ml | 60 ++-- src/codegen/codegenUtils.ml | 24 +- src/codegen/codegen_.ml | 213 ++++++----- src/common/builtins.ml | 2 +- src/common/env.ml | 52 ++- src/common/error.ml | 176 --------- src/common/logging.ml | 195 ++++++++++ src/common/monadic/monad.ml | 8 +- src/common/monadic/monadEither.ml | 1 + src/common/monadic/monadError.ml | 1 - src/common/monadic/monadState.ml | 2 +- src/common/pass.ml | 37 +- src/common/ppCommon.ml | 2 +- src/common/sailModule.ml | 5 +- src/common/typesCommon.ml | 46 +-- src/parsing/astParser.ml | 22 +- src/parsing/parser.mly | 32 +- src/parsing/parsing.ml | 12 +- src/passes/ir/sailHir/astHir.ml | 4 +- src/passes/ir/sailHir/hir.ml | 202 ++--------- src/passes/ir/sailHir/hirMonad.ml | 39 +- src/passes/ir/sailHir/hirUtils.ml | 339 +++++++++++++----- src/passes/ir/sailHir/pp_hir.ml | 12 +- src/passes/ir/sailMir/mir.ml | 35 +- src/passes/ir/sailMir/mirMonad.ml | 18 +- src/passes/ir/sailMir/mirUtils.ml | 23 +- src/passes/ir/sailMir/pp_mir.ml | 5 +- src/passes/ir/sailThir/thir.ml | 259 +++++++------ src/passes/ir/sailThir/thirMonad.ml | 78 ++-- src/passes/ir/sailThir/thirUtils.ml | 253 +++++++++---- src/passes/misc/cfg_analysis.ml | 24 +- src/passes/misc/imports.ml | 2 +- src/passes/misc/methodCall.ml | 179 --------- .../monomorphization/monomorphization.ml | 127 +++---- .../monomorphization/monomorphizationMonad.ml | 9 +- .../monomorphization/monomorphizationUtils.ml | 77 ++-- src/passes/process/process.ml | 16 +- src/passes/process/processMonad.ml | 2 +- src/passes/process/processUtils.ml | 4 +- test/blackbox-tests/sailor.t/test_utils.sl | 4 + 78 files changed, 1424 insertions(+), 1288 deletions(-) delete mode 100644 src/common/error.ml create mode 100644 src/common/logging.ml delete mode 100644 src/passes/misc/methodCall.ml diff --git a/bin/sailor.ml b/bin/sailor.ml index 6895f31..8c42a49 100644 --- a/bin/sailor.ml +++ b/bin/sailor.ml @@ -1,7 +1,7 @@ open Common open TypesCommon open SailParser -module E = Error.Logger +module E = Logging.Logger module Const = Constants module C = Codegen @@ -17,15 +17,25 @@ module Thir = IrThir.Thir.Pass module Mir = IrMir.Mir.Pass module MirChecks = Misc.Cfg_analysis.Pass module Imports = Misc.Imports.Pass -module MCall = Misc.MethodCall.Pass module Mono = Mono.Monomorphization.Pass (* error handling *) open Monad.UseMonad(E) let apply_passes (sail_module : Hir.in_body SailModule.t) (comp_mode : Cli.comp_mode) (dump_ir : bool): Mono.out_body SailModule.t E.t = - let hir_debug = fun m -> let+ m in Out_channel.with_open_text (sail_module.md.name ^ ".hir.debug") (fun f -> Format.(fprintf (formatter_of_out_channel f)) "%a" IrHir.Pp_hir.ppPrintModule m); m in - let mir_debug = fun m -> let+ m in Out_channel.with_open_text (sail_module.md.name ^ ".mir.debug") (fun f -> Format.(fprintf (formatter_of_out_channel f)) "%a" IrMir.Pp_mir.ppPrintModule m); m in + let hir_debug = fun m -> let+ m in Out_channel.with_open_text + (sail_module.md.name ^ ".hir.debug") + (fun f -> IrHir.Pp_hir.ppPrintModule (Format.formatter_of_out_channel f) m); m + in + let mir_debug = fun m -> let+ m in Out_channel.with_open_text + (sail_module.md.name ^ ".mir.debug") + (fun f -> IrMir.Pp_mir.ppPrintModule (Format.formatter_of_out_channel f) m); m + in + + let mir_mono_debug = fun (m: Mono.out_body SailModule.t E.t) -> let+ m in + Out_channel.with_open_text + (sail_module.md.name ^ ".mir_mono.debug") + Format.(fun f -> (pp_print_list IrMir.Pp_mir.ppPrintMethod) (formatter_of_out_channel f) m.body.monomorphics); m in let open Pass.Progression in let active_if cond p = if cond then p else Fun.id in @@ -35,11 +45,11 @@ let apply_passes (sail_module : Hir.in_body SailModule.t) (comp_mode : Cli.comp_ @> active_if dump_ir hir_debug @> Thir.transform @> Imports.transform - @> MCall.transform @> Mir.transform @> MirChecks.transform @> active_if dump_ir mir_debug @> Mono.transform + @> active_if dump_ir mir_mono_debug @> finish in run passes (return sail_module) @@ -218,7 +228,7 @@ let sailor (files: string list) (intermediate:bool) (jit:bool) (noopt:bool) (dum "a module cannot import itself" else "dependency cycle : " ^ (String.concat " -> " ((List.split compiling |> fst |> List.rev) @ [slmd.md.name;i.mname])) - in Error.make i.loc msg + in Logging.make_msg i.loc msg ) (List.mem_assoc i.mname compiling) in let mir_name = i.mname ^ Const.mir_file_ext in let source = i.mname ^ Const.sail_file_ext in @@ -234,11 +244,11 @@ let sailor (files: string list) (intermediate:bool) (jit:bool) (noopt:bool) (dum | None, Some m -> (* mir but no source -> use mir *) let mir = unmarshal_sm m in E.throw_if - (Error.make dummy_pos @@ Printf.sprintf "module %s was compiled with sailor %s, current is %s" mir.md.name mir.md.version Const.sailor_version) + Logging.(make_msg dummy_pos @@ Printf.sprintf "module %s was compiled with sailor %s, current is %s" mir.md.name mir.md.version Const.sailor_version) (mir.md.version <> Const.sailor_version) >>| fun () -> treated,import m | None,None -> (* nothing to work with *) - E.throw @@ Error.make i.loc "import not found" + E.throw Logging.(make_msg i.loc "import not found") | Some s, _ -> (* source but no mir or mir not up-to-date -> compile *) begin let+ treated = process_file s treated ((slmd.md.name,i.loc)::compiling) Cli.Library @@ -257,7 +267,7 @@ let sailor (files: string list) (intermediate:bool) (jit:bool) (noopt:bool) (dum (* if mir file exists, check hash, if same hash, no need to compile *) if Sys.file_exists mir_file && (List.length force_comp = 0) then let mir = unmarshal_sm mir_file in - let* () = E.throw_if (Error.make dummy_pos @@ Printf.sprintf "module %s was compiled with sailor %s, current is %s" mir.md.name mir.md.version Const.sailor_version) + let* () = E.throw_if Logging.(make_msg dummy_pos @@ Printf.sprintf "module %s was compiled with sailor %s, current is %s" mir.md.name mir.md.version Const.sailor_version) (mir.md.version <> Const.sailor_version) in if not @@ Digest.equal mir.md.hash slmd.md.hash then @@ -272,10 +282,13 @@ let sailor (files: string list) (intermediate:bool) (jit:bool) (noopt:bool) (dum in try - match ListM.fold_left (fun t f -> let+ t = process_file f t [] comp_mode in f::t) [] files with - | Ok treated,_ -> Logs.debug (fun m -> m "files processed : %s " @@ String.concat " " treated) ; `Ok () - | Error e,errs -> - Error.print_errors (e::errs); + let process_files = ListM.fold_left (fun t f -> let+ t = process_file f t [] comp_mode in f::t) [] in + match process_files files with + | Ok treated,l -> + Logging.print_log l; + Logs.debug (fun m -> m "files processed : %s " @@ String.concat " " treated) ; `Ok () + | Error e,l -> + Logging.print_log {l with errors=e::l.errors}; `Error(false, "compilation aborted") with | e -> diff --git a/examples/imperative/arrays/minArray.sl b/examples/imperative/arrays/minArray.sl index c8d79ab..11c817a 100644 --- a/examples/imperative/arrays/minArray.sl +++ b/examples/imperative/arrays/minArray.sl @@ -10,5 +10,5 @@ Run: cpt = cpt + 1 } print_int (a[res]); print_newline(); - exit(0); + quit(); } diff --git a/examples/imperative/arrays/sumArray.sl b/examples/imperative/arrays/sumArray.sl index 60e1d75..a9f3063 100644 --- a/examples/imperative/arrays/sumArray.sl +++ b/examples/imperative/arrays/sumArray.sl @@ -14,5 +14,5 @@ Run: } print_int (res); print_newline(); print_string("Hello\n"); - exit(0); + quit(); } diff --git a/examples/imperative/complex/list.sl b/examples/imperative/complex/list.sl index 0325e88..86878f3 100644 --- a/examples/imperative/complex/list.sl +++ b/examples/imperative/complex/list.sl @@ -49,5 +49,5 @@ Loop: var u : int = length(l); print_int(u);print_newline(); print_int(length(l));print_newline(); - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/complex/testLVal.sl b/examples/imperative/complex/testLVal.sl index 214b5ce..0a7c2be 100644 --- a/examples/imperative/complex/testLVal.sl +++ b/examples/imperative/complex/testLVal.sl @@ -41,5 +41,5 @@ Loop: }; var y : int = *a[0] + *a[1]; if (y == 5) print_string ("OK\n") else print_string ("KO\n"); - exit(0); + quit(); } diff --git a/examples/imperative/genericity/generics2.sl b/examples/imperative/genericity/generics2.sl index 4a92983..3fd7c58 100644 --- a/examples/imperative/genericity/generics2.sl +++ b/examples/imperative/genericity/generics2.sl @@ -19,5 +19,5 @@ process Main { // printf("%b\n", b); // printf("%c\n", c); - exit(0); + quit(); } diff --git a/examples/imperative/genericity/min2.sl b/examples/imperative/genericity/min2.sl index b4a7c1c..f2c98e1 100644 --- a/examples/imperative/genericity/min2.sl +++ b/examples/imperative/genericity/min2.sl @@ -7,5 +7,5 @@ process Main { print_int(min(3,4)); print_newline (); // printf("%f\n", min(3.5,4.5)); - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/genericity/minArrayGeneric.sl b/examples/imperative/genericity/minArrayGeneric.sl index f4ff402..9d31549 100644 --- a/examples/imperative/genericity/minArrayGeneric.sl +++ b/examples/imperative/genericity/minArrayGeneric.sl @@ -27,5 +27,5 @@ Loop: printf("%i\n", getMin(a)); printf("%f\n", getMin(c)); - exit(0); + quit(); } diff --git a/examples/imperative/genericity/testInnerGenericity.sl b/examples/imperative/genericity/testInnerGenericity.sl index c2850e1..1a05ea0 100644 --- a/examples/imperative/genericity/testInnerGenericity.sl +++ b/examples/imperative/genericity/testInnerGenericity.sl @@ -12,5 +12,5 @@ Init: Loop: f(1); - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/genericity/testInnerGenericity2.sl b/examples/imperative/genericity/testInnerGenericity2.sl index 94fd642..b78ecc9 100644 --- a/examples/imperative/genericity/testInnerGenericity2.sl +++ b/examples/imperative/genericity/testInnerGenericity2.sl @@ -12,5 +12,5 @@ Init: Loop: print_int(f(g(1), 2, 2.1)); ; - exit(0); + quit(); } diff --git a/examples/imperative/loops/sum.sl b/examples/imperative/loops/sum.sl index e3cb2ed..1eb1fc7 100644 --- a/examples/imperative/loops/sum.sl +++ b/examples/imperative/loops/sum.sl @@ -14,5 +14,5 @@ process Main { Run: print_int(sumTo(10)); print_newline(); - exit(0); + quit(); } diff --git a/examples/imperative/loops/while1.sl b/examples/imperative/loops/while1.sl index e027b95..f49caf8 100644 --- a/examples/imperative/loops/while1.sl +++ b/examples/imperative/loops/while1.sl @@ -9,5 +9,5 @@ Run: } print_int(x); print_string(" Worlds\n"); - exit(0); + quit(); } diff --git a/examples/imperative/pointers/bettercallsaul.sl b/examples/imperative/pointers/bettercallsaul.sl index 686c0e3..101abbb 100644 --- a/examples/imperative/pointers/bettercallsaul.sl +++ b/examples/imperative/pointers/bettercallsaul.sl @@ -24,5 +24,5 @@ Loop: print_string(" "); print_int(z); print_newline(); - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/pointers/drop1.sl b/examples/imperative/pointers/drop1.sl index 8651e9e..e62c367 100644 --- a/examples/imperative/pointers/drop1.sl +++ b/examples/imperative/pointers/drop1.sl @@ -14,5 +14,5 @@ Loop: //print_int(*x); print_newline(); // print_int(y); print_newline() x = box(3); - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/pointers/drop2.sl b/examples/imperative/pointers/drop2.sl index 3857c8a..f4d5ce0 100644 --- a/examples/imperative/pointers/drop2.sl +++ b/examples/imperative/pointers/drop2.sl @@ -11,5 +11,5 @@ Loop: // Ok, the content of y was tagged as moved, no free here } // OK, the pointer is freed once here; - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/pointers/drop3.sl b/examples/imperative/pointers/drop3.sl index e95e216..704262c 100644 --- a/examples/imperative/pointers/drop3.sl +++ b/examples/imperative/pointers/drop3.sl @@ -13,5 +13,5 @@ Loop: // Error x is not initialized as we don't enter the loop // print_int(*x); print_newline() ; - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/pointers/drop4.sl b/examples/imperative/pointers/drop4.sl index b01c145..bbf5aa1 100644 --- a/examples/imperative/pointers/drop4.sl +++ b/examples/imperative/pointers/drop4.sl @@ -15,5 +15,5 @@ Loop: // Error x is points to the box which has been freed // print_int(*x); print_newline(); // x = box(5) // needed otherwise we will try to drop the box a second time; - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/pointers/dropassign1.sl b/examples/imperative/pointers/dropassign1.sl index b87dcd5..d7cfbce 100644 --- a/examples/imperative/pointers/dropassign1.sl +++ b/examples/imperative/pointers/dropassign1.sl @@ -8,5 +8,5 @@ Loop: x = box(3); x = box(1); print_string("done\n"); - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/pointers/dropblock1.sl b/examples/imperative/pointers/dropblock1.sl index 8193347..94f2dc6 100644 --- a/examples/imperative/pointers/dropblock1.sl +++ b/examples/imperative/pointers/dropblock1.sl @@ -9,5 +9,5 @@ Loop: print_string("done\n") } print_string("done\n"); - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/pointers/free1.sl b/examples/imperative/pointers/free1.sl index 2ead4d5..e026022 100644 --- a/examples/imperative/pointers/free1.sl +++ b/examples/imperative/pointers/free1.sl @@ -13,5 +13,5 @@ Loop: } * x = * x + 1; print_int(*x); print_newline(); - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/pointers/free2.sl b/examples/imperative/pointers/free2.sl index ef10ba6..9d12628 100644 --- a/examples/imperative/pointers/free2.sl +++ b/examples/imperative/pointers/free2.sl @@ -8,5 +8,5 @@ process Main (){ drop(y) } ; - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/pointers/testBox.sl b/examples/imperative/pointers/testBox.sl index b5cdf05..cc78324 100644 --- a/examples/imperative/pointers/testBox.sl +++ b/examples/imperative/pointers/testBox.sl @@ -5,5 +5,5 @@ Loop: var x : box = box(3); print_int(*x); - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/pointers/testBox1.sl b/examples/imperative/pointers/testBox1.sl index f52d301..314acdc 100644 --- a/examples/imperative/pointers/testBox1.sl +++ b/examples/imperative/pointers/testBox1.sl @@ -30,5 +30,5 @@ Loop: var a : t = t { x: box(1)}; *(a.x) = 18; print_int(*(a.x)); - exit(0); + quit(); } diff --git a/examples/imperative/pointers/testMove1.sl b/examples/imperative/pointers/testMove1.sl index 21ed65d..bb5684b 100644 --- a/examples/imperative/pointers/testMove1.sl +++ b/examples/imperative/pointers/testMove1.sl @@ -11,5 +11,5 @@ Loop: //x = box(1); //y = x; //z = x; - exit(0); + quit(); } diff --git a/examples/imperative/pointers/testMove2.sl b/examples/imperative/pointers/testMove2.sl index 42232e0..4e5a120 100644 --- a/examples/imperative/pointers/testMove2.sl +++ b/examples/imperative/pointers/testMove2.sl @@ -10,5 +10,5 @@ Loop: var x : box; x = box(1); f(x,x); - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/pointers/testMove3.sl b/examples/imperative/pointers/testMove3.sl index 9004a1d..df3167f 100644 --- a/examples/imperative/pointers/testMove3.sl +++ b/examples/imperative/pointers/testMove3.sl @@ -12,5 +12,5 @@ Loop: x = box(1); f(x); y = x; - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/simple/arithmetic.mir b/examples/imperative/simple/arithmetic.mir index 73a35b59426c8b62b25ff41fab75e1863509eecf..c725855b70cca606a923981a5e843b7e930c2f77 100644 GIT binary patch literal 1986 zcmah}TWnNS6g>ym4Ac}fzz2_rPG2a2P79*ZAJ3iYU@b!#%EJ%MbSP6gGA(UqN*~BW zLNGyy7i?)8AUVB(ib$|X;Ku-mm_QIi>f^&N(-=N{{P;-tGrG>brGP))J9GCwd#}CD z-skLd_799+6=M2%A>{W$gfH^Hb7EYTXWI%ZmTPUEQA+%SX>Ox(#>U50S!+7mrfDyF zHq(=CZO>)$(QIpbrZZj0?#M**)N`3=x-(m7Y0VU}O_ljvO+$mhw6#*)mf~wa-YEMx zWJ_2!6ipgvB(pY{=_q9|%TdWCg-pJX>MCS&`By5Vb`ag1OE+!XyuD|c9p$6>bT+rW zv$7CX<9t+ErYBqQ6I}W~_y^23$$WuH*5pM3(~Z)mB40=sy7JU8gH~o!L=k6tD5g@& z94S_FS@<~5wRdv8QvBPQF=v%nC)OFvt(D?EDfS_3DI+^9+XCxrGFWKk7K>hFwXFtE zTlr4F7+V%Fwq?3=*|v<&T3#>3b^Z)0l2ZIE#V?kZ5361p!1D#>vDHUc)=9CSKZ9sO zimOswdpL`gfp^H@g+Lp&nYWU#2N!9mgpd#Yop?^L>t0x zBP;aRYjrH~mO&zxc-wl{#uAMN$ylP9T#j;LWgT6?24Q1^6u zZCz@~gFe!~V{a)pT}Rd{ZA#j!930Sie*j_SY8?DX;}Aj&8Ph}_@Vm^XIKmv5#Ru_R zyTPG&ZjV}$sPc1E{M>l1$KXgjx6@WY!?9Ju5w95xTDf)#-j}Mj@Z;8+Kxh)I zYsH(Q(O|^NZByXhls!5)ZLKX9tj3hs#C=i|t{zv2Xp|e2Dpa6^!}%Idu_F`~VF6c2 z`T?RKP*>(gWWN_PB zcT9%fuinp)A>UeOFNy>AWo2W+20vJHA9wXW-d`=^f3%KWm^X=cCh>Qzc{jrSBQ=dQ z8T?8En;&HXuzw^Lr1(eBSceS$3^Z<#pXI$^c=$hg`pZtfzXP|6Vl1CZibW`+V#;y9 ziKeEGU@E&}c}j=mleniyu1!c5@dyk_F-rCEK>ZkVr)W5b5($f&6*x~&eIiiLjMFuN zt-OyOooMo8+KEa+at3e7q-B`5MvAjioToZJP_K|7Ie`w#AQCp$ij+M9Hgai5&SvBW z8(G5ReNBolsg?%nEk*hSi558WEagP4@F!{R%!%cIP30MOS|j3J8Pp>772b3 zG=B9|aTSVQ#7%Iim%3Q5!4y4|5ctsF;A@?G*RWylo-@CXbM8Ht8DDt)fw96bjIk|> zbF}VnW?gi#@=_?8;9QiPFzrx088g$qXgq1ALRNIh^rb0}nZ8geYTb^TRy16nj`_Q~ z1QfJuY)xZ1f%4I5Q4KDjm{I{HI(D8)rlJX}H)BO(>6&_9I+Td?4cI<(T^CR0IpR6l z&h4=Od8kt>3$Ot7YGZ>e1-k`u;TfVoD3(k&^^x(KW2imU<;@B6ZY-KGWnn12qOn!_ z3E11Iv7Z{-fFdEqF4pGNX^VDwONkG)GmBXQ$~rW*qOq)${%d%IIE_sTs8GpNK-H*L zHTspiDrkfRR0oX*+rG5iXl>9K5KtF1lH|})_iq~$nRuU>%0hwT7$iHZrq-n6RT9eXXsE|97SY*pa8wScB8aYfu5j!Pr%Kf%~}W^Y1)Rv7ujWYT|h*} zK4ONOsY(_j@)FQ>xQkH6;cynp(PPug0S@I^IEbDR+ABxoS*YS1j)T{c33VC^gfjvj z2Es#*F~5igDb6Tjs)a>ZOu)2^aj0@+H^n#-NeY;eaTw0$kYJh;Jd?m6Tp+34sHD#& zV!`(^(d3Vz$lMm6?Wz^Y2dp_|H%G%4<#f+hsK zSIMN}hMQ@qUtxvA9auxs(dj<5Y3!-SUI_S{&t5{mDuV{3Q8!%ZL4kCYWL)N8!#ZLF zZ1hK52oYltGwaal#$a4d2OePnxeGsF1LZhjLJl8Nw2}y=&Md;TRBXZ%FMS={(~Vpl xOE%~*y8ye;nS(<5o)CVcE?}wQ+pF)R>$$DA<-&@VkNbY@7w8G2?ro?){TDwZVm1H( diff --git a/examples/imperative/simple/arithmetic.sl b/examples/imperative/simple/arithmetic.sl index 5774957..0ce309b 100644 --- a/examples/imperative/simple/arithmetic.sl +++ b/examples/imperative/simple/arithmetic.sl @@ -12,5 +12,5 @@ process Main { } print_int(z); print_newline(); - exit(0); + quit(); } diff --git a/examples/imperative/simple/decl1.sl b/examples/imperative/simple/decl1.sl index 6d41fa5..b4c1f7e 100644 --- a/examples/imperative/simple/decl1.sl +++ b/examples/imperative/simple/decl1.sl @@ -6,5 +6,5 @@ process Main { x = 3; print_int(x); print_newline(); - exit(0); + quit(); } diff --git a/examples/imperative/simple/decl2.sl b/examples/imperative/simple/decl2.sl index eabd193..c1d145e 100644 --- a/examples/imperative/simple/decl2.sl +++ b/examples/imperative/simple/decl2.sl @@ -11,5 +11,5 @@ Run: var x : int = 1 } // print_int(x); print_newline(); - exit(0); + quit(); } diff --git a/examples/imperative/simple/factorial.sl b/examples/imperative/simple/factorial.sl index 4503971..e4ac19e 100644 --- a/examples/imperative/simple/factorial.sl +++ b/examples/imperative/simple/factorial.sl @@ -11,5 +11,5 @@ Run: x = factorial (5); print_int(x); print_newline(); - exit(0); + quit(); } diff --git a/examples/imperative/simple/helloworld.sl b/examples/imperative/simple/helloworld.sl index 6f498aa..d293d05 100644 --- a/examples/imperative/simple/helloworld.sl +++ b/examples/imperative/simple/helloworld.sl @@ -4,5 +4,5 @@ process Main { Run: print_string("Hello World\n"); - exit(0); + quit(); } diff --git a/examples/imperative/simple/min.sl b/examples/imperative/simple/min.sl index 8d70618..83c232e 100644 --- a/examples/imperative/simple/min.sl +++ b/examples/imperative/simple/min.sl @@ -15,5 +15,5 @@ Run: x = min(5,3,6); print_int(x); print_newline(); - exit(0); + quit(); } diff --git a/examples/imperative/simple/mutual_rec.sl b/examples/imperative/simple/mutual_rec.sl index 4329af9..bff06d8 100644 --- a/examples/imperative/simple/mutual_rec.sl +++ b/examples/imperative/simple/mutual_rec.sl @@ -16,5 +16,5 @@ method b(x : int) : int { process Main { Run: print_int(a(12)); print_newline(); - exit(0); + quit(); } diff --git a/examples/imperative/simple/testReturnVal.sl b/examples/imperative/simple/testReturnVal.sl index 354f593..ded1e37 100644 --- a/examples/imperative/simple/testReturnVal.sl +++ b/examples/imperative/simple/testReturnVal.sl @@ -6,5 +6,5 @@ method f(x : int) : int{ process Main { Run: print_int(f(1)); - exit(0); + quit(); } diff --git a/examples/imperative/structuresAndEnums/lr_values.sl b/examples/imperative/structuresAndEnums/lr_values.sl index 4340ee0..217e975 100644 --- a/examples/imperative/structuresAndEnums/lr_values.sl +++ b/examples/imperative/structuresAndEnums/lr_values.sl @@ -13,5 +13,5 @@ process Main () { // var b : &Point = &a; // var c : Point = *b; // var d : int = * (a.z); - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/structuresAndEnums/point.sl b/examples/imperative/structuresAndEnums/point.sl index f74bee1..3df3e04 100644 --- a/examples/imperative/structuresAndEnums/point.sl +++ b/examples/imperative/structuresAndEnums/point.sl @@ -13,5 +13,5 @@ Loop: var p : point = point {x:5, y:7, c:Red}; var y : int = p.x + p.y; var z : color = p.c; - exit(0); + quit(); } \ No newline at end of file diff --git a/examples/imperative/structuresAndEnums/testFieldAssign.sl b/examples/imperative/structuresAndEnums/testFieldAssign.sl index b7a9e66..7fb6f64 100644 --- a/examples/imperative/structuresAndEnums/testFieldAssign.sl +++ b/examples/imperative/structuresAndEnums/testFieldAssign.sl @@ -18,5 +18,5 @@ process Main { var c : pointBis = pointBis{z:b}; (*c.z).x = 2; print_int(a.x); - exit(0); + quit(); } \ No newline at end of file diff --git a/src/codegen/codegenEnv.ml b/src/codegen/codegenEnv.ml index 93ae0fe..6bf9e9e 100644 --- a/src/codegen/codegenEnv.ml +++ b/src/codegen/codegenEnv.ml @@ -4,7 +4,7 @@ open TypesCommon open Env open Mono open IrMir -module E = Error.Logger +module E = Logging.Logger open Monad.UseMonad(E) open MakeOrderedFunctions(ImportCmp) @@ -34,56 +34,58 @@ open Declarations type in_body = Monomorphization.Pass.out_body -let getLLVMBasicType f t llc llm : lltype = - let rec aux = function - | Bool -> i1_type llc - | Int n -> integer_type llc n - | Float -> double_type llc - | Char -> i8_type llc - | String -> i8_type llc |> pointer_type - | ArrayType (t,s) -> array_type (aux t) s - | Box t | RefType (t,_) -> aux t |> pointer_type - | GenericType _ -> failwith "no generic type in codegen" - | CompoundType {name=(_,name); _} when name = "_value" -> i64_type llc (* for extern functions *) - | CompoundType {origin=None;_} | CompoundType {decl_ty=None;_} -> failwith "compound type with no origin or decl_ty" - | CompoundType {origin=Some (_,mname); name=(_,name); decl_ty=Some d;_} -> - f (mname,name,d) llc llm aux +let getLLVMBasicType f t llc llm : lltype E.t = + let rec aux t = + match snd t with + | Bool -> i1_type llc |> return + | Int n -> integer_type llc n |> return + | Float -> double_type llc |> return + | Char -> i8_type llc |> return + | String -> i8_type llc |> pointer_type |> return + | ArrayType (t,s) -> let+ t = aux t in array_type t s + | Box t | RefType (t,_) -> aux t <&> pointer_type + | GenericType _ -> E.throw Logging.(make_msg (fst t) "no generic type in codegen") + | CompoundType {name=(_,name); _} when name = "_value" -> i64_type llc |> return (* for extern functions *) + | CompoundType {origin=None;_} + | CompoundType {decl_ty=None;_} -> E.throw Logging.(make_msg (fst t) "compound type with no origin or decl_ty") + | CompoundType {origin=Some (_,mname); name=(_,name); decl_ty=Some d;_} -> + f (mname,name,d) llc llm aux in aux t - let handle_compound_type_codegen env (mname,name,d) llc _llm aux : lltype = + let handle_compound_type_codegen env (mname,name,d) llc _llm (aux : sailtype -> lltype E.t) : lltype E.t = match DeclEnv.find_decl name (Specific (mname,Filter [d])) env with | Some (T tdef) -> begin match tdef with | {ty=Some t;_} -> aux t - | {ty=None;_} -> i64_type llc + | {ty=None;_} -> i64_type llc |> return end | Some (E _enum) -> failwith "todo enum" - | Some (S {ty;_}) -> ty + | Some (S {ty;_}) -> return ty | Some _ -> failwith "something is broken" | None -> failwith @@ Fmt.str "getLLVMType : %s '%s' not found in module '%s'" (string_of_decl d) name mname let getLLVMType = fun e -> getLLVMBasicType (handle_compound_type_codegen e) - let handle_compound_type env (mname,name,d) llc llm aux : lltype = + let handle_compound_type env (mname,name,d) llc llm (aux : sailtype -> lltype E.t) : lltype E.t = match SailModule.DeclEnv.find_decl name (Specific (mname,Filter [d])) env with | Some (T tdef) -> begin match tdef with | {ty=Some t;_} -> aux t - | {ty=None;_} -> i64_type llc + | {ty=None;_} -> i64_type llc |> return end | Some (E _enum) -> failwith "todo enum" | Some (S (_,defn)) -> let _,f_types = List.split defn.fields in - let elts = List.map (fun (_,t,_) -> aux t) f_types |> Array.of_list in + let* elts = ListM.map (fun (_,t,_) -> aux t) f_types <&> Array.of_list in begin match type_by_name llm ("struct." ^ name) with - | Some ty -> ty + | Some ty -> return ty | None -> (let ty = named_struct_type llc ("struct." ^ name) in - struct_set_body ty elts false; ty) + struct_set_body ty elts false; return ty) end | Some _ -> failwith "something is broken" | None -> failwith @@ Fmt.str "getLLVMType : %s '%s' not found in module '%s'" (string_of_decl d) name mname @@ -92,11 +94,11 @@ let getLLVMBasicType f t llc llm : lltype = let _getLLVMType = fun e -> getLLVMBasicType (handle_compound_type e) let llvm_proto_of_method_sig (m:method_sig) env llc llm = - let llvm_rt = match m.rtype with + let* llvm_rt = match m.rtype with | Some t -> getLLVMType env t llc llm - | None -> void_type llc + | None -> void_type llc |> return in - let args_type = List.map (fun ({ty;_}: param) -> getLLVMType env ty llc llm) m.params |> Array.of_list in + let+ args_type = ListM.map (fun ({ty;_}: param) -> getLLVMType env ty llc llm) m.params <&> Array.of_list in let method_t = if m.variadic then var_arg_function_type else function_type in let name = if not (m.extern || m.name = "main") then Fmt.str "_%s_%s" (DeclEnv.get_name env) m.name else m.name in declare_function name (method_t llvm_rt args_type ) llm @@ -127,7 +129,7 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t = ); let valueify_method_sig (m:method_sig) : method_sig = - let value = fun pos -> CompoundType{origin=None;name=(pos,"_value");generic_instances=[];decl_ty=None} in + let value = fun pos -> dummy_pos,CompoundType{origin=None;name=(pos,"_value");generic_instances=[];decl_ty=None} in let rtype = m.rtype in (* keep the current type *) let params = List.map (fun (p:param) -> {p with ty=(value p.loc)}) m.params in {m with params; rtype} @@ -143,7 +145,7 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t = else false,m.m_proto in - let llproto = llvm_proto_of_method_sig proto env llc llm + let* llproto = llvm_proto_of_method_sig proto env llc llm in let m_body = if is_import then @@ -165,7 +167,7 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t = let load_structs structs write_env = SEnv.fold (fun acc (name,(_,defn)) -> let _,f_types = List.split defn.fields in - let elts = List.map (fun (_,t,_) -> _getLLVMType sm.declEnv t llc llm) f_types |> Array.of_list in + let* elts = ListM.map (fun (_,t,_) -> _getLLVMType sm.declEnv t llc llm) f_types <&> Array.of_list in let ty = match type_by_name llm ("struct." ^ name) with | Some ty -> ty | None -> let ty = named_struct_type llc ("struct." ^ name) in diff --git a/src/codegen/codegenUtils.ml b/src/codegen/codegenUtils.ml index a14b8b8..d0d9d44 100644 --- a/src/codegen/codegenUtils.ml +++ b/src/codegen/codegenUtils.ml @@ -2,6 +2,7 @@ open Llvm open Common open TypesCommon open CodegenEnv +open Monad.UseMonad(Logging.Logger) type llvm_args = { c:llcontext; b:llbuilder;m:llmodule; layout : Llvm_target.DataLayout.t} let mangle_method_name (name:string) (mname:string) (args: sailtype list ) : string = @@ -20,7 +21,7 @@ let getLLVMLiteral (l:literal) (llvm:llvm_args) : llvalue = | LString s -> build_global_stringptr s ".str" llvm.b let ty_of_alias(t:sailtype) env : sailtype = - match t with + match snd t with | CompoundType {origin=Some (_,mname); name=(_,name);decl_ty=Some T ();_} -> begin match DeclEnv.find_decl name (Specific (mname,Type)) env with @@ -32,7 +33,7 @@ let ty_of_alias(t:sailtype) env : sailtype = let unary (op:unOp) (t,v) : llbuilder -> llvalue = let f = - match t,op with + match snd t,op with | Float,Neg -> build_fneg | Int _,Neg -> build_neg | _,Not -> build_not @@ -75,21 +76,20 @@ let binary (op:binOp) (t:sailtype) (l1:llvalue) (l2:llvalue) : llbuilder -> llva | And -> "and" | Or -> "or" | Le -> "le" | Lt -> "lt" | Ge -> "ge" | Gt -> "gt" | Mul -> "mul" | NEq -> "neq" | Div -> "div" in - let t = if t = Bool then Int 1 else t in (* thir will have checked for correctness *) - let l = operators t in + let t = if snd t = Bool then fst t,Int 1 else t in (* thir will have checked for correctness *) + let l = operators (snd t) in let open Common.Monad.MonadOperator(Common.MonadOption.M) in - match l >>| List.assoc_opt op with - | Some Some oper -> oper l1 l2 "" - | Some None | None -> Printf.sprintf "codegen: bad usage of binop '%s' with type %s" (string_of_binop op) (string_of_sailtype @@ Some t) |> failwith + match l >>| List.assoc_opt op |> Option.join with + | Some oper -> oper l1 l2 "" + | None -> Printf.sprintf "codegen: bad usage of binop '%s' with type %s" (string_of_binop op) (string_of_sailtype @@ Some t) |> failwith -let toLLVMArgs (args: param list ) (env:DeclEnv.t) (llvm:llvm_args) : (bool * sailtype * llvalue) array = - let llvalue_list = List.map ( +let toLLVMArgs (args: param list ) (env:DeclEnv.t) (llvm:llvm_args) : (bool * sailtype * llvalue) array E.t = + ListM.map ( fun {id;mut;ty=t;_} -> - let ty = getLLVMType env t llvm.c llvm.m in + let+ ty = getLLVMType env t llvm.c llvm.m in mut,t,build_alloca ty id llvm.b - ) args in - Array.of_list llvalue_list + ) args <&> Array.of_list let get_memcpy_intrinsic llvm = diff --git a/src/codegen/codegen_.ml b/src/codegen/codegen_.ml index e6c087d..871ec93 100644 --- a/src/codegen/codegen_.ml +++ b/src/codegen/codegen_.ml @@ -5,59 +5,70 @@ open TypesCommon open IrMir open Monad.UseMonad(E) module L = Llvm -module E = Error.Logger - +module E = Logging.Logger let get_type (e:AstMir.expression) = snd e.info -let rec eval_l (env:SailEnv.t) (llvm:llvm_args) (x: AstMir.expression) : L.llvalue = +let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x: AstMir.expression) : L.llvalue E.t = match x.exp with - | Variable x -> let _,v = match (SailEnv.get_var x env) with Some (_,n) -> n | None -> failwith @@ Fmt.str "var '%s' not found" x |> snd in v + | Variable x -> + let+ _,v = match (SailEnv.get_var x venv) with + | Some (_,n) -> return n + | None -> E.throw Logging.(make_msg dummy_pos @@ Fmt.str "var '%s' not found" x) + in v + | Deref x -> eval_r env llvm x + | ArrayRead (array_exp, index_exp) -> - let array_val = eval_l env llvm array_exp in - let index = eval_r env llvm index_exp in + let* array_val = eval_l env llvm array_exp in + let+ index = eval_r env llvm index_exp in let llvm_array = L.build_in_bounds_gep array_val [|L.(const_int (i64_type llvm.c) 0 ); index|] "" llvm.b in llvm_array + | StructRead ((_,mname),struct_exp,(_,field)) -> - let st = eval_l env llvm struct_exp in - let st_type_name = match snd struct_exp.info with CompoundType c -> snd c.name | _ -> failwith "problem with structure type" in - let fields = (SailEnv.get_decl st_type_name (Specific (mname,Struct)) env |> Option.get).defn.fields in + let* st = eval_l env llvm struct_exp in + let+ st_type_name = Env.TypeEnv.get_from_id struct_exp.info tenv >>= function _,CompoundType c -> return (snd c.name) | _ -> E.throw Logging.(make_msg dummy_pos "problem with structure type") in + let fields = (SailEnv.get_decl st_type_name (Specific (mname,Struct)) venv |> Option.get).defn.fields in let _,_,idx = List.assoc field fields in L.build_struct_gep st idx "" llvm.b | StructAlloc (_,(_,name),fields) -> let _,fieldlist = fields |> List.split in - let strct_ty = match L.type_by_name llvm.m ("struct." ^ name) with - | Some s -> s - | None -> "unknown structure : " ^ ("struct." ^ name) |> failwith + let* strct_ty = match L.type_by_name llvm.m ("struct." ^ name) with + | Some s -> return s + | None -> + E.throw Logging.(make_msg (fst x.info) @@ "unknown structure : " ^ ("struct." ^ name)) in let struct_v = L.build_alloca strct_ty "" llvm.b in - List.iteri ( fun i f -> - let v = eval_r env llvm f in + let+ () = ListM.iteri ( fun i (_,f) -> + let+ v = eval_r env llvm f in let v_f = L.build_struct_gep struct_v i "" llvm.b in L.build_store v v_f llvm.b |> ignore - ) fieldlist; + ) fieldlist in struct_v - | _ -> failwith "unexpected rvalue for codegen" + | _ -> E.throw Logging.(make_msg (fst x.info) "unexpected rvalue for codegen") -and eval_r (env:SailEnv.t) (llvm:llvm_args) (x:AstMir.expression) : L.llvalue = - let ty = get_type x in +and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x:AstMir.expression) : L.llvalue E.t = + let* ty = Env.TypeEnv.get_from_id x.info tenv in match x.exp with - | Variable _ | StructRead _ | ArrayRead _ | StructAlloc _ -> let v = eval_l env llvm x in L.build_load v "" llvm.b + | Variable _ | StructRead _ | ArrayRead _ | StructAlloc _ -> let+ v = eval_l env llvm x in L.build_load v "" llvm.b + + | Literal l -> return @@ getLLVMLiteral l llvm + + | UnOp (op,e) -> let+ l = eval_r env llvm e in unary op (ty_of_alias ty (snd venv),l) llvm.b - | Literal l -> getLLVMLiteral l llvm - | UnOp (op,e) -> let l = eval_r env llvm e in unary op (ty_of_alias ty (snd env),l) llvm.b | BinOp (op,e1, e2) -> - let l1 = eval_r env llvm e1 - and l2 = eval_r env llvm e2 - in binary op (ty_of_alias ty (snd env)) l1 l2 llvm.b + let+ l1 = eval_r env llvm e1 + and* l2 = eval_r env llvm e2 + in binary op (ty_of_alias ty (snd venv)) l1 l2 llvm.b | Ref (_,e) -> eval_l env llvm e - | Deref e -> let v = eval_l env llvm e in L.build_load v "" llvm.b + + | Deref e -> let+ v = eval_l env llvm e in L.build_load v "" llvm.b + | ArrayStatic elements -> begin - let array_values = List.map (eval_r env llvm) elements in + let+ array_values = ListM.map (eval_r env llvm) elements in let ty = List.hd array_values |> L.type_of in let array_values = Array.of_list array_values in let array_type = L.array_type ty (List.length elements) in @@ -69,70 +80,72 @@ and eval_r (env:SailEnv.t) (llvm:llvm_args) (x:AstMir.expression) : L.llvalue = L.build_load array "" llvm.b end - | EnumAlloc _ -> failwith "enum allocation unimplemented" + | EnumAlloc _ -> E.throw Logging.(make_msg (fst x.info) "enum allocation unimplemented") - | _ -> failwith "problem with thir" + | _ -> E.throw Logging.(make_msg (fst x.info) "problem with thir") -and construct_call (name:string) ((_,mname):l_str) (args:AstMir.expression list) (env:SailEnv.t) (llvm:llvm_args) : L.llvalue = - let args_type,llargs = List.map (fun arg -> get_type arg,eval_r env llvm arg) args |> List.split +and construct_call (name:string) ((loc,mname):l_str) (args:AstMir.expression list) (venv,tenv as env : SailEnv.t*Env.TypeEnv.t) (llvm:llvm_args) : L.llvalue E.t = + let* args_type,llargs = ListM.map (fun arg -> let+ r = eval_r env llvm arg in arg.info,r) args >>| List.split in (* let mname = mangle_method_name name origin.mname args_type in *) let mangled_name = "_" ^ mname ^ "_" ^ name in Logs.debug (fun m -> m "constructing call to %s" name); - let llval,ext = match SailEnv.get_decl mangled_name (Specific (mname,Method)) env with + let* llval,ext = match SailEnv.get_decl mangled_name (Specific (mname,Method)) venv with | None -> begin - match SailEnv.get_decl name (Specific (mname,Method)) env with - | Some {llval;extern;_} -> llval,extern - | None -> Printf.sprintf "implementation of %s not found" mangled_name |> failwith + match SailEnv.get_decl name (Specific (mname,Method)) venv with + | Some {llval;extern;_} -> return (llval,extern) + | None -> E.throw Logging.(make_msg loc @@ Printf.sprintf "implementation of %s not found" mangled_name ) end - | Some {llval;extern;_} -> llval,extern + | Some {llval;extern;_} -> return (llval,extern) in - let args = + let+ args = if ext then - List.map2 (fun t v -> + ListM.map2 (fun t v -> + let+ t = Env.TypeEnv.get_from_id t tenv in let builder = - match ty_of_alias t (snd env) with + match snd (ty_of_alias t (snd venv)) with | Bool | Int _ | Char -> L.build_zext | Float -> L.build_bitcast | CompoundType _ -> fun v _ _ _ -> v | _ -> L.build_ptrtoint - in + in builder v (L.i64_type llvm.c) "" llvm.b ) args_type llargs else - llargs + return llargs in L.build_call llval (Array.of_list args) "" llvm.b open AstMir -let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (env :SailEnv.t) : unit = - let declare_var (mut:bool) (name:string) (ty:sailtype) (exp:AstMir.expression option) (env:SailEnv.t) : SailEnv.t E.t= - let _ = mut in (* todo manage mutable types *) - let entry_b = L.(entry_block proto |> instr_begin |> builder_at llvm.c) in - let v = - match exp with - | Some e -> - let t = get_type e - and v = eval_r env llvm e in - let x = L.build_alloca (getLLVMType (snd env) t llvm.c llvm.m) name entry_b in - L.build_store v x llvm.b |> ignore; x - | None -> - let t' = getLLVMType (snd env) ty llvm.c llvm.m in - L.build_alloca t' name entry_b - in - SailEnv.declare_var name (dummy_pos,(mut,v)) env - - and assign_var (target:expression) (exp:expression) (env:SailEnv.t) = - let lvalue = eval_l env llvm target in - let rvalue = eval_r env llvm exp in +let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,tenv : SailEnv.t*Env.TypeEnv.t) : unit E.t = + let declare_var (mut:bool) (name:string) (ty:sailtype) (exp:AstMir.expression option) (venv : SailEnv.t) : SailEnv.t E.t = + let _ = mut in (* todo manage mutable types *) + let entry_b = L.(entry_block proto |> instr_begin |> builder_at llvm.c) in + let* v = + match exp with + | Some e -> + let* t = Env.TypeEnv.get_from_id e.info tenv + and* v = eval_r (venv,tenv) llvm e in + let+ ty = getLLVMType (snd venv) t llvm.c llvm.m in + let x = L.build_alloca ty name entry_b in + L.build_store v x llvm.b |> ignore; x + | None -> + let+ t' = getLLVMType (snd venv) ty llvm.c llvm.m in + L.build_alloca t' name entry_b + in + SailEnv.declare_var name (dummy_pos,(mut,v)) venv + in + let assign_var (target:expression) (exp:expression) (env : SailEnv.t*Env.TypeEnv.t) = + let* lvalue = eval_l env llvm target in + let+ rvalue = eval_r env llvm exp in L.build_store rvalue lvalue llvm.b |> ignore - in + in - let rec aux (lbl:label) (llvm_bbs : L.llbasicblock BlockMap.t) (env:SailEnv.t) : L.llbasicblock BlockMap.t = - if BlockMap.mem lbl llvm_bbs then llvm_bbs (* already treated, nothing to do *) + let rec aux (lbl:label) (llvm_bbs : L.llbasicblock BlockMap.t) (venv : SailEnv.t) : L.llbasicblock BlockMap.t E.t = + if BlockMap.mem lbl llvm_bbs then return llvm_bbs (* already treated, nothing to do *) else begin let bb = BlockMap.find lbl cfg.blocks @@ -140,73 +153,71 @@ let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (env :S let llvm_bb = L.append_block llvm.c bb_name proto in let llvm_bbs = BlockMap.add lbl llvm_bb llvm_bbs in L.position_at_end llvm_bb llvm.b; - List.iter (fun x -> assign_var x.target x.expression env) bb.assignments; + let* () = ListM.iter (fun x -> assign_var x.target x.expression (venv,tenv)) bb.assignments in match bb.terminator with | Some (Return e) -> - let ret = match e with - | Some r -> let v = eval_r env llvm r in L.build_ret v - | None -> L.build_ret_void + let+ ret = match e with + | Some r -> let+ v = eval_r (venv,tenv) llvm r in L.build_ret v + | None -> return L.build_ret_void in ret llvm.b |> ignore; llvm_bbs | Some (Goto lbl) -> - let llvm_bbs = aux lbl llvm_bbs env in + let+ llvm_bbs = aux lbl llvm_bbs venv in L.position_at_end llvm_bb llvm.b; - L.build_br (BlockMap.find lbl llvm_bbs) llvm.b |> ignore; + let _ = L.build_br (BlockMap.find lbl llvm_bbs) llvm.b in llvm_bbs - - | Some (Invoke f) -> - let c = construct_call f.id f.origin f.params env llvm in + | Some (Invoke f) -> + let* c = construct_call f.id f.origin f.params (venv,tenv) llvm in begin match f.target with - | Some id -> L.build_store c (let _,v = SailEnv.get_var id env |> Option.get |> snd in v) llvm.b |> ignore + | Some id -> L.build_store c (let _,v = SailEnv.get_var id venv |> Option.get |> snd in v) llvm.b |> ignore | None -> () end; - let llvm_bbs = aux f.next llvm_bbs env in + let+ llvm_bbs = aux f.next llvm_bbs venv in L.position_at_end llvm_bb llvm.b; L.build_br (BlockMap.find f.next llvm_bbs) llvm.b |> ignore; llvm_bbs + | Some (SwitchInt si) -> - let sw_val = eval_r env llvm si.choice in - let sw_val = L.build_intcast sw_val (L.i32_type llvm.c) "" llvm.b (* for condition, expression val will be bool *) - and llvm_bbs = aux si.default llvm_bbs env in + let* sw_val = eval_r (venv,tenv) llvm si.choice in + let sw_val = L.build_intcast sw_val (L.i32_type llvm.c) "" llvm.b in (* for condition, expression val will be bool *) + let* llvm_bbs = aux si.default llvm_bbs venv in L.position_at_end llvm_bb llvm.b; let sw = L.build_switch sw_val (BlockMap.find si.default llvm_bbs) (List.length si.paths) llvm.b in - List.fold_left ( + ListM.fold_left ( fun bm (n,lbl) -> - let n = L.const_int (L.i32_type llvm.c) n - and bm = aux lbl bm env + let n = L.const_int (L.i32_type llvm.c) n in + let+ bm = aux lbl bm venv in L.add_case sw n (BlockMap.find lbl bm); bm ) llvm_bbs si.paths - | None -> failwith "no terminator : mir is broken" (* can't happen *) - | Some Break -> failwith "no break should be there" + | None -> E.throw Logging.(make_msg bb.location "no terminator : mir is broken") + | Some Break -> E.throw Logging.(make_msg bb.location "no break should be there") end in ( - let+ env = ListM.fold_left (fun e (d:declaration) -> declare_var d.mut d.id d.varType None e) env decls + let* env = ListM.fold_left (fun e (d:declaration) -> declare_var d.mut d.id d.varType None e) venv decls in - let init_bb = L.insertion_block llvm.b - and llvm_bbs = aux cfg.input BlockMap.empty env in + let init_bb = L.insertion_block llvm.b in + let+ llvm_bbs = aux cfg.input BlockMap.empty env in L.position_at_end init_bb llvm.b; - L.build_br (BlockMap.find cfg.input llvm_bbs) llvm.b - ) |> ignore + L.build_br (BlockMap.find cfg.input llvm_bbs) llvm.b |> ignore + ) -let methodToIR (llc:L.llcontext) (llm:L.llmodule) (decl:Declarations.method_decl) (env:SailEnv.t) (name : string) : L.llvalue = +let methodToIR (llc:L.llcontext) (llm:L.llmodule) (decl:Declarations.method_decl) (venv,tenv:SailEnv.t * Env.TypeEnv.t) (name : string) : L.llvalue E.t = match Either.find_right decl.defn.m_body with - | None -> decl.llval (* extern method *) + | None -> return decl.llval (* extern method *) | Some b -> Logs.info (fun m -> m "codegen of %s" name); let builder = L.builder llc in let llvm = {b=builder; c=llc ; m = llm; layout=Llvm_target.DataLayout.of_string (L.data_layout llm)} in - - if L.block_begin decl.llval <> At_end decl.llval then failwith ("redefinition of function " ^ name); - + let* () = E.throw_if Logging.(make_msg dummy_pos @@ "redefinition of function " ^ name) (L.block_begin decl.llval <> At_end decl.llval) in let bb = L.append_block llvm.c "" decl.llval in L.position_at_end bb llvm.b; - let args = toLLVMArgs decl.defn.m_proto.params (snd env) llvm in + let* args = toLLVMArgs decl.defn.m_proto.params (snd venv) llvm in let new_env,args = Array.fold_left_map ( fun env (m,_,v) -> @@ -215,22 +226,26 @@ let methodToIR (llc:L.llcontext) (llm:L.llmodule) (decl:Declarations.method_decl let+ new_env = SailEnv.declare_var (L.value_name v) (dummy_pos,(m,v)) env in new_env ),v - ) (E.pure env) args + ) (E.pure venv) args in Array.iteri (fun i arg -> L.build_store (L.param decl.llval i) arg llvm.b |> ignore ) args; - (let+ new_env in cfgToIR decl.llval b llvm new_env) |> ignore; + let* new_env in + let+ () = cfgToIR decl.llval b llvm (new_env,tenv) in decl.llval let moduleToIR (sm: in_body SailModule.t) (verify_ir:bool) : L.llmodule E.t = let llc = L.create_context () in let llm = L.create_module llc sm.md.name in let* decls = get_declarations sm llc llm in - let env = SailEnv.empty decls in - - DeclEnv.iter_decls (fun name m -> let func = methodToIR llc llm m env name in if verify_ir then Llvm_analysis.assert_valid_function func) (Self Method) decls >>= fun () -> + let env = SailEnv.empty decls,sm.typeEnv in + let method_cg name m : unit E.t = + let+ m = methodToIR llc llm m env name in + if verify_ir then Llvm_analysis.assert_valid_function m + in + let* _ = DeclEnv.fold_decls (fun id m accu -> let r = method_cg id m in r::accu) [] Method decls |> ListM.sequence in if verify_ir then match Llvm_analysis.verify_module llm with | None -> return llm - | Some reason -> E.throw @@ Error.make dummy_pos (Fmt.str "LLVM : %s" reason) + | Some reason -> E.throw Logging.(make_msg dummy_pos (Fmt.str "LLVM : %s" reason)) else return llm \ No newline at end of file diff --git a/src/common/builtins.ml b/src/common/builtins.ml index 05b4468..8f00dab 100644 --- a/src/common/builtins.ml +++ b/src/common/builtins.ml @@ -5,4 +5,4 @@ let register_builtin name generics p rtype variadic l: method_sig list = let get_builtins () : method_sig list = [] - |> register_builtin "box" ["T"] [GenericType "T"] (Some (Box (GenericType "T"))) false \ No newline at end of file + |> register_builtin "box" ["T"] [dummy_pos,GenericType "T"] (Some (dummy_pos,Box (dummy_pos,GenericType "T"))) false \ No newline at end of file diff --git a/src/common/env.ml b/src/common/env.ml index 9b07f47..6f4368a 100644 --- a/src/common/env.ml +++ b/src/common/env.ml @@ -1,5 +1,5 @@ open TypesCommon -module E = Error.Logger +module E = Logging.Logger open Monad open MonadSyntax(E) open MonadOperator(E) @@ -178,13 +178,13 @@ module DeclarationsEnv : DeclEnvType = functor (D:Declarations) -> struct let overwrite_decls (type d) (field: d container) = update_decls (fun _ -> field) let add_decl (type d) id (decl:d) (ty: d decl_ty) (env:t) : t E.t = - E.throw_if (Error.make dummy_pos @@ Fmt.str "duplicate declarations for '%s'" id) (FieldMap.mem id (get_decls ty env.self)) + E.throw_if Logging.(make_msg dummy_pos @@ Fmt.str "duplicate declarations for '%s'" id) (FieldMap.mem id (get_decls ty env.self)) >>= fun () -> return {env with self=update_decls (FieldMap.add id decl) ty env.self} let remove_decl id (ty:'a decl_ty) t = let new_env = update_decls (FieldMap.remove id) ty t.self in let+ () = E.throw_if - (Error.make dummy_pos @@ Fmt.str "attempting to remove unknown declaration '%s'" id) + Logging.(make_msg dummy_pos @@ Fmt.str "attempting to remove unknown declaration '%s'" id) (new_env <> t.self) in {t with self=new_env} @@ -245,7 +245,7 @@ module DeclarationsEnv : DeclEnvType = functor (D:Declarations) -> struct if m = env.name then return (FieldMap.iter f (get_decls d env.self)) else - let+ env = E.throw_if_none (Error.make dummy_pos "can't happen") + let+ env = E.throw_if_none Logging.(make_msg dummy_pos "can't happen") (List.find_opt (fun ({mname;_},_) -> mname = m) env.imports) in FieldMap.iter f (get_decls d (snd env) ) @@ -259,7 +259,7 @@ module DeclarationsEnv : DeclEnvType = functor (D:Declarations) -> struct let find_closest name env : string list = if String.length name > 3 then - let check = (fun n _ l -> if Error.levenshtein_distance n name < 3 then n::l else l) in + let check = (fun n _ l -> if Logging.levenshtein_distance n name < 3 then n::l else l) in FieldMap.fold check env.self.methods [] |> fun l -> FieldMap.fold check env.self.processes l |> fun l -> FieldMap.fold check env.self.structs l @@ -355,12 +355,12 @@ module VariableEnv : VariableEnvType = functor (V : Variable) -> struct let upd_frame = FieldMap.add name v current in let stack = push_frame stack upd_frame in {stack} |> E.pure | Some _ -> - E.throw (Error.make l @@ Printf.sprintf "variable %s already declared in current frame!" name) + E.throw Logging.(make_msg l @@ Printf.sprintf "variable %s already declared in current frame!" name) >>| fun () -> e let init_env (args:param list) : t E.t = - let open Monad.MonadFunctions(Error.Logger) in + let open Monad.MonadFunctions(Logging.Logger) in let env = empty |> new_frame in ListM.fold_right (fun (p:param) -> let v = V.param_to_var p in @@ -373,7 +373,7 @@ module VariableEnv : VariableEnvType = functor (V : Variable) -> struct let current,stack = current_frame stack in match FieldMap.find_opt name current with | Some v -> let+ v' = f v in FieldMap.add name v' current :: stack - | None when stack = [] -> E.throw (Error.make l @@ Printf.sprintf "variable %s not found" name) + | None when stack = [] -> E.throw Logging.(make_msg l @@ Printf.sprintf "variable %s not found" name) >>| fun () -> e.stack | _ -> let+ e = aux stack in current :: e in @@ -383,6 +383,28 @@ module VariableEnv : VariableEnvType = functor (V : Variable) -> struct end +module TypeEnv = struct + type t = sailtype FieldMap.t + let empty = FieldMap.empty + let get_id ty (te :t) : string * t = + let add_if_no_exists s = FieldMap.update s (Option.fold ~none:(Some ty) ~some:Option.some) in + let s = match snd ty with + | Bool -> "bool" + | Float -> "float" + | Char -> "char" + | String -> "string" + | Int n -> "int" ^ string_of_int n + | ArrayType _ -> "array" + | GenericType t -> t + | CompoundType t -> snd t.name + | Box _ -> "box" + | RefType _ -> "ref" + in s,add_if_no_exists s te + + let get_from_id (lid,id:l_str) te : sailtype E.t = E.throw_if_none Logging.(make_msg lid @@ Fmt.str "id '%s' not found" id) (FieldMap.find_opt id te) + +end + module VariableDeclEnv = functor (D:Declarations) (V:Variable) -> struct module D = DeclarationsEnv(D) module VE = VariableEnv(V) @@ -391,22 +413,22 @@ module VariableDeclEnv = functor (D:Declarations) (V:Variable) -> struct let get_decl name ty (_,g) = D.find_decl name ty g - let add_decl name d ty (v,g) = let+ g = D.add_decl name d ty g in v,g + let add_decl name d ty (v,g) : t E.t = let+ g = D.add_decl name d ty g in v,g let get_imports (_,g) = D.get_imports g let get_closest name (_,g) = D.find_closest name g let get_var id (env,_) = VE.get_var id env - let declare_var id v (env,g) = let+ env = VE.declare_var id v env in env,g + let declare_var id v (env,g) : t E.t = let+ env = VE.declare_var id v env in env,g - let update_var l v id (env,g) = let+ env = VE.update_var l v id env in env,g - let new_frame (env,g) = VE.new_frame env,g - let pop_frame (env,g) = VE.pop_frame env,g - let get_env (env,_) = env + let update_var l v id (env,g) : t E.t = let+ env = VE.update_var l v id env in env,g + let new_frame (env,g) : t = VE.new_frame env,g + let pop_frame (env,g) : t = VE.pop_frame env,g + let get_env (env,_,_) = env let empty g : t = VE.empty,g let get_start_env (decls:D.t) (args:param list) : t E.t = let+ venv = VE.init_env args in venv,decls -end +end \ No newline at end of file diff --git a/src/common/error.ml b/src/common/error.ml deleted file mode 100644 index 581e10f..0000000 --- a/src/common/error.ml +++ /dev/null @@ -1,176 +0,0 @@ -open Lexing -open Monad - -let show_context text (p1,p2) = - let open MenhirLib.ErrorReports in - let p1 = { p1 with pos_cnum=p1.pos_bol} - and p2 = match String.index_from_opt text p2.pos_cnum '\n' with - | Some pos_cnum -> { p2 with pos_cnum } - | None -> p2 - in - extract text (p1,p2) |> sanitize - - - type error = - { - where : TypesCommon.loc; - what : string; - why : (TypesCommon.loc option * string) option; - hint : (TypesCommon.loc option * string) option; - label : string option; - } - type errors = error list - type 'a result = ('a, errors) Result.t - - let make ?(why=None) ?(hint=None) ?(label=None) where what : error = - {where;what;why;label;hint} - - let print_errors (errs:errors) : unit = - let print_indication fmt (where: TypesCommon.loc) what : unit = - let location = MenhirLib.LexerUtil.range where in - let f = ((fst where).pos_fname |> open_in |> In_channel.input_all) in - let indication = show_context f where in - let start = String.make ((fst where).pos_cnum - (fst where).pos_bol )' ' in - let ending = String.make ((snd where).pos_cnum - (fst where).pos_cnum )'^' in - Fmt.pf fmt "@[%s@ %s@ %s%s %s@,@]@ @ " location indication start ending what - in - if errs <> [] then - let s fmt = List.iter ( - fun {where;what;hint;_} -> - if where = (dummy_pos,dummy_pos) then - Fmt.pf fmt "@[%s@ @] @ @ " what - else - print_indication fmt where what; - match hint with - | Some (where,h) -> - begin - match where with - | Some loc -> print_indication fmt loc h - | None -> Fmt.pf fmt "@[Hint : %s @ @]@ " h - end - | None -> () - ) - errs in - Logs.err (fun m -> m "@[found %i error(s) :@." (List.length errs) ); - s Fmt.stderr - - -(* taken from http://rosettacode.org/wiki/Levenshtein_distance#A_recursive_functional_version *) -let levenshtein_distance s t = - let m = String.length s - and n = String.length t in - let d = Array.make_matrix (m + 1) (n + 1) 0 in - for i = 0 to m do d.(i).(0) <- i done; - for j = 0 to n do d.(0).(j) <- j done; - for j = 1 to n do - for i = 1 to m do - if s.[i - 1] = t.[j - 1] then d.(i).(j) <- d.(i - 1).(j - 1) - else - d.(i).(j) <- min (d.(i - 1).(j) + 1) (min (d.(i).(j - 1) + 1) (d.(i - 1).(j - 1) + 1)) - done - done; - d.(m).(n) - - - -module type Logger = sig - include MonadTransformer - val catch : (error -> 'a t) -> 'a t -> 'a t - - val get_error : (error -> unit) -> 'a t -> 'a t - val throw : error -> 'a t - val log : error -> unit t - val recover : 'a -> 'a t -> 'a t - val fail : 'a t -> 'a t - val log_if : error -> bool -> unit t - val throw_if : error -> bool -> unit t - val throw_if_none : error -> 'a option -> 'a t - val throw_if_some : ('a -> error) -> 'a option -> unit t - - val get_warnings : (error list -> unit) -> 'a t -> 'a t - - -end - - module MakeTransformer (M : Monad) : Logger with type 'a t = (('a,error) Result.t * error list) M.t and type 'a old_t = 'a M.t = struct - open MonadSyntax(M) - - type 'a old_t = 'a M.t - type 'a t = (('a, error) Result.t * error list) old_t - - let pure (x:'a) : 'a t = (Ok x,[]) |> M.pure - - let fmap (f: 'a -> 'b) (x:'a t) : 'b t = - let+ v,l = x in - match v with - | Ok x -> Ok (f x),l - | Error e -> Error e,l - - let apply (f:('a -> 'b) t) (x: 'a t) : 'b t = - let+ f,l1 = f and* x,l2 = x in - match f,x with - | Error err1,Error err2 -> Error err1, err2 :: l1 @ l2 - | Error err1,_ -> Error err1, l1@l2 - | Ok f, Ok x -> Ok (f x),l1@l2 - | Ok _, Error err -> Error err,l1@l2 - - let bind (x:'a t) (f : 'a -> 'b t) : 'b t = - let* v,l1 = x in - match v with - | Error err -> (Error err,l1) |> M.pure - | Ok x -> let+ v,l2 = f x in v,l1@l2 - - let lift (x:'a M.t) : 'a t = let+ x in Ok x,[] - - let throw (e:error) : 'a t = (Error e,[]) |> M.pure - - let catch (f:error -> 'a t) (x : 'a t) : 'a t = - let* v,l = x in - match v with - | Error err -> let+ x,l2 = f err in x,l@l2 - | Ok x -> (Ok x,l) |> M.pure - - - let get_error (f:error -> unit) (x : 'a t) : 'a t = - let* v,l = x in - match v with - | Error err -> f err; ((Error err),l) |> M.pure - | Ok x -> (Ok x,l) |> M.pure - - - let log (msg:error) : unit t = (Ok (),[msg]) |> M.pure - - let recover (default : 'a) (x:'a t) : 'a t = - let+ v,l = x in - match v with - | Ok x -> Ok x,l - | Error err -> Ok default,err::l - - - let fail (x:'a t) : 'a t = - let+ v,l = x in - match v,l with - | Error err,_ -> Error err,l - | Ok x,[] -> Ok x,[] - | Ok _,h::t -> Error h,t - - let log_if e b = if b then log e else pure () - let throw_if e b = if b then throw e else pure () - - let throw_if_none (e: error) (x:'a option) : 'a t = - match x with - | None -> throw e - | Some r -> pure r - - let throw_if_some (f: 'a -> error) (x:'a option) : unit t = - match x with - | Some r -> throw (f r) - | None -> pure () - - - let get_warnings (f : error list -> unit) (x : 'a t) : 'a t = - let+ v,l = x in f l; (v,l) - -end - -module Logger = MakeTransformer(MonadIdentity) \ No newline at end of file diff --git a/src/common/logging.ml b/src/common/logging.ml new file mode 100644 index 0000000..c5a114a --- /dev/null +++ b/src/common/logging.ml @@ -0,0 +1,195 @@ +open Lexing +open Monad + +type msg = +{ + where : TypesCommon.loc; + what : string; + why : (TypesCommon.loc option * string) option; + hint : (TypesCommon.loc option * string) option; + label : string option; +} + +let make_msg ?(why=None) ?(hint=None) ?(label=None) where what : msg = {where;what;why;label;hint;} + + +type log = {errors: msg list; warnings : msg list} + +module Log : Monoid with type t = log = struct + type t = log + let mempty = {errors=[]; warnings=[]} + let mconcat l1 l2 = {errors=l1.errors@l2.errors;warnings=l1.warnings@l2.warnings} +end + + +let show_context text (p1,p2) = + let open MenhirLib.ErrorReports in + let p1 = { p1 with pos_cnum=p1.pos_bol} + and p2 = match String.index_from_opt text p2.pos_cnum '\n' with + | Some pos_cnum -> { p2 with pos_cnum } + | None -> p2 + in + extract text (p1,p2) |> sanitize + + + + +let print_log (log:Log.t) : unit = + let print_indication fmt (where: TypesCommon.loc) what : unit = + let location = MenhirLib.LexerUtil.range where in + let f = ((fst where).pos_fname |> open_in |> In_channel.input_all) in + let indication = show_context f where in + let start = String.make ((fst where).pos_cnum - (fst where).pos_bol )' ' in + let ending = String.make ((snd where).pos_cnum - (fst where).pos_cnum )'^' in + Fmt.pf fmt "@[%s@ %s@ %s%s %s@,@]@ @ " location indication start ending what + in + + let format_msg fmt msg = + if msg.where = (dummy_pos,dummy_pos) then + Fmt.pf fmt "@[%s@ @] @ @ " msg.what + else + print_indication fmt msg.where msg.what; + match msg.hint with + | Some (where,h) -> + begin + match where with + | Some loc -> print_indication fmt loc h + | None -> Fmt.pf fmt "@[Hint : %s @ @]@ " h + end + | None -> () + + in + (* Logs.err (fun m -> m "@[found %i error(s) :@." (List.length errs) ); *) + + List.iter (format_msg Fmt.stderr) log.errors; + List.iter (format_msg Fmt.stderr) log.warnings + + + + +(* taken from http://rosettacode.org/wiki/Levenshtein_distance#A_recursive_functional_version *) +let levenshtein_distance s t = + let m = String.length s + and n = String.length t in + let d = Array.make_matrix (m + 1) (n + 1) 0 in + for i = 0 to m do d.(i).(0) <- i done; + for j = 0 to n do d.(0).(j) <- j done; + for j = 1 to n do + for i = 1 to m do + if s.[i - 1] = t.[j - 1] then d.(i).(j) <- d.(i - 1).(j - 1) + else + d.(i).(j) <- min (d.(i - 1).(j) + 1) (min (d.(i).(j - 1) + 1) (d.(i - 1).(j - 1) + 1)) + done + done; + d.(m).(n) + + + +module type Logger = sig + include MonadTransformer + val catch : (msg -> 'a t) -> 'a t -> 'a t + + val get_error : (msg -> unit) -> 'a t -> 'a t + val throw : msg -> 'a t + val log : msg -> unit t + val recover : 'a -> 'a t -> 'a t + val fail : 'a t -> 'a t + val log_if : msg -> bool -> unit t + val throw_if : msg -> bool -> unit t + val throw_if_none : msg -> 'a option -> 'a t + val throw_if_some : ('a -> msg) -> 'a option -> unit t + + val get_warnings : (msg list -> unit) -> 'a t -> 'a t + + val run : ('a -> 'b t) -> ('b -> unit) -> (msg list -> unit) -> 'a -> unit old_t + + +end + + module MakeTransformer (M : Monad) (*: Logger with type 'a t = (('a,msg) Result.t * msg list) M.t and type 'a old_t = 'a M.t *) = struct + open MonadSyntax(M) + + type 'a old_t = 'a M.t + type 'a t = (('a, msg) Result.t * log) old_t + + let pure (x:'a) : 'a t = (Ok x,Log.mempty) |> M.pure + + let fmap (f: 'a -> 'b) (x:'a t) : 'b t = + let+ v,l = x in + match v with + | Ok x -> Ok (f x),l + | Error e -> Error e,l + + let apply (f:('a -> 'b) t) (x: 'a t) : 'b t = + let+ f,l1 = f and* x,l2 = x in + match f,x with + | Error err1 , Error err2 -> Error err1, Log.mconcat l1 {l2 with errors=err2::l2.errors} + | Error err1,_ -> Error err1, Log.mconcat l1 l2 + | Ok f, Ok x -> Ok (f x),Log.mconcat l1 l2 + | Ok _, Error err -> Error err,Log.mconcat l1 l2 + + let bind (x:'a t) (f : 'a -> 'b t) : 'b t = + let* v,l1 = x in + match v with + | Error err -> (Error err,l1) |> M.pure + | Ok x -> let+ v,l2 = f x in v,Log.mconcat l1 l2 + + let lift (x:'a M.t) : 'a t = let+ x in Ok x,Log.mempty + + let throw (e:msg) : 'a t = (Error e,Log.mempty) |> M.pure + + let catch (f:msg -> 'a t) (x : 'a t) : 'a t = + let* v,l = x in + match v with + | Error err -> let+ x,l2 = f err in x,Log.mconcat l l2 + | Ok x -> (Ok x,l) |> M.pure + + + let get_error (f:msg -> unit) (x : 'a t) : 'a t = + let* v,l = x in + match v with + | Error err -> f err; ((Error err),l) |> M.pure + | Ok x -> (Ok x,l) |> M.pure + + + let log (msg:msg) : unit t = (Ok (),{errors=[]; warnings=[msg]}) |> M.pure + + let recover (default : 'a) (x:'a t) : 'a t = + let+ v,l = x in + match v with + | Ok x -> Ok x,l + | Error err -> Ok default,{l with errors=err::l.errors} + + + let fail (x:'a t) : 'a t = + let+ v,l = x in + match v,l with + | Error err,_ -> Error err,l + | Ok x,{errors=[];_} -> Ok x,l + | Ok _,{errors=h::t;_} -> Error h,{l with errors=t} + + let log_if e b = if b then log e else pure () + let throw_if e b = if b then throw e else pure () + + let throw_if_none (e: msg) (x:'a option) : 'a t = + match x with + | None -> throw e + | Some r -> pure r + + let throw_if_some (f: 'a -> msg) (x:'a option) : unit t = + match x with + | Some r -> throw (f r) + | None -> pure () + + + let get_warnings (f : msg list -> unit) (x : 'a t) : 'a t = + let+ v,l = x in f l.warnings; (v,l) + + let run (f : 'a -> 'b t) (on_success : 'b -> unit) (on_error : log -> unit) : 'a -> unit old_t = + fun x -> let* res = f x in + match res with + | Error e,l -> return (on_error {l with errors=e::l.errors}) + | Ok x,_ -> return @@ on_success x +end + +module Logger = MakeTransformer(MonadIdentity) \ No newline at end of file diff --git a/src/common/monadic/monad.ml b/src/common/monadic/monad.ml index 4c5483c..ac37192 100644 --- a/src/common/monadic/monad.ml +++ b/src/common/monadic/monad.ml @@ -193,7 +193,7 @@ module MonadFunctions (M : Monad) = struct | h::t -> let+ u = (f h) and* t = map f t in u::t - let rec map2 (f : 'a -> 'a -> 'b M.t) (l1 : 'a list) (l2 : 'a list): ('b list) M.t = + let rec map2 (f : 'a -> 'b -> 'c M.t) (l1 : 'a list) (l2 : 'b list): ('c list) M.t = match l1,l2 with | [],[] -> return [] | h1::t1,h2::t2 -> let+ u = (f h1 h2) and* t = map2 f t1 t2 in u::t @@ -203,6 +203,12 @@ module MonadFunctions (M : Monad) = struct match l with | [] -> return () | h::t -> f h >>= (fun () -> iter f t) + let iteri (f : int -> 'a -> unit M.t) (l : 'a list) : unit M.t = + let rec aux i l = + match l with + | [] -> return () + | h::t -> f i h >>= (fun () -> aux (i+1) t) + in aux 0 l let rec iter2 f l1 l2 = match (l1, l2) with diff --git a/src/common/monadic/monadEither.ml b/src/common/monadic/monadEither.ml index e6ccea5..4436b5f 100644 --- a/src/common/monadic/monadEither.ml +++ b/src/common/monadic/monadEither.ml @@ -28,6 +28,7 @@ module MakeTransformer (M : Monad) (T : Type) : MonadTransformer | Left l -> Left l |> M.pure let lift (x:'a M.t) : 'a t = let+ x in right x + end module Make = MakeTransformer(MonadIdentity) diff --git a/src/common/monadic/monadError.ml b/src/common/monadic/monadError.ml index 061f225..831d0ca 100644 --- a/src/common/monadic/monadError.ml +++ b/src/common/monadic/monadError.ml @@ -46,7 +46,6 @@ end module ErrorMonadEither = struct module Make (T : Type) : ErrorMonad with type 'a t = (T.t, 'a) Either.t and type error = T.t = struct - include MonadEither.Make(T) type error = T.t diff --git a/src/common/monadic/monadState.ml b/src/common/monadic/monadState.ml index 3f6fb79..68a2dad 100644 --- a/src/common/monadic/monadState.ml +++ b/src/common/monadic/monadState.ml @@ -10,7 +10,7 @@ module type S = functor (M:Monad) (State: Type) -> sig val update : (State.t -> State.t M.t) -> unit t -end with type 'a t = State.t -> ('a * State.t) M.t and type 'a old_t := 'a M.t +end with type 'a t = State.t -> ('a * State.t) M.t and type 'a old_t := 'a M.t module T : S = functor (M:Monad) (State: Type) -> struct open MonadSyntax(M) diff --git a/src/common/pass.ml b/src/common/pass.ml index b2a5f7c..78bc4fb 100644 --- a/src/common/pass.ml +++ b/src/common/pass.ml @@ -1,4 +1,4 @@ -open Error +open Logging open Monad open TypesCommon @@ -79,10 +79,10 @@ module MakeFunctionPass type p_out - val lower_method : m_in * method_sig -> SailModule.SailEnv(V).t -> (m_in,p_in) SailModule.methods_processes SailModule.t -> (m_out * SailModule.SailEnv(V).D.t) Logger.t + val lower_method : m_in * method_sig -> (SailModule.SailEnv(V).t * Env.TypeEnv.t) -> (m_in,p_in) SailModule.methods_processes SailModule.t -> (m_out * SailModule.SailEnv(V).D.t * Env.TypeEnv.t) Logger.t - val lower_process : p_in process_defn -> SailModule.SailEnv(V).t -> (m_in,p_in) SailModule.methods_processes SailModule.t -> (p_out * SailModule.SailEnv(V).D.t) Logger.t + val lower_process : p_in process_defn -> (SailModule.SailEnv(V).t * Env.TypeEnv.t)-> (m_in,p_in) SailModule.methods_processes SailModule.t -> (p_out * SailModule.SailEnv(V).D.t * Env.TypeEnv.t) Logger.t val preprocess : (m_in,p_in) SailModule.methods_processes SailModule.t -> (m_in,p_in) SailModule.methods_processes SailModule.t Logger.t end) @@ -95,20 +95,19 @@ struct module VEnv = SailModule.SailEnv(V) - let lower_method (m:T.m_in method_defn) (sm : (T.m_in,T.p_in) SailModule.methods_processes SailModule.t) : (VEnv.D.t * T.m_out method_defn) Logger.t = + let lower_method (m:T.m_in method_defn) (sm : (T.m_in,T.p_in) SailModule.methods_processes SailModule.t) : ((VEnv.D.t * Env.TypeEnv.t) * T.m_out method_defn) Logger.t = match m.m_body with | Right f -> let* ve = VEnv.get_start_env sm.declEnv m.m_proto.params in - let+ b,d = T.lower_method (f,m.m_proto) ve sm in - d,{ m with m_body=Either.right b } - | Left x -> Logger.pure (sm.declEnv,{ m with m_body = Left x}) (* nothing to do for externals *) + let+ b,d,t = T.lower_method (f,m.m_proto) (ve,sm.typeEnv) sm in + (d,t),{ m with m_body=Either.right b } + | Left x -> Logger.pure ((sm.declEnv,sm.typeEnv),{ m with m_body = Left x}) (* nothing to do for externals *) - let lower_process (p: T.p_in process_defn) (sm : (T.m_in,T.p_in) SailModule.methods_processes SailModule.t) : (VEnv.D.t * T.p_out process_defn ) Logger.t = - let start_env = VEnv.get_start_env sm.declEnv p.p_interface.p_params in - let* ve = start_env in - let+ p_body,d = T.lower_process p ve sm in - d,{ p with p_body} + let lower_process (p: T.p_in process_defn) (sm : (T.m_in,T.p_in) SailModule.methods_processes SailModule.t) : ((VEnv.D.t * Env.TypeEnv.t) * T.p_out process_defn ) Logger.t = + let* ve = VEnv.get_start_env sm.declEnv p.p_interface.p_params in + let+ p_body,d,t = T.lower_process p (ve,sm.typeEnv) sm in + (d,t),{ p with p_body} @@ -116,15 +115,15 @@ struct let* sm = sm >>= T.preprocess in Logs.info (fun m -> m "Lowering module '%s' to '%s'" sm.md.name name); ( - let* declEnv,methods = ListM.fold_left_map - (fun declEnv methd -> lower_method methd {sm with declEnv}) - sm.declEnv sm.body.methods |> Logger.recover (sm.declEnv,[]) + let* (declEnv,typeEnv),methods = ListM.fold_left_map + (fun (declEnv,typeEnv) methd -> lower_method methd {sm with declEnv ; typeEnv}) + (sm.declEnv,sm.typeEnv) sm.body.methods |> Logger.recover ((sm.declEnv,sm.typeEnv),[]) in - let+ declEnv,processes = ListM.fold_left_map - (fun declEnv proccess -> lower_process proccess {sm with declEnv}) - declEnv sm.body.processes |> Logger.recover (sm.declEnv,[]) in + let+ (declEnv,typeEnv),processes = ListM.fold_left_map + (fun (declEnv,typeEnv) proccess -> lower_process proccess {sm with declEnv ; typeEnv}) + (declEnv,typeEnv) sm.body.processes |> Logger.recover ((declEnv,typeEnv),[]) in - { sm with body=SailModule.{processes; methods} ; declEnv } + { sm with body=SailModule.{processes; methods} ; declEnv; typeEnv } ) |> Logger.fail end \ No newline at end of file diff --git a/src/common/ppCommon.ml b/src/common/ppCommon.ml index 864ec8b..f6064b8 100644 --- a/src/common/ppCommon.ml +++ b/src/common/ppCommon.ml @@ -40,7 +40,7 @@ let pp_binop pf b = - let rec pp_type (pf : formatter) (t : sailtype) : unit = + let rec pp_type (pf : formatter) (_,t : sailtype) : unit = match t with Bool -> pp_print_string pf "bool" | Int n -> Format.fprintf pf "i%i" n diff --git a/src/common/sailModule.ml b/src/common/sailModule.ml index a4073ef..fd0dc1d 100644 --- a/src/common/sailModule.ml +++ b/src/common/sailModule.ml @@ -1,5 +1,5 @@ open TypesCommon -module E = Error.Logger +module E = Logging.Logger module Declarations = struct type process_decl = loc * process_proto @@ -13,11 +13,13 @@ module DeclEnv = Env.DeclarationsEnv(Declarations) module SailEnv = Env.VariableDeclEnv(Declarations) +module TypeMap = Map.Make(struct type t = loc let compare = compare end) type ('m,'p) methods_processes = {methods : 'm method_defn list ; processes : 'p process_defn list; } type 'a t = { + typeEnv: Env.TypeEnv.t; (* create functions for each type to create if not exist entry to map (otherwise gives same id )*) declEnv: DeclEnv.t; builtins : method_sig list ; body : 'a; @@ -27,6 +29,7 @@ type 'a t = let emptyModule empty_content = { + typeEnv = Env.TypeEnv.empty; declEnv = DeclEnv.empty; builtins = []; body = empty_content; diff --git a/src/common/typesCommon.ml b/src/common/typesCommon.ml index 019e18c..bfdae09 100644 --- a/src/common/typesCommon.ml +++ b/src/common/typesCommon.ml @@ -50,7 +50,7 @@ let string_of_decl : (_,_,_,_,_) decl_sum -> string = function | T _ -> "type" -type sailtype = +type sailtype = loc * sailtype_ and sailtype_ = | Bool | Int of int | Float @@ -75,29 +75,33 @@ type literal = | LString of string let sailtype_of_literal = function -| LBool _ -> Bool -| LFloat _ -> Float -| LInt l -> Int l.size -| LChar _ -> Char -| LString _ -> String +| LBool _ -> dummy_pos,Bool +| LFloat _ -> dummy_pos,Float +| LInt l -> dummy_pos,Int l.size +| LChar _ -> dummy_pos,Char +| LString _ -> dummy_pos,String let rec string_of_sailtype (t : sailtype option) : string = let open Printf in match t with - | Some Bool -> "bool" - | Some Int size -> "i" ^ string_of_int size - | Some Float -> "float" - | Some Char -> "char" - | Some String -> "string" - | Some ArrayType (t,s) -> sprintf "array<%s;%d>" (string_of_sailtype (Some t)) s - | Some CompoundType {name=(_,x); generic_instances=[];_} -> (* empty compound type -> lookup what it binds to *) sprintf "%s" x - | Some CompoundType {name=(_,x); generic_instances;_} -> sprintf "%s<%s>" x (String.concat ", " (List.map (fun t -> string_of_sailtype (Some t)) generic_instances)) - | Some Box(t) -> sprintf "ref<%s>" (string_of_sailtype (Some t)) - | Some RefType (t,b) -> - if b then sprintf "&mut %s" (string_of_sailtype (Some t)) - else sprintf "&%s" (string_of_sailtype (Some t)) - | Some GenericType(s) -> s + | Some (_,t) -> + begin + match t with + | Bool -> "bool" + | Int size -> "i" ^ string_of_int size + | Float -> "float" + | Char -> "char" + | String -> "string" + | ArrayType (t,s) -> sprintf "array<%s;%d>" (string_of_sailtype (Some t)) s + | CompoundType {name=(_,x); generic_instances=[];_} -> (* empty compound type -> lookup what it binds to *) sprintf "%s" x + | CompoundType {name=(_,x); generic_instances;_} -> sprintf "%s<%s>" x (String.concat ", " (List.map (fun t -> string_of_sailtype (Some t)) generic_instances)) + | Box(t) -> sprintf "ref<%s>" (string_of_sailtype (Some t)) + | RefType (t,b) -> + if b then sprintf "&mut %s" (string_of_sailtype (Some t)) + else sprintf "&%s" (string_of_sailtype (Some t)) + | GenericType(s) -> s + end | None -> "void" type unOp = Neg | Not @@ -118,7 +122,7 @@ type struct_defn = s_pos : loc; s_name : string; s_generics : string list; - s_fields : (string * (loc * sailtype)) list; + s_fields : (l_str * sailtype) list; } type enum_defn = @@ -222,7 +226,7 @@ let defn_to_proto (type proto) (decl: proto decl) : proto = match decl with and generics = d.p_generics and params = d.p_interface.p_params in {read;write;generics;params} -| Struct d -> {generics=d.s_generics;fields=List.mapi (fun i (n,(l,t)) -> n,(l,t,i)) d.s_fields} +| Struct d -> {generics=d.s_generics;fields=List.mapi (fun i ((l,n),t) -> n,(l,t,i)) d.s_fields} | Enum d -> {generics=d.e_generics;injections=d.e_injections} type import = diff --git a/src/parsing/astParser.ml b/src/parsing/astParser.ml index 721b25c..93e94c9 100644 --- a/src/parsing/astParser.ml +++ b/src/parsing/astParser.ml @@ -34,17 +34,17 @@ type expression = loc * expression_ and expression_ = | BinOp of binOp * expression * expression | Ref of bool * expression | ArrayStatic of expression list - | StructAlloc of l_str option * l_str * expression dict + | StructAlloc of l_str option * l_str * (loc * expression) dict | EnumAlloc of l_str * expression list | MethodCall of l_str option * l_str * expression list - type pattern = - | PVar of string - | PCons of string * pattern list +type pattern = +| PVar of string +| PCons of string * pattern list type statement = loc * statement_ and statement_ = - | DeclVar of bool * string * sailtype option * expression option + | DeclVar of bool * string * sailtype option * expression option | Skip | Assign of expression * expression | Seq of statement * statement @@ -67,7 +67,7 @@ type ('s,'e) p_statement = loc * ('s,'e) p_statement_ and ('s,'e) p_statement_ = | PGroup of {p_ty : pgroup_ty ; cond : 'e option ; children : ('s,'e) p_statement list} type ('s,'e) process_body = { - locals : (loc * (string * (loc*sailtype))) list; + locals : (l_str * sailtype) list; init : 's; proc_init : (loc * 'e proc_init) list; loop : ('s,'e) p_statement; @@ -80,7 +80,7 @@ type 'a defn = | Method of statement method_defn list | Process of 'a process_defn -module E = Error.Logger +module E = Logging.Logger let mk_program (md:metadata) (imports: ImportSet.t) l : (statement, (statement,expression) process_body) SailModule.methods_processes SailModule.t E.t = let open SailModule in @@ -100,8 +100,8 @@ let mk_program (md:metadata) (imports: ImportSet.t) l : (statement, (statement in (env,m,p) | Struct d -> - let s_fields = List.sort_uniq (fun (s1,_) (s2,_) -> String.compare s1 s2) d.s_fields in - E.throw_if (Error.make d.s_pos "duplicate fields" ) + let s_fields = List.sort_uniq (fun ((_,s1),_) ((_,s2),_) -> String.compare s1 s2) d.s_fields in + E.throw_if Logging.(make_msg d.s_pos "duplicate fields" ) (List.(length s_fields <> length d.s_fields)) >>= fun () -> let+ env = DeclEnv.add_decl d.s_name (d.s_pos, defn_to_proto (Struct {d with s_fields})) Struct e |> rethrow d.s_pos in (env,m,p) @@ -114,7 +114,7 @@ let mk_program (md:metadata) (imports: ImportSet.t) l : (statement, (statement let+ env,funs = ListM.fold_left (fun (e,f) d -> - let* () = E.throw_if Error.(make d.m_proto.pos "calling a method 'main' is not allowed") (d.m_proto.name = "main") in + let* () = E.throw_if Logging.(make_msg d.m_proto.pos "calling a method 'main' is not allowed") (d.m_proto.name = "main") in let true_name = (match d.m_body with Left (sname,_) -> sname | Right _ -> d.m_proto.name) in let+ env = DeclEnv.add_decl d.m_proto.name ((d.m_proto.pos,true_name),defn_to_proto (Method d)) Method e in (env,d::f) @@ -126,5 +126,5 @@ let mk_program (md:metadata) (imports: ImportSet.t) l : (statement, (statement in let+ declEnv,methods,processes = aux l in let builtins = Builtins.get_builtins () in - {md; imports; declEnv ; body={methods;processes};builtins} + {typeEnv=Env.TypeEnv.empty; md; imports; declEnv ; body={methods;processes};builtins} diff --git a/src/parsing/parser.mly b/src/parsing/parser.mly index 04aa39c..71b8fb7 100644 --- a/src/parsing/parser.mly +++ b/src/parsing/parser.mly @@ -126,7 +126,7 @@ let var_list(var) := let process_body := - locals = list(located(id_colon(sailtype))) ; + locals = list(id_colon(sailtype)) ; init = midrule(P_INIT ; ":" ; statement?)? ; proc_init = midrule(P_PROC_INIT ; ":" ; ~ = list(located(proc_init)) ; <>)? ; loop = midrule(P_LOOP ; ":" ; loop?)? ; @@ -139,7 +139,7 @@ let process_body := } let proc_init := - | id = UID ; ":" ; "=" ; mloc = ioption(module_loc); proc = UID + | id = UID ; ":" ; "=" ; mloc = module_loc?; proc = UID ; params = midrule(p = process_params(separated_list(",", expression)); {Option.value p ~default:[]}) ; (read,write) = shared_vars(located(ID)) ; { {mloc;id;proc;params;read;write} } | mloc = ioption(module_loc) ; id = UID @@ -200,12 +200,13 @@ let expression := | "*" ; ~ = expression ; %prec UNARY | e1 = expression ; op =binOp ; e2 =expression ; { BinOp(op,e1,e2) } | ~ = delimited ("[", separated_list(",", expression), "]") ; - | ~ = ioption(module_loc) ; ~ =located(ID) ; ~ = midrule(l = brace_del_sep_list(",", id_colon(expression)); {List.fold_left (fun l (y,(_,z)) -> (y,z)::l) [] l}) ; + | ~ = ioption(module_loc) ; ~ =located(ID) ; ~ = midrule(l = brace_del_sep_list(",", id_colon(expression)); + {List.fold_left (fun l ((ly,y),z) -> (y,(ly,z))::l) [] l}) ; | ~ = located(UID) ; ~ = loption(parenthesized (separated_list(",", expression))) ; | ~ = ioption(module_loc) ; ~ = located(ID) ; ~ = parenthesized(separated_list (",", expression)) ; ) -let id_colon(X) := ~ =ID ; ":" ; ~ = located(X) ; <> +let id_colon(X) := ~ = located(ID) ; ":" ; ~ = X ; <> let literal := | TRUE ; {LBool(true) } @@ -264,7 +265,7 @@ let vardecl := VAR ; ~ = mut ; ~ = ID ; ~ = preceded(":", sailtype)? ; ~ = prece let brace_del_sep_list(sep,x) := delimited("{", separated_nonempty_list(sep, x), "}") -let located(x) := ~ = x ; { ($loc,x) } +let located(x) == ~ = x ; { ($loc,x) } let case := separated_pair(pattern, ":", statement) @@ -292,16 +293,17 @@ let pattern := | ~ = UID ; ~ = delimited("(", separated_list(",", pattern), ")") ; -let sailtype := -| TYPE_BOOL ; {Bool} -| l = TYPE_INT ; {Int l} -| TYPE_FLOAT ; {Float} -| TYPE_CHAR ; {Char} -| TYPE_STRING ; {String} -| ARRAY ; "<" ; ~ = sailtype ; ";" ; ~ = midrule(size = INT ; {Z.to_int size}) ; ">" ; -| mloc = ioption(module_loc) ; name = located(ID) ; generic_instances = instance ; {CompoundType {origin=mloc; name; generic_instances; decl_ty=None} } -| ~ = UID ; -| REF ; b = mut ; t = sailtype ; {RefType(t,b)} +let sailtype := located( + | TYPE_BOOL ; {Bool} + | l = TYPE_INT ; {Int l} + | TYPE_FLOAT ; {Float} + | TYPE_CHAR ; {Char} + | TYPE_STRING ; {String} + | ARRAY ; "<" ; ~ = sailtype ; ";" ; ~ = midrule(size = INT ; {Z.to_int size}) ; ">" ; + | mloc = ioption(module_loc) ; name = located(ID) ; generic_instances = instance ; {CompoundType {origin=mloc; name; generic_instances; decl_ty=None} } + | ~ = UID ; + | REF ; b = mut ; t = sailtype ; {RefType(t,b)} +) let instance := loption(delimited("<", separated_list(",", sailtype), ">")) \ No newline at end of file diff --git a/src/parsing/parsing.ml b/src/parsing/parsing.ml index fc4dee3..6941d04 100644 --- a/src/parsing/parsing.ml +++ b/src/parsing/parsing.ml @@ -3,7 +3,7 @@ open Common open Lexer open Lexing -open Error +open Logging open TypesCommon open AstParser module L = MenhirLib.LexerUtil @@ -20,7 +20,7 @@ let print_error_position lexbuf = -let fastParse filename : (string * (statement,(statement,expression) process_body) SailModule.methods_processes SailModule.t Error.Logger.t, string) Result.t = +let fastParse filename : (string * (statement,(statement,expression) process_body) SailModule.methods_processes SailModule.t Logging.Logger.t, string) Result.t = let text, lexbuf = L.read filename in let hash = Digest.string text in @@ -32,7 +32,7 @@ let fastParse filename : (string * (statement,(statement,expression) process_bod let lexer_prefix = "Lexer - " in (* removes lexer prefix in case of a lexing error *) let msg = String.(if starts_with ~prefix:lexer_prefix msg then sub msg (length lexer_prefix) (length msg - length lexer_prefix) else msg) in - Error.print_errors [Error.make loc msg]; + Logging.print_log {errors=[make_msg loc msg];warnings=[]}; exit 1 | exception Parser.Error -> @@ -59,7 +59,7 @@ let state checkpoint : int = let get text checkpoint i = match I.get i (env checkpoint) with | Some (I.Element (_, _, pos1, pos2)) -> - Error.show_context text (pos1, pos2) + Logging.show_context text (pos1, pos2) | None -> "???" @@ -72,9 +72,9 @@ let fail text buffer (checkpoint : _ I.checkpoint) = let message = ParserMessages.message state_num in let message = E.expand (get text checkpoint) message in Logs.debug (fun m -> m "reached error state %i "state_num); - Logger.throw @@ Error.make location message + Logger.throw @@ make_msg location message with Not_found -> - Logger.throw @@ Error.make location "Syntax error" + Logger.throw @@ make_msg location "Syntax error" let slowParse filename text = diff --git a/src/passes/ir/sailHir/astHir.ml b/src/passes/ir/sailHir/astHir.ml index d90a7bd..8dfd46c 100644 --- a/src/passes/ir/sailHir/astHir.ml +++ b/src/passes/ir/sailHir/astHir.ml @@ -32,7 +32,7 @@ type ('info,'import) expression = {info: 'info ; exp: ('info,'import) _expressio | BinOp of binOp * ('info,'import) expression * ('info,'import) expression | Ref of bool * ('info,'import) expression | ArrayStatic of ('info,'import) expression list - | StructAlloc of 'import * l_str * ('info,'import) expression dict + | StructAlloc of 'import * l_str * (loc * ('info,'import) expression) dict | EnumAlloc of l_str * ('info,'import) expression list | MethodCall of l_str * 'import * ('info,'import) expression list @@ -47,7 +47,7 @@ type ('info,'import,'exp) statement = {info: 'info; stmt: ('info,'import,'exp) _ | Loop of ('info,'import,'exp) statement | Break | Case of 'exp * (string * string list * ('info,'import,'exp) statement) list - | Invoke of string option * 'import * l_str * 'exp list + | Invoke of {ret_var:string option; import: 'import; id: l_str; args:'exp list} | Return of 'exp option (* | DeclSignal of string diff --git a/src/passes/ir/sailHir/hir.ml b/src/passes/ir/sailHir/hir.ml index 7eccce6..53e384a 100644 --- a/src/passes/ir/sailHir/hir.ml +++ b/src/passes/ir/sailHir/hir.ml @@ -5,7 +5,6 @@ open Monad open AstHir open HirUtils open M -module SM = SailModule type expression = HirUtils.expression type statement = HirUtils.statement @@ -20,58 +19,19 @@ struct type p_in = (m_in,AstParser.expression) AstParser.process_body type p_out = (m_out,expression) AstParser.process_body - open UseMonad(ECSW) + open UseMonad(M) open MakeOrderedFunctions(String) (* let get_hint id env = Option.bind (List.nth_opt (HIREnv.get_closest id env) 0) (fun id -> Some (None,Printf.sprintf "Did you mean %s ?" id)) *) + let preprocess = resolve_names - let lower_expression (e : AstParser.expression) : expression ECSW.t = - let open MonadSyntax(ECSW) in - let rec aux (info,e : AstParser.expression) : expression ECSW.t = - match e with - | Variable id -> - let* v = (ECS.find_var id |> ECSW.lift) in - ECSW.throw_if_none Error.(make info @@ Fmt.str "undeclared variable '%s'" id) v >>| fun _ -> - {info; exp=Variable id} - | Deref e -> - let+ e = aux e in {info;exp=Deref e} - | StructRead (e, id) -> - let+ e = aux e in {info; exp=StructRead (None, e, id)} - | ArrayRead (e1, e2) -> - let* e1 = aux e1 in - let+ e2 = aux e2 in - {info; exp=ArrayRead(e1,e2)} - | Literal l -> return {info; exp=Literal l} - | UnOp (op, e) -> - let+ e = aux e in {info;exp=UnOp (op, e)} - | BinOp(op,e1,e2)-> - let* e1 = aux e1 in - let+ e2 = aux e2 in - {info; exp=BinOp (op, e1, e2)} - | Ref (b, e) -> - let+ e = aux e in {info;exp=Ref(b, e)} - | ArrayStatic el -> - let+ el = ListM.map aux el in {info;exp=ArrayStatic el} - | StructAlloc (origin,id, m) -> - let m' = List.sort_uniq (fun (id1,_) (id2,_) -> String.compare id1 id2) m in - let* () = ECSW.throw_if (Error.make info "duplicate fields") List.(length m <> length m') in - let+ m' = ListM.map (pairMap2 aux) m' in - {info; exp=StructAlloc (origin, id, m')} - - | EnumAlloc (id, el) -> - let+ el = ListM.map aux el in {info;exp=EnumAlloc (id, el)} - | MethodCall (mod_loc, id, el) -> - let+ el = ListM.map aux el in {info ; exp=MethodCall(id, mod_loc, el)} - in aux e - - - let lower_method (body,_) (env:HIREnv.t) _ : (m_out * HIREnv.D.t) E.t = - let open MonadSyntax(ECS) in - let open MonadOperator(ECS) in + let lower_method (body,_) (env,tenv:HIREnv.t * _) _ : (m_out * HIREnv.D.t * _) M.E.t = + let open MonadSyntax(M.ECS) in + let open MonadOperator(M.ECS) in - let rec aux (info,s : m_in) : m_out ECS.t = + let rec aux (info,s : m_in) : m_out M.ECS.t = let buildSeq s1 s2 = {info; stmt = Seq (s1, s2)} in let buildStmt stmt = {info;stmt} in @@ -79,13 +39,13 @@ struct match s with | DeclVar (mut, id, t, e ) -> - ECS.set_var info id >>= fun () -> + M.ECS.set_var info id >>= fun () -> let* t = match t with | None -> return None | Some t -> - let* (ve,d) = ECS.get in - let* t',d' = (follow_type t d) |> EC.lift |> ECS.lift in - let+ () = ECS.update (fun _ -> E.pure (ve,d') |> EC.lift ) in + let* (ve,d) = M.ECS.get in + let* t',d' = (follow_type t d) |> M.EC.lift |> M.ECS.lift in + let+ () = M.ECS.update (fun _ -> M.E.pure (ve,d') |> M.EC.lift ) in Some t' in begin match e with @@ -97,7 +57,7 @@ struct | Assign(e1, e2) -> let* e1,s1 = lower_expression e1 in let+ e2,s2 = lower_expression e2 in - buildSeq s1 @@ buildSeqStmt s2 (Assign (e1, e2)) + buildSeq s1 @@ buildSeqStmt s2 @@ Assign (e1, e2) | Seq (c1, c2) -> let+ c1 = aux c1 and* c2 = aux c2 in {info;stmt=Seq (c1, c2)} | If (e, c1, Some c2) -> @@ -125,9 +85,9 @@ struct let i_id = "_for_i_" ^ var in let arr_length = List.length el in - let tab_decl = info,DeclVar (false, arr_id, Some (ArrayType (Int 32,arr_length)), Some iterable) in - let var_decl = info,DeclVar (true, var, Some (Int 32), None) in - let i_decl = info,DeclVar (true, i_id, Some (Int 32), Some (info,(Literal (LInt {l=Z.zero;size=32})))) in + let tab_decl = info,DeclVar (false, arr_id, Some (dummy_pos,ArrayType ((dummy_pos,Int 32),arr_length)), Some iterable) in + let var_decl = info,DeclVar (true, var, Some (dummy_pos,Int 32), None) in + let i_decl = info,DeclVar (true, i_id, Some (dummy_pos,Int 32), Some (info,(Literal (LInt {l=Z.zero;size=32})))) in let tab = info,Variable arr_id in let var = info,Variable var in @@ -142,7 +102,7 @@ struct let _while = info,While (cond,body) in let _for = info,Seq(init,_while) in aux _for - | loc,_ -> ECS.throw (Error.make loc "for loop only allows a static array expression at the moment") + | loc,_ -> M.ECS.throw Logging.(make_msg loc "for loop only allows a static array expression at the moment") end | Break () -> return {info; stmt=Break} @@ -150,9 +110,9 @@ struct | Case (e, _cases) -> let+ e,s = lower_expression e in buildSeqStmt s (Case (e, [])) - | Invoke (mod_loc, lid, el) -> - let+ el,s = ListM.map lower_expression el in - buildSeqStmt s (Invoke(None, mod_loc, lid,el)) + | Invoke (mod_loc, id, args) -> + let+ args,s = ListM.map lower_expression args in + buildSeqStmt s (Invoke {ret_var=None;import=mod_loc;id;args}) | Return e -> begin match e with @@ -161,25 +121,24 @@ struct buildSeqStmt s (Return (Some e)) end | Block c -> - let* () = ECS.update (fun e -> EC.pure (HIREnv.new_frame e)) in + let* () = M.ECS.update (fun e -> M.EC.pure (HIREnv.new_frame e)) in let* c = aux c in - let+ () = ECS.update (fun e -> EC.pure (HIREnv.pop_frame e)) in + let+ () = M.ECS.update (fun e -> M.EC.pure (HIREnv.pop_frame e)) in buildStmt (Block c) in - ECS.run (aux body env) |> E.recover ({info=dummy_pos;stmt=Skip},snd env) + M.E.(bind (M.run aux env body) (fun (r,venv) -> pure (r,venv,tenv))) |> M.E.recover ({info=dummy_pos;stmt=Skip},snd env,tenv) - - let lower_process (p:p_in process_defn) (env,decls:HIREnv.t) sm : (p_out * HIREnv.D.t) E.t = + let lower_process (p:p_in process_defn) ((env,decls),tenv: HIREnv.t * _ ) sm : (p_out * HIREnv.D.t * _) M.E.t = let open AstParser in - let module F = MonadFunctions(ECSW) in + let module F = MonadFunctions(M) in let open UseMonad(E) in let params = List.to_seq p.p_interface.p_params |> Seq.map (fun (p:param) -> p.id,p.loc) |> FieldMap.of_seq in - let locals = List.to_seq p.p_body.locals |> Seq.map (fun (l,(id,_)) -> id,l ) |> FieldMap.of_seq in + let locals = List.to_seq p.p_body.locals |> Seq.map (fun ((l,id),_) -> id,l ) |> FieldMap.of_seq in let read,write = p.p_interface.p_shared_vars in let read = List.to_seq read |> Seq.map (fun (l,(id,_)) -> id,l) |> FieldMap.of_seq in let write = List.to_seq write |> Seq.map (fun (l,(id,_)) -> id,l) |> FieldMap.of_seq in @@ -191,26 +150,26 @@ struct <> cardinal params + cardinal locals + cardinal read + cardinal write ) in - E.throw_if Error.(make dummy_pos @@ Fmt.str "process '%s' : name conflict between params,local decls or shared variables" p.p_name) has_name_conflict + E.throw_if Logging.(make_msg dummy_pos @@ Fmt.str "process '%s' : name conflict between params,local decls or shared variables" p.p_name) has_name_conflict >>= fun () -> - let add_locals v e = ListM.fold_left (fun e (l,(id,_)) -> HIREnv.declare_var id (l,()) e) (e,decls) v >>| fst in + let add_locals v e = ListM.fold_left (fun e ((l,id),_) -> HIREnv.declare_var id (l,()) e) (e,decls) v >>| fst in let add_rw r e = ListM.fold_left (fun e (l,(id,_)) -> HIREnv.declare_var id (l,()) e) (e,decls) r >>| fst in let* env = add_locals p.p_body.locals env >>= add_rw (fst p.p_interface.p_shared_vars) >>= add_rw (snd p.p_interface.p_shared_vars) in - let* init,decls = lower_method (p.p_body.init,()) (env,decls) sm |> E.recover ({info=dummy_pos;stmt=Skip},decls) in + let* init,decls,tenv = lower_method (p.p_body.init,()) ((env,decls),tenv) sm |> E.recover ({info=dummy_pos;stmt=Skip},decls,tenv) in let* (proc_init,_),decls = F.ListM.map (fun ((l,p): loc * _ proc_init) -> - let open UseMonad(ECSW) in + let open UseMonad(M) in let+ params = F.ListM.map lower_expression p.params in l,{p with params} - ) p.p_body.proc_init (env,decls) |> ECS.run + ) p.p_body.proc_init (env,decls) |> M.ECS.run in let+ loop = - let process_cond = function None -> return None | Some c -> let+ (cond,_),_ = lower_expression c (env,decls) |> ECS.run in Some cond in + let process_cond = function None -> return None | Some c -> let+ (cond,_),_ = lower_expression c (env,decls) |> M.ECS.run in Some cond in let rec aux (l,s) = match s with - | Statement (s,cond) -> let+ s,_ = lower_method (s,()) (env,decls) sm and* cond = process_cond cond in l,Statement (s,cond) + | Statement (s,cond) -> let+ s,_,_ = lower_method (s,()) ((env,decls),tenv) sm and* cond = process_cond cond in l,Statement (s,cond) | Run (proc,cond) -> let+ cond = process_cond cond in l,Run (proc,cond) | PGroup g -> let* cond = process_cond g.cond in @@ -218,104 +177,7 @@ struct l,PGroup {g with cond ; children} in aux p.p_body.loop in - {p.p_body with init ; proc_init; loop},decls - - - - let preprocess (sm: ('a,'b) SM.methods_processes SM.t) : ('a,'b) SM.methods_processes SM.t E.t = - let module ES = struct - module S = MonadState.M(struct type t = D.t end) - include Error.MakeTransformer(S) - let update_env f = S.update f |> lift - let set_env e = S.set e |> lift - let get_env = S.get |> lift - end - in - let open UseMonad(ES) in - let module TEnv = MakeFromSequencable(SM.DeclEnv.TypeSeq) in - let module SEnv = MakeFromSequencable(SM.DeclEnv.StructSeq) in - (* let module MEnv = F.MakeFromSequencable(SM.DeclEnv.MethodSeq) in - let module PEnv = F.MakeFromSequencable(SM.DeclEnv.ProcessSeq) in *) - let open SM.DeclEnv in - - (* resolving aliases *) - let sm = ( - - let* declEnv = ES.get_env in - - let* () = - TEnv.iter ( - fun (id,({ty; _} as def)) -> - let* ty = match ty with - | None -> return None - | Some t -> - let* env = ES.get_env in - let* t,env = (follow_type t env) |> ES.S.lift in - let+ () = ES.set_env env in Some t - in - ES.update_env (update_decl id {def with ty} (Self Type)) - ) (get_own_decls declEnv |> get_decls Type) in - - let* env = ES.get_env in - - let* () = SEnv.iter ( - fun (id,(l,{fields; generics})) -> - let* fields = ListM.map ( - fun (name,(l,t,n)) -> - let* env = ES.get_env in - let* t,env = (follow_type t env) |> ES.S.lift in - let+ () = ES.set_env env in - name,(l,t,n) - ) fields - in - let proto = l,{fields;generics} in - ES.update_env (update_decl id proto (Self Struct)) - ) (get_own_decls env |> get_decls Struct) - in - - let* methods = ListM.map ( - fun ({m_proto;m_body} as m) -> - let* rtype = match m_proto.rtype with - | None -> return None - | Some t -> - let* env = ES.get_env in - let* t,env = (follow_type t env) |> ES.S.lift in - let+ () = ES.set_env env in Some t - in - let* params = ListM.map ( - fun (({ty;_}:param) as p) -> - let* env = ES.get_env in - let* ty,env = (follow_type ty env) |> ES.S.lift in - let+ () = ES.set_env env in {p with ty} - ) m_proto.params in - let m = {m with m_proto={m_proto with params; rtype}} in - let true_name = (match m_body with Left (sname,_) -> sname | Right _ -> m_proto.name) in - let+ () = ES.update_env (update_decl m_proto.name ((m_proto.pos,true_name), defn_to_proto (Method m)) (Self Method)) - in m - ) sm.body.methods in - - let* processes = ListM.map ( - fun proc -> - let* p_params = ListM.map ( - fun (({ty;_}:param) as p) -> - let* env = ES.get_env in - let* ty,env = (follow_type ty env) |> ES.S.lift in - let+ () = ES.set_env env in {p with ty} - ) proc.p_interface.p_params in - let p = {proc with p_interface={proc.p_interface with p_params}} in - let+ () = ES.update_env (update_decl p.p_name (p.p_pos, defn_to_proto (Process p)) (Self Process)) - in p - ) sm.body.processes in - - (* at this point, all types must have an origin *) - - - let* declEnv = ES.get_env in - let+ () = SEnv.iter (fun (id,proto) -> check_non_cyclic_struct id proto declEnv |> ES.S.lift) (get_own_decls declEnv |> get_decls Struct) in + {p.p_body with init ; proc_init; loop},decls,tenv - (* Logs.debug (fun m -> m "%s" @@ string_of_declarations declEnv); *) - {sm with body=SM.{methods; processes}; declEnv} - ) sm.declEnv |> fst in - sm end ) \ No newline at end of file diff --git a/src/passes/ir/sailHir/hirMonad.ml b/src/passes/ir/sailHir/hirMonad.ml index 4a99417..53ac723 100644 --- a/src/passes/ir/sailHir/hirMonad.ml +++ b/src/passes/ir/sailHir/hirMonad.ml @@ -8,25 +8,26 @@ module Make(MonoidSeq : Monad.Monoid) = struct end module HIREnv = SailModule.SailEnv(V) - module E = Error.Logger - module EC = MonadState.CounterTransformer(E)(struct type t = int let succ = Int.succ let init = 0 end) - module ECS = struct - include MonadState.T(EC)(HIREnv) - let fresh = EC.fresh |> lift - let run e = let e = EC.run e in E.bind e (fun (e,(_,s)) -> E.pure (e,s) ) - let find_var id = bind get (fun e -> HIREnv.get_var id e |> pure) - let set_var loc id = update (fun e -> HIREnv.declare_var id (loc,()) e |> EC.lift) - let throw e = E.throw e |> EC.lift |> lift - let throw_if_none b e = E.throw_if_none b e |> EC.lift |> lift - let log e = E.log e |> EC.lift |> lift - let log_if b e = E.log_if b e |> EC.lift |> lift - let throw_if b e = E.throw_if b e |> EC.lift |> lift - - let get_decl id ty = bind get (fun e -> HIREnv.get_decl id ty e |> pure) - end - - module ECSW = struct + module M = struct + module E = Logging.Logger + module EC = MonadState.CounterTransformer(E)(struct type t = int let succ = Int.succ let init = 0 end) + module ECS = struct + include MonadState.T(EC)(HIREnv) + let fresh = EC.fresh |> lift + let run e = let e = EC.run e in E.bind e (fun (e,(_,s)) -> E.pure (e,s) ) + let find_var id = bind get (fun e -> HIREnv.get_var id e |> pure) + let set_var loc id = update (fun e -> HIREnv.declare_var id (loc,()) e |> EC.lift) + + let throw e = E.throw e |> EC.lift |> lift + let throw_if_none b e = E.throw_if_none b e |> EC.lift |> lift + let log e = E.log e |> EC.lift |> lift + let log_if b e = E.log_if b e |> EC.lift |> lift + let throw_if b e = E.throw_if b e |> EC.lift |> lift + + let get_decl id ty = bind get (fun e -> HIREnv.get_decl id ty e |> pure) + end + include MonadWriter.MakeTransformer(ECS)(MonoidSeq) let get_decl id ty = ECS.bind ECS.get (fun e -> HIREnv.get_decl id ty e |> ECS.pure) |> lift let fresh = ECS.fresh |> lift @@ -37,5 +38,7 @@ module Make(MonoidSeq : Monad.Monoid) = struct let log e = ECS.log e |> lift let get_env = ECS.get |> lift + + let run (f : 'a -> 'b old_t) env = fun x -> ECS.run (f x env) end end \ No newline at end of file diff --git a/src/passes/ir/sailHir/hirUtils.ml b/src/passes/ir/sailHir/hirUtils.ml index b7b2a85..4ce2e08 100644 --- a/src/passes/ir/sailHir/hirUtils.ml +++ b/src/passes/ir/sailHir/hirUtils.ml @@ -1,7 +1,7 @@ open Common open TypesCommon open Monad - +open SailParser type expression = (loc,l_str option) AstHir.expression type statement = (loc,l_str option,expression) AstHir.statement @@ -14,92 +14,144 @@ module M = HirMonad.Make( struct ) open M -open MonadSyntax(M.E) -open MonadOperator(M.E) -open MonadFunctions(M.E) module D = SailModule.DeclEnv -let find_symbol_source ?(filt = [E (); S (); T ()] ) (loc,id: l_str) (import : l_str option) (env : D.t) : (l_str * D.decls) E.t = -match import with -| Some (iloc,name) -> - if name = Constants.sail_module_self || name = D.get_name env then - let+ decl = - D.find_decl id (Self (Filter filt)) env - |> E.throw_if_none (Error.make loc @@ "no declaration named '" ^ id ^ "' in current module ") +let lower_expression (e : AstParser.expression) : expression M.t = + let open UseMonad(M) in + let rec aux (info,e : AstParser.expression) : expression M.t = + let open AstHir in + match e with + | Variable id -> + let* v = (M.ECS.find_var id |> M.lift) in + M.throw_if_none Logging.(make_msg info @@ Fmt.str "undeclared variable '%s'" id) v >>| fun _ -> + {info; exp=Variable id} + + | Deref e -> + let+ e = aux e in {info;exp=Deref e} + + | StructRead (e, id) -> + let+ e = aux e in {info; exp=StructRead (None, e, id)} + + | ArrayRead (e1, e2) -> + let* e1 = aux e1 in + let+ e2 = aux e2 in + {info; exp=ArrayRead(e1,e2)} + | Literal l -> return {info; exp=Literal l} + + | UnOp (op, e) -> + let+ e = aux e in {info;exp=UnOp (op, e)} + + | BinOp(op,e1,e2)-> + let* e1 = aux e1 in + let+ e2 = aux e2 in + {info; exp=BinOp (op, e1, e2)} + + | Ref (b, e) -> + let+ e = aux e in {info;exp=Ref(b, e)} + + | ArrayStatic el -> + let+ el = ListM.map aux el in {info;exp=ArrayStatic el} + + | StructAlloc (origin,id, m) -> + let m' = List.sort_uniq (fun (id1,_) (id2,_) -> String.compare id1 id2) m in + let* () = M.throw_if Logging.(make_msg info "duplicate fields") List.(length m <> length m') in + let+ m' = ListM.map (aux |> pairMap2 |> pairMap2) m' in + {info; exp=StructAlloc (origin, id, m')} + + | EnumAlloc (id, el) -> + let+ el = ListM.map aux el in {info;exp=EnumAlloc (id, el)} + + + | MethodCall (import, id, el) -> let+ el = ListM.map aux el in {info ; exp=MethodCall(id, import, el)} + in aux e + + +open UseMonad(M.E) + +let find_symbol_source ?(filt = [E (); S (); T ()] ) (loc,id: l_str) (import : l_str option) (env : D.t) : (l_str * D.decls) M.E.t = + match import with + | Some (iloc,name) -> + if name = Constants.sail_module_self || name = D.get_name env then + let+ decl = + D.find_decl id (Self (Filter filt)) env + |> M.E.throw_if_none Logging.(make_msg loc @@ "no declaration named '" ^ id ^ "' in current module ") + in + (iloc,D.get_name env),decl + else + let+ t = + M.E.throw_if_none Logging.(make_msg iloc ~hint:(Some (None,Fmt.str "try importing the module with 'import %s'" name)) @@ "unknown module " ^ name) + (List.find_opt (fun {mname;_} -> mname = name) (D.get_imports env)) >>= fun _ -> + M.E.throw_if_none Logging.(make_msg loc @@ "declaration " ^ id ^ " not found in module " ^ name) + (D.find_decl id (Specific (name, Filter filt)) env) in - (iloc,D.get_name env),decl - else - let+ t = - E.throw_if_none (Error.make iloc ~hint:(Some (None,Fmt.str "try importing the module with 'import %s'" name)) @@ "unknown module " ^ name) - (List.find_opt (fun {mname;_} -> mname = name) (D.get_imports env)) >>= fun _ -> - E.throw_if_none (Error.make loc @@ "declaration " ^ id ^ " not found in module " ^ name) - (D.find_decl id (Specific (name, Filter filt)) env) - in - (iloc,name),t -| None -> (* find it ourselves *) - begin - let decl = D.find_decl id (All (Filter filt)) env in - match decl with - | [i,m] -> - (* Logs.debug (fun m -> m "'%s' is from %s" id i.mname); *) - return ((dummy_pos,i.mname),m) - - | [] -> E.throw @@ Error.make loc @@ "unknown declaration " ^ id - - | _ as choice -> E.throw - @@ Error.make loc ~hint:(Some (None,"specify one with '::' annotation")) - @@ Fmt.str "multiple definitions for declaration %s : \n\t%s" id - (List.map (fun (i,def) -> match def with T def -> Fmt.str "from %s : %s" i.mname (string_of_sailtype (def.ty)) | _ -> "") choice |> String.concat "\n\t") + (iloc,name),t + | None -> (* find it ourselves *) + begin + let decl = D.find_decl id (All (Filter filt)) env in + match decl with + | [i,m] -> + (* Logs.debug (fun m -> m "'%s' is from %s" id i.mname); *) + return ((dummy_pos,i.mname),m) + + | [] -> M.E.throw Logging.(make_msg loc @@ "unknown declaration " ^ id) + + | _ as choice -> M.E.throw + @@ Logging.make_msg loc ~hint:(Some (None,"specify one with '::' annotation")) + @@ Fmt.str "multiple definitions for declaration %s : \n\t%s" id + (List.map (fun (i,def) -> match def with T def -> Fmt.str "from %s : %s" i.mname (string_of_sailtype (def.ty)) | _ -> "") choice |> String.concat "\n\t") end -let follow_type ty env : (sailtype * D.t) E.t = - +let follow_type ty env : (sailtype * D.t) M.E.t = let current = SailModule.DeclEnv.get_name env in (* Logs.debug (fun m -> m "following type '%s'" (string_of_sailtype (Some ty))); *) - let rec aux ty' path : (sailtype * ty_defn list) E.t = + let rec aux (l_ty,ty') path : (sailtype * ty_defn list) M.E.t = (* Logs.debug (fun m -> m "path: %s" (List.map (fun ({name;_}:ty_defn) -> name)path |> String.concat " ")); *) - match ty' with - | CompoundType {origin;name=id;generic_instances;_} -> (* compound type, find location and definition *) + let+ (t,path : sailtype_ * ty_defn list) = + match ty' with + + | ArrayType (t,n) -> let+ t,path = aux t path in ArrayType (t,n),path + | Box t -> let+ t,path = aux t path in Box t,path + | RefType (t,mut) -> let+ t,path = aux t path in RefType (t,mut),path + | Bool | Char | Int _ | Float | String | GenericType _ as t -> (* basic type, stop *) + (* Logs.debug (fun m -> m "'%s' resolves to '%s'" (string_of_sailtype (Some ty)) (string_of_sailtype (Some ty'))); *) + return (t,path) + | CompoundType {origin;name=id;generic_instances;_} -> (* compound type, find location and definition *) let* (l,origin),def = find_symbol_source id origin env in let default = fun ty -> CompoundType {origin=Some (l,origin);name=id; generic_instances;decl_ty=Some ty} in begin match def with | T def when origin=current -> begin - match def.ty with - | Some ty -> ( - E.throw_if_some (fun _ -> (Error.make def.loc - @@ Fmt.str "circular type declaration : %s" - @@ String.concat " -> " (List.rev (def::path) |> List.map (fun ({name;_}:ty_defn) -> name)))) - (List.find_opt (fun (d:ty_defn) -> d.name = def.name) path) - - >>= (fun () -> aux ty (def::path)) - ) - | None -> (* abstract type *) - (* Logs.debug (fun m -> m "'%s' resolves to abstract type '%s' " (string_of_sailtype (Some ty)) def.name); *) - return (default (T ()),path) + match def.ty with + | Some ty -> ( + M.E.throw_if_some + (fun _ -> Logging.(make_msg def.loc + @@ Fmt.str "circular type declaration : %s" + @@ String.concat " -> " (List.rev (def::path) |> List.map (fun ({name;_}:ty_defn) -> name))) + ) + (List.find_opt (fun (d:ty_defn) -> d.name = def.name) path) + + >>= fun () -> let+ ((_,t),p) = aux ty (def::path) in t,p + ) + | None -> (* abstract type *) + (* Logs.debug (fun m -> m "'%s' resolves to abstract type '%s' " (string_of_sailtype (Some ty)) def.name); *) + return (default (T ()),path) end - | _ -> - return (default (unit_decl_of_decl def),path) (* must point to an enum or struct, nothing to resolve *) - end - | ArrayType (t,n) -> let+ t,path = aux t path in ArrayType (t,n),path - | Box t -> let+ t,path = aux t path in Box t,path - | RefType (t,mut) -> let+ t,path = aux t path in RefType (t,mut),path - | Bool | Char | Int _ | Float | String | GenericType _ as t -> (* basic type, stop *) - (* Logs.debug (fun m -> m "'%s' resolves to '%s'" (string_of_sailtype (Some ty)) (string_of_sailtype (Some ty'))); *) - return (t,path) + | _ -> return (default @@ unit_decl_of_decl def,path) (* must point to an enum or struct, nothing to resolve *) + end + in (l_ty,t),path in let+ res,p = aux ty [] in (* p only contains type_def from the current module *) let env = List.fold_left (fun env (td:ty_defn) -> D.update_decl td.name {td with ty=Some res} (Self Type) env) env p in res,env -let check_non_cyclic_struct (name:string) (l,proto) env : unit E.t = +let check_non_cyclic_struct (name:string) (l,proto) env : unit M.E.t = let rec aux id curr_loc (s:struct_proto) checked = - let* () = E.throw_if - (Error.make l + let* () = M.E.throw_if + Logging.(make_msg l ~hint:(Some (Some curr_loc,"Hint : try boxing this type")) @@ Fmt.str "circular structure declaration : %s" @@ String.concat " -> " (List.rev (id::checked)) @@ -109,7 +161,7 @@ let check_non_cyclic_struct (name:string) (l,proto) env : unit E.t = (List.mem id checked) in let checked = id::checked in ListM.iter ( - fun (_,(l,t,_)) -> match t with + fun (_,(l,t,_)) -> match snd t with | CompoundType {name=_,name;origin=Some (_,origin); decl_ty = Some S ();_} -> begin @@ -117,37 +169,37 @@ let check_non_cyclic_struct (name:string) (l,proto) env : unit E.t = | Some (S (_,d)) -> aux name l d checked | _ -> failwith "invariant : all compound types must have a correct origin and type at this step" end - | CompoundType {origin=None;decl_ty=None;_} -> E.throw (Error.make l "follow type not called") + | CompoundType {origin=None;decl_ty=None;_} -> M.E.throw Logging.(make_msg l "follow type not called") | _ -> return () ) s.fields in aux name l proto [] - let rename_var_exp (f: string -> string) (e: _ AstHir.expression) = - let open AstHir in - let rec aux (e : _ expression) = - let buildExp = buildExp e.info in - match e.exp with - | Variable id -> buildExp @@ Variable (f id) - | Deref e -> let e = aux e in buildExp @@ Deref e - | StructRead (mod_loc,e, id) -> let e = aux e in buildExp @@ StructRead(mod_loc,e,id) - | ArrayRead (e1, e2) -> - let e1 = aux e1 in - let e2 = aux e2 in - buildExp @@ ArrayRead (e1,e2) - | Literal _ as l -> buildExp l - | UnOp (op, e) -> let e = aux e in buildExp @@ UnOp (op,e) - | BinOp(op,e1,e2)-> - let e1 = aux e1 in - let e2 = aux e2 in - buildExp @@ BinOp(op,e1,e2) - | Ref (b, e) -> - let e = aux e in buildExp @@ Ref(b,e) - | ArrayStatic el -> let el = List.map aux el in buildExp @@ ArrayStatic el - | StructAlloc (origin,id, m) -> let m = List.map (fun (n,e) -> n,aux e) m in buildExp @@ StructAlloc (origin,id,m) - | EnumAlloc (id, el) -> let el = List.map aux el in buildExp @@ EnumAlloc (id,el) - | MethodCall (mod_loc, id, el) -> let el = List.map aux el in buildExp @@ MethodCall (mod_loc,id,el) - in aux e +let rename_var_exp (f: string -> string) (e: _ AstHir.expression) = + let open AstHir in + let rec aux (e : _ expression) = + let buildExp = buildExp e.info in + match e.exp with + | Variable id -> buildExp @@ Variable (f id) + | Deref e -> let e = aux e in buildExp @@ Deref e + | StructRead (mod_loc,e, id) -> let e = aux e in buildExp @@ StructRead(mod_loc,e,id) + | ArrayRead (e1, e2) -> + let e1 = aux e1 in + let e2 = aux e2 in + buildExp @@ ArrayRead (e1,e2) + | Literal _ as l -> buildExp l + | UnOp (op, e) -> let e = aux e in buildExp @@ UnOp (op,e) + | BinOp(op,e1,e2)-> + let e1 = aux e1 in + let e2 = aux e2 in + buildExp @@ BinOp(op,e1,e2) + | Ref (b, e) -> + let e = aux e in buildExp @@ Ref(b,e) + | ArrayStatic el -> let el = List.map aux el in buildExp @@ ArrayStatic el + | StructAlloc (origin,id, m) -> let m = List.map (fun (n,(l,e)) -> n,(l,aux e)) m in buildExp @@ StructAlloc (origin,id,m) + | EnumAlloc (id, el) -> let el = List.map aux el in buildExp @@ EnumAlloc (id,el) + | MethodCall (mod_loc, id, el) -> let el = List.map aux el in buildExp @@ MethodCall (mod_loc,id,el) + in aux e let rename_var_stmt (f:string -> string) s = let open AstHir in @@ -173,12 +225,109 @@ let check_non_cyclic_struct (name:string) (l,proto) env : unit E.t = | Loop c -> let c = aux c in buildStmt (Loop c) | Break -> buildStmt Break | Case(e, _cases) -> let e = rename_var_exp f e in buildStmt (Case (e, [])) - | Invoke (var, mod_loc, id, el) -> - let el = List.map (rename_var_exp f) el in - let var = MonadOption.M.fmap f var in - buildStmt @@ Invoke(var,mod_loc, id,el) + | Invoke i -> + let args = List.map (rename_var_exp f) i.args in + let ret_var = MonadOption.M.fmap f i.ret_var in + buildStmt @@ Invoke {i with ret_var;args} | Return e -> let e = MonadOption.M.fmap (rename_var_exp f) e in buildStmt @@ Return e | Block c -> let c = aux c in buildStmt (Block c) | Skip -> buildStmt Skip in - aux s \ No newline at end of file + aux s + + +let resolve_names (sm : ('a,'b) SailModule.methods_processes SailModule.t) = + let module ES = struct + module S = MonadState.M(struct type t = D.t end) + include Logging.MakeTransformer(S) + let update_env f = S.update f |> lift + let set_env e = S.set e |> lift + let get_env = S.get |> lift + end + in + let open UseMonad(ES) in + let module TEnv = MakeFromSequencable(SailModule.DeclEnv.TypeSeq) in + let module SEnv = MakeFromSequencable(SailModule.DeclEnv.StructSeq) in + (* let module MEnv = F.MakeFromSequencable(SM.DeclEnv.MethodSeq) in + let module PEnv = F.MakeFromSequencable(SM.DeclEnv.ProcessSeq) in *) + let open SailModule.DeclEnv in + + (* resolving aliases *) + let sm = ( + + let* declEnv = ES.get_env in + + let* () = + TEnv.iter ( + fun (id,({ty; _} as def)) -> + let* ty = match ty with + | None -> return None + | Some t -> + let* env = ES.get_env in + let* t,env = (follow_type t env) |> ES.S.lift in + let+ () = ES.set_env env in Some t + in + ES.update_env (update_decl id {def with ty} (Self Type)) + ) (get_own_decls declEnv |> get_decls Type) in + + let* env = ES.get_env in + + let* () = SEnv.iter ( + fun (id,(l,{fields; generics})) -> + let* fields = ListM.map ( + fun (name,(l,t,n)) -> + let* env = ES.get_env in + let* t,env = (follow_type t env) |> ES.S.lift in + let+ () = ES.set_env env in + name,(l,t,n) + ) fields + in + let proto = l,{fields;generics} in + ES.update_env (update_decl id proto (Self Struct)) + ) (get_own_decls env |> get_decls Struct) + in + + let* methods = ListM.map ( + fun ({m_proto;m_body} as m) -> + let* rtype = match m_proto.rtype with + | None -> return None + | Some t -> + let* env = ES.get_env in + let* t,env = (follow_type t env) |> ES.S.lift in + let+ () = ES.set_env env in Some t + in + let* params = ListM.map ( + fun (({ty;_}:param) as p) -> + let* env = ES.get_env in + let* ty,env = (follow_type ty env) |> ES.S.lift in + let+ () = ES.set_env env in {p with ty} + ) m_proto.params in + let m = {m with m_proto={m_proto with params; rtype}} in + let true_name = (match m_body with Left (sname,_) -> sname | Right _ -> m_proto.name) in + let+ () = ES.update_env (update_decl m_proto.name ((m_proto.pos,true_name), defn_to_proto (Method m)) (Self Method)) + in m + ) sm.body.methods in + + let* processes = ListM.map ( + fun proc -> + let* p_params = ListM.map ( + fun (({ty;_}:param) as p) -> + let* env = ES.get_env in + let* ty,env = (follow_type ty env) |> ES.S.lift in + let+ () = ES.set_env env in {p with ty} + ) proc.p_interface.p_params in + let p = {proc with p_interface={proc.p_interface with p_params}} in + let+ () = ES.update_env (update_decl p.p_name (p.p_pos, defn_to_proto (Process p)) (Self Process)) + in p + ) sm.body.processes in + + (* at this point, all types must have an origin *) + + + let* declEnv = ES.get_env in + let+ () = SEnv.iter (fun (id,proto) -> check_non_cyclic_struct id proto declEnv |> ES.S.lift) (get_own_decls declEnv |> get_decls Struct) in + + (* Logs.debug (fun m -> m "%s" @@ string_of_declarations declEnv); *) + {sm with body=SailModule.{methods; processes}; declEnv} + ) sm.declEnv |> fst in + sm \ No newline at end of file diff --git a/src/passes/ir/sailHir/pp_hir.ml b/src/passes/ir/sailHir/pp_hir.ml index 6cdf0ff..05b04f3 100644 --- a/src/passes/ir/sailHir/pp_hir.ml +++ b/src/passes/ir/sailHir/pp_hir.ml @@ -21,7 +21,7 @@ let rec ppPrintExpression (pf : Format.formatter) (e : expression) : unit = fprintf pf "[%a]" (pp_print_list ~pp_sep:pp_comma ppPrintExpression) el |StructAlloc (_,id, m) -> - let pp_field pf (x, y) = fprintf pf "%s:%a" x ppPrintExpression y in + let pp_field pf (x, (_,y)) = fprintf pf "%s:%a" x ppPrintExpression y in fprintf pf "%s{%a}" (snd id) (pp_print_list ~pp_sep:pp_comma pp_field) m @@ -47,11 +47,11 @@ let rec ppPrintStatement (pf : Format.formatter) (s : statement) : unit = match | Loop c -> fprintf pf "\nloop {%a\n}" ppPrintStatement c | Break -> fprintf pf "break;" | Case(_e, _cases) -> () -| Invoke (var, mod_loc, (_,id), el) -> fprintf pf "\n%a%a%s(%a);" - (pp_print_option (fun fmt v -> fprintf fmt "%s = " v)) var - (pp_print_option (fun fmt (_,ml) -> fprintf fmt "%s::" ml)) mod_loc - id - (pp_print_list ~pp_sep:pp_comma ppPrintExpression) el +| Invoke i -> fprintf pf "\n%a%a%s(%a);" + (pp_print_option (fun fmt v -> fprintf fmt "%s = " v)) i.ret_var + (pp_print_option (fun fmt (_,ml) -> fprintf fmt "%s::" ml)) i.import + (snd i.id) + (pp_print_list ~pp_sep:pp_comma ppPrintExpression) i.args | Return e -> fprintf pf "\nreturn %a;" (pp_print_option ppPrintExpression) e | Block c -> fprintf pf "\n{\n@[ %a @]\n}" ppPrintStatement c | Skip -> () diff --git a/src/passes/ir/sailMir/mir.ml b/src/passes/ir/sailMir/mir.ml index d018a25..23f0387 100644 --- a/src/passes/ir/sailMir/mir.ml +++ b/src/passes/ir/sailMir/mir.ml @@ -31,8 +31,8 @@ struct | Deref e -> rexpr e | ArrayRead (e1, e2) -> let+ e1' = lexpr e1 and* e2' = rexpr e2 in buildExp lt (ArrayRead(e1',e2')) | StructRead (origin,e,field) -> let+ e = lexpr e in buildExp lt (StructRead (origin,e,field)) - | Ref _ -> M.error @@ Error.make (fst lt) "todo" - | _ -> M.error @@ Error.make (fst lt) @@ "thir didn't lower correctly this expression" + | Ref _ -> M.error Logging.(make_msg (fst lt) "todo") + | _ -> M.error Logging.(make_msg (fst lt) @@ "thir didn't lower correctly this expression") and rexpr (e : Thir.expression) : expression M.t = let lt = e.info in let open AstHir in @@ -51,17 +51,17 @@ struct buildExp lt (StructRead (origin,exp,field)) | StructAlloc (origin,id, fields) -> - let+ fields = ListM.map (pairMap2 (rexpr)) fields in + let+ fields = ListM.map (rexpr |> pairMap2 |> pairMap2) fields in buildExp lt (StructAlloc(origin,id,fields)) | MethodCall _ - | _ -> M.error @@ Error.make (fst lt) @@ "thir didn't lower correctly this expression" + | _ -> M.error @@ Logging.(make_msg (fst lt) @@ "thir didn't lower correctly this expression") open UseMonad(M.E) - let lower_method (body,_ : m_in * method_sig) env (_sm: (m_in,p_in) SailModule.methods_processes SailModule.t) : (m_out * SailModule.DeclEnv.t) M.E.t = + let lower_method (body,_ : m_in * method_sig) (env,tenv) (_sm: (m_in,p_in) SailModule.methods_processes SailModule.t) : (m_out * SailModule.DeclEnv.t * _) M.E.t = let rec aux (s : Thir.statement) : m_out M.t = let open UseMonad(M) in let loc = s.info in @@ -73,10 +73,11 @@ struct [{location=loc; mut; id; varType=ty}],bb | DeclVar(mut, id, Some ty, Some e) -> + let* id_ty = M.get_type_id ty in let* expression = rexpr e in let* id = M.fresh_scoped_var >>| get_scoped_var id in let* () = M.declare_var loc id {ty;mut;id;loc} in - let target = AstHir.buildExp (loc,ty) (Variable id) in + let target = AstHir.buildExp (loc,id_ty) (Variable id) in let+ bn = assignBasicBlock loc {location=loc; target; expression } in [{location=loc; mut; id=id; varType=ty}],bn (* ++ other statements *) @@ -114,17 +115,17 @@ struct (d, l) | Break -> - let* env = M.get_env in + let* env,_ = M.get_env in let bb = {location=loc; forward_info=env; backward_info = (); assignments=[]; predecessors=LabelSet.empty;terminator=Some Break} in let+ cfg = singleBlock bb in ([],cfg) - | Invoke (target, ((_,mname) as origin), (l,id), el) -> - let* ((_,realname),_) = M.throw_if_none (Error.make loc @@ Fmt.str "Compiler Error : function '%s' must exist" id) - (SailModule.DeclEnv.find_decl id (Specific (mname,Method)) (snd env)) + | Invoke i -> + let* ((_,realname),_) = M.throw_if_none Logging.(make_msg loc @@ Fmt.str "Compiler Error : function '%s' must exist" (snd i.id)) + (SailModule.DeclEnv.find_decl (snd i.id) (Specific (snd i.import,Method)) (snd env)) in - let* el' = ListM.map rexpr el in - let+ invoke = buildInvoke loc origin (l,realname) target el' in + let* args = ListM.map rexpr i.args in + let+ invoke = buildInvoke loc i.import (fst i.id,realname) i.ret_var args in ([], invoke) | Return e -> @@ -134,7 +135,7 @@ struct let+ ret = buildReturn loc e in ([], ret) - | Case _ -> M.error @@ Error.make loc "unimplemented" + | Case _ -> M.error Logging.(make_msg loc "unimplemented") | Block s -> let* env = M.get_env in @@ -143,11 +144,11 @@ struct res in - let+ res = M.run aux body (fst env) in - res,(snd env) + let+ res = M.run aux body (fst env,tenv) in + res,(snd env),tenv - let preprocess = Error.Logger.pure + let preprocess = Logging.Logger.pure - let lower_process (c:p_in process_defn) env _ = M.E.pure (c.p_body,snd env) + let lower_process (c:p_in process_defn) ((_,env),tenv) _ = M.E.pure (c.p_body,env,tenv) end ) diff --git a/src/passes/ir/sailMir/mirMonad.ml b/src/passes/ir/sailMir/mirMonad.ml index ec11d87..b3af329 100644 --- a/src/passes/ir/sailMir/mirMonad.ml +++ b/src/passes/ir/sailMir/mirMonad.ml @@ -2,8 +2,13 @@ open AstMir open Common module M = struct - module E = Error.Logger - module ES = MonadState.T(E)(VE) + module E = Logging.Logger + module ES = struct + include MonadState.T(E)(struct type t = VE.t * Env.TypeEnv.t end ) + let get_type_from_id id = bind get (fun (_,te) -> Env.TypeEnv.get_from_id id te |> pure) + + let get_type_id ty : string t = fun (e,te) -> let id,te = Env.TypeEnv.get_id ty te in E.pure (id,(e,te)) + end module C = struct type t = {scoped_vars : int; bblocks: int; vars : int} end module ESC = MonadState.T(ES)(C) @@ -16,9 +21,9 @@ module M = struct let get_env = ES.get |> ESC.lift let update_env f = ES.update f |> ESC.lift let set_env e = ES.set e |> ESC.lift - let update_var l f id = ES.update (VE.update_var l id f) |> ESC.lift - let declare_var l id v = ES.update (VE.declare_var id (l,v)) |> ESC.lift - let find_var id = bind get_env (fun e -> VE.get_var id e |> pure) + let update_var l f id = ES.update (fun (e,t) -> E.bind (VE.update_var l id f e) (fun e -> E.pure (e,t))) |> ESC.lift + let declare_var l id v = ES.update (fun (e,t) -> E.bind (VE.declare_var id (l,v) e) (fun e -> E.pure (e,t))) |> ESC.lift + let find_var id = bind get_env (fun (e,_) -> VE.get_var id e |> pure) let throw_if_none opt e = E.throw_if_none opt e |> ES.lift |> ESC.lift let run (f:'a -> 'b t) x env = E.bind (f x {scoped_vars=0; bblocks=0;vars=0} env) (fun ((res,_),_) -> E.pure res) @@ -31,5 +36,6 @@ module M = struct let fresh_var = bind ESC.get (fun n -> bind (fun n -> ESC.set {n with vars=n.vars + 1} n) (fun () -> pure n.vars)) let current_var = bind ESC.get (fun n -> pure n.vars) - + let get_type_from_id id = ES.get_type_from_id id |> ESC.lift + let get_type_id t = ES.get_type_id t |> ESC.lift end diff --git a/src/passes/ir/sailMir/mirUtils.ml b/src/passes/ir/sailMir/mirUtils.ml index 74f2ecd..f6f755d 100644 --- a/src/passes/ir/sailMir/mirUtils.ml +++ b/src/passes/ir/sailMir/mirUtils.ml @@ -16,7 +16,7 @@ let rename (src : label) (tgt : label) (t : terminator) : terminator = | _ -> t let emptyBasicBlock (location:loc) : cfg M.t = - let+ lbl = M.fresh_block and* env = M.get_env in + let+ lbl = M.fresh_block and* env,_ = M.get_env in { input = lbl; output = lbl; @@ -32,7 +32,7 @@ let singleBlock (bb : _ basicBlock) : cfg M.t = } let assignBasicBlock (location : loc) (a : assignment) : cfg M.t = - let* env = M.get_env in + let* env,_ = M.get_env in let bb = {assignments = [a]; predecessors = LabelSet.empty; forward_info=env; backward_info = (); location; terminator=None} in let+ lbl = M.fresh_block in { @@ -51,7 +51,7 @@ let buildSeq (cfg1 : cfg) (cfg2 : cfg) : cfg M.t = and right = BlockMap.find cfg2.input cfg2.blocks in match left.terminator with - | Some (Invoke _) -> let+ () = M.error @@ Error.make left.location "invalid output node" in cfg1 + | Some (Invoke _) -> let+ () = M.error Logging.(make_msg left.location "invalid output node") in cfg1 | Some _ -> { input = cfg1.input; @@ -59,7 +59,7 @@ let buildSeq (cfg1 : cfg) (cfg2 : cfg) : cfg M.t = blocks = BlockMap.union assert_disjoint cfg1.blocks cfg2.blocks } |> M.ok | None -> - let+ env = M.get_env in + let+ env,_ = M.get_env in let bb = {assignments = left.assignments@right.assignments; predecessors = left.predecessors; forward_info=env; backward_info = (); location= right.location; terminator = right.terminator} in { input = cfg1.input; @@ -79,7 +79,7 @@ let buildSeq (cfg1 : cfg) (cfg2 : cfg) : cfg M.t = } let addGoto (lbl : label) (cfg : cfg) : cfg M.t = - let* env = M.get_env in + let* env,_ = M.get_env in let bb = {assignments=[]; predecessors=LabelSet.empty; forward_info=env; backward_info = (); location = dummy_pos; terminator=Some (Goto lbl)} in let* cfg' = singleBlock bb in buildSeq cfg cfg' @@ -97,7 +97,7 @@ let addPredecessors (lbls : label list) (cfg : cfg) : cfg = let buildIfThen (location : loc) (e : expression) (cfg : cfg) : cfg M.t = let* outputLbl = M.fresh_block and* inputLbl = M.fresh_block in let* goto = addGoto outputLbl cfg >>| addPredecessors [inputLbl] in - let+ env = M.get_env in + let+ env,_ = M.get_env in let inputBlock = {assignments = []; predecessors = LabelSet.empty ; forward_info=env; backward_info = (); location; terminator = Some (SwitchInt {choice=e; paths=[(0,outputLbl)]; default=cfg.input})} in let outputBlock = {assignments = []; predecessors = LabelSet.of_list [inputLbl;cfg.input] ; forward_info=env; backward_info = (); location; terminator = None} in { @@ -114,7 +114,7 @@ let buildIfThenElse (location : loc) (e : expression) (cfgTrue : cfg) (cfgFalse let* gotoF = addGoto outputLbl cfgFalse >>| addPredecessors [inputLbl] and* gotoT = addGoto outputLbl cfgTrue >>| addPredecessors [inputLbl] in - let+ env = M.get_env in + let+ env,_ = M.get_env in let inputBlock = {assignments = []; predecessors = LabelSet.empty ; forward_info=env; backward_info = (); location; terminator = Some (SwitchInt {choice=e; paths=[(0,cfgFalse.input)]; default=cfgTrue.input})} and outputBlock = {assignments = []; forward_info=env; backward_info = (); predecessors = LabelSet.of_list [cfgTrue.output;cfgFalse.output] ; location; terminator = None} in @@ -129,7 +129,7 @@ let buildIfThenElse (location : loc) (e : expression) (cfgTrue : cfg) (cfgFalse let buildSwitch (choice : expression) (blocks : (int * cfg) list) (cfg : cfg): cfg M.t = - let* env = M.get_env in + let* env,_ = M.get_env in let paths = List.map (fun (value, cfg) -> (value, cfg.input)) blocks in let bb1 = {assignments = []; predecessors = LabelSet.empty ; forward_info=env; backward_info = (); location = dummy_pos; terminator = Some (SwitchInt {choice ; paths; default=cfg.input})} and bb2 = {assignments = []; predecessors = LabelSet.empty ; forward_info=env; backward_info = (); location = dummy_pos; terminator = None} in @@ -146,7 +146,7 @@ let buildSwitch (choice : expression) (blocks : (int * cfg) list) (cfg : cfg): c } let buildLoop (location : loc) (cfg : cfg) : cfg M.t = - let* env = M.get_env in + let* env,_ = M.get_env in let* inputLbl = M.fresh_block and* outputLbl = M.fresh_block in (* all break terminators within the body go to outputLbl *) @@ -167,7 +167,7 @@ let buildLoop (location : loc) (cfg : cfg) : cfg M.t = } let buildInvoke (l : loc) (origin:l_str) (id : l_str) (target : string option) (el : expression list) : cfg M.t = - let* env = match target with + let* env,_ = match target with | None -> M.get_env | Some tid -> M.update_var l tid assign_var >>= fun () -> M.get_env in @@ -190,7 +190,7 @@ let buildInvoke (l : loc) (origin:l_str) (id : l_str) (target : string option) ( let buildReturn (location : loc) (e : expression option) : cfg M.t = - let* env = M.get_env in + let* env,_ = M.get_env in let returnBlock = {assignments=[]; predecessors = LabelSet.empty ; forward_info=env; backward_info = (); location; terminator= Some (Return e)} in let+ returnLbl = M.fresh_block in { @@ -260,4 +260,3 @@ module Traversal(M : Monad) = struct end - diff --git a/src/passes/ir/sailMir/pp_mir.ml b/src/passes/ir/sailMir/pp_mir.ml index af420b4..e1f7509 100644 --- a/src/passes/ir/sailMir/pp_mir.ml +++ b/src/passes/ir/sailMir/pp_mir.ml @@ -18,10 +18,9 @@ let rec ppPrintExpression (pf : Format.formatter) (e : AstMir.expression) : unit Format.fprintf pf "[%a]" (Format.pp_print_list ~pp_sep:pp_comma ppPrintExpression) el |StructAlloc (_,id, m) -> - let pp_field pf (x, y) = Format.fprintf pf "%s:%a" x ppPrintExpression y in + let pp_field pf (x, (_ , y)) = Format.fprintf pf "%s:%a" x ppPrintExpression y in Format.fprintf pf "%s{%a}" (snd id) - (Format.pp_print_list ~pp_sep:pp_comma pp_field) - m + (Format.pp_print_list ~pp_sep:pp_comma pp_field) m | EnumAlloc (id,el) -> Format.fprintf pf "[%s(%a)]" (snd id) (Format.pp_print_list ~pp_sep:pp_comma ppPrintExpression) el diff --git a/src/passes/ir/sailThir/thir.ml b/src/passes/ir/sailThir/thir.ml index ad4d213..fa93a4c 100644 --- a/src/passes/ir/sailThir/thir.ml +++ b/src/passes/ir/sailThir/thir.ml @@ -1,16 +1,14 @@ open Common open TypesCommon -open Error +open Logging open Monad open IrHir open AstHir open SailParser -open ThirMonad open ThirUtils - -open MonadSyntax(ES) -open MonadFunctions(ES) -open MonadOperator(ES) +open M +open UseMonad(M) +module SM = SailModule type expression = ThirUtils.expression @@ -26,11 +24,14 @@ struct type p_out = p_in - let rec lower_lexp (e : Hir.expression) : expression ES.t = - let rec aux (e:Hir.expression) : expression ES.t = - let loc = e.info in match e.exp with + let rec lower_lexp (e : Hir.expression) : expression M.t = + let rec aux (e:Hir.expression) : expression M.t = + let loc = e.info in match e.exp with | Variable id -> - let+ (_,(_,t)) = ES.get_var id >>= ES.throw_if_none (Error.make loc @@ Printf.sprintf "unknown variable %s" id) in + let* _,t = M.get_var id >>= M.throw_if_none (make_msg loc @@ Printf.sprintf "unknown variable %s" id) in + let* venv,tenv = M.get_env in + let t,tenv = t tenv in + let+ () = M.set_env (venv,tenv) in buildExp (loc,t) @@ Variable id | Deref e -> let* e = lower_rexp e in @@ -40,156 +41,197 @@ struct | Ref (_,r) -> return @@ buildExp r.info @@ Deref e | _ -> return e end - | ArrayRead (array_exp,idx) -> let* array_exp = aux array_exp and* idx = lower_rexp idx in + | ArrayRead (array_exp,idx) -> + let* array_exp = aux array_exp and* idx_exp = lower_rexp idx in + let* array_ty = M.get_type_from_id (array_exp.info) + and* idx_ty = M.get_type_from_id (idx_exp.info) in begin - match snd array_exp.info with - | ArrayType (t,sz) -> - let* _ = matchArgParam (idx.info) (Int 32) in + match array_ty with + | l,ArrayType (t,sz) -> + let* t = M.get_type_id t in + let* _ = matchArgParam l idx_ty (dummy_pos,Int 32) |> M.ESC.lift |> M.lift in begin (* can do a simple oob check if the type is an int literal *) match idx.exp with | Literal (LInt n) -> - ES.throw_if (Error.make (fst idx.info) @@ Printf.sprintf "index out of bounds : must be between 0 and %i (got %s)" + M.throw_if (make_msg (fst idx_exp.info) @@ Printf.sprintf "index out of bounds : must be between 0 and %i (got %s)" (sz - 1) Z.(to_string n.l) ) Z.( n.l < ~$0 || n.l > ~$sz) | _ -> return () - end >>| fun () -> buildExp (loc,t) @@ ArrayRead (array_exp,idx) - | _ -> ES.throw (Error.make loc "not an array !") + end >>| fun () -> buildExp (loc,t) @@ ArrayRead (array_exp,idx_exp) + | _ -> M.throw (make_msg loc "not an array !") end | StructRead (origin,e,(fl,field)) -> let* e = lower_lexp e in + let* ty_e = M.get_type_from_id e.info in let+ origin,t = - begin - match e.info with - | _, CompoundType {name=l,name;decl_ty=Some S ();_} -> - let* origin,(_,strct) = find_struct_source (l,name) origin in - let+ _,t,_ = List.assoc_opt field strct.fields - |> ES.throw_if_none (Error.make fl @@ Fmt.str "field '%s' is not part of structure '%s'" field name) - in origin,t - | l,t -> - let* str = string_of_sailtype_thir (Some t) in - ES.throw (Error.make l @@ Fmt.str "expected a structure but got type '%s'" str) - end + begin + match ty_e with + | _,CompoundType {name=l,name;decl_ty=Some S ();_} -> + let* origin,(_,strct) = find_struct_source (l,name) origin |> M.ESC.lift |> M.lift in + let* _,t,_ = List.assoc_opt field strct.fields + |> M.throw_if_none (make_msg fl @@ Fmt.str "field '%s' is not part of structure '%s'" field name) + in + let+ t_id = M.get_type_id t in + origin,t_id + | l,t -> + let* str = string_of_sailtype_thir (Some (l,t)) |> M.ESC.lift |> M.lift in + M.throw (make_msg l @@ Fmt.str "expected a structure but got type '%s'" str) + end in let x : expression = buildExp (loc,t) (StructRead (origin,e,(fl,field))) in x - | _ -> ES.throw (Error.make loc "not a lvalue !") + | _ -> M.throw (make_msg loc "not a lvalue !") in aux e - and lower_rexp (e : Hir.expression) : expression ES.t = - let rec aux (e:Hir.expression) : expression ES.t = + and lower_rexp (e : Hir.expression) : expression M.t = + let rec aux (e:Hir.expression) : expression M.t = let loc = e.info in match e.exp with | Variable id -> - let+ (_,(_,t)) = ES.get_var id >>= ES.throw_if_none (Error.make loc @@ Printf.sprintf "unknown variable %s" id) in + let* _,t = M.get_var id >>= M.throw_if_none (make_msg loc @@ Printf.sprintf "unknown variable %s" id) in + let* venv,tenv = M.get_env in + let t,tenv = t tenv in + let+ () = M.set_env (venv,tenv) in buildExp (loc,t) @@ Variable id + | Literal li -> - let+ () = + let* () = match li with | LInt t -> - let* () = ES.throw_if Error.(make loc "signed integers use a minimum of 2 bits") (t.size < 2) in + let* () = M.throw_if Logging.(make_msg loc "signed integers use a minimum of 2 bits") (t.size < 2) in let max_int = Z.( ~$2 ** t.size - ~$1) in let min_int = Z.( ~-max_int + ~$1) in - ES.throw_if + M.throw_if ( - Error.make loc @@ Fmt.str "type suffix can't contain int literal : i%i is between %s and %s but literal is %s" + make_msg loc @@ Fmt.str "type suffix can't contain int literal : i%i is between %s and %s but literal is %s" t.size (Z.to_string min_int) (Z.to_string max_int) (Z.to_string t.l) ) Z.(lt t.l min_int || gt t.l max_int) | _ -> return () in - let t = sailtype_of_literal li in + let+ t = M.get_type_id (sailtype_of_literal li) in buildExp (loc,t) @@ Literal li + | UnOp (op,e) -> let+ e = aux e in buildExp e.info @@ UnOp (op,e) + | BinOp (op,le,re) -> let* le = aux le in let* re = aux re in - let lt = le.info and rt = re.info in - let+ t = check_binop op lt rt |> ES.recover (snd lt) in + let+ t = check_binop op le.info re.info |> M.recover (snd le.info) in buildExp (loc,t) @@ BinOp (op,le,re) - | Ref (mut,e) -> let+ e = lower_lexp e in - let t = RefType (snd e.info,mut) in + | Ref (mut,e) -> + let* e = lower_lexp e in + let* e_t = M.get_type_from_id e.info in + let+ t = M.get_type_id (dummy_pos,RefType (e_t,mut)) in buildExp (loc,t) @@ Ref(mut, e) + | ArrayStatic el -> let* first_t = aux (List.hd el) in - let first_t = snd first_t.info in - let* el = ListM.map ( - fun e -> let+ e = aux e in - matchArgParam e.info first_t >>| fun _ -> e + let* first_t = M.get_type_from_id first_t.info in + let* el = ListM.map (fun e -> + let* e = aux e in + let+ e_t = M.get_type_from_id e.info in + matchArgParam (fst e.info) e_t first_t |> M.ESC.lift |> M.lift >>| fun _ -> e ) el in - let+ el = ListM.sequence el in - let t = ArrayType (first_t, List.length el) in - buildExp (loc,t) (ArrayStatic el) - - | MethodCall ((l,name) as lid,source,el) -> - let* (el: expression list) = ListM.map lower_rexp el in - let* mod_loc,(_realname,m) = find_function_source e.info None lid source el in - let+ ret = ES.throw_if_none (Error.make e.info "methods in expressions should return a value") m.ret in - buildExp (loc,ret) (MethodCall ((l,name),mod_loc,el)) - + let* el = ListM.sequence el in + let t = dummy_pos,ArrayType (first_t, List.length el) in + let+ t_id = M.get_type_id t in + buildExp (loc,t_id) (ArrayStatic el) + + | MethodCall (lid,source,args) -> + let* (args: expression list) = ListM.map lower_rexp args in + let* mod_loc,(_realname,m) = find_function_source e.info None lid source args |> M.ESC.lift |> M.lift in + let* ret = M.throw_if_none (make_msg e.info "methods in expressions should return a value") m.ret in + let* ret_t = M.get_type_id ret in + let* x = M.fresh_fvar in + M.write {info=loc; stmt=DeclVar (false, x, Some ret, None)} >>= fun () -> + M.write {info=loc; stmt=Invoke {args;id=lid; ret_var = Some x;import=mod_loc}} >>| fun () -> + buildExp (loc,ret_t) (Variable x) | ArrayRead _ -> lower_lexp e (* todo : some checking *) | Deref _ -> lower_lexp e (* todo : some checking *) | StructRead _ -> lower_lexp e (* todo : some checking *) | StructAlloc (origin,name,fields) -> - let* origin,(_l,strct) = find_struct_source name origin in + let* origin,(_l,strct) = find_struct_source name origin |> M.ESC.lift |> M.lift in let struct_fields = List.to_seq strct.fields in - let fields = FieldMap.(merge ( - fun n f1 f2 -> match f1,f2 with - | Some _, Some e -> Some(let+ e = lower_rexp e in (n,e)) - | None,None -> None - | None, Some (e:Hir.expression) -> Some (ES.throw @@ Error.make e.info @@ Fmt.str "no field '%s' in struct '%s'" n (snd name)) - | Some _, None -> Some (ES.throw @@ Error.make loc @@ Fmt.str "missing field '%s' from struct '%s'" n (snd name)) - ) (struct_fields |> of_seq) (fields |> List.to_seq |> of_seq ) |> to_seq) in + let fields = FieldMap.( + merge + ( + fun n f1 f2 -> match f1,f2 with + | Some _, Some (l,e) -> Some (let+ e = lower_rexp e in n,(l,e)) + | None,None -> None + | None, Some (l,_) -> Some (M.throw @@ make_msg l @@ Fmt.str "no field '%s' in struct '%s'" n (snd name)) + | Some _, None -> Some (M.throw @@ make_msg loc @@ Fmt.str "missing field '%s' from struct '%s'" n (snd name)) + ) + (struct_fields |> of_seq) + (fields |> List.to_seq |> of_seq) + |> to_seq + ) in - let* () = ES.throw_if (Error.make (fst name) "missing fields ") Seq.(length fields < Seq.length struct_fields) in + let* () = M.throw_if (make_msg (fst name) "missing fields ") Seq.(length fields < Seq.length struct_fields) in let* fields = SeqM.sequence (Seq.map snd fields) in - let+ () = SeqM.iter2 (fun (_name1,(e:expression)) (_name2,(_,t,_)) -> matchArgParam e.info t >>| fun _ -> ()) - fields - struct_fields + let* () = SeqM.iter2 (fun (_name1,(l,(e:expression))) (_name2,(_,t,_)) -> + let* e_t = M.get_type_from_id e.info in + matchArgParam l e_t t |> M.ESC.lift |> M.lift >>| fun _ -> () + ) + fields + struct_fields in - let ty = CompoundType {origin= Some origin;decl_ty=Some (S ()); name; generic_instances=[]} in - (buildExp (loc,ty) (StructAlloc (origin,name, List.of_seq fields))) + let ty = dummy_pos,CompoundType {origin= Some origin;decl_ty=Some (S ()); name; generic_instances=[]} in + let+ ty = M.get_type_id ty in + (buildExp (loc,ty) (StructAlloc (origin,name, List.of_seq fields) )) - | EnumAlloc _ -> ES.throw (Error.make loc "todo enum alloc ") + | EnumAlloc _ -> M.throw (make_msg loc "todo enum alloc ") in aux e - let lower_method (body,proto : _ * method_sig) (env:THIREnv.t) _ : (m_out * THIREnv.D.t) E.t = - let log_and_skip e = ES.log e >>| fun () -> buildStmt e.where Skip - in + let lower_method (body,proto : _ * method_sig) (env,tenv:THIREnv.t * _) _ : (m_out * THIREnv.D.t * _) M.E.t = + let open UseMonad(M.ESC) in + let module MF = MonadFunctions(M) in + let log_and_skip e = M.ESC.log e >>| fun () -> buildStmt e.where Skip in + - let rec aux s : m_out ES.t = + let rec aux s : m_out M.ESC.t = let loc = s.info in let buildStmt = buildStmt loc in + let buildSeq s1 s2 = {info=loc; stmt = Seq (s1, s2)} in + let buildSeqStmt s1 s2 = buildSeq s1 @@ buildStmt s2 in + match s.stmt with | DeclVar (mut, id, opt_t, (opt_exp : Hir.expression option)) -> - let* ((ty,opt_e):sailtype * 'b) = + let* ((ty,opt_e,s):sailtype * 'b * 'c) = begin match opt_t,opt_exp with | Some t, Some e -> - let* e = lower_rexp e in - matchArgParam e.info t >>| fun _ -> t,Some e + let* e,s = lower_rexp e in + let* e_t = M.ES.get_type_from_id e.info |> M.ESC.lift in + matchArgParam (fst e.info) e_t t |> M.ESC.lift >>| fun _ -> t,Some e,s | None,Some e -> - let+ e = lower_rexp e in - (snd e.info),Some e - | Some t,None -> return (t,None) - | None,None -> ES.throw (Error.make loc "can't infere type with no expression") + let* e,s = lower_rexp e in + let+ e_t = M.ES.get_type_from_id e.info |> M.ESC.lift in + e_t,Some e,s + | Some t,None -> return (t,None,buildStmt Skip) + | None,None -> M.ESC.throw (make_msg loc "can't infere type with no expression") end in - ES.update (fun st -> THIREnv.declare_var id (loc,(mut,ty)) st) - >>| fun () -> (buildStmt @@ DeclVar (mut,id,Some ty,opt_e)) + let* ty_id = M.ES.get_type_id ty |> M.ESC.lift in + let decl_var = THIREnv.declare_var id (loc,fun e -> ty_id,e) in + M.ESC.update_env (fun (st,t) -> M.E.(bind (decl_var st) (fun st -> pure (st,t)))) + >>| fun () -> (buildSeqStmt s @@ DeclVar (mut,id,Some ty,opt_e)) | Assign(e1, e2) -> - let* e1 = lower_lexp e1 - and* e2 = lower_rexp e2 in - matchArgParam e2.info (snd e1.info) >>| - fun _ -> buildStmt (Assign(e1, e2)) + let* e1,s1 = lower_lexp e1 + and* e2,s2 = lower_rexp e2 in + let* e1_t = M.ES.get_type_from_id e1.info |> M.ESC.lift + and* e2_t = M.ES.get_type_from_id e2.info |> M.ESC.lift in + matchArgParam (fst e2.info) e2_t e1_t |> M.ESC.lift >>| + fun _ -> buildSeq s1 @@ buildSeqStmt s2 @@ Assign(e1, e2) | Seq(c1, c2) -> let* c1 = aux c1 in @@ -198,13 +240,14 @@ struct | If(cond_exp, then_s, else_s) -> - let* cond_exp = lower_rexp cond_exp in - let* _ = matchArgParam cond_exp.info Bool in + let* cond_exp,s = lower_rexp cond_exp in + let* cond_t = M.ES.get_type_from_id cond_exp.info |> M.ESC.lift in + let* _ = matchArgParam (fst cond_exp.info) cond_t (dummy_pos,Bool) |> M.ESC.lift in let* res = aux then_s in begin match else_s with - | None -> return @@ buildStmt (If(cond_exp, res, None)) - | Some s -> let+ s = aux s in buildStmt (If(cond_exp, res, Some s)) + | None -> return @@ buildSeqStmt s (If(cond_exp, res, None)) + | Some else_ -> let+ else_ = aux else_ in buildSeqStmt s (If(cond_exp, res, Some else_)) end | Loop c -> @@ -214,45 +257,45 @@ struct | Break -> return (buildStmt Break) | Case(e, _cases) -> - let+ e = lower_rexp e in - buildStmt (Case (e, [])) + let+ e,s = lower_rexp e in + buildSeqStmt s (Case (e, [])) - | Invoke (var, mod_loc, id, el) -> (* todo: handle var *) - let* el = ListM.map lower_rexp el in - let* origin,_ = find_function_source s.info var id mod_loc el in - buildStmt (Invoke(var,origin, id,el)) |> return + | Invoke i -> (* todo: handle var *) + let* args,s = MF.ListM.map lower_rexp i.args in + let* import,_ = find_function_source s.info i.ret_var i.id i.import args |> M.ESC.lift in + buildSeqStmt s (Invoke { i with import ; args} ) |> return | Return None as r -> if proto.rtype = None then return (buildStmt r) else - log_and_skip (Error.make loc @@ Printf.sprintf "void return but %s returns %s" proto.name (string_of_sailtype proto.rtype)) + log_and_skip (make_msg loc @@ Printf.sprintf "void return but %s returns %s" proto.name (string_of_sailtype proto.rtype)) | Return (Some e) -> - let* e = lower_rexp e in - let _,t as lt = e.info in + let* e,s = lower_rexp e in + let* t = M.ES.get_type_from_id e.info |> M.ESC.lift in begin match proto.rtype with | None -> - log_and_skip (Error.make loc @@ Printf.sprintf "returns %s but %s doesn't return anything" (string_of_sailtype (Some t)) proto.name) + log_and_skip (make_msg loc @@ Printf.sprintf "returns %s but %s doesn't return anything" (string_of_sailtype (Some t)) proto.name) | Some r -> - matchArgParam lt r >>| fun _ -> - buildStmt (Return (Some e)) + matchArgParam (fst e.info) t r |> M.ESC.lift >>| fun _ -> + buildSeqStmt s (Return (Some e)) end | Block c -> - let* () = ES.update (fun e -> THIREnv.new_frame e |> E.pure) in + let* () = M.ESC.update_env (fun (e,t) -> (THIREnv.new_frame e,t) |> M.E.pure) in let* res = aux c in - let+ () = ES.update (fun e -> THIREnv.pop_frame e |> E.pure) in + let+ () = M.ESC.update_env (fun (e,t) -> (THIREnv.pop_frame e,t) |> M.E.pure) in buildStmt (Block res) | Skip -> return (buildStmt Skip) in - ES.run (aux body env) |> Logger.recover (buildStmt dummy_pos Skip,snd env) + M.(E.bind (ESC.run aux body (env,tenv)) (fun (x,y) -> E.pure (x,snd env,y))) |> Logger.recover (buildStmt dummy_pos Skip,snd env,tenv) - let preprocess = Logger.pure + let preprocess = resolve_types (* todo : create semantic types + type inference *) - let lower_process (c:p_in process_defn) env _ = E.pure (c.p_body,snd env) + let lower_process (c:p_in process_defn) (env,tenv) _ = M.E.pure (c.p_body,snd env,tenv) end ) diff --git a/src/passes/ir/sailThir/thirMonad.ml b/src/passes/ir/sailThir/thirMonad.ml index 8776d9b..11b3e4f 100644 --- a/src/passes/ir/sailThir/thirMonad.ml +++ b/src/passes/ir/sailThir/thirMonad.ml @@ -1,36 +1,68 @@ open Common -open TypesCommon -module E = Error.Logger +module Make(MonoidSeq : Monad.Monoid) = struct + module V : Env.Variable with type t = Env.TypeEnv.t -> (string * Env.TypeEnv.t) = ( + struct + type t = Env.TypeEnv.t -> (string * Env.TypeEnv.t) + let string_of_var _v : string = "" + let param_to_var (p:TypesCommon.param) : t = fun te -> Env.TypeEnv.get_id p.ty te + end + ) -module V = ( - struct - type t = bool * sailtype - let string_of_var (_,t) = string_of_sailtype (Some t) - let param_to_var p = p.mut,p.ty - end -) - -module THIREnv = SailModule.SailEnv(V) + module THIREnv = SailModule.SailEnv(V) -module ES = struct - include MonadState.T(E)(THIREnv) + module M = struct + module E = Logging.Logger + module ES = struct + include MonadState.T(E)(struct type t = THIREnv.t * Env.TypeEnv.t end) - let get_decl id ty = bind get (fun e -> THIREnv.get_decl id ty e |> pure) + let get_decl id ty = bind get (fun (e,_) -> THIREnv.get_decl id ty e |> pure) + let get_var id = bind get (fun (e,_) -> THIREnv.get_var id e |> pure) + let throw e = E.throw e |> lift + let log e = E.log e |> lift + let log_if b e = E.log_if b e |> lift + let throw_if e b = E.throw_if e b |> lift + let throw_if_none e opt = E.throw_if_none e opt |> lift + let run (f : 'a -> 'b t) env = fun x -> E.bind (f x env) (fun (e,(_,s)) -> E.pure (e,s)) + let recover (type a) (default:a) (x: a t) : a t = + fun e -> E.recover (default,e) (x e) + + let get_type_from_id id = bind get (fun (_,te) -> Env.TypeEnv.get_from_id id te |> lift) - let get_var id = bind get (fun e -> THIREnv.get_var id e |> pure) + let get_type_id ty : string t = fun (e,te) -> let id,te = Env.TypeEnv.get_id ty te in E.pure (id,(e,te)) + + end + + module ESC = struct + include MonadState.CounterTransformer(ES)(struct type t = int let succ = Int.succ let init = 0 end) + let get_decl id ty = ES.get_decl id ty |> lift + let get_var id = ES.get_var id |> lift + let throw_if_none e opt = ES.throw_if_none e opt |> lift + let throw_if e b = ES.throw_if e b |> lift + let throw e = ES.throw e |> lift + let recover def x = ES.recover def x |> lift + let log e = ES.log e |> lift + let update_env f = ES.update f |> lift - let throw e = E.throw e |> lift - let log e = E.log e |> lift - let log_if b e = E.log_if b e |> lift - let throw_if e b = E.throw_if e b |> lift + let run (f : 'a -> 'b t) = fun x e -> E.bind (ES.run (f x) e 0) (fun ((res,_),e) -> E.pure (res,e)) + end - let throw_if_none e opt = E.throw_if_none e opt |> lift + include MonadWriter.MakeTransformer(ESC)(MonoidSeq) + let get_decl id ty = ESC.get_decl id ty |> lift + let get_var id = ESC.get_var id |> lift + let throw_if_none e opt = ESC.throw_if_none e opt |> lift + let throw_if e b = ESC.throw_if e b |> lift + let throw e = ESC.throw e |> lift + let recover def x = ESC.recover def x |> lift + let log e = ESC.log e |> lift + let get_env = ES.get |> ESC.lift |> lift - let run e = E.bind e (fun (e,(_,s)) -> E.pure (e,s)) + let set_env e = ES.set e |> ESC.lift |> lift - let recover (type a) (default:a) (x: a t) : a t = - fun e -> E.recover (default,e) (x e) + let get_type_from_id id = ES.get_type_from_id id |> ESC.lift |> lift + let get_type_id ty : string t = ES.get_type_id ty|> ESC.lift |> lift + let fresh_fvar = bind (ESC.fresh |> lift) (fun n -> pure @@ "__f" ^ string_of_int n) + end end \ No newline at end of file diff --git a/src/passes/ir/sailThir/thirUtils.ml b/src/passes/ir/sailThir/thirUtils.ml index 1df54aa..09e4fb4 100644 --- a/src/passes/ir/sailThir/thirUtils.ml +++ b/src/passes/ir/sailThir/thirUtils.ml @@ -1,37 +1,45 @@ open Common open TypesCommon -open ThirMonad open Monad open IrHir module D = SailModule.Declarations -open MonadSyntax(ES) -open MonadOperator(ES) -open MonadFunctions(ES) - -type expression = (loc * sailtype, l_str) AstHir.expression +type expression = (loc * string, l_str) AstHir.expression (* string is the key for the type map *) type statement = (loc,l_str,expression) AstHir.statement -let rec resolve_alias loc : sailtype -> (sailtype,string) Either.t ES.t = function -| CompoundType {origin;name=(_,name);decl_ty=Some (T ());_} -> - let* (_,mname) = ES.throw_if_none (Error.make loc @@ "unknown type '" ^ name ^ "' , all types must have an origin (problem with HIR)") origin in - let* ty_t = ES.get_decl name (Specific (mname,Type)) - >>= ES.throw_if_none (Error.make loc @@ Fmt.str "declaration '%s' requires importing module '%s'" name mname) in - begin - match ty_t.ty with - | Some (CompoundType _ as ct) -> resolve_alias loc ct - | Some t -> return (Either.left t) - | None -> return (Either.right name) (* abstract type, only look at name *) +module M = ThirMonad.Make(struct + type t = statement + let mempty : t = {info=dummy_pos; stmt=Skip} + let mconcat : t -> t -> t = fun x y -> {info=dummy_pos; stmt=Seq (x,y)} end -| t -> return (Either.left t) +) + +open M +open UseMonad(M.ES) + + + +let rec resolve_alias (l,ty : sailtype) : (sailtype,string) Either.t M.ES.t = + match ty with + | CompoundType {origin;name=(_,name);decl_ty=Some (T ());_} -> + let* (_,mname) = M.ES.throw_if_none Logging.(make_msg l @@ "unknown type '" ^ name ^ "' , all types must have an origin (problem with HIR)") origin in + let* ty_t = M.ES.get_decl name (Specific (mname,Type)) + >>= M.ES.throw_if_none Logging.(make_msg l @@ Fmt.str "declaration '%s' requires importing module '%s'" name mname) in + begin + match ty_t.ty with + | Some (_,CompoundType _ as ct) -> resolve_alias ct + | Some t -> return (Either.left t) + | None -> return (Either.right name) (* abstract type, only look at name *) + end + | _ -> return (Either.left (l,ty)) -let string_of_sailtype_thir (t : sailtype option) : string ES.t = +let string_of_sailtype_thir (t : sailtype option) : string M.ES.t = let+ res = match t with - | Some CompoundType {origin; name=(loc,x); _} -> - let* (_,mname) = ES.throw_if_none (Error.make loc "no origin in THIR (problem with HIR)") origin in - let+ decl = ES.(get_decl x (Specific (mname,Filter [E (); S (); T()])) - >>= throw_if_none (Error.make loc "decl is null (problem with HIR)")) in + | Some (_,CompoundType {origin; name=(loc,x); _}) -> + let* (_,mname) = M.ES.throw_if_none Logging.(make_msg loc "no origin in THIR (problem with HIR)") origin in + let+ decl = M.ES.(get_decl x (Specific (mname,Filter [E (); S (); T()])) + >>= throw_if_none Logging.(make_msg loc "decl is null (problem with HIR)")) in begin match decl with | T ty_def -> @@ -43,89 +51,101 @@ let string_of_sailtype_thir (t : sailtype option) : string ES.t = | S (_,s) -> Fmt.str " (= struct <%s>)" (List.map (fun (n,(_,t,_)) -> Fmt.str "%s:%s" n @@ string_of_sailtype (Some t) ) s.fields |> String.concat ", ") | _ -> failwith "can't happen" end - | _ -> ES.pure "" + | _ -> return "" in (string_of_sailtype t) ^ res -let matchArgParam (l,arg: loc * sailtype) (m_param : sailtype) : sailtype ES.t = - let open MonadSyntax(ES) in - - let rec aux (a:sailtype) (m:sailtype) = - let* lt = resolve_alias l a in - let* rt = resolve_alias l m in +let matchArgParam (loc : loc) (arg: sailtype) (m_param : sailtype) : sailtype M.ES.t = + let rec aux (a:sailtype) (m:sailtype) : sailtype M.ES.t = + let* lt = resolve_alias a in + let* rt = resolve_alias m in match lt,rt with - | Left Bool,Left Bool -> return Bool - | Left (Int i1), Left (Int i2) when i1 = i2 -> return (Int i1) - | Left Float,Left Float -> return Float - | Left Char,Left Char -> return Char - | Left String,Left String -> return String - | Left ArrayType (at,s),Left ArrayType (mt,s') -> - if s = s' then - let+ t = aux at mt in ArrayType (t,s) - else - ES.throw @@ Error.make l (Printf.sprintf "array length mismatch : wants %i but %i provided" s' s) + | Left (loc_l,l), Left (_,r) -> + begin + match l,r with + | Bool, Bool -> return (loc_l,Bool) + | (Int i1), (Int i2) when i1 = i2 -> return (loc_l,Int i1) + | Float, Float -> return (loc_l,Float) + | Char, Char -> return (loc_l,Char) + | String, String -> return (loc_l,String) + | ArrayType (at,s), ArrayType (mt,s') -> + if s = s' then + let+ t = aux at mt in loc_l,ArrayType (t,s) + else + M.ES.throw Logging.(make_msg loc_l (Printf.sprintf "array length mismatch : wants %i but %i provided" s' s)) + + | Box _at, Box _mt -> M.ES.throw Logging.(make_msg loc_l "todo box") + + | RefType (at,am), RefType (mt,mm) -> + if am <> mm then M.ES.throw Logging.(make_msg loc_l "different mutability") + else aux at mt + + | at, GenericType _ + | GenericType _, at -> return (loc_l,at) + + | CompoundType c1, CompoundType c2 when snd c1.name = snd c2.name -> + return arg + + | _ -> let* param = string_of_sailtype_thir (Some m_param) and* arg = string_of_sailtype_thir (Some arg) in + M.ES.throw Logging.(make_msg loc_l @@ Printf.sprintf "wants %s but %s provided" param arg) + end | Right name, Right name' -> - let+ () = ES.throw_if (Error.make l @@ Fmt.str "want abstract type %s but abstract type %s provided" name name') (name <> name') in + let+ () = M.ES.throw_if Logging.(make_msg loc @@ Fmt.str "want abstract type %s but abstract type %s provided" name name') (name <> name') in arg - - | Left Box _at, Left Box _mt -> ES.throw (Error.make l "todo box") - | Left RefType (at,am), Left RefType (mt,mm) -> if am <> mm then ES.throw (Error.make l "different mutability") else aux at mt - | Left at,Left GenericType _ |Left GenericType _,Left at -> return at - | Left CompoundType {name=(_,name1);origin=_;_}, Left CompoundType {name=(_,name2);_} when name1 = name2 -> - return arg - | _ -> let* param = string_of_sailtype_thir (Some m_param) - and* arg = string_of_sailtype_thir (Some arg) in - ES.throw @@ Error.make l @@ Printf.sprintf "wants %s but %s provided" param arg + and* arg' = string_of_sailtype_thir (Some arg) in + M.ES.throw Logging.(make_msg loc @@ Printf.sprintf "wants %s but %s provided" param arg') in aux arg m_param -let check_binop op l r : sailtype ES.t = - let open MonadSyntax(ES) in +let check_binop op l r : string M.ES.t = + let* l_t = M.ES.get_type_from_id l + and* r_t = M.ES.get_type_from_id r in match op with | Lt | Le | Gt | Ge | Eq | NEq -> - let+ _ = matchArgParam r (snd l) in Bool + let* _ = matchArgParam (fst l_t) r_t l_t in M.ES.get_type_id (fst l_t,Bool) | And | Or -> - let+ _ = matchArgParam l Bool - and* _ = matchArgParam r Bool in Bool + let* _ = matchArgParam (fst l_t) l_t (fst l_t,Bool) + and* _ = matchArgParam (fst l_t) r_t (fst l_t,Bool) + in M.ES.get_type_id (fst l_t,Bool) | Plus | Mul | Div | Minus | Rem -> - let+ _ = matchArgParam r (snd l) in snd l + let+ _ = matchArgParam (fst l_t) r_t l_t in snd l -let check_call (name:string) (f : method_proto) (args: expression list) loc : unit ES.t = +let check_call (name:string) (f : method_proto) (args: expression list) loc : unit M.ES.t = (* if variadic, we just make sure there is at least minimum number of arguments needed *) let args = if f.variadic then List.filteri (fun i _ -> i < (List.length f.args)) args else args in let nb_args = List.length args and nb_params = List.length f.args in - ES.throw_if (Error.make loc (Printf.sprintf "unexpected number of arguments passed to %s : expected %i but got %i" name nb_params nb_args)) + M.ES.throw_if Logging.(make_msg loc (Printf.sprintf "unexpected number of arguments passed to %s : expected %i but got %i" name nb_params nb_args)) (nb_args <> nb_params) >>= fun () -> ListM.iter2 ( fun (ca:expression) ({ty=a;_}:param) -> - let+ _ = matchArgParam ca.info a in () + let* rty = M.ES.get_type_from_id ca.info in + let+ _ = matchArgParam (fst ca.info) rty a in () ) args f.args -let find_function_source (fun_loc:loc) (_var: string option) (name : l_str) (import:l_str option) (el: expression list) : (l_str * D.method_decl) ES.t = - let* _,env = ES.get in - let* mname,def = HirUtils.find_symbol_source ~filt:[M ()] name import env |> ES.lift in +let find_function_source (fun_loc:loc) (_var: string option) (name : l_str) (import:l_str option) (el: expression list) : (l_str * D.method_decl) M.ES.t = + let* (_,env),_ = M.ES.get in + let* mname,def = HirUtils.find_symbol_source ~filt:[M ()] name import env |> M.ES.lift in match def with | M decl -> - let _x = fun_loc and _y = el in let+ _ = check_call (snd name) (snd decl) el fun_loc in mname,decl - (* return (mname,decl) *) + (* let _x = fun_loc and _y = el in return (mname,decl) *) | _ -> failwith "non method returned" (* cannot happen because we only requested methods *) (* ES.throw - @@ Error.make fun_loc ~hint:(Some (None,"specify one with '::' annotation")) + @@ make_msg fun_loc ~hint:(Some (None,"specify one with '::' annotation")) @@ Fmt.str "multiple definitions for function %s : \n\t%s" id (List.map ( fun (i,((_,_,m):SailModule.Declarations.method_decl)) -> @@ -134,11 +154,110 @@ let find_function_source (fun_loc:loc) (_var: string option) (name : l_str) (imp (string_of_sailtype m.ret) ) choice |> String.concat "\n\t") *) -let find_struct_source (name: l_str) (import : l_str option) : (l_str * D.struct_decl) ES.t = - let* _,env = ES.get in - let+ origin,def = HirUtils.find_symbol_source ~filt:[S()] name import env |> ES.lift in +let find_struct_source (name: l_str) (import : l_str option) : (l_str * D.struct_decl) M.ES.t = + let* (_,env),_ = M.ES.get in + let+ origin,def = HirUtils.find_symbol_source ~filt:[S()] name import env |> M.ES.lift in begin match def with | S decl -> origin,decl | _ -> failwith "non struct returned" end + + + + + let resolve_types (sm : ('a,'b) SailModule.methods_processes SailModule.t) = + let open HirUtils in + + let module ES = struct + module T = struct type t = {decls: D.t ; types : Env.TypeEnv.t} end + module S = MonadState.M(T) + include Logging.MakeTransformer(S) + let update_env f = S.update f |> lift + let set_env e = S.set e |> lift + let get_env = S.get |> lift + end + in + let open UseMonad(ES) in + let module TEnv = MakeFromSequencable(SailModule.DeclEnv.TypeSeq) in + let module SEnv = MakeFromSequencable(SailModule.DeclEnv.StructSeq) in + let open SailModule.DeclEnv in + + (* resolving aliases *) + let sm = ( + + let* env = ES.get_env in + + let* () = + TEnv.iter ( + fun (id,({ty; _} as def)) -> + let* ty = match ty with + | None -> return None + | Some t -> + let* env = ES.get_env in + let* t,decls = (follow_type t env.decls) |> ES.S.lift in + let+ () = ES.set_env {env with decls} in Some t + in + ES.update_env (fun e -> {e with decls=update_decl id {def with ty} (Self Type) e.decls}) + ) (get_own_decls env.decls |> get_decls Type) in + + let* env = ES.get_env in + + let* () = SEnv.iter ( + fun (id,(l,{fields; generics})) -> + let* fields = ListM.map ( + fun (name,(l,t,n)) -> + let* env = ES.get_env in + let* t,decls = (follow_type t env.decls) |> ES.S.lift in + let+ () = ES.set_env {env with decls} in + name,(l,t,n) + ) fields + in + let proto = l,{fields;generics} in + ES.update_env (fun e -> {e with decls=update_decl id proto (Self Struct) e.decls}) + ) (get_own_decls env.decls |> get_decls Struct) + in + + let* methods = ListM.map ( + fun ({m_proto;m_body} as m) -> + let* rtype = match m_proto.rtype with + | None -> return None + | Some t -> + let* env = ES.get_env in + let* t,decls = (follow_type t env.decls) |> ES.S.lift in + let+ () = ES.set_env {env with decls} in Some t + in + let* params = ListM.map ( + fun (({ty;_}:param) as p) -> + let* env = ES.get_env in + let* ty,decls = (follow_type ty env.decls) |> ES.S.lift in + let+ () = ES.set_env {env with decls} in {p with ty} + ) m_proto.params in + let m = {m with m_proto={m_proto with params; rtype}} in + let true_name = (match m_body with Left (sname,_) -> sname | Right _ -> m_proto.name) in + let+ () = ES.update_env + (fun e -> + let decls = update_decl m_proto.name ((m_proto.pos,true_name), defn_to_proto (Method m)) (Self Method) e.decls in + {e with decls} + ) + in m + ) sm.body.methods in + + let+ processes = ListM.map ( + fun proc -> + let* p_params = ListM.map ( + fun (({ty;_}:param) as p) -> + let* env = ES.get_env in + let* ty,decls = (follow_type ty env.decls) |> ES.S.lift in + let+ () = ES.set_env {env with decls} in {p with ty} + ) proc.p_interface.p_params in + let p = {proc with p_interface={proc.p_interface with p_params}} in + let+ () = ES.update_env (fun e -> {e with decls=update_decl p.p_name (p.p_pos, defn_to_proto (Process p)) (Self Process) e.decls}) + in p + ) sm.body.processes in + + + {sm with body=SailModule.{methods; processes}; declEnv=env.decls; typeEnv=env.types} + + ) {decls=sm.declEnv;types=sm.typeEnv} |> fst in + sm \ No newline at end of file diff --git a/src/passes/misc/cfg_analysis.ml b/src/passes/misc/cfg_analysis.ml index 5ff8f6e..8e42a2f 100644 --- a/src/passes/misc/cfg_analysis.ml +++ b/src/passes/misc/cfg_analysis.ml @@ -8,7 +8,7 @@ open AstMir open Monad -module E = Error.Logger +module E = Logging.Logger open UseMonad(E) let cfg_returns ({input;blocks;_} : cfg) : (loc option * _ basicBlock BlockMap.t) E.t = let rec aux lbl blocks = @@ -19,7 +19,7 @@ let cfg_returns ({input;blocks;_} : cfg) : (loc option * _ basicBlock BlockMap.t begin match bb.terminator with | None -> (Some bb.location, blocks') |> E.pure - | Some Break -> E.throw @@ Error.make bb.location "there should be no break at this point" >>= fun () -> aux input blocks' + | Some Break -> E.throw Logging.(make_msg bb.location "there should be no break at this point") >>= fun () -> aux input blocks' | Some Return _ -> (None, blocks') |> E.pure | Some (Invoke {next;_}) -> aux next blocks' | Some (Goto lbl) -> aux lbl blocks' @@ -30,20 +30,22 @@ let cfg_returns ({input;blocks;_} : cfg) : (loc option * _ basicBlock BlockMap.t in aux input blocks +(* not correct let check_unreachable (proto : method_sig) (_,cfg : mir_function) : unit E.t = let* ret,unreachable_blocks = cfg_returns cfg in if Option.is_some ret && proto.rtype <> None then - E.throw @@ Error.make (Option.get ret) @@ Printf.sprintf "%s doesn't always return !" proto.name + E.throw Logging.(make_msg (Option.get ret) @@ Printf.sprintf "%s doesn't always return !" proto.name) else - let () = BlockMap.iter (fun lbl {location=_;_} -> Logs.debug (fun m -> m "unreachable block %i" lbl)) unreachable_blocks in + let () = BlockMap.iter (fun lbl {location=l;_} -> Logs.debug (fun m -> m "unreachable block %i, line %i" lbl (fst l).pos_lnum)) unreachable_blocks in let blocks = BlockMap.filter (fun _ {location;_} -> location <> dummy_pos) unreachable_blocks in match BlockMap.choose_opt blocks with | Some (_,bb) -> - let _loc = match List.nth_opt bb.assignments 0 with + let loc = match List.nth_opt bb.assignments 0 with | Some v -> v.location | None -> bb.location in - E.throw @@ Error.make bb.location "unreachable code" - | None -> E.pure () + E.log Logging.(make_msg loc"unreachable code") + | None -> E.pure () + *) let check_returns (proto : method_sig) (decls,cfg : mir_function) : mir_function E.t = @@ -53,7 +55,7 @@ let check_returns (proto : method_sig) (decls,cfg : mir_function) : mir_functio | None,None -> (* we insert void return when return type is void*) E.pure {block with terminator= Some (Return None)} | None,Some _ -> - E.throw Error.(make proto.pos + E.throw Logging.(make_msg proto.pos @@ Printf.sprintf "no return statement but must return %s" @@ string_of_sailtype proto.rtype) | _ -> E.pure block @@ -72,9 +74,9 @@ module Pass = Make(struct fun f -> let+ m_body = match f.m_body with | Left _ -> E.pure f.m_body - | Right b -> - check_unreachable f.m_proto b >>= fun () -> - check_returns f.m_proto b >>| fun b -> Either.right b in + | Right b -> + (* check_unreachable f.m_proto b >>= fun () -> not correct *) + check_returns f.m_proto b >>| fun b -> Either.right b in {f with m_body} ) sm.body.methods in diff --git a/src/passes/misc/imports.ml b/src/passes/misc/imports.ml index 252540d..3bb1572 100644 --- a/src/passes/misc/imports.ml +++ b/src/passes/misc/imports.ml @@ -5,7 +5,7 @@ open IrMir open IrHir open SailParser -module E = Common.Error.Logger +module E = Common.Logging.Logger module Env = SailModule.DeclEnv open Monad.UseMonad(E) diff --git a/src/passes/misc/methodCall.ml b/src/passes/misc/methodCall.ml deleted file mode 100644 index a40ac91..0000000 --- a/src/passes/misc/methodCall.ml +++ /dev/null @@ -1,179 +0,0 @@ -open Common -open TypesCommon -open Monad -open IrThir -open IrHir -open SailParser - -module V = ( - struct - type t = bool * sailtype - let string_of_var (_,t) = string_of_sailtype (Some t) - let param_to_var (p:param) = p.mut,p.ty - end -) -module THIREnv = SailModule.SailEnv(V) -module E = Error.Logger -module EC = MonadState.CounterTransformer(E)(struct type t = int let succ = Int.succ let init = 0 end) -module ECS = struct - include MonadState.T(EC)(THIREnv) - - let fresh = EC.fresh |> lift - let throw e = E.throw e |> EC.lift |> lift - let log e = E.log e |> EC.lift |> lift - let log_if b e = E.log_if b e |> EC.lift |> lift - let throw_if b e = E.throw_if b e |> EC.lift |> lift - let run e = let e = EC.run e in E.bind e (fun (e,(_,s)) -> E.pure (e,s)) - - let get_decl id ty = bind get (fun e -> THIREnv.get_decl id ty e |> pure) -end - -module ECSW = struct - include MonadWriter.MakeTransformer(ECS)( struct - type t = Thir.statement - let mempty : t = {info=dummy_pos; stmt=Skip} - let mconcat : t -> t -> t = fun x y -> {info=dummy_pos; stmt=Seq (x,y)} - end) - - let fresh = ECS.fresh |> lift - let throw e = ECS.throw e |> lift - let log e = ECS.log e |> lift - let get_env = ECS.get |> lift - - let get_decl id ty = ECS.bind ECS.get (fun e -> THIREnv.get_decl id ty e |> ECS.pure) |> lift -end - -let get_hint id env = Option.bind (List.nth_opt (THIREnv.get_closest id env) 0) (fun id -> Some (None,Printf.sprintf "Did you mean %s ?" id)) - - - -module Pass = Pass.MakeFunctionPass(V) -( - struct - let name = "Extract method call out of expressions (fixme : should be in hir but requires type inference)" - - type m_in = ThirUtils.statement - type m_out = m_in - - type p_in = (HirUtils.statement,HirUtils.expression) AstParser.process_body - type p_out = p_in - - open MonadFunctions(ECSW) - open MakeOrderedFunctions(String) - open AstHir - open Thir - - let lower_expression e : expression ECSW.t = - let open MonadSyntax(ECSW) in - let open MonadOperator(ECSW) in - let rec aux (e:expression) = - let loc,_ as info = e.info in - match e.exp with - | Variable id -> return {info; exp=Variable id} - | Deref e -> - let+ e = aux e in {info;exp=Deref e} - | StructRead (o,e, id) -> - let+ e = aux e in {info; exp=StructRead (o,e, id)} - | ArrayRead (e1, e2) -> - let* e1 = aux e1 in - let+ e2 = aux e2 in - {info; exp=ArrayRead(e1,e2)} - | Literal l -> return {info; exp=Literal l} - | UnOp (op, e) -> - let+ e = aux e in {info;exp=UnOp (op, e)} - | BinOp(op,e1,e2)-> - let* e1 = aux e1 in - let+ e2 = aux e2 in - {info; exp=BinOp (op, e1, e2)} - | Ref (b, e) -> - let+ e = aux e in {info;exp=Ref(b, e)} - | ArrayStatic el -> - let+ el = ListM.map aux el in {info;exp=ArrayStatic el} - | StructAlloc (o,id, m) -> - let+ m = ListM.map (pairMap2 aux) m in {info; exp=StructAlloc (o,id, m)} - | EnumAlloc (id, el) -> - let+ el = ListM.map aux el in {info;exp=EnumAlloc (id, el)} - | MethodCall ((l_id,id), ((_,mname) as origin), el) -> (* THIS IS THE PROBLEM : WE NEED TO KNOW THE RETURN TYPE !! *) - let* m = ECSW.get_decl id (Specific (mname,Method)) in - match m with - | Some (_proto_loc,proto) -> - begin - match proto.ret with - | Some rtype -> - let* n = ECSW.fresh in - let x = "__f" ^ string_of_int n in - let* el = ListM.map aux el in - let* () = ECSW.write {info=loc; stmt=DeclVar (false, x, Some rtype, None)} in - let+ () = ECSW.write {info=loc; stmt=Invoke(Some x,origin, (l_id,id), el)} in - {info;exp=Variable x} - - | None -> ECSW.throw (Error.make loc "methods in expressions should return a value (problem with THIR)") - end - | _ -> let* env = ECSW.get_env in let hint = get_hint id env in ECSW.throw (Error.make l_id "unknown method" ~hint) - in aux e - - - let lower_method (body,_proto : m_in * method_sig ) env _ : (m_out * THIREnv.D.t) E.t = - let open MonadSyntax(ECS) in - let open MonadOperator(ECS) in - let rec aux (s : statement) : statement ECS.t = - - let buildSeq s1 s2 = {info=dummy_pos; stmt = Seq (s1, s2)} in - let buildStmt stmt = {info=dummy_pos;stmt} in - let buildSeqStmt s1 s2 = buildSeq s1 @@ buildStmt s2 in - - let info = s.info in - match s.stmt with - | DeclVar (mut, id, t, e ) -> - begin match e with - | Some e -> let+ (e, s) = lower_expression e in - buildSeqStmt s (DeclVar (mut,id, t, Some e)) - | None -> return {info;stmt=DeclVar (mut,id, t, None)} - end - | Skip -> return {info;stmt=Skip} - | Assign(e1, e2) -> - let* e1,s1 = lower_expression e1 in - let+ e2,s2 = lower_expression e2 in - buildSeq s1 @@ buildSeqStmt s2 (Assign (e1, e2)) - - | Seq (c1, c2) -> let+ c1 = aux c1 and* c2 = aux c2 in {info;stmt=Seq (c1, c2)} - | If (e, c1, Some c2) -> - let+ e,s = lower_expression e and* c1 = aux c1 and* c2 = aux c2 in - buildSeqStmt s (If (e, c1, Some c2)) - - | If ( e, c1, None) -> - let+ (e, s) = lower_expression e and* c1 = aux c1 in - buildSeqStmt s (If (e, c1, None)) - - | Loop c -> - let+ c = aux c in - buildStmt (Loop c) - - | Break -> return {info; stmt=Break} - | Case (e, _cases) -> let+ e,s = lower_expression e in - buildSeqStmt s (Case (e, [])) - - | Invoke (target, origin, lid, el) -> - let+ el,s = ListM.map lower_expression el in - buildSeqStmt s (Invoke(target, origin, lid,el)) - - | Return e -> - begin match e with - | None -> return @@ buildStmt (Return None) - | Some e -> let+ e,s = lower_expression e in - buildSeqStmt s (Return (Some e)) - end - | Block c -> let+ c = aux c in buildStmt (Block c) - - in - ECS.run (aux body env) |> E.recover ({info=dummy_pos;stmt=Skip},snd env) - - let preprocess = Error.Logger.pure - - let lower_process (c:p_in process_defn) env _ = E.pure (c.p_body,snd env) - - end -) - - - diff --git a/src/passes/monomorphization/monomorphization.ml b/src/passes/monomorphization/monomorphization.ml index 039e6d8..2a63bf0 100644 --- a/src/passes/monomorphization/monomorphization.ml +++ b/src/passes/monomorphization/monomorphization.ml @@ -1,16 +1,13 @@ open Common open Monad open TypesCommon -module E = Common.Error +module E = Common.Logging open Monad.MonadSyntax (E.Logger) open IrMir.AstMir open MonomorphizationMonad module M = MonoMonad open MonomorphizationUtils -open MonadSyntax(M) -open MonadOperator(M) -open MonadFunctions(M) - +open UseMonad(M) module Pass = Pass.Make (struct let name = "Monomorphization" @@ -22,20 +19,26 @@ module Pass = Pass.Make (struct let mono_fun (f : sailor_function) (sm : in_body SailModule.t) : unit M.t = - let mono_exp (e : expression) : sailtype M.t = + let mono_exp (e : expression) (decls :declaration list) : sailtype M.t = let rec aux (e : expression) : sailtype M.t = match e.exp with - | Variable s -> M.get_var s >>| fun v -> (v |> Option.get |> snd).ty + | Variable s -> + M.get_var s + <&> (function + | Some v -> Some (snd v).ty (* var is a function param *) + | None -> Option.bind (List.find_opt (fun v -> v.id = s) decls) (fun decl -> Some decl.varType) (* var is function declaration *) + ) + >>= M.throw_if_none Logging.(make_msg (fst e.info) @@ Fmt.str "compiler error : var '%s' not found" s) | Literal l -> return (sailtype_of_literal l) | ArrayRead (e, idx) -> begin - let* t = aux e in + let* l,t = aux e in match t with | ArrayType (t, _) -> let+ idx_t = aux idx in - let _ = resolveType idx_t (Int 32) [] [] in + let _ = resolveType idx_t (l,Int 32) [] [] in t | _ -> failwith "cannot happen" end @@ -49,25 +52,25 @@ module Pass = Pass.Make (struct | Ref (m, e) -> let+ t = aux e in - RefType (t, m) + dummy_pos,RefType (t, m) | Deref e -> ( - let+ t = aux e in + let+ l,t = aux e in match t with - | RefType _ -> t + | RefType _ -> l,t | _ -> failwith "cannot happen" ) | ArrayStatic (e :: h) -> - let* t = aux e in - let+ t = - ListM.fold_left (fun last_t e -> - let+ next_t = aux e in - let _ = resolveType next_t last_t [] [] in - next_t - ) t h - in - ArrayType (t, List.length (e :: h)) + let* t = aux e in + let+ t = + ListM.fold_left (fun last_t e -> + let+ next_t = aux e in + let _ = resolveType next_t last_t [] [] in + next_t + ) t h + in + dummy_pos,ArrayType (t, List.length (e :: h)) | ArrayStatic [] -> failwith "error : empty array" | StructAlloc (_, _, _) -> failwith "todo: struct alloc" @@ -78,7 +81,7 @@ module Pass = Pass.Make (struct aux e in - let construct_call (calle : string) (el : expression list) : (string * sailtype option) M.t = + let construct_call (calle : string) (el : expression list) decls : (string * sailtype option) M.t = (* we construct the types of the args (and collect extra new calls) *) Logs.debug (fun m -> m "contructing call to %s from %s" calle f.m_proto.name); let* monos = M.get_monos and* funs = M.get_funs in @@ -90,7 +93,7 @@ module Pass = Pass.Make (struct ListM.fold_left (fun l e -> Logs.debug (fun m -> m "analyze param expression"); - let* t = mono_exp e in + let* t = mono_exp e decls in Logs.debug (fun m -> m "param is %s " @@ string_of_sailtype @@ Some t); return (t :: l) ) @@ -108,7 +111,7 @@ module Pass = Pass.Make (struct begin let* f = find_callable calle sm |> M.lift in match f with - | None -> (*import *) return (mname,Some (Int 32) (*fixme*)) + | None -> (*import *) return (mname,Some (dummy_pos,Int 32) (*fixme*)) | Some f -> begin Logs.debug (fun m -> m "found call to %s, variadic : %b" f.m_proto.name f.m_proto.variadic ); @@ -153,47 +156,49 @@ module Pass = Pass.Make (struct end in - let rec mono_body (lbl: label) (treated: LabelSet.t) (blocks : (VE.t,unit) basicBlock BlockMap.t): (LabelSet.t * (_,_) basicBlock BlockMap.t) MonoMonad.t = - (* collect calls and name correctly *) - if LabelSet.mem lbl treated then return (treated,blocks) - else - begin - let treated = LabelSet.add lbl treated in + let mono_body (lbl: label) (blocks : (VE.t,unit) basicBlock BlockMap.t) (decls : declaration list) : (_,_) basicBlock BlockMap.t MonoMonad.t = + let rec aux lbl (treated: LabelSet.t) blocks = + (* collect calls and name correctly *) + if LabelSet.mem lbl treated then return (treated,blocks) + else + begin + let treated = LabelSet.add lbl treated in - let bb = BlockMap.find lbl blocks in - let* () = M.set_ve bb.forward_info in - let* () = ListM.iter (fun assign -> mono_exp assign.target >>= fun _ty -> mono_exp assign.expression >>| fun _ty -> ()) bb.assignments - in - - match bb.terminator |> Option.get with - | Return e -> - let+ _ = - begin - match e with - | Some e -> let+ t = mono_exp e in Some t - | None -> return None - end - in treated,blocks + let bb = BlockMap.find lbl blocks in + let* () = M.set_ve bb.forward_info in + let* () = ListM.iter (fun assign -> mono_exp assign.target decls >>= fun _ty -> mono_exp assign.expression decls >>| fun _ty -> ()) bb.assignments + in + let* terminator = M.throw_if_none Logging.(make_msg bb.location @@ Fmt.str "no terminator for bb%i : mir is broken.." lbl) bb.terminator in + match terminator with + | Return e -> + let+ _ = + begin + match e with + | Some e -> let+ t = mono_exp e decls in Some t + | None -> return None + end + in treated,blocks - | Invoke new_f -> - let* (id,_) = construct_call new_f.id new_f.params in - mono_body new_f.next treated BlockMap.(update lbl (fun _ -> Some {bb with terminator=Some (Invoke {new_f with id})}) blocks) - - | Goto lbl -> mono_body lbl treated blocks + | Invoke new_f -> + let* (id,_) = construct_call new_f.id new_f.params decls in + aux new_f.next treated BlockMap.(update lbl (fun _ -> Some {bb with terminator=Some (Invoke {new_f with id})}) blocks) + + | Goto lbl -> aux lbl treated blocks - | SwitchInt si -> - let* _ = mono_exp si.choice in - let* treated,blocks = mono_body si.default treated blocks in - ListM.fold_left ( fun (treated,blocks) (_,lbl) -> - mono_body lbl treated blocks - ) (treated,blocks) si.paths + | SwitchInt si -> + let* _ = mono_exp si.choice decls in + let* treated,blocks = aux si.default treated blocks in + ListM.fold_left ( fun (treated,blocks) (_,lbl) -> + aux lbl treated blocks + ) (treated,blocks) si.paths - | Break -> failwith "no break should be there" - end + | Break -> failwith "no break should be there" + end + in aux lbl LabelSet.empty blocks <&> snd in - match f.m_body with - | Right (decls,cfg) -> mono_body cfg.input LabelSet.empty cfg.blocks >>= fun (_,blocks) -> + | Right (decls,cfg) -> + let* blocks = mono_body cfg.input cfg.blocks decls in let params = List.map (fun (p:param) -> p.ty) f.m_proto.params in let name = mangle_method_name f.m_proto.name params in let methd = {m_proto = f.m_proto; m_body=Right (decls,{cfg with blocks})} in @@ -244,7 +249,9 @@ module Pass = Pass.Make (struct end else return () in - let* empty = M.get_monos >>| (=) [] in M.throw_if Error.(make dummy_pos "no monomorphic callable (no main?)") empty >>= aux + let* _empty = M.get_monos >>| (=) [] in + (* M.throw_if Logging.(make_msg dummy_pos "no monomorphic callable (no main?)") empty >>= *) + aux () let transform (smdl : in_body SailModule.t) : out_body SailModule.t E.t = diff --git a/src/passes/monomorphization/monomorphizationMonad.ml b/src/passes/monomorphization/monomorphizationMonad.ml index 19b49a6..cbb6a33 100644 --- a/src/passes/monomorphization/monomorphizationMonad.ml +++ b/src/passes/monomorphization/monomorphizationMonad.ml @@ -6,17 +6,21 @@ open MonomorphizationUtils type env = {monos: monomorphics; functions : sailor_functions; env: varTypesMap} module MonoMonad = struct - module S = MonadState.T(Error.Logger)(struct type t = env end) + module S = MonadState.T(Logging.Logger)(struct type t = env end) open MonadSyntax(S) open MonadOperator(S) include S (* error *) let throw e = E.throw e |> lift let throw_if e c = E.throw_if e c |> lift + let throw_if_none e c = E.throw_if_none e c |> lift + let get_decl id ty = get >>| fun e -> Env.get_decl id ty e.env let add_decl id decl ty = update (fun e -> E.bind (Env.add_decl id decl ty e.env) (fun env -> E.pure {e with env})) let get_var id = get >>| fun e -> Env.get_var id e.env + (* let declare_var id v = get >>| fun e -> Env.declare_var id v e.env *) + let set_ve ve = update (fun e -> E.pure {e with env=(ve,snd e.env)}) @@ -27,7 +31,7 @@ module MonoMonad = struct let get_monos = let+ e = S.get in e.monos let pop_monos = let* e = S.get in match e.monos with - | [] -> throw Error.(make dummy_pos "empty_monos") + | [] -> throw Logging.(make_msg dummy_pos "empty_monos") | h::monos -> S.set {e with monos} >>| fun () -> h let run (decls:Env.D.t) (x: 'a t) : ('a * env) E.t = x {monos=[];functions=FieldMap.empty;env=Env.empty decls} @@ -36,6 +40,7 @@ end let mangle_method_name (name : string) (args : sailtype list) : string = + if name = "main" then "main" else let back = List.fold_left (fun s t -> s ^ string_of_sailtype (Some t) ^ "_") "" args in diff --git a/src/passes/monomorphization/monomorphizationUtils.ml b/src/passes/monomorphization/monomorphizationUtils.ml index 868b8f6..3511a66 100644 --- a/src/passes/monomorphization/monomorphizationUtils.ml +++ b/src/passes/monomorphization/monomorphizationUtils.ml @@ -2,7 +2,7 @@ open Common open TypesCommon open Monad open IrHir -module E = Error.Logger +module E = Logging.Logger module Env = SailModule.SailEnv(IrMir.AstMir.V) open UseMonad(E) @@ -35,53 +35,52 @@ let print_method_proto (name : string) (methd : in_body sailor_method) = let resolveType (arg : sailtype) (m_param : sailtype) (generics : string list) (resolved_generics : sailor_args) : (sailtype * sailor_args) E.t = - let rec aux (a : sailtype) (m : sailtype) (g : sailor_args) = - match (a, m) with - | Bool, Bool -> return (Bool, g) - | Int x, Int y when x = y -> return (Int x, g) - | Float, Float -> return (Float, g) - | Char, Char -> return (Char, g) - | String, String -> return (String, g) - | ArrayType (at, s), ArrayType (mt, _) -> let+ t,g = aux at mt g in ArrayType (t, s), g - | GenericType _g1, GenericType _g2 -> return (Int 32,g) - (* E.throw Error.(make dummy_pos @@ Fmt.str "resolveType between generic %s and %s" g1 g2) *) - | at, GenericType gt -> - let* () = E.throw_if Error.(make dummy_pos @@ Fmt.str "generic type %s not declared" gt) (not @@ List.mem gt generics) in + let rec aux ((aloc, a) : sailtype) ((mloc, m) : sailtype) (g : sailor_args) : (sailtype * sailor_args) E.t = + match a,m with + | Bool, Bool -> return ((aloc,Bool), g) + | Int x, Int y when x = y -> return ((aloc,Int x), g) + | Float, Float -> return ((aloc,Float), g) + | Char, Char -> return ((aloc,Char), g) + | String, String -> return ((aloc,String), g) + | ArrayType (at, s), ArrayType (mt, _) -> let+ t,g = aux at mt g in (aloc,ArrayType (t, s)), g + | GenericType _g1, GenericType _g2 -> return ((aloc,Int 32),g) + (* E.throw Logging.(make_msg dummy_pos @@ Fmt.str "resolveType between generic %s and %s" g1 g2) *) + + | _, GenericType gt -> + let* () = E.throw_if Logging.(make_msg mloc @@ Fmt.str "generic type %s not declared" gt) (not @@ List.mem gt generics) in begin match List.assoc_opt gt g with - | None -> return (at, (gt, at) :: g) - | Some t -> + | None -> return ((aloc,a), (gt, (aloc,a)) :: g) + | Some (lt,t) -> E.throw_if - Error.(make dummy_pos @@ Fmt.str "generic type mismatch : %s -> %s vs %s" gt (string_of_sailtype (Some t)) (string_of_sailtype (Some at))) - (t <> at) - >>| fun () -> at, g + Logging.(make_msg lt @@ Fmt.str "generic type mismatch : %s -> %s vs %s" gt (string_of_sailtype (Some (lt,t))) (string_of_sailtype (Some (aloc,a)))) + (t <> a) + >>| fun () -> (aloc,a), g end | RefType (at, _), RefType (mt, _) -> aux at mt g | CompoundType _, CompoundType _ -> failwith "todocompoundtype" | Box _at, Box _mt -> failwith "todobox" - | _ -> E.throw Error.(make dummy_pos @@ Fmt.str "cannot happen : %s vs %s" (string_of_sailtype (Some a)) (string_of_sailtype (Some m))) + | _ -> E.throw Logging.(make_msg dummy_pos @@ Fmt.str "cannot happen : %s vs %s" (string_of_sailtype (Some (aloc,a))) (string_of_sailtype (Some (mloc,m)))) in aux arg m_param resolved_generics let degenerifyType (t : sailtype) (generics : sailor_args) : sailtype E.t = - let rec aux = function - | Bool -> return Bool - | Int n -> return (Int n) - | Float -> return Float - | Char -> return Char - | String -> return String - | ArrayType (t, s) -> let+ t = aux t in ArrayType (t, s) - | Box t -> let+ t = aux t in Box t - | RefType (t, m) -> let+ t = aux t in RefType (t, m) - | GenericType _t when generics = [] -> - (* E.throw Error.(make dummy_pos @@ Fmt.str "generic type %s present but empty generics list" t) *) - return (Int 32) - - | GenericType _n -> - (* E.throw_if_none Error.(make dummy_pos @@ Fmt.str "generic type %s not present in the generics list" n) (List.assoc_opt n generics) *) - return (Int 32) - | CompoundType _ -> failwith "todo compoundtype" + let rec aux (l,t) = + let+ t = match t with + | Bool -> return Bool + | Int n -> return (Int n) + | Float -> return Float + | Char -> return Char + | String -> return String + | ArrayType (t, s) -> let+ t = aux t in ArrayType (t, s) + | Box t -> let+ t = aux t in Box t + | RefType (t, m) -> let+ t = aux t in RefType (t, m) + | GenericType n -> + let+ t = E.throw_if_none Logging.(make_msg dummy_pos @@ Fmt.str "generic type %s not present in the generics list" n) (List.assoc_opt n generics) in + snd t + | CompoundType _ -> failwith "todo compoundtype" + in l,t in aux t @@ -106,7 +105,7 @@ let find_callable (name : string) (sm : _ SailModule.methods_processes SailModul match SailModule.DeclEnv.find_decl name (All (Method)) sm.declEnv with | [_,_] -> - return @@ List.find_opt (fun m -> print_string m.m_proto.name; print_newline (); m.m_proto.name = name) sm.body.methods + return @@ List.find_opt (fun m -> m.m_proto.name = name) sm.body.methods - | [] -> E.throw Error.(make dummy_pos @@ Fmt.str "mono : %s not found" name) - | l -> E.throw Error.(make dummy_pos @@ Fmt.str "multiple symbols for %s : %s" name (List.map (fun (i,_) -> i.mname) l |> String.concat " ")) \ No newline at end of file + | [] -> E.throw Logging.(make_msg dummy_pos @@ Fmt.str "mono : %s not found" name) + | l -> E.throw Logging.(make_msg dummy_pos @@ Fmt.str "multiple symbols for %s : %s" name (List.map (fun (i,_) -> i.mname) l |> String.concat " ")) \ No newline at end of file diff --git a/src/passes/process/process.ml b/src/passes/process/process.ml index d5b7187..dbaa83d 100644 --- a/src/passes/process/process.ml +++ b/src/passes/process/process.ml @@ -5,7 +5,7 @@ open SailParser open ProcessUtils module H = HirUtils module HirS = AstHir.Syntax -module E = Error.Logger +module E = Logging.Logger open ProcessMonad open Monad.UseMonad(M) @@ -20,13 +20,13 @@ module Pass = Pass.Make(struct let closed = FieldSet.add pi.proc closed in (* no cycle *) let* p = find_process_source (l,pi.proc) pi.mloc procs (*fixme : grammar to allow :: syntax *) in - let* p = M.throw_if_none Error.(make l @@ Fmt.str "process '%s' is unknown" pi.proc) p in + let* p = M.throw_if_none Logging.(make_msg l @@ Fmt.str "process '%s' is unknown" pi.proc) p in let* tag = M.fresh_prefix p.p_name in let prefix = (Fmt.str "%s_%s_" tag) in let read_params,write_params = p.p_interface.p_shared_vars in let param_arg_mismatch name p a = let pl,al = List.(length p,length a) in - M.throw_if Error.(make l @@ Fmt.str "number of %s params provided and required differ : %i vs %i" name al pl) (pl <> al) in + M.throw_if Logging.(make_msg l @@ Fmt.str "number of %s params provided and required differ : %i vs %i" name al pl) (pl <> al) in let* () = param_arg_mismatch "read" read_params pi.read in let* () = param_arg_mismatch "write" write_params pi.write in @@ -37,7 +37,7 @@ module Pass = Pass.Make(struct let rename = fun id -> match List.assoc_opt id rename_l with Some v -> v | None -> id in (* add process local (but persistant) vars *) - ListM.iter (fun (l,(id,(_,ty))) -> + ListM.iter (fun ((l,id),ty) -> let* ty,_ = HirUtils.follow_type ty sm.declEnv |> M.EC.lift |> M.ECW.lift |> M.lift in M.(write_decls HirS.(var (l,prefix id,ty))) ) p.p_body.locals >>= fun () -> @@ -65,9 +65,9 @@ module Pass = Pass.Make(struct return (process_cond cond s) | Run ((l,id),cond) -> - M.throw_if Error.(make l "not allowed to call Main process explicitely") (id = Constants.main_process) >>= fun () -> - M.throw_if Error.(make l "not allowed to have recursive process") (FieldSet.mem id closed) >>= fun () -> - let* l,pi = M.throw_if_none Error.(make l @@ Fmt.str "no proc init called '%s'" id) (List.find_opt (fun (_,p: loc * _ proc_init) -> p.id = id) p.p_body.proc_init) in + M.throw_if Logging.(make_msg l "not allowed to call Main process explicitely") (id = Constants.main_process) >>= fun () -> + M.throw_if Logging.(make_msg l "not allowed to have recursive process") (FieldSet.mem id closed) >>= fun () -> + let* l,pi = M.throw_if_none Logging.(make_msg l @@ Fmt.str "no proc init called '%s'" id) (List.find_opt (fun (_,p: loc * _ proc_init) -> p.id = id) p.p_body.proc_init) in let read = List.map (fun (l,id) -> l,prefix id) pi.read in let write = List.map (fun (l,id) -> l,prefix id) pi.write in let params = List.map (H.rename_var_exp prefix) pi.params in @@ -82,7 +82,7 @@ module Pass = Pass.Make(struct let open Monad.MonadOperator(E) in ( - let* m = M.throw_if_none (Error.make dummy_pos "need main process") + let* m = M.throw_if_none Logging.(make_msg dummy_pos "need main process") (List.find_opt (fun p -> p.p_name = Constants.main_process) procs) in let (pi: _ proc_init) = {mloc=None; read = []; write = [] ; params = [] ; id = Constants.main_process ; proc = Constants.main_process} in diff --git a/src/passes/process/processMonad.ml b/src/passes/process/processMonad.ml index 1638026..e4fcc1c 100644 --- a/src/passes/process/processMonad.ml +++ b/src/passes/process/processMonad.ml @@ -14,7 +14,7 @@ module V = ( module M = struct open AstHir - module E = Error.Logger + module E = Logging.Logger module Env = Env.VariableDeclEnv(SailModule.Declarations)(V) module SeqMonoid = struct diff --git a/src/passes/process/processUtils.ml b/src/passes/process/processUtils.ml index ea5e83d..cff855a 100644 --- a/src/passes/process/processUtils.ml +++ b/src/passes/process/processUtils.ml @@ -3,7 +3,7 @@ open TypesCommon open ProcessMonad open Monad.UseMonad(M) open IrHir -module E = Error.Logger +module E = Logging.Logger module D = SailModule.Declarations type body = (Hir.statement,(Hir.statement,Hir.expression) SailParser.AstParser.process_body) SailModule.methods_processes @@ -43,7 +43,7 @@ let find_process_source (name: l_str) (import : l_str option) procs : 'a process if origin = HirUtils.D.get_name env then return procs else let find_import = List.find_opt (fun i -> i.mname = origin) (HirUtils.D.get_imports env) in - let+ i = M.throw_if_none Error.(make dummy_pos "can't happen") find_import in + let+ i = M.throw_if_none Logging.(make_msg dummy_pos "can't happen") find_import in let sm = In_channel.with_open_bin (i.dir ^ i.mname ^ Constants.mir_file_ext) @@ fun c -> (Marshal.from_channel c : Mono.MonomorphizationUtils.out_body SailModule.t) in sm.body.processes in diff --git a/test/blackbox-tests/sailor.t/test_utils.sl b/test/blackbox-tests/sailor.t/test_utils.sl index 05c4cef..91e8714 100644 --- a/test/blackbox-tests/sailor.t/test_utils.sl +++ b/test/blackbox-tests/sailor.t/test_utils.sl @@ -15,3 +15,7 @@ method print_string(v : string) { method print_newline() { printf("\n") } + +method quit() { + exit(0); +} \ No newline at end of file From cda4f0a673c860eccd554f2bd767744a7cde62f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?T=C3=A9rence=20Clastres?= Date: Sat, 2 Sep 2023 00:57:08 +0200 Subject: [PATCH 3/7] Generic Ast - Ast is shared between the passes - Subtyping with GADT allows to keep constraints to differenciate between expression and statements and between transformations --- src/codegen/codegenEnv.ml | 30 +- src/codegen/codegenUtils.ml | 20 +- src/codegen/codegen_.ml | 103 ++++--- src/common/builtins.ml | 2 +- src/common/env.ml | 6 +- src/common/monadic/monad.ml | 23 +- src/common/ppCommon.ml | 8 +- src/common/sailModule.ml | 2 +- src/common/typesCommon.ml | 50 +-- src/parsing/astParser.ml | 16 +- src/parsing/parser.mly | 26 +- src/passes/ir/baseAst/ast.ml | 134 ++++++++ src/passes/ir/baseAst/dune | 3 + src/passes/ir/baseAst/ppAst.ml | 73 +++++ src/passes/ir/baseAst/utils.ml | 157 ++++++++++ src/passes/ir/sailHir/astHir.ml | 95 ------ src/passes/ir/sailHir/dune | 2 +- src/passes/ir/sailHir/hir.ml | 142 +++++---- src/passes/ir/sailHir/hirMonad.ml | 5 +- src/passes/ir/sailHir/hirUtils.ml | 205 +++++------- src/passes/ir/sailHir/pp_hir.ml | 64 ++-- src/passes/ir/sailMir/mir.ml | 120 ++++---- .../ir/sailMir/{astMir.ml => mirAst.ml} | 4 +- src/passes/ir/sailMir/mirMonad.ml | 2 +- src/passes/ir/sailMir/mirUtils.ml | 8 +- src/passes/ir/sailMir/pp_mir.ml | 28 +- src/passes/ir/sailThir/thir.ml | 291 +++++++++--------- src/passes/ir/sailThir/thirUtils.ml | 104 ++++--- src/passes/misc/cfg_analysis.ml | 4 +- .../monomorphization/monomorphization.ml | 127 ++++---- .../monomorphization/monomorphizationUtils.ml | 56 ++-- src/passes/process/dune | 2 +- src/passes/process/process.ml | 50 +-- src/passes/process/processMonad.ml | 10 +- src/passes/process/processUtils.ml | 12 +- 35 files changed, 1138 insertions(+), 846 deletions(-) create mode 100644 src/passes/ir/baseAst/ast.ml create mode 100644 src/passes/ir/baseAst/dune create mode 100644 src/passes/ir/baseAst/ppAst.ml create mode 100644 src/passes/ir/baseAst/utils.ml delete mode 100644 src/passes/ir/sailHir/astHir.ml rename src/passes/ir/sailMir/{astMir.ml => mirAst.ml} (96%) diff --git a/src/codegen/codegenEnv.ml b/src/codegen/codegenEnv.ml index 6bf9e9e..f1dc565 100644 --- a/src/codegen/codegenEnv.ml +++ b/src/codegen/codegenEnv.ml @@ -12,7 +12,7 @@ open MakeOrderedFunctions(ImportCmp) module Declarations = struct include SailModule.Declarations type process_decl = unit - type method_decl = {defn : AstMir.mir_function method_defn ; llval : llvalue ; extern : bool} + type method_decl = {defn : MirAst.mir_function method_defn ; llval : llvalue ; extern : bool} type struct_decl = {defn : struct_proto; ty : lltype} type enum_decl = unit end @@ -34,9 +34,9 @@ open Declarations type in_body = Monomorphization.Pass.out_body -let getLLVMBasicType f t llc llm : lltype E.t = - let rec aux t = - match snd t with +let getLLVMBasicType f ty llc llm : lltype E.t = + let rec aux ty = + match ty.value with | Bool -> i1_type llc |> return | Int n -> integer_type llc n |> return | Float -> double_type llc |> return @@ -44,13 +44,13 @@ let getLLVMBasicType f t llc llm : lltype E.t = | String -> i8_type llc |> pointer_type |> return | ArrayType (t,s) -> let+ t = aux t in array_type t s | Box t | RefType (t,_) -> aux t <&> pointer_type - | GenericType _ -> E.throw Logging.(make_msg (fst t) "no generic type in codegen") - | CompoundType {name=(_,name); _} when name = "_value" -> i64_type llc |> return (* for extern functions *) + | GenericType _ -> E.throw Logging.(make_msg ty.loc "no generic type in codegen") + | CompoundType {name; _} when name.value = "_value" -> i64_type llc |> return (* for extern functions *) | CompoundType {origin=None;_} - | CompoundType {decl_ty=None;_} -> E.throw Logging.(make_msg (fst t) "compound type with no origin or decl_ty") - | CompoundType {origin=Some (_,mname); name=(_,name); decl_ty=Some d;_} -> - f (mname,name,d) llc llm aux - in aux t + | CompoundType {decl_ty=None;_} -> E.throw Logging.(make_msg ty.loc "compound type with no origin or decl_ty") + | CompoundType {origin=Some mname; name; decl_ty=Some d;_} -> + f (mname.value,name.value,d) llc llm aux + in aux ty let handle_compound_type_codegen env (mname,name,d) llc _llm (aux : sailtype -> lltype E.t) : lltype E.t = @@ -80,7 +80,7 @@ let getLLVMBasicType f t llc llm : lltype E.t = | Some (E _enum) -> failwith "todo enum" | Some (S (_,defn)) -> let _,f_types = List.split defn.fields in - let* elts = ListM.map (fun (_,t,_) -> aux t) f_types <&> Array.of_list in + let* elts = ListM.map (fun ty -> aux (fst ty.value)) f_types <&> Array.of_list in begin match type_by_name llm ("struct." ^ name) with | Some ty -> return ty @@ -129,7 +129,7 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t = ); let valueify_method_sig (m:method_sig) : method_sig = - let value = fun pos -> dummy_pos,CompoundType{origin=None;name=(pos,"_value");generic_instances=[];decl_ty=None} in + let value = fun pos -> mk_locatable dummy_pos @@ CompoundType{origin=None;name=(mk_locatable pos "_value");generic_instances=[];decl_ty=None} in let rtype = m.rtype in (* keep the current type *) let params = List.map (fun (p:param) -> {p with ty=(value p.loc)}) m.params in {m with params; rtype} @@ -137,8 +137,8 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t = (* because the imports are at the mir stage, we also have to do some codegen for them *) - let load_methods (methods: IrMir.AstMir.mir_function method_defn list) is_import env = - ListM.fold_left ( fun d (m:IrMir.AstMir.mir_function method_defn) -> + let load_methods (methods: IrMir.MirAst.mir_function method_defn list) is_import env = + ListM.fold_left ( fun d (m:IrMir.MirAst.mir_function method_defn) -> let extern,proto = if (Either.is_left m.m_body) then (* extern method, all parameters must be of type value *) true,valueify_method_sig m.m_proto @@ -167,7 +167,7 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t = let load_structs structs write_env = SEnv.fold (fun acc (name,(_,defn)) -> let _,f_types = List.split defn.fields in - let* elts = ListM.map (fun (_,t,_) -> _getLLVMType sm.declEnv t llc llm) f_types <&> Array.of_list in + let* elts = ListM.map (fun ty-> _getLLVMType sm.declEnv (fst ty.value) llc llm) f_types <&> Array.of_list in let ty = match type_by_name llm ("struct." ^ name) with | Some ty -> ty | None -> let ty = named_struct_type llc ("struct." ^ name) in diff --git a/src/codegen/codegenUtils.ml b/src/codegen/codegenUtils.ml index d0d9d44..8cf2fe4 100644 --- a/src/codegen/codegenUtils.ml +++ b/src/codegen/codegenUtils.ml @@ -20,16 +20,16 @@ let getLLVMLiteral (l:literal) (llvm:llvm_args) : llvalue = | LChar c -> const_int (i8_type llvm.c) (Char.code c) | LString s -> build_global_stringptr s ".str" llvm.b -let ty_of_alias(t:sailtype) env : sailtype = - match snd t with - | CompoundType {origin=Some (_,mname); name=(_,name);decl_ty=Some T ();_} -> +let ty_of_alias(ty:sailtype) env : sailtype = + match ty.value with + | CompoundType {origin=Some mname; name;decl_ty=Some T ();_} -> begin - match DeclEnv.find_decl name (Specific (mname,Type)) env with + match DeclEnv.find_decl name.value (Specific (mname.value,Type)) env with | Some {ty=Some t';_} -> t' - | Some {ty=None;_} -> t - | None -> failwith @@ Fmt.str "ty_of_alias : '%s' not found in %s" (string_of_sailtype (Some t)) mname + | Some {ty=None;_} -> ty + | None -> failwith @@ Fmt.str "ty_of_alias : '%s' not found in %s" (string_of_sailtype (Some ty)) mname.value end - | _ -> t + | _ -> ty let unary (op:unOp) (t,v) : llbuilder -> llvalue = let f = @@ -37,7 +37,7 @@ let unary (op:unOp) (t,v) : llbuilder -> llvalue = | Float,Neg -> build_fneg | Int _,Neg -> build_neg | _,Not -> build_not - | _ -> Printf.sprintf "bad unary operand type : '%s'. Only double and int are supported" (string_of_sailtype (Some t)) |> failwith + | _ -> Printf.sprintf "bad unary operand type : '%s'. Only double and int are supported" (string_of_sailtype (Some (mk_locatable (fst t) (snd t)))) |> failwith in f v "" @@ -76,8 +76,8 @@ let binary (op:binOp) (t:sailtype) (l1:llvalue) (l2:llvalue) : llbuilder -> llva | And -> "and" | Or -> "or" | Le -> "le" | Lt -> "lt" | Ge -> "ge" | Gt -> "gt" | Mul -> "mul" | NEq -> "neq" | Div -> "div" in - let t = if snd t = Bool then fst t,Int 1 else t in (* thir will have checked for correctness *) - let l = operators (snd t) in + let t = if t.value = Bool then mk_locatable t.loc @@ Int 1 else t in (* thir will have checked for correctness *) + let l = operators t.value in let open Common.Monad.MonadOperator(Common.MonadOption.M) in match l >>| List.assoc_opt op |> Option.join with | Some oper -> oper l1 l2 "" diff --git a/src/codegen/codegen_.ml b/src/codegen/codegen_.ml index 871ec93..960cffc 100644 --- a/src/codegen/codegen_.ml +++ b/src/codegen/codegen_.ml @@ -6,10 +6,10 @@ open IrMir open Monad.UseMonad(E) module L = Llvm module E = Logging.Logger -let get_type (e:AstMir.expression) = snd e.info +let get_type (e:MirAst.expression) = e.tag.ty -let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x: AstMir.expression) : L.llvalue E.t = - match x.exp with +let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp: MirAst.expression) : L.llvalue E.t = + match exp.node with | Variable x -> let+ _,v = match (SailEnv.get_var x venv) with | Some (_,n) -> return n @@ -18,50 +18,57 @@ let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x: | Deref x -> eval_r env llvm x - | ArrayRead (array_exp, index_exp) -> - let* array_val = eval_l env llvm array_exp in - let+ index = eval_r env llvm index_exp in + | ArrayRead a -> + let* array_val = eval_l env llvm a.array in + let+ index = eval_r env llvm a.idx in let llvm_array = L.build_in_bounds_gep array_val [|L.(const_int (i64_type llvm.c) 0 ); index|] "" llvm.b in llvm_array - | StructRead ((_,mname),struct_exp,(_,field)) -> - let* st = eval_l env llvm struct_exp in - let+ st_type_name = Env.TypeEnv.get_from_id struct_exp.info tenv >>= function _,CompoundType c -> return (snd c.name) | _ -> E.throw Logging.(make_msg dummy_pos "problem with structure type") in - let fields = (SailEnv.get_decl st_type_name (Specific (mname,Struct)) venv |> Option.get).defn.fields in - let _,_,idx = List.assoc field fields in + | StructRead2 s -> + let* st = eval_l env llvm s.value.strct in + let* st_type_name = Env.TypeEnv.get_from_id (mk_locatable s.value.strct.tag.loc s.value.strct.tag.ty) tenv >>= function + | {value=CompoundType c;_} -> return c.name.value + | _ -> E.throw Logging.(make_msg dummy_pos "problem with structure type") + in + let+ decl = (SailEnv.get_decl st_type_name (Specific (s.import.value,Struct)) venv + |> E.throw_if_none Logging.(make_msg exp.tag.loc @@ Fmt.str "compiler error : no decl '%s' found" st_type_name)) in + + let fields = decl.defn.fields in + let {value=_,idx;_} = List.assoc s.value.field.value fields in L.build_struct_gep st idx "" llvm.b - | StructAlloc (_,(_,name),fields) -> - let _,fieldlist = fields |> List.split in - let* strct_ty = match L.type_by_name llvm.m ("struct." ^ name) with + | StructAlloc2 s -> + let _,fieldlist = s.value.fields |> List.split in + let* strct_ty = match L.type_by_name llvm.m ("struct." ^ s.value.name.value) with | Some s -> return s | None -> - E.throw Logging.(make_msg (fst x.info) @@ "unknown structure : " ^ ("struct." ^ name)) + E.throw Logging.(make_msg exp.tag.loc @@ "unknown structure : " ^ ("struct." ^ s.value.name.value)) in let struct_v = L.build_alloca strct_ty "" llvm.b in - let+ () = ListM.iteri ( fun i (_,f) -> - let+ v = eval_r env llvm f in + let+ () = ListM.iteri ( fun i f -> + let+ v = eval_r env llvm f.value in let v_f = L.build_struct_gep struct_v i "" llvm.b in L.build_store v v_f llvm.b |> ignore ) fieldlist in struct_v - | _ -> E.throw Logging.(make_msg (fst x.info) "unexpected rvalue for codegen") + | Literal _ | UnOp _ | BinOp _ | Ref _ | ArrayStatic _ | EnumAlloc _ -> + E.throw Logging.(make_msg exp.tag.loc "unexpected lvalue for codegen") -and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x:AstMir.expression) : L.llvalue E.t = - let* ty = Env.TypeEnv.get_from_id x.info tenv in - match x.exp with - | Variable _ | StructRead _ | ArrayRead _ | StructAlloc _ -> let+ v = eval_l env llvm x in L.build_load v "" llvm.b +and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp:MirAst.expression) : L.llvalue E.t = + let* ty = Env.TypeEnv.get_from_id (mk_locatable exp.tag.loc exp.tag.ty) tenv in + match exp.node with + | Variable _ | StructRead2 _ | ArrayRead _ | StructAlloc2 _ -> let+ v = eval_l env llvm exp in L.build_load v "" llvm.b | Literal l -> return @@ getLLVMLiteral l llvm - | UnOp (op,e) -> let+ l = eval_r env llvm e in unary op (ty_of_alias ty (snd venv),l) llvm.b + | UnOp (op,e) -> let+ l = eval_r env llvm e in unary op (let ty = ty_of_alias ty (snd venv) in (ty.loc,ty.value),l) llvm.b - | BinOp (op,e1, e2) -> - let+ l1 = eval_r env llvm e1 - and* l2 = eval_r env llvm e2 - in binary op (ty_of_alias ty (snd venv)) l1 l2 llvm.b + | BinOp bop -> + let+ l1 = eval_r env llvm bop.left + and* l2 = eval_r env llvm bop.right + in binary bop.op (ty_of_alias ty (snd venv)) l1 l2 llvm.b | Ref (_,e) -> eval_l env llvm e | Deref e -> let+ v = eval_l env llvm e in L.build_load v "" llvm.b @@ -80,32 +87,31 @@ and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (x:AstMi L.build_load array "" llvm.b end - | EnumAlloc _ -> E.throw Logging.(make_msg (fst x.info) "enum allocation unimplemented") + | EnumAlloc _ -> E.throw Logging.(make_msg exp.tag.loc "enum allocation unimplemented") - | _ -> E.throw Logging.(make_msg (fst x.info) "problem with thir") -and construct_call (name:string) ((loc,mname):l_str) (args:AstMir.expression list) (venv,tenv as env : SailEnv.t*Env.TypeEnv.t) (llvm:llvm_args) : L.llvalue E.t = - let* args_type,llargs = ListM.map (fun arg -> let+ r = eval_r env llvm arg in arg.info,r) args >>| List.split +and construct_call (name:string) (mname:l_str) (args:MirAst.expression list) (venv,tenv as env : SailEnv.t*Env.TypeEnv.t) (llvm:llvm_args) : L.llvalue E.t = + let* args_type,llargs = ListM.map (fun arg -> let+ r = eval_r env llvm arg in arg.tag,r) args >>| List.split in (* let mname = mangle_method_name name origin.mname args_type in *) - let mangled_name = "_" ^ mname ^ "_" ^ name in + let mangled_name = "_" ^ mname.value ^ "_" ^ name in Logs.debug (fun m -> m "constructing call to %s" name); - let* llval,ext = match SailEnv.get_decl mangled_name (Specific (mname,Method)) venv with + let* llval,ext = match SailEnv.get_decl mangled_name (Specific (mname.value,Method)) venv with | None -> begin - match SailEnv.get_decl name (Specific (mname,Method)) venv with + match SailEnv.get_decl name (Specific (mname.value,Method)) venv with | Some {llval;extern;_} -> return (llval,extern) - | None -> E.throw Logging.(make_msg loc @@ Printf.sprintf "implementation of %s not found" mangled_name ) + | None -> E.throw Logging.(make_msg mname.loc @@ Printf.sprintf "implementation of %s not found" mangled_name ) end | Some {llval;extern;_} -> return (llval,extern) in let+ args = if ext then - ListM.map2 (fun t v -> - let+ t = Env.TypeEnv.get_from_id t tenv in + ListM.map2 (fun (t:IrThir.ThirUtils.tag) v -> + let+ t = Env.TypeEnv.get_from_id (mk_locatable t.loc t.ty) tenv in let builder = - match snd (ty_of_alias t (snd venv)) with + match (ty_of_alias t (snd venv)).value with | Bool | Int _ | Char -> L.build_zext | Float -> L.build_bitcast | CompoundType _ -> fun v _ _ _ -> v @@ -118,16 +124,16 @@ and construct_call (name:string) ((loc,mname):l_str) (args:AstMir.expression lis in L.build_call llval (Array.of_list args) "" llvm.b -open AstMir +open MirAst let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,tenv : SailEnv.t*Env.TypeEnv.t) : unit E.t = - let declare_var (mut:bool) (name:string) (ty:sailtype) (exp:AstMir.expression option) (venv : SailEnv.t) : SailEnv.t E.t = + let declare_var (mut:bool) (name:string) (ty:sailtype) (exp:MirAst.expression option) (venv : SailEnv.t) : SailEnv.t E.t = let _ = mut in (* todo manage mutable types *) let entry_b = L.(entry_block proto |> instr_begin |> builder_at llvm.c) in let* v = match exp with | Some e -> - let* t = Env.TypeEnv.get_from_id e.info tenv + let* t = Env.TypeEnv.get_from_id (mk_locatable e.tag.loc e.tag.ty) tenv and* v = eval_r (venv,tenv) llvm e in let+ ty = getLLVMType (snd venv) t llvm.c llvm.m in let x = L.build_alloca ty name entry_b in @@ -154,20 +160,21 @@ let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,t let llvm_bbs = BlockMap.add lbl llvm_bb llvm_bbs in L.position_at_end llvm_bb llvm.b; let* () = ListM.iter (fun x -> assign_var x.target x.expression (venv,tenv)) bb.assignments in - match bb.terminator with - | Some (Return e) -> + let* terminator = E.throw_if_none Logging.(make_msg bb.location "no terminator : mir is broken") bb.terminator in + match terminator with + | Return e -> let+ ret = match e with | Some r -> let+ v = eval_r (venv,tenv) llvm r in L.build_ret v | None -> return L.build_ret_void in ret llvm.b |> ignore; llvm_bbs - | Some (Goto lbl) -> + | Goto lbl -> let+ llvm_bbs = aux lbl llvm_bbs venv in L.position_at_end llvm_bb llvm.b; let _ = L.build_br (BlockMap.find lbl llvm_bbs) llvm.b in llvm_bbs - | Some (Invoke f) -> + | Invoke f -> let* c = construct_call f.id f.origin f.params (venv,tenv) llvm in begin match f.target with @@ -179,7 +186,7 @@ let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,t L.build_br (BlockMap.find f.next llvm_bbs) llvm.b |> ignore; llvm_bbs - | Some (SwitchInt si) -> + | SwitchInt si -> let* sw_val = eval_r (venv,tenv) llvm si.choice in let sw_val = L.build_intcast sw_val (L.i32_type llvm.c) "" llvm.b in (* for condition, expression val will be bool *) let* llvm_bbs = aux si.default llvm_bbs venv in @@ -192,9 +199,7 @@ let cfgToIR (proto:L.llvalue) (decls,cfg: mir_function) (llvm:llvm_args) (venv,t in L.add_case sw n (BlockMap.find lbl bm); bm ) llvm_bbs si.paths - - | None -> E.throw Logging.(make_msg bb.location "no terminator : mir is broken") - | Some Break -> E.throw Logging.(make_msg bb.location "no break should be there") + | Break -> E.throw Logging.(make_msg bb.location "no break should be there") end in ( diff --git a/src/common/builtins.ml b/src/common/builtins.ml index 8f00dab..0a507a2 100644 --- a/src/common/builtins.ml +++ b/src/common/builtins.ml @@ -5,4 +5,4 @@ let register_builtin name generics p rtype variadic l: method_sig list = let get_builtins () : method_sig list = [] - |> register_builtin "box" ["T"] [dummy_pos,GenericType "T"] (Some (dummy_pos,Box (dummy_pos,GenericType "T"))) false \ No newline at end of file + |> register_builtin "box" ["T"] [mk_locatable dummy_pos (GenericType "T")] (Some (mk_locatable dummy_pos (Box (mk_locatable dummy_pos (GenericType "T"))))) false \ No newline at end of file diff --git a/src/common/env.ml b/src/common/env.ml index 6f4368a..40eae86 100644 --- a/src/common/env.ml +++ b/src/common/env.ml @@ -388,7 +388,7 @@ module TypeEnv = struct let empty = FieldMap.empty let get_id ty (te :t) : string * t = let add_if_no_exists s = FieldMap.update s (Option.fold ~none:(Some ty) ~some:Option.some) in - let s = match snd ty with + let s = match ty.value with | Bool -> "bool" | Float -> "float" | Char -> "char" @@ -396,12 +396,12 @@ module TypeEnv = struct | Int n -> "int" ^ string_of_int n | ArrayType _ -> "array" | GenericType t -> t - | CompoundType t -> snd t.name + | CompoundType t -> t.name.value | Box _ -> "box" | RefType _ -> "ref" in s,add_if_no_exists s te - let get_from_id (lid,id:l_str) te : sailtype E.t = E.throw_if_none Logging.(make_msg lid @@ Fmt.str "id '%s' not found" id) (FieldMap.find_opt id te) + let get_from_id (id:l_str) te : sailtype E.t = E.throw_if_none Logging.(make_msg id.loc @@ Fmt.str "id '%s' not found" id.value) (FieldMap.find_opt id.value te) end diff --git a/src/common/monadic/monad.ml b/src/common/monadic/monad.ml index ac37192..20435a2 100644 --- a/src/common/monadic/monad.ml +++ b/src/common/monadic/monad.ml @@ -43,12 +43,26 @@ module type Functor = sig val fmap : ('a -> 'b) -> 'a t -> 'b t end +module FunctorOperator(F : Functor) = struct + let (<$>) = F.fmap + let (<&>) = fun x f -> F.fmap f x + let (<$) (type a b) : a -> b F.t -> a F.t = fun x y -> F.fmap (fun _ -> x) y + let ($>) = fun x y -> (<$) y x + let void = fun x -> () <$ x +end + module type Applicative = sig include Functor val pure : 'a -> 'a t val apply : ('a -> 'b ) t -> 'a t -> 'b t end +module ApplicativeOperator (A : Applicative) = struct + include FunctorOperator(A) + let (<*>) : ('a -> 'b) A.t -> 'a A.t -> 'b A.t = A.apply + let ( *>) : 'a A.t -> 'b A.t -> 'b A.t = fun x y -> (Fun.id <$ x) <*> y +end + module type Monad = sig include Applicative val bind : 'a t -> ('a -> 'b t) -> 'b t @@ -69,11 +83,14 @@ module MonadIdentity : Monad with type 'a t = 'a = struct end module MonadOperator (M : Monad) = struct - let (<*>) = M.apply - let (<$>) = M.fmap - let (<&>) = fun x f -> M.fmap f x + include ApplicativeOperator(M) let (>>=) = M.bind + let (=<<) = fun x y -> M.bind y x let (>>|) x f = x >>= fun x -> f x |> M.pure + let (>=>) : ('a -> 'b M.t) -> ('b -> 'c M.t) -> 'a -> 'c M.t = fun f1 f2 x -> f1 x >>= f2 + let (<=<) : 'a -> ('a -> 'b M.t) -> ('b -> 'c M.t) -> 'c M.t = fun x f1 f2 -> (f1 >=> f2) x + + let ap : ('a -> 'b) M.t -> 'a M.t -> 'b M.t = fun f x -> f >>= fun f -> f <$> x end module MonadSyntax (M : Monad ) = struct diff --git a/src/common/ppCommon.ml b/src/common/ppCommon.ml index f6064b8..adee878 100644 --- a/src/common/ppCommon.ml +++ b/src/common/ppCommon.ml @@ -40,16 +40,16 @@ let pp_binop pf b = - let rec pp_type (pf : formatter) (_,t : sailtype) : unit = - match t with + let rec pp_type (pf : formatter) (t : sailtype) : unit = + match t.value with Bool -> pp_print_string pf "bool" | Int n -> Format.fprintf pf "i%i" n | Float -> pp_print_string pf "float" | Char -> pp_print_string pf "char" | String -> pp_print_string pf "string" | ArrayType (t,s) -> Format.fprintf pf "array<%a;%d>" pp_type t s - | CompoundType {name=(_,x); generic_instances;_} -> - Format.fprintf pf "%s<%a>" x (pp_print_list ~pp_sep:pp_comma pp_type) generic_instances + | CompoundType {name; generic_instances;_} -> + Format.fprintf pf "%s<%a>" name.value (pp_print_list ~pp_sep:pp_comma pp_type) generic_instances | Box(t) -> Format.fprintf pf "ref<%a>" pp_type t | RefType (t,b) -> if b then Format.fprintf pf "&mut %a" pp_type t diff --git a/src/common/sailModule.ml b/src/common/sailModule.ml index fd0dc1d..78f9605 100644 --- a/src/common/sailModule.ml +++ b/src/common/sailModule.ml @@ -49,4 +49,4 @@ let method_decl_of_defn (d : 'a method_defn) : Declarations.method_decl = and args = d.m_proto.params and generics = d.m_proto.generics and variadic = d.m_proto.variadic in - ((pos,name),{ret;args;generics;variadic}) \ No newline at end of file + (mk_locatable pos name,{ret;args;generics;variadic}) \ No newline at end of file diff --git a/src/common/typesCommon.ml b/src/common/typesCommon.ml index bfdae09..100a326 100644 --- a/src/common/typesCommon.ml +++ b/src/common/typesCommon.ml @@ -26,9 +26,20 @@ module FieldSet = Set.Make (String) type loc = Lexing.position * Lexing.position let dummy_pos : loc = Lexing.dummy_pos,Lexing.dummy_pos +type ('i,'n) tagged_node = {tag : 'i; node : 'n} +let mk_tagged_node tag node = {tag;node} + +type ('i,'v) importable = {import : 'i; value : 'v} +let mk_importable import value = {import;value} + +type 'v locatable = {loc : loc; value : 'v} +let mk_locatable loc value = {loc;value} + + + type 'a dict = (string * 'a) list -type l_str = loc * string +type l_str = string locatable type ('m,'p,'s,'e,'t) decl_sum = M of 'm | P of 'p | S of 's | E of 'e | T of 't @@ -50,7 +61,7 @@ let string_of_decl : (_,_,_,_,_) decl_sum -> string = function | T _ -> "type" -type sailtype = loc * sailtype_ and sailtype_ = +type sailtype = sailtype_ locatable and sailtype_ = | Bool | Int of int | Float @@ -75,27 +86,27 @@ type literal = | LString of string let sailtype_of_literal = function -| LBool _ -> dummy_pos,Bool -| LFloat _ -> dummy_pos,Float -| LInt l -> dummy_pos,Int l.size -| LChar _ -> dummy_pos,Char -| LString _ -> dummy_pos,String +| LBool _ -> mk_locatable dummy_pos Bool +| LFloat _ -> mk_locatable dummy_pos Float +| LInt l -> mk_locatable dummy_pos @@ Int l.size +| LChar _ -> mk_locatable dummy_pos Char +| LString _ -> mk_locatable dummy_pos String let rec string_of_sailtype (t : sailtype option) : string = let open Printf in match t with - | Some (_,t) -> + | Some t -> begin - match t with + match t.value with | Bool -> "bool" | Int size -> "i" ^ string_of_int size | Float -> "float" | Char -> "char" | String -> "string" | ArrayType (t,s) -> sprintf "array<%s;%d>" (string_of_sailtype (Some t)) s - | CompoundType {name=(_,x); generic_instances=[];_} -> (* empty compound type -> lookup what it binds to *) sprintf "%s" x - | CompoundType {name=(_,x); generic_instances;_} -> sprintf "%s<%s>" x (String.concat ", " (List.map (fun t -> string_of_sailtype (Some t)) generic_instances)) + | CompoundType {name; generic_instances=[];_} -> (* empty compound type -> lookup what it binds to *) sprintf "%s" name.value + | CompoundType {name; generic_instances;_} -> sprintf "%s<%s>" name.value (String.concat ", " (List.map (fun t -> string_of_sailtype (Some t)) generic_instances)) | Box(t) -> sprintf "ref<%s>" (string_of_sailtype (Some t)) | RefType (t,b) -> if b then sprintf "&mut %s" (string_of_sailtype (Some t)) @@ -136,7 +147,11 @@ type enum_defn = -type interface = {p_params: param list ; p_shared_vars: ((loc * (string * sailtype)) list * (loc * (string * sailtype)) list)} +type interface = +{ + p_params: param list ; + p_shared_vars: (string * sailtype) locatable list * (string * sailtype) locatable list +} type 'a process_defn = { @@ -148,7 +163,7 @@ type 'a process_defn = } type 'e proc_init = { - mloc : l_str option; + mloc : l_str option; id : string; proc : string; params : 'e list; @@ -189,7 +204,7 @@ type enum_proto = type struct_proto = { generics : string list; - fields : (loc * sailtype * int) dict + fields : (sailtype * int) locatable dict } type method_proto = @@ -202,8 +217,8 @@ type method_proto = type process_proto = { - read : (loc * (string * sailtype)) list; - write : (loc * (string * sailtype)) list; + read : (string * sailtype) locatable list; + write :(string * sailtype) locatable list; params : param list; generics : string list; } @@ -226,7 +241,7 @@ let defn_to_proto (type proto) (decl: proto decl) : proto = match decl with and generics = d.p_generics and params = d.p_interface.p_params in {read;write;generics;params} -| Struct d -> {generics=d.s_generics;fields=List.mapi (fun i ((l,n),t) -> n,(l,t,i)) d.s_fields} +| Struct d -> {generics=d.s_generics;fields=List.mapi (fun i (n,t) -> n.value,mk_locatable n.loc (t,i)) d.s_fields} | Enum d -> {generics=d.e_generics;injections=d.e_injections} type import = @@ -237,7 +252,6 @@ type import = proc_order: int; } - module ImportCmp = struct type t = import let compare i1 i2 = String.compare i1.mname i2.mname end module ImportSet = Set.Make(ImportCmp) diff --git a/src/parsing/astParser.ml b/src/parsing/astParser.ml index 93e94c9..6bd08bd 100644 --- a/src/parsing/astParser.ml +++ b/src/parsing/astParser.ml @@ -24,7 +24,7 @@ open Common open TypesCommon (* expressions are control free *) -type expression = loc * expression_ and expression_ = +type expression = expression_ locatable and expression_ = Variable of string | Deref of expression | StructRead of expression * l_str @@ -34,7 +34,7 @@ type expression = loc * expression_ and expression_ = | BinOp of binOp * expression * expression | Ref of bool * expression | ArrayStatic of expression list - | StructAlloc of l_str option * l_str * (loc * expression) dict + | StructAlloc of l_str option * l_str * expression locatable dict | EnumAlloc of l_str * expression list | MethodCall of l_str option * l_str * expression list @@ -43,10 +43,10 @@ type pattern = | PVar of string | PCons of string * pattern list -type statement = loc * statement_ and statement_ = +type statement = statement_ locatable and statement_ = | DeclVar of bool * string * sailtype option * expression option | Skip - | Assign of expression * expression + | Assign of {path:expression; value:expression} | Seq of statement * statement | If of expression * statement * statement option | While of expression * statement @@ -61,7 +61,7 @@ type statement = loc * statement_ and statement_ = type pgroup_ty = Sequence | Parallel -type ('s,'e) p_statement = loc * ('s,'e) p_statement_ and ('s,'e) p_statement_ = +type ('s,'e) p_statement = ('s,'e) p_statement_ locatable and ('s,'e) p_statement_ = | Run of l_str * 'e option | Statement of 's * 'e option | PGroup of {p_ty : pgroup_ty ; cond : 'e option ; children : ('s,'e) p_statement list} @@ -69,7 +69,7 @@ type ('s,'e) p_statement = loc * ('s,'e) p_statement_ and ('s,'e) p_statement_ = type ('s,'e) process_body = { locals : (l_str * sailtype) list; init : 's; - proc_init : (loc * 'e proc_init) list; + proc_init : ('e proc_init locatable) list; loop : ('s,'e) p_statement; } @@ -100,7 +100,7 @@ let mk_program (md:metadata) (imports: ImportSet.t) l : (statement, (statement in (env,m,p) | Struct d -> - let s_fields = List.sort_uniq (fun ((_,s1),_) ((_,s2),_) -> String.compare s1 s2) d.s_fields in + let s_fields = List.sort_uniq (fun (s1,_) (s2,_) -> String.compare s1.value s2.value) d.s_fields in E.throw_if Logging.(make_msg d.s_pos "duplicate fields" ) (List.(length s_fields <> length d.s_fields)) >>= fun () -> let+ env = DeclEnv.add_decl d.s_name (d.s_pos, defn_to_proto (Struct {d with s_fields})) Struct e |> rethrow d.s_pos @@ -116,7 +116,7 @@ let mk_program (md:metadata) (imports: ImportSet.t) l : (statement, (statement ListM.fold_left (fun (e,f) d -> let* () = E.throw_if Logging.(make_msg d.m_proto.pos "calling a method 'main' is not allowed") (d.m_proto.name = "main") in let true_name = (match d.m_body with Left (sname,_) -> sname | Right _ -> d.m_proto.name) in - let+ env = DeclEnv.add_decl d.m_proto.name ((d.m_proto.pos,true_name),defn_to_proto (Method d)) Method e + let+ env = DeclEnv.add_decl d.m_proto.name (mk_locatable d.m_proto.pos true_name,defn_to_proto (Method d)) Method e in (env,d::f) ) (e,m) d in (env,funs,p) diff --git a/src/parsing/parser.mly b/src/parsing/parser.mly index 71b8fb7..8df4866 100644 --- a/src/parsing/parser.mly +++ b/src/parsing/parser.mly @@ -24,6 +24,7 @@ open Common open TypesCommon open AstParser + module SailParser = struct end %} %token TYPE_BOOL TYPE_FLOAT TYPE_CHAR TYPE_STRING %token TYPE_INT @@ -131,9 +132,10 @@ let process_body := proc_init = midrule(P_PROC_INIT ; ":" ; ~ = list(located(proc_init)) ; <>)? ; loop = midrule(P_LOOP ; ":" ; loop?)? ; { - let init = Option.(join init |> value ~default:($loc,Skip)) in + + let init = Option.(join init |> value ~default:(mk_locatable $loc Skip)) in let proc_init = Option.value proc_init ~default:[] in - let loop = Option.join loop |> function Some x -> x | None -> $loc,(Statement (($loc,Skip),None)) in + let loop = Option.join loop |> function Some x -> x | None -> mk_locatable $loc (Statement (mk_locatable $loc Skip,None)) in {locals;init;proc_init;loop} } @@ -167,7 +169,7 @@ let generic := loption(delimited("<", separated_list(",", UID), ">")) let returnType := preceded(":", sailtype)? -let mutable_var(X) := (loc,id) = located(ID) ; ":" ; mut = mut ; ty =X ; { {id;mut;loc;ty} } +let mutable_var(X) := id = located(ID) ; ":" ; mut = mut ; ty =X ; { {id=id.value;mut;loc=id.loc;ty} } let separated_nonempty_list_opt(separator, X) := | x = X ; separator? ; { [ x ] } @@ -200,8 +202,10 @@ let expression := | "*" ; ~ = expression ; %prec UNARY | e1 = expression ; op =binOp ; e2 =expression ; { BinOp(op,e1,e2) } | ~ = delimited ("[", separated_list(",", expression), "]") ; - | ~ = ioption(module_loc) ; ~ =located(ID) ; ~ = midrule(l = brace_del_sep_list(",", id_colon(expression)); - {List.fold_left (fun l ((ly,y),z) -> (y,(ly,z))::l) [] l}) ; + | ~ = ioption(module_loc) ; ~ =located(ID) ; + ~ = midrule(l = brace_del_sep_list(",", id_colon(expression)); + { List.fold_left (fun accu (f,e) -> (f.value,mk_locatable f.loc e)::accu) [] l } + ) ; | ~ = located(UID) ; ~ = loption(parenthesized (separated_list(",", expression))) ; | ~ = ioption(module_loc) ; ~ = located(ID) ; ~ = parenthesized(separated_list (",", expression)) ; ) @@ -244,15 +248,15 @@ let iterable_or_range := | rl = INT ; "," ; rr = INT ; { let rl = Z.to_int rl in let rr = Z.to_int rr in - ArrayStatic (List.init (rr - rl) (fun i -> dummy_pos,Literal (LInt {l=Z.of_int (i + rl); size=32}))) + ArrayStatic (List.init (rr - rl) (fun i -> mk_locatable dummy_pos (Literal (LInt {l=Z.of_int (i + rl); size=32})) ) ) } -| e = expression ; {snd e} +| e = expression ; {e.value} let single_statement := | located ( | vardecl - | l = expression ; "=" ; e = expression ; + | path = expression ; "=" ; value = expression ; {Assign {path;value} } | CASE ; ~ = parenthesized(expression) ; ~ = brace_del_sep_list(",", case) ; | ~ = ioption(module_loc) ; ~ = located(ID) ; ~ = parenthesized(separated_list(",", expression)) ; | RETURN ; ~ = expression? ; @@ -265,13 +269,15 @@ let vardecl := VAR ; ~ = mut ; ~ = ID ; ~ = preceded(":", sailtype)? ; ~ = prece let brace_del_sep_list(sep,x) := delimited("{", separated_nonempty_list(sep, x), "}") -let located(x) == ~ = x ; { ($loc,x) } +let located(x) == ~ = x ; { mk_locatable $loc x } let case := separated_pair(pattern, ":", statement) let mut := boption(MUT) -let module_loc := ~ = located(ID); "::" ; <> | x = located(SELF) ; "::" ; { (fst x),Constants.sail_module_self} +let module_loc := + | ~ = located(ID); "::" ; <> + | located(SELF) ; "::" ; { mk_locatable dummy_pos Constants.sail_module_self } let parenthesized(e) == delimited("(",e,")") diff --git a/src/passes/ir/baseAst/ast.ml b/src/passes/ir/baseAst/ast.ml new file mode 100644 index 0000000..cfe1756 --- /dev/null +++ b/src/passes/ir/baseAst/ast.ml @@ -0,0 +1,134 @@ +(**************************************************************************) +(* *) +(* SAIL *) +(* *) +(* Frédéric Dabrowski, LMV, Orléans University *) +(* *) +(* Copyright (C) 2022 Frédéric Dabrowski *) +(* *) +(* This program is free software: you can redistribute it and/or modify *) +(* it under the terms of the GNU General Public License as published by *) +(* the Free Software Foundation, either version 3 of the License, or *) +(* (at your option) any later version. *) +(* *) +(* This program is distributed in the hope that it will be useful, *) +(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) +(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) +(* GNU General Public License for more details. *) +(* *) +(* You should have received a copy of the GNU General Public License *) +(* along with this program. If not, see . *) +(**************************************************************************) + +open Common.TypesCommon + +type 'e struct_read = {field: l_str; strct: 'e} +type 'e struct_alloc = {name:l_str; fields: 'e locatable dict} +type 'e fcall = {ret_var : string option; id: l_str; args:'e list} + +type yes = Yes +type no = No + +type exp = Expression +type stmt = Statement + +(* based on https://icfp23.sigplan.org/details/ocaml-2023-papers/4/Modern-DSL-compiler-architecture-in-OCaml-our-experience-with-Catala *) + +type ('tag,'kind,'features) generic_ast = ('tag, 'kind, 'features,'features) relaxed_ast +and ('tag,'kind, 'deep_features,'shallow_features) relaxed_ast = ('tag,('tag,'kind,'deep_features,'shallow_features) relaxed_ast_) tagged_node +and ('tag,'kind,'deep_features,'shallow_features) relaxed_ast_ = + (* expressions *) + | Variable : string -> ('tag,exp,'deep_features,'shallow_features) relaxed_ast_ + + | Deref : ('tag,exp,'deep_features) generic_ast -> ('tag,exp,'deep_features,'shallow_features) relaxed_ast_ + + | Ref : bool * ('tag,exp,'deep_features) generic_ast -> ('tag,exp,'deep_features,'shallow_features) relaxed_ast_ + + | Literal : literal -> ('tag,exp,'deep_features,'shallow_features) relaxed_ast_ + + | UnOp : unOp * ('tag,exp,'deep_features) generic_ast -> ('tag,exp,'deep_features,'shallow_features) relaxed_ast_ + + | BinOp : {left: ('tag,exp,'deep_features) generic_ast; right:('tag,exp,'deep_features) generic_ast; op: binOp} -> ('tag,exp,'deep_features,'shallow_features) relaxed_ast_ + + | ArrayRead : {array: ('tag,exp,'deep_features) generic_ast ; idx: ('tag,exp,'deep_features) generic_ast} -> ('tag,exp,'deep_features,'shallow_features) relaxed_ast_ + + | ArrayStatic : ('tag,exp,'deep_features) generic_ast list -> ('tag,exp,'deep_features,'shallow_features) relaxed_ast_ + + | MethodCall : (l_str option, ('tag,exp,'deep_features) generic_ast fcall) importable -> ('tag,exp,'deep_features,) relaxed_ast_ + + | StructRead : (l_str option, ('tag,exp,'deep_features) generic_ast struct_read) importable -> ('tag,exp,'deep_features, ) relaxed_ast_ + + | StructRead2 : (l_str, ('tag,exp,'deep_features) generic_ast struct_read) importable -> ('tag,exp,'deep_features,) relaxed_ast_ + + | StructAlloc : (l_str option, ('tag,exp,'deep_features) generic_ast struct_alloc) importable -> ('tag,exp,'deep_features,) relaxed_ast_ + + | StructAlloc2 : (l_str, ('tag,exp,'deep_features) generic_ast struct_alloc) importable -> ('tag,exp,'deep_features, ) relaxed_ast_ + + | EnumAlloc : l_str * ('tag,exp,'deep_features) generic_ast list -> ('tag,exp,'deep_features,'shallow_features) relaxed_ast_ + + (* statements *) + + | DeclVar : {mut:bool; id:string; ty: sailtype option; value: ('tag,exp,'deep_features) generic_ast option} -> ('tag,stmt,'deep_features, ) relaxed_ast_ + + | DeclVar2 : {mut:bool; id:string; ty: sailtype} -> ('tag,stmt,'deep_features, ) relaxed_ast_ + + | Invoke : (l_str option, ('tag,exp,'deep_features) generic_ast fcall) importable -> ('tag,stmt, 'deep_features, ) relaxed_ast_ + + | Invoke2 : (l_str, ('tag,exp,'deep_features) generic_ast fcall) importable -> ('tag,stmt,'deep_features,) relaxed_ast_ + + | Skip : ('tag,stmt,'deep_features,'shallow_features) relaxed_ast_ + + | Assign : {path: ('tag,exp,'deep_features) generic_ast ; value: ('tag,exp,'deep_features) generic_ast} -> ('tag,stmt,'deep_features,'shallow_features) relaxed_ast_ + + | Seq : ('tag,stmt,'deep_features) generic_ast * ('tag,stmt,'deep_features) generic_ast -> ('tag,stmt,'deep_features,'shallow_features) relaxed_ast_ + + | If : + { + cond: ('tag,exp,'deep_features) generic_ast; + then_:('tag,stmt,'deep_features) generic_ast ; + else_: ('tag,stmt,'deep_features) generic_ast option + } -> ('tag,stmt,'deep_features,'shallow_features) relaxed_ast_ + + | Loop : ('tag,stmt,'deep_features) generic_ast -> ('tag,stmt,'deep_features,'shallow_features) relaxed_ast_ + + | Break : ('tag,stmt,'deep_features,'shallow_features) relaxed_ast_ + + | Case : {switch : ('tag,exp,'deep_features) generic_ast ; cases : (string * string list * ('tag, stmt, 'deep_features) generic_ast) list} -> ('tag,stmt, 'deep_features, 'shallow_features) relaxed_ast_ + + | Return : ('tag, exp,'deep_features) generic_ast option -> ('tag,stmt,'deep_features,'shallow_features) relaxed_ast_ + + | Block : ('tag,stmt,'deep_features) generic_ast -> ('tag,stmt,'deep_features,'shallow_features) relaxed_ast_ + + + + +let buildExp (tag : 't) (node : ('t,exp,'a,'b) relaxed_ast_) : ('t,exp,'a,'b) relaxed_ast = {tag;node} +let buildStmt (tag : 't) (node : ('t,stmt,'a,'b) relaxed_ast_) : ('t,stmt,'a,'b) relaxed_ast = {tag;node} +let buildSeq tag s1 s2 = buildStmt tag (Seq (s1, s2)) +let buildSeqStmt tag s1 s2 = buildSeq tag s1 @@ buildStmt tag s2 + + +module Syntax = struct + + let skip : unit -> (loc,stmt,'a) generic_ast = fun () -> buildStmt dummy_pos Skip + + let (=) path value = buildStmt dummy_pos (Assign {path;value}) + + let var loc id ty value = buildStmt loc (DeclVar {mut=true;id;ty=Some ty;value}) + + let true_ : unit -> (_,exp,_) generic_ast = fun () -> buildExp dummy_pos (Literal (LBool true)) + let false_ : unit -> (_,exp,_) generic_ast = fun () -> buildExp dummy_pos (Literal (LBool false)) + + let (+) = fun left right -> buildExp dummy_pos (BinOp {op=Plus;left;right}) + let (%) = fun left right -> buildExp dummy_pos (BinOp {op=Rem;left;right}) + let (==) = fun left right -> buildExp dummy_pos (BinOp {op=Eq;left;right}) + + let (&&) = fun s1 s2 -> buildStmt dummy_pos (Seq (s1,s2)) + let (!@) = fun id -> buildExp dummy_pos (Variable id) + let (!) = fun n -> buildExp dummy_pos (Literal (LInt {l=Z.of_int n; size=32})) + let (!!) = fun b -> buildStmt dummy_pos (Block b) + + let if_ cond then_ else_ = + let else_ = match else_.node with Skip -> None | node -> Some {else_ with node} in + buildStmt dummy_pos (If {cond;then_;else_}) +end \ No newline at end of file diff --git a/src/passes/ir/baseAst/dune b/src/passes/ir/baseAst/dune new file mode 100644 index 0000000..1055df4 --- /dev/null +++ b/src/passes/ir/baseAst/dune @@ -0,0 +1,3 @@ +(library + (libraries common) + (name irAst)) \ No newline at end of file diff --git a/src/passes/ir/baseAst/ppAst.ml b/src/passes/ir/baseAst/ppAst.ml new file mode 100644 index 0000000..039285d --- /dev/null +++ b/src/passes/ir/baseAst/ppAst.ml @@ -0,0 +1,73 @@ +open Common +open PpCommon +open Format +open Ast +open TypesCommon + +let rec ppPrintExpression (pf : Format.formatter) (e : _) : unit = + let open Format in + match e.node with + | Variable s -> fprintf pf "%s" s + | Deref e -> fprintf pf "*%a" ppPrintExpression e + | StructRead st -> fprintf pf "%a.%s" ppPrintExpression st.value.strct st.value.field.value + | ArrayRead ar -> fprintf pf "%a[%a]" ppPrintExpression ar.array ppPrintExpression ar.idx + | Literal (l) -> fprintf pf "%a" PpCommon.pp_literal l + | UnOp (o, e) -> fprintf pf "%a %a" pp_unop o ppPrintExpression e + | BinOp bop -> fprintf pf "%a %a %a" ppPrintExpression bop.left pp_binop bop.op ppPrintExpression bop.right + | Ref (true,e) -> fprintf pf "&mut %a" ppPrintExpression e + | Ref (false,e) -> fprintf pf "&%a" ppPrintExpression e + | ArrayStatic el -> + fprintf pf "[%a]" + (pp_print_list ~pp_sep:pp_comma ppPrintExpression) el + | StructAlloc st -> + let pp_field pf (x, (y: _ locatable)) = fprintf pf "%s:%a" x ppPrintExpression y.value in + fprintf pf "%s{%a}" st.value.name.value + (pp_print_list ~pp_sep:pp_comma pp_field) + st.value.fields + | EnumAlloc (id,el) -> + fprintf pf "[%s(%a)]" id.value + (pp_print_list ~pp_sep:pp_comma ppPrintExpression) el + | MethodCall m -> + fprintf pf "%a%s(%a)" + (pp_print_option (fun fmt ml -> fprintf fmt "%s::" ml.value)) m.import + m.value.id.value + (pp_print_list ~pp_sep:pp_comma ppPrintExpression) m.value.args + +let rec ppPrintStatement (pf : Format.formatter) (s : _) : unit = match s.node with +| DeclVar d -> fprintf pf "\nvar %s%a%a;" d.id + (pp_print_option (fun fmt -> fprintf fmt " : %a" pp_type)) d.ty + (pp_print_option (fun fmt -> fprintf fmt " = %a" ppPrintExpression)) d.value + +| Assign a -> fprintf pf "\n%a = %a;" ppPrintExpression a.path ppPrintExpression a.value +| Seq(c1, c2) -> fprintf pf "%a%a" ppPrintStatement c1 ppPrintStatement c2 +| If if_ -> fprintf pf "\nif (%a) {\n%a\n}\n%a" + ppPrintExpression if_.cond + ppPrintStatement if_.then_ + (pp_print_option (fun pf -> fprintf pf "else {%a\n}" ppPrintStatement)) if_.else_ +| Loop c -> fprintf pf "\nloop {%a\n}" ppPrintStatement c +| Break -> fprintf pf "break;" +| Case _ -> () +| Invoke i -> fprintf pf "\n%a%a%s(%a);" + (pp_print_option (fun fmt v -> fprintf fmt "%s = " v)) i.value.ret_var + (pp_print_option (fun fmt ml -> fprintf fmt "%s::" ml.value)) i.import + i.value.id.value + (pp_print_list ~pp_sep:pp_comma ppPrintExpression) i.value.args +| Return e -> fprintf pf "\nreturn %a;" (pp_print_option ppPrintExpression) e +| Block c -> fprintf pf "\n{\n@[ %a @]\n}" ppPrintStatement c +| Skip -> () + +let ppPrintMethodSig (pf : Format.formatter) (s : TypesCommon.method_sig) : unit = + match s.rtype with + None -> + fprintf pf "%s(%a)" s.name (pp_print_list ~pp_sep:pp_comma (pp_field pp_type)) s.params +| Some t -> + fprintf pf "%s(%a) -> %a" s.name (pp_print_list ~pp_sep:pp_comma (pp_field pp_type)) s.params pp_type t + +let ppPrintMethod (pf : Format.formatter) (m: _ TypesCommon.method_defn) : unit = + match m.m_body with + | Right s -> fprintf pf "fn %a{\n@[%a@]\n}\n" ppPrintMethodSig m.m_proto ppPrintStatement s + | Left _ -> fprintf pf "extern fn %a\n" ppPrintMethodSig m.m_proto + + +(* let ppPrintProcess (pf : Format.formatter) (p : (declaration list * cfg) Common.TypesCommon.process_defn) : unit = + fprintf pf "proc %s() {\n%a\n%a}\n" p.p_name (pp_print_list ~pp_sep:pp_semicr ppPrintDeclaration) (fst p.p_body) ppPrintCfg (snd p.p_body) *) diff --git a/src/passes/ir/baseAst/utils.ml b/src/passes/ir/baseAst/utils.ml new file mode 100644 index 0000000..68947b5 --- /dev/null +++ b/src/passes/ir/baseAst/utils.ml @@ -0,0 +1,157 @@ +open Common +open TypesCommon +open Monad + + +module AstMonad(M : Monad) = struct + open Ast + open UseMonad(M) + + let opt_f f = function None -> return None | Some x -> let+ r = f x in Some r + + + type ('in_tag,'out_tag,'in_features,'out_features) map = {f: 'kind. ('in_tag,'kind,'in_features) Ast.generic_ast -> ('out_tag,'kind,'out_features) Ast.generic_ast M.t} + + + let shallow_exp_map (type src_features dst_features in_tag out_tag) + (map: (in_tag,exp,src_features) Ast.generic_ast -> (out_tag,exp,dst_features) Ast.generic_ast M.t) + (tag : out_tag) + (x : (in_tag,exp,src_features,dst_features) Ast.relaxed_ast) + : (out_tag,exp,dst_features) Ast.generic_ast M.t = + let (node : (out_tag,exp,dst_features,dst_features) Ast.relaxed_ast_ M.t) = + match x.node with + | Variable id -> return @@ Variable id + | Deref e -> let+ e = map e in Deref e + | ArrayRead ar -> let+ array = map ar.array and* idx = map ar.idx in ArrayRead {array;idx} + | Literal l -> return @@ Literal l + | UnOp (op, e) -> let+ e = map e in UnOp (op, e) + | BinOp bop -> let+ left = map bop.left and* right = map bop.right in BinOp {bop with left;right} + | Ref (b, e) -> let+ e = map e in Ref (b,e) + | ArrayStatic el -> let+ el = ListM.map map el in ArrayStatic el + | EnumAlloc (id, el) -> let+ el = ListM.map map el in EnumAlloc (id,el) + | MethodCall mc -> let+ args = ListM.map map mc.value.args in MethodCall {mc with value={mc.value with args}} + | StructAlloc st -> + let+ fields = ListM.map (fun (s,(fi:_ locatable)) -> let+ value = map fi.value in s,{fi with value}) st.value.fields in + StructAlloc {st with value={st.value with fields}} + + | StructAlloc2 st -> + let+ fields = ListM.map (fun (s,(fi:_ locatable)) -> let+ value = map fi.value in s,{fi with value}) st.value.fields in + StructAlloc2 {st with value={st.value with fields}} + + | StructRead s -> let+ strct = map s.value.strct in StructRead {s with value={s.value with strct} } + | StructRead2 s -> let+ strct = map s.value.strct in StructRead2 {s with value={s.value with strct} } + in let+ node in {node;tag} + + + let shallow_stmt_map (type src_features dst_features in_tag out_tag) + (map_stmt: (in_tag,stmt,src_features) Ast.generic_ast -> (out_tag,stmt,dst_features) Ast.generic_ast M.t) + (map_exp: (in_tag,exp,src_features) Ast.generic_ast -> (out_tag,exp,dst_features) Ast.generic_ast M.t) + (tag : out_tag) + (x : (in_tag,stmt,src_features,dst_features) Ast.relaxed_ast) + : (out_tag,stmt,dst_features) Ast.generic_ast M.t = + + + let (node : (out_tag,Ast.stmt,dst_features,dst_features) Ast.relaxed_ast_ M.t) = + match x.node with + | DeclVar v -> let+ value = opt_f map_exp v.value in DeclVar {v with value} + | DeclVar2 v -> return (DeclVar2 v) + | Assign a -> + let+ path = map_exp a.path + and* value = map_exp a.value in + Assign {path;value} + + | Seq(c1, c2) -> + let+ c1 = map_stmt c1 + and* c2 = map_stmt c2 in + Seq (c1, c2) + + | If if_ -> + let+ cond = map_exp if_.cond + and* then_ = map_stmt if_.then_ + and* else_ = opt_f map_stmt if_.else_ in + If {cond;then_;else_} + + | Loop c -> let+ c = map_stmt c in Loop c + | Break -> return Break + | Case c -> + let+ switch = map_exp c.switch + and* cases = ListM.map (fun (s,s2,s3) -> let+ s3 = map_stmt s3 in s,s2,s3) c.cases in + Case {switch;cases} + + | Invoke i -> + let+ args = ListM.map map_exp i.value.args in + Invoke {i with value = {i.value with args}} + + | Invoke2 i -> + let+ args = ListM.map map_exp i.value.args in + Invoke2 {i with value = {i.value with args}} + + | Return e -> let+ e = opt_f map_exp e in Return e + | Block c -> let+ c = map_stmt c in Block c + | Skip -> return Skip + in let+ node in {node;tag} + + + let separate_kind (type src_features dst_features in_tag out_tag kind) + (f_stmt : (in_tag,stmt,src_features,dst_features) Ast.relaxed_ast -> (out_tag,stmt,dst_features) Ast.generic_ast M.t) + (f_exp : (in_tag,exp,src_features,dst_features) Ast.relaxed_ast -> (out_tag,exp,dst_features) Ast.generic_ast M.t) + (x: (in_tag,kind,src_features,dst_features) Ast.relaxed_ast) : (out_tag, kind, dst_features) generic_ast M.t = + + let f_stmt : (in_tag, stmt,src_features,dst_features) relaxed_ast_ -> (out_tag, stmt, dst_features) generic_ast M.t = fun node -> f_stmt {node;tag=x.tag} in + let f_exp : (in_tag,exp,src_features,dst_features) relaxed_ast_ -> (out_tag, exp, dst_features) generic_ast M.t = fun node -> f_exp {node;tag=x.tag} in + + match (x.node : (in_tag,kind,src_features,dst_features) Ast.relaxed_ast_) with + | StructAlloc _ as exp -> f_exp exp + | StructAlloc2 _ as exp -> f_exp exp + | StructRead _ as exp -> f_exp exp + | StructRead2 _ as exp -> f_exp exp + | MethodCall _ as exp -> f_exp exp + | Variable _ as exp -> f_exp exp + | Deref _ as exp -> f_exp exp + | ArrayRead _ as exp -> f_exp exp + | Literal _ as exp -> f_exp exp + | UnOp _ as exp -> f_exp exp + | BinOp _ as exp -> f_exp exp + | Ref _ as exp -> f_exp exp + | ArrayStatic _ as exp -> f_exp exp + | EnumAlloc _ as exp -> f_exp exp + | DeclVar _ as stmt -> f_stmt stmt + | DeclVar2 _ as stmt -> f_stmt stmt + | Assign _ as stmt -> f_stmt stmt + | Seq _ as stmt -> f_stmt stmt + | If _ as stmt -> f_stmt stmt + | Loop _ as stmt -> f_stmt stmt + | Break as stmt -> f_stmt stmt + | Case _ as stmt -> f_stmt stmt + | Invoke _ as stmt -> f_stmt stmt + | Invoke2 _ as stmt -> f_stmt stmt + | Return _ as stmt -> f_stmt stmt + | Block _ as stmt -> f_stmt stmt + | Skip as stmt -> f_stmt stmt + + let shallow_map (type src_features dst_features in_tag out_tag kind) + (map: (in_tag,out_tag,src_features,dst_features) map) + (tag : out_tag) + (x : (in_tag,kind,src_features,dst_features) Ast.relaxed_ast) + : (out_tag,kind,dst_features) Ast.generic_ast M.t = + separate_kind (shallow_stmt_map map.f map.f tag) (shallow_exp_map map.f tag) x + +end + +let rename_var (f: string -> string) x = + let module Map = AstMonad(MonadIdentity) in + let open Ast in + let rec aux : type kind features tag. (tag,kind,features) generic_ast -> (tag,kind,features) generic_ast = fun x -> + let ret : (_,kind,features,features) relaxed_ast_ -> (_,kind,features) generic_ast = fun n -> mk_tagged_node x.tag n in + match x.node with + | Variable id -> ret @@ Variable (f id) + | DeclVar v -> + let value = MonadOption.M.fmap aux v.value in + let id = f v.id in + ret @@ DeclVar {v with value; id} + | Invoke i -> + let args = List.map aux i.value.args in + let ret_var = MonadOption.M.fmap f i.value.ret_var in + ret @@ Invoke {i with value = {i.value with ret_var;args}} + | _ as x' -> Map.shallow_map {f=aux} x.tag {tag=x.tag;node=x'} + in aux x diff --git a/src/passes/ir/sailHir/astHir.ml b/src/passes/ir/sailHir/astHir.ml deleted file mode 100644 index 8dfd46c..0000000 --- a/src/passes/ir/sailHir/astHir.ml +++ /dev/null @@ -1,95 +0,0 @@ -(**************************************************************************) -(* *) -(* SAIL *) -(* *) -(* Frédéric Dabrowski, LMV, Orléans University *) -(* *) -(* Copyright (C) 2022 Frédéric Dabrowski *) -(* *) -(* This program is free software: you can redistribute it and/or modify *) -(* it under the terms of the GNU General Public License as published by *) -(* the Free Software Foundation, either version 3 of the License, or *) -(* (at your option) any later version. *) -(* *) -(* This program is distributed in the hope that it will be useful, *) -(* but WITHOUT ANY WARRANTY; without even the implied warranty of *) -(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) -(* GNU General Public License for more details. *) -(* *) -(* You should have received a copy of the GNU General Public License *) -(* along with this program. If not, see . *) -(**************************************************************************) - -open Common.TypesCommon - -type ('info,'import) expression = {info: 'info ; exp: ('info,'import) _expression} and ('info,'import) _expression = - | Variable of string - | Deref of ('info,'import) expression - | StructRead of 'import * ('info,'import) expression * l_str - | ArrayRead of ('info,'import) expression * ('info,'import) expression - | Literal of literal - | UnOp of unOp * ('info,'import) expression - | BinOp of binOp * ('info,'import) expression * ('info,'import) expression - | Ref of bool * ('info,'import) expression - | ArrayStatic of ('info,'import) expression list - | StructAlloc of 'import * l_str * (loc * ('info,'import) expression) dict - | EnumAlloc of l_str * ('info,'import) expression list - | MethodCall of l_str * 'import * ('info,'import) expression list - - - -type ('info,'import,'exp) statement = {info: 'info; stmt: ('info,'import,'exp) _statement} and ('info,'import,'exp) _statement = - | DeclVar of bool * string * sailtype option * 'exp option - | Skip - | Assign of 'exp * 'exp - | Seq of ('info,'import,'exp) statement * ('info,'import,'exp) statement - | If of 'exp * ('info,'import,'exp) statement * ('info,'import,'exp) statement option - | Loop of ('info,'import,'exp) statement - | Break - | Case of 'exp * (string * string list * ('info,'import,'exp) statement) list - | Invoke of {ret_var:string option; import: 'import; id: l_str; args:'exp list} - | Return of 'exp option - (* - | DeclSignal of string - | Emit of string - | Await of string - | When of string * ('info,'import,'exp) statement - | Watching of string * ('info,'import,'exp) statement - | Par of ('info,'import,'exp) statement * ('info,'import,'exp) statement - *) - | Block of ('info,'import,'exp) statement - -let buildExp info (exp: (_,_) _expression) : (_,_) expression = {info;exp} -let buildStmt info stmt : (_,_,_) statement = {info;stmt} - - -module Syntax = struct - let skip = buildStmt dummy_pos Skip - - let (=) = fun l r -> buildStmt dummy_pos (Assign (l,r)) - - let var (loc,id,ty) = buildStmt loc (DeclVar (true,id,Some ty,None)) - - let _true = buildExp dummy_pos (Literal (LBool true)) - let _false = buildExp dummy_pos (Literal (LBool false)) - - - let (+) = fun l r -> buildExp dummy_pos (BinOp(Plus,l,r)) - let (%) = fun l r -> buildExp dummy_pos (BinOp(Rem,l,r)) - let (==) = fun l r -> buildExp dummy_pos (BinOp(Eq, l,r)) - - let (&&) = fun s1 s2 -> buildStmt dummy_pos (Seq (s1,s2)) - - - let (!@) = fun id -> buildExp dummy_pos (Variable id) - - let (!) = fun n -> buildExp dummy_pos (Literal (LInt {l=Z.of_int n; size=32})) - - let (!!) = fun b -> buildStmt dummy_pos (Block b) - - let _if cond _then _else = - let _else = match _else.stmt with Skip -> None | stmt -> Some {_else with stmt} in - buildStmt dummy_pos (If (cond,_then,_else)) - - -end \ No newline at end of file diff --git a/src/passes/ir/sailHir/dune b/src/passes/ir/sailHir/dune index 44d16a9..6a94dfa 100644 --- a/src/passes/ir/sailHir/dune +++ b/src/passes/ir/sailHir/dune @@ -1,3 +1,3 @@ (library - (libraries common sailParser) + (libraries common sailParser irAst) (name irHir)) diff --git a/src/passes/ir/sailHir/hir.ml b/src/passes/ir/sailHir/hir.ml index 53e384a..6c41648 100644 --- a/src/passes/ir/sailHir/hir.ml +++ b/src/passes/ir/sailHir/hir.ml @@ -2,11 +2,9 @@ open SailParser open Common open TypesCommon open Monad -open AstHir +open IrAst open HirUtils open M -type expression = HirUtils.expression -type statement = HirUtils.statement module Pass = Pass.MakeFunctionPass (V)( @@ -31,16 +29,16 @@ struct let open MonadSyntax(M.ECS) in let open MonadOperator(M.ECS) in - let rec aux (info,s : m_in) : m_out M.ECS.t = + let rec aux (stmt:m_in) : m_out M.ECS.t = - let buildSeq s1 s2 = {info; stmt = Seq (s1, s2)} in - let buildStmt stmt = {info;stmt} in + let buildSeq = Ast.buildSeq stmt.loc in + let buildStmt = Ast.buildStmt stmt.loc in let buildSeqStmt s1 s2 = buildSeq s1 @@ buildStmt s2 in - match s with + match stmt.value with | DeclVar (mut, id, t, e ) -> - M.ECS.set_var info id >>= fun () -> - let* t = match t with + M.ECS.set_var stmt.loc id >>= fun () -> + let* ty = match t with | None -> return None | Some t -> let* (ve,d) = M.ECS.get in @@ -50,69 +48,78 @@ struct in begin match e with | Some e -> let+ (e, s) = lower_expression e in - buildSeqStmt s (DeclVar (mut,id, t, Some e)) - | None -> return {info;stmt=DeclVar (mut,id, t, None)} + buildSeqStmt s (DeclVar {mut;id;ty;value=Some e}) + | None -> return (buildStmt (DeclVar {mut;id;ty;value=None})) end - | Skip -> return {info;stmt=Skip} - | Assign(e1, e2) -> - let* e1,s1 = lower_expression e1 in - let+ e2,s2 = lower_expression e2 in - buildSeq s1 @@ buildSeqStmt s2 @@ Assign (e1, e2) - - | Seq (c1, c2) -> let+ c1 = aux c1 and* c2 = aux c2 in {info;stmt=Seq (c1, c2)} - | If (e, c1, Some c2) -> - let+ e,s = lower_expression e and* c1 = aux c1 and* c2 = aux c2 in - buildSeqStmt s (If (e, c1, Some c2)) - - | If ( e, c1, None) -> - let+ (e, s) = lower_expression e and* c1 = aux c1 in - buildSeqStmt s (If (e, c1, None)) + + | Skip -> return (buildStmt Skip) + + | Assign a -> + let+ path,s1 = lower_expression a.path + and* value,s2 = lower_expression a.value in + buildSeq s1 @@ buildSeqStmt s2 @@ Assign {path; value} + + | Seq (c1, c2) -> + (buildSeq <$> aux c1) <*> aux c2 + + | If (e, then_, else_) -> + let+ cond,s = lower_expression e + and* then_ = aux then_ + and* else_ = match else_ with None -> return None | Some else_ -> aux else_ >>| Option.some in + buildSeqStmt s (If {cond;then_;else_}) | While (e, c) -> - let+ e,s = lower_expression e and* c = aux c in - let c = buildStmt (If (e,c,Some (buildStmt Break))) in + let+ cond,s = lower_expression e + and* then_ = aux c in + let c = buildStmt (If {cond;then_;else_=Some (buildStmt Break)}) in buildSeqStmt s (Loop c) | Loop c -> let+ c = aux c in buildStmt (Loop c) - | For {var;iterable;body} -> + | For {var;iterable;body} -> (* fixme : temporary *) begin - match iterable with - | _,ArrayStatic el -> + match iterable.value with + | ArrayStatic el -> let open AstParser in let arr_id = "_for_arr_" ^ var in let i_id = "_for_i_" ^ var in let arr_length = List.length el in - - let tab_decl = info,DeclVar (false, arr_id, Some (dummy_pos,ArrayType ((dummy_pos,Int 32),arr_length)), Some iterable) in - let var_decl = info,DeclVar (true, var, Some (dummy_pos,Int 32), None) in - let i_decl = info,DeclVar (true, i_id, Some (dummy_pos,Int 32), Some (info,(Literal (LInt {l=Z.zero;size=32})))) in - - let tab = info,Variable arr_id in - let var = info,Variable var in - let i = info,Variable i_id in + let loc x = mk_locatable iterable.loc x in + + let tab_decl = loc @@ DeclVar (false, arr_id, Some (loc @@ ArrayType (loc (Int 32),arr_length)), Some iterable) in + let var_decl = loc @@ DeclVar (true, var, Some (loc @@ Int 32), None) in + let i_decl = loc @@ DeclVar (true, i_id, Some (loc @@ Int 32), Some (mk_locatable dummy_pos @@ (Literal (LInt {l=Z.zero;size=32})))) in + + let tab = loc (Variable arr_id) in + let var = loc (Variable var) in + let i =loc (Variable i_id) in - let cond = info,BinOp (Lt, i, (info,Literal (LInt {l=Z.of_int arr_length;size=32}))) in - let incr = info,Assign (i,(info,BinOp (Plus, i, (info, Literal (LInt {l=Z.one;size=32}))))) in - let init = info,Seq ((info,Seq (tab_decl,var_decl)), i_decl) in - let vari = info, Assign (var,(info,ArrayRead(tab,i))) in - - let body = info,Seq((info,Seq(vari,body)),incr) in - let _while = info,While (cond,body) in - let _for = info,Seq(init,_while) in + let cond = loc @@ BinOp (Lt, i, (mk_locatable dummy_pos @@ Literal (LInt {l=Z.of_int arr_length;size=32}))) in + let incr = loc @@ Assign {path=i;value= loc @@ BinOp (Plus, i, (loc @@ Literal (LInt {l=Z.one;size=32})))} in + let init = loc @@ Seq (loc @@ Seq (tab_decl,var_decl), i_decl) in + let vari = loc @@ Assign {path=var;value=loc @@ ArrayRead(tab,i)} in + + let body = loc @@ Seq(loc @@ Seq(vari,body),incr) in + let _while = loc @@ While (cond,body) in + let _for = loc @@ Seq(init,_while) in aux _for - | loc,_ -> M.ECS.throw Logging.(make_msg loc "for loop only allows a static array expression at the moment") + | _ -> M.ECS.throw Logging.(make_msg iterable.loc "for loop only allows a static array expression at the moment") end - | Break () -> return {info; stmt=Break} + | Break () -> return (buildStmt Break) + (* | Case(loc, e, cases) -> Case (loc, e, List.map (fun (p,c) -> (p, aux c)) cases) *) - | Case (e, _cases) -> let+ e,s = lower_expression e in - buildSeqStmt s (Case (e, [])) + | Case (e, cases) -> + let open MonadFunctions(M.ECS) in + let+ switch,s = lower_expression e + and* _cases = ListM.map (pairMap2 aux) cases in + buildSeqStmt s (Case {switch;cases=[]}) | Invoke (mod_loc, id, args) -> - let+ args,s = ListM.map lower_expression args in - buildSeqStmt s (Invoke {ret_var=None;import=mod_loc;id;args}) + let+ args,s1 = ListM.map lower_expression args in + let s2 = Ast.Invoke (mk_importable mod_loc Ast.{ret_var=None;id;args}) in + buildSeqStmt s1 s2 | Return e -> begin match e with @@ -129,7 +136,7 @@ struct in - M.E.(bind (M.run aux env body) (fun (r,venv) -> pure (r,venv,tenv))) |> M.E.recover ({info=dummy_pos;stmt=Skip},snd env,tenv) + M.E.(bind (M.run aux env body) (fun (r,venv) -> pure (r,venv,tenv))) |> M.E.recover (MonoidSeq.mempty,snd env,tenv) let lower_process (p:p_in process_defn) ((env,decls),tenv: HIREnv.t * _ ) sm : (p_out * HIREnv.D.t * _) M.E.t = @@ -138,10 +145,10 @@ struct let open UseMonad(E) in let params = List.to_seq p.p_interface.p_params |> Seq.map (fun (p:param) -> p.id,p.loc) |> FieldMap.of_seq in - let locals = List.to_seq p.p_body.locals |> Seq.map (fun ((l,id),_) -> id,l ) |> FieldMap.of_seq in + let locals = List.to_seq p.p_body.locals |> Seq.map (fun (id,_) -> id.value,id.loc ) |> FieldMap.of_seq in let read,write = p.p_interface.p_shared_vars in - let read = List.to_seq read |> Seq.map (fun (l,(id,_)) -> id,l) |> FieldMap.of_seq in - let write = List.to_seq write |> Seq.map (fun (l,(id,_)) -> id,l) |> FieldMap.of_seq in + let read = List.to_seq read |> Seq.map (fun r -> fst r.value,r.loc) |> FieldMap.of_seq in + let write = List.to_seq write |> Seq.map (fun w -> fst w.value,w.loc) |> FieldMap.of_seq in let union_no_dupl = FieldMap.union (fun _k _loca _locb -> None) in @@ -153,28 +160,33 @@ struct E.throw_if Logging.(make_msg dummy_pos @@ Fmt.str "process '%s' : name conflict between params,local decls or shared variables" p.p_name) has_name_conflict >>= fun () -> - let add_locals v e = ListM.fold_left (fun e ((l,id),_) -> HIREnv.declare_var id (l,()) e) (e,decls) v >>| fst in - let add_rw r e = ListM.fold_left (fun e (l,(id,_)) -> HIREnv.declare_var id (l,()) e) (e,decls) r >>| fst in + let add_locals v e = ListM.fold_left (fun e (id,_) -> HIREnv.declare_var id.value (id.loc,()) e) (e,decls) v >>| fst in + let add_rw r e = ListM.fold_left (fun e rw -> HIREnv.declare_var (fst rw.value) (rw.loc,()) e) (e,decls) r >>| fst in let* env = add_locals p.p_body.locals env >>= add_rw (fst p.p_interface.p_shared_vars) >>= add_rw (snd p.p_interface.p_shared_vars) in - let* init,decls,tenv = lower_method (p.p_body.init,()) ((env,decls),tenv) sm |> E.recover ({info=dummy_pos;stmt=Skip},decls,tenv) in + let* init,decls,tenv = lower_method (p.p_body.init,()) ((env,decls),tenv) sm |> E.recover (MonoidSeq.mempty,decls,tenv) in - let* (proc_init,_),decls = F.ListM.map (fun ((l,p): loc * _ proc_init) -> + let* (proc_init,_),decls = F.ListM.map (fun (p : _ proc_init locatable) -> let open UseMonad(M) in - let+ params = F.ListM.map lower_expression p.params in l,{p with params} + let+ params = F.ListM.map lower_expression p.value.params in + mk_locatable p.loc {p.value with params} ) p.p_body.proc_init (env,decls) |> M.ECS.run in let+ loop = let process_cond = function None -> return None | Some c -> let+ (cond,_),_ = lower_expression c (env,decls) |> M.ECS.run in Some cond in - let rec aux (l,s) = match s with - | Statement (s,cond) -> let+ s,_,_ = lower_method (s,()) ((env,decls),tenv) sm and* cond = process_cond cond in l,Statement (s,cond) - | Run (proc,cond) -> let+ cond = process_cond cond in l,Run (proc,cond) + let rec aux (stmt :(m_in, AstParser.expression) SailParser.AstParser.p_statement) = match stmt.value with + | Statement (s,cond) -> + let+ s,_,_ = lower_method (s,()) ((env,decls),tenv) sm + and* cond = process_cond cond in + mk_locatable stmt.loc (Statement (s,cond)) + + | Run (proc,cond) -> let+ cond = process_cond cond in mk_locatable stmt.loc @@ Run (proc,cond) | PGroup g -> let* cond = process_cond g.cond in let+ children = ListM.map aux g.children in - l,PGroup {g with cond ; children} + mk_locatable stmt.loc @@ PGroup {g with cond ; children} in aux p.p_body.loop in {p.p_body with init ; proc_init; loop},decls,tenv diff --git a/src/passes/ir/sailHir/hirMonad.ml b/src/passes/ir/sailHir/hirMonad.ml index 53ac723..714f3b3 100644 --- a/src/passes/ir/sailHir/hirMonad.ml +++ b/src/passes/ir/sailHir/hirMonad.ml @@ -5,7 +5,10 @@ module Make(MonoidSeq : Monad.Monoid) = struct type t = unit let string_of_var _ = "" let param_to_var _ = () - end + end + + module MonoidSeq = MonoidSeq + module HIREnv = SailModule.SailEnv(V) diff --git a/src/passes/ir/sailHir/hirUtils.ml b/src/passes/ir/sailHir/hirUtils.ml index 4ce2e08..8eec5d1 100644 --- a/src/passes/ir/sailHir/hirUtils.ml +++ b/src/passes/ir/sailHir/hirUtils.ml @@ -2,102 +2,105 @@ open Common open TypesCommon open Monad open SailParser +open IrAst -type expression = (loc,l_str option) AstHir.expression -type statement = (loc,l_str option,expression) AstHir.statement +type hir = + +type expression = (loc,Ast.exp,hir) Ast.generic_ast +type statement = (loc,Ast.stmt,hir) Ast.generic_ast module M = HirMonad.Make( struct type t = statement - let mempty : t = {info=dummy_pos; stmt=Skip} - let mconcat : t -> t -> t = fun x y -> {info=dummy_pos; stmt=Seq (x,y)} + let mempty = Ast.buildStmt dummy_pos Skip + let mconcat = fun x y -> Ast.buildStmt dummy_pos (Seq (x,y)) end ) open M module D = SailModule.DeclEnv -let lower_expression (e : AstParser.expression) : expression M.t = +let lower_expression (exp : AstParser.expression) : expression M.t = let open UseMonad(M) in - let rec aux (info,e : AstParser.expression) : expression M.t = - let open AstHir in - match e with - | Variable id -> - let* v = (M.ECS.find_var id |> M.lift) in - M.throw_if_none Logging.(make_msg info @@ Fmt.str "undeclared variable '%s'" id) v >>| fun _ -> - {info; exp=Variable id} + let rec aux (exp : AstParser.expression) : expression M.t = + let open Ast in + let v = match exp.value with + | Variable id -> + let* v = (M.ECS.find_var id |> M.lift) in + M.throw_if_none Logging.(make_msg exp.loc @@ Fmt.str "undeclared variable '%s'" id) v >>| fun _ -> + Variable id - | Deref e -> - let+ e = aux e in {info;exp=Deref e} + | Deref e -> + let+ e = aux e in Deref e - | StructRead (e, id) -> - let+ e = aux e in {info; exp=StructRead (None, e, id)} + | StructRead (e, field) -> + let+ strct = aux e in StructRead {import=None;value={field;strct}} - | ArrayRead (e1, e2) -> - let* e1 = aux e1 in - let+ e2 = aux e2 in - {info; exp=ArrayRead(e1,e2)} - | Literal l -> return {info; exp=Literal l} + | ArrayRead (array, idx) -> + let+ array = aux array + and* idx = aux idx in + ArrayRead {array;idx} - | UnOp (op, e) -> - let+ e = aux e in {info;exp=UnOp (op, e)} + | Literal l -> return @@ Literal l - | BinOp(op,e1,e2)-> - let* e1 = aux e1 in - let+ e2 = aux e2 in - {info; exp=BinOp (op, e1, e2)} + | UnOp (op, e) -> + let+ e = aux e in UnOp (op, e) - | Ref (b, e) -> - let+ e = aux e in {info;exp=Ref(b, e)} + | BinOp(op,left,right)-> + let+ left = aux left + and* right = aux right in + BinOp {op;left;right} - | ArrayStatic el -> - let+ el = ListM.map aux el in {info;exp=ArrayStatic el} + | Ref (m,e) -> let+ e = aux e in Ref (m,e) - | StructAlloc (origin,id, m) -> - let m' = List.sort_uniq (fun (id1,_) (id2,_) -> String.compare id1 id2) m in - let* () = M.throw_if Logging.(make_msg info "duplicate fields") List.(length m <> length m') in - let+ m' = ListM.map (aux |> pairMap2 |> pairMap2) m' in - {info; exp=StructAlloc (origin, id, m')} + | ArrayStatic el -> let+ el = ListM.map aux el in ArrayStatic el - | EnumAlloc (id, el) -> - let+ el = ListM.map aux el in {info;exp=EnumAlloc (id, el)} + | StructAlloc (import,name,fields) -> + let fields' = List.sort_uniq (fun (id1,_) (id2,_) -> String.compare id1 id2) fields in + let* () = M.throw_if Logging.(make_msg exp.loc "duplicate fields") List.(length fields <> length fields') in + let+ fields = ListM.map (pairMap2 (fun e -> let+ value = aux e.value in {e with value})) fields' in + StructAlloc (mk_importable import {name;fields}) - - | MethodCall (import, id, el) -> let+ el = ListM.map aux el in {info ; exp=MethodCall(id, import, el)} - in aux e + | EnumAlloc (id, el) -> let+ el = ListM.map aux el in EnumAlloc (id, el) + + + | MethodCall (import, id, args) -> let+ args = ListM.map aux args in MethodCall (mk_importable import {id;args;ret_var=None}) + + in v <&> buildExp exp.loc + in aux exp open UseMonad(M.E) -let find_symbol_source ?(filt = [E (); S (); T ()] ) (loc,id: l_str) (import : l_str option) (env : D.t) : (l_str * D.decls) M.E.t = +let find_symbol_source ?(filt = [E (); S (); T ()] ) (id: l_str) (import : l_str option) (env : D.t) : (l_str * D.decls) M.E.t = match import with - | Some (iloc,name) -> - if name = Constants.sail_module_self || name = D.get_name env then + | Some iname -> + if iname.value = Constants.sail_module_self || iname.value = D.get_name env then let+ decl = - D.find_decl id (Self (Filter filt)) env - |> M.E.throw_if_none Logging.(make_msg loc @@ "no declaration named '" ^ id ^ "' in current module ") + D.find_decl id.value (Self (Filter filt)) env + |> M.E.throw_if_none Logging.(make_msg id.loc @@ "no declaration named '" ^ id.value ^ "' in current module ") in - (iloc,D.get_name env),decl + {iname with value=D.get_name env},decl else let+ t = - M.E.throw_if_none Logging.(make_msg iloc ~hint:(Some (None,Fmt.str "try importing the module with 'import %s'" name)) @@ "unknown module " ^ name) - (List.find_opt (fun {mname;_} -> mname = name) (D.get_imports env)) >>= fun _ -> - M.E.throw_if_none Logging.(make_msg loc @@ "declaration " ^ id ^ " not found in module " ^ name) - (D.find_decl id (Specific (name, Filter filt)) env) + M.E.throw_if_none Logging.(make_msg iname.loc ~hint:(Some (None,Fmt.str "try importing the module with 'import %s'" iname.value)) @@ "unknown module " ^ iname.value) + (List.find_opt (fun {mname;_} -> mname = iname.value) (D.get_imports env)) >>= fun _ -> + M.E.throw_if_none Logging.(make_msg id.loc @@ "declaration " ^ id.value ^ " not found in module " ^ iname.value) + (D.find_decl id.value (Specific (iname.value, Filter filt)) env) in - (iloc,name),t + iname,t | None -> (* find it ourselves *) begin - let decl = D.find_decl id (All (Filter filt)) env in + let decl = D.find_decl id.value (All (Filter filt)) env in match decl with | [i,m] -> (* Logs.debug (fun m -> m "'%s' is from %s" id i.mname); *) - return ((dummy_pos,i.mname),m) + return (mk_locatable dummy_pos i.mname,m) - | [] -> M.E.throw Logging.(make_msg loc @@ "unknown declaration " ^ id) + | [] -> M.E.throw Logging.(make_msg id.loc @@ "unknown declaration " ^ id.value) | _ as choice -> M.E.throw - @@ Logging.make_msg loc ~hint:(Some (None,"specify one with '::' annotation")) - @@ Fmt.str "multiple definitions for declaration %s : \n\t%s" id + @@ Logging.make_msg id.loc ~hint:(Some (None,"specify one with '::' annotation")) + @@ Fmt.str "multiple definitions for declaration %s : \n\t%s" id .value (List.map (fun (i,def) -> match def with T def -> Fmt.str "from %s : %s" i.mname (string_of_sailtype (def.ty)) | _ -> "") choice |> String.concat "\n\t") end @@ -106,10 +109,10 @@ let follow_type ty env : (sailtype * D.t) M.E.t = (* Logs.debug (fun m -> m "following type '%s'" (string_of_sailtype (Some ty))); *) - let rec aux (l_ty,ty') path : (sailtype * ty_defn list) M.E.t = + let rec aux ty path : (sailtype * ty_defn list) M.E.t = (* Logs.debug (fun m -> m "path: %s" (List.map (fun ({name;_}:ty_defn) -> name)path |> String.concat " ")); *) let+ (t,path : sailtype_ * ty_defn list) = - match ty' with + match ty.value with | ArrayType (t,n) -> let+ t,path = aux t path in ArrayType (t,n),path | Box t -> let+ t,path = aux t path in Box t,path @@ -118,11 +121,11 @@ let follow_type ty env : (sailtype * D.t) M.E.t = (* Logs.debug (fun m -> m "'%s' resolves to '%s'" (string_of_sailtype (Some ty)) (string_of_sailtype (Some ty'))); *) return (t,path) | CompoundType {origin;name=id;generic_instances;_} -> (* compound type, find location and definition *) - let* (l,origin),def = find_symbol_source id origin env in - let default = fun ty -> CompoundType {origin=Some (l,origin);name=id; generic_instances;decl_ty=Some ty} in + let* origin,def = find_symbol_source id origin env in + let default = fun ty -> CompoundType {origin=Some origin;name=id; generic_instances;decl_ty=Some ty} in begin match def with - | T def when origin=current -> + | T def when origin.value = current -> begin match def.ty with | Some ty -> ( @@ -133,7 +136,7 @@ let follow_type ty env : (sailtype * D.t) M.E.t = ) (List.find_opt (fun (d:ty_defn) -> d.name = def.name) path) - >>= fun () -> let+ ((_,t),p) = aux ty (def::path) in t,p + >>= fun () -> let+ (t,p) = aux ty (def::path) in t.value,p ) | None -> (* abstract type *) (* Logs.debug (fun m -> m "'%s' resolves to abstract type '%s' " (string_of_sailtype (Some ty)) def.name); *) @@ -141,7 +144,7 @@ let follow_type ty env : (sailtype * D.t) M.E.t = end | _ -> return (default @@ unit_decl_of_decl def,path) (* must point to an enum or struct, nothing to resolve *) end - in (l_ty,t),path + in {ty with value=t},path in let+ res,p = aux ty [] in (* p only contains type_def from the current module *) @@ -161,12 +164,12 @@ let check_non_cyclic_struct (name:string) (l,proto) env : unit M.E.t = (List.mem id checked) in let checked = id::checked in ListM.iter ( - fun (_,(l,t,_)) -> match snd t with + fun (_,{value=(ty,_);_}) -> match ty.value with - | CompoundType {name=_,name;origin=Some (_,origin); decl_ty = Some S ();_} -> + | CompoundType {name;origin=Some origin; decl_ty = Some S ();_} -> begin - match D.find_decl name (Specific (origin,(Filter [S ()]))) env with - | Some (S (_,d)) -> aux name l d checked + match D.find_decl name.value (Specific (origin.value,(Filter [S ()]))) env with + | Some (S (_,d)) -> aux name.value l d checked | _ -> failwith "invariant : all compound types must have a correct origin and type at this step" end | CompoundType {origin=None;decl_ty=None;_} -> M.E.throw Logging.(make_msg l "follow type not called") @@ -175,66 +178,6 @@ let check_non_cyclic_struct (name:string) (l,proto) env : unit M.E.t = in aux name l proto [] -let rename_var_exp (f: string -> string) (e: _ AstHir.expression) = - let open AstHir in - let rec aux (e : _ expression) = - let buildExp = buildExp e.info in - match e.exp with - | Variable id -> buildExp @@ Variable (f id) - | Deref e -> let e = aux e in buildExp @@ Deref e - | StructRead (mod_loc,e, id) -> let e = aux e in buildExp @@ StructRead(mod_loc,e,id) - | ArrayRead (e1, e2) -> - let e1 = aux e1 in - let e2 = aux e2 in - buildExp @@ ArrayRead (e1,e2) - | Literal _ as l -> buildExp l - | UnOp (op, e) -> let e = aux e in buildExp @@ UnOp (op,e) - | BinOp(op,e1,e2)-> - let e1 = aux e1 in - let e2 = aux e2 in - buildExp @@ BinOp(op,e1,e2) - | Ref (b, e) -> - let e = aux e in buildExp @@ Ref(b,e) - | ArrayStatic el -> let el = List.map aux el in buildExp @@ ArrayStatic el - | StructAlloc (origin,id, m) -> let m = List.map (fun (n,(l,e)) -> n,(l,aux e)) m in buildExp @@ StructAlloc (origin,id,m) - | EnumAlloc (id, el) -> let el = List.map aux el in buildExp @@ EnumAlloc (id,el) - | MethodCall (mod_loc, id, el) -> let el = List.map aux el in buildExp @@ MethodCall (mod_loc,id,el) - in aux e - - let rename_var_stmt (f:string -> string) s = - let open AstHir in - let rec aux (s : _ statement) = - let buildStmt = buildStmt s.info in - match s.stmt with - | DeclVar (mut, id, opt_t,opt_exp) -> - let e = MonadOption.M.fmap (rename_var_exp f) opt_exp in - buildStmt @@ DeclVar (mut,f id,opt_t,e) - | Assign(e1, e2) -> - let e1 = rename_var_exp f e1 - and e2 = rename_var_exp f e2 in - buildStmt @@ Assign(e1, e2) - | Seq(c1, c2) -> - let c1 = aux c1 in - let c2 = aux c2 in - buildStmt @@ Seq(c1, c2) - | If(cond_exp, then_s, else_s) -> - let cond_exp = rename_var_exp f cond_exp in - let then_s = aux then_s in - let else_s = MonadOption.M.fmap aux else_s in - buildStmt (If(cond_exp, then_s, else_s)) - | Loop c -> let c = aux c in buildStmt (Loop c) - | Break -> buildStmt Break - | Case(e, _cases) -> let e = rename_var_exp f e in buildStmt (Case (e, [])) - | Invoke i -> - let args = List.map (rename_var_exp f) i.args in - let ret_var = MonadOption.M.fmap f i.ret_var in - buildStmt @@ Invoke {i with ret_var;args} - | Return e -> let e = MonadOption.M.fmap (rename_var_exp f) e in buildStmt @@ Return e - | Block c -> let c = aux c in buildStmt (Block c) - | Skip -> buildStmt Skip - in - aux s - let resolve_names (sm : ('a,'b) SailModule.methods_processes SailModule.t) = let module ES = struct @@ -275,11 +218,11 @@ let resolve_names (sm : ('a,'b) SailModule.methods_processes SailModule.t) = let* () = SEnv.iter ( fun (id,(l,{fields; generics})) -> let* fields = ListM.map ( - fun (name,(l,t,n)) -> + fun (name,({value=(t,i);_} as f)) -> let* env = ES.get_env in let* t,env = (follow_type t env) |> ES.S.lift in let+ () = ES.set_env env in - name,(l,t,n) + name,{f with value=t,i} ) fields in let proto = l,{fields;generics} in @@ -304,7 +247,7 @@ let resolve_names (sm : ('a,'b) SailModule.methods_processes SailModule.t) = ) m_proto.params in let m = {m with m_proto={m_proto with params; rtype}} in let true_name = (match m_body with Left (sname,_) -> sname | Right _ -> m_proto.name) in - let+ () = ES.update_env (update_decl m_proto.name ((m_proto.pos,true_name), defn_to_proto (Method m)) (Self Method)) + let+ () = ES.update_env (update_decl m_proto.name (mk_locatable m_proto.pos true_name, defn_to_proto (Method m)) (Self Method)) in m ) sm.body.methods in diff --git a/src/passes/ir/sailHir/pp_hir.ml b/src/passes/ir/sailHir/pp_hir.ml index 05b04f3..cd1feab 100644 --- a/src/passes/ir/sailHir/pp_hir.ml +++ b/src/passes/ir/sailHir/pp_hir.ml @@ -1,57 +1,59 @@ open Common open PpCommon open Format -open AstHir +(* open HirAst *) open Hir +open HirUtils +open TypesCommon - -let rec ppPrintExpression (pf : Format.formatter) (e : expression) : unit = +let rec ppPrintExpression (pf : Format.formatter) (e : expression) : unit = let open Format in - match e.exp with + match e.node with | Variable s -> fprintf pf "%s" s - | Deref e -> fprintf pf "*%a" ppPrintExpression e - | StructRead (_,e, (_,s)) -> fprintf pf "%a.%s" ppPrintExpression e s - | ArrayRead (e1, e2) -> fprintf pf "%a[%a]" ppPrintExpression e1 ppPrintExpression e2 + | Deref e -> fprintf pf "*%a" ppPrintExpression e + | StructRead st -> fprintf pf "%a.%s" ppPrintExpression st.value.strct st.value.field.value + | ArrayRead ar -> fprintf pf "%a[%a]" ppPrintExpression ar.array ppPrintExpression ar.idx | Literal (l) -> fprintf pf "%a" PpCommon.pp_literal l | UnOp (o, e) -> fprintf pf "%a %a" pp_unop o ppPrintExpression e - | BinOp ( o, e1, e2) -> fprintf pf "%a %a %a" ppPrintExpression e1 pp_binop o ppPrintExpression e2 + | BinOp bop -> fprintf pf "%a %a %a" ppPrintExpression bop.left pp_binop bop.op ppPrintExpression bop.right | Ref (true,e) -> fprintf pf "&mut %a" ppPrintExpression e | Ref (false,e) -> fprintf pf "&%a" ppPrintExpression e | ArrayStatic el -> fprintf pf "[%a]" (pp_print_list ~pp_sep:pp_comma ppPrintExpression) el - |StructAlloc (_,id, m) -> - let pp_field pf (x, (_,y)) = fprintf pf "%s:%a" x ppPrintExpression y in - fprintf pf "%s{%a}" (snd id) + | StructAlloc st -> + let pp_field pf (x, (y: _ locatable)) = fprintf pf "%s:%a" x ppPrintExpression y.value in + fprintf pf "%s{%a}" st.value.name.value (pp_print_list ~pp_sep:pp_comma pp_field) - m + st.value.fields | EnumAlloc (id,el) -> - fprintf pf "[%s(%a)]" (snd id) + fprintf pf "[%s(%a)]" id.value (pp_print_list ~pp_sep:pp_comma ppPrintExpression) el - | MethodCall ((_,id),mod_loc,el) -> + | MethodCall m -> fprintf pf "%a%s(%a)" - (pp_print_option (fun fmt (_,ml) -> fprintf fmt "%s::" ml)) mod_loc - id - (pp_print_list ~pp_sep:pp_comma ppPrintExpression) el + (pp_print_option (fun fmt ml -> fprintf fmt "%s::" ml.value)) m.import + m.value.id.value + (pp_print_list ~pp_sep:pp_comma ppPrintExpression) m.value.args + +let rec ppPrintStatement (pf : Format.formatter) (s : statement) : unit = match s.node with +| DeclVar d -> fprintf pf "\nvar %s%a%a;" d.id + (pp_print_option (fun fmt -> fprintf fmt " : %a" pp_type)) d.ty + (pp_print_option (fun fmt -> fprintf fmt " = %a" ppPrintExpression)) d.value -let rec ppPrintStatement (pf : Format.formatter) (s : statement) : unit = match s.stmt with -| DeclVar (_mut, id, opt_t,opt_exp) -> fprintf pf "\nvar %s%a%a;" id - (pp_print_option (fun fmt -> fprintf fmt " : %a" pp_type)) opt_t - (pp_print_option (fun fmt -> fprintf fmt " = %a" ppPrintExpression)) opt_exp -| Assign(e1, e2) -> fprintf pf "\n%a = %a;" ppPrintExpression e1 ppPrintExpression e2 +| Assign a -> fprintf pf "\n%a = %a;" ppPrintExpression a.path ppPrintExpression a.value | Seq(c1, c2) -> fprintf pf "%a%a" ppPrintStatement c1 ppPrintStatement c2 -| If(cond_exp, then_s,else_s) -> fprintf pf "\nif (%a) {\n%a\n}\n%a" - ppPrintExpression cond_exp - ppPrintStatement then_s - (pp_print_option (fun pf -> fprintf pf "else {%a\n}" ppPrintStatement)) else_s +| If if_ -> fprintf pf "\nif (%a) {\n%a\n}\n%a" + ppPrintExpression if_.cond + ppPrintStatement if_.then_ + (pp_print_option (fun pf -> fprintf pf "else {%a\n}" ppPrintStatement)) if_.else_ | Loop c -> fprintf pf "\nloop {%a\n}" ppPrintStatement c | Break -> fprintf pf "break;" -| Case(_e, _cases) -> () +| Case _ -> () | Invoke i -> fprintf pf "\n%a%a%s(%a);" - (pp_print_option (fun fmt v -> fprintf fmt "%s = " v)) i.ret_var - (pp_print_option (fun fmt (_,ml) -> fprintf fmt "%s::" ml)) i.import - (snd i.id) - (pp_print_list ~pp_sep:pp_comma ppPrintExpression) i.args + (pp_print_option (fun fmt v -> fprintf fmt "%s = " v)) i.value.ret_var + (pp_print_option (fun fmt ml -> fprintf fmt "%s::" ml.value)) i.import + i.value.id.value + (pp_print_list ~pp_sep:pp_comma ppPrintExpression) i.value.args | Return e -> fprintf pf "\nreturn %a;" (pp_print_option ppPrintExpression) e | Block c -> fprintf pf "\n{\n@[ %a @]\n}" ppPrintStatement c | Skip -> () diff --git a/src/passes/ir/sailMir/mir.ml b/src/passes/ir/sailMir/mir.ml index 23f0387..dd21895 100644 --- a/src/passes/ir/sailMir/mir.ml +++ b/src/passes/ir/sailMir/mir.ml @@ -1,4 +1,4 @@ -open AstMir +open MirAst open Common open TypesCommon open Monad @@ -16,78 +16,68 @@ module Pass = MakeFunctionPass(V)( struct let name = "MIR" - type m_in = Thir.statement + type m_in = ThirUtils.statement type m_out = mir_function type p_in = (HirUtils.statement,HirUtils.expression) AstParser.process_body type p_out = p_in - let rec lexpr (e : Thir.expression) : expression M.t = - let open AstHir in - let lt = e.info in - match e.exp with + let rec lexpr (exp : ThirUtils.expression) : expression M.t = + let open IrAst.Ast in + match exp.node with | Variable name -> let* id = find_scoped_var name in - let+ () = M.update_var (fst lt) id assign_var in buildExp lt (Variable id) + let+ () = M.update_var exp.tag.loc id assign_var in buildExp exp.tag (Variable id) | Deref e -> rexpr e - | ArrayRead (e1, e2) -> let+ e1' = lexpr e1 and* e2' = rexpr e2 in buildExp lt (ArrayRead(e1',e2')) - | StructRead (origin,e,field) -> let+ e = lexpr e in buildExp lt (StructRead (origin,e,field)) - | Ref _ -> M.error Logging.(make_msg (fst lt) "todo") - | _ -> M.error Logging.(make_msg (fst lt) @@ "thir didn't lower correctly this expression") - and rexpr (e : Thir.expression) : expression M.t = - let lt = e.info in - let open AstHir in - match e.exp with + | ArrayRead a -> let+ array = lexpr a.array and* idx = rexpr a.idx in + buildExp exp.tag (ArrayRead {array;idx}) + + | StructRead2 s -> let+ strct = lexpr s.value.strct in buildExp exp.tag (StructRead2 {s with value={s.value with strct}}) + + | Ref _ -> M.error Logging.(make_msg exp.tag.loc "todo") + + | Literal _ | UnOp _ | BinOp _ | ArrayStatic _ | StructAlloc2 _ | EnumAlloc _ -> + M.error Logging.(make_msg exp.tag.loc @@ "compiler error : not a lexpr") + + and rexpr (exp : ThirUtils.expression) : expression M.t = + let open IrAst.Ast in + match exp.node with | Variable name -> - let+ id = find_scoped_var name in buildExp lt (Variable id) - | Literal l -> buildExp lt (Literal l) |> M.pure + let+ id = find_scoped_var name in buildExp exp.tag (Variable id) + | Literal l -> buildExp exp.tag (Literal l) |> M.pure | Deref e -> lexpr e - | ArrayRead (array_exp,idx) -> let+ arr = rexpr array_exp and* idx' = rexpr idx in buildExp lt (ArrayRead(arr,idx')) - | UnOp (o, e) -> let+ e' = rexpr e in buildExp lt (UnOp (o, e')) - | BinOp (o ,e1, e2) -> let+ e1' = rexpr e1 and* e2' = rexpr e2 in buildExp lt (BinOp(o, e1', e2')) - | Ref (b, e) -> let+ e' = rexpr e in buildExp lt (Ref(b, e')) - | ArrayStatic el -> let+ el' = ListM.map rexpr el in buildExp lt (ArrayStatic el') - | StructRead (origin,struct_exp,field) -> - let+ exp = rexpr struct_exp in - buildExp lt (StructRead (origin,exp,field)) + | ArrayRead a -> let+ array = rexpr a.array and* idx = rexpr a.idx in buildExp exp.tag (ArrayRead {array;idx}) + | UnOp (o, e) -> let+ e' = rexpr e in buildExp exp.tag (UnOp (o, e')) + | BinOp bop -> let+ left = rexpr bop.left and* right = rexpr bop.right in buildExp exp.tag (BinOp {bop with left;right}) + | Ref (b, e) -> let+ e' = rexpr e in buildExp exp.tag (Ref(b, e')) + | ArrayStatic el -> let+ el' = ListM.map rexpr el in buildExp exp.tag (ArrayStatic el') + | StructRead2 s -> + let+ strct = rexpr s.value.strct in + buildExp exp.tag (StructRead2 {s with value={s.value with strct}}) - | StructAlloc (origin,id, fields) -> - let+ fields = ListM.map (rexpr |> pairMap2 |> pairMap2) fields in - buildExp lt (StructAlloc(origin,id,fields)) - | MethodCall _ - | _ -> M.error @@ Logging.(make_msg (fst lt) @@ "thir didn't lower correctly this expression") - + | StructAlloc2 s -> + let+ fields = ListM.map (fun f -> pairMap2 (fun f -> let+ value = rexpr f.value in {f with value}) f) s.value.fields in + buildExp exp.tag (StructAlloc2 {s with value={s.value with fields}}) + | EnumAlloc _ -> M.error @@ Logging.(make_msg exp.tag.loc @@ "compiler error : not a rexpr") open UseMonad(M.E) let lower_method (body,_ : m_in * method_sig) (env,tenv) (_sm: (m_in,p_in) SailModule.methods_processes SailModule.t) : (m_out * SailModule.DeclEnv.t * _) M.E.t = - let rec aux (s : Thir.statement) : m_out M.t = + let rec aux (s : ThirUtils.statement) : m_out M.t = let open UseMonad(M) in - let loc = s.info in - match s.stmt with - | DeclVar(mut, id, Some ty, None) -> + let loc = s.tag.loc in + match s.node with + | DeclVar2 d -> let* bb = emptyBasicBlock loc in - let* id = M.fresh_scoped_var >>| get_scoped_var id in - let+ () = M.declare_var loc id {ty;mut;id;loc} in - [{location=loc; mut; id; varType=ty}],bb - - | DeclVar(mut, id, Some ty, Some e) -> - let* id_ty = M.get_type_id ty in - let* expression = rexpr e in - let* id = M.fresh_scoped_var >>| get_scoped_var id in - let* () = M.declare_var loc id {ty;mut;id;loc} in - let target = AstHir.buildExp (loc,id_ty) (Variable id) in - let+ bn = assignBasicBlock loc {location=loc; target; expression } in - [{location=loc; mut; id=id; varType=ty}],bn - (* ++ other statements *) - - | DeclVar (_,name,None,_) -> failwith @@ "thir broken : variable declaration should have a type : " ^name + let* id = M.fresh_scoped_var >>| get_scoped_var d.id in + let+ () = M.declare_var loc id {ty=d.ty;mut=d.mut;id;loc} in + [{location=loc; mut=d.mut; id; varType=d.ty}],bb | Skip -> let+ bb = emptyBasicBlock loc in ([], bb) - | Assign (e1, e2) -> - let* expression = rexpr e2 and* target = lexpr e1 in + | Assign a -> + let* expression = rexpr a.value and* target = lexpr a.path in let+ bb = assignBasicBlock loc {location=loc; target; expression} in [],bb | Seq (s1, s2) -> @@ -97,16 +87,16 @@ struct (* let* () = M.set_env env in *) let+ bb = buildSeq cfg1 cfg2 in d1@d2,bb - | If (e, s, None) -> - let* e' = rexpr e in - let* d, cfg = aux s in - let+ ite = buildIfThen loc e' cfg in + | If ({else_=None;_} as if_) -> + let* cond = rexpr if_.cond in + let* d, cfg = aux if_.then_ in + let+ ite = buildIfThen loc cond cfg in (d,ite) - | If (e, s1, Some s2) -> - let* e' = rexpr e in - let* d1,cfg1 = aux s1 and* d2,cfg2 = aux s2 in - let+ ite = buildIfThenElse loc e' cfg1 cfg2 in + | If ({else_=Some else_;_} as if_) -> + let* cond = rexpr if_.cond in + let* d1,cfg1 = aux if_.then_ and* d2,cfg2 = aux else_ in + let+ ite = buildIfThenElse loc cond cfg1 cfg2 in (d1@d2, ite) | Loop s -> @@ -120,12 +110,12 @@ struct let+ cfg = singleBlock bb in ([],cfg) - | Invoke i -> - let* ((_,realname),_) = M.throw_if_none Logging.(make_msg loc @@ Fmt.str "Compiler Error : function '%s' must exist" (snd i.id)) - (SailModule.DeclEnv.find_decl (snd i.id) (Specific (snd i.import,Method)) (snd env)) + | Invoke2 i -> + let* (realname,_) = M.throw_if_none Logging.(make_msg loc @@ Fmt.str "Compiler Error : function '%s' must exist" i.value.id.value) + (SailModule.DeclEnv.find_decl i.value.id.value (Specific (i.import.value,Method)) (snd env)) in - let* args = ListM.map rexpr i.args in - let+ invoke = buildInvoke loc i.import (fst i.id,realname) i.ret_var args in + let* args = ListM.map rexpr i.value.args in + let+ invoke = buildInvoke loc i.import (mk_locatable i.value.id.loc realname.value) i.value.ret_var args in ([], invoke) | Return e -> diff --git a/src/passes/ir/sailMir/astMir.ml b/src/passes/ir/sailMir/mirAst.ml similarity index 96% rename from src/passes/ir/sailMir/astMir.ml rename to src/passes/ir/sailMir/mirAst.ml index 22f2b24..bc482d4 100644 --- a/src/passes/ir/sailMir/astMir.ml +++ b/src/passes/ir/sailMir/mirAst.ml @@ -31,8 +31,8 @@ type statement = Assign of lvalue * rvalue | Drop of drop_kind * lvalue *) -type expression = Thir.expression -type statement = Thir.statement +type expression = ThirUtils.expression +type statement = ThirUtils.statement type declaration = {location : loc; mut : bool; id : string; varType : sailtype} type assignment = {location : loc; target : expression; expression : expression} diff --git a/src/passes/ir/sailMir/mirMonad.ml b/src/passes/ir/sailMir/mirMonad.ml index b3af329..8c52509 100644 --- a/src/passes/ir/sailMir/mirMonad.ml +++ b/src/passes/ir/sailMir/mirMonad.ml @@ -1,4 +1,4 @@ -open AstMir +open MirAst open Common module M = struct diff --git a/src/passes/ir/sailMir/mirUtils.ml b/src/passes/ir/sailMir/mirUtils.ml index f6f755d..2baee78 100644 --- a/src/passes/ir/sailMir/mirUtils.ml +++ b/src/passes/ir/sailMir/mirUtils.ml @@ -1,9 +1,10 @@ -open AstMir +open MirAst open Common open TypesCommon open Monad open MirMonad open UseMonad(M) + let assign_var (var_l,v:VE.variable) = (var_l,v) |> M.E.pure @@ -179,7 +180,7 @@ let buildInvoke (l : loc) (origin:l_str) (id : l_str) (target : string option) ( forward_info = env; backward_info = (); location=l; - terminator = Some (Invoke {id = (snd id); origin; target; params = el; next = returnLbl}) + terminator = Some (Invoke {id=id.value; origin; target; params = el; next = returnLbl}) } in let returnBlock = {assignments = []; predecessors = LabelSet.singleton invokeLbl ; forward_info = env; backward_info = () ; location = dummy_pos; terminator = None} in { @@ -220,7 +221,8 @@ let find_scoped_var name : string M.t = let seqOfList (l : statement list) : statement = - List.fold_left (fun s l : statement -> {info=dummy_pos; stmt=Seq (s, l)}) {info=dummy_pos;stmt=Skip} l + let tag = IrThir.ThirUtils.empty_tag in + List.fold_left (fun s l : statement -> {tag; node=Seq (s, l)}) {tag;node=Skip} l diff --git a/src/passes/ir/sailMir/pp_mir.ml b/src/passes/ir/sailMir/pp_mir.ml index e1f7509..b1732ae 100644 --- a/src/passes/ir/sailMir/pp_mir.ml +++ b/src/passes/ir/sailMir/pp_mir.ml @@ -1,30 +1,30 @@ open Common open PpCommon open Format -open AstMir +open MirAst +open TypesCommon -let rec ppPrintExpression (pf : Format.formatter) (e : AstMir.expression) : unit = - match e.exp with +let rec ppPrintExpression (pf : Format.formatter) (e : MirAst.expression) : unit = + match e.node with | Variable s -> fprintf pf "%s" s | Deref e -> fprintf pf "*%a" ppPrintExpression e - | StructRead (_,e, (_,s)) -> fprintf pf "%a.%s" ppPrintExpression e s - | ArrayRead (e1, e2) -> fprintf pf "%a[%a]" ppPrintExpression e1 ppPrintExpression e2 + | StructRead2 s -> fprintf pf "%a.%s" ppPrintExpression s.value.strct s.value.field.value + | ArrayRead a -> fprintf pf "%a[%a]" ppPrintExpression a.array ppPrintExpression a.idx | Literal (l) -> fprintf pf "%a" PpCommon.pp_literal l | UnOp (o, e) -> fprintf pf "%a %a" pp_unop o ppPrintExpression e - | BinOp ( o, e1, e2) -> fprintf pf "%a %a %a" ppPrintExpression e1 pp_binop o ppPrintExpression e2 + | BinOp bop -> fprintf pf "%a %a %a" ppPrintExpression bop.left pp_binop bop.op ppPrintExpression bop.right | Ref (true,e) -> fprintf pf "&mut %a" ppPrintExpression e | Ref (false,e) -> fprintf pf "&%a" ppPrintExpression e | ArrayStatic el -> Format.fprintf pf "[%a]" (Format.pp_print_list ~pp_sep:pp_comma ppPrintExpression) el - |StructAlloc (_,id, m) -> - let pp_field pf (x, (_ , y)) = Format.fprintf pf "%s:%a" x ppPrintExpression y in - Format.fprintf pf "%s{%a}" (snd id) - (Format.pp_print_list ~pp_sep:pp_comma pp_field) m + |StructAlloc2 s -> + let pp_field pf (x, (y: 'a locatable)) = Format.fprintf pf "%s:%a" x ppPrintExpression y.value in + Format.fprintf pf "%s{%a}" s.value.name.value + (Format.pp_print_list ~pp_sep:pp_comma pp_field) s.value.fields | EnumAlloc (id,el) -> - Format.fprintf pf "[%s(%a)]" (snd id) + Format.fprintf pf "[%s(%a)]" id.value (Format.pp_print_list ~pp_sep:pp_comma ppPrintExpression) el - | MethodCall _ -> () let ppPrintPredecessors (pf : Format.formatter) (preds : LabelSet.t ) : unit = if LabelSet.is_empty preds then fprintf pf "// no precedessors" @@ -40,8 +40,8 @@ let ppPrintAssignement (pf : Format.formatter) (a : assignment) : unit = let ppPrintTerminator (pf : Format.formatter) (t : terminator) : unit = match t with | Goto lbl -> fprintf pf "\t\tgoto %d;" lbl - | Invoke {id; params;next;origin=(_,mname);target} -> fprintf pf "\t\t%a%s(%a) -> [return: bb%d]" - (Format.pp_print_option (fun fmt id -> fprintf fmt "%s = %s::" id mname)) target + | Invoke {id; params;next;origin;target} -> fprintf pf "\t\t%a%s(%a) -> [return: bb%d]" + (Format.pp_print_option (fun fmt id -> fprintf fmt "%s = %s::" id origin.value)) target id (Format.pp_print_list ~pp_sep:pp_comma ppPrintExpression) params next diff --git a/src/passes/ir/sailThir/thir.ml b/src/passes/ir/sailThir/thir.ml index fa93a4c..df9b8bc 100644 --- a/src/passes/ir/sailThir/thir.ml +++ b/src/passes/ir/sailThir/thir.ml @@ -3,251 +3,256 @@ open TypesCommon open Logging open Monad open IrHir -open AstHir open SailParser +open IrAst open ThirUtils open M open UseMonad(M) module SM = SailModule - -type expression = ThirUtils.expression -type statement = ThirUtils.statement - module Pass = Pass.MakeFunctionPass(V)( struct let name = "THIR" type m_in = HirUtils.statement - type m_out = statement + type m_out = ThirUtils.statement type p_in = (m_in,HirUtils.expression) AstParser.process_body type p_out = p_in - let rec lower_lexp (e : Hir.expression) : expression M.t = - let rec aux (e:Hir.expression) : expression M.t = - let loc = e.info in match e.exp with + + let rec lower_lexp (exp : HirUtils.expression) : expression M.t = + let rec aux (exp:HirUtils.expression) : expression M.t = + + match exp.node with | Variable id -> - let* _,t = M.get_var id >>= M.throw_if_none (make_msg loc @@ Printf.sprintf "unknown variable %s" id) in + let* _,t = M.get_var id >>= M.throw_if_none (make_msg exp.tag @@ Printf.sprintf "unknown variable %s" id) in let* venv,tenv = M.get_env in let t,tenv = t tenv in let+ () = M.set_env (venv,tenv) in - buildExp (loc,t) @@ Variable id + buildExp exp.tag t @@ Variable id | Deref e -> let* e = lower_rexp e in (* return @@ Deref((l,extract_exp_loc_ty e |> snd), e) *) begin - match e.exp with - | Ref (_,r) -> return @@ buildExp r.info @@ Deref e + match e.node with + | Ref (_,r) -> return @@ buildExp r.tag.loc r.tag.ty @@ Deref e | _ -> return e end - | ArrayRead (array_exp,idx) -> - let* array_exp = aux array_exp and* idx_exp = lower_rexp idx in - let* array_ty = M.get_type_from_id (array_exp.info) - and* idx_ty = M.get_type_from_id (idx_exp.info) in + + | ArrayRead ar -> + let* array = aux ar.array and* idx = lower_rexp ar.idx in + let* array_ty = M.get_type_from_id (mk_locatable array.tag.loc array.tag.ty) and* idx_ty = M.get_type_from_id (mk_locatable idx.tag.loc idx.tag.ty) in begin - match array_ty with - | l,ArrayType (t,sz) -> + match array_ty.value with + | ArrayType (t,sz) -> let* t = M.get_type_id t in - let* _ = matchArgParam l idx_ty (dummy_pos,Int 32) |> M.ESC.lift |> M.lift in + let* _ = matchArgParam array_ty.loc idx_ty (mk_locatable dummy_pos @@ Int 32) |> M.ESC.lift |> M.lift in begin (* can do a simple oob check if the type is an int literal *) - match idx.exp with + match idx.node with | Literal (LInt n) -> - M.throw_if (make_msg (fst idx_exp.info) @@ Printf.sprintf "index out of bounds : must be between 0 and %i (got %s)" + M.throw_if (make_msg idx.tag.loc @@ Printf.sprintf "index out of bounds : must be between 0 and %i (got %s)" (sz - 1) Z.(to_string n.l) ) Z.( n.l < ~$0 || n.l > ~$sz) | _ -> return () - end >>| fun () -> buildExp (loc,t) @@ ArrayRead (array_exp,idx_exp) - | _ -> M.throw (make_msg loc "not an array !") + end >>| fun () -> buildExp exp.tag t @@ ArrayRead {array;idx} + | _ -> M.throw (make_msg exp.tag "not an array !") end - | StructRead (origin,e,(fl,field)) -> - let* e = lower_lexp e in - let* ty_e = M.get_type_from_id e.info in - let+ origin,t = + + | StructRead s -> + let* strct = lower_lexp s.value.strct in + let* ty = M.get_type_from_id (mk_locatable strct.tag.loc strct.tag.ty) in + let+ import,t = begin - match ty_e with - | _,CompoundType {name=l,name;decl_ty=Some S ();_} -> - let* origin,(_,strct) = find_struct_source (l,name) origin |> M.ESC.lift |> M.lift in - let* _,t,_ = List.assoc_opt field strct.fields - |> M.throw_if_none (make_msg fl @@ Fmt.str "field '%s' is not part of structure '%s'" field name) + match ty.value with + | CompoundType {name;decl_ty=Some S ();_} -> + let* origin,(_,strct) = find_struct_source name s.import |> M.ESC.lift |> M.lift in + let* f = List.assoc_opt s.value.field.value strct.fields + |> M.throw_if_none (make_msg s.value.field.loc @@ Fmt.str "field '%s' is not part of structure '%s'" s.value.field.value name.value) in - let+ t_id = M.get_type_id t in + let+ t_id = M.get_type_id (fst f.value) in origin,t_id - | l,t -> - let* str = string_of_sailtype_thir (Some (l,t)) |> M.ESC.lift |> M.lift in - M.throw (make_msg l @@ Fmt.str "expected a structure but got type '%s'" str) + | t -> + let* str = string_of_sailtype_thir (Some (mk_locatable ty.loc t)) |> M.ESC.lift |> M.lift in + M.throw (make_msg ty.loc @@ Fmt.str "expected a structure but got type '%s'" str) end in - let x : expression = buildExp (loc,t) (StructRead (origin,e,(fl,field))) in - x - - | _ -> M.throw (make_msg loc "not a lvalue !") - - in aux e - and lower_rexp (e : Hir.expression) : expression M.t = - let rec aux (e:Hir.expression) : expression M.t = - let loc = e.info in match e.exp with + buildExp ty.loc t @@ StructRead2 (mk_importable import Ast.{field=s.value.field;strct}) + + | BinOp _ | Literal _ | UnOp _ | Ref _ | ArrayStatic _ | StructAlloc _ | EnumAlloc _ | MethodCall _ -> + M.throw (make_msg exp.tag "not a lvalue !") + + in aux exp + and lower_rexp (exp : HirUtils.expression) : expression M.t = + let rec aux (exp: HirUtils.expression) : expression M.t = + match exp.node with | Variable id -> - let* _,t = M.get_var id >>= M.throw_if_none (make_msg loc @@ Printf.sprintf "unknown variable %s" id) in + let* _,t = M.get_var id >>= M.throw_if_none (make_msg exp.tag @@ Printf.sprintf "unknown variable %s" id) in let* venv,tenv = M.get_env in let t,tenv = t tenv in let+ () = M.set_env (venv,tenv) in - buildExp (loc,t) @@ Variable id + buildExp exp.tag t @@ Variable id | Literal li -> let* () = match li with | LInt t -> - let* () = M.throw_if Logging.(make_msg loc "signed integers use a minimum of 2 bits") (t.size < 2) in + let* () = M.throw_if Logging.(make_msg exp.tag "signed integers use a minimum of 2 bits") (t.size < 2) in let max_int = Z.( ~$2 ** t.size - ~$1) in let min_int = Z.( ~-max_int + ~$1) in M.throw_if ( - make_msg loc @@ Fmt.str "type suffix can't contain int literal : i%i is between %s and %s but literal is %s" + make_msg exp.tag @@ Fmt.str "type suffix can't contain int literal : i%i is between %s and %s but literal is %s" t.size (Z.to_string min_int) (Z.to_string max_int) (Z.to_string t.l) ) Z.(lt t.l min_int || gt t.l max_int) | _ -> return () in let+ t = M.get_type_id (sailtype_of_literal li) in - buildExp (loc,t) @@ Literal li + buildExp exp.tag t @@ Literal li - | UnOp (op,e) -> let+ e = aux e in buildExp e.info @@ UnOp (op,e) + | UnOp (op,e) -> let+ e = aux e in buildExp exp.tag e.tag.ty @@ UnOp (op,e) - | BinOp (op,le,re) -> - let* le = aux le in - let* re = aux re in - let+ t = check_binop op le.info re.info |> M.recover (snd le.info) in - buildExp (loc,t) @@ BinOp (op,le,re) + | BinOp bop -> + let* left = aux bop.left in + let* right = aux bop.right in + let+ t = check_binop bop.op (mk_locatable left.tag.loc left.tag.ty) (mk_locatable right.tag.loc right.tag.ty) |> M.recover left.tag.ty in + buildExp exp.tag t @@ BinOp {op=bop.op;left;right} | Ref (mut,e) -> let* e = lower_lexp e in - let* e_t = M.get_type_from_id e.info in - let+ t = M.get_type_id (dummy_pos,RefType (e_t,mut)) in - buildExp (loc,t) @@ Ref(mut, e) + let* e_t = M.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) in + let+ t = M.get_type_id (mk_locatable dummy_pos @@ RefType (e_t,mut)) in + buildExp exp.tag t @@ Ref(mut, e) | ArrayStatic el -> let* first_t = aux (List.hd el) in - let* first_t = M.get_type_from_id first_t.info in + let* first_t = M.get_type_from_id (mk_locatable first_t.tag.loc first_t.tag.ty) in let* el = ListM.map (fun e -> let* e = aux e in - let+ e_t = M.get_type_from_id e.info in - matchArgParam (fst e.info) e_t first_t |> M.ESC.lift |> M.lift >>| fun _ -> e + let+ e_t = M.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) in + matchArgParam e.tag.loc e_t first_t |> M.ESC.lift |> M.lift >>| fun _ -> e ) el in let* el = ListM.sequence el in - let t = dummy_pos,ArrayType (first_t, List.length el) in - let+ t_id = M.get_type_id t in - buildExp (loc,t_id) (ArrayStatic el) - - | MethodCall (lid,source,args) -> - let* (args: expression list) = ListM.map lower_rexp args in - let* mod_loc,(_realname,m) = find_function_source e.info None lid source args |> M.ESC.lift |> M.lift in - let* ret = M.throw_if_none (make_msg e.info "methods in expressions should return a value") m.ret in - let* ret_t = M.get_type_id ret in - let* x = M.fresh_fvar in - M.write {info=loc; stmt=DeclVar (false, x, Some ret, None)} >>= fun () -> - M.write {info=loc; stmt=Invoke {args;id=lid; ret_var = Some x;import=mod_loc}} >>| fun () -> - buildExp (loc,ret_t) (Variable x) - - | ArrayRead _ -> lower_lexp e (* todo : some checking *) - | Deref _ -> lower_lexp e (* todo : some checking *) - | StructRead _ -> lower_lexp e (* todo : some checking *) - | StructAlloc (origin,name,fields) -> - let* origin,(_l,strct) = find_struct_source name origin |> M.ESC.lift |> M.lift in + let t = exp.tag,ArrayType (first_t, List.length el) in + let+ t_id = M.get_type_id (mk_locatable (fst t) (snd t)) in + buildExp exp.tag t_id (ArrayStatic el) + + | MethodCall mc -> + let* args = ListM.map lower_rexp mc.value.args in + let* import,(_realname,m) = find_function_source exp.tag None mc.value.id mc.import args |> M.ESC.lift |> M.lift in + let* ty = M.throw_if_none (make_msg exp.tag "methods in expressions should return a value") m.ret in + let* ty_t = M.get_type_id ty in + let* ret_var = M.fresh_fvar in + M.write {tag={loc=exp.tag;ty=""}; node=DeclVar2 {mut=false;id=ret_var;ty}} >>= fun () -> + let x = Ast.{args;id=mc.value.id; ret_var = Some ret_var} in + M.write {tag={loc=exp.tag;ty=""}; node=Invoke2 (mk_importable import x)} >>| fun () -> + buildExp exp.tag ty_t (Variable ret_var) + + | ArrayRead _ + | Deref _ + | StructRead _ -> lower_lexp exp (* todo : some checking *) + + | StructAlloc s -> + let* import,(_l,strct) = find_struct_source s.value.name s.import |> M.ESC.lift |> M.lift in let struct_fields = List.to_seq strct.fields in let fields = FieldMap.( merge ( fun n f1 f2 -> match f1,f2 with - | Some _, Some (l,e) -> Some (let+ e = lower_rexp e in n,(l,e)) + | Some _, Some e -> Some (let+ e = lower_rexp e.value in n,e) | None,None -> None - | None, Some (l,_) -> Some (M.throw @@ make_msg l @@ Fmt.str "no field '%s' in struct '%s'" n (snd name)) - | Some _, None -> Some (M.throw @@ make_msg loc @@ Fmt.str "missing field '%s' from struct '%s'" n (snd name)) + | None, Some l -> Some (M.throw @@ make_msg l.loc @@ Fmt.str "no field '%s' in struct '%s'" n s.value.name.value) + | Some _, None -> Some (M.throw @@ make_msg s.value.name.loc @@ Fmt.str "missing field '%s' from struct '%s'" n s.value.name.value) ) (struct_fields |> of_seq) - (fields |> List.to_seq |> of_seq) + (s.value.fields |> List.to_seq |> of_seq) |> to_seq ) in - let* () = M.throw_if (make_msg (fst name) "missing fields ") Seq.(length fields < Seq.length struct_fields) in + let* () = M.throw_if (make_msg s.value.name.loc "missing fields ") Seq.(length fields < Seq.length struct_fields) in - let* fields = SeqM.sequence (Seq.map snd fields) in + let* fields: (string * expression) Seq.t = SeqM.sequence (Seq.map snd fields) in - let* () = SeqM.iter2 (fun (_name1,(l,(e:expression))) (_name2,(_,t,_)) -> - let* e_t = M.get_type_from_id e.info in - matchArgParam l e_t t |> M.ESC.lift |> M.lift >>| fun _ -> () - ) - fields - struct_fields + let* () = + SeqM.iter2 (fun (_name1,e) (_name2,t) -> + let* e_t = M.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) in + matchArgParam e.tag.loc e_t (fst t.value) |> M.ESC.lift |> M.lift >>| fun _ -> () + ) + fields + struct_fields in - let ty = dummy_pos,CompoundType {origin= Some origin;decl_ty=Some (S ()); name; generic_instances=[]} in - let+ ty = M.get_type_id ty in - (buildExp (loc,ty) (StructAlloc (origin,name, List.of_seq fields) )) + let _fields = List.of_seq fields in + let l,ty = dummy_pos,CompoundType {origin= Some import;decl_ty=Some (S ()); name=s.value.name; generic_instances=[]} in + let+ ty = M.get_type_id (mk_locatable l ty) in + buildExp exp.tag ty @@ StructAlloc2 (mk_importable import Ast.{name=s.value.name;fields=[] (* FIXMEEEEEEEEEEEEE *)} ) - | EnumAlloc _ -> M.throw (make_msg loc "todo enum alloc ") - in aux e + | EnumAlloc _ -> M.throw (make_msg exp.tag "todo enum alloc ") + in aux exp let lower_method (body,proto : _ * method_sig) (env,tenv:THIREnv.t * _) _ : (m_out * THIREnv.D.t * _) M.E.t = let open UseMonad(M.ESC) in let module MF = MonadFunctions(M) in - let log_and_skip e = M.ESC.log e >>| fun () -> buildStmt e.where Skip in + let log_and_skip e = M.ESC.log e >>| fun () -> Ast.buildStmt {loc=e.where;ty="unit"} Skip in - let rec aux s : m_out M.ESC.t = - let loc = s.info in - let buildStmt = buildStmt loc in - let buildSeq s1 s2 = {info=loc; stmt = Seq (s1, s2)} in + let rec aux (s: m_in) : m_out M.ESC.t = + let loc = s.tag in + let buildStmt = Ast.buildStmt {loc;ty="unit"} in + let buildSeq = Ast.buildSeq {loc;ty="unit"} in let buildSeqStmt s1 s2 = buildSeq s1 @@ buildStmt s2 in - match s.stmt with - | DeclVar (mut, id, opt_t, (opt_exp : Hir.expression option)) -> - let* ((ty,opt_e,s):sailtype * 'b * 'c) = + match s.node with + | DeclVar d -> + let* ty,opt_e,s = begin - match opt_t,opt_exp with + match d.ty,d.value with | Some t, Some e -> let* e,s = lower_rexp e in - let* e_t = M.ES.get_type_from_id e.info |> M.ESC.lift in - matchArgParam (fst e.info) e_t t |> M.ESC.lift >>| fun _ -> t,Some e,s + let* e_t = M.ES.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) |> M.ESC.lift in + matchArgParam e.tag.loc e_t t |> M.ESC.lift >>| fun _ -> t,Some e,s | None,Some e -> let* e,s = lower_rexp e in - let+ e_t = M.ES.get_type_from_id e.info |> M.ESC.lift in + let+ e_t = M.ES.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) |> M.ESC.lift in e_t,Some e,s | Some t,None -> return (t,None,buildStmt Skip) | None,None -> M.ESC.throw (make_msg loc "can't infere type with no expression") end in let* ty_id = M.ES.get_type_id ty |> M.ESC.lift in - let decl_var = THIREnv.declare_var id (loc,fun e -> ty_id,e) in - M.ESC.update_env (fun (st,t) -> M.E.(bind (decl_var st) (fun st -> pure (st,t)))) - >>| fun () -> (buildSeqStmt s @@ DeclVar (mut,id,Some ty,opt_e)) + let decl_var = THIREnv.declare_var d.id (loc,fun e -> ty_id,e) in + M.ESC.update_env (fun (st,t) -> M.E.(bind (decl_var st) (fun st -> pure (st,t)))) >>| fun () -> + let s1 = buildSeqStmt s @@ DeclVar2 {mut=d.mut;ty;id=d.id;} in + let s2 = match opt_e with None -> Ast.Skip | Some value -> Assign {path=buildExp loc ty_id (Variable d.id); value} in + buildSeqStmt s1 s2 - | Assign(e1, e2) -> - let* e1,s1 = lower_lexp e1 - and* e2,s2 = lower_rexp e2 in - let* e1_t = M.ES.get_type_from_id e1.info |> M.ESC.lift - and* e2_t = M.ES.get_type_from_id e2.info |> M.ESC.lift in - matchArgParam (fst e2.info) e2_t e1_t |> M.ESC.lift >>| - fun _ -> buildSeq s1 @@ buildSeqStmt s2 @@ Assign(e1, e2) - - | Seq(c1, c2) -> + | Assign a -> + let* value,s1 = lower_rexp a.value + and* path,s2 = lower_lexp a.path in + let* value_t = M.ES.get_type_from_id (mk_locatable value.tag.loc value.tag.ty) |> M.ESC.lift + and* path_t = M.ES.get_type_from_id (mk_locatable path.tag.loc path.tag.ty) |> M.ESC.lift in + matchArgParam path.tag.loc path_t value_t |> M.ESC.lift >>| + fun _ -> buildSeq s1 @@ buildSeqStmt s2 @@ Assign {path;value} + + | Seq (c1, c2) -> let* c1 = aux c1 in let+ c2 = aux c2 in - buildStmt (Seq(c1, c2)) + buildStmt (Seq (c1, c2)) - | If(cond_exp, then_s, else_s) -> - let* cond_exp,s = lower_rexp cond_exp in - let* cond_t = M.ES.get_type_from_id cond_exp.info |> M.ESC.lift in - let* _ = matchArgParam (fst cond_exp.info) cond_t (dummy_pos,Bool) |> M.ESC.lift in - let* res = aux then_s in + | If if_ -> + let* cond,s = lower_rexp if_.cond in + let* cond_t = M.ES.get_type_from_id (mk_locatable cond.tag.loc cond.tag.ty) |> M.ESC.lift in + let* _ = matchArgParam cond.tag.loc cond_t (mk_locatable dummy_pos Bool) |> M.ESC.lift in + let* then_ = aux if_.then_ in begin - match else_s with - | None -> return @@ buildSeqStmt s (If(cond_exp, res, None)) - | Some else_ -> let+ else_ = aux else_ in buildSeqStmt s (If(cond_exp, res, Some else_)) + match if_.else_ with + | None -> return @@ buildSeqStmt s (If {cond;then_;else_=None}) + | Some else_ -> let+ else_ = aux else_ in buildSeqStmt s (If {cond;then_;else_=Some else_}) end | Loop c -> @@ -256,29 +261,29 @@ struct | Break -> return (buildStmt Break) - | Case(e, _cases) -> - let+ e,s = lower_rexp e in - buildSeqStmt s (Case (e, [])) + | Case c -> + let+ switch,s = lower_rexp c.switch in + buildSeqStmt s (Case {switch;cases=[]}) | Invoke i -> (* todo: handle var *) - let* args,s = MF.ListM.map lower_rexp i.args in - let* import,_ = find_function_source s.info i.ret_var i.id i.import args |> M.ESC.lift in - buildSeqStmt s (Invoke { i with import ; args} ) |> return + let* args,s = MF.ListM.map lower_rexp i.value.args in + let* import,_ = find_function_source s.tag.loc i.value.ret_var i.value.id i.import args |> M.ESC.lift in + buildSeqStmt s (Invoke2 (mk_importable import Ast.{args;ret_var=i.value.ret_var;id=i.value.id} )) |> return - | Return None as r -> - if proto.rtype = None then return (buildStmt r) else + | Return None -> + if proto.rtype = None then return (buildStmt (Return None)) else log_and_skip (make_msg loc @@ Printf.sprintf "void return but %s returns %s" proto.name (string_of_sailtype proto.rtype)) | Return (Some e) -> let* e,s = lower_rexp e in - let* t = M.ES.get_type_from_id e.info |> M.ESC.lift in + let* t = M.ES.get_type_from_id (mk_locatable e.tag.loc e.tag.ty) |> M.ESC.lift in begin match proto.rtype with | None -> log_and_skip (make_msg loc @@ Printf.sprintf "returns %s but %s doesn't return anything" (string_of_sailtype (Some t)) proto.name) | Some r -> - matchArgParam (fst e.info) t r |> M.ESC.lift >>| fun _ -> + matchArgParam e.tag.loc t r |> M.ESC.lift >>| fun _ -> buildSeqStmt s (Return (Some e)) end @@ -291,7 +296,7 @@ struct | Skip -> return (buildStmt Skip) in - M.(E.bind (ESC.run aux body (env,tenv)) (fun (x,y) -> E.pure (x,snd env,y))) |> Logger.recover (buildStmt dummy_pos Skip,snd env,tenv) + M.(E.bind (ESC.run aux body (env,tenv)) (fun (x,y) -> E.pure (x,snd env,y))) |> Logger.recover (Ast.buildStmt {loc=dummy_pos; ty="unit"} Skip,snd env,tenv) let preprocess = resolve_types (* todo : create semantic types + type inference *) diff --git a/src/passes/ir/sailThir/thirUtils.ml b/src/passes/ir/sailThir/thirUtils.ml index 09e4fb4..49775c4 100644 --- a/src/passes/ir/sailThir/thirUtils.ml +++ b/src/passes/ir/sailThir/thirUtils.ml @@ -1,16 +1,21 @@ open Common open TypesCommon open Monad -open IrHir +open IrAst module D = SailModule.Declarations -type expression = (loc * string, l_str) AstHir.expression (* string is the key for the type map *) -type statement = (loc,l_str,expression) AstHir.statement +type tag = {loc:loc; ty:string} +type thir = +type expression = (tag,Ast.exp,thir) Ast.generic_ast +type statement = (tag,Ast.stmt,thir) Ast.generic_ast +let buildExp loc ty = Ast.buildExp {loc;ty} + +let empty_tag = {loc=dummy_pos;ty="unit"} module M = ThirMonad.Make(struct type t = statement - let mempty : t = {info=dummy_pos; stmt=Skip} - let mconcat : t -> t -> t = fun x y -> {info=dummy_pos; stmt=Seq (x,y)} + let mempty : t = {tag=empty_tag; node=Skip} + let mconcat : t -> t -> t = fun x y -> {tag=empty_tag; node=Seq (x,y)} end ) @@ -19,27 +24,27 @@ open UseMonad(M.ES) -let rec resolve_alias (l,ty : sailtype) : (sailtype,string) Either.t M.ES.t = - match ty with - | CompoundType {origin;name=(_,name);decl_ty=Some (T ());_} -> - let* (_,mname) = M.ES.throw_if_none Logging.(make_msg l @@ "unknown type '" ^ name ^ "' , all types must have an origin (problem with HIR)") origin in - let* ty_t = M.ES.get_decl name (Specific (mname,Type)) - >>= M.ES.throw_if_none Logging.(make_msg l @@ Fmt.str "declaration '%s' requires importing module '%s'" name mname) in +let rec resolve_alias (ty : sailtype) : (sailtype,string) Either.t M.ES.t = + match ty.value with + | CompoundType {origin;name;decl_ty=Some (T ());_} -> + let* mname = M.ES.throw_if_none Logging.(make_msg name.loc @@ "unknown type '" ^ name.value ^ "' , all types must have an origin (problem with HIR)") origin in + let* ty_t = M.ES.get_decl name.value (Specific (mname.value,Type)) + >>= M.ES.throw_if_none Logging.(make_msg name.loc @@ Fmt.str "declaration '%s' requires importing module '%s'" name.value mname.value) in begin match ty_t.ty with - | Some (_,CompoundType _ as ct) -> resolve_alias ct + | Some ({value=CompoundType _;_} as ct) -> resolve_alias ct | Some t -> return (Either.left t) - | None -> return (Either.right name) (* abstract type, only look at name *) + | None -> return (Either.right name.value) (* abstract type, only look at name *) end - | _ -> return (Either.left (l,ty)) + | _ -> return (Either.left ty) let string_of_sailtype_thir (t : sailtype option) : string M.ES.t = let+ res = match t with - | Some (_,CompoundType {origin; name=(loc,x); _}) -> - let* (_,mname) = M.ES.throw_if_none Logging.(make_msg loc "no origin in THIR (problem with HIR)") origin in - let+ decl = M.ES.(get_decl x (Specific (mname,Filter [E (); S (); T()])) - >>= throw_if_none Logging.(make_msg loc "decl is null (problem with HIR)")) in + | Some {value=CompoundType {origin; name; _};_} -> + let* mname = M.ES.throw_if_none Logging.(make_msg name.loc "no origin in THIR (problem with HIR)") origin in + let+ decl = M.ES.(get_decl name.value (Specific (mname.value,Filter [E (); S (); T()])) + >>= throw_if_none Logging.(make_msg name.loc "decl is null (problem with HIR)")) in begin match decl with | T ty_def -> @@ -48,7 +53,7 @@ let string_of_sailtype_thir (t : sailtype option) : string M.ES.t = | Some t -> Fmt.str " (= %s)" @@ string_of_sailtype (Some t) | None -> "(abstract)" end - | S (_,s) -> Fmt.str " (= struct <%s>)" (List.map (fun (n,(_,t,_)) -> Fmt.str "%s:%s" n @@ string_of_sailtype (Some t) ) s.fields |> String.concat ", ") + | S (_,s) -> Fmt.str " (= struct <%s>)" (List.map (fun (n,f) -> Fmt.str "%s:%s" n @@ string_of_sailtype (Some (fst f.value)) ) s.fields |> String.concat ", ") | _ -> failwith "can't happen" end | _ -> return "" @@ -59,36 +64,36 @@ let matchArgParam (loc : loc) (arg: sailtype) (m_param : sailtype) : sailtype let rec aux (a:sailtype) (m:sailtype) : sailtype M.ES.t = let* lt = resolve_alias a in let* rt = resolve_alias m in - + let mk_locatable = fun x -> mk_locatable loc x |> return in match lt,rt with - | Left (loc_l,l), Left (_,r) -> + | Left l, Left r -> begin - match l,r with - | Bool, Bool -> return (loc_l,Bool) - | (Int i1), (Int i2) when i1 = i2 -> return (loc_l,Int i1) - | Float, Float -> return (loc_l,Float) - | Char, Char -> return (loc_l,Char) - | String, String -> return (loc_l,String) + match l.value,r.value with + | Bool, Bool -> mk_locatable Bool + | (Int i1), (Int i2) when i1 = i2 -> mk_locatable (Int i1) + | Float, Float -> mk_locatable Float + | Char, Char -> mk_locatable Char + | String, String -> mk_locatable String | ArrayType (at,s), ArrayType (mt,s') -> if s = s' then - let+ t = aux at mt in loc_l,ArrayType (t,s) + let* t = aux at mt in mk_locatable (ArrayType (t,s)) else - M.ES.throw Logging.(make_msg loc_l (Printf.sprintf "array length mismatch : wants %i but %i provided" s' s)) + M.ES.throw Logging.(make_msg l.loc (Printf.sprintf "array length mismatch : wants %i but %i provided" s' s)) - | Box _at, Box _mt -> M.ES.throw Logging.(make_msg loc_l "todo box") + | Box _at, Box _mt -> M.ES.throw Logging.(make_msg l.loc "todo box") | RefType (at,am), RefType (mt,mm) -> - if am <> mm then M.ES.throw Logging.(make_msg loc_l "different mutability") + if am <> mm then M.ES.throw Logging.(make_msg l.loc "different mutability") else aux at mt | at, GenericType _ - | GenericType _, at -> return (loc_l,at) + | GenericType _, at -> mk_locatable at - | CompoundType c1, CompoundType c2 when snd c1.name = snd c2.name -> + | CompoundType c1, CompoundType c2 when c1.name.value = c2.name.value -> return arg | _ -> let* param = string_of_sailtype_thir (Some m_param) and* arg = string_of_sailtype_thir (Some arg) in - M.ES.throw Logging.(make_msg loc_l @@ Printf.sprintf "wants %s but %s provided" param arg) + M.ES.throw Logging.(make_msg l.loc @@ Printf.sprintf "wants %s but %s provided" param arg) end | Right name, Right name' -> @@ -105,15 +110,18 @@ let matchArgParam (loc : loc) (arg: sailtype) (m_param : sailtype) : sailtype let check_binop op l r : string M.ES.t = let* l_t = M.ES.get_type_from_id l and* r_t = M.ES.get_type_from_id r in + + let mk_locatable = fun x -> mk_locatable l_t.loc x in + match op with | Lt | Le | Gt | Ge | Eq | NEq -> - let* _ = matchArgParam (fst l_t) r_t l_t in M.ES.get_type_id (fst l_t,Bool) + let* _ = matchArgParam l_t.loc r_t l_t in M.ES.get_type_id (mk_locatable Bool) | And | Or -> - let* _ = matchArgParam (fst l_t) l_t (fst l_t,Bool) - and* _ = matchArgParam (fst l_t) r_t (fst l_t,Bool) - in M.ES.get_type_id (fst l_t,Bool) + let* _ = matchArgParam l_t.loc l_t (mk_locatable Bool) + and* _ = matchArgParam l_t.loc r_t (mk_locatable Bool) + in M.ES.get_type_id (mk_locatable Bool) | Plus | Mul | Div | Minus | Rem -> - let+ _ = matchArgParam (fst l_t) r_t l_t in snd l + let+ _ = matchArgParam l_t.loc r_t l_t in l.value let check_call (name:string) (f : method_proto) (args: expression list) loc : unit M.ES.t = @@ -126,8 +134,8 @@ let check_call (name:string) (f : method_proto) (args: expression list) loc : un ListM.iter2 ( fun (ca:expression) ({ty=a;_}:param) -> - let* rty = M.ES.get_type_from_id ca.info in - let+ _ = matchArgParam (fst ca.info) rty a in () + let* rty = M.ES.get_type_from_id (mk_locatable ca.tag.loc ca.tag.ty) in + let+ _ = matchArgParam ca.tag.loc rty a in () ) args f.args @@ -136,10 +144,10 @@ let check_call (name:string) (f : method_proto) (args: expression list) loc : un let find_function_source (fun_loc:loc) (_var: string option) (name : l_str) (import:l_str option) (el: expression list) : (l_str * D.method_decl) M.ES.t = let* (_,env),_ = M.ES.get in - let* mname,def = HirUtils.find_symbol_source ~filt:[M ()] name import env |> M.ES.lift in + let* mname,def = IrHir.HirUtils.find_symbol_source ~filt:[M ()] name import env |> M.ES.lift in match def with | M decl -> - let+ _ = check_call (snd name) (snd decl) el fun_loc in mname,decl + let+ _ = check_call name.value (snd decl) el fun_loc in mname,decl (* let _x = fun_loc and _y = el in return (mname,decl) *) | _ -> failwith "non method returned" (* cannot happen because we only requested methods *) @@ -156,7 +164,7 @@ let find_function_source (fun_loc:loc) (_var: string option) (name : l_str) (imp let find_struct_source (name: l_str) (import : l_str option) : (l_str * D.struct_decl) M.ES.t = let* (_,env),_ = M.ES.get in - let+ origin,def = HirUtils.find_symbol_source ~filt:[S()] name import env |> M.ES.lift in + let+ origin,def = IrHir.HirUtils.find_symbol_source ~filt:[S()] name import env |> M.ES.lift in begin match def with | S decl -> origin,decl @@ -167,7 +175,7 @@ let find_struct_source (name: l_str) (import : l_str option) : (l_str * D.struct let resolve_types (sm : ('a,'b) SailModule.methods_processes SailModule.t) = - let open HirUtils in + let open IrHir.HirUtils in let module ES = struct module T = struct type t = {decls: D.t ; types : Env.TypeEnv.t} end @@ -206,11 +214,11 @@ let find_struct_source (name: l_str) (import : l_str option) : (l_str * D.struct let* () = SEnv.iter ( fun (id,(l,{fields; generics})) -> let* fields = ListM.map ( - fun (name,(l,t,n)) -> + fun (name, ({value=t,n;_} as tn)) -> let* env = ES.get_env in let* t,decls = (follow_type t env.decls) |> ES.S.lift in let+ () = ES.set_env {env with decls} in - name,(l,t,n) + name,mk_locatable tn.loc (t,n) ) fields in let proto = l,{fields;generics} in @@ -237,7 +245,7 @@ let find_struct_source (name: l_str) (import : l_str option) : (l_str * D.struct let true_name = (match m_body with Left (sname,_) -> sname | Right _ -> m_proto.name) in let+ () = ES.update_env (fun e -> - let decls = update_decl m_proto.name ((m_proto.pos,true_name), defn_to_proto (Method m)) (Self Method) e.decls in + let decls = update_decl m_proto.name (mk_locatable m_proto.pos true_name, defn_to_proto (Method m)) (Self Method) e.decls in {e with decls} ) in m diff --git a/src/passes/misc/cfg_analysis.ml b/src/passes/misc/cfg_analysis.ml index 8e42a2f..4dc462c 100644 --- a/src/passes/misc/cfg_analysis.ml +++ b/src/passes/misc/cfg_analysis.ml @@ -4,7 +4,7 @@ open Pass open IrMir open SailParser open IrHir -open AstMir +open MirAst open Monad @@ -66,7 +66,7 @@ let check_returns (proto : method_sig) (decls,cfg : mir_function) : mir_functio module Pass = Make(struct let name = "analysis on mir" - type in_body = (AstMir.mir_function,(HirUtils.statement,HirUtils.expression) AstParser.process_body) SailModule.methods_processes + type in_body = (MirAst.mir_function,(HirUtils.statement,HirUtils.expression) AstParser.process_body) SailModule.methods_processes type out_body = in_body let transform (sm: in_body SailModule.t) : out_body SailModule.t E.t = diff --git a/src/passes/monomorphization/monomorphization.ml b/src/passes/monomorphization/monomorphization.ml index 2a63bf0..ac6701e 100644 --- a/src/passes/monomorphization/monomorphization.ml +++ b/src/passes/monomorphization/monomorphization.ml @@ -3,7 +3,7 @@ open Monad open TypesCommon module E = Common.Logging open Monad.MonadSyntax (E.Logger) -open IrMir.AstMir +open IrMir.MirAst open MonomorphizationMonad module M = MonoMonad open MonomorphizationUtils @@ -19,45 +19,45 @@ module Pass = Pass.Make (struct let mono_fun (f : sailor_function) (sm : in_body SailModule.t) : unit M.t = - let mono_exp (e : expression) (decls :declaration list) : sailtype M.t = - let rec aux (e : expression) : sailtype M.t = - match e.exp with + let mono_exp (exp : expression) (decls :declaration list) : sailtype M.t = + let rec aux (exp : expression) : sailtype M.t = + match exp.node with | Variable s -> M.get_var s <&> (function | Some v -> Some (snd v).ty (* var is a function param *) | None -> Option.bind (List.find_opt (fun v -> v.id = s) decls) (fun decl -> Some decl.varType) (* var is function declaration *) ) - >>= M.throw_if_none Logging.(make_msg (fst e.info) @@ Fmt.str "compiler error : var '%s' not found" s) + >>= M.throw_if_none Logging.(make_msg exp.tag.loc @@ Fmt.str "compiler error : var '%s' not found" s) | Literal l -> return (sailtype_of_literal l) - | ArrayRead (e, idx) -> + | ArrayRead a -> begin - let* l,t = aux e in - match t with + let* t = aux a.array in + match t.value with | ArrayType (t, _) -> - let+ idx_t = aux idx in - let _ = resolveType idx_t (l,Int 32) [] [] in + let+ idx_t = aux a.idx in + let _ = resolveType idx_t (mk_locatable t.loc @@ Int 32) [] [] in t | _ -> failwith "cannot happen" end | UnOp (_, e) -> aux e - | BinOp (_, e1, e2) -> - let* t1 = aux e1 in - let+ t2 = aux e2 in - let _ = resolveType t1 t2 [] [] in - t1 + | BinOp bop -> + let* left = aux bop.left in + let+ right = aux bop.right in + let _ = resolveType left right [] [] in + left | Ref (m, e) -> let+ t = aux e in - dummy_pos,RefType (t, m) + mk_locatable exp.tag.loc @@ RefType (t, m) | Deref e -> ( - let+ l,t = aux e in - match t with - | RefType _ -> l,t + let+ t = aux e in + match t.value with + | RefType _ -> t | _ -> failwith "cannot happen" ) @@ -70,15 +70,14 @@ module Pass = Pass.Make (struct next_t ) t h in - dummy_pos,ArrayType (t, List.length (e :: h)) + mk_locatable exp.tag.loc @@ ArrayType (t, List.length (e :: h)) | ArrayStatic [] -> failwith "error : empty array" - | StructAlloc (_, _, _) -> failwith "todo: struct alloc" + | StructAlloc2 _ -> failwith "todo: struct alloc" | EnumAlloc (_, _) -> failwith "todo: enum alloc" - | StructRead (_, _, _) -> failwith "todo: struct read" - | MethodCall _ -> failwith "no method call at this stage" + | StructRead2 _ -> failwith "todo: struct read" in - aux e + aux exp in let construct_call (calle : string) (el : expression list) decls : (string * sailtype option) M.t = @@ -111,47 +110,47 @@ module Pass = Pass.Make (struct begin let* f = find_callable calle sm |> M.lift in match f with - | None -> (*import *) return (mname,Some (dummy_pos,Int 32) (*fixme*)) + | None -> (*import *) return (mname,Some (mk_locatable dummy_pos @@ Int 32) (*fixme*)) | Some f -> - begin - Logs.debug (fun m -> m "found call to %s, variadic : %b" f.m_proto.name f.m_proto.variadic ); - match f.m_body with - | Right _ -> - (* process and method - - we make sure they correspond to what the callable wants - if the callable is generic we check all the generic types are present at least once - - we build a (string*sailtype) list of generic to type correspondance - if the generic is not found in the list, we add it with the corresponding type - if the generic already exists with the same type as the new one, we are good else we fail - *) - let* resolved_generics = check_args call_args f |> M.lift in - List.iter (fun (n, t) -> Logs.debug (fun m -> m "resolved %s to %s " n (string_of_sailtype (Some t)))) resolved_generics; - - let* () = M.push_monos calle resolved_generics in - - let* rtype = - match f.m_proto.rtype with - | Some t -> - (* Logs.warn (fun m -> m "TYPE BEFORE : %s" (string_of_sailtype (Some t))); *) - let+ t = (degenerifyType t resolved_generics|> M.lift) in - (* Logs.warn (fun m -> m "TYPE AFTER : %s" (string_of_sailtype (Some t))); *) - Some t - | None -> return None - in - - let params = List.map2 (fun (p:param) ty -> {p with ty}) f.m_proto.params call_args in - let name = mname in - let methd = { f with m_proto = { f.m_proto with rtype ; params } } in - let+ () = - let* f = M.get_decl name (Self Method) in - if Option.is_none f then - M.add_decl name ((dummy_pos,name),(defn_to_proto (Method methd))) Method - else return () - in - mname,rtype - | Left _ -> (* external method *) return (calle,f.m_proto.rtype) + begin + Logs.debug (fun m -> m "found call to %s, variadic : %b" f.m_proto.name f.m_proto.variadic ); + match f.m_body with + | Right _ -> + (* process and method + + we make sure they correspond to what the callable wants + if the callable is generic we check all the generic types are present at least once + + we build a (string*sailtype) list of generic to type correspondance + if the generic is not found in the list, we add it with the corresponding type + if the generic already exists with the same type as the new one, we are good else we fail + *) + let* resolved_generics = check_args call_args f |> M.lift in + List.iter (fun (n, t) -> Logs.debug (fun m -> m "resolved %s to %s " n (string_of_sailtype (Some t)))) resolved_generics; + + let* () = M.push_monos calle resolved_generics in + + let* rtype = + match f.m_proto.rtype with + | Some t -> + (* Logs.warn (fun m -> m "TYPE BEFORE : %s" (string_of_sailtype (Some t))); *) + let+ t = (degenerifyType t resolved_generics|> M.lift) in + (* Logs.warn (fun m -> m "TYPE AFTER : %s" (string_of_sailtype (Some t))); *) + Some t + | None -> return None + in + + let params = List.map2 (fun (p:param) ty -> {p with ty}) f.m_proto.params call_args in + let name = mname in + let methd = { f with m_proto = { f.m_proto with rtype ; params } } in + let+ () = + let* f = M.get_decl name (Self Method) in + if Option.is_none f then + M.add_decl name ((mk_locatable dummy_pos name),(defn_to_proto (Method methd))) Method + else return () + in + mname,rtype + | Left _ -> (* external method *) return (calle,f.m_proto.rtype) end end in diff --git a/src/passes/monomorphization/monomorphizationUtils.ml b/src/passes/monomorphization/monomorphizationUtils.ml index 3511a66..5f39551 100644 --- a/src/passes/monomorphization/monomorphizationUtils.ml +++ b/src/passes/monomorphization/monomorphizationUtils.ml @@ -3,10 +3,10 @@ open TypesCommon open Monad open IrHir module E = Logging.Logger -module Env = SailModule.SailEnv(IrMir.AstMir.V) +module Env = SailModule.SailEnv(IrMir.MirAst.V) open UseMonad(E) -type in_body = IrMir.AstMir.mir_function +type in_body = IrMir.MirAst.mir_function type out_body = { monomorphics : in_body method_defn list; polymorphics : in_body method_defn list; @@ -35,39 +35,46 @@ let print_method_proto (name : string) (methd : in_body sailor_method) = let resolveType (arg : sailtype) (m_param : sailtype) (generics : string list) (resolved_generics : sailor_args) : (sailtype * sailor_args) E.t = - let rec aux ((aloc, a) : sailtype) ((mloc, m) : sailtype) (g : sailor_args) : (sailtype * sailor_args) E.t = - match a,m with - | Bool, Bool -> return ((aloc,Bool), g) - | Int x, Int y when x = y -> return ((aloc,Int x), g) - | Float, Float -> return ((aloc,Float), g) - | Char, Char -> return ((aloc,Char), g) - | String, String -> return ((aloc,String), g) - | ArrayType (at, s), ArrayType (mt, _) -> let+ t,g = aux at mt g in (aloc,ArrayType (t, s)), g - | GenericType _g1, GenericType _g2 -> return ((aloc,Int 32),g) + let rec aux (at : sailtype) (mt : sailtype) (g : sailor_args) : (sailtype * sailor_args) E.t = + match at.value,mt.value with + | Bool, Bool -> return ((mk_locatable at.loc Bool), g) + | Int x, Int y when x = y -> return ((mk_locatable at.loc @@ Int x), g) + | Float, Float -> return ((mk_locatable at.loc Float), g) + | Char, Char -> return ((mk_locatable at.loc Char), g) + | String, String -> return ((mk_locatable at.loc String), g) + | ArrayType (at, s), ArrayType (mt, _) -> let+ t,g = aux at mt g in (mk_locatable at.loc @@ ArrayType (t, s)), g + | GenericType _g1, GenericType _g2 -> return ((mk_locatable at.loc @@ Int 32),g) (* E.throw Logging.(make_msg dummy_pos @@ Fmt.str "resolveType between generic %s and %s" g1 g2) *) | _, GenericType gt -> - let* () = E.throw_if Logging.(make_msg mloc @@ Fmt.str "generic type %s not declared" gt) (not @@ List.mem gt generics) in + let* () = E.throw_if + Logging.(make_msg mt.loc @@ Fmt.str "generic type %s not declared" gt) + (not @@ List.mem gt generics) in begin match List.assoc_opt gt g with - | None -> return ((aloc,a), (gt, (aloc,a)) :: g) - | Some (lt,t) -> + | None -> return (at, (gt, at) :: g) + | Some t -> E.throw_if - Logging.(make_msg lt @@ Fmt.str "generic type mismatch : %s -> %s vs %s" gt (string_of_sailtype (Some (lt,t))) (string_of_sailtype (Some (aloc,a)))) - (t <> a) - >>| fun () -> (aloc,a), g + Logging.(make_msg t.loc @@ Fmt.str "generic type mismatch : %s -> %s vs %s" gt + (string_of_sailtype (Some t)) + (string_of_sailtype (Some at))) + (t.value <> at.value) + >>| fun () -> at, g end + | RefType (at, _), RefType (mt, _) -> aux at mt g | CompoundType _, CompoundType _ -> failwith "todocompoundtype" | Box _at, Box _mt -> failwith "todobox" - | _ -> E.throw Logging.(make_msg dummy_pos @@ Fmt.str "cannot happen : %s vs %s" (string_of_sailtype (Some (aloc,a))) (string_of_sailtype (Some (mloc,m)))) + | _ -> E.throw Logging.(make_msg dummy_pos @@ Fmt.str "cannot happen : %s vs %s" + (string_of_sailtype (Some at)) + (string_of_sailtype (Some mt))) in aux arg m_param resolved_generics let degenerifyType (t : sailtype) (generics : sailor_args) : sailtype E.t = - let rec aux (l,t) = - let+ t = match t with + let rec aux t = + let+ t' = match t.value with | Bool -> return Bool | Int n -> return (Int n) | Float -> return Float @@ -77,10 +84,13 @@ let degenerifyType (t : sailtype) (generics : sailor_args) : sailtype E.t = | Box t -> let+ t = aux t in Box t | RefType (t, m) -> let+ t = aux t in RefType (t, m) | GenericType n -> - let+ t = E.throw_if_none Logging.(make_msg dummy_pos @@ Fmt.str "generic type %s not present in the generics list" n) (List.assoc_opt n generics) in - snd t + let+ t = E.throw_if_none + Logging.(make_msg dummy_pos @@ Fmt.str "generic type %s not present in the generics list" n) + (List.assoc_opt n generics) + in + t.value | CompoundType _ -> failwith "todo compoundtype" - in l,t + in mk_locatable t.loc t' in aux t diff --git a/src/passes/process/dune b/src/passes/process/dune index 4512632..d23d938 100644 --- a/src/passes/process/dune +++ b/src/passes/process/dune @@ -1,3 +1,3 @@ (library -(libraries common sailParser irHir irThir irMir mono) +(libraries common sailParser irHir irThir irMir irAst mono) (name processPass)) diff --git a/src/passes/process/process.ml b/src/passes/process/process.ml index dbaa83d..23f7a31 100644 --- a/src/passes/process/process.ml +++ b/src/passes/process/process.ml @@ -3,8 +3,9 @@ open TypesCommon open IrHir open SailParser open ProcessUtils -module H = HirUtils -module HirS = AstHir.Syntax +module HirU = HirUtils +module AstU = IrAst.Utils +module HirS = IrAst.Ast.Syntax module E = Logging.Logger open ProcessMonad open Monad.UseMonad(M) @@ -15,11 +16,11 @@ module Pass = Pass.Make(struct type out_body = in_body let transform (sm:in_body SailModule.t) : out_body SailModule.t E.t = - let lower_processes (procs : (H.statement,H.expression) AstParser.process_body process_defn list) : _ method_defn E.t = - let rec compute_tree closed (l,pi:loc * _ proc_init) : H.statement M.t = + let lower_processes (procs : (HirU.statement,HirU.expression) AstParser.process_body process_defn list) : _ method_defn E.t = + let rec compute_tree closed (l,pi:loc * _ proc_init) : HirU.statement M.t = let closed = FieldSet.add pi.proc closed in (* no cycle *) - let* p = find_process_source (l,pi.proc) pi.mloc procs (*fixme : grammar to allow :: syntax *) in + let* p = find_process_source (mk_locatable l pi.proc) pi.mloc procs (*fixme : grammar to allow :: syntax *) in let* p = M.throw_if_none Logging.(make_msg l @@ Fmt.str "process '%s' is unknown" pi.proc) p in let* tag = M.fresh_prefix p.p_name in let prefix = (Fmt.str "%s_%s_" tag) in @@ -33,48 +34,51 @@ module Pass = Pass.Make(struct let* () = param_arg_mismatch "init" p.p_interface.p_params pi.params in - let rename_l = List.map2 (fun (_,subx) (_,(x,_)) -> (x,subx) ) (pi.read @ pi.write) (fst p.p_interface.p_shared_vars @ snd p.p_interface.p_shared_vars) in + let rename_l = List.map2 (fun subx x -> (fst x.value,subx.value) ) + (pi.read @ pi.write) + (fst p.p_interface.p_shared_vars @ snd p.p_interface.p_shared_vars) in let rename = fun id -> match List.assoc_opt id rename_l with Some v -> v | None -> id in (* add process local (but persistant) vars *) - ListM.iter (fun ((l,id),ty) -> + ListM.iter (fun (id,ty) -> let* ty,_ = HirUtils.follow_type ty sm.declEnv |> M.EC.lift |> M.ECW.lift |> M.lift in - M.(write_decls HirS.(var (l,prefix id,ty))) + M.(write_decls HirS.(var l (prefix id.value) ty None)) ) p.p_body.locals >>= fun () -> let* params = ListM.fold_right2 (fun (p:param) arg params -> let param = prefix p.id in (* add process parameters to the decls *) - M.(write_decls HirS.(var (p.loc,param,p.ty))) >>| fun () -> + M.(write_decls HirS.(var p.loc param p.ty None)) >>| fun () -> HirS.(params && !@param = arg) ) p.p_interface.p_params pi.params M.SeqMonoid.empty in (* add process init *) - let init = H.rename_var_stmt prefix p.p_body.init in + let init = IrAst.Utils.rename_var prefix p.p_body.init in M.write_init HirS.(!! (params && init)) >>= fun () -> (* inline process calls *) - let rec aux ((_,s) : (H.statement, H.expression) AstParser.p_statement) (_ty:AstParser.pgroup_ty) : H.statement M.t = + let rec aux (stmt : (HirU.statement, HirU.expression) AstParser.p_statement) (_ty:AstParser.pgroup_ty) : HirU.statement M.t = let replace_or_prefix = fun id -> let new_id = rename id in if new_id <> id then new_id else prefix id in - let process_cond c s = match c with Some c -> HirS.(_if (H.rename_var_exp replace_or_prefix c) s skip) | None -> s in + let process_cond c s = match c with Some c -> HirS.(if_ (AstU.rename_var replace_or_prefix c) s (skip ())) | None -> s in - match s with + match stmt.value with | Statement (s,cond) -> - let s = H.rename_var_stmt replace_or_prefix s in + let s = AstU.rename_var replace_or_prefix s in return (process_cond cond s) - | Run ((l,id),cond) -> - M.throw_if Logging.(make_msg l "not allowed to call Main process explicitely") (id = Constants.main_process) >>= fun () -> - M.throw_if Logging.(make_msg l "not allowed to have recursive process") (FieldSet.mem id closed) >>= fun () -> - let* l,pi = M.throw_if_none Logging.(make_msg l @@ Fmt.str "no proc init called '%s'" id) (List.find_opt (fun (_,p: loc * _ proc_init) -> p.id = id) p.p_body.proc_init) in - let read = List.map (fun (l,id) -> l,prefix id) pi.read in - let write = List.map (fun (l,id) -> l,prefix id) pi.write in - let params = List.map (H.rename_var_exp prefix) pi.params in - compute_tree closed (l,{pi with read ; write ; params}) >>| process_cond cond + | Run (id,cond) -> + M.throw_if Logging.(make_msg l "not allowed to call Main process explicitely") (id.value = Constants.main_process) >>= fun () -> + M.throw_if Logging.(make_msg l "not allowed to have recursive process") (FieldSet.mem id.value closed) >>= fun () -> + let* pi = M.throw_if_none Logging.(make_msg l @@ Fmt.str "no proc init called '%s'" id.value) + (List.find_opt (fun p -> p.value.id = id.value) p.p_body.proc_init) in + let read = List.map (fun (id:l_str) -> mk_locatable id.loc @@ prefix id.value) pi.value.read in + let write = List.map (fun (id:l_str) -> mk_locatable id.loc @@ prefix id.value) pi.value.write in + let params = List.map (AstU.rename_var prefix) pi.value.params in + compute_tree closed (l,{pi.value with read ; write ; params}) >>| process_cond cond | PGroup g -> - ListM.fold_right (fun child s -> let+ res = aux child g.p_ty in HirS.(s && res)) g.children HirS.skip >>| process_cond g.cond + ListM.fold_right (fun child s -> let+ res = aux child g.p_ty in HirS.(s && res)) g.children (HirS.skip ()) >>| process_cond g.cond in aux p.p_body.loop Parallel diff --git a/src/passes/process/processMonad.ml b/src/passes/process/processMonad.ml index e4fcc1c..716ac6f 100644 --- a/src/passes/process/processMonad.ml +++ b/src/passes/process/processMonad.ml @@ -12,7 +12,7 @@ module V = ( module M = struct - open AstHir + open IrAst.Ast module E = Logging.Logger module Env = Env.VariableDeclEnv(SailModule.Declarations)(V) @@ -21,11 +21,11 @@ module M = struct type t = {decls : HirUtils.statement; init : HirUtils.statement ; loop : HirUtils.statement} - let empty = {info=dummy_pos ; stmt=Skip} - let concat s1 s2 = match s1.stmt,s2.stmt with + let empty = {tag=dummy_pos ; node=Skip} + let concat s1 s2 = match s1.node,s2.node with | Skip, Skip -> empty - | Skip,stmt | stmt,Skip -> {info=dummy_pos ; stmt} - | _ -> {info=dummy_pos ; stmt=Seq (s1,s2)} + | Skip,node | node,Skip -> {tag=dummy_pos ; node} + | _ -> {tag=dummy_pos ; node=Seq (s1,s2)} let mempty = {decls=empty; init=empty ; loop=empty} let mconcat s1 s2 = {decls=concat s1.decls s2.decls; init = concat s1.init s2.init ; loop = concat s1.loop s2.loop} diff --git a/src/passes/process/processUtils.ml b/src/passes/process/processUtils.ml index cff855a..e2421f7 100644 --- a/src/passes/process/processUtils.ml +++ b/src/passes/process/processUtils.ml @@ -6,7 +6,7 @@ open IrHir module E = Logging.Logger module D = SailModule.Declarations -type body = (Hir.statement,(Hir.statement,Hir.expression) SailParser.AstParser.process_body) SailModule.methods_processes +type body = (HirUtils.statement,(HirUtils.statement,HirUtils.expression) SailParser.AstParser.process_body) SailModule.methods_processes let method_of_main_process (p : 'a process_defn): 'a method_defn = let m_proto = {pos=p.p_pos; name="main"; generics = p.p_generics; params = p.p_interface.p_params; variadic=false; rtype=None; extern=false} @@ -15,7 +15,7 @@ let method_of_main_process (p : 'a process_defn): 'a method_defn = let finalize (proc_def,(new_body: M.ECW.elt)) = - let open AstHir in + let open IrAst.Ast in let (++) = M.SeqMonoid.concat in let main = method_of_main_process proc_def in @@ -38,13 +38,13 @@ let ppPrintModule (pf : Format.formatter) (m : body SailModule.t ) : unit = let find_process_source (name: l_str) (import : l_str option) procs : 'a process_defn option M.t = let* _,env = M.get in - let* (_,origin),_ = HirUtils.find_symbol_source ~filt:[P()] name import env |> M.from_error in + let* origin,_ = HirUtils.find_symbol_source ~filt:[P()] name import env |> M.from_error in let+ procs = - if origin = HirUtils.D.get_name env then return procs + if origin.value = HirUtils.D.get_name env then return procs else - let find_import = List.find_opt (fun i -> i.mname = origin) (HirUtils.D.get_imports env) in + let find_import = List.find_opt (fun i -> i.mname = origin.value) (HirUtils.D.get_imports env) in let+ i = M.throw_if_none Logging.(make_msg dummy_pos "can't happen") find_import in let sm = In_channel.with_open_bin (i.dir ^ i.mname ^ Constants.mir_file_ext) @@ fun c -> (Marshal.from_channel c : Mono.MonomorphizationUtils.out_body SailModule.t) in sm.body.processes in - List.find_opt (fun (p:_ process_defn) -> p.p_name = snd name) procs + List.find_opt (fun (p:_ process_defn) -> p.p_name = name.value) procs From 89df7ec475840340159d66fbc598bef1e267a993 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?T=C3=A9rence=20Clastres?= Date: Wed, 6 Sep 2023 01:13:53 +0200 Subject: [PATCH 4/7] update to LLVM 15 & OCaml 5.1.0 --- .github/workflows/build.yml | 13 ++++---- bin/dune | 5 ++- dune-project | 7 +++-- sail-pl.opam | 6 ++-- src/codegen/codegenEnv.ml | 59 ++++++++++++++++++----------------- src/codegen/codegenUtils.ml | 4 +-- src/codegen/codegen_.ml | 35 +++++++++++++-------- src/passes/process/process.ml | 2 +- 8 files changed, 72 insertions(+), 59 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c2bc86c..0919037 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,10 +16,10 @@ jobs: - uses: actions/checkout@v3 - - name: setup llvm 14 repo + - name: setup llvm 15 repo run: | - echo "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-14 main" | sudo tee -a /etc/apt/sources.list - echo "deb-src http://apt.llvm.org/focal/ llvm-toolchain-focal-14 main" | sudo tee -a /etc/apt/sources.list + echo "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-15 main" | sudo tee -a /etc/apt/sources.list + echo "deb-src http://apt.llvm.org/focal/ llvm-toolchain-focal-15 main" | sudo tee -a /etc/apt/sources.list wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - sudo apt update @@ -29,19 +29,18 @@ jobs: # uses: ocaml/setup-ocaml@6d924c1a7769aa5cdd74bdd901f6e24eb05024b1 uses: ocaml/setup-ocaml@v2 with: - ocaml-compiler: 4.14.X + ocaml-compiler: 5.1.X - run: opam install . --deps-only - run: opam exec -- dune build - - name: Archive saili and sailor + - name: Archive sailor uses: actions/upload-artifact@v3 with: - name: saili and sailor for ${{ steps.system-info.outputs.release }} + name: sailor for ${{ steps.system-info.outputs.release }} path: | - _build/install/default/bin/saili _build/install/default/bin/sailor if-no-files-found: error diff --git a/bin/dune b/bin/dune index 4826ea5..bd40c2b 100755 --- a/bin/dune +++ b/bin/dune @@ -8,6 +8,9 @@ fmt fmt.tty fmt.cli + ctypes.foreign logs.cli ) - (public_name sailor)) + (public_name sailor) + (modes byte exe) +) diff --git a/dune-project b/dune-project index f4fc6cd..4b8aefa 100644 --- a/dune-project +++ b/dune-project @@ -1,8 +1,9 @@ -(lang dune 3.2) +(lang dune 3.7) (name "sail-pl") (version 0.1) (using menhir 2.0) (generate_opam_files true) +(map_workspace_root false) (license GPL) (authors "Frederic Dabrowski") @@ -15,13 +16,13 @@ (synopsis "SAIL: Safe Interactive Language") (description "SAIL means Safe Interactive Language.") (depends - (ocaml (>= 4.13.1)) + (ocaml (>= 5.1.0)) (cmdliner (>= 1.1.1)) (fmt (>= 0.9.0)) (menhir (>= 2.0)) (logs (>= 0.7)) (mtime (>= 1.3.0)) (ctypes-foreign (>= 0.18.0)) - (llvm (>= 13.0.0)) + (llvm (= 15.0.7+nnp-2)) zarith )) diff --git a/sail-pl.opam b/sail-pl.opam index 5f0db52..4e41564 100644 --- a/sail-pl.opam +++ b/sail-pl.opam @@ -9,15 +9,15 @@ license: "GPL" homepage: "https://sail-pl.github.io" bug-reports: "https://sail-pl.github.io" depends: [ - "dune" {>= "3.2"} - "ocaml" {>= "4.13.1"} + "dune" {>= "3.7"} + "ocaml" {>= "5.1.0"} "cmdliner" {>= "1.1.1"} "fmt" {>= "0.9.0"} "menhir" {>= "2.0"} "logs" {>= "0.7"} "mtime" {>= "1.3.0"} "ctypes-foreign" {>= "0.18.0"} - "llvm" {>= "13.0.0"} + "llvm" {= "15.0.7+nnp-2"} "zarith" "odoc" {with-doc} ] diff --git a/src/codegen/codegenEnv.ml b/src/codegen/codegenEnv.ml index f1dc565..bda04d0 100644 --- a/src/codegen/codegenEnv.ml +++ b/src/codegen/codegenEnv.ml @@ -1,10 +1,10 @@ -open Llvm open Common open TypesCommon open Env open Mono open IrMir module E = Logging.Logger +module L = Llvm open Monad.UseMonad(E) open MakeOrderedFunctions(ImportCmp) @@ -12,8 +12,8 @@ open MakeOrderedFunctions(ImportCmp) module Declarations = struct include SailModule.Declarations type process_decl = unit - type method_decl = {defn : MirAst.mir_function method_defn ; llval : llvalue ; extern : bool} - type struct_decl = {defn : struct_proto; ty : lltype} + type method_decl = {defn : MirAst.mir_function method_defn ; llval : L.llvalue ; llty : L.lltype; extern : bool} + type struct_decl = {defn : struct_proto; ty : L.lltype} type enum_decl = unit end @@ -21,10 +21,10 @@ module DeclEnv = DeclarationsEnv (Declarations) module SailEnv = VariableDeclEnv (Declarations)( struct - type t = bool * llvalue + type t = bool * L.llvalue let string_of_var _ = "" - let param_to_var (p:param) = p.mut,global_context () |> i1_type |> const_null (*fixme : make specialized var env for passes to not have this ugly thing *) + let param_to_var (p:param) = L.(p.mut,global_context () |> i1_type |> const_null) (*fixme : make specialized var env for passes to not have this ugly thing *) end ) @@ -34,18 +34,18 @@ open Declarations type in_body = Monomorphization.Pass.out_body -let getLLVMBasicType f ty llc llm : lltype E.t = +let getLLVMBasicType f ty llc llm : L.lltype E.t = let rec aux ty = match ty.value with - | Bool -> i1_type llc |> return - | Int n -> integer_type llc n |> return - | Float -> double_type llc |> return - | Char -> i8_type llc |> return - | String -> i8_type llc |> pointer_type |> return - | ArrayType (t,s) -> let+ t = aux t in array_type t s - | Box t | RefType (t,_) -> aux t <&> pointer_type + | Bool -> L.i1_type llc |> return + | Int n -> L.integer_type llc n |> return + | Float -> L.double_type llc |> return + | Char -> L.i8_type llc |> return + | String -> L.pointer_type2 llc |> return + | ArrayType (t,s) -> let+ t = aux t in L.array_type t s + | Box _ | RefType _ -> L.pointer_type2 llc |> return | GenericType _ -> E.throw Logging.(make_msg ty.loc "no generic type in codegen") - | CompoundType {name; _} when name.value = "_value" -> i64_type llc |> return (* for extern functions *) + | CompoundType {name; _} when name.value = "_value" -> L.i64_type llc |> return (* for extern functions *) | CompoundType {origin=None;_} | CompoundType {decl_ty=None;_} -> E.throw Logging.(make_msg ty.loc "compound type with no origin or decl_ty") | CompoundType {origin=Some mname; name; decl_ty=Some d;_} -> @@ -53,13 +53,13 @@ let getLLVMBasicType f ty llc llm : lltype E.t = in aux ty - let handle_compound_type_codegen env (mname,name,d) llc _llm (aux : sailtype -> lltype E.t) : lltype E.t = + let handle_compound_type_codegen env (mname,name,d) llc _llm (aux : sailtype -> L.lltype E.t) : L.lltype E.t = match DeclEnv.find_decl name (Specific (mname,Filter [d])) env with | Some (T tdef) -> begin match tdef with | {ty=Some t;_} -> aux t - | {ty=None;_} -> i64_type llc |> return + | {ty=None;_} -> L.i64_type llc |> return end | Some (E _enum) -> failwith "todo enum" | Some (S {ty;_}) -> return ty @@ -69,23 +69,23 @@ let getLLVMBasicType f ty llc llm : lltype E.t = let getLLVMType = fun e -> getLLVMBasicType (handle_compound_type_codegen e) - let handle_compound_type env (mname,name,d) llc llm (aux : sailtype -> lltype E.t) : lltype E.t = + let handle_compound_type env (mname,name,d) llc llm (aux : sailtype -> L.lltype E.t) : L.lltype E.t = match SailModule.DeclEnv.find_decl name (Specific (mname,Filter [d])) env with | Some (T tdef) -> begin match tdef with | {ty=Some t;_} -> aux t - | {ty=None;_} -> i64_type llc |> return + | {ty=None;_} -> L.i64_type llc |> return end | Some (E _enum) -> failwith "todo enum" | Some (S (_,defn)) -> let _,f_types = List.split defn.fields in let* elts = ListM.map (fun ty -> aux (fst ty.value)) f_types <&> Array.of_list in begin - match type_by_name llm ("struct." ^ name) with + match L.type_by_name llm ("struct." ^ name) with | Some ty -> return ty - | None -> (let ty = named_struct_type llc ("struct." ^ name) in - struct_set_body ty elts false; return ty) + | None -> (let ty = L.named_struct_type llc ("struct." ^ name) in + L.struct_set_body ty elts false; return ty) end | Some _ -> failwith "something is broken" | None -> failwith @@ Fmt.str "getLLVMType : %s '%s' not found in module '%s'" (string_of_decl d) name mname @@ -96,12 +96,13 @@ let getLLVMBasicType f ty llc llm : lltype E.t = let llvm_proto_of_method_sig (m:method_sig) env llc llm = let* llvm_rt = match m.rtype with | Some t -> getLLVMType env t llc llm - | None -> void_type llc |> return + | None -> L.void_type llc |> return in let+ args_type = ListM.map (fun ({ty;_}: param) -> getLLVMType env ty llc llm) m.params <&> Array.of_list in - let method_t = if m.variadic then var_arg_function_type else function_type in + let method_t = if m.variadic then L.var_arg_function_type else L.function_type in let name = if not (m.extern || m.name = "main") then Fmt.str "_%s_%s" (DeclEnv.get_name env) m.name else m.name in - declare_function name (method_t llvm_rt args_type ) llm + let ty = method_t llvm_rt args_type in + ty,L.declare_function name ty llm let collect_monos (sm: in_body SailModule.t) = let open SailModule.DeclEnv in @@ -145,14 +146,14 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t = else false,m.m_proto in - let* llproto = llvm_proto_of_method_sig proto env llc llm + let* llty,llproto = llvm_proto_of_method_sig proto env llc llm in let m_body = if is_import then Either.left (m.m_proto.name,[]) (* decls body from imports are opaque *) else m.m_body in - DeclEnv.add_decl m.m_proto.name {extern; defn = {m with m_body}; llval=llproto} Method d + DeclEnv.add_decl m.m_proto.name {extern; defn = {m with m_body}; llval=llproto; llty} Method d ) env methods in @@ -168,10 +169,10 @@ let get_declarations (sm: in_body SailModule.t) llc llm : DeclEnv.t E.t = SEnv.fold (fun acc (name,(_,defn)) -> let _,f_types = List.split defn.fields in let* elts = ListM.map (fun ty-> _getLLVMType sm.declEnv (fst ty.value) llc llm) f_types <&> Array.of_list in - let ty = match type_by_name llm ("struct." ^ name) with + let ty = match L.type_by_name llm ("struct." ^ name) with | Some ty -> ty - | None -> let ty = named_struct_type llc ("struct." ^ name) in - struct_set_body ty elts false; ty + | None -> let ty = L.named_struct_type llc ("struct." ^ name) in + L.struct_set_body ty elts false; ty in DeclEnv.add_decl name {defn;ty} Struct acc ) write_env structs diff --git a/src/codegen/codegenUtils.ml b/src/codegen/codegenUtils.ml index 8cf2fe4..507df51 100644 --- a/src/codegen/codegenUtils.ml +++ b/src/codegen/codegenUtils.ml @@ -18,7 +18,7 @@ let getLLVMLiteral (l:literal) (llvm:llvm_args) : llvalue = | LInt i -> const_int_of_string (integer_type llvm.c i.size) (Z.to_string i.l) 10 | LFloat f -> const_float (double_type llvm.c) f | LChar c -> const_int (i8_type llvm.c) (Char.code c) - | LString s -> build_global_stringptr s ".str" llvm.b + | LString s -> let s = build_global_stringptr s ".str" llvm.b in build_pointercast s (pointer_type2 llvm.c) "" llvm.b let ty_of_alias(ty:sailtype) env : sailtype = match ty.value with @@ -93,7 +93,7 @@ let toLLVMArgs (args: param list ) (env:DeclEnv.t) (llvm:llvm_args) : (bool * sa let get_memcpy_intrinsic llvm = - let args_type = [|i8_type llvm.c |> pointer_type; i8_type llvm.c |> pointer_type ; i64_type llvm.c; i1_type llvm.c|] in + let args_type = [|pointer_type2 llvm.c; pointer_type2 llvm.c; i64_type llvm.c; i1_type llvm.c|] in let f = declare_function "llvm.memcpy.p0i8.p0i8.i64" (function_type (void_type llvm.c) args_type ) llvm.m in f \ No newline at end of file diff --git a/src/codegen/codegen_.ml b/src/codegen/codegen_.ml index 960cffc..f9f22b9 100644 --- a/src/codegen/codegen_.ml +++ b/src/codegen/codegen_.ml @@ -13,29 +13,36 @@ let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp | Variable x -> let+ _,v = match (SailEnv.get_var x venv) with | Some (_,n) -> return n - | None -> E.throw Logging.(make_msg dummy_pos @@ Fmt.str "var '%s' not found" x) + | None -> E.throw Logging.(make_msg exp.tag.loc @@ Fmt.str "var '%s' not found" x) in v | Deref x -> eval_r env llvm x | ArrayRead a -> let* array_val = eval_l env llvm a.array in + let* llty = + Env.TypeEnv.get_from_id (mk_locatable a.array.tag.loc a.array.tag.ty) tenv >>= fun t -> + getLLVMType (snd venv) t llvm.c llvm.m in let+ index = eval_r env llvm a.idx in - let llvm_array = L.build_in_bounds_gep array_val [|L.(const_int (i64_type llvm.c) 0 ); index|] "" llvm.b in + let llvm_array = L.build_in_bounds_gep2 llty array_val [|L.(const_int (i64_type llvm.c) 0 ); index|] "" llvm.b in llvm_array | StructRead2 s -> let* st = eval_l env llvm s.value.strct in + let* llty = + Env.TypeEnv.get_from_id (mk_locatable s.value.strct.tag.loc s.value.strct.tag.ty) tenv >>= fun t -> + getLLVMType (snd venv) t llvm.c llvm.m in + let* st_type_name = Env.TypeEnv.get_from_id (mk_locatable s.value.strct.tag.loc s.value.strct.tag.ty) tenv >>= function | {value=CompoundType c;_} -> return c.name.value - | _ -> E.throw Logging.(make_msg dummy_pos "problem with structure type") + | _ -> E.throw Logging.(make_msg exp.tag.loc "problem with structure type") in let+ decl = (SailEnv.get_decl st_type_name (Specific (s.import.value,Struct)) venv |> E.throw_if_none Logging.(make_msg exp.tag.loc @@ Fmt.str "compiler error : no decl '%s' found" st_type_name)) in let fields = decl.defn.fields in let {value=_,idx;_} = List.assoc s.value.field.value fields in - L.build_struct_gep st idx "" llvm.b + L.build_struct_gep2 llty st idx "" llvm.b | StructAlloc2 s -> let _,fieldlist = s.value.fields |> List.split in @@ -47,7 +54,7 @@ let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp let struct_v = L.build_alloca strct_ty "" llvm.b in let+ () = ListM.iteri ( fun i f -> let+ v = eval_r env llvm f.value in - let v_f = L.build_struct_gep struct_v i "" llvm.b in + let v_f = L.build_struct_gep2 (L.pointer_type2 llvm.c) struct_v i "" llvm.b in L.build_store v v_f llvm.b |> ignore ) fieldlist in struct_v @@ -58,8 +65,9 @@ let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp:MirAst.expression) : L.llvalue E.t = let* ty = Env.TypeEnv.get_from_id (mk_locatable exp.tag.loc exp.tag.ty) tenv in + let* llty = getLLVMType (snd venv) ty llvm.c llvm.m in match exp.node with - | Variable _ | StructRead2 _ | ArrayRead _ | StructAlloc2 _ -> let+ v = eval_l env llvm exp in L.build_load v "" llvm.b + | Variable _ | StructRead2 _ | ArrayRead _ | StructAlloc2 _ -> let+ v = eval_l env llvm exp in L.build_load2 llty v "" llvm.b | Literal l -> return @@ getLLVMLiteral l llvm @@ -71,7 +79,7 @@ and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp:Mir in binary bop.op (ty_of_alias ty (snd venv)) l1 l2 llvm.b | Ref (_,e) -> eval_l env llvm e - | Deref e -> let+ v = eval_l env llvm e in L.build_load v "" llvm.b + | Deref e -> let+ v = eval_l env llvm e in L.build_load2 llty v "" llvm.b | ArrayStatic elements -> begin @@ -84,7 +92,7 @@ and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp:Mir L.set_linkage L.Linkage.Private array; L.set_unnamed_addr true array; L.set_global_constant true array; - L.build_load array "" llvm.b + L.build_load2 llty array "" llvm.b end | EnumAlloc _ -> E.throw Logging.(make_msg exp.tag.loc "enum allocation unimplemented") @@ -96,14 +104,15 @@ and construct_call (name:string) (mname:l_str) (args:MirAst.expression list) (ve (* let mname = mangle_method_name name origin.mname args_type in *) let mangled_name = "_" ^ mname.value ^ "_" ^ name in Logs.debug (fun m -> m "constructing call to %s" name); - let* llval,ext = match SailEnv.get_decl mangled_name (Specific (mname.value,Method)) venv with + + let* llval,ext,llty = match SailEnv.get_decl mangled_name (Specific (mname.value,Method)) venv with | None -> begin match SailEnv.get_decl name (Specific (mname.value,Method)) venv with - | Some {llval;extern;_} -> return (llval,extern) + | Some d -> return (d.llval,d.extern,d.llty) | None -> E.throw Logging.(make_msg mname.loc @@ Printf.sprintf "implementation of %s not found" mangled_name ) end - | Some {llval;extern;_} -> return (llval,extern) + | Some d -> return (d.llval,d.extern,d.llty) in let+ args = @@ -122,7 +131,7 @@ and construct_call (name:string) (mname:l_str) (args:MirAst.expression list) (ve else return llargs in - L.build_call llval (Array.of_list args) "" llvm.b + L.build_call2 llty llval (Array.of_list args) "" llvm.b open MirAst @@ -218,7 +227,7 @@ let methodToIR (llc:L.llcontext) (llm:L.llmodule) (decl:Declarations.method_decl Logs.info (fun m -> m "codegen of %s" name); let builder = L.builder llc in let llvm = {b=builder; c=llc ; m = llm; layout=Llvm_target.DataLayout.of_string (L.data_layout llm)} in - let* () = E.throw_if Logging.(make_msg dummy_pos @@ "redefinition of function " ^ name) (L.block_begin decl.llval <> At_end decl.llval) in + let* () = E.throw_if Logging.(make_msg decl.defn.m_proto.pos @@ "redefinition of function " ^ name) (L.block_begin decl.llval <> At_end decl.llval) in let bb = L.append_block llvm.c "" decl.llval in L.position_at_end bb llvm.b; diff --git a/src/passes/process/process.ml b/src/passes/process/process.ml index 23f7a31..402c177 100644 --- a/src/passes/process/process.ml +++ b/src/passes/process/process.ml @@ -89,7 +89,7 @@ module Pass = Pass.Make(struct let* m = M.throw_if_none Logging.(make_msg dummy_pos "need main process") (List.find_opt (fun p -> p.p_name = Constants.main_process) procs) in - let (pi: _ proc_init) = {mloc=None; read = []; write = [] ; params = [] ; id = Constants.main_process ; proc = Constants.main_process} in + let (pi: _ proc_init) = {mloc=Some (mk_locatable dummy_pos Constants.sail_module_self); read = []; write = [] ; params = [] ; id = Constants.main_process ; proc = Constants.main_process} in let* body = compute_tree FieldSet.empty (dummy_pos,pi) in let+ () = M.write_loop body in m ) |> M.run sm.declEnv >>| finalize From c5fd645feba7bb2951ed1a1ba0e1188835234c77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?T=C3=A9rence=20Clastres?= Date: Fri, 13 Dec 2024 21:26:32 +0100 Subject: [PATCH 5/7] don't attempt unmarshaling when forcing a rebuild When the marshalled data structures have changed between two versions of the compiler, attempting to unmarshal them will produce a runtime crash. In that case, a forceful rebuild can be used to skip the unmarshalling and rebuild them. Prior to this commit, the condition was present, but too late. --- bin/sailor.ml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bin/sailor.ml b/bin/sailor.ml index 8c42a49..5f99f0a 100644 --- a/bin/sailor.ml +++ b/bin/sailor.ml @@ -236,15 +236,17 @@ let sailor (files: string list) (intermediate:bool) (jit:bool) (noopt:bool) (dum let import = fun m -> {i with dir=Filename.(dirname m ^ dir_sep); proc_order=(List.length compiling)} in match find_file_opt source ~paths:(Filename.current_dir_name::paths),find_file_opt mir_name with - | Some s,Some m when let mir = unmarshal_sm m in + | Some s,Some m when List.length force_comp < 2 && let mir = unmarshal_sm m in Digest.(equal mir.md.hash @@ file s) && - List.length force_comp < 2 && mir.md.version = Const.sailor_version -> (* mir up-to-date with source -> use mir *) return (treated,import m) | None, Some m -> (* mir but no source -> use mir *) + let* () = E.throw_if Logging.(make_msg dummy_pos + @@ Printf.sprintf "module '%s' has no source but forceful rebuild requested, aborting..." source) (List.length force_comp = 2) + in let mir = unmarshal_sm m in E.throw_if - Logging.(make_msg dummy_pos @@ Printf.sprintf "module %s was compiled with sailor %s, current is %s" mir.md.name mir.md.version Const.sailor_version) + Logging.(make_msg dummy_pos @@ Printf.sprintf "module '%s' was compiled with sailor %s, current is %s" mir.md.name mir.md.version Const.sailor_version) (mir.md.version <> Const.sailor_version) >>| fun () -> treated,import m | None,None -> (* nothing to work with *) From 3087d23feaa06d87d773fad569bee089f6eb100d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?T=C3=A9rence=20Clastres?= Date: Wed, 5 Mar 2025 11:09:23 +0100 Subject: [PATCH 6/7] update to LLVM 19 --- .github/workflows/build.yml | 20 +++++++++----------- bin/dune | 1 + bin/sailor.ml | 37 +++++++++++-------------------------- dune-project | 2 +- sail-pl.opam | 2 +- src/codegen/codegenEnv.ml | 4 ++-- src/codegen/codegenUtils.ml | 23 +++++++++++++---------- src/codegen/codegen_.ml | 14 +++++++------- src/codegen/dune | 2 +- src/common/logging.ml | 4 ++++ 10 files changed, 50 insertions(+), 59 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0919037..0add69a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,29 +7,27 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - uses: kenchan0130/actions-system-info@master + - uses: kenchan0130/actions-system-info@v1.3.1 id: system-info # Checks-out the repository - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - - name: setup llvm 15 repo + - name: setup llvm 19 repo run: | - echo "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-15 main" | sudo tee -a /etc/apt/sources.list - echo "deb-src http://apt.llvm.org/focal/ llvm-toolchain-focal-15 main" | sudo tee -a /etc/apt/sources.list + echo "deb http://apt.llvm.org/noble/ llvm-toolchain-noble-19 main" | sudo tee -a /etc/apt/sources.list + echo "deb-src http://apt.llvm.org/noble/ llvm-toolchain-noble-19 main" | sudo tee -a /etc/apt/sources.list wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - sudo apt update - name: Set up OCaml - # You may pin to the exact commit or the version. - # uses: ocaml/setup-ocaml@6d924c1a7769aa5cdd74bdd901f6e24eb05024b1 - uses: ocaml/setup-ocaml@v2 + uses: ocaml/setup-ocaml@v3 with: - ocaml-compiler: 5.1.X + ocaml-compiler: 5 - run: opam install . --deps-only @@ -37,7 +35,7 @@ jobs: - run: opam exec -- dune build - name: Archive sailor - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: sailor for ${{ steps.system-info.outputs.release }} path: | diff --git a/bin/dune b/bin/dune index bd40c2b..0d2ed13 100755 --- a/bin/dune +++ b/bin/dune @@ -10,6 +10,7 @@ fmt.cli ctypes.foreign logs.cli + llvm.passbuilder ) (public_name sailor) (modes byte exe) diff --git a/bin/sailor.ml b/bin/sailor.ml index 5f99f0a..2f46b5c 100644 --- a/bin/sailor.ml +++ b/bin/sailor.ml @@ -8,6 +8,7 @@ module C = Codegen (* llvm *) module L = Llvm module T = Llvm_target +module P = Llvm_passbuilder (* passes *) @@ -62,24 +63,6 @@ let set_target (llm : Llvm.llmodule) (triple:string) : Llvm_target.Target.t * Ll (target,machine) -let add_opt_passes (pm : [`Module] Llvm.PassManager.t) : unit = - (* seems to be deprecated - TargetMachine.add_analysis_passes pm machine; *) - - (* base needed for other passes *) - Llvm_scalar_opts.add_memory_to_register_promotion pm; - (* eleminates redundant values and loads *) - Llvm_scalar_opts.add_gvn pm; - (* reassociate binary expressions *) - Llvm_scalar_opts.add_reassociation pm; - (* dead code elimination, basic block merging and more *) - Llvm_scalar_opts.add_cfg_simplification pm; - - Llvm_ipo.add_global_optimizer pm; - Llvm_ipo.add_constant_merge pm; - Llvm_ipo.add_function_inlining pm - - let link ?(is_lib = false) (llm:Llvm.llmodule) (module_name : string) (basepath:string) (imports: string list) (libs : string list) (target, machine) clang_args : int = let f = Filename.(concat basepath module_name ^ Const.object_file_ext) in let triple = T.TargetMachine.triple machine in @@ -156,21 +139,23 @@ let sailor (files: string list) (intermediate:bool) (jit:bool) (noopt:bool) (dum let compile sail_module basepath (comp_mode : Cli.comp_mode) : unit E.t = let* m = apply_passes sail_module comp_mode dump_ir in - let+ llm = C.Codegen_.moduleToIR m verify_ir in + let* llm = C.Codegen_.moduleToIR m verify_ir in (* only generate mir file if codegen succeeds *) marshal_sm Filename.(concat basepath m.md.name ^ Const.mir_file_ext) m; let tm = set_target llm target_triple in - if not noopt && comp_mode <> Library then - L.PassManager.( - let pm = create () in add_opt_passes pm; - let res = run_module llm pm in - Logs.debug (fun m -> m "pass manager executed, module modified : %b" res); - dispose pm + let+ () = if not noopt && comp_mode <> Library then + P.( + let options = create_passbuilder_options () in + Logs.debug (fun m -> m "LLVM: running passes"); + let res = run_passes llm "default" (snd tm) options in + dispose_passbuilder_options options; + E.throw_if_result Logging.(fun m -> make_msg dummy_pos m) res ) - ; + else E.pure () + in if intermediate then L.print_module Filename.(concat basepath m.md.name ^ Const.llvm_ir_ext) llm; diff --git a/dune-project b/dune-project index 4b8aefa..85c50a3 100644 --- a/dune-project +++ b/dune-project @@ -23,6 +23,6 @@ (logs (>= 0.7)) (mtime (>= 1.3.0)) (ctypes-foreign (>= 0.18.0)) - (llvm (= 15.0.7+nnp-2)) + (llvm (= 19-shared)) zarith )) diff --git a/sail-pl.opam b/sail-pl.opam index 4e41564..50ff778 100644 --- a/sail-pl.opam +++ b/sail-pl.opam @@ -17,7 +17,7 @@ depends: [ "logs" {>= "0.7"} "mtime" {>= "1.3.0"} "ctypes-foreign" {>= "0.18.0"} - "llvm" {= "15.0.7+nnp-2"} + "llvm" {= "19-shared"} "zarith" "odoc" {with-doc} ] diff --git a/src/codegen/codegenEnv.ml b/src/codegen/codegenEnv.ml index bda04d0..9269b94 100644 --- a/src/codegen/codegenEnv.ml +++ b/src/codegen/codegenEnv.ml @@ -41,9 +41,9 @@ let getLLVMBasicType f ty llc llm : L.lltype E.t = | Int n -> L.integer_type llc n |> return | Float -> L.double_type llc |> return | Char -> L.i8_type llc |> return - | String -> L.pointer_type2 llc |> return + | String -> L.pointer_type llc |> return | ArrayType (t,s) -> let+ t = aux t in L.array_type t s - | Box _ | RefType _ -> L.pointer_type2 llc |> return + | Box _ | RefType _ -> L.pointer_type llc |> return | GenericType _ -> E.throw Logging.(make_msg ty.loc "no generic type in codegen") | CompoundType {name; _} when name.value = "_value" -> L.i64_type llc |> return (* for extern functions *) | CompoundType {origin=None;_} diff --git a/src/codegen/codegenUtils.ml b/src/codegen/codegenUtils.ml index 507df51..8afb653 100644 --- a/src/codegen/codegenUtils.ml +++ b/src/codegen/codegenUtils.ml @@ -1,10 +1,10 @@ -open Llvm open Common +module L = Llvm open TypesCommon open CodegenEnv open Monad.UseMonad(Logging.Logger) -type llvm_args = { c:llcontext; b:llbuilder;m:llmodule; layout : Llvm_target.DataLayout.t} +type llvm_args = { c:L.llcontext; b:L.llbuilder;m:L.llmodule; layout : Llvm_target.DataLayout.t} let mangle_method_name (name:string) (mname:string) (args: sailtype list ) : string = let back = List.fold_left (fun s t -> s ^ string_of_sailtype (Some t) ^ "_" ) "" args in let front = "_" ^ mname ^ "_" ^ name ^ "_" in @@ -12,13 +12,14 @@ let mangle_method_name (name:string) (mname:string) (args: sailtype list ) : str (* Logs.debug (fun m -> m "renamed %s to %s" name res); *) res -let getLLVMLiteral (l:literal) (llvm:llvm_args) : llvalue = +let getLLVMLiteral (l:literal) (llvm:llvm_args) : L.llvalue = + let open L in match l with | LBool b -> const_int (i1_type llvm.c) (Bool.to_int b) | LInt i -> const_int_of_string (integer_type llvm.c i.size) (Z.to_string i.l) 10 | LFloat f -> const_float (double_type llvm.c) f | LChar c -> const_int (i8_type llvm.c) (Char.code c) - | LString s -> let s = build_global_stringptr s ".str" llvm.b in build_pointercast s (pointer_type2 llvm.c) "" llvm.b + | LString s -> let s = build_global_stringptr s ".str" llvm.b in build_pointercast s (pointer_type llvm.c) "" llvm.b let ty_of_alias(ty:sailtype) env : sailtype = match ty.value with @@ -31,7 +32,8 @@ let ty_of_alias(ty:sailtype) env : sailtype = end | _ -> ty -let unary (op:unOp) (t,v) : llbuilder -> llvalue = +let unary (op:unOp) (t,v) : L.llbuilder -> L.llvalue = + let open L in let f = match snd t,op with | Float,Neg -> build_fneg @@ -41,7 +43,8 @@ let unary (op:unOp) (t,v) : llbuilder -> llvalue = in f v "" -let binary (op:binOp) (t:sailtype) (l1:llvalue) (l2:llvalue) : llbuilder -> llvalue = +let binary (op:binOp) (t:sailtype) (l1:L.llvalue) (l2:L.llvalue) : L.llbuilder -> L.llvalue = + let open L in let operators = function | Int _ -> Some [ @@ -84,16 +87,16 @@ let binary (op:binOp) (t:sailtype) (l1:llvalue) (l2:llvalue) : llbuilder -> llva | None -> Printf.sprintf "codegen: bad usage of binop '%s' with type %s" (string_of_binop op) (string_of_sailtype @@ Some t) |> failwith -let toLLVMArgs (args: param list ) (env:DeclEnv.t) (llvm:llvm_args) : (bool * sailtype * llvalue) array E.t = +let toLLVMArgs (args: param list ) (env:DeclEnv.t) (llvm:llvm_args) : (bool * sailtype * L.llvalue) array E.t = ListM.map ( fun {id;mut;ty=t;_} -> let+ ty = getLLVMType env t llvm.c llvm.m in - mut,t,build_alloca ty id llvm.b + mut,t,L.build_alloca ty id llvm.b ) args <&> Array.of_list let get_memcpy_intrinsic llvm = - let args_type = [|pointer_type2 llvm.c; pointer_type2 llvm.c; i64_type llvm.c; i1_type llvm.c|] in - + let open L in + let args_type = [|pointer_type llvm.c; pointer_type llvm.c; i64_type llvm.c; i1_type llvm.c|] in let f = declare_function "llvm.memcpy.p0i8.p0i8.i64" (function_type (void_type llvm.c) args_type ) llvm.m in f \ No newline at end of file diff --git a/src/codegen/codegen_.ml b/src/codegen/codegen_.ml index f9f22b9..a42a8c0 100644 --- a/src/codegen/codegen_.ml +++ b/src/codegen/codegen_.ml @@ -24,7 +24,7 @@ let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp Env.TypeEnv.get_from_id (mk_locatable a.array.tag.loc a.array.tag.ty) tenv >>= fun t -> getLLVMType (snd venv) t llvm.c llvm.m in let+ index = eval_r env llvm a.idx in - let llvm_array = L.build_in_bounds_gep2 llty array_val [|L.(const_int (i64_type llvm.c) 0 ); index|] "" llvm.b in + let llvm_array = L.build_in_bounds_gep llty array_val [|L.(const_int (i64_type llvm.c) 0 ); index|] "" llvm.b in llvm_array | StructRead2 s -> @@ -42,7 +42,7 @@ let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp let fields = decl.defn.fields in let {value=_,idx;_} = List.assoc s.value.field.value fields in - L.build_struct_gep2 llty st idx "" llvm.b + L.build_struct_gep llty st idx "" llvm.b | StructAlloc2 s -> let _,fieldlist = s.value.fields |> List.split in @@ -54,7 +54,7 @@ let rec eval_l (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp let struct_v = L.build_alloca strct_ty "" llvm.b in let+ () = ListM.iteri ( fun i f -> let+ v = eval_r env llvm f.value in - let v_f = L.build_struct_gep2 (L.pointer_type2 llvm.c) struct_v i "" llvm.b in + let v_f = L.build_struct_gep (L.pointer_type llvm.c) struct_v i "" llvm.b in L.build_store v v_f llvm.b |> ignore ) fieldlist in struct_v @@ -67,7 +67,7 @@ and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp:Mir let* ty = Env.TypeEnv.get_from_id (mk_locatable exp.tag.loc exp.tag.ty) tenv in let* llty = getLLVMType (snd venv) ty llvm.c llvm.m in match exp.node with - | Variable _ | StructRead2 _ | ArrayRead _ | StructAlloc2 _ -> let+ v = eval_l env llvm exp in L.build_load2 llty v "" llvm.b + | Variable _ | StructRead2 _ | ArrayRead _ | StructAlloc2 _ -> let+ v = eval_l env llvm exp in L.build_load llty v "" llvm.b | Literal l -> return @@ getLLVMLiteral l llvm @@ -79,7 +79,7 @@ and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp:Mir in binary bop.op (ty_of_alias ty (snd venv)) l1 l2 llvm.b | Ref (_,e) -> eval_l env llvm e - | Deref e -> let+ v = eval_l env llvm e in L.build_load2 llty v "" llvm.b + | Deref e -> let+ v = eval_l env llvm e in L.build_load llty v "" llvm.b | ArrayStatic elements -> begin @@ -92,7 +92,7 @@ and eval_r (venv,tenv as env:SailEnv.t* Env.TypeEnv.t) (llvm:llvm_args) (exp:Mir L.set_linkage L.Linkage.Private array; L.set_unnamed_addr true array; L.set_global_constant true array; - L.build_load2 llty array "" llvm.b + L.build_load llty array "" llvm.b end | EnumAlloc _ -> E.throw Logging.(make_msg exp.tag.loc "enum allocation unimplemented") @@ -131,7 +131,7 @@ and construct_call (name:string) (mname:l_str) (args:MirAst.expression list) (ve else return llargs in - L.build_call2 llty llval (Array.of_list args) "" llvm.b + L.build_call llty llval (Array.of_list args) "" llvm.b open MirAst diff --git a/src/codegen/dune b/src/codegen/dune index 4fd7ad6..b3a85b7 100644 --- a/src/codegen/dune +++ b/src/codegen/dune @@ -1,3 +1,3 @@ (library - (libraries common logs passes ir mono sailParser llvm llvm.analysis llvm.executionengine llvm.scalar_opts llvm.ipo llvm.all_backends) + (libraries common logs passes ir mono sailParser llvm llvm.analysis llvm.executionengine llvm.all_backends) (name codegen)) diff --git a/src/common/logging.ml b/src/common/logging.ml index c5a114a..8191629 100644 --- a/src/common/logging.ml +++ b/src/common/logging.ml @@ -181,6 +181,10 @@ end | Some r -> throw (f r) | None -> pure () + let throw_if_result (f: 'b -> msg) (x: ('a,'b) Result.t) : 'a t = match x with + | Ok x -> pure x + | Error e -> throw (f e) + let get_warnings (f : msg list -> unit) (x : 'a t) : 'a t = let+ v,l = x in f l.warnings; (v,l) From f26f4013915bf5ff5ac365099b10c60ee1beff6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?T=C3=A9rence=20Clastres?= Date: Mon, 6 Jan 2025 21:25:22 +0100 Subject: [PATCH 7/7] process: new composition syntax, misc fixes --- src/parsing/lexer.mll | 8 ++++++-- src/parsing/parser.mly | 7 ++++--- src/passes/process/process.ml | 9 +++++---- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/parsing/lexer.mll b/src/parsing/lexer.mll index 5319474..6df64bc 100644 --- a/src/parsing/lexer.mll +++ b/src/parsing/lexer.mll @@ -89,8 +89,12 @@ rule read_token = parse | "with" { WITH } | "reads" { READS } | "writes" { WRITES } - | "Par" { PAR } - | "Seq" { SEQ } + (* | "Par" { PAR } + | "Seq" { SEQ } *) + | "{{" { LPARC } + | "}}" { RPARC } + | "[[" { LSEQC } + | "]]" { RSEQC } (* | "((" {P_LPAREN} | "))" {P_RPAREN} *) | "!" { NOT } diff --git a/src/parsing/parser.mly b/src/parsing/parser.mly index 8df4866..b0ffec4 100644 --- a/src/parsing/parser.mly +++ b/src/parsing/parser.mly @@ -37,6 +37,7 @@ %token CHAR %token LPAREN "(" RPAREN ")" LBRACE "{" RBRACE "}" LSQBRACE "[" RSQBRACE "]" LANGLE "<" RANGLE ">" (* ARROW "->" *) +%token LPARC RPARC LSEQC RSEQC %token COMMA "," COLON ":" DCOLON "::" SEMICOLON ";" DOT "." %token ASSIGN "=" %token IMPORT @@ -50,7 +51,7 @@ %token WRITES READS %token P_PROC_INIT // %token AWAIT EMIT WATCHING WHEN PAR "||" -%token PAR SEQ +// %token PAR SEQ %token P_INIT P_LOOP // %token P_LPAREN "((" P_RPAREN "))" %token WITH @@ -152,8 +153,8 @@ let loop := located( | ~ = statement ; ~ = pwhen ; | ~ = located(UID) ; ~ = pwhen; - | PAR ; cond = pwhen ; "{" ; children = separated_list(WITH,loop) ; "}" ; { PGroup {p_ty=Parallel; cond ; children} } - | SEQ ; cond = pwhen ; "{" ; children = separated_list(WITH,loop) ; "}" ; { PGroup {p_ty=Sequence; cond ; children} } + | LPARC ; children = separated_list(WITH,loop) ; RPARC ; cond = pwhen ; { PGroup {p_ty=Parallel; cond ; children} } + | LSEQC ; children = separated_list(WITH,loop) ; RSEQC ; cond = pwhen ; { PGroup {p_ty=Sequence; cond ; children} } ) let pwhen == midrule(WHEN ; LPAREN ; ~= expression ; RPAREN; <>)? diff --git a/src/passes/process/process.ml b/src/passes/process/process.ml index 402c177..9fbe054 100644 --- a/src/passes/process/process.ml +++ b/src/passes/process/process.ml @@ -20,8 +20,8 @@ module Pass = Pass.Make(struct let rec compute_tree closed (l,pi:loc * _ proc_init) : HirU.statement M.t = let closed = FieldSet.add pi.proc closed in (* no cycle *) - let* p = find_process_source (mk_locatable l pi.proc) pi.mloc procs (*fixme : grammar to allow :: syntax *) in - let* p = M.throw_if_none Logging.(make_msg l @@ Fmt.str "process '%s' is unknown" pi.proc) p in + let* p = find_process_source (mk_locatable l pi.proc) pi.mloc procs (*fixme : grammar to allow :: syntax *) + >>= M.throw_if_none Logging.(make_msg l @@ Fmt.str "process '%s' is unknown" pi.proc) in let* tag = M.fresh_prefix p.p_name in let prefix = (Fmt.str "%s_%s_" tag) in let read_params,write_params = p.p_interface.p_shared_vars in @@ -34,9 +34,11 @@ module Pass = Pass.Make(struct let* () = param_arg_mismatch "init" p.p_interface.p_params pi.params in + (* correspondence between shared variables names (given, original) *) let rename_l = List.map2 (fun subx x -> (fst x.value,subx.value) ) (pi.read @ pi.write) (fst p.p_interface.p_shared_vars @ snd p.p_interface.p_shared_vars) in + (* if the id corresponds to a shared variable, replace it by the name of the provided variable *) let rename = fun id -> match List.assoc_opt id rename_l with Some v -> v | None -> id in (* add process local (but persistant) vars *) @@ -68,10 +70,9 @@ module Pass = Pass.Make(struct return (process_cond cond s) | Run (id,cond) -> - M.throw_if Logging.(make_msg l "not allowed to call Main process explicitely") (id.value = Constants.main_process) >>= fun () -> - M.throw_if Logging.(make_msg l "not allowed to have recursive process") (FieldSet.mem id.value closed) >>= fun () -> let* pi = M.throw_if_none Logging.(make_msg l @@ Fmt.str "no proc init called '%s'" id.value) (List.find_opt (fun p -> p.value.id = id.value) p.p_body.proc_init) in + M.throw_if Logging.(make_msg l "not allowed to have recursive process") (FieldSet.mem pi.value.proc closed) >>= fun () -> let read = List.map (fun (id:l_str) -> mk_locatable id.loc @@ prefix id.value) pi.value.read in let write = List.map (fun (id:l_str) -> mk_locatable id.loc @@ prefix id.value) pi.value.write in let params = List.map (AstU.rename_var prefix) pi.value.params in