diff --git a/README.md b/README.md index 0bd079b..7b50ab1 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,28 @@ docker compose logs -f rule-bot | `ALLOWED_GROUP_IDS` | 群组模式允许的群组 ID,逗号分隔 | 空 | | `ADMIN_USER_IDS` | 管理员 Telegram 用户 ID,逗号分隔 | 空 | | `TZ` | 时区 | `Asia/Shanghai` | +| `DNS_CACHE_TTL` | DNS A 记录缓存秒数 | `60` | +| `DNS_CACHE_SIZE` | DNS A 记录缓存上限 | `1024` | +| `NS_CACHE_TTL` | DNS NS 记录缓存秒数 | `300` | +| `NS_CACHE_SIZE` | DNS NS 记录缓存上限 | `512` | +| `DNS_MAX_CONCURRENCY` | DNS 并发限制 | `20` | +| `DNS_CONN_LIMIT` | DNS 全局连接池上限 | `30` | +| `DNS_CONN_LIMIT_PER_HOST` | DNS 单主机连接上限 | `10` | +| `DNS_TIMEOUT_TOTAL` | DNS 请求总超时 | `10` | +| `DNS_TIMEOUT_CONNECT` | DNS 连接超时 | `3` | +| `GEOSITE_CACHE_TTL` | GeoSite 查询缓存秒数 | `3600` | +| `GEOSITE_CACHE_SIZE` | GeoSite 查询缓存上限 | `2048` | +| `GEOIP_CACHE_TTL` | GeoIP 缓存秒数 | `21600` | +| `GEOIP_CACHE_SIZE` | GeoIP 缓存上限 | `4096` | +| `GITHUB_FILE_CACHE_TTL` | 规则文件缓存秒数 | `60` | +| `GITHUB_FILE_CACHE_SIZE` | 规则文件缓存上限 | `4` | +| `METRICS_ENABLED` | 开启 metrics 导出 | `false` | +| `METRICS_EXPORT_PATH` | metrics 输出路径 | `/tmp/rule-bot-metrics.json` | +| `METRICS_EXPORT_INTERVAL` | metrics 导出间隔秒数 | `30` | +| `METRICS_RESET_ON_EXPORT` | 导出后清零 | `false` | +| `MEMORY_SOFT_LIMIT_MB` | 进程软限制 MB | `256` | +| `MEMORY_HARD_LIMIT_MB` | 进程硬限制 MB | `512` | +| `MEMORY_TRIM_ENABLED` | 启用内存修剪 | `true` | @@ -188,6 +210,41 @@ docker compose logs -f rule-bot - Python 3.12+ - `pip install -r requirements.txt` +## 🧪 1C1G 运行建议 + +建议对容器设置内存上限与 CPU 配额,避免全机抖动: + +```yaml +services: + rule-bot: + mem_limit: 256m + mem_reservation: 128m + cpus: "0.5" +``` + +如果你使用的是 Swarm 模式,可改用 `deploy.resources` 语法。 + +**Swap/RSS 观测** + +```bash +pid=$(pgrep -f "python -m src.main" | head -n 1) +grep -E "VmRSS|VmSize|VmSwap" /proc/$pid/status +``` + +**性能采样(需开启 metrics 导出)** + +```bash +export METRICS_ENABLED=1 +export METRICS_EXPORT_INTERVAL=30 +python tools/profile_runtime.py --process-name "python -m src.main" --duration 300 --interval 5 +``` + +**10 分钟压力模拟** + +```bash +python tools/stress_sim.py --duration 600 --concurrency 4 --pause 0.5 +``` + ## 📄 许可证 GPLv3 diff --git a/src/bot.py b/src/bot.py index b5fae4f..7eb4f24 100644 --- a/src/bot.py +++ b/src/bot.py @@ -15,6 +15,7 @@ from .config import Config from .data_manager import DataManager from .handlers import HandlerManager, GroupHandler +from .utils.metrics import EXPORTER class RuleBot: @@ -26,18 +27,36 @@ def __init__(self, config: Config, data_manager: DataManager): self.app: Optional[Application] = None self.handler_manager = None # 延迟初始化 self.group_handler = None # 群组处理器 + self._metrics_task = None async def stop(self): """停止机器人""" logger.info("正在停止机器人...") if self.handler_manager: await self.handler_manager.stop() + if self._metrics_task: + await EXPORTER.stop() if self.app: - await self.app.stop() - await self.app.shutdown() + try: + if self.app.updater and self.app.updater.running: + await self.app.updater.stop() + except Exception as e: + logger.debug(f"停止 updater 失败: {e}") + try: + if self.app.running: + await self.app.stop() + except Exception as e: + logger.debug(f"停止 app 失败: {e}") + try: + if self.app.initialized: + await self.app.shutdown() + except Exception as e: + logger.debug(f"关闭 app 失败: {e}") + if self.data_manager: + await self.data_manager.close() logger.info("机器人已停止") - def start(self): + async def start(self): """启动机器人""" try: # 创建应用 @@ -54,30 +73,23 @@ def start(self): # 启动轮询 logger.info("机器人启动成功,开始轮询...") - - # 在新的事件循环中运行机器人 - import asyncio - - async def run_bot(): - try: - async with self.app: - await self.handler_manager.start() # 显式启动服务(如 DNS Session) - await self.app.start() - await self.app.updater.start_polling( - allowed_updates=Update.ALL_TYPES, - drop_pending_updates=True # 丢弃待处理的更新,避免发送旧消息 - ) - # 保持运行 - await asyncio.Event().wait() - finally: - await self.stop() - - # 使用新的事件循环运行 - asyncio.run(run_bot()) + + async with self.app: + await self.handler_manager.start() # 显式启动服务(如 DNS Session) + await self.app.start() + self._metrics_task = EXPORTER.start() + await self.app.updater.start_polling( + allowed_updates=Update.ALL_TYPES, + drop_pending_updates=True # 丢弃待处理的更新,避免发送旧消息 + ) + # 保持运行 + await asyncio.Event().wait() except Exception as e: logger.error(f"机器人启动失败: {e}") raise + finally: + await self.stop() def _register_handlers(self): """注册所有处理器""" diff --git a/src/config.py b/src/config.py index 04f81b6..dd727c5 100644 --- a/src/config.py +++ b/src/config.py @@ -31,6 +31,34 @@ def __init__(self): # 数据目录(可选) self.DATA_DIR = os.getenv("DATA_DIR", "").strip() + + # 性能与缓存配置 + self.DNS_CACHE_TTL = self._parse_int_env("DNS_CACHE_TTL", 60, min_value=0) + self.DNS_CACHE_SIZE = self._parse_int_env("DNS_CACHE_SIZE", 1024, min_value=0) + self.NS_CACHE_TTL = self._parse_int_env("NS_CACHE_TTL", 300, min_value=0) + self.NS_CACHE_SIZE = self._parse_int_env("NS_CACHE_SIZE", 512, min_value=0) + self.DNS_MAX_CONCURRENCY = self._parse_int_env("DNS_MAX_CONCURRENCY", 20, min_value=1) + self.DNS_CONN_LIMIT = self._parse_int_env("DNS_CONN_LIMIT", 30, min_value=1) + self.DNS_CONN_LIMIT_PER_HOST = self._parse_int_env("DNS_CONN_LIMIT_PER_HOST", 10, min_value=1) + self.DNS_TIMEOUT_TOTAL = self._parse_int_env("DNS_TIMEOUT_TOTAL", 10, min_value=1) + self.DNS_TIMEOUT_CONNECT = self._parse_int_env("DNS_TIMEOUT_CONNECT", 3, min_value=1) + + self.GEOSITE_CACHE_TTL = self._parse_int_env("GEOSITE_CACHE_TTL", 3600, min_value=0) + self.GEOSITE_CACHE_SIZE = self._parse_int_env("GEOSITE_CACHE_SIZE", 2048, min_value=0) + self.GEOIP_CACHE_TTL = self._parse_int_env("GEOIP_CACHE_TTL", 21600, min_value=0) + self.GEOIP_CACHE_SIZE = self._parse_int_env("GEOIP_CACHE_SIZE", 4096, min_value=0) + + self.GITHUB_FILE_CACHE_TTL = self._parse_int_env("GITHUB_FILE_CACHE_TTL", 60, min_value=0) + self.GITHUB_FILE_CACHE_SIZE = self._parse_int_env("GITHUB_FILE_CACHE_SIZE", 4, min_value=0) + + # Metrics 配置 + self.METRICS_ENABLED = self._parse_bool_env("METRICS_ENABLED", False) + self.METRICS_EXPORT_PATH = os.getenv("METRICS_EXPORT_PATH", "/tmp/rule-bot-metrics.json") + self.METRICS_EXPORT_INTERVAL = self._parse_int_env("METRICS_EXPORT_INTERVAL", 30, min_value=1) + self.METRICS_RESET_ON_EXPORT = self._parse_bool_env("METRICS_RESET_ON_EXPORT", False) + + # 内存修剪(glibc malloc_trim) + self.MEMORY_TRIM_ENABLED = self._parse_bool_env("MEMORY_TRIM_ENABLED", True) # 群组验证配置(用于私聊模式下验证用户是否在群组中) required_group_id_raw = os.getenv("REQUIRED_GROUP_ID", "").strip() @@ -201,3 +229,32 @@ def _parse_doh_servers(self, value: str, defaults: Dict[str, str]) -> Dict[str, return defaults return servers + + def _parse_int_env( + self, + key: str, + default: int, + min_value: Optional[int] = None, + max_value: Optional[int] = None + ) -> int: + raw = os.getenv(key, "").strip() + if not raw: + return default + try: + value = int(raw) + except ValueError: + logger.warning(f"无效的 {key}: {raw},使用默认值 {default}") + return default + if min_value is not None and value < min_value: + logger.warning(f"{key} 小于最小值 {min_value},使用默认值 {default}") + return default + if max_value is not None and value > max_value: + logger.warning(f"{key} 大于最大值 {max_value},使用默认值 {default}") + return default + return value + + def _parse_bool_env(self, key: str, default: bool = False) -> bool: + raw = os.getenv(key, "").strip().lower() + if not raw: + return default + return raw in ("1", "true", "yes", "on") diff --git a/src/data_manager.py b/src/data_manager.py index 2ee52b5..396183f 100644 --- a/src/data_manager.py +++ b/src/data_manager.py @@ -5,16 +5,21 @@ import asyncio import aiohttp +import hashlib +import json import re import threading import time import tempfile from datetime import datetime, timedelta from pathlib import Path -from typing import Set, List, Pattern +from typing import Set, List, Pattern, Optional, Tuple, Dict, Any from loguru import logger from .config import Config +from .utils.cache import TTLCache +from .utils.metrics import METRICS +from .utils.memory import trim_memory class DataManager: @@ -28,12 +33,22 @@ def __init__(self, config: Config): self.geosite_includes: List[str] = [] self._data_lock = threading.RLock() self._update_lock = threading.Lock() + self._scheduler_task: Optional[asyncio.Task] = None + self._session: Optional[aiohttp.ClientSession] = None + self._geosite_cache = TTLCache( + config.GEOSITE_CACHE_SIZE, + config.GEOSITE_CACHE_TTL + ) + self._geosite_stamp: Optional[Tuple[int, int]] = None # 默认使用容器内目录,不强制持久化 self.data_dir = self._resolve_data_dir() logger.info("数据目录: {}", self.data_dir) self.geoip_file = self.data_dir / "geoip" / "Country-without-asn.mmdb" self.cn_ipv4_file = self.data_dir / "geoip" / "cn-ipv4.txt" self.geosite_file = self.data_dir / "geosite" / "direct-list.txt" + self.geoip_meta = self.geoip_file.with_suffix(self.geoip_file.suffix + ".meta.json") + self.cn_ipv4_meta = self.cn_ipv4_file.with_suffix(self.cn_ipv4_file.suffix + ".meta.json") + self.geosite_meta = self.geosite_file.with_suffix(self.geosite_file.suffix + ".meta.json") # 确保目录存在 self.data_dir.mkdir(parents=True, exist_ok=True) @@ -58,6 +73,36 @@ def _resolve_data_dir(self) -> Path: fallback = Path(tempfile.gettempdir()) / "rule-bot" fallback.mkdir(parents=True, exist_ok=True) return fallback + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session and not self._session.closed: + return self._session + connector = aiohttp.TCPConnector( + limit=4, + limit_per_host=2, + ttl_dns_cache=300, + use_dns_cache=True + ) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout(total=60, connect=10) + ) + return self._session + + async def close(self): + """关闭后台任务与共享 Session""" + if self._scheduler_task and not self._scheduler_task.done(): + self._scheduler_task.cancel() + try: + await self._scheduler_task + except asyncio.CancelledError: + # 预期中的异常:关闭时取消后台调度任务,无需向上传播 + logger.debug("后台调度任务在关闭过程中被取消") + self._scheduler_task = None + + if self._session and not self._session.closed: + await self._session.close() + self._session = None async def initialize(self): """初始化数据管理器""" @@ -91,20 +136,26 @@ async def _download_initial_data(self): self.config.DATA_UPDATE_INTERVAL ) + geoip_changed = False + cn_ipv4_changed = False + geosite_changed = False + if need_geoip: logger.info("下载 GeoIP 数据...") - await self._download_geoip() + geoip_changed = await self._download_geoip() if need_cn_ipv4: logger.info("下载中国 IPv4 CIDR 数据...") - await self._download_cn_ipv4() + cn_ipv4_changed = await self._download_cn_ipv4() if need_geosite: logger.info("下载 GeoSite 数据...") - await self._download_geosite() + geosite_changed = await self._download_geosite() # 加载 GeoSite 数据到内存 - await self._load_geosite_data() + await self._load_geosite_data(force=True) + if geosite_changed or geoip_changed or cn_ipv4_changed: + trim_memory("初始化后内存修剪") except Exception as e: logger.error(f"初始数据下载失败: {e}") @@ -113,10 +164,11 @@ async def _download_initial_data(self): async def _download_geoip(self): """下载 GeoIP 数据""" try: - await self._download_with_fallback( + return await self._download_with_fallback( self.config.GEOIP_URLS, self.geoip_file, - "GeoIP" + "geoip", + self.geoip_meta ) except Exception as e: logger.error(f"GeoIP 数据下载失败: {e}") @@ -125,10 +177,11 @@ async def _download_geoip(self): async def _download_cn_ipv4(self): """下载中国 IPv4 CIDR 数据""" try: - await self._download_with_fallback( + return await self._download_with_fallback( self.config.CN_IPV4_URLS, self.cn_ipv4_file, - "中国 IPv4 CIDR" + "cn_ipv4", + self.cn_ipv4_meta ) except Exception as e: logger.error(f"中国 IPv4 CIDR 数据下载失败: {e}") @@ -137,25 +190,32 @@ async def _download_cn_ipv4(self): async def _download_geosite(self): """下载 GeoSite 数据""" try: - async with aiohttp.ClientSession() as session: - async with session.get(self.config.GEOSITE_URL) as response: - if response.status == 200: - with open(self.geosite_file, 'wb') as f: - async for chunk in response.content.iter_chunked(8192): - f.write(chunk) - logger.info("GeoSite 数据下载完成") - else: - raise Exception(f"下载失败,状态码: {response.status}") + return await self._download_with_fallback( + [self.config.GEOSITE_URL], + self.geosite_file, + "geosite", + self.geosite_meta + ) except Exception as e: logger.error(f"GeoSite 数据下载失败: {e}") raise - async def _load_geosite_data(self): + async def _load_geosite_data(self, force: bool = False): """加载 GeoSite 数据到内存""" try: if not self.geosite_file.exists(): logger.warning("GeoSite 文件不存在,跳过加载") return + + try: + stat = self.geosite_file.stat() + stamp = (int(stat.st_mtime_ns), int(stat.st_size)) + except Exception: + stamp = None + + if not force and stamp and self._geosite_stamp == stamp: + logger.info("GeoSite 文件未变化,跳过加载") + return logger.info("加载 GeoSite 数据到内存...") domains: Set[str] = set() @@ -202,6 +262,9 @@ async def _load_geosite_data(self): self.geosite_keywords = keywords self.geosite_regex_patterns = regex_patterns self.geosite_includes = includes + if stamp: + self._geosite_stamp = stamp + self._geosite_cache.clear() logger.info( "GeoSite 数据加载完成,域名: {}, 关键字: {}, 正则: {}, include: {}", @@ -224,6 +287,12 @@ async def is_domain_in_geosite(self, domain: str) -> bool: if not domain: return False + cached = self._geosite_cache.get(domain) + if cached is not None: + METRICS.inc("geosite.cache.hit") + return cached + METRICS.inc("geosite.cache.miss") + with self._data_lock: domain_set = self.geosite_domains keywords = self.geosite_keywords @@ -231,6 +300,7 @@ async def is_domain_in_geosite(self, domain: str) -> bool: # 1. 直接检查完整域名 if domain in domain_set: + self._geosite_cache.set(domain, True) return True # 2. 检查是否为 GeoSite 中域名的子域名 @@ -239,18 +309,22 @@ async def is_domain_in_geosite(self, domain: str) -> bool: for i in range(1, len(parts)): parent_domain = '.'.join(parts[i:]) if parent_domain in domain_set: + self._geosite_cache.set(domain, True) return True for keyword in keywords: if keyword and keyword in domain: + self._geosite_cache.set(domain, True) return True for pattern in regex_patterns: if pattern.search(domain): + self._geosite_cache.set(domain, True) return True # 注意:不做反向检查,因为 GeoSite 通常只包含具体域名,不需要检查子域名覆盖父域名的情况 + self._geosite_cache.set(domain, False) return False except Exception as e: @@ -267,26 +341,33 @@ def _is_file_outdated(self, file_path: Path, max_age_seconds: int) -> bool: def _start_scheduled_updates(self): """启动定时更新任务""" - def run_scheduler(): - update_interval = self.config.DATA_UPDATE_INTERVAL - while True: - time.sleep(update_interval) - self._update_data_sync() - - # 在单独线程中运行调度器 - scheduler_thread = threading.Thread(target=run_scheduler, daemon=True) - scheduler_thread.start() + if self._scheduler_task and not self._scheduler_task.done(): + logger.info("定时更新任务已在运行") + return + self._scheduler_task = asyncio.create_task(self._scheduled_update_loop()) logger.info("定时更新任务已启动") - - def _update_data_sync(self): - """同步版本的数据更新(用于scheduler)""" + + async def _scheduled_update_loop(self): + """异步定时更新循环(与主事件循环一致)""" + update_interval = self.config.DATA_UPDATE_INTERVAL + while True: + try: + await asyncio.sleep(update_interval) + await self._update_data_guarded() + except asyncio.CancelledError: + logger.info("定时更新任务已停止") + raise + except Exception as e: + logger.error(f"定时更新循环异常: {e}") + await asyncio.sleep(1) + + async def _update_data_guarded(self): + """带并发保护的数据更新""" if not self._update_lock.acquire(blocking=False): logger.info("已有更新任务在执行,跳过本次更新") return try: - asyncio.run(self._update_data()) - except Exception as e: - logger.error(f"同步更新执行失败: {e}") + await self._update_data() finally: self._update_lock.release() @@ -296,34 +377,120 @@ async def _update_data(self): logger.info("开始定时更新数据...") # 下载新数据 - await self._download_geoip() - await self._download_cn_ipv4() - await self._download_geosite() + geoip_changed = await self._download_geoip() + cn_ipv4_changed = await self._download_cn_ipv4() + geosite_changed = await self._download_geosite() - # 重新加载 GeoSite 数据 - await self._load_geosite_data() + # 重新加载 GeoSite 数据(仅文件变化时) + if geosite_changed: + await self._load_geosite_data() + trim_memory("geosite 更新后内存修剪") + elif geoip_changed or cn_ipv4_changed: + trim_memory("数据更新后内存修剪") logger.info("定时更新完成") except Exception as e: logger.error(f"定时更新失败: {e}") + + def _load_meta(self, meta_path: Path) -> Dict[str, Any]: + if not meta_path.exists(): + return {} + try: + with meta_path.open("r", encoding="utf-8") as handle: + return json.load(handle) + except Exception: + return {} - async def _download_with_fallback(self, urls: List[str], dest_path: Path, label: str): - """按顺序尝试多个 URL 下载数据""" + def _save_meta(self, meta_path: Path, meta: Dict[str, Any]) -> None: + try: + meta_path.parent.mkdir(parents=True, exist_ok=True) + with meta_path.open("w", encoding="utf-8") as handle: + json.dump(meta, handle, ensure_ascii=False, indent=2) + except Exception as e: + logger.debug(f"保存 meta 失败: {e}") + + async def _download_with_fallback( + self, + urls: List[str], + dest_path: Path, + label: str, + meta_path: Path + ) -> bool: + """按顺序尝试多个 URL 下载数据,支持条件更新和变更检测""" last_error = None - async with aiohttp.ClientSession() as session: - for url in urls: - try: - async with session.get(url) as response: - if response.status == 200: - with open(dest_path, 'wb') as f: - async for chunk in response.content.iter_chunked(8192): - f.write(chunk) - logger.info("{} 数据下载完成: {}", label, url) - return + session = await self._get_session() + current_meta = self._load_meta(meta_path) + conditional_headers = {} + if current_meta.get("etag"): + conditional_headers["If-None-Match"] = current_meta["etag"] + if current_meta.get("last_modified"): + conditional_headers["If-Modified-Since"] = current_meta["last_modified"] + + for idx, url in enumerate(urls): + try: + # 仅在首选源使用条件请求头,避免不同镜像之间复用 ETag/Last-Modified。 + headers = conditional_headers if idx == 0 else {} + start_ts = time.perf_counter() + async with session.get(url, headers=headers) as response: + if response.status == 304: + logger.info("{} 数据未更新(304): {}", label, url) + METRICS.record_request( + f"data.download.{label}", + (time.perf_counter() - start_ts) * 1000, + success=True + ) + return False + if response.status != 200: last_error = f"下载失败,状态码: {response.status}" logger.warning("{} 数据下载失败: {} (状态码: {})", label, url, response.status) - except Exception as e: - last_error = str(e) - logger.warning("{} 数据下载失败: {} ({})", label, url, e) + continue + + tmp_path = dest_path.with_suffix(dest_path.suffix + ".tmp") + digest = hashlib.sha256() + size = 0 + with tmp_path.open("wb") as handle: + async for chunk in response.content.iter_chunked(8192): + handle.write(chunk) + digest.update(chunk) + size += len(chunk) + + new_hash = digest.hexdigest() + old_hash = current_meta.get("sha256") + changed = True + if dest_path.exists() and old_hash and old_hash == new_hash: + changed = False + tmp_path.unlink(missing_ok=True) + else: + tmp_path.replace(dest_path) + + meta = { + "etag": response.headers.get("ETag"), + "last_modified": response.headers.get("Last-Modified"), + "sha256": new_hash, + "size": size, + "updated_at": datetime.utcnow().isoformat() + "Z", + "source": url + } + self._save_meta(meta_path, meta) + + if changed: + logger.info("{} 数据下载完成: {}", label, url) + else: + logger.info("{} 数据未变化(hash 相同): {}", label, url) + METRICS.record_request( + f"data.download.{label}", + (time.perf_counter() - start_ts) * 1000, + success=True + ) + return changed + except Exception as e: + last_error = str(e) + logger.warning("{} 数据下载失败: {} ({})", label, url, e) + + METRICS.record_request( + f"data.download.{label}", + 0.0, + success=False + ) raise Exception(f"{label} 数据下载失败: {last_error or '所有地址不可用'}") diff --git a/src/handlers/handler_manager.py b/src/handlers/handler_manager.py index baa436c..cab356b 100644 --- a/src/handlers/handler_manager.py +++ b/src/handlers/handler_manager.py @@ -30,10 +30,24 @@ def __init__(self, config: Config, data_manager: DataManager, application=None): self.data_manager = data_manager # 初始化服务 - self.dns_service = DNSService(config.DOH_SERVERS, config.NS_DOH_SERVERS) + self.dns_service = DNSService( + config.DOH_SERVERS, + config.NS_DOH_SERVERS, + cache_size=config.DNS_CACHE_SIZE, + cache_ttl=config.DNS_CACHE_TTL, + ns_cache_size=config.NS_CACHE_SIZE, + ns_cache_ttl=config.NS_CACHE_TTL, + max_concurrency=config.DNS_MAX_CONCURRENCY, + conn_limit=config.DNS_CONN_LIMIT, + conn_limit_per_host=config.DNS_CONN_LIMIT_PER_HOST, + timeout_total=config.DNS_TIMEOUT_TOTAL, + timeout_connect=config.DNS_TIMEOUT_CONNECT, + ) self.geoip_service = GeoIPService( str(data_manager.geoip_file), - str(data_manager.cn_ipv4_file) + str(data_manager.cn_ipv4_file), + cache_size=config.GEOIP_CACHE_SIZE, + cache_ttl=config.GEOIP_CACHE_TTL ) self.github_service = GitHubService(config) self.domain_checker = DomainChecker(self.dns_service, self.geoip_service) @@ -53,6 +67,7 @@ async def start(self): # 用户限制管理 self.user_add_history: Dict[int, list] = defaultdict(list) # 用户添加历史 {user_id: [timestamp1, timestamp2, ...]} + self._last_history_cleanup = 0 self.MAX_DESCRIPTION_LENGTH = 20 # 域名说明最大字符数 self.MAX_ADDS_PER_HOUR = 50 # 每小时最多添加域名数 self.MAX_DETAIL_LINES = 6 # 检查详情最大行数 @@ -178,6 +193,7 @@ def check_user_add_limit(self, user_id: int) -> tuple[bool, int]: Returns: tuple: (是否可以添加, 剩余次数) """ + self._maybe_cleanup_user_history() current_time = time.time() one_hour_ago = current_time - 3600 # 1小时前的时间戳 @@ -192,6 +208,19 @@ def check_user_add_limit(self, user_id: int) -> tuple[bool, int]: remaining = self.MAX_ADDS_PER_HOUR - current_count return current_count < self.MAX_ADDS_PER_HOUR, remaining + + def _maybe_cleanup_user_history(self) -> None: + now = time.time() + if now - self._last_history_cleanup < 600: + return + cutoff = now - 3600 + for uid, timestamps in list(self.user_add_history.items()): + filtered = [ts for ts in timestamps if ts > cutoff] + if filtered: + self.user_add_history[uid] = filtered + else: + self.user_add_history.pop(uid, None) + self._last_history_cleanup = now def record_user_add(self, user_id: int): """记录用户添加操作""" diff --git a/src/main.py b/src/main.py index cf0e937..f8a60fc 100644 --- a/src/main.py +++ b/src/main.py @@ -15,6 +15,7 @@ from .bot import RuleBot from .config import Config from .data_manager import DataManager +from .utils.memory import trim_memory def _configure_logging(): @@ -47,13 +48,14 @@ def _configure_logging(): def set_memory_limit(): - """设置内存限制为 256 MB(软限制,超出时给出警告)""" + """设置内存限制(默认软限制 256 MB,硬限制 512 MB)""" try: - # 256 MB = 256 * 1024 * 1024 bytes - memory_limit = 256 * 1024 * 1024 - # 设置软限制为 256 MB,硬限制为 512 MB(给一些缓冲空间) - resource.setrlimit(resource.RLIMIT_AS, (memory_limit, memory_limit * 2)) - logger.info("已设置内存软限制为 256 MB,硬限制为 512 MB") + soft_mb = int(os.getenv("MEMORY_SOFT_LIMIT_MB", "256")) + hard_mb = int(os.getenv("MEMORY_HARD_LIMIT_MB", str(soft_mb * 2))) + memory_soft = soft_mb * 1024 * 1024 + memory_hard = hard_mb * 1024 * 1024 + resource.setrlimit(resource.RLIMIT_AS, (memory_soft, memory_hard)) + logger.info(f"已设置内存软限制为 {soft_mb} MB,硬限制为 {hard_mb} MB") # 记录当前内存使用情况 try: @@ -119,6 +121,44 @@ def log_memory_usage(): except Exception as e: logger.warning(f"获取内存使用情况失败: {e}") +async def _run(): + """异步主流程(全进程单事件循环)""" + # 初始化配置 + config = Config() + + logger.info("Rule-Bot 正在启动...") + + # 初始化数据管理器(与机器人运行保持同一事件循环) + data_manager = DataManager(config) + await data_manager.initialize() + + # 记录数据加载后的内存使用 + log_memory_usage() + trim_memory("初始化完成后内存修剪") + + # 初始化机器人 + bot = RuleBot(config, data_manager) + + # 启动机器人 + logger.info("启动 Telegram 机器人...") + + # 启动定期内存检查(每 10 分钟检查一次) + import threading + + def memory_monitor(): + while True: + try: + time.sleep(600) # 10 分钟 + log_memory_usage() + except Exception as e: + logger.warning(f"内存监控出错: {e}") + time.sleep(60) # 出错后等待 1 分钟再继续 + + monitor_thread = threading.Thread(target=memory_monitor, daemon=True) + monitor_thread.start() + + await bot.start() + def main(): """主程序入口""" try: @@ -127,46 +167,9 @@ def main(): # 设置内存限制 set_memory_limit() - - # 初始化配置 - config = Config() - - logger.info("Rule-Bot 正在启动...") - - # 初始化数据管理器(在新的事件循环中) - async def init_data(): - data_manager = DataManager(config) - await data_manager.initialize() - return data_manager - - data_manager = asyncio.run(init_data()) - - # 记录数据加载后的内存使用 - log_memory_usage() - - # 初始化机器人 - bot = RuleBot(config, data_manager) - - # 启动机器人 - logger.info("启动 Telegram 机器人...") - - # 启动定期内存检查(每 10 分钟检查一次) - import threading - - def memory_monitor(): - while True: - try: - time.sleep(600) # 10 分钟 - log_memory_usage() - except Exception as e: - logger.warning(f"内存监控出错: {e}") - time.sleep(60) # 出错后等待 1 分钟再继续 - - monitor_thread = threading.Thread(target=memory_monitor, daemon=True) - monitor_thread.start() - - bot.start() - + + asyncio.run(_run()) + except KeyboardInterrupt: logger.info("收到停止信号,正在关闭...") except Exception as e: diff --git a/src/services/dns_service.py b/src/services/dns_service.py index 04ea62c..43424cc 100644 --- a/src/services/dns_service.py +++ b/src/services/dns_service.py @@ -8,28 +8,58 @@ import base64 import struct import socket +import time from typing import List, Optional, Dict, Any from loguru import logger +from ..utils.cache import TTLCache +from ..utils.metrics import METRICS + class DNSService: """DNS 服务""" - def __init__(self, doh_servers: Dict[str, str], ns_doh_servers: Dict[str, str] = None): + def __init__( + self, + doh_servers: Dict[str, str], + ns_doh_servers: Dict[str, str] = None, + cache_size: int = 1024, + cache_ttl: int = 60, + ns_cache_size: int = 512, + ns_cache_ttl: int = 300, + max_concurrency: int = 20, + conn_limit: int = 30, + conn_limit_per_host: int = 10, + timeout_total: int = 10, + timeout_connect: int = 3, + ): self.doh_servers = doh_servers self.ns_doh_servers = ns_doh_servers or doh_servers self.session: Optional[aiohttp.ClientSession] = None + self._a_cache = TTLCache(cache_size, cache_ttl) + self._ns_cache = TTLCache(ns_cache_size, ns_cache_ttl) + self._semaphore = asyncio.Semaphore(max_concurrency) + self._conn_limit = conn_limit + self._conn_limit_per_host = conn_limit_per_host + self._timeout_total = timeout_total + self._timeout_connect = timeout_connect async def start(self): """启动 DNS 服务,初始化共享 Session""" if not self.session or self.session.closed: connector = aiohttp.TCPConnector( - limit=100, # 增加连接限制 - limit_per_host=10, + limit=self._conn_limit, + limit_per_host=self._conn_limit_per_host, ttl_dns_cache=300, use_dns_cache=True ) - self.session = aiohttp.ClientSession(connector=connector) + self.session = aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout( + total=self._timeout_total, + connect=self._timeout_connect + ) + ) logger.info("DNS 服务已启动,Session 已初始化") async def close(self): @@ -40,6 +70,13 @@ async def close(self): async def query_a_record(self, domain: str, use_edns_china: bool = True) -> List[str]: """查询 A 记录,返回 IP 地址列表(并发查询所有 DoH 服务器)""" + cache_key = (domain, use_edns_china) + cached = self._a_cache.get(cache_key) + if cached is not None: + METRICS.inc("dns.cache.a.hit") + return cached + METRICS.inc("dns.cache.a.miss") + start_ts = time.perf_counter() try: # 确保 Session 已启动 if not self.session or self.session.closed: @@ -63,24 +100,46 @@ async def query_a_record(self, domain: str, use_edns_china: bool = True) -> List ips = await future if ips: logger.debug(f"DoH 查询 {domain} 成功,获得 {len(ips)} 个 IP") + self._a_cache.set(cache_key, ips) # 取消其他未完成的任务 for task in tasks: if not task.done(): task.cancel() + METRICS.record_request( + "dns.query_a", + (time.perf_counter() - start_ts) * 1000, + success=True + ) return ips except Exception: # 单个任务失败不影响其他任务 continue logger.warning(f"所有 DoH 服务器查询域名 {domain} 都失败") + METRICS.record_request( + "dns.query_a", + (time.perf_counter() - start_ts) * 1000, + success=False + ) return [] except Exception as e: logger.error(f"DNS 查询失败: {e}") + METRICS.record_request( + "dns.query_a", + (time.perf_counter() - start_ts) * 1000, + success=False + ) return [] async def query_ns_records(self, domain: str) -> List[str]: """查询 NS 记录,返回权威域名服务器列表(并发查询)""" + cached = self._ns_cache.get(domain) + if cached is not None: + METRICS.inc("dns.cache.ns.hit") + return cached + METRICS.inc("dns.cache.ns.miss") + start_ts = time.perf_counter() try: # 确保 Session 已启动 if not self.session or self.session.closed: @@ -103,10 +162,16 @@ async def query_ns_records(self, domain: str) -> List[str]: ns_servers = await future if ns_servers: logger.debug(f"DoH 查询 {domain} NS 记录成功") + self._ns_cache.set(domain, ns_servers) # 取消其他任务 for task in tasks: if not task.done(): task.cancel() + METRICS.record_request( + "dns.query_ns", + (time.perf_counter() - start_ts) * 1000, + success=True + ) return ns_servers except Exception: continue @@ -116,13 +181,29 @@ async def query_ns_records(self, domain: str) -> List[str]: ns_servers = await self._query_ns_system_dns(domain) if ns_servers: logger.debug(f"使用系统 DNS 查询 {domain} NS 记录成功") + self._ns_cache.set(domain, ns_servers) + METRICS.record_request( + "dns.query_ns", + (time.perf_counter() - start_ts) * 1000, + success=True + ) return ns_servers logger.warning(f"所有 NS 记录查询方法都失败,域名: {domain}") + METRICS.record_request( + "dns.query_ns", + (time.perf_counter() - start_ts) * 1000, + success=False + ) return [] except Exception as e: logger.error(f"NS 记录查询失败: {e}") + METRICS.record_request( + "dns.query_ns", + (time.perf_counter() - start_ts) * 1000, + success=False + ) return [] async def _query_ns_system_dns(self, domain: str) -> List[str]: @@ -197,32 +278,38 @@ def _build_dns_query(self, domain: str, use_edns_china: bool = True, record_type logger.error(f"构建 DNS 查询包失败: {e}") return b'' - async def _perform_doh_query(self, server_name: str, server_url: str, query_data: bytes, parser_func) -> List[str]: + async def _perform_doh_query( + self, + server_name: str, + server_url: str, + query_data: bytes, + parser_func + ) -> List[str]: """执行 DoH 查询通用方法""" max_retries = 2 for attempt in range(max_retries): try: - encoded_query = base64.urlsafe_b64encode(query_data).decode().rstrip('=') - url = f"{server_url}?dns={encoded_query}" - - # 使用共享的session - async with self.session.get( - url, - headers={ - 'Accept': 'application/dns-message', - 'User-Agent': 'Rule-Bot DNS Client/1.0' - }, - timeout=aiohttp.ClientTimeout(total=10, connect=3) - ) as response: - if response.status == 200: - response_data = await response.read() - result = parser_func(response_data) - if result: - return result - # 如果解析结果为空但状态码200,可能是没有该记录,不一定是错误,但也重试一下 - else: - # logger.warning(f"{server_name} HTTP error: {response.status}") - pass + async with self._semaphore: + encoded_query = base64.urlsafe_b64encode(query_data).decode().rstrip('=') + url = f"{server_url}?dns={encoded_query}" + + # 使用共享的session + async with self.session.get( + url, + headers={ + 'Accept': 'application/dns-message', + 'User-Agent': 'Rule-Bot DNS Client/1.0' + } + ) as response: + if response.status == 200: + response_data = await response.read() + result = parser_func(response_data) + if result: + return result + # 如果解析结果为空但状态码200,可能是没有该记录,不一定是错误,但也重试一下 + else: + # logger.warning(f"{server_name} HTTP error: {response.status}") + pass except asyncio.CancelledError: raise # 允许被取消 diff --git a/src/services/geoip_service.py b/src/services/geoip_service.py index 9fea1d4..ef903ad 100644 --- a/src/services/geoip_service.py +++ b/src/services/geoip_service.py @@ -21,12 +21,21 @@ class GeoIPService: """GeoIP 服务""" - def __init__(self, geoip_file_path: str, cn_ipv4_file_path: Optional[str] = None): + def __init__( + self, + geoip_file_path: str, + cn_ipv4_file_path: Optional[str] = None, + cache_size: int = 4096, + cache_ttl: int = 21600 + ): self.geoip_file = Path(geoip_file_path) self.cn_ipv4_file = Path(cn_ipv4_file_path) if cn_ipv4_file_path else None self.reader = None self._cn_ipv4_ranges = [] self._cn_ipv4_range_starts = [] + self._cache_size = cache_size + self._cache_ttl = cache_ttl + self._location_cache = None self._load_data() def _load_data(self): @@ -42,15 +51,31 @@ def _load_data(self): logger.info(f"GeoIP 数据库加载成功: {self.geoip_file}") self._load_cn_ipv4() + self._init_cache() except Exception as e: logger.error(f"加载 GeoIP 数据失败: {e}") + + def _init_cache(self): + try: + if self._cache_size <= 0 or self._cache_ttl <= 0: + self._location_cache = None + return + from ..utils.cache import TTLCache + self._location_cache = TTLCache(self._cache_size, self._cache_ttl) + except Exception: + self._location_cache = None def get_country_code(self, ip: str) -> Optional[str]: """获取 IP 的国家代码""" try: # 验证 IP 格式 socket.inet_aton(ip) + + if self._location_cache: + cached = self._location_cache.get(ip) + if cached is not None: + return cached.get("country_code") # 如果有真实的 GeoIP2 数据库 if self.reader: @@ -144,6 +169,10 @@ def is_china_ip(self, ip: str) -> bool: def get_location_info(self, ip: str) -> Dict[str, Any]: """获取 IP 的详细位置信息""" try: + if self._location_cache: + cached = self._location_cache.get(ip) + if cached is not None: + return cached country_code = self.get_country_code(ip) # 如果使用真实数据库且找到结果 @@ -152,12 +181,15 @@ def get_location_info(self, ip: str) -> Dict[str, Any]: response = self.reader.country(ip) country_name = response.country.names.get('zh-CN') or response.country.name or "未知" - return { + result = { "ip": ip, "country_code": country_code, "country_name": country_name, "is_china": country_code == "CN" } + if self._location_cache: + self._location_cache.set(ip, result) + return result except Exception: pass @@ -175,21 +207,27 @@ def get_location_info(self, ip: str) -> Dict[str, Any]: "FR": "法国", } - return { + result = { "ip": ip, "country_code": country_code, "country_name": country_names.get(country_code, "未知" if country_code else "未知"), "is_china": country_code == "CN" if country_code else False } + if self._location_cache: + self._location_cache.set(ip, result) + return result except Exception as e: logger.error(f"获取 IP 位置信息失败: {e}") - return { + result = { "ip": ip, "country_code": None, "country_name": "未知", "is_china": False } + if self._location_cache: + self._location_cache.set(ip, result) + return result def __del__(self): """关闭数据库连接""" diff --git a/src/services/github_service.py b/src/services/github_service.py index cfc1924..a3d47ae 100644 --- a/src/services/github_service.py +++ b/src/services/github_service.py @@ -5,12 +5,16 @@ import asyncio import base64 +import io +import time from datetime import datetime from typing import Optional, List, Dict, Any from loguru import logger from github import Github, GithubException, InputGitAuthor from ..config import Config +from ..utils.cache import TTLCache +from ..utils.metrics import METRICS class GitHubService: @@ -20,6 +24,10 @@ def __init__(self, config: Config): self.config = config self.github = Github(config.GITHUB_TOKEN) self.repo = None + self._file_cache = TTLCache( + getattr(config, "GITHUB_FILE_CACHE_SIZE", 0), + getattr(config, "GITHUB_FILE_CACHE_TTL", 0) + ) self._initialize_repo() def _initialize_repo(self): @@ -80,37 +88,69 @@ def test_connection(self) -> Dict[str, Any]: "error": str(e) } - async def get_rule_file_content(self, file_path: str) -> Optional[str]: + async def get_rule_file_content(self, file_path: str, use_cache: bool = True) -> Optional[str]: """获取规则文件内容""" try: logger.debug(f"正在获取文件内容: {file_path}") + if use_cache: + cached = self._file_cache.get(file_path) + if cached and "content" in cached: + METRICS.inc("github.cache.hit") + return cached["content"] + METRICS.inc("github.cache.miss") + + start_ts = time.perf_counter() # 使用 asyncio.to_thread 在线程池中执行阻塞IO file_content = await asyncio.to_thread(self.repo.get_contents, file_path) content = base64.b64decode(file_content.content).decode('utf-8') + self._file_cache.set(file_path, {"content": content, "sha": getattr(file_content, "sha", None)}) + METRICS.record_request( + "github.get_contents", + (time.perf_counter() - start_ts) * 1000, + success=True + ) logger.debug(f"成功获取文件内容: {file_path}, 长度: {len(content)} 字符") return content except GithubException as e: logger.error(f"GitHub API 获取文件失败: {file_path}, status={getattr(e, 'status', 'unknown')}, message={getattr(e, 'data', {}).get('message', str(e))}") + METRICS.record_request("github.get_contents", 0.0, success=False) return None except Exception as e: logger.error(f"获取文件内容失败: {file_path}, {type(e).__name__}: {e}", exc_info=True) + METRICS.record_request("github.get_contents", 0.0, success=False) return None - async def get_rule_file_data(self, file_path: str) -> Optional[Dict[str, Any]]: + async def get_rule_file_data(self, file_path: str, use_cache: bool = True) -> Optional[Dict[str, Any]]: """获取规则文件内容和 SHA""" try: logger.debug(f"正在获取文件内容和 SHA: {file_path}") + if use_cache: + cached = self._file_cache.get(file_path) + if cached and "content" in cached and cached.get("sha"): + METRICS.inc("github.cache.hit") + return {"content": cached["content"], "sha": cached["sha"]} + METRICS.inc("github.cache.miss") + + start_ts = time.perf_counter() file_content = await asyncio.to_thread(self.repo.get_contents, file_path) content = base64.b64decode(file_content.content).decode('utf-8') + self._file_cache.set(file_path, {"content": content, "sha": file_content.sha}) + METRICS.record_request( + "github.get_contents", + (time.perf_counter() - start_ts) * 1000, + success=True + ) return {"content": content, "sha": file_content.sha} except GithubException as e: logger.error( f"GitHub API 获取文件失败: {file_path}, status={getattr(e, 'status', 'unknown')}, " f"message={getattr(e, 'data', {}).get('message', str(e))}" ) + METRICS.record_request("github.get_contents", 0.0, success=False) return None except Exception as e: logger.error(f"获取文件内容和 SHA 失败: {file_path}, {type(e).__name__}: {e}", exc_info=True) + METRICS.record_request("github.get_contents", 0.0, success=False) return None async def check_domain_in_rules(self, domain: str, file_path: str = None) -> Dict[str, Any]: @@ -125,11 +165,10 @@ async def check_domain_in_rules(self, domain: str, file_path: str = None) -> Dic # CPU密集型操作也在线程池中执行,避免阻塞事件循环 def _process_content(): - lines = content.split('\n') domain_lower = domain.lower() found_rules = [] - for line_num, line in enumerate(lines, 1): + for line_num, line in enumerate(io.StringIO(content), 1): line = line.strip() if line and not line.startswith('#'): # 检查 DOMAIN-SUFFIX 格式 @@ -149,7 +188,13 @@ def _process_content(): }) return found_rules + start_ts = time.perf_counter() found_rules = await asyncio.to_thread(_process_content) + METRICS.record_request( + "github.check_rules", + (time.perf_counter() - start_ts) * 1000, + success=True + ) return { "exists": len(found_rules) > 0, @@ -159,6 +204,7 @@ def _process_content(): except Exception as e: logger.error(f"检查域名规则失败: {e}") + METRICS.record_request("github.check_rules", 0.0, success=False) return {"exists": False, "error": str(e)} async def add_domain_to_rules( @@ -187,7 +233,7 @@ async def add_domain_to_rules( for attempt in range(1, max_retries + 1): # 获取当前文件内容和 SHA logger.debug(f"开始添加域名 {domain} 到文件 {file_path} (尝试 {attempt}/{max_retries})") - file_data = await self.get_rule_file_data(file_path) + file_data = await self.get_rule_file_data(file_path, use_cache=(attempt == 1)) if not file_data: error_msg = f"无法获取规则文件内容: {file_path}。请检查文件是否存在,仓库访问权限是否正确。" logger.error(error_msg) @@ -292,8 +338,15 @@ def _perform_commit(): ) try: + start_ts = time.perf_counter() commit_result = await asyncio.to_thread(_perform_commit) + METRICS.record_request( + "github.update_file", + (time.perf_counter() - start_ts) * 1000, + success=True + ) except GithubException as e: + METRICS.record_request("github.update_file", 0.0, success=False) if getattr(e, "status", None) == 409 and attempt < max_retries: logger.warning("GitHub 更新冲突,准备重试") await asyncio.sleep(0.5 * attempt) @@ -305,6 +358,7 @@ def _perform_commit(): commit_url = f"https://github.com/{self.config.GITHUB_REPO}/commit/{commit_sha}" logger.info(f"成功添加域名 {domain} 到规则文件,commit: {commit_sha}") + self._file_cache.pop(file_path) return { "success": True, @@ -340,7 +394,7 @@ async def remove_domain_from_rules(self, domain: str, user_name: str, file_path: max_retries = 3 for attempt in range(1, max_retries + 1): - file_data = await self.get_rule_file_data(file_path) + file_data = await self.get_rule_file_data(file_path, use_cache=(attempt == 1)) if not file_data: return {"success": False, "error": "无法获取文件内容"} @@ -407,8 +461,15 @@ def _perform_commit(): ) try: + start_ts = time.perf_counter() commit_result = await asyncio.to_thread(_perform_commit) + METRICS.record_request( + "github.update_file", + (time.perf_counter() - start_ts) * 1000, + success=True + ) except GithubException as e: + METRICS.record_request("github.update_file", 0.0, success=False) if getattr(e, "status", None) == 409 and attempt < max_retries: logger.warning("GitHub 更新冲突,准备重试") await asyncio.sleep(0.5 * attempt) @@ -420,6 +481,7 @@ def _perform_commit(): commit_url = f"https://github.com/{self.config.GITHUB_REPO}/commit/{commit_sha}" logger.info(f"成功删除域名 {domain} 从规则文件,commit: {commit_sha}") + self._file_cache.pop(file_path) return { "success": True, @@ -449,11 +511,12 @@ async def get_file_stats(self, file_path: str = None) -> Dict[str, Any]: if not content: return {"error": "无法获取文件内容"} - lines = content.split('\n') rule_count = 0 comment_count = 0 + total_lines = 0 - for line in lines: + for line in io.StringIO(content): + total_lines += 1 line = line.strip() if line: if line.startswith('#'): @@ -463,7 +526,7 @@ async def get_file_stats(self, file_path: str = None) -> Dict[str, Any]: return { "file_path": file_path, - "total_lines": len(lines), + "total_lines": total_lines, "rule_count": rule_count, "comment_count": comment_count } diff --git a/src/utils/cache.py b/src/utils/cache.py new file mode 100644 index 0000000..c3a7df2 --- /dev/null +++ b/src/utils/cache.py @@ -0,0 +1,59 @@ +""" +Simple bounded TTL cache utilities. +""" + +from __future__ import annotations + +from collections import OrderedDict +from time import monotonic +from typing import Generic, Optional, TypeVar + +K = TypeVar("K") +V = TypeVar("V") + + +class TTLCache(Generic[K, V]): + def __init__(self, maxsize: int, ttl_seconds: float): + self.maxsize = max(0, int(maxsize)) + self.ttl_seconds = float(ttl_seconds) + self._data: OrderedDict[K, tuple[float, V]] = OrderedDict() + + def get(self, key: K) -> Optional[V]: + if self.maxsize <= 0: + return None + now = monotonic() + item = self._data.get(key) + if not item: + return None + expires_at, value = item + if expires_at <= now: + self._data.pop(key, None) + return None + self._data.move_to_end(key) + return value + + def set(self, key: K, value: V) -> None: + if self.maxsize <= 0: + return + now = monotonic() + expires_at = now + self.ttl_seconds if self.ttl_seconds > 0 else now + self._data[key] = (expires_at, value) + self._data.move_to_end(key) + self._evict(now) + + def pop(self, key: K) -> None: + self._data.pop(key, None) + + def clear(self) -> None: + self._data.clear() + + def __len__(self) -> int: + return len(self._data) + + def _evict(self, now: float) -> None: + if self.ttl_seconds > 0 and self._data: + expired_keys = [k for k, (exp, _) in list(self._data.items()) if exp <= now] + for key in expired_keys: + self._data.pop(key, None) + while len(self._data) > self.maxsize: + self._data.popitem(last=False) diff --git a/src/utils/memory.py b/src/utils/memory.py new file mode 100644 index 0000000..90662ef --- /dev/null +++ b/src/utils/memory.py @@ -0,0 +1,47 @@ +""" +Memory trimming helpers (Linux/glibc). +""" + +from __future__ import annotations + +import ctypes +import ctypes.util +import os + +from loguru import logger + + +def _env_bool(key: str, default: bool = True) -> bool: + raw = os.getenv(key, "").strip().lower() + if not raw: + return default + return raw in ("1", "true", "yes", "on") + + +MEMORY_TRIM_ENABLED = _env_bool("MEMORY_TRIM_ENABLED", True) +_LIBC = None +_HAS_MALLOC_TRIM = False + +try: + libc_name = ctypes.util.find_library("c") + if libc_name: + _LIBC = ctypes.CDLL(libc_name) + _HAS_MALLOC_TRIM = hasattr(_LIBC, "malloc_trim") +except Exception: + _LIBC = None + _HAS_MALLOC_TRIM = False + + +def trim_memory(reason: str = "") -> bool: + if not MEMORY_TRIM_ENABLED: + return False + if not _LIBC or not _HAS_MALLOC_TRIM: + return False + try: + result = _LIBC.malloc_trim(0) + if result and reason: + logger.debug(f"已触发内存回收: {reason}") + return bool(result) + except Exception as e: + logger.debug(f"触发内存回收失败: {e}") + return False diff --git a/src/utils/metrics.py b/src/utils/metrics.py new file mode 100644 index 0000000..5f242f9 --- /dev/null +++ b/src/utils/metrics.py @@ -0,0 +1,176 @@ +""" +Lightweight in-process metrics with optional periodic export. +""" + +from __future__ import annotations + +import asyncio +import json +import os +import threading +import time +from bisect import bisect_right +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, Optional + +from loguru import logger + + +DEFAULT_LATENCY_BUCKETS_MS = (5, 10, 25, 50, 100, 250, 500, 1000, 2000, 5000) + + +def _env_bool(key: str, default: bool = False) -> bool: + raw = os.getenv(key, "").strip().lower() + if not raw: + return default + return raw in ("1", "true", "yes", "on") + + +def _env_int(key: str, default: int) -> int: + raw = os.getenv(key, "").strip() + if not raw: + return default + try: + return int(raw) + except ValueError: + return default + + +@dataclass +class HistogramSnapshot: + buckets: Iterable[int] + counts: Iterable[int] + count: int + total: float + + +class Histogram: + def __init__(self, buckets_ms: Iterable[int] = DEFAULT_LATENCY_BUCKETS_MS): + self.buckets = tuple(sorted(buckets_ms)) + self.counts = [0] * (len(self.buckets) + 1) + self.count = 0 + self.total = 0.0 + + def observe(self, value_ms: float) -> None: + idx = bisect_right(self.buckets, value_ms) + self.counts[idx] += 1 + self.count += 1 + self.total += value_ms + + def snapshot(self) -> HistogramSnapshot: + return HistogramSnapshot( + buckets=self.buckets, + counts=tuple(self.counts), + count=self.count, + total=self.total, + ) + + +class MetricsStore: + def __init__(self, enabled: bool): + self.enabled = enabled + self._lock = threading.Lock() + self._counters: Dict[str, int] = {} + self._histograms: Dict[str, Histogram] = {} + self._start_time = time.monotonic() + + def inc(self, name: str, value: int = 1) -> None: + if not self.enabled: + return + with self._lock: + self._counters[name] = self._counters.get(name, 0) + value + + def observe(self, name: str, value_ms: float) -> None: + if not self.enabled: + return + with self._lock: + hist = self._histograms.get(name) + if not hist: + hist = Histogram() + self._histograms[name] = hist + hist.observe(value_ms) + + def record_request(self, name: str, duration_ms: float, success: bool = True) -> None: + if not self.enabled: + return + status = "ok" if success else "fail" + self.inc(f"{name}.count") + self.inc(f"{name}.{status}") + self.observe(f"{name}.latency_ms", duration_ms) + + def snapshot(self, reset: bool = False) -> Dict[str, object]: + if not self.enabled: + return {} + with self._lock: + counters = dict(self._counters) + histograms = { + key: hist.snapshot().__dict__ for key, hist in self._histograms.items() + } + uptime = time.monotonic() - self._start_time + if reset: + self._counters.clear() + self._histograms.clear() + return { + "counters": counters, + "histograms": histograms, + "uptime_s": round(uptime, 3), + } + + +def _atomic_write_json(path: Path, data: Dict[str, object]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + ".tmp") + with tmp_path.open("w", encoding="utf-8") as handle: + json.dump(data, handle, ensure_ascii=False, indent=2) + tmp_path.replace(path) + + +class MetricsExporter: + def __init__(self, metrics: MetricsStore, path: Path, interval: int, reset: bool = False): + self.metrics = metrics + self.path = path + self.interval = max(1, int(interval)) + self.reset = reset + self._task: Optional[asyncio.Task] = None + + def start(self) -> Optional[asyncio.Task]: + if not self.metrics.enabled: + return None + if self._task and not self._task.done(): + return self._task + self._task = asyncio.create_task(self._run()) + return self._task + + async def stop(self) -> None: + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + # stop() 主动取消任务,CancelledError 属于预期。 + logger.debug("metrics 导出任务在停止时被取消") + + async def _run(self) -> None: + while True: + await asyncio.sleep(self.interval) + snapshot = self.metrics.snapshot(reset=self.reset) + if not snapshot: + continue + snapshot["timestamp"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + snapshot["asyncio_task_count"] = len(asyncio.all_tasks()) + try: + _atomic_write_json(self.path, snapshot) + except Exception as e: + logger.debug(f"写入 metrics 失败: {e}") + + +METRICS_ENABLED = _env_bool("METRICS_ENABLED", False) +METRICS_EXPORT_PATH = Path(os.getenv("METRICS_EXPORT_PATH", "/tmp/rule-bot-metrics.json")) +METRICS_EXPORT_INTERVAL = _env_int("METRICS_EXPORT_INTERVAL", 30) +METRICS_RESET_ON_EXPORT = _env_bool("METRICS_RESET_ON_EXPORT", False) + +METRICS = MetricsStore(enabled=METRICS_ENABLED) +EXPORTER = MetricsExporter( + METRICS, METRICS_EXPORT_PATH, METRICS_EXPORT_INTERVAL, METRICS_RESET_ON_EXPORT +) diff --git a/tests/test_data_manager.py b/tests/test_data_manager.py new file mode 100644 index 0000000..7f2ba09 --- /dev/null +++ b/tests/test_data_manager.py @@ -0,0 +1,65 @@ +import asyncio +import os +import sys +import tempfile +import unittest +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +# Add repo root to path +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +from src.data_manager import DataManager + + +def _build_config(data_dir: str, interval: float = 0.05) -> SimpleNamespace: + return SimpleNamespace( + DATA_DIR=data_dir, + GEOSITE_CACHE_SIZE=32, + GEOSITE_CACHE_TTL=60, + DATA_UPDATE_INTERVAL=interval, + GEOIP_URLS=[], + CN_IPV4_URLS=[], + GEOSITE_URL="", + ) + + +class TestDataManagerScheduling(unittest.IsolatedAsyncioTestCase): + async def test_scheduler_runs_and_stops_cleanly(self): + with tempfile.TemporaryDirectory() as temp_dir: + config = _build_config(temp_dir, interval=0.05) + manager = DataManager(config) + + update_called = asyncio.Event() + + async def _fake_update(): + update_called.set() + + with patch.object(manager, "_download_initial_data", AsyncMock()): + with patch.object(manager, "_update_data", AsyncMock(side_effect=_fake_update)): + await manager.initialize() + await asyncio.wait_for(update_called.wait(), timeout=0.5) + self.assertIsNotNone(manager._scheduler_task) + self.assertFalse(manager._scheduler_task.done()) + + await manager.close() + self.assertIsNone(manager._scheduler_task) + + async def test_session_lifecycle(self): + with tempfile.TemporaryDirectory() as temp_dir: + config = _build_config(temp_dir, interval=3600) + manager = DataManager(config) + + session1 = await manager._get_session() + session2 = await manager._get_session() + + self.assertIs(session1, session2) + self.assertFalse(session1.closed) + + await manager.close() + self.assertTrue(session1.closed) + self.assertIsNone(manager._session) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_utils_cache_metrics.py b/tests/test_utils_cache_metrics.py new file mode 100644 index 0000000..a97b7b8 --- /dev/null +++ b/tests/test_utils_cache_metrics.py @@ -0,0 +1,44 @@ +import os +import sys +import time +import unittest + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +from src.utils.cache import TTLCache +from src.utils.metrics import MetricsStore + + +class TestTTLCache(unittest.TestCase): + def test_ttl_cache_eviction(self): + cache = TTLCache(maxsize=2, ttl_seconds=60) + cache.set("a", 1) + cache.set("b", 2) + cache.set("c", 3) + self.assertIsNone(cache.get("a")) + self.assertEqual(cache.get("b"), 2) + self.assertEqual(cache.get("c"), 3) + + def test_ttl_cache_expire(self): + cache = TTLCache(maxsize=2, ttl_seconds=0.01) + cache.set("a", 1) + time.sleep(0.05) + self.assertIsNone(cache.get("a")) + + +class TestMetricsStore(unittest.TestCase): + def test_metrics_snapshot(self): + metrics = MetricsStore(enabled=True) + metrics.inc("counter.test") + metrics.observe("latency.test", 12.5) + metrics.record_request("req.test", 5.0, success=True) + + snap = metrics.snapshot() + self.assertIn("counters", snap) + self.assertIn("histograms", snap) + self.assertGreaterEqual(snap["counters"].get("counter.test", 0), 1) + self.assertIn("req.test.count", snap["counters"]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/profile_runtime.py b/tools/profile_runtime.py new file mode 100644 index 0000000..975c46e --- /dev/null +++ b/tools/profile_runtime.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +""" +Runtime profiler for Rule-Bot. +Collects RSS, VMS, VmSwap, CPU%, threads, asyncio tasks, and request metrics. +""" + +from __future__ import annotations + +import argparse +import json +import os +import time +from pathlib import Path +from typing import Dict, Optional + +import psutil + + +def read_vmswap_kb(pid: int) -> int: + status_path = Path(f"/proc/{pid}/status") + if not status_path.exists(): + return 0 + try: + with status_path.open("r", encoding="utf-8") as handle: + for line in handle: + if line.startswith("VmSwap:"): + parts = line.split() + if len(parts) >= 2: + return int(parts[1]) + except Exception: + return 0 + return 0 + + +def pick_pid_by_name(name: str) -> Optional[int]: + candidates = [] + for proc in psutil.process_iter(["pid", "name", "cmdline", "memory_info"]): + try: + pname = proc.info.get("name") or "" + cmdline = " ".join(proc.info.get("cmdline") or []) + if name in pname or name in cmdline: + rss = proc.info.get("memory_info").rss if proc.info.get("memory_info") else 0 + candidates.append((rss, proc.info["pid"])) + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + if not candidates: + return None + candidates.sort(reverse=True) + return candidates[0][1] + + +def load_metrics(path: Path) -> Dict[str, object]: + if not path.exists(): + return {} + try: + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) + except Exception: + return {} + + +def format_histogram(hist: Dict[str, object]) -> str: + buckets = hist.get("buckets", []) + counts = hist.get("counts", []) + if not buckets or not counts: + return "" + parts = [] + for idx, upper in enumerate(buckets): + parts.append(f"<= {upper}ms: {counts[idx]}") + parts.append(f"> {buckets[-1]}ms: {counts[-1]}") + return "; ".join(parts) + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--pid", type=int, help="Target PID") + parser.add_argument("--process-name", default="python", help="Process name/cmdline match") + parser.add_argument("--interval", type=float, default=5.0, help="Sample interval seconds") + parser.add_argument("--duration", type=float, default=60.0, help="Total duration seconds") + parser.add_argument("--output", type=Path, help="CSV output path") + parser.add_argument( + "--metrics-path", + type=Path, + default=Path(os.getenv("METRICS_EXPORT_PATH", "/tmp/rule-bot-metrics.json")), + ) + args = parser.parse_args() + + pid = args.pid or pick_pid_by_name(args.process_name) + if not pid: + print("No process found.") + return 1 + + proc = psutil.Process(pid) + proc.cpu_percent(interval=None) + + header = [ + "ts", + "rss_mb", + "vms_mb", + "vmswap_mb", + "cpu_percent", + "threads", + "asyncio_tasks", + ] + + out_handle = None + if args.output: + out_handle = args.output.open("w", encoding="utf-8") + out_handle.write(",".join(header) + "\n") + + start = time.time() + while time.time() - start <= args.duration: + time.sleep(args.interval) + try: + mem = proc.memory_info() + except psutil.NoSuchProcess: + break + + rss_mb = mem.rss / 1024 / 1024 + vms_mb = mem.vms / 1024 / 1024 + swap_mb = read_vmswap_kb(pid) / 1024 + cpu = proc.cpu_percent(interval=None) + threads = proc.num_threads() + + metrics = load_metrics(args.metrics_path) + asyncio_tasks = metrics.get("asyncio_task_count", 0) + + ts = time.strftime("%H:%M:%S") + row = [ts, f"{rss_mb:.1f}", f"{vms_mb:.1f}", f"{swap_mb:.1f}", f"{cpu:.1f}", str(threads), str(asyncio_tasks)] + line = ",".join(row) + print(line) + if out_handle: + out_handle.write(line + "\n") + + if out_handle: + out_handle.close() + + metrics = load_metrics(args.metrics_path) + if metrics: + print("\n[Metrics Snapshot]") + counters = metrics.get("counters", {}) + for key, value in sorted(counters.items()): + print(f"{key}: {value}") + histograms = metrics.get("histograms", {}) + for key, hist in histograms.items(): + summary = format_histogram(hist) + if summary: + print(f"{key}: {summary}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/stress_sim.py b/tools/stress_sim.py new file mode 100644 index 0000000..d72b97b --- /dev/null +++ b/tools/stress_sim.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +""" +10-minute stress simulation for Rule-Bot core logic. +Exercises GeoSite checks, DNS/NS resolution, and optional GitHub rule checks. +""" + +from __future__ import annotations + +import argparse +import asyncio +import os +import random +import time +from types import SimpleNamespace +from typing import List + +from loguru import logger + +from src.data_manager import DataManager +from src.services.dns_service import DNSService +from src.services.geoip_service import GeoIPService +from src.services.domain_checker import DomainChecker +from src.services.github_service import GitHubService + + +DEFAULT_DOMAINS = [ + "example.com", + "www.google.com", + "www.bing.com", + "www.cloudflare.com", + "github.com", + "openai.com", + "www.baidu.com", + "www.qq.com", + "www.taobao.com", + "www.jd.com", +] + + +def load_domains(path: str) -> List[str]: + if not path: + return DEFAULT_DOMAINS + domains = [] + with open(path, "r", encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line or line.startswith("#"): + continue + domains.append(line) + return domains or DEFAULT_DOMAINS + + +def build_data_config() -> SimpleNamespace: + return SimpleNamespace( + DATA_DIR=os.getenv("DATA_DIR", "/tmp/rule-bot-data"), + DATA_UPDATE_INTERVAL=int(os.getenv("DATA_UPDATE_INTERVAL", "21600")), + GEOIP_URLS=[ + "https://gcore.jsdelivr.net/gh/Aethersailor/geoip@release/Country-without-asn.mmdb", + "https://testingcf.jsdelivr.net/gh/Aethersailor/geoip@release/Country-without-asn.mmdb", + "https://raw.githubusercontent.com/Aethersailor/geoip/release/Country-without-asn.mmdb", + ], + CN_IPV4_URLS=[ + "https://raw.githubusercontent.com/Aethersailor/geoip/refs/heads/release/text/cn-ipv4.txt", + "https://gcore.jsdelivr.net/gh/Aethersailor/geoip@release/text/cn-ipv4.txt", + "https://testingcf.jsdelivr.net/gh/Aethersailor/geoip@release/text/cn-ipv4.txt", + ], + GEOSITE_URL="https://raw.githubusercontent.com/Loyalsoldier/v2ray-rules-dat/refs/heads/release/direct-list.txt", + GEOSITE_CACHE_TTL=int(os.getenv("GEOSITE_CACHE_TTL", "3600")), + GEOSITE_CACHE_SIZE=int(os.getenv("GEOSITE_CACHE_SIZE", "2048")), + ) + + +def build_dns_config() -> SimpleNamespace: + return SimpleNamespace( + DOH_SERVERS={ + "alibaba": "https://dns.alidns.com/dns-query", + "tencent": "https://doh.pub/dns-query", + "cloudflare": "https://cloudflare-dns.com/dns-query", + }, + NS_DOH_SERVERS={ + "cloudflare": "https://cloudflare-dns.com/dns-query", + "google": "https://dns.google/dns-query", + "quad9": "https://dns.quad9.net/dns-query", + }, + DNS_CACHE_TTL=int(os.getenv("DNS_CACHE_TTL", "60")), + DNS_CACHE_SIZE=int(os.getenv("DNS_CACHE_SIZE", "1024")), + NS_CACHE_TTL=int(os.getenv("NS_CACHE_TTL", "300")), + NS_CACHE_SIZE=int(os.getenv("NS_CACHE_SIZE", "512")), + DNS_MAX_CONCURRENCY=int(os.getenv("DNS_MAX_CONCURRENCY", "20")), + DNS_CONN_LIMIT=int(os.getenv("DNS_CONN_LIMIT", "30")), + DNS_CONN_LIMIT_PER_HOST=int(os.getenv("DNS_CONN_LIMIT_PER_HOST", "10")), + DNS_TIMEOUT_TOTAL=int(os.getenv("DNS_TIMEOUT_TOTAL", "10")), + DNS_TIMEOUT_CONNECT=int(os.getenv("DNS_TIMEOUT_CONNECT", "3")), + GEOIP_CACHE_SIZE=int(os.getenv("GEOIP_CACHE_SIZE", "4096")), + GEOIP_CACHE_TTL=int(os.getenv("GEOIP_CACHE_TTL", "21600")), + ) + + +def build_github_service() -> GitHubService | None: + token = os.getenv("GITHUB_TOKEN", "").strip() + repo = os.getenv("GITHUB_REPO", "").strip() + direct_file = os.getenv("DIRECT_RULE_FILE", "").strip() + if not token or not repo or not direct_file: + return None + config = SimpleNamespace( + GITHUB_TOKEN=token, + GITHUB_REPO=repo, + DIRECT_RULE_FILE=direct_file, + GITHUB_COMMIT_NAME="Rule-Bot", + GITHUB_COMMIT_EMAIL=os.getenv("GITHUB_COMMIT_EMAIL", "noreply@users.noreply.github.com"), + GITHUB_FILE_CACHE_TTL=int(os.getenv("GITHUB_FILE_CACHE_TTL", "60")), + GITHUB_FILE_CACHE_SIZE=int(os.getenv("GITHUB_FILE_CACHE_SIZE", "4")), + ) + return GitHubService(config) + + +async def worker( + worker_id: int, + domains: List[str], + end_time: float, + data_manager: DataManager, + checker: DomainChecker, + github_service: GitHubService | None, + pause: float, +): + while time.time() < end_time: + domain = random.choice(domains) + try: + await data_manager.is_domain_in_geosite(domain) + await checker.check_domain_comprehensive(domain) + if github_service: + await github_service.check_domain_in_rules(domain) + except Exception as e: + logger.warning("worker {} 处理域名 {} 失败: {}", worker_id, domain, e) + await asyncio.sleep(pause) + + +async def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--duration", type=int, default=600, help="Duration seconds") + parser.add_argument("--concurrency", type=int, default=4, help="Concurrent workers") + parser.add_argument("--pause", type=float, default=0.5, help="Pause between iterations") + parser.add_argument("--domains-file", type=str, default="", help="Optional domain list file") + args = parser.parse_args() + + domains = load_domains(args.domains_file) + + data_cfg = build_data_config() + dns_cfg = build_dns_config() + data_manager = DataManager(data_cfg) + await data_manager.initialize() + + dns_service = DNSService( + dns_cfg.DOH_SERVERS, + dns_cfg.NS_DOH_SERVERS, + cache_size=dns_cfg.DNS_CACHE_SIZE, + cache_ttl=dns_cfg.DNS_CACHE_TTL, + ns_cache_size=dns_cfg.NS_CACHE_SIZE, + ns_cache_ttl=dns_cfg.NS_CACHE_TTL, + max_concurrency=dns_cfg.DNS_MAX_CONCURRENCY, + conn_limit=dns_cfg.DNS_CONN_LIMIT, + conn_limit_per_host=dns_cfg.DNS_CONN_LIMIT_PER_HOST, + timeout_total=dns_cfg.DNS_TIMEOUT_TOTAL, + timeout_connect=dns_cfg.DNS_TIMEOUT_CONNECT, + ) + await dns_service.start() + geoip_service = GeoIPService( + str(data_manager.geoip_file), + str(data_manager.cn_ipv4_file), + cache_size=dns_cfg.GEOIP_CACHE_SIZE, + cache_ttl=dns_cfg.GEOIP_CACHE_TTL, + ) + checker = DomainChecker(dns_service, geoip_service) + github_service = build_github_service() + + end_time = time.time() + args.duration + tasks = [ + asyncio.create_task( + worker(i, domains, end_time, data_manager, checker, github_service, args.pause) + ) + for i in range(args.concurrency) + ] + + await asyncio.gather(*tasks, return_exceptions=True) + await dns_service.close() + await data_manager.close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main()))