From 116ca8cec90bb9b055b4d7f317c24f70598fb173 Mon Sep 17 00:00:00 2001 From: FrederickPu Date: Fri, 4 Apr 2025 10:03:54 -0400 Subject: [PATCH 1/9] added batch processing to repl --- REPL/Main.lean | 82 ++++++++++++++++++++++++++++++++++++++++++++++++-- stats.txt | 7 +++++ test.py | 45 +++++++++++++++++++++++++++ 3 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 stats.txt create mode 100644 test.py diff --git a/REPL/Main.lean b/REPL/Main.lean index 345bbb07..45e6ac59 100644 --- a/REPL/Main.lean +++ b/REPL/Main.lean @@ -232,6 +232,46 @@ def runCommand (s : Command) : M IO (CommandResponse ⊕ Error) := do tactics infotree } +def runCommandsParrallel (commands : Array Command) : M IO (Array (CommandResponse ⊕ Error)) := do + let tasks ← (commands.mapM ((fun cmd => IO.asTask <| IO.processInput cmd.cmd none))) + let result := ← IO.wait <| ← IO.mapTasks (·.mapM' (pure ·)) tasks.toList + match result with + | .ok results => { + let womp : List (CommandResponse ⊕ Error) ← results.mapM + (fun x => do + match x with + | .ok ⟨_, messages, infotrees⟩ => do + return .inl + { env := 0, + messages := ← messages.mapM fun m => Message.of m, + sorries := ← sorries infotrees none, + tactics := [], -- ← tactics infotrees, + infotree := none} -- Json.arr (← infotrees.mapM fun t => t.toJson none).toArray} + | .error _ => pure <| .inr (Error.mk "failed") + ) + return womp.toArray + } + | .error _ => return #[] + +def runCommandsSequential (commands : Array Command) : M IO (Array (CommandResponse ⊕ Error)) := do + let (cmdSnapshot?, _) ← do match (← get).cmdStates[0]? with + | some env => pure (some env, false) + | none => pure (none, true) + let initialCmdState? := cmdSnapshot?.map fun c => c.cmdState + let results ← commands.mapM (fun cmd => do + IO.println "brr" + IO.processInput cmd.cmd initialCmdState? + ) + results.mapM (fun ⟨_, messages, infotrees⟩ => do + return .inl + { env := 0, + messages := ← messages.mapM fun m => Message.of m, + sorries := ← sorries infotrees none, + tactics := ← tactics infotrees, + infotree := Json.arr (← infotrees.mapM fun t => t.toJson none).toArray + } + ) + def processFile (s : File) : M IO (CommandResponse ⊕ Error) := do try let cmd ← IO.FS.readFile s.path @@ -270,6 +310,22 @@ instance [ToJson α] [ToJson β] : ToJson (α ⊕ β) where | .inl a => toJson a | .inr b => toJson b + +structure Batch where + header : String + proofs : Array String +deriving FromJson, ToJson + + +def parseBatch (query : String) : IO Batch := do + let json := Json.parse query + match json with + | .error e => throw <| IO.userError <| toString <| toJson <| + (⟨"Could not parse JSON:\n" ++ e⟩ : Error) + | .ok j => match fromJson? j with + | .ok (r : Batch) => return r + | .error e => throw <| IO.userError <| toString <| toJson <| (⟨"Could not parse JSON batch:\n" ++ e⟩ : Error) + /-- Commands accepted by the REPL. -/ inductive Input | command : REPL.Command → Input @@ -328,7 +384,29 @@ where loop : M IO Unit := do printFlush "\n" -- easier to parse the output if there are blank lines loop + +def testSeqential: M IO Unit := do + let query ← getLines + let ⟨header, proofs⟩ ← parseBatch query + let _ ← (runCommand (⟨⟨none, none⟩, none, header⟩)) + let comm : Array REPL.Command := proofs.map (fun pf => ⟨⟨none, none⟩, some 0, pf⟩) + let q ← (runCommandsSequential comm) + for l in q do + IO.println (toJson l) + +-- #check Command.mk +-- #check CommandOptions.mk +def testParrallel : IO Unit := do + let _ ← StateT.run' (runCommand (⟨⟨none, none⟩, some 0, "#eval 0"⟩)) {} + let comm : Array REPL.Command := ((List.range 10).map (fun _ => ⟨⟨none, none⟩, some 0, "theorem womp (a b c : ℕ) : a + b + c = c + (b + a) := by sorry"⟩)).toArray + let q ← StateT.run' (runCommandsParrallel comm) {} + for l in q do + IO.println (toJson l) + +-- #eval (do testSeqential.run' {}) + /-- Main executable function, run as `lake exe repl`. -/ unsafe def main (_ : List String) : IO Unit := do - initSearchPath (← Lean.findSysroot) - repl + -- initSearchPath (← Lean.findSysroot) + -- repl + testSeqential.run' {} diff --git a/stats.txt b/stats.txt new file mode 100644 index 00000000..b011ed25 --- /dev/null +++ b/stats.txt @@ -0,0 +1,7 @@ +# all profiling is one using meaure-command + +# test run on 1000 samples of +# theorem womp (a b c : ℕ) : a + b + c = c + (b + a) := by sorry + +batch verify repl only, without env sharing: 78.0058819 seconds +batch verify repl only, with env sharing: 2.0209996 seconds \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 00000000..25f10afe --- /dev/null +++ b/test.py @@ -0,0 +1,45 @@ +from datasets import load_dataset +import json +import subprocess +import time + +header = "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n" +# Login using e.g. `huggingface-cli login` to access this dataset +ds = load_dataset("Goedel-LM/Lean-workbook-proofs") + +proofs = [] + +for data in ds["train"].select(range(20)): + proof = data["full_proof"].split(header)[1] + proofs.append(proof) + # print(header, proof) + +batchCmd = { "header": header, "proofs" : proofs} +tmp = json.dumps(batchCmd) + +print("done loading") + +start_repl_time = time.time() + +process = subprocess.Popen( + ["lake", "env", "../../.lake/build/bin/repl"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, # Makes it work with strings + encoding="utf-8" +) + +# Write input directly to the process +process.stdin.write(tmp + "\n") +process.stdin.flush() # Ensure it's sent immediately + +# Read output +stdout, stderr = process.communicate() + +# Print results +print("STDOUT:", stdout) +print("STDERR:", stderr) + +repl_time = time.time() - start_repl_time +print(f"REPL execution completed in {repl_time:.2f} seconds.") \ No newline at end of file From 3e4e9c363e75bd58cd616e1702e461018981de05 Mon Sep 17 00:00:00 2001 From: FrederickPu Date: Sat, 5 Apr 2025 01:15:02 -0400 Subject: [PATCH 2/9] simplified enviroment lifecycle --- REPL/Main.lean | 139 ++++++++++++++++++++++++++++++++----------------- test.py | 19 ++++--- 2 files changed, 104 insertions(+), 54 deletions(-) diff --git a/REPL/Main.lean b/REPL/Main.lean index 45e6ac59..7c43d58c 100644 --- a/REPL/Main.lean +++ b/REPL/Main.lean @@ -232,45 +232,91 @@ def runCommand (s : Command) : M IO (CommandResponse ⊕ Error) := do tactics infotree } -def runCommandsParrallel (commands : Array Command) : M IO (Array (CommandResponse ⊕ Error)) := do - let tasks ← (commands.mapM ((fun cmd => IO.asTask <| IO.processInput cmd.cmd none))) - let result := ← IO.wait <| ← IO.mapTasks (·.mapM' (pure ·)) tasks.toList - match result with - | .ok results => { - let womp : List (CommandResponse ⊕ Error) ← results.mapM - (fun x => do - match x with - | .ok ⟨_, messages, infotrees⟩ => do - return .inl - { env := 0, - messages := ← messages.mapM fun m => Message.of m, - sorries := ← sorries infotrees none, - tactics := [], -- ← tactics infotrees, - infotree := none} -- Json.arr (← infotrees.mapM fun t => t.toJson none).toArray} - | .error _ => pure <| .inr (Error.mk "failed") - ) - return womp.toArray - } - | .error _ => return #[] - -def runCommandsSequential (commands : Array Command) : M IO (Array (CommandResponse ⊕ Error)) := do - let (cmdSnapshot?, _) ← do match (← get).cmdStates[0]? with - | some env => pure (some env, false) - | none => pure (none, true) - let initialCmdState? := cmdSnapshot?.map fun c => c.cmdState - let results ← commands.mapM (fun cmd => do - IO.println "brr" - IO.processInput cmd.cmd initialCmdState? +def splitArray {α : Type} (arr : Array α) (n : Nat) : Array (Array α) := Id.run do + if n ≤ 0 then #[] + else if n = 1 then #[arr] + else if arr.size = 0 then Array.replicate n #[] + else + let baseSize := arr.size / n + let remainder := arr.size % n + + let mut result : Array (Array α) := #[] + let mut start : Nat := 0 + + for i in List.range n do + let extraElem := if i < remainder then 1 else 0 + let endPos := start + baseSize + extraElem + let subArray := arr.extract start endPos + result := result.push subArray + start := start + baseSize + extraElem + result + +#eval splitArray #[1] 4 + +unsafe def getHeaderEnv (header : String) : IO Command.State := do + Lean.initSearchPath (← Lean.findSysroot) + enableInitializersExecution + let inputCtx := Parser.mkInputContext header "" + let (header, parserState, messages) ← Parser.parseHeader inputCtx + let (env, _) ← processHeader header {} messages inputCtx + let commandState := (Command.mkState env messages {}) + let s ← IO.processCommands inputCtx parserState commandState <&> Frontend.State.commandState + pure s + +unsafe def runCommandsSequential (commandState : Command.State) (proofs : Array String) : IO (Array (CommandResponse ⊕ Error)) := do + proofs.mapM (fun pf => do + let inputCtx := Parser.mkInputContext pf "" + let parserState := { : Parser.ModuleParserState } + let (_, msgs, _) ← Lean.Elab.IO.processCommandsWithInfoTrees inputCtx parserState commandState + return .inl ({ env := 0, + messages := ← msgs.mapM fun m => Message.of m, + sorries := [], + tactics := [], + infotree := none }) ) - results.mapM (fun ⟨_, messages, infotrees⟩ => do - return .inl - { env := 0, - messages := ← messages.mapM fun m => Message.of m, - sorries := ← sorries infotrees none, - tactics := ← tactics infotrees, - infotree := Json.arr (← infotrees.mapM fun t => t.toJson none).toArray - } + +unsafe def runCommandsParrallelNaive (header : String) (proofs : Array String) : IO (Array (CommandResponse ⊕ Error)) := do + let commandState ← getHeaderEnv header + let tasks : Array (Task (Except IO.Error CommandResponse)) ← (proofs.mapM <| fun proof => IO.asTask <| do + let inputCtx := Parser.mkInputContext proof "" + let parserState := { : Parser.ModuleParserState } + let (_, msgs, _) ← Lean.Elab.IO.processCommandsWithInfoTrees inputCtx parserState commandState + return ({ env := 0, + messages := ← msgs.mapM fun m => Message.of m, + sorries := [], + tactics := [], + infotree := none + }) + ) + let result := ← IO.wait <| ← IO.mapTasks (·.mapM (pure ·)) tasks.toList + match result with + | .ok results => + return (results.map + fun x => + match x with + | .ok cmdres => .inl cmdres + | .error e => .inr (Error.mk e.toString) + ).toArray + | .error _ => return #[] + +#check Except +unsafe def runCommandsParrallel (header : String) (proofs : Array String) : IO (Array (CommandResponse ⊕ Error)) := do + let commandState ← getHeaderEnv header + let tasks ← (splitArray proofs 100 |>.mapM <| fun bucket => IO.asTask ( (runCommandsSequential commandState bucket))) + let result := ← IO.wait <| ← IO.mapTasks (·.mapM (pure ·)) tasks.toList + match result with + | .ok results => { + let womp : List (Array (CommandResponse ⊕ Error)) ← results.mapM ( + fun x => do + match x with + | .ok bucket => pure bucket + | .error e => + IO.println e + pure #[] ) + return womp.toArray.flatMap id + } + | .error _ => return #[] def processFile (s : File) : M IO (CommandResponse ⊕ Error) := do try @@ -385,21 +431,20 @@ where loop : M IO Unit := do loop -def testSeqential: M IO Unit := do +unsafe def testSeqential: M IO Unit := do let query ← getLines let ⟨header, proofs⟩ ← parseBatch query - let _ ← (runCommand (⟨⟨none, none⟩, none, header⟩)) - let comm : Array REPL.Command := proofs.map (fun pf => ⟨⟨none, none⟩, some 0, pf⟩) - let q ← (runCommandsSequential comm) + let commandState ← getHeaderEnv header + let q ← (runCommandsSequential commandState proofs) for l in q do IO.println (toJson l) --- #check Command.mk +#check Command.mk -- #check CommandOptions.mk -def testParrallel : IO Unit := do - let _ ← StateT.run' (runCommand (⟨⟨none, none⟩, some 0, "#eval 0"⟩)) {} - let comm : Array REPL.Command := ((List.range 10).map (fun _ => ⟨⟨none, none⟩, some 0, "theorem womp (a b c : ℕ) : a + b + c = c + (b + a) := by sorry"⟩)).toArray - let q ← StateT.run' (runCommandsParrallel comm) {} +unsafe def testParrallel : IO Unit := do + let query ← getLines + let ⟨header, proofs⟩ ← parseBatch query + let q ← (runCommandsParrallelNaive header proofs) for l in q do IO.println (toJson l) @@ -409,4 +454,4 @@ def testParrallel : IO Unit := do unsafe def main (_ : List String) : IO Unit := do -- initSearchPath (← Lean.findSysroot) -- repl - testSeqential.run' {} + testParrallel diff --git a/test.py b/test.py index 25f10afe..a3013e88 100644 --- a/test.py +++ b/test.py @@ -9,13 +9,13 @@ proofs = [] -for data in ds["train"].select(range(20)): +for data in ds["train"].select(range(1000)): proof = data["full_proof"].split(header)[1] proofs.append(proof) - # print(header, proof) + print(header, proof) -batchCmd = { "header": header, "proofs" : proofs} -tmp = json.dumps(batchCmd) +# batchCmd = { "header": header, "proofs": proofs} +# tmp = json.dumps(batchCmd) print("done loading") @@ -30,9 +30,14 @@ encoding="utf-8" ) -# Write input directly to the process -process.stdin.write(tmp + "\n") -process.stdin.flush() # Ensure it's sent immediately +# # Write input directly to the process +process.stdin.write(json.dumps({"header" : header, "proofs": proofs}) + "\n") +process.stdin.flush() +# process.stdin.write(json.dumps({ "cmd" : header}) + "\n\n") +# process.stdin.flush() +# for proof in proofs: +# process.stdin.write(json.dumps({ "cmd" : proof, "env" : 0}) + "\n\n") +# process.stdin.flush() # Ensure it's sent immediately # Read output stdout, stderr = process.communicate() From 07181a56b8e1cd582ea796e8ed3e1b04aaaf4220 Mon Sep 17 00:00:00 2001 From: FrederickPu Date: Tue, 8 Apr 2025 09:22:34 -0400 Subject: [PATCH 3/9] integrated into cli interface --- REPL/JSON.lean | 12 ++++++++++++ REPL/Main.lean | 42 +++++++++++++++++++++++++++++------------- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/REPL/JSON.lean b/REPL/JSON.lean index d5c5ba2d..74e58731 100644 --- a/REPL/JSON.lean +++ b/REPL/JSON.lean @@ -28,6 +28,18 @@ structure Command extends CommandOptions where cmd : String deriving ToJson, FromJson +structure BatchVerifyOptions where + /- + "sequential", "naive", "parrallel" + -/ + mode : Option String + buckets : Option Nat + +structure BatchVerify extends BatchVerifyOptions where + header : String + proofs : Array String +deriving ToJson, FromJson + /-- Process a Lean file in a fresh environment. -/ structure File extends CommandOptions where path : System.FilePath diff --git a/REPL/Main.lean b/REPL/Main.lean index 7c43d58c..aea10b2d 100644 --- a/REPL/Main.lean +++ b/REPL/Main.lean @@ -251,7 +251,7 @@ def splitArray {α : Type} (arr : Array α) (n : Nat) : Array (Array α) := Id.r start := start + baseSize + extraElem result -#eval splitArray #[1] 4 +-- #eval splitArray #[1] 4 unsafe def getHeaderEnv (header : String) : IO Command.State := do Lean.initSearchPath (← Lean.findSysroot) @@ -263,7 +263,7 @@ unsafe def getHeaderEnv (header : String) : IO Command.State := do let s ← IO.processCommands inputCtx parserState commandState <&> Frontend.State.commandState pure s -unsafe def runCommandsSequential (commandState : Command.State) (proofs : Array String) : IO (Array (CommandResponse ⊕ Error)) := do +unsafe def batchVerifySequential (commandState : Command.State) (proofs : Array String) : IO (Array (CommandResponse ⊕ Error)) := do proofs.mapM (fun pf => do let inputCtx := Parser.mkInputContext pf "" let parserState := { : Parser.ModuleParserState } @@ -275,7 +275,7 @@ unsafe def runCommandsSequential (commandState : Command.State) (proofs : Array infotree := none }) ) -unsafe def runCommandsParrallelNaive (header : String) (proofs : Array String) : IO (Array (CommandResponse ⊕ Error)) := do +unsafe def batchVerifyParrallelNaive (header : String) (proofs : Array String) : IO (Array (CommandResponse ⊕ Error)) := do let commandState ← getHeaderEnv header let tasks : Array (Task (Except IO.Error CommandResponse)) ← (proofs.mapM <| fun proof => IO.asTask <| do let inputCtx := Parser.mkInputContext proof "" @@ -299,10 +299,13 @@ unsafe def runCommandsParrallelNaive (header : String) (proofs : Array String) : ).toArray | .error _ => return #[] -#check Except -unsafe def runCommandsParrallel (header : String) (proofs : Array String) : IO (Array (CommandResponse ⊕ Error)) := do +unsafe def batchVerifyParrallel (header : String) (proofs : Array String) (buckets : Option Nat): IO (Array (CommandResponse ⊕ Error)) := do + let buckets := + match buckets with + | some x => x + | none => max 50 proofs.size let commandState ← getHeaderEnv header - let tasks ← (splitArray proofs 100 |>.mapM <| fun bucket => IO.asTask ( (runCommandsSequential commandState bucket))) + let tasks ← (splitArray proofs buckets |>.mapM <| fun bucket => IO.asTask ( (batchVerifySequential commandState bucket))) let result := ← IO.wait <| ← IO.mapTasks (·.mapM (pure ·)) tasks.toList match result with | .ok results => { @@ -325,6 +328,18 @@ def processFile (s : File) : M IO (CommandResponse ⊕ Error) := do catch e => pure <| .inr ⟨e.toString⟩ +unsafe def runBatchVerify (batch : BatchVerify) : IO (Array (CommandResponse ⊕ Error)) := do + match batch.mode with + | some x => + if x = "naive" then do + return ← batchVerifyParrallelNaive batch.header batch.proofs + if x = "parrallel" then do + return ← batchVerifyParrallel batch.header batch.proofs batch.buckets + | none => + pure () + let commandState ← getHeaderEnv batch.header + batchVerifySequential commandState batch.proofs + /-- Run a single tactic, returning the id of the new proof statement, and the new goals. -/ @@ -381,6 +396,7 @@ inductive Input | unpickleEnvironment : REPL.UnpickleEnvironment → Input | pickleProofSnapshot : REPL.PickleProofState → Input | unpickleProofSnapshot : REPL.UnpickleProofState → Input +| batchVerify : REPL.BatchVerify → Input /-- Parse a user input string to an input command. -/ def parse (query : String) : IO Input := do @@ -401,6 +417,8 @@ def parse (query : String) : IO Input := do | .error _ => match fromJson? j with | .ok (r : REPL.Command) => return .command r | .error _ => match fromJson? j with + | .ok (r : REPL.BatchVerify) => return .batchVerify r + | .error _ => match fromJson? j with | .ok (r : REPL.File) => return .file r | .error e => throw <| IO.userError <| toString <| toJson <| (⟨"Could not parse as a valid JSON command:\n" ++ e⟩ : Error) @@ -421,6 +439,7 @@ where loop : M IO Unit := do if query.startsWith "#" || query.startsWith "--" then loop else IO.println <| toString <| ← match ← parse query with | .command r => return toJson (← runCommand r) + | .batchVerify r => return toJson (← runBatchVerify r) | .file r => return toJson (← processFile r) | .proofStep r => return toJson (← runProofStep r) | .pickleEnvironment r => return toJson (← pickleCommandSnapshot r) @@ -435,23 +454,20 @@ unsafe def testSeqential: M IO Unit := do let query ← getLines let ⟨header, proofs⟩ ← parseBatch query let commandState ← getHeaderEnv header - let q ← (runCommandsSequential commandState proofs) + let q ← (batchVerifySequential commandState proofs) for l in q do IO.println (toJson l) -#check Command.mk -- #check CommandOptions.mk unsafe def testParrallel : IO Unit := do let query ← getLines let ⟨header, proofs⟩ ← parseBatch query - let q ← (runCommandsParrallelNaive header proofs) + let q ← (batchVerifyParrallelNaive header proofs) for l in q do IO.println (toJson l) --- #eval (do testSeqential.run' {}) /-- Main executable function, run as `lake exe repl`. -/ unsafe def main (_ : List String) : IO Unit := do - -- initSearchPath (← Lean.findSysroot) - -- repl - testParrallel + initSearchPath (← Lean.findSysroot) + repl From 3a194e9868dd488c4e5507d320ebcc7c9a2c4be6 Mon Sep 17 00:00:00 2001 From: FrederickPu Date: Wed, 9 Apr 2025 21:53:59 -0400 Subject: [PATCH 4/9] added support for enviroments --- REPL/Main.lean | 58 ++++++++++++++++++++++++++++---------------------- test.py | 22 ++++++++----------- 2 files changed, 41 insertions(+), 39 deletions(-) diff --git a/REPL/Main.lean b/REPL/Main.lean index aea10b2d..a12ca628 100644 --- a/REPL/Main.lean +++ b/REPL/Main.lean @@ -263,32 +263,22 @@ unsafe def getHeaderEnv (header : String) : IO Command.State := do let s ← IO.processCommands inputCtx parserState commandState <&> Frontend.State.commandState pure s -unsafe def batchVerifySequential (commandState : Command.State) (proofs : Array String) : IO (Array (CommandResponse ⊕ Error)) := do +unsafe def batchVerifySequential (commandState : Command.State) (proofs : Array String) : IO (Array (VerifyResponse ⊕ Error)) := do proofs.mapM (fun pf => do let inputCtx := Parser.mkInputContext pf "" let parserState := { : Parser.ModuleParserState } let (_, msgs, _) ← Lean.Elab.IO.processCommandsWithInfoTrees inputCtx parserState commandState - return .inl ({ env := 0, - messages := ← msgs.mapM fun m => Message.of m, - sorries := [], - tactics := [], - infotree := none }) + return .inl ({ messages := ← msgs.mapM fun m => Message.of m }) ) -unsafe def batchVerifyParrallelNaive (header : String) (proofs : Array String) : IO (Array (CommandResponse ⊕ Error)) := do - let commandState ← getHeaderEnv header - let tasks : Array (Task (Except IO.Error CommandResponse)) ← (proofs.mapM <| fun proof => IO.asTask <| do +unsafe def batchVerifyParrallelNaive (commandState : Command.State) (proofs : Array String) : IO (Array (VerifyResponse ⊕ Error)) := do + let tasks : Array (Task (Except IO.Error VerifyResponse)) ← (proofs.mapM <| fun proof => IO.asTask <| do let inputCtx := Parser.mkInputContext proof "" let parserState := { : Parser.ModuleParserState } let (_, msgs, _) ← Lean.Elab.IO.processCommandsWithInfoTrees inputCtx parserState commandState - return ({ env := 0, - messages := ← msgs.mapM fun m => Message.of m, - sorries := [], - tactics := [], - infotree := none - }) + return ({ messages := ← msgs.mapM fun m => Message.of m}) ) - let result := ← IO.wait <| ← IO.mapTasks (·.mapM (pure ·)) tasks.toList + let result := ← IO.wait <| ← IO.mapTasks (·.mapM (pure ·)) tasks.toList (prio := Task.Priority.max) match result with | .ok results => return (results.map @@ -299,17 +289,16 @@ unsafe def batchVerifyParrallelNaive (header : String) (proofs : Array String) : ).toArray | .error _ => return #[] -unsafe def batchVerifyParrallel (header : String) (proofs : Array String) (buckets : Option Nat): IO (Array (CommandResponse ⊕ Error)) := do +unsafe def batchVerifyParrallel (commandState : Command.State) (proofs : Array String) (buckets : Option Nat): IO (Array (VerifyResponse ⊕ Error)) := do let buckets := match buckets with | some x => x | none => max 50 proofs.size - let commandState ← getHeaderEnv header let tasks ← (splitArray proofs buckets |>.mapM <| fun bucket => IO.asTask ( (batchVerifySequential commandState bucket))) - let result := ← IO.wait <| ← IO.mapTasks (·.mapM (pure ·)) tasks.toList + let result := ← IO.wait <| ← IO.mapTasks (·.mapM (pure ·)) tasks.toList (prio := Task.Priority.max) match result with | .ok results => { - let womp : List (Array (CommandResponse ⊕ Error)) ← results.mapM ( + let womp : List (Array (VerifyResponse ⊕ Error)) ← results.mapM ( fun x => do match x with | .ok bucket => pure bucket @@ -328,17 +317,33 @@ def processFile (s : File) : M IO (CommandResponse ⊕ Error) := do catch e => pure <| .inr ⟨e.toString⟩ -unsafe def runBatchVerify (batch : BatchVerify) : IO (Array (CommandResponse ⊕ Error)) := do +unsafe def runBatchVerify (batch : BatchVerify) : M IO (Array (VerifyResponse ⊕ Error) ⊕ Error) := do + let (cmdSnapshot?, notFound) ← do match batch.env with + | none => pure (none, false) + | some i => do match (← get).cmdStates[i]? with + | some env => pure (some env, false) + | none => pure (none, true) + if notFound then + return .inr ⟨"Unknown environment."⟩ + let cmdState? := cmdSnapshot?.map fun c => c.cmdState + let commandState ← match cmdState? with + | none => do + let inputCtx := Parser.mkInputContext "" "" + let (header, _, messages) ← Parser.parseHeader inputCtx + let (env, messages) ← processHeader header {} messages inputCtx + pure (Command.mkState env messages {}) + | some cmdState => do + pure cmdState match batch.mode with | some x => if x = "naive" then do - return ← batchVerifyParrallelNaive batch.header batch.proofs + return .inl <| ← batchVerifyParrallelNaive commandState batch.proofs if x = "parrallel" then do - return ← batchVerifyParrallel batch.header batch.proofs batch.buckets + return .inl <| ← batchVerifyParrallel commandState batch.proofs batch.buckets | none => pure () - let commandState ← getHeaderEnv batch.header - batchVerifySequential commandState batch.proofs + + return .inl <| ← batchVerifySequential commandState batch.proofs /-- Run a single tactic, returning the id of the new proof statement, and the new goals. @@ -462,7 +467,8 @@ unsafe def testSeqential: M IO Unit := do unsafe def testParrallel : IO Unit := do let query ← getLines let ⟨header, proofs⟩ ← parseBatch query - let q ← (batchVerifyParrallelNaive header proofs) + let commandState ← getHeaderEnv header + let q ← (batchVerifyParrallelNaive commandState proofs) for l in q do IO.println (toJson l) diff --git a/test.py b/test.py index a3013e88..5c748e89 100644 --- a/test.py +++ b/test.py @@ -9,14 +9,11 @@ proofs = [] -for data in ds["train"].select(range(1000)): +for data in ds["train"].select(range(2000)): proof = data["full_proof"].split(header)[1] proofs.append(proof) print(header, proof) -# batchCmd = { "header": header, "proofs": proofs} -# tmp = json.dumps(batchCmd) - print("done loading") start_repl_time = time.time() @@ -30,21 +27,20 @@ encoding="utf-8" ) +start = time.time() + +process.stdin.write(json.dumps({"cmd": header}) + "\n\n") + # # Write input directly to the process -process.stdin.write(json.dumps({"header" : header, "proofs": proofs}) + "\n") +process.stdin.write(json.dumps({"env": 0, "proofs": proofs, "mode": "parrallel", "buckets": 50}) + "\n\n") process.stdin.flush() -# process.stdin.write(json.dumps({ "cmd" : header}) + "\n\n") -# process.stdin.flush() -# for proof in proofs: -# process.stdin.write(json.dumps({ "cmd" : proof, "env" : 0}) + "\n\n") -# process.stdin.flush() # Ensure it's sent immediately -# Read output stdout, stderr = process.communicate() +end = time.time() + # Print results print("STDOUT:", stdout) print("STDERR:", stderr) -repl_time = time.time() - start_repl_time -print(f"REPL execution completed in {repl_time:.2f} seconds.") \ No newline at end of file +print("time: ", end - start) From b5c1a8fc544021aa2ea315a4e3ee95c4e13a71b5 Mon Sep 17 00:00:00 2001 From: FrederickPu Date: Thu, 10 Apr 2025 08:48:23 -0400 Subject: [PATCH 5/9] included json changes --- REPL/JSON.lean | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/REPL/JSON.lean b/REPL/JSON.lean index 74e58731..86d16189 100644 --- a/REPL/JSON.lean +++ b/REPL/JSON.lean @@ -36,7 +36,7 @@ structure BatchVerifyOptions where buckets : Option Nat structure BatchVerify extends BatchVerifyOptions where - header : String + env : Option Nat proofs : Array String deriving ToJson, FromJson @@ -139,6 +139,10 @@ structure CommandResponse where infotree : Option Json := none deriving FromJson +structure VerifyResponse where + messages : List Message := [] +deriving FromJson, ToJson + def Json.nonemptyList [ToJson α] (k : String) : List α → List (String × Json) | [] => [] | l => [⟨k, toJson l⟩] From 80f9f6cd920d21a1d3c78f733d9128d33042e388 Mon Sep 17 00:00:00 2001 From: FrederickPu Date: Fri, 11 Apr 2025 15:45:21 -0400 Subject: [PATCH 6/9] better test --- stats.txt | 19 ++++++++++++- stats_naive.txt | 44 ++++++++++++++++++++++++++++++ test.py | 72 ++++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 122 insertions(+), 13 deletions(-) create mode 100644 stats_naive.txt diff --git a/stats.txt b/stats.txt index b011ed25..0ec719c0 100644 --- a/stats.txt +++ b/stats.txt @@ -4,4 +4,21 @@ # theorem womp (a b c : ℕ) : a + b + c = c + (b + a) := by sorry batch verify repl only, without env sharing: 78.0058819 seconds -batch verify repl only, with env sharing: 2.0209996 seconds \ No newline at end of file +batch verify repl only, with env sharing: 2.0209996 seconds + + +# first 20 from lean Goedel LM + +batch verify sequenctial: 472 seconds +batch verify with 2 parrallel tasks: 444.80 seconds + +# first 1000 from lean Goedel LM + +batch verify sequenctial: 8930.02 seconds +batch verify parrallel naive (1 proof per task): 1618.84 seconds +batch verify with 2 parrallel tasks: +batch verify with 5 parrallel tasks: +batch verify with 25 parrallel tasks: 1273.84 +batch verify with 50 parrallel tasks: 1253.51 +batch verify with 100 parrallel tasks: 1554.14 +batch verify with 200 parrallel tasks: 1613.46 diff --git a/stats_naive.txt b/stats_naive.txt new file mode 100644 index 00000000..81a07cdb --- /dev/null +++ b/stats_naive.txt @@ -0,0 +1,44 @@ +# single cmd vs batch command with 7 items (in naive mode) + +done loading +loadded header +[0] Time Cmd: 12.56s, Memory: 5.84 MB +[0] Time Batch: 11.94s, Memory: 5.84 MB +[1] Time Cmd: 12.21s, Memory: 5.84 MB +[1] Time Batch: 12.36s, Memory: 5.84 MB +[2] Time Cmd: 12.31s, Memory: 5.84 MB +[2] Time Batch: 12.16s, Memory: 5.84 MB +[3] Time Cmd: 16.09s, Memory: 5.84 MB +[3] Time Batch: 20.35s, Memory: 4.17 MB +[4] Time Cmd: 19.22s, Memory: 2.19 MB +[4] Time Batch: 22.80s, Memory: 0.45 MB +[5] Time Cmd: 19.09s, Memory: 0.45 MB +[5] Time Batch: 19.34s, Memory: 0.41 MB +[6] Time Cmd: 11.88s, Memory: 0.41 MB +[6] Time Batch: 11.79s, Memory: 0.41 MB +[7] Time Cmd: 15.72s, Memory: 0.41 MB +[7] Time Batch: 13.86s, Memory: 0.31 MB +[8] Time Cmd: 11.00s, Memory: 0.25 MB +[8] Time Batch: 11.92s, Memory: 1.43 MB +[9] Time Cmd: 11.34s, Memory: 1.43 MB +[9] Time Batch: 11.30s, Memory: 1.43 MB +[10] Time Cmd: 11.10s, Memory: 1.43 MB +[10] Time Batch: 11.33s, Memory: 1.43 MB +[11] Time Cmd: 10.84s, Memory: 1.43 MB +[11] Time Batch: 12.03s, Memory: 1.43 MB +[12] Time Cmd: 16.42s, Memory: 1.42 MB +[12] Time Batch: 11.26s, Memory: 1.48 MB +[13] Time Cmd: 10.94s, Memory: 1.48 MB +[13] Time Batch: 13.73s, Memory: 1.47 MB +[14] Time Cmd: 19.77s, Memory: 1.47 MB +[14] Time Batch: 16.88s, Memory: 1.47 MB +[15] Time Cmd: 12.80s, Memory: 1.47 MB +[15] Time Batch: 13.76s, Memory: 1.47 MB +[16] Time Cmd: 11.82s, Memory: 1.47 MB +[16] Time Batch: 16.25s, Memory: 1.45 MB +[17] Time Cmd: 15.12s, Memory: 1.45 MB +[17] Time Batch: 20.97s, Memory: 1.45 MB +[18] Time Cmd: 18.38s, Memory: 1.45 MB +[18] Time Batch: 20.46s, Memory: 1.45 MB +[19] Time Cmd: 18.88s, Memory: 1.51 MB +[19] Time Batch: 19.58s, Memory: 1.51 MB diff --git a/test.py b/test.py index 5c748e89..950a3505 100644 --- a/test.py +++ b/test.py @@ -2,6 +2,7 @@ import json import subprocess import time +import psutil header = "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n" # Login using e.g. `huggingface-cli login` to access this dataset @@ -9,10 +10,10 @@ proofs = [] -for data in ds["train"].select(range(2000)): +for data in ds["train"].select(range(7)): proof = data["full_proof"].split(header)[1] proofs.append(proof) - print(header, proof) + # print(header, proof) print("done loading") @@ -27,20 +28,67 @@ encoding="utf-8" ) -start = time.time() +p = psutil.Process(process.pid) process.stdin.write(json.dumps({"cmd": header}) + "\n\n") - -# # Write input directly to the process -process.stdin.write(json.dumps({"env": 0, "proofs": proofs, "mode": "parrallel", "buckets": 50}) + "\n\n") process.stdin.flush() +while True: + line = process.stdout.readline().strip() + if not line: + break + +print("loadded header") + +for i in range(20): + start = time.time() + + # # Write input directly to the process + process.stdin.write(json.dumps({"env": 0, "cmd": proofs[0]}) + "\n\n") + process.stdin.flush() + + output_lines = [] + while True: + line = process.stdout.readline().strip() + if not line: + break + output_lines.append(line) + + stdout = "\n".join(output_lines) + # print(stdout) + + end = time.time() + + # Monitor memory in MB + mem_info = p.memory_info() + memory_mb = mem_info.rss / (1024 ** 2) + + print(f"[{i}] Time Cmd: {end - start:.2f}s, Memory: {memory_mb:.2f} MB") + + # ---------------------------------- + + start = time.time() + + # # Write input directly to the process + process.stdin.write(json.dumps({"env": 0, "proofs": proofs, "mode": "naive"}) + "\n\n") + process.stdin.flush() + + output_lines = [] + while True: + line = process.stdout.readline().strip() + if not line: + break + output_lines.append(line) + + stdout = "\n".join(output_lines) + # print(stdout) -stdout, stderr = process.communicate() + end = time.time() -end = time.time() + # Monitor memory in MB + mem_info = p.memory_info() + memory_mb = mem_info.rss / (1024 ** 2) -# Print results -print("STDOUT:", stdout) -print("STDERR:", stderr) + print(f"[{i}] Time Batch: {end - start:.2f}s, Memory: {memory_mb:.2f} MB") -print("time: ", end - start) +process.stdin.close() +process.wait() From a8c5c67caf5078e00614d5fa8a29483d0659ec22 Mon Sep 17 00:00:00 2001 From: FrederickPu Date: Tue, 15 Apr 2025 12:44:52 -0400 Subject: [PATCH 7/9] added generalized batch commands --- REPL/JSON.lean | 20 ++--- REPL/Main.lean | 232 ++++++++++++++++++++++++------------------------- test.py | 6 +- 3 files changed, 128 insertions(+), 130 deletions(-) diff --git a/REPL/JSON.lean b/REPL/JSON.lean index 86d16189..733e31fe 100644 --- a/REPL/JSON.lean +++ b/REPL/JSON.lean @@ -26,18 +26,20 @@ If `env = some n`, builds on the existing environment `n`. structure Command extends CommandOptions where env : Option Nat cmd : String + gc : Option Bool := false deriving ToJson, FromJson -structure BatchVerifyOptions where +structure BatchCommandOptions extends CommandOptions where /- - "sequential", "naive", "parrallel" + mode = "sequential", "naive", "parrallel" + buckets is unused if mode is "sequential" or "naive" -/ mode : Option String buckets : Option Nat -structure BatchVerify extends BatchVerifyOptions where +structure BatchCommand extends BatchCommandOptions where env : Option Nat - proofs : Array String + cmds : Array String deriving ToJson, FromJson /-- Process a Lean file in a fresh environment. -/ @@ -132,24 +134,22 @@ A response to a Lean command. `env` can be used in later calls, to build on the stored environment. -/ structure CommandResponse where - env : Nat + env : Option Nat messages : List Message := [] sorries : List Sorry := [] tactics : List Tactic := [] infotree : Option Json := none deriving FromJson -structure VerifyResponse where - messages : List Message := [] -deriving FromJson, ToJson - def Json.nonemptyList [ToJson α] (k : String) : List α → List (String × Json) | [] => [] | l => [⟨k, toJson l⟩] instance : ToJson CommandResponse where toJson r := Json.mkObj <| .flatten [ - [("env", r.env)], + match r.env with + | some x => [("env", x)] + | none => [], Json.nonemptyList "messages" r.messages, Json.nonemptyList "sorries" r.sorries, Json.nonemptyList "tactics" r.tactics, diff --git a/REPL/Main.lean b/REPL/Main.lean index a12ca628..61deace5 100644 --- a/REPL/Main.lean +++ b/REPL/Main.lean @@ -110,6 +110,21 @@ def sorries (trees : List InfoTree) (env? : Option Environment) : M m (List Sorr let proofStateId ← proofState.mapM recordProofSnapshot return Sorry.of goal pos endPos proofStateId +def sorriesGC (trees : List InfoTree) (env? : Option Environment) : IO (List Sorry) := + trees.flatMap InfoTree.sorries |>.filter (fun t => match t.2.1 with + | .term _ none => false + | _ => true ) |>.mapM + fun ⟨ctx, g, pos, endPos⟩ => do + let (goal, _) ← match g with + | .tactic g => do + let s ← ProofSnapshot.create ctx none env? [g] + pure ("\n".intercalate <| (← s.ppGoals).map fun s => s!"{s}", some s) + | .term lctx (some t) => do + let s ← ProofSnapshot.create ctx lctx env? [] [t] + pure ("\n".intercalate <| (← s.ppGoals).map fun s => s!"{s}", some s) + | .term _ none => unreachable! + return Sorry.of goal pos endPos none + def ppTactic (ctx : ContextInfo) (stx : Syntax) : IO Format := ctx.runMetaM {} try Lean.PrettyPrinter.ppTactic ⟨stx⟩ @@ -125,6 +140,13 @@ def tactics (trees : List InfoTree) : M m (List Tactic) := let proofStateId ← proofState.mapM recordProofSnapshot return Tactic.of goals tactic pos endPos proofStateId ns +def tacticsGC (trees : List InfoTree) : IO (List Tactic) := + trees.flatMap InfoTree.tactics |>.mapM + fun ⟨ctx, stx, goals, pos, endPos, ns⟩ => do + let goals := s!"{(← ctx.ppGoals goals)}".trim + let tactic := Format.pretty (← ppTactic ctx stx) + return Tactic.of goals tactic pos endPos none ns + /-- Record a `ProofSnapshot` and generate a JSON response for it. -/ def createProofStepReponse (proofState : ProofSnapshot) (old? : Option ProofSnapshot := none) : M m ProofStepResponse := do @@ -184,18 +206,43 @@ def unpickleProofSnapshot (n : UnpickleProofState) : M IO (ProofStepResponse ⊕ let (proofState, _) ← ProofSnapshot.unpickle n.unpickleProofStateFrom cmdSnapshot? Sum.inl <$> createProofStepReponse proofState -/-- -Run a command, returning the id of the new environment, and any messages and sorries. --/ -def runCommand (s : Command) : M IO (CommandResponse ⊕ Error) := do - let (cmdSnapshot?, notFound) ← do match s.env with + +def getCommandSnapshot (env : Option Nat) : M IO (Option CommandSnapshot × Bool) := do match env with | none => pure (none, false) | some i => do match (← get).cmdStates[i]? with | some env => pure (some env, false) | none => pure (none, true) - if notFound then - return .inr ⟨"Unknown environment."⟩ - let initialCmdState? := cmdSnapshot?.map fun c => c.cmdState + +def runCommandGCAux (initialCmdState? : Option Command.State) (s : Command) : IO (CommandResponse ⊕ Error):= do + let (_, messages, trees) ← try + IO.processInput s.cmd initialCmdState? + catch ex => + return .inr ⟨ex.toString⟩ + let messages ← messages.mapM fun m => Message.of m + -- For debugging purposes, sometimes we print out the trees here: + -- trees.forM fun t => do IO.println (← t.format) + let sorries ← sorriesGC trees (initialCmdState?.map (·.env)) + let tactics ← match s.allTactics with + | some true => tacticsGC trees + | _ => pure [] + let jsonTrees := match s.infotree with + | some "full" => trees + | some "tactics" => trees.flatMap InfoTree.retainTacticInfo + | some "original" => trees.flatMap InfoTree.retainTacticInfo |>.flatMap InfoTree.retainOriginal + | some "substantive" => trees.flatMap InfoTree.retainTacticInfo |>.flatMap InfoTree.retainSubstantive + | _ => [] + let infotree ← if jsonTrees.isEmpty then + pure none + else + pure <| some <| Json.arr (← jsonTrees.toArray.mapM fun t => t.toJson none) + return .inl + { env := none, + messages, + sorries, + tactics + infotree } + +def runCommandAux (cmdSnapshot? : Option CommandSnapshot) (initialCmdState? : Option Command.State) (s : Command) : M IO (CommandResponse ⊕ Error) := do let (cmdState, messages, trees) ← try IO.processInput s.cmd initialCmdState? catch ex => @@ -207,14 +254,6 @@ def runCommand (s : Command) : M IO (CommandResponse ⊕ Error) := do let tactics ← match s.allTactics with | some true => tactics trees | _ => pure [] - let cmdSnapshot := - { cmdState - cmdContext := (cmdSnapshot?.map fun c => c.cmdContext).getD - { fileName := "", - fileMap := default, - snap? := none, - cancelTk? := none } } - let env ← recordCommandSnapshot cmdSnapshot let jsonTrees := match s.infotree with | some "full" => trees | some "tactics" => trees.flatMap InfoTree.retainTacticInfo @@ -225,13 +264,35 @@ def runCommand (s : Command) : M IO (CommandResponse ⊕ Error) := do pure none else pure <| some <| Json.arr (← jsonTrees.toArray.mapM fun t => t.toJson none) + + let cmdSnapshot := + { cmdState + cmdContext := (cmdSnapshot?.map fun c => c.cmdContext).getD + { fileName := "", + fileMap := default, + snap? := none, + cancelTk? := none } } + let env ← recordCommandSnapshot cmdSnapshot return .inl - { env, + { env := some env, messages, sorries, tactics infotree } +/-- +Run a command, returning the id of the new environment, and any messages and sorries. +-/ +def runCommand (s : Command) : M IO (CommandResponse ⊕ Error) := do + let (cmdSnapshot?, notFound) ← getCommandSnapshot s.env + if notFound then + return .inr ⟨"Unknown environment."⟩ + let initialCmdState? := cmdSnapshot?.map fun c => c.cmdState + if s.gc = some true then + runCommandGCAux initialCmdState? s + else + runCommandAux cmdSnapshot? initialCmdState? s + def splitArray {α : Type} (arr : Array α) (n : Nat) : Array (Array α) := Id.run do if n ≤ 0 then #[] else if n = 1 then #[arr] @@ -251,64 +312,33 @@ def splitArray {α : Type} (arr : Array α) (n : Nat) : Array (Array α) := Id.r start := start + baseSize + extraElem result --- #eval splitArray #[1] 4 - -unsafe def getHeaderEnv (header : String) : IO Command.State := do - Lean.initSearchPath (← Lean.findSysroot) - enableInitializersExecution - let inputCtx := Parser.mkInputContext header "" - let (header, parserState, messages) ← Parser.parseHeader inputCtx - let (env, _) ← processHeader header {} messages inputCtx - let commandState := (Command.mkState env messages {}) - let s ← IO.processCommands inputCtx parserState commandState <&> Frontend.State.commandState - pure s - -unsafe def batchVerifySequential (commandState : Command.State) (proofs : Array String) : IO (Array (VerifyResponse ⊕ Error)) := do - proofs.mapM (fun pf => do - let inputCtx := Parser.mkInputContext pf "" - let parserState := { : Parser.ModuleParserState } - let (_, msgs, _) ← Lean.Elab.IO.processCommandsWithInfoTrees inputCtx parserState commandState - return .inl ({ messages := ← msgs.mapM fun m => Message.of m }) - ) +unsafe def batchVerifySequential (initialCmdState : Command.State) (cmds : Array Command) : IO (Array (CommandResponse ⊕ Error)) := cmds.mapM (fun cmd => runCommandGCAux initialCmdState cmd) -unsafe def batchVerifyParrallelNaive (commandState : Command.State) (proofs : Array String) : IO (Array (VerifyResponse ⊕ Error)) := do - let tasks : Array (Task (Except IO.Error VerifyResponse)) ← (proofs.mapM <| fun proof => IO.asTask <| do - let inputCtx := Parser.mkInputContext proof "" - let parserState := { : Parser.ModuleParserState } - let (_, msgs, _) ← Lean.Elab.IO.processCommandsWithInfoTrees inputCtx parserState commandState - return ({ messages := ← msgs.mapM fun m => Message.of m}) +unsafe def batchVerifyParrallelNaive (initialCmdState : Command.State) (cmds : Array Command) : IO (Array (CommandResponse ⊕ Error)) := do + let tasks : Array (Task (Except IO.Error (CommandResponse ⊕ Error))) ← (cmds.mapM <| fun cmd => IO.asTask (runCommandGCAux initialCmdState cmd) ) - let result := ← IO.wait <| ← IO.mapTasks (·.mapM (pure ·)) tasks.toList (prio := Task.Priority.max) - match result with - | .ok results => - return (results.map - fun x => - match x with - | .ok cmdres => .inl cmdres - | .error e => .inr (Error.mk e.toString) - ).toArray - | .error _ => return #[] - -unsafe def batchVerifyParrallel (commandState : Command.State) (proofs : Array String) (buckets : Option Nat): IO (Array (VerifyResponse ⊕ Error)) := do + tasks.mapM fun task => do + try + match task.get with + | .ok cmdres => return cmdres + | .error e => return .inr ⟨e.toString⟩ + catch e => + return .inr ⟨e.toString⟩ + +unsafe def batchVerifyParrallel (commandState : Command.State) (cmds : Array Command) (buckets : Option Nat): IO (Array (CommandResponse ⊕ Error)) := do let buckets := match buckets with | some x => x - | none => max 50 proofs.size - let tasks ← (splitArray proofs buckets |>.mapM <| fun bucket => IO.asTask ( (batchVerifySequential commandState bucket))) - let result := ← IO.wait <| ← IO.mapTasks (·.mapM (pure ·)) tasks.toList (prio := Task.Priority.max) - match result with - | .ok results => { - let womp : List (Array (VerifyResponse ⊕ Error)) ← results.mapM ( - fun x => do - match x with - | .ok bucket => pure bucket - | .error e => - IO.println e - pure #[] - ) - return womp.toArray.flatMap id - } - | .error _ => return #[] + | none => max 50 cmds.size + let tasks ← (splitArray cmds buckets |>.mapM <| fun bucket => IO.asTask ( (batchVerifySequential commandState bucket))) + tasks.flatMapM <| + fun task => do + try + match task.get with + | .ok cmdres => return cmdres + | .error e => return Array.replicate buckets (.inr ⟨e.toString⟩) + catch e => + return Array.replicate buckets (.inr ⟨e.toString⟩) def processFile (s : File) : M IO (CommandResponse ⊕ Error) := do try @@ -317,33 +347,30 @@ def processFile (s : File) : M IO (CommandResponse ⊕ Error) := do catch e => pure <| .inr ⟨e.toString⟩ -unsafe def runBatchVerify (batch : BatchVerify) : M IO (Array (VerifyResponse ⊕ Error) ⊕ Error) := do - let (cmdSnapshot?, notFound) ← do match batch.env with - | none => pure (none, false) - | some i => do match (← get).cmdStates[i]? with - | some env => pure (some env, false) - | none => pure (none, true) +unsafe def runBatchVerify (batch : BatchCommand) : M IO (Array (CommandResponse ⊕ Error) ⊕ Error) := do + let (cmdSnapshot?, notFound) ← getCommandSnapshot batch.env if notFound then return .inr ⟨"Unknown environment."⟩ let cmdState? := cmdSnapshot?.map fun c => c.cmdState let commandState ← match cmdState? with - | none => do - let inputCtx := Parser.mkInputContext "" "" - let (header, _, messages) ← Parser.parseHeader inputCtx - let (env, messages) ← processHeader header {} messages inputCtx - pure (Command.mkState env messages {}) - | some cmdState => do - pure cmdState + | none => do + let inputCtx := Parser.mkInputContext "" "" + let (header, _, messages) ← Parser.parseHeader inputCtx + let (env, messages) ← processHeader header {} messages inputCtx + pure (Command.mkState env messages {}) + | some cmdState => do + pure cmdState + let cmds : Array Command := (batch.cmds.map fun cmd => { toCommandOptions := batch.toCommandOptions, env := none, cmd := cmd}) match batch.mode with | some x => if x = "naive" then do - return .inl <| ← batchVerifyParrallelNaive commandState batch.proofs + return .inl <| ← batchVerifyParrallelNaive commandState cmds if x = "parrallel" then do - return .inl <| ← batchVerifyParrallel commandState batch.proofs batch.buckets + return .inl <| ← batchVerifyParrallel commandState cmds batch.buckets | none => pure () - return .inl <| ← batchVerifySequential commandState batch.proofs + return .inl <| ← batchVerifySequential commandState cmds /-- Run a single tactic, returning the id of the new proof statement, and the new goals. @@ -382,16 +409,6 @@ structure Batch where proofs : Array String deriving FromJson, ToJson - -def parseBatch (query : String) : IO Batch := do - let json := Json.parse query - match json with - | .error e => throw <| IO.userError <| toString <| toJson <| - (⟨"Could not parse JSON:\n" ++ e⟩ : Error) - | .ok j => match fromJson? j with - | .ok (r : Batch) => return r - | .error e => throw <| IO.userError <| toString <| toJson <| (⟨"Could not parse JSON batch:\n" ++ e⟩ : Error) - /-- Commands accepted by the REPL. -/ inductive Input | command : REPL.Command → Input @@ -401,7 +418,7 @@ inductive Input | unpickleEnvironment : REPL.UnpickleEnvironment → Input | pickleProofSnapshot : REPL.PickleProofState → Input | unpickleProofSnapshot : REPL.UnpickleProofState → Input -| batchVerify : REPL.BatchVerify → Input +| batchVerify : REPL.BatchCommand → Input /-- Parse a user input string to an input command. -/ def parse (query : String) : IO Input := do @@ -422,7 +439,7 @@ def parse (query : String) : IO Input := do | .error _ => match fromJson? j with | .ok (r : REPL.Command) => return .command r | .error _ => match fromJson? j with - | .ok (r : REPL.BatchVerify) => return .batchVerify r + | .ok (r : REPL.BatchCommand) => return .batchVerify r | .error _ => match fromJson? j with | .ok (r : REPL.File) => return .file r | .error e => throw <| IO.userError <| toString <| toJson <| @@ -454,25 +471,6 @@ where loop : M IO Unit := do printFlush "\n" -- easier to parse the output if there are blank lines loop - -unsafe def testSeqential: M IO Unit := do - let query ← getLines - let ⟨header, proofs⟩ ← parseBatch query - let commandState ← getHeaderEnv header - let q ← (batchVerifySequential commandState proofs) - for l in q do - IO.println (toJson l) - --- #check CommandOptions.mk -unsafe def testParrallel : IO Unit := do - let query ← getLines - let ⟨header, proofs⟩ ← parseBatch query - let commandState ← getHeaderEnv header - let q ← (batchVerifyParrallelNaive commandState proofs) - for l in q do - IO.println (toJson l) - - /-- Main executable function, run as `lake exe repl`. -/ unsafe def main (_ : List String) : IO Unit := do initSearchPath (← Lean.findSysroot) diff --git a/test.py b/test.py index 950a3505..d70d3758 100644 --- a/test.py +++ b/test.py @@ -10,7 +10,7 @@ proofs = [] -for data in ds["train"].select(range(7)): +for data in ds["train"].select(range(8)): proof = data["full_proof"].split(header)[1] proofs.append(proof) # print(header, proof) @@ -39,7 +39,7 @@ print("loadded header") -for i in range(20): +for i in range(5): start = time.time() # # Write input directly to the process @@ -69,7 +69,7 @@ start = time.time() # # Write input directly to the process - process.stdin.write(json.dumps({"env": 0, "proofs": proofs, "mode": "naive"}) + "\n\n") + process.stdin.write(json.dumps({"env": 0, "cmds": proofs, "mode": "sequential"}) + "\n\n") process.stdin.flush() output_lines = [] From 9693c5a0cd0ce1d6e1b73c97e1fb0cf81c457da5 Mon Sep 17 00:00:00 2001 From: FrederickPu Date: Tue, 15 Apr 2025 14:00:44 -0400 Subject: [PATCH 8/9] added readme description --- README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.md b/README.md index d6adc283..4f522d34 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,26 @@ Example output: showing any messages generated, and sorries with their goal states. +## Batch Command mode + +Multiple commands can be run in a single batch using + +```json +{ "cmds": ["theorem womp : 2 + 2 = 4 := by rfl", "theorem womp1 : 2 + 4 = 6 := by rfl"]} +``` + +All the same options from Command can be used and will be applied to each command in the `cmds` array. Additionally, you can specify the parrallelism mode using `mode` + +```json +{ "cmds": ["theorem womp : 2 + 2 = 4 := by rfl", "theorem womp1 : 2 + 4 = 6 := by rfl"], "mode": "sequential"} +{ "cmds": ["theorem womp : 2 + 2 = 4 := by rfl", "theorem womp1 : 2 + 4 = 6 := by rfl"], "mode": "naive"} +{ "cmds": ["theorem womp : 2 + 2 = 4 := by rfl", "theorem womp1 : 2 + 4 = 6 := by rfl"], "mode": "parrallel", "buckets": 10} +``` +`sequential` runs all of the commands sequentially. `naive` runs each command as its own multithreading task. `parrallel` splits the commands into `buckets` number of buckets and runs each bucket as a multithreading task. +By default the parrallelism mode will be `sequential` + +Note that for parrallelism, both Command and Proof snapshots will not be saved. + ## File mode There is a simple wrapper around command mode that allows reading in an entire file. From 0336da11eca8ac0b2c7ca3b5138e446f07878a0b Mon Sep 17 00:00:00 2001 From: FrederickPu Date: Thu, 17 Apr 2025 20:47:55 -0400 Subject: [PATCH 9/9] added timeouts --- REPL/Frontend.lean | 15 +++++++++++++++ REPL/JSON.lean | 1 + REPL/Main.lean | 34 ++++++++++++++++++++++++---------- test.py | 27 ++++++++++++++++++++++++++- testtimeout.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 109 insertions(+), 11 deletions(-) create mode 100644 testtimeout.py diff --git a/REPL/Frontend.lean b/REPL/Frontend.lean index 9cc09145..52eabe02 100644 --- a/REPL/Frontend.lean +++ b/REPL/Frontend.lean @@ -45,3 +45,18 @@ def processInput (input : String) (cmdState? : Option Command.State) | some cmdState => do pure ({ : Parser.ModuleParserState }, cmdState) processCommandsWithInfoTrees inputCtx parserState commandState + + +/- + asTask but with a timeout +-/ +def withTimeout {α : Type} (act : IO α) (timeoutMs : Nat) (prio := Task.Priority.default) : IO (Except IO.Error α) := do + let task ← IO.asTask act prio + for _ in [0:timeoutMs / 1000] do + if ← IO.hasFinished task then + return task.get + else + IO.sleep 1000 + IO.cancel task + -- I'm not sure what the actual error code should be + return .error <| IO.Error.timeExpired 0 "timeout exceeded" diff --git a/REPL/JSON.lean b/REPL/JSON.lean index 733e31fe..5ed2323b 100644 --- a/REPL/JSON.lean +++ b/REPL/JSON.lean @@ -36,6 +36,7 @@ structure BatchCommandOptions extends CommandOptions where -/ mode : Option String buckets : Option Nat + timeout : Option Nat structure BatchCommand extends BatchCommandOptions where env : Option Nat diff --git a/REPL/Main.lean b/REPL/Main.lean index 61deace5..20c72ade 100644 --- a/REPL/Main.lean +++ b/REPL/Main.lean @@ -213,7 +213,7 @@ def getCommandSnapshot (env : Option Nat) : M IO (Option CommandSnapshot × Boo | some env => pure (some env, false) | none => pure (none, true) -def runCommandGCAux (initialCmdState? : Option Command.State) (s : Command) : IO (CommandResponse ⊕ Error):= do +def runCommandGCAux (initialCmdState? : Option Command.State) (s : Command) : IO (CommandResponse ⊕ Error) := do let (_, messages, trees) ← try IO.processInput s.cmd initialCmdState? catch ex => @@ -242,6 +242,11 @@ def runCommandGCAux (initialCmdState? : Option Command.State) (s : Command) : IO tactics infotree } +def runCommandGCAuxTimeout (initialCmdState? : Option Command.State) (s : Command) (timeout : Nat) : IO (CommandResponse ⊕ Error) := do + match ← IO.withTimeout (runCommandGCAux initialCmdState? s) timeout with + | .ok res => return res + | .error e => return .inr ⟨e.toString⟩ + def runCommandAux (cmdSnapshot? : Option CommandSnapshot) (initialCmdState? : Option Command.State) (s : Command) : M IO (CommandResponse ⊕ Error) := do let (cmdState, messages, trees) ← try IO.processInput s.cmd initialCmdState? @@ -312,10 +317,19 @@ def splitArray {α : Type} (arr : Array α) (n : Nat) : Array (Array α) := Id.r start := start + baseSize + extraElem result -unsafe def batchVerifySequential (initialCmdState : Command.State) (cmds : Array Command) : IO (Array (CommandResponse ⊕ Error)) := cmds.mapM (fun cmd => runCommandGCAux initialCmdState cmd) - -unsafe def batchVerifyParrallelNaive (initialCmdState : Command.State) (cmds : Array Command) : IO (Array (CommandResponse ⊕ Error)) := do - let tasks : Array (Task (Except IO.Error (CommandResponse ⊕ Error))) ← (cmds.mapM <| fun cmd => IO.asTask (runCommandGCAux initialCmdState cmd) +unsafe def batchVerifySequential (initialCmdState : Command.State) (cmds : Array Command) (timeout : Option Nat): IO (Array (CommandResponse ⊕ Error)) := + let commandRunner := + match timeout with + | some t => fun cmd => runCommandGCAuxTimeout initialCmdState cmd t + | none => runCommandGCAux initialCmdState + cmds.mapM commandRunner + +unsafe def batchVerifyParrallelNaive (initialCmdState : Command.State) (cmds : Array Command) (timeout : Option Nat) : IO (Array (CommandResponse ⊕ Error)) := do + let commandRunner := + match timeout with + | some t => fun cmd => runCommandGCAuxTimeout initialCmdState cmd t + | none => runCommandGCAux initialCmdState + let tasks : Array (Task (Except IO.Error (CommandResponse ⊕ Error))) ← (cmds.mapM <| fun cmd => IO.asTask (commandRunner cmd) ) tasks.mapM fun task => do try @@ -325,12 +339,12 @@ unsafe def batchVerifyParrallelNaive (initialCmdState : Command.State) (cmds : A catch e => return .inr ⟨e.toString⟩ -unsafe def batchVerifyParrallel (commandState : Command.State) (cmds : Array Command) (buckets : Option Nat): IO (Array (CommandResponse ⊕ Error)) := do +unsafe def batchVerifyParrallel (commandState : Command.State) (cmds : Array Command) (buckets : Option Nat) (timeout : Option Nat) : IO (Array (CommandResponse ⊕ Error)) := do let buckets := match buckets with | some x => x | none => max 50 cmds.size - let tasks ← (splitArray cmds buckets |>.mapM <| fun bucket => IO.asTask ( (batchVerifySequential commandState bucket))) + let tasks ← (splitArray cmds buckets |>.mapM <| fun bucket => IO.asTask ( (batchVerifySequential commandState bucket timeout))) tasks.flatMapM <| fun task => do try @@ -364,13 +378,13 @@ unsafe def runBatchVerify (batch : BatchCommand) : M IO (Array (CommandResponse match batch.mode with | some x => if x = "naive" then do - return .inl <| ← batchVerifyParrallelNaive commandState cmds + return .inl <| ← batchVerifyParrallelNaive commandState cmds batch.timeout if x = "parrallel" then do - return .inl <| ← batchVerifyParrallel commandState cmds batch.buckets + return .inl <| ← batchVerifyParrallel commandState cmds batch.buckets batch.timeout | none => pure () - return .inl <| ← batchVerifySequential commandState cmds + return .inl <| ← batchVerifySequential commandState cmds batch.timeout /-- Run a single tactic, returning the id of the new proof statement, and the new goals. diff --git a/test.py b/test.py index d70d3758..483257c7 100644 --- a/test.py +++ b/test.py @@ -64,12 +64,37 @@ print(f"[{i}] Time Cmd: {end - start:.2f}s, Memory: {memory_mb:.2f} MB") + # ---------------------------------- + start = time.time() + + # # Write input directly to the process + process.stdin.write(json.dumps({"env": 0, "cmd": proofs[0], "gc": True}) + "\n\n") + process.stdin.flush() + + output_lines = [] + while True: + line = process.stdout.readline().strip() + if not line: + break + output_lines.append(line) + + stdout = "\n".join(output_lines) + # print(stdout) + + end = time.time() + + # Monitor memory in MB + mem_info = p.memory_info() + memory_mb = mem_info.rss / (1024 ** 2) + + print(f"[{i}] Time Cmd GC: {end - start:.2f}s, Memory: {memory_mb:.2f} MB") + # ---------------------------------- start = time.time() # # Write input directly to the process - process.stdin.write(json.dumps({"env": 0, "cmds": proofs, "mode": "sequential"}) + "\n\n") + process.stdin.write(json.dumps({"env": 0, "cmds": proofs, "mode": "naive", "timeout": 60000}) + "\n\n") process.stdin.flush() output_lines = [] diff --git a/testtimeout.py b/testtimeout.py new file mode 100644 index 00000000..ee6a40c6 --- /dev/null +++ b/testtimeout.py @@ -0,0 +1,43 @@ +import subprocess +import json +import time + +process = subprocess.Popen( + ["lake", "env", "../../.lake/build/bin/repl"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, # Makes it work with strings + encoding="utf-8" +) + +header = 'import Mathlib\nimport Aesop' +proof = '\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/- Prove $x^2 + x + y^2 + y + 1 \\geq x y$ for all real x,y -/\ntheorem lean_workbook_1003 (x y: ℝ): x ^ 2 + x + y ^ 2 + y + 1 ≥ x * y := by\n /-\n To prove the inequality \\( x^2 + x + y^2 + y + 1 \\geq x y \\) for all real numbers \\( x \\) and \\( y \\), we start by considering the expression \\( x^2 + x + y^2 + y + 1 - x y \\). We need to show that this expression is non-negative for all \\( x \\) and \\( y \\).\n First, we rewrite the expression:\n \\[ x^2 + x + y^2 + y + 1 - x y \\]\n Next, we use the fact that the square of any real number is non-negative. We consider the squares of the following expressions:\n \\[ (x + 1)^2 \\]\n \\[ (y + 1)^2 \\]\n \\[ (x - y)^2 \\]\n Expanding these squares, we get:\n \\[ (x + 1)^2 = x^2 + 2x + 1 \\]\n \\[ (y + 1)^2 = y^2 + 2y + 1 \\]\n \\[ (x - y)^2 = x^2 - 2xy + y^2 \\]\n Adding these together, we have:\n \\[ (x + 1)^2 + (y + 1)^2 + (x - y)^2 = x^2 + 2x + 1 + y^2 + 2y + 1 + x^2 - 2xy + y^2 = 2x^2 + 2y^2 + 2x + 2y + 2 - 2xy \\]\n Simplifying, we get:\n \\[ 2x^2 + 2y^2 + 2x + 2y + 2 - 2xy = 2(x^2 + x + y^2 + y) + 2(1 - xy) \\]\n Since squares are non-negative, the sum \\( (x + 1)^2 + (y + 1)^2 + (x - y)^2 \\) is non-negative. Therefore, \\( 2(x^2 + x + y^2 + y) + 2(1 - xy) \\geq 0 \\), which implies:\n \\[ x^2 + x + y^2 + y + 1 - xy \\geq 0 \\]\n Thus, we have:\n \\[ x^2 + x + y^2 + y + 1 \\geq x y \\]\n -/\n -- Use the fact that squares are non-negative to prove the inequality.\n -- Consider the squares of the following expressions:\n -- (x + 1)^2, (y + 1)^2, and (x - y)^2.\n nlinarith [sq_nonneg (x + 1), sq_nonneg (y + 1), sq_nonneg (x - y),\n sq_nonneg (x + y), sq_nonneg (x + y + 1), sq_nonneg (x + y - 1)]\n' +proof1 = "\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/- Let $a,b,c$ be real numbers such that $a^{2}+b^{2}+c^{2}=3$ \\nShow: $|a|+|b|+|c|-abc\\leq 4$ -/\ntheorem lean_workbook_10036 (a b c : ℝ) (h : a^2 + b^2 + c^2 = 3) : |a| + |b| + |c| - a * b * c ≤ 4 := by\n /-\n Given real numbers \\(a, b, c\\) such that \\(a^2 + b^2 + c^2 = 3\\), we need to show that \\(|a| + |b| + |c| - abc \\leq 4\\). We will consider different cases based on the signs of \\(a, b, c\\) and use the given condition to derive the inequality.\n -/\n -- Consider different cases based on the signs of a, b, and c\n cases' le_total 0 a with ha ha <;>\n cases' le_total 0 b with hb hb <;>\n cases' le_total 0 c with hc hc <;>\n -- Simplify the absolute values based on the signs\n simp_all only [abs_of_nonneg, abs_of_nonpos, add_left_neg, add_right_neg, add_assoc] <;>\n -- Use linear arithmetic to prove the inequality in each case\n nlinarith [sq_nonneg (a - b), sq_nonneg (b - c), sq_nonneg (c - a), h, sq_nonneg (a + b), sq_nonneg (b + c), sq_nonneg (c + a)]\n" + +process.stdin.write(json.dumps({"cmd": header}) + "\n\n") +process.stdin.flush() +while True: + line = process.stdout.readline().strip() + if not line: + break + + # # Write input directly to the process +process.stdin.write(json.dumps({"env": 0, "cmds": [proof1], "mode": "naive"}) + "\n\n") +process.stdin.flush() + +start = time.time() + +output_lines = [] +while True: + line = process.stdout.readline().strip() + if not line: + break + output_lines.append(line) + +stdout = "\n".join(output_lines) +print(stdout) + +end = time.time() + +print("time: ", end - start) \ No newline at end of file