From b45ddd4626153d28703280177bd8561ea099e2cc Mon Sep 17 00:00:00 2001 From: bogay Date: Wed, 30 Jul 2025 02:38:55 +0800 Subject: [PATCH] feat: add API to migrate submission output --- migrate.go | 115 +++++++++++++++++--------------------------- model/submission.py | 15 ++++++ mongo/submission.py | 98 ++++++++++++++++++++++++++++++++++++- 3 files changed, 156 insertions(+), 72 deletions(-) diff --git a/migrate.go b/migrate.go index fe46ae1..d555434 100644 --- a/migrate.go +++ b/migrate.go @@ -8,8 +8,6 @@ import ( "log" "net/http" "net/http/cookiejar" - "os" - "strconv" "sync" ) @@ -28,11 +26,11 @@ type LoginRequest struct { Password string `json:"password"` } -const baseAPIURL = "http://localhost:8080/api" +const baseApiUrl = "http://localhost:8080/api" // loginUser authenticates the user and returns an http.Client with session cookies func loginUser(username, password string) (*http.Client, error) { - loginURL := fmt.Sprintf("%s/auth/session", baseAPIURL) + loginURL := fmt.Sprintf("%s/auth/session", baseApiUrl) loginPayload := LoginRequest{Username: username, Password: password} payloadBytes, err := json.Marshal(loginPayload) if err != nil { @@ -68,7 +66,7 @@ func loginUser(username, password string) (*http.Client, error) { // fetchSubmissionIDs retrieves submission IDs from the API using the provided http.Client func fetchSubmissionIDs(offset, count int, client *http.Client) ([]string, error) { - url := fmt.Sprintf("%s/submission?count=%d&offset=%d", baseAPIURL, count, offset) + url := fmt.Sprintf("%s/submission?count=%d&offset=%d", baseApiUrl, count, offset) req, err := http.NewRequest("GET", url, nil) if err != nil { @@ -100,7 +98,7 @@ func fetchSubmissionIDs(offset, count int, client *http.Client) ([]string, error // migrateSubmissionCode sends a POST request to migrate the code for a given submission ID using the provided http.Client func migrateSubmissionCode(submissionID string, client *http.Client) { log.Printf("Processing submissionId: %s", submissionID) - url := fmt.Sprintf("%s/submission/%s/migrate-code", baseAPIURL, submissionID) + url := fmt.Sprintf("%s/submission/%s/migrate-code", baseApiUrl, submissionID) req, err := http.NewRequest("POST", url, nil) if err != nil { @@ -127,6 +125,36 @@ func migrateSubmissionCode(submissionID string, client *http.Client) { log.Printf("Successfully triggered migration for submissionId: %s (Status: %s)", submissionID, resp.Status) } +// migrateSubmissionOutput sends a POST request to migrate the output for a given submission ID using the provided http.Client +func migrateSubmissionOutput(submissionID string, client *http.Client) { + log.Printf("Processing submissionId: %s for output migration", submissionID) + url := fmt.Sprintf("%s/submission/%s/migrate-output", baseApiUrl, submissionID) + + req, err := http.NewRequest("POST", url, nil) + if err != nil { + log.Printf("Error creating request for submissionId %s: %v", submissionID, err) + return + } + + resp, err := client.Do(req) + if err != nil { + log.Printf("Error migrating output for submissionId %s: %v", submissionID, err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusAccepted { + bodyBytes, err := ReadAll(resp) + if err != nil { + bodyBytes = []byte("failed to read response body") + } + log.Printf("Error migrating output for submissionId %s: status code %d, body: %s", submissionID, resp.StatusCode, string(bodyBytes)) + return + } + + log.Printf("Successfully triggered output migration for submissionId: %s (Status: %s)", submissionID, resp.Status) +} + // ReadAll is a helper, as ioutil.ReadAll is deprecated in Go 1.16+ // For older Go versions, use ioutil.ReadAll func ReadAll(r *http.Response) ([]byte, error) { @@ -146,20 +174,6 @@ func main() { numConsumers := flag.Int("consumers", 5, "Number of consumer goroutines for migration") flag.Parse() - if *username == "" || *password == "" { - envUser := os.Getenv("OJ_USERNAME") - envPass := os.Getenv("OJ_PASSWORD") - if *username == "" { - *username = envUser - } - if *password == "" { - *password = envPass - } - if *username == "" || *password == "" { - log.Fatal("Username and password must be provided either via flags (-username, -password) or environment variables (OJ_USERNAME, OJ_PASSWORD)") - } - } - log.Println("Attempting to login...") httpClient, err := loginUser(*username, *password) if err != nil { @@ -167,56 +181,20 @@ func main() { } log.Println("Login successful.") - // Positional argument parsing for offset and count (if no flags were set for them) - args := flag.Args() - if len(args) > 0 { - val, err := parseInt(args[0]) - if err == nil { - isOffsetDefault := true - flag.Visit(func(f *flag.Flag) { - if f.Name == "offset" { - isOffsetDefault = false - } - }) - if isOffsetDefault { - *offset = val - } - } else { - log.Printf("Warning: Could not parse first positional argument as offset: %v. Using default or flag value.", err) - } - } - if len(args) > 1 { - val, err := parseInt(args[1]) - if err == nil { - isCountDefault := true - flag.Visit(func(f *flag.Flag) { - if f.Name == "count" { - isCountDefault = false - } - }) - if isCountDefault { - *count = val - } - } else { - log.Printf("Warning: Could not parse second positional argument as count: %v. Using default or flag value.", err) - } - } - submissionIDChan := make(chan string, 1000) // Buffered channel for submission IDs - var consumerWg sync.WaitGroup // Producer goroutine go func() { defer close(submissionIDChan) // Close channel when producer is done const chunk = 100 - fetchOffset := offset - for *fetchOffset < *offset+*count { + fetchOffset := *offset + for fetchOffset < *offset+*count { fetchCount := chunk - if *fetchOffset+chunk > *offset+*count { - fetchCount = *offset + *count - *fetchOffset + if fetchOffset+chunk > *offset+*count { + fetchCount = *offset + *count - fetchOffset } - log.Printf("Fetching submission IDs with offset: %d, count: %d", *fetchOffset, fetchCount) - submissionIDs, err := fetchSubmissionIDs(*fetchOffset, fetchCount, httpClient) + log.Printf("Fetching submission IDs with offset: %d, count: %d", fetchOffset, fetchCount) + submissionIDs, err := fetchSubmissionIDs(fetchOffset, fetchCount, httpClient) if err != nil { log.Printf("Error fetching submission IDs: %v. Producer stopping.", err) return @@ -230,14 +208,15 @@ func main() { for _, id := range submissionIDs { submissionIDChan <- id } - *fetchOffset += chunk + fetchOffset += chunk } log.Println("Producer finished sending all submission IDs.") }() // Consumer goroutines - log.Printf("Starting %d consumer goroutines...", *numConsumers) - for i := 0; i < *numConsumers; i++ { + log.Printf("Starting %d consumer goroutines...", numConsumers) + var consumerWg sync.WaitGroup + for i := range *numConsumers { consumerWg.Add(1) go func(workerID int) { defer consumerWg.Done() @@ -245,6 +224,7 @@ func main() { for id := range submissionIDChan { log.Printf("Consumer %d processing submissionId: %s", workerID, id) migrateSubmissionCode(id, httpClient) + migrateSubmissionOutput(id, httpClient) } log.Printf("Consumer %d finished.", workerID) }(i + 1) @@ -253,8 +233,3 @@ func main() { consumerWg.Wait() // Wait for all consumers to finish log.Println("All processing finished.") } - -// Helper to parse int from string for positional arguments -func parseInt(s string) (int, error) { - return strconv.Atoi(s) -} diff --git a/model/submission.py b/model/submission.py index c5a8c85..3c9013a 100644 --- a/model/submission.py +++ b/model/submission.py @@ -538,3 +538,18 @@ def migrate_code(user: User, submission: Submission): submission.migrate_code_to_minio() return HTTPResponse('ok') + + +@submission_api.post('//migrate-output') +@login_required +@identity_verify(0) +@Request.doc('submission', Submission) +def migrate_output(user: User, submission: Submission): + if not submission.permission( + user, + Submission.Permission.MANAGER, + ): + return HTTPError('forbidden.', 403) + + submission.migrate_output_to_minio() + return HTTPResponse('ok') diff --git a/mongo/submission.py b/mongo/submission.py index e5950bf..fb972fd 100644 --- a/mongo/submission.py +++ b/mongo/submission.py @@ -360,7 +360,7 @@ def rejudge(self) -> bool: return self.send() def _generate_code_minio_path(self): - return f'submissions/{ULID()}.zip' + return f'submissions/{self.id}_{ULID()}.zip' def _put_code(self, code_file) -> str: ''' @@ -543,7 +543,7 @@ def _generate_output_minio_path(self, task_no: int, case_no: int) -> str: ''' generate a output file path for minio ''' - return f'submissions/task{task_no:02d}_case{case_no:02d}_{ULID()}.zip' + return f'submissions/{self.id}_task{task_no:02d}_case{case_no:02d}_{ULID()}.zip' def finish_judging(self): # update user's submission @@ -960,3 +960,97 @@ def _check_code_consistency(self): f"calculated minio checksum. submission={self.id} checksum={minio_checksum}" ) return minio_checksum == gridfs_checksum + + def migrate_output_to_minio(self): + """ + migrate output from gridfs to minio + """ + for (i, task) in enumerate(self.tasks): + for (j, case) in enumerate(task.cases): + self._migrate_case_output_to_minio(case, i, j) + + def _migrate_case_output_to_minio( + self, + case: engine.CaseResult, + i: int, + j: int, + ): + """ + migrate a single case output to minio + """ + minio_client = MinioClient() + + if case.output is None or case.output.grid_id is None: + self.logger.info( + f"no output to migrate. submission={self.id} task={i} case={j}" + ) + return + + if case.output_minio_path is None: + self.logger.info( + f"uploading output to minio. submission={self.id} task={i} case={j}" + ) + output_minio_path = self._generate_output_minio_path(i, j) + minio_client.client.put_object( + minio_client.bucket, + output_minio_path, + io.BytesIO(case.output.read()), + -1, + part_size=5 * 1024 * 1024, # 5MB + content_type='application/zip', + ) + case.output_minio_path = output_minio_path + self.save() + self.logger.info( + f"output uploaded to minio. submission={self.id} task={i} case={j}" + ) + + # remove output in gridfs if it is consistent + if self._check_case_output_consistency(case, i, j): + self.logger.info( + f"data consistency validated, removing output in gridfs. submission={self.id} task={i} case={j}" + ) + case.output.delete() + self.save() + else: + self.logger.warning( + f"data inconsistent, keeping output in gridfs. submission={self.id} task={i} case={j}" + ) + + def _check_case_output_consistency( + self, + case: engine.CaseResult, + i: int, + j: int, + ): + """ + check whether the case output is consistent + """ + if case.output is None or case.output.grid_id is None: + return False + gridfs_output = case.output.read() + if gridfs_output is None: + # if file is deleted but GridFS proxy is not updated + return False + gridfs_checksum = md5(gridfs_output).hexdigest() + self.logger.info( + f"calculated grid checksum. submission={self.id} task={i} case={j} checksum={gridfs_checksum}" + ) + + minio_client = MinioClient() + try: + resp = minio_client.client.get_object( + minio_client.bucket, + case.output_minio_path, + ) + minio_output = resp.read() + finally: + if 'resp' in locals(): + resp.close() + resp.release_conn() + + minio_checksum = md5(minio_output).hexdigest() + self.logger.info( + f"calculated minio checksum. submission={self.id} task={i} case={j} checksum={minio_checksum}" + ) + return minio_checksum == gridfs_checksum