diff --git a/datagen/rule2code/cwe2code.py b/datagen/rule2code/cwe2code.py index bff493a..b43fe54 100644 --- a/datagen/rule2code/cwe2code.py +++ b/datagen/rule2code/cwe2code.py @@ -219,17 +219,17 @@ def generate_followup_prompt(): """ -def _create_client(remote_api=False): - if remote_api: - return OpenAI(base_url="https://api.deepseek.com"), "deepseek-reasoner" - # Otherwise sglang - return OpenAI(api_key="none", base_url="http://0.0.0.0:30000/v1"), "default" +def _create_client(): + return ( + OpenAI(api_key="none", base_url="http://localhost:30000/v1"), + "default", + ) -def datagen_for_one_cwe(cwe_id, markdown, depth, remote_api=False): +def datagen_for_one_cwe(cwe_id, markdown, depth): assert depth > 0 - client, model = _create_client(remote_api=remote_api) + client, model = _create_client() common_args = {"model": model, "temperature": 0.6} rprint(f"[bold yellow]Processing: CWE ID: {cwe_id}[/bold yellow]") @@ -277,10 +277,10 @@ def main( parallel=256, output_path="outputs/rule2code/cwe2code.jsonl", depth=1, - remote_api=False, ): + os.makedirs(os.path.dirname(output_path), exist_ok=True) + collection = create_cwe_information() - # each line: cwe_id, conversation finished = set() if os.path.exists(output_path): @@ -294,9 +294,7 @@ def main( if cwe_id in finished: continue futures.append( - executor.submit( - datagen_for_one_cwe, cwe_id, markdown, depth, remote_api - ) + executor.submit(datagen_for_one_cwe, cwe_id, markdown, depth) ) for future in tqdm(as_completed(futures), total=len(futures)): diff --git a/datagen/rule2code/guru2code.py b/datagen/rule2code/guru2code.py index ac64a35..8f3d4d6 100644 --- a/datagen/rule2code/guru2code.py +++ b/datagen/rule2code/guru2code.py @@ -137,10 +137,7 @@ def generate_followup_prompt(seed_data): --- END OF EXAMPLE ---""" -def _create_client(remote_api=False): - if remote_api: - load_dotenv() - return None, "bedrock/converse/us.deepseek.r1-v1:0" +def _create_client(): return ( OpenAI(api_key="none", base_url="http://localhost:30000/v1"), "default", @@ -152,9 +149,8 @@ def datagen_for_one_seed( output_file, finished_pairs, depth=1, - remote_api=False, ): - client, model = _create_client(remote_api=remote_api) + client, model = _create_client() common_args = { "model": model, "temperature": 0.8, @@ -173,13 +169,7 @@ def datagen_for_one_seed( ] for i in range(depth): - if remote_api: - response = batch_completion( - model=model, - messages=[messages], - )[0] - else: - response = client.chat.completions.create(messages=messages, **common_args) + response = client.chat.completions.create(messages=messages, **common_args) if response.choices[0].finish_reason == "length": break @@ -213,7 +203,6 @@ def main( parallel=256, output_path="outputs/rule2code/guru2code.jsonl", depth=1, - remote_api=False, ): os.makedirs(os.path.dirname(output_path), exist_ok=True) @@ -239,7 +228,6 @@ def main( output_path, finished_pairs, depth, - remote_api, ) )