diff --git a/openviking/async_client.py b/openviking/async_client.py index 1dc1f1a8..cfc3920f 100644 --- a/openviking/async_client.py +++ b/openviking/async_client.py @@ -255,10 +255,10 @@ async def overview(self, uri: str) -> str: await self._ensure_initialized() return await self._client.overview(uri) - async def read(self, uri: str) -> str: + async def read(self, uri: str, offset: int = 0, limit: int = -1) -> str: """Read file content""" await self._ensure_initialized() - return await self._client.read(uri) + return await self._client.read(uri, offset=offset, limit=limit) async def ls(self, uri: str, **kwargs) -> List[Any]: """ diff --git a/openviking/client/local.py b/openviking/client/local.py index 208d7312..27ce39c7 100644 --- a/openviking/client/local.py +++ b/openviking/client/local.py @@ -144,9 +144,9 @@ async def mv(self, from_uri: str, to_uri: str) -> None: # ============= Content Reading ============= - async def read(self, uri: str) -> str: + async def read(self, uri: str, offset: int = 0, limit: int = -1) -> str: """Read file content.""" - return await self._service.fs.read(uri) + return await self._service.fs.read(uri, offset=offset, limit=limit) async def abstract(self, uri: str) -> str: """Read L0 abstract.""" diff --git a/openviking/parse/parsers/code/code.py b/openviking/parse/parsers/code/code.py index f22f0a42..3693384b 100644 --- a/openviking/parse/parsers/code/code.py +++ b/openviking/parse/parsers/code/code.py @@ -15,6 +15,8 @@ import stat import tempfile import time +import urllib.request +import zipfile from pathlib import Path, PurePosixPath from typing import Any, List, Optional, Tuple, Union from urllib.parse import unquote, urlparse @@ -106,14 +108,22 @@ async def parse(self, source: Union[str, Path], instruction: str = "", **kwargs) # 2. Fetch content (Clone or Extract) repo_name = "repository" + local_dir = Path(temp_local_dir) if source_str.startswith(("http://", "https://", "git://", "ssh://")): repo_url, branch, commit = self._parse_repo_source(source_str, **kwargs) - repo_name = await self._git_clone( - repo_url, - temp_local_dir, - branch=branch, - commit=commit, - ) + if self._is_github_url(repo_url) and not commit: + # Use GitHub ZIP API: single HTTPS download, no git history, much faster + local_dir, repo_name = await self._github_zip_download( + repo_url, branch, temp_local_dir + ) + else: + # Non-GitHub URL or specific commit: fall back to git clone + repo_name = await self._git_clone( + repo_url, + temp_local_dir, + branch=branch, + commit=commit, + ) elif str(source).endswith(".zip"): repo_name = await self._extract_zip(source_str, temp_local_dir) else: @@ -128,9 +138,7 @@ async def parse(self, source: Union[str, Path], instruction: str = "", **kwargs) logger.info(f"Uploading to VikingFS: {target_root_uri}") # 4. Upload to VikingFS (filtering on the fly) - file_count = await self._upload_directory( - Path(temp_local_dir), target_root_uri, viking_fs - ) + file_count = await self._upload_directory(local_dir, target_root_uri, viking_fs) logger.info(f"Uploaded {file_count} files to {target_root_uri}") @@ -185,9 +193,7 @@ async def parse_content( """Not supported for repositories.""" raise NotImplementedError("CodeRepositoryParser does not support parse_content") - def _parse_repo_source( - self, source: str, **kwargs - ) -> Tuple[str, Optional[str], Optional[str]]: + def _parse_repo_source(self, source: str, **kwargs) -> Tuple[str, Optional[str], Optional[str]]: branch = kwargs.get("branch") or kwargs.get("ref") commit = kwargs.get("commit") repo_url = source @@ -275,6 +281,91 @@ async def _has_commit(self, repo_dir: str, commit: str) -> bool: except RuntimeError: return False + @staticmethod + def _is_github_url(url: str) -> bool: + """Return True for github.com URLs (supports ZIP archive API).""" + return urlparse(url).netloc in ("github.com", "www.github.com") + + async def _github_zip_download( + self, + repo_url: str, + branch: Optional[str], + target_dir: str, + ) -> Tuple[Path, str]: + """Download a GitHub repo as a ZIP archive and extract it. + + Uses the GitHub archive API (single HTTPS GET, no git history). + + Returns: + (content_dir, repo_name) — content_dir is the extracted repo root. + """ + repo_name = self._get_repo_name(repo_url) + + # Build archive URL from owner/repo path components. + parsed = urlparse(repo_url) + path_parts = [p for p in parsed.path.split("/") if p] + owner = path_parts[0] + repo_raw = path_parts[1] + # Strip .git suffix for the archive URL (git clone keeps it, ZIP API does not). + repo_slug = repo_raw[:-4] if repo_raw.endswith(".git") else repo_raw + + if branch: + zip_url = f"https://github.com/{owner}/{repo_slug}/archive/refs/heads/{branch}.zip" + else: + zip_url = f"https://github.com/{owner}/{repo_slug}/archive/HEAD.zip" + + logger.info(f"Downloading GitHub ZIP: {zip_url}") + + zip_path = os.path.join(target_dir, "_archive.zip") + extract_dir = os.path.join(target_dir, "_extracted") + os.makedirs(extract_dir, exist_ok=True) + + # Download (blocking HTTP; run in thread pool to avoid stalling event loop). + def _download() -> None: + req = urllib.request.Request(zip_url, headers={"User-Agent": "OpenViking"}) + with urllib.request.urlopen(req) as resp, open(zip_path, "wb") as f: + shutil.copyfileobj(resp, f) + + try: + await asyncio.to_thread(_download) + except Exception as exc: + raise RuntimeError(f"Failed to download GitHub ZIP {zip_url}: {exc}") + + # Safe extraction with Zip Slip validation (mirrors _extract_zip logic). + target = Path(extract_dir).resolve() + with zipfile.ZipFile(zip_path, "r") as zf: + for info in zf.infolist(): + mode = info.external_attr >> 16 + if info.is_dir() or stat.S_ISDIR(mode): + continue + if stat.S_ISLNK(mode): + logger.warning(f"Skipping symlink entry in GitHub ZIP: {info.filename}") + continue + raw = info.filename.replace("\\", "/") + raw_parts = [p for p in raw.split("/") if p] + if ".." in raw_parts: + raise ValueError(f"Zip Slip detected in GitHub archive: {info.filename!r}") + if PurePosixPath(raw).is_absolute(): + raise ValueError(f"Zip Slip detected in GitHub archive: {info.filename!r}") + extracted = Path(zf.extract(info, extract_dir)).resolve() + if not extracted.is_relative_to(target): + extracted.unlink(missing_ok=True) + raise ValueError(f"Zip Slip detected in GitHub archive: {info.filename!r}") + + # Remove downloaded archive to free disk space. + try: + os.unlink(zip_path) + except OSError: + pass + + # GitHub ZIPs have a single top-level directory: {repo}-{branch}/ or {repo}-{sha}/. + # Return that directory as the content root so callers see bare repo files. + top_level = [d for d in Path(extract_dir).iterdir() if d.is_dir()] + content_dir = top_level[0] if len(top_level) == 1 else Path(extract_dir) + + logger.info(f"GitHub ZIP extracted to {content_dir} ({repo_name})") + return content_dir, repo_name + async def _git_clone( self, url: str, @@ -282,26 +373,15 @@ async def _git_clone( branch: Optional[str] = None, commit: Optional[str] = None, ) -> str: - """ - Clone git repository. + """Clone a git repository into target_dir; return the repo name. + + Uses --depth 1 for speed. If a specific commit is requested, it is + fetched and checked out after the shallow clone. Returns: - Repository name (e.g. "OpenViking" from "https://.../OpenViking.git") + Repository name derived from the URL (e.g. "OpenViking"). """ - # Extract repo name from URL name = self._get_repo_name(url) - - # Clone into a subdirectory to keep structure clean - # But here we clone content directly into target_dir? - # Actually, git clone clones INTO . - # But if we want the repo name directory to exist in VikingFS, we should clone into target_dir/name? - # No, parse logic says: - # temp_local_dir contains the files (e.g. .git, src, README) - # We upload temp_local_dir content to viking://temp/{uuid}/{repo_name}/ - - # So we clone current content directly into temp_local_dir - # git clone --depth 1 url target_dir - logger.info(f"Cloning {url} to {target_dir}...") clone_args = [ @@ -320,19 +400,30 @@ async def _git_clone( await self._run_git(["git", "-C", target_dir, "fetch", "origin", commit]) except RuntimeError: try: - await self._run_git(["git", "-C", target_dir, "fetch", "--all", "--tags", "--prune"]) + await self._run_git( + ["git", "-C", target_dir, "fetch", "--all", "--tags", "--prune"] + ) except RuntimeError: pass ok = await self._has_commit(target_dir, commit) if not ok: try: - await self._run_git(["git", "-C", target_dir, "fetch", "--unshallow", "origin"]) + await self._run_git( + ["git", "-C", target_dir, "fetch", "--unshallow", "origin"] + ) except RuntimeError: pass ok = await self._has_commit(target_dir, commit) if not ok: await self._run_git( - ["git", "-C", target_dir, "fetch", "origin", "+refs/heads/*:refs/remotes/origin/*"] + [ + "git", + "-C", + target_dir, + "fetch", + "origin", + "+refs/heads/*:refs/remotes/origin/*", + ] ) ok = await self._has_commit(target_dir, commit) if not ok: @@ -342,17 +433,7 @@ async def _git_clone( return name async def _extract_zip(self, zip_path: str, target_dir: str) -> str: - """Extract zip file.""" - import zipfile - - # We assume it's a local path if passed here? - # Actually logic in parse() handles local path check before calling here? - # Or if it's a URL ending in zip, HTMLParser might have downloaded it? - # Wait, HTMLParser handles download. If we are here, source IS a path or URL. - # If it's a URL, we need to download it first? - # CodeRepositoryParser is designed to handle "source" which can be URL. - # So I need to download zip if it is a URL. - + """Extract a local zip file into target_dir; return the archive stem as the repo name.""" if zip_path.startswith(("http://", "https://")): # TODO: implement download logic or rely on caller? # For now, assume it's implemented if needed, but raise error as strictly we only support git URL for now as per plan diff --git a/openviking/parse/parsers/upload_utils.py b/openviking/parse/parsers/upload_utils.py index 39f73744..50c80ed0 100644 --- a/openviking/parse/parsers/upload_utils.py +++ b/openviking/parse/parsers/upload_utils.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Shared upload utilities for directory and file uploading to VikingFS.""" +import asyncio import os import re from pathlib import Path @@ -65,11 +66,13 @@ def detect_and_convert_encoding(content: bytes, file_path: Union[str, Path] = "" # Check for potential binary content (null bytes in first 8KB) # Binary files often contain null bytes which can cause issues sample_size = min(8192, len(content)) - if b'\x00' in content[:sample_size]: - null_count = content[:sample_size].count(b'\x00') + if b"\x00" in content[:sample_size]: + null_count = content[:sample_size].count(b"\x00") # If more than 5% null bytes in sample, likely binary - don't process if null_count / sample_size > 0.05: - logger.debug(f"Detected binary content in {file_path} (null bytes: {null_count}), skipping encoding detection") + logger.debug( + f"Detected binary content in {file_path} (null bytes: {null_count}), skipping encoding detection" + ) return content detected_encoding: Optional[str] = None @@ -77,7 +80,7 @@ def detect_and_convert_encoding(content: bytes, file_path: Union[str, Path] = "" try: decoded = content.decode(encoding) # Additional validation: check for control characters that suggest binary - control_chars = sum(1 for c in decoded[:1000] if ord(c) < 32 and c not in '\t\n\r') + control_chars = sum(1 for c in decoded[:1000] if ord(c) < 32 and c not in "\t\n\r") if control_chars / min(1000, len(decoded)) > 0.05: # More than 5% control chars continue detected_encoding = encoding @@ -92,8 +95,8 @@ def detect_and_convert_encoding(content: bytes, file_path: Union[str, Path] = "" if detected_encoding not in UTF8_VARIANTS: decoded_content = content.decode(detected_encoding, errors="replace") # Remove null bytes from decoded content as they can cause issues downstream - if '\x00' in decoded_content: - decoded_content = decoded_content.replace('\x00', '') + if "\x00" in decoded_content: + decoded_content = decoded_content.replace("\x00", "") logger.debug(f"Removed null bytes from decoded content in {file_path}") content = decoded_content.encode("utf-8") logger.debug(f"Converted {file_path} from {detected_encoding} to UTF-8") @@ -195,6 +198,9 @@ async def upload_text_files( return uploaded_count, warnings +_UPLOAD_CONCURRENCY = 8 + + async def upload_directory( local_dir: Path, viking_uri_base: str, @@ -203,24 +209,26 @@ async def upload_directory( ignore_extensions: Optional[Set[str]] = None, max_file_size: int = 10 * 1024 * 1024, ) -> Tuple[int, List[str]]: - """Upload an entire directory recursively and return uploaded count with warnings.""" + """Upload an entire directory recursively and return uploaded count with warnings. + + Optimized: collects all files in one pass, pre-creates directories upfront, + then uploads all files concurrently (up to _UPLOAD_CONCURRENCY at a time). + """ effective_ignore_dirs = ignore_dirs if ignore_dirs is not None else IGNORE_DIRS effective_ignore_extensions = ( ignore_extensions if ignore_extensions is not None else IGNORE_EXTENSIONS ) - uploaded_count = 0 warnings: List[str] = [] - await viking_fs.mkdir(viking_uri_base, exist_ok=True) + # --- Phase 1: Collect files and unique parent directory URIs in one pass --- + files_to_upload: List[Tuple[Path, str]] = [] # (local_path, target_uri) + parent_uris: Set[str] = {viking_uri_base} for root, dirs, files in os.walk(local_dir): dirs[:] = [ - dir_name - for dir_name in dirs - if not should_skip_directory(dir_name, ignore_dirs=effective_ignore_dirs) + d for d in dirs if not should_skip_directory(d, ignore_dirs=effective_ignore_dirs) ] - for file_name in files: file_path = Path(root) / file_name should_skip, _ = should_skip_file( @@ -230,18 +238,69 @@ async def upload_directory( ) if should_skip: continue - rel_path_str = str(file_path.relative_to(local_dir)).replace(os.sep, "/") try: safe_rel = _sanitize_rel_path(rel_path_str) - target_uri = f"{viking_uri_base}/{safe_rel}" - content = file_path.read_bytes() - content = detect_and_convert_encoding(content, file_path) - await viking_fs.write_file_bytes(target_uri, content) - uploaded_count += 1 - except Exception as exc: - warning = f"Failed to upload {file_path}: {exc}" + except ValueError as exc: + warning = f"Skipping {file_path}: {exc}" warnings.append(warning) logger.warning(warning) + continue + target_uri = f"{viking_uri_base}/{safe_rel}" + files_to_upload.append((file_path, target_uri)) + parent_uris.add(target_uri.rsplit("/", 1)[0]) + + # --- Phase 2: Pre-create all directories --- + # Memoized mkdir: each unique agfs path is created at most once. + # This is equivalent to _ensure_parent_dirs but avoids redundant HTTP calls + # by tracking already-processed paths across all directories. + _created: Set[str] = set() + + def _mkdir_with_parents(agfs_path: str) -> None: + parts = agfs_path.lstrip("/").split("/") + for i in range(1, len(parts) + 1): + p = "/" + "/".join(parts[:i]) + if p in _created: + continue + try: + viking_fs.agfs.mkdir(p) + _created.add(p) + except Exception as e: + if "already" in str(e).lower(): + _created.add(p) + else: + logger.warning(f"Failed to create directory {p}: {e}") + + def _create_all_dirs() -> None: + for dir_uri in sorted(parent_uris): + _mkdir_with_parents(viking_fs._uri_to_path(dir_uri)) + + await asyncio.to_thread(_create_all_dirs) + + # --- Phase 3: Upload files concurrently --- + sem = asyncio.Semaphore(_UPLOAD_CONCURRENCY) + errors: List[Optional[str]] = [None] * len(files_to_upload) + + async def _upload_one(idx: int, file_path: Path, target_uri: str) -> None: + async with sem: + + def _do() -> None: + content = file_path.read_bytes() + encoded = detect_and_convert_encoding(content, file_path) + agfs_path = viking_fs._uri_to_path(target_uri) + viking_fs.agfs.write(agfs_path, encoded) + + try: + await asyncio.to_thread(_do) + except Exception as exc: + errors[idx] = f"Failed to upload {file_path}: {exc}" + + await asyncio.gather(*[_upload_one(i, fp, uri) for i, (fp, uri) in enumerate(files_to_upload)]) + + for err in errors: + if err: + warnings.append(err) + logger.warning(err) + uploaded_count = sum(1 for e in errors if e is None) return uploaded_count, warnings diff --git a/openviking/parse/tree_builder.py b/openviking/parse/tree_builder.py index 7d3f5a89..2d202aa0 100644 --- a/openviking/parse/tree_builder.py +++ b/openviking/parse/tree_builder.py @@ -20,6 +20,7 @@ - Content splitting is handled by Parser, not TreeBuilder """ +import asyncio import logging from typing import TYPE_CHECKING, Optional @@ -139,7 +140,7 @@ async def finalize_from_temp( logger.info(f"[TreeBuilder] Finalizing from temp: {final_uri}") # 4. Move directory tree from temp to final location in AGFS - await self._move_directory_in_agfs(temp_doc_uri, final_uri) + await self._move_temp_to_dest(viking_fs, temp_doc_uri, final_uri) logger.info(f"[TreeBuilder] Moved temp tree: {temp_doc_uri} -> {final_uri}") # 5. Cleanup temporary root directory @@ -191,39 +192,15 @@ async def _exists(u: str) -> bool: raise FileExistsError(f"Cannot resolve unique name for {uri} after {max_attempts} attempts") - async def _move_directory_in_agfs(self, src_uri: str, dst_uri: str) -> None: - """Recursively move AGFS directory tree (copy + delete).""" - viking_fs = get_viking_fs() + async def _move_temp_to_dest(self, viking_fs, src_uri: str, dst_uri: str) -> None: + """Move temp directory to final destination using a single native AGFS mv call. - # 1. Ensure parent directories exist + Temp files have no vector records yet, so no vector index update is needed. + """ + src_path = viking_fs._uri_to_path(src_uri) + dst_path = viking_fs._uri_to_path(dst_uri) await self._ensure_parent_dirs(dst_uri) - - # 2. Create target directory - await viking_fs.mkdir(dst_uri) - - # 3. List source directory contents - entries = await viking_fs.ls(src_uri) - - for entry in entries: - name = entry.get("name", "") - if not name or name in [".", ".."]: - continue - - src_item = f"{src_uri}/{name}" - dst_item = f"{dst_uri}/{name}" - - if entry.get("isDir"): - # Recursively move subdirectory - await self._move_directory_in_agfs(src_item, dst_item) - else: - # Move file - await viking_fs.move_file(src_item, dst_item) - - # 4. Delete source directory (should be empty now) - try: - await viking_fs.rm(src_uri) - except Exception: - pass # Ignore error when deleting empty directory + await asyncio.to_thread(viking_fs.agfs.mv, src_path, dst_path) async def _ensure_parent_dirs(self, uri: str) -> None: """Recursively create parent directories.""" diff --git a/openviking/server/routers/content.py b/openviking/server/routers/content.py index c1e9d8ce..38a3cde3 100644 --- a/openviking/server/routers/content.py +++ b/openviking/server/routers/content.py @@ -15,11 +15,13 @@ @router.get("/read") async def read( uri: str = Query(..., description="Viking URI"), + offset: int = Query(0, description="Starting line number (0-indexed)"), + limit: int = Query(-1, description="Number of lines to read, -1 means read to end"), _ctx: RequestContext = Depends(get_request_context), ): """Read file content (L2).""" service = get_service() - result = await service.fs.read(uri) + result = await service.fs.read(uri, offset=offset, limit=limit) return Response(status="ok", result=result) diff --git a/openviking/service/core.py b/openviking/service/core.py index b040c24f..59d81163 100644 --- a/openviking/service/core.py +++ b/openviking/service/core.py @@ -89,7 +89,7 @@ def __init__( self._initialized = False # Initialize storage - self._init_storage(config.storage) + self._init_storage(config.storage, config.embedding.max_concurrent) # Initialize embedder self._embedder = config.embedding.get_embedder() @@ -97,7 +97,7 @@ def __init__( f"Initialized embedder (dim {config.embedding.dimension}, sparse {self._embedder.is_sparse})" ) - def _init_storage(self, config: StorageConfig) -> None: + def _init_storage(self, config: StorageConfig, max_concurrent_embedding: int = 1) -> None: """Initialize storage resources.""" if config.agfs.backend == "local": self._agfs_manager = AGFSManager(config=config.agfs) @@ -112,6 +112,7 @@ def _init_storage(self, config: StorageConfig) -> None: self._queue_manager = init_queue_manager( agfs_url=self._agfs_url, timeout=config.agfs.timeout, + max_concurrent_embedding=max_concurrent_embedding, ) else: logger.warning("AGFS URL not configured, skipping queue manager initialization") @@ -195,7 +196,7 @@ async def initialize(self) -> None: return if self._vikingdb_manager is None: - self._init_storage(self._config.storage) + self._init_storage(self._config.storage, self._config.embedding.max_concurrent) if self._embedder is None: self._embedder = self._config.embedding.get_embedder() @@ -204,10 +205,10 @@ async def initialize(self) -> None: # Initialize VikingFS and VikingDB with recorder if enabled enable_recorder = os.environ.get("OPENVIKING_ENABLE_RECORDER", "").lower() == "true" - + # Create context collection await init_context_collection(self._vikingdb_manager) - + self._viking_fs = init_viking_fs( agfs_url=self._agfs_url or "http://localhost:8080", query_embedder=self._embedder, diff --git a/openviking/service/fs_service.py b/openviking/service/fs_service.py index c57afb1e..4f5dcc45 100644 --- a/openviking/service/fs_service.py +++ b/openviking/service/fs_service.py @@ -121,10 +121,10 @@ async def stat(self, uri: str) -> Dict[str, Any]: viking_fs = self._ensure_initialized() return await viking_fs.stat(uri) - async def read(self, uri: str) -> str: + async def read(self, uri: str, offset: int = 0, limit: int = -1) -> str: """Read file content.""" viking_fs = self._ensure_initialized() - return await viking_fs.read_file(uri) + return await viking_fs.read_file(uri, offset=offset, limit=limit) async def abstract(self, uri: str) -> str: """Read L0 abstract (.abstract.md).""" diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index 584cc42b..c703b33b 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -7,6 +7,7 @@ similar to how init_viking_fs encapsulates VikingFS initialization. """ +import asyncio import hashlib import json from typing import Any, Dict, Optional @@ -149,7 +150,11 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, # Generate embedding vector(s) if self._embedder: - result: EmbedResult = self._embedder.embed(embedding_msg.message) + # embed() is a blocking HTTP call; offload to thread pool to avoid + # blocking the event loop and allow real concurrency. + result: EmbedResult = await asyncio.to_thread( + self._embedder.embed, embedding_msg.message + ) # Add dense vector if result.dense_vector: diff --git a/openviking/storage/queuefs/named_queue.py b/openviking/storage/queuefs/named_queue.py index 6baeb398..ca0e9b29 100644 --- a/openviking/storage/queuefs/named_queue.py +++ b/openviking/storage/queuefs/named_queue.py @@ -198,43 +198,58 @@ async def enqueue(self, data: Union[str, Dict[str, Any]]) -> str: msg_id = self._agfs.write(enqueue_file, data.encode("utf-8")) return msg_id if isinstance(msg_id, str) else str(msg_id) + def _read_queue_message(self) -> Optional[Dict[str, Any]]: + """Read and remove one message from the AGFS queue; return parsed dict or None. + + Normalises the various return types AGFSClient.read() may produce. + """ + content = self._agfs.read(f"{self.path}/dequeue") + if not content or content == b"{}": + return None + if isinstance(content, bytes): + raw = content + elif isinstance(content, str): + raw = content.encode("utf-8") + elif hasattr(content, "content") and content.content is not None: + raw = content.content + else: + raw = str(content).encode("utf-8") + return json.loads(raw.decode("utf-8")) + async def dequeue(self) -> Optional[Dict[str, Any]]: - """Get and remove message from queue (dequeue).""" + """Get and remove message from queue, then invoke the dequeue handler.""" await self._ensure_initialized() - dequeue_file = f"{self.path}/dequeue" - try: - content = self._agfs.read(dequeue_file) - if not content or content == b"{}": - return None - - # Handle different return types from AGFSClient - content_bytes = None - if isinstance(content, bytes): - content_bytes = content - elif isinstance(content, str): - content_bytes = content.encode("utf-8") - elif hasattr(content, "content"): # Response object - content_obj = content.content - if content_obj is not None: - content_bytes = content_obj - else: - content_bytes = str(content).encode("utf-8") - - if content_bytes is None: + data = self._read_queue_message() + if data is None: return None - data = json.loads(content_bytes.decode("utf-8")) - - # Dequeue success, mark in_progress if self._dequeue_handler: self._on_dequeue_start() data = await self._dequeue_handler.on_dequeue(data) - return data except Exception as e: logger.debug(f"[NamedQueue] Dequeue failed for {self.name}: {e}") return None + async def dequeue_raw(self) -> Optional[Dict[str, Any]]: + """Get and remove message from queue without invoking the handler.""" + await self._ensure_initialized() + try: + return self._read_queue_message() + except Exception as e: + logger.debug(f"[NamedQueue] Dequeue raw failed for {self.name}: {e}") + return None + + async def process_dequeued(self, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Invoke the dequeue handler on already-fetched raw data. + + NOTE: caller must call _on_dequeue_start() before invoking this method + so that in_progress is incremented atomically with the dequeue. + """ + if self._dequeue_handler: + return await self._dequeue_handler.on_dequeue(data) + return data + async def peek(self) -> Optional[Dict[str, Any]]: """Peek at head message without removing.""" await self._ensure_initialized() diff --git a/openviking/storage/queuefs/queue_manager.py b/openviking/storage/queuefs/queue_manager.py index 8832f8b0..7ceab4df 100644 --- a/openviking/storage/queuefs/queue_manager.py +++ b/openviking/storage/queuefs/queue_manager.py @@ -9,7 +9,8 @@ import atexit import threading import time -from typing import Any, Dict, Optional, Union +import traceback +from typing import Any, Dict, Optional, Set, Union from pyagfs import AGFSClient @@ -29,6 +30,7 @@ def init_queue_manager( agfs_url: str = "http://localhost:8080", timeout: int = 10, mount_point: str = "/queue", + max_concurrent_embedding: int = 1, ) -> "QueueManager": """Initialize QueueManager singleton.""" global _instance @@ -36,6 +38,7 @@ def init_queue_manager( agfs_url=agfs_url, timeout=timeout, mount_point=mount_point, + max_concurrent_embedding=max_concurrent_embedding, ) return _instance @@ -63,11 +66,13 @@ def __init__( agfs_url: str = "http://localhost:8080", timeout: int = 10, mount_point: str = "/queue", + max_concurrent_embedding: int = 1, ): """Initialize QueueManager.""" self._agfs_url = agfs_url self.timeout = timeout self.mount_point = mount_point + self._max_concurrent_embedding = max_concurrent_embedding self._agfs: Optional[Any] = None self._queues: Dict[str, NamedQueue] = {} self._started = False @@ -137,41 +142,101 @@ def _start_queue_worker(self, queue: NamedQueue) -> None: if thread.is_alive(): return + max_concurrent = self._max_concurrent_embedding if queue.name == self.EMBEDDING else 1 stop_event = threading.Event() self._queue_stop_events[queue.name] = stop_event thread = threading.Thread( target=self._queue_worker_loop, - args=(queue, stop_event), + args=(queue, stop_event, max_concurrent), daemon=True, ) self._queue_threads[queue.name] = thread thread.start() - def _queue_worker_loop(self, queue: NamedQueue, stop_event: threading.Event) -> None: - """Worker loop for a single queue, processes items sequentially.""" + def _queue_worker_loop( + self, queue: NamedQueue, stop_event: threading.Event, max_concurrent: int = 1 + ) -> None: + """Worker loop for a single queue. + + When max_concurrent > 1, items are fetched and processed in parallel + (up to max_concurrent at a time). Otherwise items are processed one by one. + """ loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - while not stop_event.is_set(): - try: - queue_size = loop.run_until_complete(queue.size()) - if queue.has_dequeue_handler() and queue_size > 0: - data = loop.run_until_complete(queue.dequeue()) - if data is not None: - logger.debug( - f"[QueueManager] Dequeued message from {queue.name}: {data}" - ) - else: + if max_concurrent > 1: + loop.run_until_complete( + self._worker_async_concurrent(queue, stop_event, max_concurrent) + ) + else: + while not stop_event.is_set(): + try: + queue_size = loop.run_until_complete(queue.size()) + if queue.has_dequeue_handler() and queue_size > 0: + data = loop.run_until_complete(queue.dequeue()) + if data is not None: + logger.debug( + f"[QueueManager] Dequeued message from {queue.name}: {data}" + ) + else: + stop_event.wait(self._poll_interval) + except Exception as e: + logger.error(f"[QueueManager] Worker error for {queue.name}: {e}") + traceback.print_exc() stop_event.wait(self._poll_interval) - except Exception as e: - logger.error(f"[QueueManager] Worker error for {queue.name}: {e}") - import traceback - - traceback.print_exc() - stop_event.wait(self._poll_interval) finally: loop.close() + async def _worker_async_concurrent( + self, queue: NamedQueue, stop_event: threading.Event, max_concurrent: int + ) -> None: + """Concurrent worker: drains the queue and processes items in parallel. + + A Semaphore caps inflight tasks at max_concurrent. + """ + sem = asyncio.Semaphore(max_concurrent) + active_tasks: Set[asyncio.Task] = set() + + async def process_one(data: Dict[str, Any]) -> None: + async with sem: + try: + await queue.process_dequeued(data) + except Exception as e: + # Handler did not call report_error; decrement in_progress manually. + queue._on_process_error(str(e), data) + logger.error(f"[QueueManager] Concurrent worker error for {queue.name}: {e}") + + while not stop_event.is_set(): + # Prune completed tasks + active_tasks = {t for t in active_tasks if not t.done()} + + # While capacity remains, keep draining the queue + while len(active_tasks) < max_concurrent: + try: + queue_size = await queue.size() + except Exception: + break + if not queue.has_dequeue_handler() or queue_size == 0: + break + data = await queue.dequeue_raw() + if data is None: + break + # Increment before task creation to close the race window where + # size=0 and in_progress=0 between dequeue_raw() and task execution. + queue._on_dequeue_start() + task = asyncio.create_task(process_one(data)) + active_tasks.add(task) + logger.debug( + f"[QueueManager] Dispatched concurrent task for {queue.name} " + f"(active={len(active_tasks)})" + ) + + await asyncio.sleep(self._poll_interval) + + # Drain remaining in-flight tasks on shutdown + if active_tasks: + await asyncio.gather(*active_tasks, return_exceptions=True) + def stop(self) -> None: """Stop QueueManager and release resources.""" global _instance diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index f01a4262..d47bcf40 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -111,6 +111,7 @@ def _enable_viking_fs_recorder(viking_fs: "VikingFS") -> None: recorder = get_recorder() if not recorder.enabled: from openviking.eval.recorder import init_recorder + init_recorder(enabled=True) global _instance @@ -964,11 +965,24 @@ async def write_file( async def read_file( self, uri: str, + offset: int = 0, + limit: int = -1, ) -> str: - """Read single file.""" + """Read single file, optionally sliced by line range. + + Args: + uri: Viking URI + offset: Starting line number (0-indexed). Default 0. + limit: Number of lines to read. -1 means read to end. Default -1. + """ path = self._uri_to_path(uri) content = self.agfs.read(path) - return self._handle_agfs_content(content) + text = self._handle_agfs_content(content) + if offset == 0 and limit == -1: + return text + lines = text.splitlines(keepends=True) + sliced = lines[offset:] if limit == -1 else lines[offset : offset + limit] + return "".join(sliced) async def read_file_bytes( self, diff --git a/openviking/sync_client.py b/openviking/sync_client.py index 8dc1f60b..936328bb 100644 --- a/openviking/sync_client.py +++ b/openviking/sync_client.py @@ -128,9 +128,9 @@ def overview(self, uri: str) -> str: """Read L1 overview""" return run_async(self._async_client.overview(uri)) - def read(self, uri: str) -> str: + def read(self, uri: str, offset: int = 0, limit: int = -1) -> str: """Read file""" - return run_async(self._async_client.read(uri)) + return run_async(self._async_client.read(uri, offset=offset, limit=limit)) def ls(self, uri: str, **kwargs) -> List[Any]: """ diff --git a/openviking_cli/cli/commands/content.py b/openviking_cli/cli/commands/content.py index d20462e8..613eefa4 100644 --- a/openviking_cli/cli/commands/content.py +++ b/openviking_cli/cli/commands/content.py @@ -14,9 +14,11 @@ def register(app: typer.Typer) -> None: def read_command( ctx: typer.Context, uri: str = typer.Argument(..., help="Viking URI"), + offset: int = typer.Option(0, "--offset", "-s", help="Starting line number (0-indexed)"), + limit: int = typer.Option(-1, "--limit", "-n", help="Number of lines to read (-1 = all)"), ) -> None: """Read full file content (L2).""" - run(ctx, lambda client: client.read(uri)) + run(ctx, lambda client: client.read(uri, offset=offset, limit=limit)) @app.command("abstract") def abstract_command( diff --git a/openviking_cli/client/base.py b/openviking_cli/client/base.py index 7882b585..4141ba68 100644 --- a/openviking_cli/client/base.py +++ b/openviking_cli/client/base.py @@ -108,8 +108,14 @@ async def mv(self, from_uri: str, to_uri: str) -> None: # ============= Content Reading ============= @abstractmethod - async def read(self, uri: str) -> str: - """Read file content (L2).""" + async def read(self, uri: str, offset: int = 0, limit: int = -1) -> str: + """Read file content (L2). + + Args: + uri: Viking URI + offset: Starting line number (0-indexed). Default 0. + limit: Number of lines to read. -1 means read to end. Default -1. + """ ... @abstractmethod diff --git a/openviking_cli/client/http.py b/openviking_cli/client/http.py index 859bdd95..13e44b31 100644 --- a/openviking_cli/client/http.py +++ b/openviking_cli/client/http.py @@ -417,12 +417,12 @@ async def mv(self, from_uri: str, to_uri: str) -> None: # ============= Content Reading ============= - async def read(self, uri: str) -> str: + async def read(self, uri: str, offset: int = 0, limit: int = -1) -> str: """Read file content.""" uri = VikingURI.normalize(uri) response = await self._http.get( "/api/v1/content/read", - params={"uri": uri}, + params={"uri": uri, "offset": offset, "limit": limit}, ) return self._handle_response(response) diff --git a/openviking_cli/client/sync_http.py b/openviking_cli/client/sync_http.py index 7970ee26..40ea2e12 100644 --- a/openviking_cli/client/sync_http.py +++ b/openviking_cli/client/sync_http.py @@ -213,9 +213,9 @@ def mv(self, from_uri: str, to_uri: str) -> None: # ============= Content ============= - def read(self, uri: str) -> str: + def read(self, uri: str, offset: int = 0, limit: int = -1) -> str: """Read file content.""" - return run_async(self._async_client.read(uri)) + return run_async(self._async_client.read(uri, offset=offset, limit=limit)) def abstract(self, uri: str) -> str: """Read L0 abstract.""" diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index ba9aab96..5184ddab 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -15,7 +15,8 @@ class EmbeddingModelConfig(BaseModel): batch_size: int = Field(default=32, description="Batch size for embedding generation") input: str = Field(default="multimodal", description="Input type: 'text' or 'multimodal'") provider: Optional[str] = Field( - default="volcengine", description="Provider type: 'openai', 'volcengine', 'vikingdb', 'jina'" + default="volcengine", + description="Provider type: 'openai', 'volcengine', 'vikingdb', 'jina'", ) backend: Optional[str] = Field( default="volcengine", @@ -103,6 +104,10 @@ class EmbeddingConfig(BaseModel): sparse: Optional[EmbeddingModelConfig] = Field(default=None) hybrid: Optional[EmbeddingModelConfig] = Field(default=None) + max_concurrent: int = Field( + default=1, description="Maximum number of concurrent embedding requests" + ) + model_config = {"extra": "forbid"} @model_validator(mode="after") diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 4696b910..9504b436 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -14,25 +14,21 @@ import json import os import sys -import tempfile -import threading import unicodedata -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Dict, List, Optional -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from openviking.parse.parsers.upload_utils import ( +from openviking.parse.parsers.upload_utils import ( # noqa: I001 _sanitize_rel_path, detect_and_convert_encoding, is_text_file, upload_directory, ) -from openviking.storage.viking_fs import VikingFS from openviking.storage.vikingdb_interface import VikingDBInterface from openviking.utils.compression import CompressManager from openviking_cli.utils.uri import VikingURI @@ -40,24 +36,26 @@ class MockVikingDB(VikingDBInterface): """Mock vector database for testing.""" - + def __init__(self): self.collections: Dict[str, Dict] = {} self.data: Dict[str, List[Dict]] = {} self.deleted_ids: set = set() - + async def create_collection(self, name: str, schema: Dict) -> bool: if name in self.collections: return False self.collections[name] = schema self.data[name] = [] return True - - async def search_by_id(self, collection: str, doc_id: str, candidates: Optional[List[str]] = None) -> Optional[Dict]: + + async def search_by_id( + self, collection: str, doc_id: str, candidates: Optional[List[str]] = None + ) -> Optional[Dict]: """Search for document by ID with optional candidate filtering.""" if collection not in self.data: return None - + if candidates is None: # Search all documents for doc in self.data[collection]: @@ -68,17 +66,21 @@ async def search_by_id(self, collection: str, doc_id: str, candidates: Optional[ if not candidates: # Empty candidate list return None for doc in self.data[collection]: - if doc.get("id") == doc_id and doc_id in candidates and doc_id not in self.deleted_ids: + if ( + doc.get("id") == doc_id + and doc_id in candidates + and doc_id not in self.deleted_ids + ): return doc - + return None - + async def insert(self, collection: str, data: List[Dict]) -> bool: if collection not in self.data: return False self.data[collection].extend(data) return True - + async def delete(self, collection: str, doc_id: str) -> bool: self.deleted_ids.add(doc_id) return True @@ -86,62 +88,62 @@ async def delete(self, collection: str, doc_id: str) -> bool: class TestLongFilenames: """Test handling of very long filenames and path components.""" - + def test_filename_exactly_255_bytes(self): """Test filename with exactly 255 bytes (filesystem limit boundary).""" # Create a filename that's exactly 255 bytes in UTF-8 base_name = "a" * 251 # 251 + ".txt" = 255 bytes filename = base_name + ".txt" - - assert len(filename.encode('utf-8')) == 255 - + + assert len(filename.encode("utf-8")) == 255 + # Test sanitization doesn't break at exact boundary sanitized = _sanitize_rel_path(filename) assert sanitized is not None assert len(sanitized) > 0 - + def test_filename_256_bytes_boundary(self): """Test filename with 256 bytes (just over filesystem limit).""" # Create filename that's exactly 256 bytes - should be truncated - base_name = "b" * 252 # 252 + ".txt" = 256 bytes + base_name = "b" * 252 # 252 + ".txt" = 256 bytes filename = base_name + ".txt" - - assert len(filename.encode('utf-8')) == 256 - + + assert len(filename.encode("utf-8")) == 256 + sanitized = _sanitize_rel_path(filename) # Should be handled gracefully (truncated or rejected) assert sanitized is not None - + def test_very_long_filename_with_cjk(self): """Test extremely long filename with CJK characters (3 bytes per char in UTF-8).""" # Each CJK character is 3 bytes in UTF-8 cjk_chars = "测试文件名" * 30 # ~450 bytes filename = f"{cjk_chars}.py" - - assert len(filename.encode('utf-8')) > 400 - + + assert len(filename.encode("utf-8")) > 400 + sanitized = _sanitize_rel_path(filename) assert sanitized is not None # Should handle or truncate appropriately - + def test_filename_only_special_characters(self): """Test filename composed entirely of special characters.""" special_filename = "!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/~`" + ".txt" - + sanitized = _sanitize_rel_path(special_filename) # Should sanitize dangerous characters while preserving valid ones assert sanitized is not None assert ".txt" in sanitized # Extension should be preserved - + def test_filename_with_path_traversal_attempts(self): """Test filename containing path traversal sequences.""" dangerous_filenames = [ "../../../etc/passwd", "..\\..\\windows\\system32\\config", "file/../../../secret.txt", - "normal_file_../../../dangerous.py" + "normal_file_../../../dangerous.py", ] - + for filename in dangerous_filenames: sanitized = _sanitize_rel_path(filename) # Should not contain path traversal sequences @@ -151,241 +153,237 @@ def test_filename_with_path_traversal_attempts(self): class TestSearchByIdEdgeCases: """Test search_by_id with various edge cases and None conditions.""" - + @pytest.mark.asyncio async def test_search_nonexistent_id(self): """Test searching for an ID that doesn't exist.""" mock_db = MockVikingDB() await mock_db.create_collection("test", {}) - + result = await mock_db.search_by_id("test", "nonexistent_id") assert result is None - + @pytest.mark.asyncio async def test_search_after_delete(self): """Test searching for an ID after it has been deleted.""" mock_db = MockVikingDB() await mock_db.create_collection("test", {}) - + # Insert document await mock_db.insert("test", [{"id": "doc1", "content": "test"}]) - + # Verify it exists result = await mock_db.search_by_id("test", "doc1") assert result is not None - + # Delete it await mock_db.delete("test", "doc1") - + # Search should return None result = await mock_db.search_by_id("test", "doc1") assert result is None - + @pytest.mark.asyncio async def test_search_with_empty_candidates(self): """Test search_by_id with empty candidate list.""" mock_db = MockVikingDB() await mock_db.create_collection("test", {}) - + # Insert document await mock_db.insert("test", [{"id": "doc1", "content": "test"}]) - + # Search with empty candidates should return None result = await mock_db.search_by_id("test", "doc1", candidates=[]) assert result is None - + @pytest.mark.asyncio async def test_search_with_none_candidates(self): """Test search_by_id with None candidates (should search all).""" mock_db = MockVikingDB() await mock_db.create_collection("test", {}) - + # Insert document await mock_db.insert("test", [{"id": "doc1", "content": "test"}]) - + # Search with None candidates should find document result = await mock_db.search_by_id("test", "doc1", candidates=None) assert result is not None assert result["id"] == "doc1" - + @pytest.mark.asyncio async def test_search_nonexistent_collection(self): """Test searching in a collection that doesn't exist.""" mock_db = MockVikingDB() - + result = await mock_db.search_by_id("nonexistent", "doc1") assert result is None class TestDuplicateFilenameHandling: """Test duplicate filename handling and case sensitivity.""" - + @pytest.mark.asyncio async def test_upload_same_file_multiple_times(self, tmp_path): """Test uploading the same file 10 times - should handle duplicates gracefully.""" # Create test file test_file = tmp_path / "duplicate_test.txt" test_file.write_text("This is a test file for duplicate testing.") - + # Mock VikingFS mock_fs = MagicMock() mock_fs.write_file_bytes = AsyncMock() mock_fs.mkdir = AsyncMock() - + # Upload the same file 10 times - for i in range(10): + for _ in range(10): await upload_text_files([str(test_file)], "viking://test/", mock_fs) - + # Should handle duplicates without crashing assert mock_fs.write_file_bytes.call_count == 10 - + def test_case_sensitivity_filenames(self): """Test filenames that differ only in case.""" - filenames = [ - "TestFile.txt", - "testfile.txt", - "TESTFILE.TXT", - "TestFile.TXT" - ] - + filenames = ["TestFile.txt", "testfile.txt", "TESTFILE.TXT", "TestFile.TXT"] + sanitized_names = [_sanitize_rel_path(name) for name in filenames] - + # All should be valid but may be treated differently on case-insensitive systems for name in sanitized_names: assert name is not None assert len(name) > 0 - + def test_unicode_normalization_differences(self): """Test filenames with different Unicode normalizations (NFC vs NFD).""" # Same logical character represented differently filename_nfc = "café.txt" # NFC: é is a single codepoint filename_nfd = "cafe\u0301.txt" # NFD: e + combining acute accent - + # These look the same but have different byte representations assert filename_nfc != filename_nfd - assert unicodedata.normalize('NFC', filename_nfd) == filename_nfc - + assert unicodedata.normalize("NFC", filename_nfd) == filename_nfc + sanitized_nfc = _sanitize_rel_path(filename_nfc) sanitized_nfd = _sanitize_rel_path(filename_nfd) - + assert sanitized_nfc is not None assert sanitized_nfd is not None class TestCompressionEdgeCases: """Test compression with various edge cases.""" - + def test_compress_empty_string(self): """Test compressing empty string.""" compressor = CompressManager() - + empty_text = "" compressed = compressor.compress_text(empty_text) - + assert compressed is not None assert len(compressed) >= 0 # May be empty or minimal header - + def test_compress_emoji_only(self): """Test compressing string containing only emoji.""" compressor = CompressManager() - + emoji_text = "😀😃😄😁😆😅😂🤣😊😇🙂🙃😉😌😍🥰😘😗😙😚" compressed = compressor.compress_text(emoji_text) - + assert compressed is not None # Emoji compression ratio might be poor due to lack of repetition assert len(compressed) > 0 - + def test_compress_large_text(self): """Test compressing 1MB of text.""" compressor = CompressManager() - + # Generate ~1MB of text with patterns (should compress well) large_text = "This is a repeating pattern for compression testing. " * 20000 - - assert len(large_text.encode('utf-8')) > 1_000_000 - + + assert len(large_text.encode("utf-8")) > 1_000_000 + compressed = compressor.compress_text(large_text) - + assert compressed is not None # Should achieve significant compression due to repetition - compression_ratio = len(compressed) / len(large_text.encode('utf-8')) + compression_ratio = len(compressed) / len(large_text.encode("utf-8")) assert compression_ratio < 0.1 # Should compress to less than 10% - + def test_compress_binary_like_data(self): """Test compressing text that looks like binary data.""" compressor = CompressManager() - + # Pseudo-random looking text (poor compression expected) import hashlib + binary_like = "" for i in range(1000): binary_like += hashlib.md5(f"seed{i}".encode()).hexdigest() - + compressed = compressor.compress_text(binary_like) assert compressed is not None # Compression ratio should be poor for random-looking data - compression_ratio = len(compressed) / len(binary_like.encode('utf-8')) + compression_ratio = len(compressed) / len(binary_like.encode("utf-8")) assert compression_ratio > 0.8 # Should not compress much class TestConcurrentOperations: """Test concurrent operations for race conditions and thread safety.""" - + @pytest.mark.asyncio async def test_concurrent_writes(self): """Test 20 parallel write operations.""" mock_fs = MagicMock() mock_fs.write_file_bytes = AsyncMock() mock_fs.mkdir = AsyncMock() - + # Create 20 concurrent write tasks async def write_task(i): content = f"Content for file {i}" uri = f"viking://concurrent/file_{i}.txt" - await mock_fs.write_file_bytes(uri, content.encode('utf-8')) - + await mock_fs.write_file_bytes(uri, content.encode("utf-8")) + tasks = [write_task(i) for i in range(20)] - + # Execute all tasks concurrently await asyncio.gather(*tasks) - + # Verify all writes were attempted assert mock_fs.write_file_bytes.call_count == 20 - - @pytest.mark.asyncio + + @pytest.mark.asyncio async def test_concurrent_search_while_writing(self): """Test 10 parallel searches while writing.""" mock_db = MockVikingDB() await mock_db.create_collection("concurrent", {}) - + # Insert initial data for i in range(5): await mock_db.insert("concurrent", [{"id": f"doc{i}", "content": f"content{i}"}]) - + async def search_task(): return await mock_db.search_by_id("concurrent", "doc1") - + async def write_task(): return await mock_db.insert("concurrent", [{"id": "new_doc", "content": "new_content"}]) - + # Mix of search and write operations tasks = [] tasks.extend([search_task() for _ in range(10)]) tasks.extend([write_task() for _ in range(5)]) - + results = await asyncio.gather(*tasks, return_exceptions=True) - + # No tasks should have raised exceptions for result in results: assert not isinstance(result, Exception) - + @pytest.mark.asyncio async def test_rapid_create_delete_cycles(self): """Test rapid create/delete cycles for race conditions.""" mock_db = MockVikingDB() await mock_db.create_collection("rapid", {}) - + async def create_delete_cycle(doc_id): # Create document await mock_db.insert("rapid", [{"id": doc_id, "content": "temp"}]) @@ -396,86 +394,86 @@ async def create_delete_cycle(doc_id): # Search again (should be None) deleted_result = await mock_db.search_by_id("rapid", doc_id) return result, deleted_result - + # Run 10 rapid create/delete cycles tasks = [create_delete_cycle(f"rapid_doc_{i}") for i in range(10)] results = await asyncio.gather(*tasks) - + # Verify results are consistent for found, deleted in results: assert found is not None # Should find before delete - assert deleted is None # Should not find after delete + assert deleted is None # Should not find after delete class TestUnicodeEdgeCases: """Test Unicode edge cases and special character handling.""" - + def test_zero_width_characters(self): """Test filenames containing zero-width characters.""" # Zero-width characters that might cause issues filename = "test\u200b\u200c\u200d\ufefffile.txt" # ZWSP, ZWNJ, ZWJ, BOM - + sanitized = _sanitize_rel_path(filename) assert sanitized is not None - + # Zero-width characters should ideally be stripped assert "\u200b" not in sanitized or len(sanitized) > 0 - + def test_rtl_text_filenames(self): """Test right-to-left text in filenames.""" # Arabic/Hebrew filename rtl_filename = "ملف_اختبار.txt" # Arabic for "test file" - + sanitized = _sanitize_rel_path(rtl_filename) assert sanitized is not None assert len(sanitized) > 0 - + # Should preserve RTL characters assert "ملف" in sanitized - + def test_combining_characters(self): """Test filenames with combining characters.""" # Base character + multiple combining marks filename = "e\u0301\u0302\u0303\u0304.txt" # e + acute + circumflex + tilde + macron - + sanitized = _sanitize_rel_path(filename) assert sanitized is not None assert len(sanitized) > 0 - + def test_surrogate_pairs(self): """Test filenames with surrogate pairs (emoji, etc).""" # Emoji that require surrogate pairs in UTF-16 filename = "test🏴󠁧󠁢󠁥󠁮󠁧󠁿🧑‍💻👨‍👩‍👧‍👦.txt" # Flag, person, family - + sanitized = _sanitize_rel_path(filename) - assert sanitized is not None + assert sanitized is not None assert len(sanitized) > 0 - + # Should handle complex emoji sequences class TestSecurityEdgeCases: """Test security-related edge cases.""" - + def test_null_bytes_in_content(self): """Test handling of null bytes in file content.""" content_with_nulls = "Hello\x00World\x00Test" - + # Should handle gracefully without crashing - encoding_result = detect_and_convert_encoding(content_with_nulls.encode('utf-8')) + encoding_result = detect_and_convert_encoding(content_with_nulls.encode("utf-8")) assert encoding_result is not None - + def test_deeply_nested_json(self): """Test handling of very deeply nested JSON structures.""" # Create deeply nested JSON (potential DoS via recursion) - nested_json = "{" + nested_json = "{" for _ in range(1000): nested_json += '"key": {' nested_json += '"value": "deep"' for _ in range(1000): nested_json += "}" nested_json += "}" - + # Should handle without stack overflow try: parsed = json.loads(nested_json) @@ -483,7 +481,7 @@ def test_deeply_nested_json(self): except (json.JSONDecodeError, RecursionError): # Either parsing fails gracefully or recursion is limited pass - + def test_malformed_uri_handling(self): """Test handling of malformed URIs.""" malformed_uris = [ @@ -493,7 +491,7 @@ def test_malformed_uri_handling(self): "viking://path with spaces", # Unescaped spaces "viking://../../../etc/passwd", # Path traversal ] - + for uri in malformed_uris: try: viking_uri = VikingURI(uri) @@ -506,7 +504,7 @@ def test_malformed_uri_handling(self): class TestBoundaryConditions: """Test various boundary conditions and limits.""" - + def test_is_text_file_edge_cases(self): """Test is_text_file with edge case filenames.""" edge_cases = [ @@ -520,7 +518,7 @@ def test_is_text_file_edge_cases(self): "file.TXT", # Uppercase extension "FILE.txt", # Mixed case ] - + for filename in edge_cases: # Should not crash try: @@ -529,34 +527,47 @@ def test_is_text_file_edge_cases(self): except Exception: # May raise exception for invalid filenames - that's OK pass - - def test_directory_upload_with_circular_symlinks(self, tmp_path): + + @pytest.mark.asyncio + async def test_directory_upload_with_circular_symlinks(self, tmp_path): """Test directory upload with circular symbolic links.""" - if os.name == 'nt': # Skip on Windows due to symlink permissions + if os.name == "nt": # Skip on Windows due to symlink permissions pytest.skip("Symlink test skipped on Windows") - + # Create directories dir_a = tmp_path / "dir_a" dir_b = tmp_path / "dir_b" dir_a.mkdir() dir_b.mkdir() - + # Create circular symlinks (dir_a / "link_to_b").symlink_to(dir_b) (dir_b / "link_to_a").symlink_to(dir_a) - + # Add a regular file (dir_a / "test.txt").write_text("test content") - - mock_fs = MagicMock() - mock_fs.write_file_bytes = AsyncMock() - mock_fs.mkdir = AsyncMock() - + + class FakeAGFS: + def mkdir(self, path: str) -> None: + pass + + def write(self, path: str, content: bytes) -> None: + pass + + class MockFS: + agfs = FakeAGFS() + + def _uri_to_path(self, uri: str) -> str: + return uri + + async def mkdir(self, uri: str, exist_ok: bool = False) -> None: + pass + # Should handle circular links without infinite recursion try: - result = await upload_directory(str(tmp_path), "viking://test/", mock_fs) + result = await upload_directory(tmp_path, "viking://test/", MockFS()) # Should complete without hanging - assert result is None or result is not None # Just checking it returns + assert result is None or result is not None except Exception as e: # Acceptable to raise exception for circular links assert "recursion" in str(e).lower() or "circular" in str(e).lower() @@ -570,4 +581,4 @@ async def upload_text_files(file_paths: List[str], target_uri: str, viking_fs): if path.exists() and path.is_file(): content = path.read_bytes() uri = f"{target_uri.rstrip('/')}/{path.name}" - await viking_fs.write_file_bytes(uri, content) \ No newline at end of file + await viking_fs.write_file_bytes(uri, content) diff --git a/tests/test_upload_utils.py b/tests/test_upload_utils.py index f334a583..6a47803d 100644 --- a/tests/test_upload_utils.py +++ b/tests/test_upload_utils.py @@ -22,12 +22,31 @@ # --------------------------------------------------------------------------- +class FakeAGFS: + """Minimal AGFS mock that stores files and directories by path key.""" + + def __init__(self, storage: Dict[str, bytes]) -> None: + self._storage = storage + self.dirs: List[str] = [] + + def mkdir(self, path: str) -> None: + self.dirs.append(path) + + def write(self, path: str, content: bytes) -> None: + self._storage[path] = content + + class FakeVikingFS: """Minimal VikingFS mock for testing upload functions.""" def __init__(self) -> None: self.files: Dict[str, bytes] = {} self.dirs: List[str] = [] + self.agfs = FakeAGFS(self.files) + + def _uri_to_path(self, uri: str) -> str: + # Simplified: use the URI itself as the storage key so test assertions work. + return uri async def write_file_bytes(self, uri: str, content: bytes) -> None: self.files[uri] = content @@ -328,7 +347,9 @@ async def test_skips_empty_files(self, tmp_dir: Path, viking_fs: FakeVikingFS) - @pytest.mark.asyncio async def test_creates_root_dir(self, tmp_dir: Path, viking_fs: FakeVikingFS) -> None: await upload_directory(tmp_dir, "viking://temp/root", viking_fs) - assert "viking://temp/root" in viking_fs.dirs + # _mkdir_with_parents strips leading slash then re-adds it, so the stored agfs + # path is the _uri_to_path() result with a "/" prefix. + assert any("temp/root" in d for d in viking_fs.agfs.dirs) @pytest.mark.asyncio async def test_custom_ignore_dirs(self, tmp_dir: Path, viking_fs: FakeVikingFS) -> None: @@ -477,10 +498,19 @@ async def mkdir(self, uri: str, exist_ok: bool = False) -> None: class TestUploadDirectoryEdgeCases: @pytest.mark.asyncio async def test_write_failure_produces_warning(self, tmp_path: Path) -> None: - class FailingWriteFS: - async def write_file_bytes(self, uri: str, content: bytes) -> None: + class FailingAGFS: + def mkdir(self, path: str) -> None: + pass + + def write(self, path: str, content: bytes) -> None: raise IOError("write error") + class FailingWriteFS: + agfs = FailingAGFS() + + def _uri_to_path(self, uri: str) -> str: + return uri + async def mkdir(self, uri: str, exist_ok: bool = False) -> None: pass