diff --git a/nonebot_plugin_chatrecorder/__init__.py b/nonebot_plugin_chatrecorder/__init__.py index fae1c2f..528eab9 100644 --- a/nonebot_plugin_chatrecorder/__init__.py +++ b/nonebot_plugin_chatrecorder/__init__.py @@ -1,7 +1,8 @@ from nonebot import require from nonebot.plugin import PluginMetadata -require("nonebot_plugin_session_orm") +require("nonebot_plugin_orm") +require("nonebot_plugin_uninfo") require("nonebot_plugin_localstore") from . import adapters as adapters diff --git a/nonebot_plugin_chatrecorder/adapters/console.py b/nonebot_plugin_chatrecorder/adapters/console.py index 6db5c5f..bfcb60f 100644 --- a/nonebot_plugin_chatrecorder/adapters/console.py +++ b/nonebot_plugin_chatrecorder/adapters/console.py @@ -1,17 +1,24 @@ +import uuid from dataclasses import asdict from datetime import datetime, timezone -from itertools import count from typing import Any, Optional from nonebot.adapters import Bot as BaseBot from nonebot.message import event_postprocessor from nonebot_plugin_orm import get_session -from nonebot_plugin_session import Session, SessionLevel, extract_session -from nonebot_plugin_session_orm import SessionModel, get_session_persist_id +from nonebot_plugin_uninfo import ( + Scene, + SceneType, + Session, + SupportAdapter, + SupportScope, + Uninfo, + User, +) +from nonebot_plugin_uninfo.orm import get_session_persist_id from typing_extensions import override from ..config import plugin_config -from ..consts import SupportedAdapter, SupportedPlatform from ..message import ( MessageDeserializer, MessageSerializer, @@ -25,36 +32,21 @@ try: from nonebot.adapters.console import Bot, Message, MessageEvent, MessageSegment from nonechat import ConsoleMessage, Emoji, Text - from sqlalchemy import select - - adapter = SupportedAdapter.console - id = None + adapter = SupportAdapter.console - async def get_id() -> str: - global id - if not id: - statement = ( - select(MessageRecord.message_id) - .where(SessionModel.bot_type == adapter) - .order_by(MessageRecord.message_id.desc()) - .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id) - ) - async with get_session() as db_session: - message_id = await db_session.scalar(statement) - id = count(int(message_id) + 1) if message_id else count(0) - return str(next(id)) + def get_id() -> str: + return uuid.uuid4().hex @event_postprocessor - async def record_recv_msg(bot: Bot, event: MessageEvent): - session = extract_session(bot, event) + async def record_recv_msg(event: MessageEvent, session: Uninfo): session_persist_id = await get_session_persist_id(session) record = MessageRecord( session_persist_id=session_persist_id, time=remove_timezone(event.time.astimezone(timezone.utc)), type=record_type(event), - message_id=await get_id(), + message_id=get_id(), message=serialize_message(adapter, event.get_message()), plain_text=event.get_plaintext(), ) @@ -78,13 +70,11 @@ async def record_send_msg( return session = Session( - bot_id=bot.self_id, - bot_type=bot.type, - platform=SupportedPlatform.console, - level=SessionLevel.LEVEL1, - id1=data.get("user_id"), - id2=None, - id3=None, + self_id=bot.self_id, + adapter=adapter, + scope=SupportScope.console, + scene=Scene(id=data["user_id"], type=SceneType.PRIVATE), + user=User(id=bot.self_id), ) session_persist_id = await get_session_persist_id(session) @@ -105,7 +95,7 @@ async def record_send_msg( session_persist_id=session_persist_id, time=remove_timezone(datetime.now(timezone.utc)), type="message_sent", - message_id=await get_id(), + message_id=get_id(), message=serialize_message(adapter, message), plain_text=message.extract_plain_text(), ) diff --git a/nonebot_plugin_chatrecorder/adapters/discord.py b/nonebot_plugin_chatrecorder/adapters/discord.py index d7ed442..30ab61b 100644 --- a/nonebot_plugin_chatrecorder/adapters/discord.py +++ b/nonebot_plugin_chatrecorder/adapters/discord.py @@ -3,12 +3,19 @@ from nonebot.adapters import Bot as BaseBot from nonebot.message import event_postprocessor from nonebot_plugin_orm import get_session -from nonebot_plugin_session import Session, SessionLevel, extract_session -from nonebot_plugin_session_orm import get_session_persist_id +from nonebot_plugin_uninfo import ( + Scene, + SceneType, + Session, + SupportAdapter, + SupportScope, + Uninfo, + User, +) +from nonebot_plugin_uninfo.orm import get_session_persist_id from typing_extensions import override from ..config import plugin_config -from ..consts import SupportedAdapter, SupportedPlatform from ..message import ( MessageDeserializer, MessageSerializer, @@ -23,11 +30,10 @@ from nonebot.adapters.discord import Bot, Message, MessageEvent from nonebot.adapters.discord.api import UNSET, Channel, ChannelType, MessageGet - adapter = SupportedAdapter.discord + adapter = SupportAdapter.discord @event_postprocessor - async def record_recv_msg(bot: Bot, event: MessageEvent): - session = extract_session(bot, event) + async def record_recv_msg(event: MessageEvent, session: Uninfo): session_persist_id = await get_session_persist_id(session) record = MessageRecord( @@ -71,30 +77,26 @@ async def record_send_msg( return channel = await get_channel(bot, result.channel_id) - - level = SessionLevel.LEVEL0 - id1 = None - id2 = str(result.channel_id) - id3 = None + parent = None if channel.type in [ChannelType.DM]: - level = SessionLevel.LEVEL1 - id1 = ( + scene_type = SceneType.PRIVATE + scene_id = ( str(channel.recipients[0].id) if channel.recipients != UNSET and channel.recipients - else None + else "" ) else: - level = SessionLevel.LEVEL3 - id3 = str(channel.guild_id) if channel.guild_id != UNSET else None + scene_type = SceneType.CHANNEL_TEXT + scene_id = str(result.channel_id) + if channel.guild_id != UNSET: + parent = Scene(id=str(channel.guild_id), type=SceneType.GUILD) session = Session( - bot_id=bot.self_id, - bot_type=bot.type, - platform=SupportedPlatform.discord, - level=level, - id1=id1, - id2=id2, - id3=id3, + self_id=bot.self_id, + adapter=adapter, + scope=SupportScope.discord, + scene=Scene(id=scene_id, type=scene_type, parent=parent), + user=User(id=bot.self_id), ) session_persist_id = await get_session_persist_id(session) diff --git a/nonebot_plugin_chatrecorder/adapters/dodo.py b/nonebot_plugin_chatrecorder/adapters/dodo.py index bba0b29..84bab2a 100644 --- a/nonebot_plugin_chatrecorder/adapters/dodo.py +++ b/nonebot_plugin_chatrecorder/adapters/dodo.py @@ -4,12 +4,19 @@ from nonebot.adapters import Bot as BaseBot from nonebot.message import event_postprocessor from nonebot_plugin_orm import get_session -from nonebot_plugin_session import Session, SessionLevel, extract_session -from nonebot_plugin_session_orm import get_session_persist_id +from nonebot_plugin_uninfo import ( + Scene, + SceneType, + Session, + SupportAdapter, + SupportScope, + Uninfo, + User, +) +from nonebot_plugin_uninfo.orm import get_session_persist_id from typing_extensions import override from ..config import plugin_config -from ..consts import SupportedAdapter, SupportedPlatform from ..message import ( MessageDeserializer, MessageSerializer, @@ -24,11 +31,10 @@ from nonebot.adapters.dodo import Bot, Message, MessageEvent from nonebot.adapters.dodo.models import MessageReturn - adapter = SupportedAdapter.dodo + adapter = SupportAdapter.dodo @event_postprocessor - async def record_recv_msg(bot: Bot, event: MessageEvent): - session = extract_session(bot, event) + async def record_recv_msg(event: MessageEvent, session: Uninfo): session_persist_id = await get_session_persist_id(session) record = MessageRecord( @@ -58,29 +64,25 @@ async def record_send_msg( if e or not result or not isinstance(result, MessageReturn): return - island_source_id = None - channel_id = None if api == "set_channel_message_send": - level = SessionLevel.LEVEL3 - channel_id = data["channel_id"] - dodo_source_id = data.get("dodo_source_id") + scene_type = SceneType.CHANNEL_TEXT + scene_id = data["channel_id"] + parent = None elif api == "set_personal_message_send": - level = SessionLevel.LEVEL1 - island_source_id = data["island_source_id"] - dodo_source_id = data["dodo_source_id"] + scene_type = SceneType.PRIVATE + scene_id = data["dodo_source_id"] + parent = Scene(id=data["island_source_id"], type=SceneType.GUILD) else: return message = Message.from_message_body(data["message_body"]) session = Session( - bot_id=bot.self_id, - bot_type=bot.type, - platform=SupportedPlatform.dodo, - level=level, - id1=dodo_source_id, - id2=channel_id, - id3=island_source_id, + self_id=bot.self_id, + adapter=adapter, + scope=SupportScope.dodo, + scene=Scene(id=scene_id, type=scene_type, parent=parent), + user=User(id=bot.self_id), ) session_persist_id = await get_session_persist_id(session) diff --git a/nonebot_plugin_chatrecorder/adapters/feishu.py b/nonebot_plugin_chatrecorder/adapters/feishu.py index 9b3907c..0cde891 100644 --- a/nonebot_plugin_chatrecorder/adapters/feishu.py +++ b/nonebot_plugin_chatrecorder/adapters/feishu.py @@ -5,12 +5,19 @@ from nonebot.adapters import Bot as BaseBot from nonebot.message import event_postprocessor from nonebot_plugin_orm import get_session -from nonebot_plugin_session import Session, SessionLevel, extract_session -from nonebot_plugin_session_orm import get_session_persist_id +from nonebot_plugin_uninfo import ( + Scene, + SceneType, + Session, + SupportAdapter, + SupportScope, + Uninfo, + User, +) +from nonebot_plugin_uninfo.orm import get_session_persist_id from typing_extensions import override from ..config import plugin_config -from ..consts import SupportedAdapter, SupportedPlatform from ..message import ( MessageDeserializer, MessageSerializer, @@ -24,11 +31,10 @@ try: from nonebot.adapters.feishu import Bot, Message, MessageEvent - adapter = SupportedAdapter.feishu + adapter = SupportAdapter.feishu @event_postprocessor - async def record_recv_msg(bot: Bot, event: MessageEvent): - session = extract_session(bot, event) + async def record_recv_msg(event: MessageEvent, session: Uninfo): session_persist_id = await get_session_persist_id(session) record = MessageRecord( @@ -82,24 +88,17 @@ async def record_send_msg( resp = await get_chat_info(bot, chat_id) chat_mode = resp["data"]["chat_mode"] - level = SessionLevel.LEVEL0 - id1 = None - id2 = None if chat_mode == "p2p": - level = SessionLevel.LEVEL1 - id1 = resp["data"]["owner_id"] + scene_type = SceneType.PRIVATE elif chat_mode == "group": - level = SessionLevel.LEVEL2 - id2 = chat_id + scene_type = SceneType.GROUP session = Session( - bot_id=bot.self_id, - bot_type=bot.type, - platform=SupportedPlatform.feishu, - level=level, - id1=id1, - id2=id2, - id3=None, + self_id=bot.self_id, + adapter=adapter, + scope=SupportScope.feishu, + scene=Scene(id=chat_id, type=scene_type), + user=User(id=bot.self_id), ) session_persist_id = await get_session_persist_id(session) diff --git a/nonebot_plugin_chatrecorder/adapters/kaiheila.py b/nonebot_plugin_chatrecorder/adapters/kaiheila.py index 741a189..f0af22b 100644 --- a/nonebot_plugin_chatrecorder/adapters/kaiheila.py +++ b/nonebot_plugin_chatrecorder/adapters/kaiheila.py @@ -4,12 +4,19 @@ from nonebot.adapters import Bot as BaseBot from nonebot.message import event_postprocessor from nonebot_plugin_orm import get_session -from nonebot_plugin_session import Session, SessionLevel, extract_session -from nonebot_plugin_session_orm import get_session_persist_id +from nonebot_plugin_uninfo import ( + Scene, + SceneType, + Session, + SupportAdapter, + SupportScope, + Uninfo, + User, +) +from nonebot_plugin_uninfo.orm import get_session_persist_id from typing_extensions import override from ..config import plugin_config -from ..consts import SupportedAdapter, SupportedPlatform from ..message import ( MessageDeserializer, MessageSerializer, @@ -25,11 +32,10 @@ from nonebot.adapters.kaiheila.api.model import MessageCreateReturn from nonebot.adapters.kaiheila.event import MessageEvent - adapter = SupportedAdapter.kaiheila + adapter = SupportAdapter.kook @event_postprocessor - async def record_recv_msg(bot: Bot, event: MessageEvent): - session = extract_session(bot, event) + async def record_recv_msg(event: MessageEvent, session: Uninfo): session_persist_id = await get_session_persist_id(session) record = MessageRecord( @@ -67,14 +73,11 @@ async def record_send_msg( ): return + scene_id = data["target_id"] if api == "message_create": - level = SessionLevel.LEVEL3 - channel_id = data["target_id"] - user_id = data.get("temp_target_id") + scene_type = SceneType.CHANNEL_TEXT elif api == "directMessage_create": - level = SessionLevel.LEVEL1 - channel_id = None - user_id = data["target_id"] + scene_type = SceneType.PRIVATE else: return @@ -99,13 +102,11 @@ async def record_send_msg( message = Message(message) session = Session( - bot_id=bot.self_id, - bot_type=bot.type, - platform=SupportedPlatform.kaiheila, - level=level, - id1=user_id, - id2=None, - id3=channel_id, + self_id=bot.self_id, + adapter=adapter, + scope=SupportScope.kook, + scene=Scene(id=scene_id, type=scene_type), + user=User(id=bot.self_id), ) session_persist_id = await get_session_persist_id(session) diff --git a/nonebot_plugin_chatrecorder/adapters/onebot_v11.py b/nonebot_plugin_chatrecorder/adapters/onebot_v11.py index 01749c1..2ab9892 100644 --- a/nonebot_plugin_chatrecorder/adapters/onebot_v11.py +++ b/nonebot_plugin_chatrecorder/adapters/onebot_v11.py @@ -7,18 +7,20 @@ from nonebot.adapters import Bot as BaseBot from nonebot.message import event_postprocessor from nonebot_plugin_orm import get_session -from nonebot_plugin_session import Session, SessionLevel, extract_session -from nonebot_plugin_session_orm import get_session_persist_id +from nonebot_plugin_uninfo import ( + Scene, + SceneType, + Session, + SupportAdapter, + SupportScope, + Uninfo, + User, +) +from nonebot_plugin_uninfo.orm import get_session_persist_id from typing_extensions import override from ..config import plugin_config -from ..consts import ( - IMAGE_CACHE_DIR, - RECORD_CACHE_DIR, - VIDEO_CACHE_DIR, - SupportedAdapter, - SupportedPlatform, -) +from ..consts import IMAGE_CACHE_DIR, RECORD_CACHE_DIR, VIDEO_CACHE_DIR from ..message import ( JsonMsg, MessageDeserializer, @@ -33,11 +35,10 @@ try: from nonebot.adapters.onebot.v11 import Bot, Message, MessageEvent, MessageSegment - adapter = SupportedAdapter.onebot_v11 + adapter = SupportAdapter.onebot11 @event_postprocessor - async def record_recv_msg(bot: Bot, event: MessageEvent): - session = extract_session(bot, event) + async def record_recv_msg(event: MessageEvent, session: Uninfo): session_persist_id = await get_session_persist_id(session) record = MessageRecord( @@ -76,18 +77,18 @@ async def record_send_msg( or (data.get("message_type") == None and data.get("group_id")) ) ): - level = SessionLevel.LEVEL2 + scene_id = str(data["group_id"]) + scene_type = SceneType.GROUP else: - level = SessionLevel.LEVEL1 + scene_id = str(data["user_id"]) + scene_type = SceneType.PRIVATE session = Session( - bot_id=bot.self_id, - bot_type=bot.type, - platform=SupportedPlatform.qq, - level=level, - id1=str(data.get("user_id", "")) or None, - id2=str(data.get("group_id", "")) or None, - id3=None, + self_id=bot.self_id, + adapter=adapter, + scope=SupportScope.qq_client, + scene=Scene(id=scene_id, type=scene_type), + user=User(id=bot.self_id), ) session_persist_id = await get_session_persist_id(session) diff --git a/nonebot_plugin_chatrecorder/adapters/onebot_v12.py b/nonebot_plugin_chatrecorder/adapters/onebot_v12.py index 9abbe03..f732637 100644 --- a/nonebot_plugin_chatrecorder/adapters/onebot_v12.py +++ b/nonebot_plugin_chatrecorder/adapters/onebot_v12.py @@ -4,12 +4,19 @@ from nonebot.adapters import Bot as BaseBot from nonebot.message import event_postprocessor from nonebot_plugin_orm import get_session -from nonebot_plugin_session import Session, SessionLevel, extract_session -from nonebot_plugin_session_orm import get_session_persist_id +from nonebot_plugin_uninfo import ( + Scene, + SceneType, + Session, + SupportAdapter, + SupportScope, + Uninfo, + User, +) +from nonebot_plugin_uninfo.orm import get_session_persist_id from typing_extensions import override from ..config import plugin_config -from ..consts import SupportedAdapter from ..message import ( MessageDeserializer, MessageSerializer, @@ -18,16 +25,15 @@ serialize_message, ) from ..model import MessageRecord -from ..utils import format_platform, record_type, remove_timezone +from ..utils import record_type, remove_timezone try: from nonebot.adapters.onebot.v12 import Bot, Message, MessageEvent - adapter = SupportedAdapter.onebot_v12 + adapter = SupportAdapter.onebot12 @event_postprocessor - async def record_recv_msg(bot: Bot, event: MessageEvent): - session = extract_session(bot, event) + async def record_recv_msg(event: MessageEvent, session: Uninfo): session_persist_id = await get_session_persist_id(session) record = MessageRecord( @@ -59,23 +65,31 @@ async def record_send_msg( if api not in ["send_message"]: return + parent = None detail_type = data["detail_type"] - level = SessionLevel.LEVEL0 if detail_type == "channel": - level = SessionLevel.LEVEL3 + scene_type = SceneType.CHANNEL_TEXT + scene_id = data["channel_id"] + parent = ( + Scene(id=data["guild_id"], type=SceneType.GUILD) + if data.get("guild_id") + else None + ) elif detail_type == "group": - level = SessionLevel.LEVEL2 + scene_type = SceneType.GROUP + scene_id = data["group_id"] elif detail_type == "private": - level = SessionLevel.LEVEL1 + scene_type = SceneType.PRIVATE + scene_id = data["user_id"] + else: + return session = Session( - bot_id=bot.self_id, - bot_type=bot.type, - platform=format_platform(bot.platform), - level=level, - id1=data.get("user_id"), - id2=data.get("group_id") or data.get("channel_id"), - id3=data.get("guild_id"), + self_id=bot.self_id, + adapter=adapter, + scope=SupportScope.ensure_ob12(bot.platform), + scene=Scene(id=scene_id, type=scene_type, parent=parent), + user=User(id=bot.self_id), ) session_persist_id = await get_session_persist_id(session) diff --git a/nonebot_plugin_chatrecorder/adapters/qq.py b/nonebot_plugin_chatrecorder/adapters/qq.py index 18e0f0c..0326e3b 100644 --- a/nonebot_plugin_chatrecorder/adapters/qq.py +++ b/nonebot_plugin_chatrecorder/adapters/qq.py @@ -4,12 +4,19 @@ from nonebot.adapters import Bot as BaseBot from nonebot.message import event_postprocessor from nonebot_plugin_orm import get_session -from nonebot_plugin_session import Session, SessionLevel, extract_session -from nonebot_plugin_session_orm import get_session_persist_id +from nonebot_plugin_uninfo import ( + Scene, + SceneType, + Session, + SupportAdapter, + SupportScope, + Uninfo, + User, +) +from nonebot_plugin_uninfo.orm import get_session_persist_id from typing_extensions import override from ..config import plugin_config -from ..consts import SupportedAdapter, SupportedPlatform from ..message import ( MessageDeserializer, MessageSerializer, @@ -28,13 +35,12 @@ PostGroupMessagesReturn, ) - adapter = SupportedAdapter.qq + adapter = SupportAdapter.qq @event_postprocessor async def record_recv_msg( - bot: Bot, event: Union[GuildMessageEvent, QQMessageEvent] + event: Union[GuildMessageEvent, QQMessageEvent], session: Uninfo ): - session = extract_session(bot, event) session_persist_id = await get_session_persist_id(session) if isinstance(event, QQMessageEvent): @@ -83,43 +89,36 @@ async def record_send_msg( ): return - id1 = None - id2 = None - id3 = None - level = SessionLevel.LEVEL0 - platform = SupportedPlatform.qqguild + parent = None if api == "post_messages": assert isinstance(result, GuildMessage) - level = SessionLevel.LEVEL3 - id3 = result.guild_id - id2 = result.channel_id + scene_type = SceneType.CHANNEL_TEXT + scene_id = result.channel_id + parent = Scene(id=result.guild_id, type=SceneType.GUILD) elif api == "post_dms_messages": assert isinstance(result, GuildMessage) - level = SessionLevel.LEVEL1 - id3 = data["guild_id"] + scene_type = SceneType.PRIVATE + scene_id = result.channel_id + parent = Scene(id=result.guild_id, type=SceneType.GUILD) elif api == "post_c2c_messages": assert isinstance(result, PostC2CMessagesReturn) - level = SessionLevel.LEVEL1 - id1 = data["openid"] - platform = SupportedPlatform.qq + scene_type = SceneType.PRIVATE + scene_id = data["openid"] elif api == "post_group_messages": assert isinstance(result, PostGroupMessagesReturn) - level = SessionLevel.LEVEL2 - id2 = data["group_openid"] - platform = SupportedPlatform.qq + scene_type = SceneType.GROUP + scene_id = data["group_openid"] session = Session( - bot_id=bot.self_id, - bot_type=bot.type, - platform=platform, - level=level, - id1=id1, - id2=id2, - id3=id3, + self_id=bot.self_id, + adapter=adapter, + scope=SupportScope.qq_api, + scene=Scene(id=scene_id, type=scene_type, parent=parent), + user=User(id=bot.self_id), ) session_persist_id = await get_session_persist_id(session) diff --git a/nonebot_plugin_chatrecorder/adapters/satori.py b/nonebot_plugin_chatrecorder/adapters/satori.py index 9d1965a..6c2fd02 100644 --- a/nonebot_plugin_chatrecorder/adapters/satori.py +++ b/nonebot_plugin_chatrecorder/adapters/satori.py @@ -4,12 +4,19 @@ from nonebot.adapters import Bot as BaseBot from nonebot.message import event_postprocessor from nonebot_plugin_orm import get_session -from nonebot_plugin_session import Session, SessionLevel, extract_session -from nonebot_plugin_session_orm import get_session_persist_id +from nonebot_plugin_uninfo import ( + Scene, + SceneType, + Session, + SupportAdapter, + SupportScope, + Uninfo, + User, +) +from nonebot_plugin_uninfo.orm import get_session_persist_id from typing_extensions import override from ..config import plugin_config -from ..consts import SupportedAdapter from ..message import ( MessageDeserializer, MessageSerializer, @@ -18,18 +25,18 @@ serialize_message, ) from ..model import MessageRecord -from ..utils import format_platform, record_type, remove_timezone +from ..utils import record_type, remove_timezone try: from nonebot.adapters.satori import Bot, Message from nonebot.adapters.satori.event import MessageCreatedEvent from nonebot.adapters.satori.models import MessageObject + from nonebot_plugin_uninfo.adapters.satori.main import TYPE_MAPPING - adapter = SupportedAdapter.satori + adapter = SupportAdapter.satori @event_postprocessor - async def record_recv_msg(bot: Bot, event: MessageCreatedEvent): - session = extract_session(bot, event) + async def record_recv_msg(event: MessageCreatedEvent, session: Uninfo): session_persist_id = await get_session_persist_id(session) record = MessageRecord( @@ -64,27 +71,46 @@ async def record_send_msg( if not isinstance(res, MessageObject): return result_messages = cast(list[MessageObject], result) - result_message = result_messages[0] - level = SessionLevel.LEVEL0 - if result_message.guild: - level = SessionLevel.LEVEL3 - elif result_message.member: - level = SessionLevel.LEVEL2 - elif result_message.user: - level = SessionLevel.LEVEL1 - id1 = data["channel_id"] if level == SessionLevel.LEVEL1 else None - id2 = result_message.channel.id if result_message.channel else None - id3 = result_message.guild.id if result_message.guild else None + + parent = None + + if result_message.guild and result_message.channel: + scene_type = TYPE_MAPPING[result_message.channel.type] + scene_id = result_message.channel.id + parent = Scene(id=result_message.guild.id, type=SceneType.GUILD) + if ( + "guild.plain" in bot._self_info.features + or result_message.guild.id == result_message.channel.id + ): + scene_type = SceneType.GROUP + parent.type = SceneType.GROUP + + elif result_message.guild: + scene_type = ( + SceneType.GROUP + if "guild.plain" in bot._self_info.features + else SceneType.GUILD + ) + scene_id = result_message.guild.id + + elif result_message.channel: + scene_type = ( + SceneType.GROUP + if "guild.plain" in bot._self_info.features + else SceneType.GUILD + ) + scene_id = result_message.channel.id + + else: + return session = Session( - bot_id=bot.self_id, - bot_type=bot.type, - platform=format_platform(bot.platform), - level=level, - id1=id1, - id2=id2, - id3=id3, + self_id=bot.self_id, + adapter=adapter, + scope=SupportScope.ensure_satori(bot.platform), + scene=Scene(id=scene_id, type=scene_type, parent=parent), + user=User(id=bot.self_id), ) session_persist_id = await get_session_persist_id(session) diff --git a/nonebot_plugin_chatrecorder/adapters/telegram.py b/nonebot_plugin_chatrecorder/adapters/telegram.py index baafdea..4480a29 100644 --- a/nonebot_plugin_chatrecorder/adapters/telegram.py +++ b/nonebot_plugin_chatrecorder/adapters/telegram.py @@ -5,12 +5,19 @@ from nonebot.compat import type_validate_python from nonebot.message import event_postprocessor from nonebot_plugin_orm import get_session -from nonebot_plugin_session import Session, SessionLevel, extract_session -from nonebot_plugin_session_orm import get_session_persist_id +from nonebot_plugin_uninfo import ( + Scene, + SceneType, + Session, + SupportAdapter, + SupportScope, + Uninfo, + User, +) +from nonebot_plugin_uninfo.orm import get_session_persist_id from typing_extensions import override from ..config import plugin_config -from ..consts import SupportedAdapter, SupportedPlatform from ..message import ( MessageDeserializer, MessageSerializer, @@ -26,11 +33,10 @@ from nonebot.adapters.telegram.event import MessageEvent from nonebot.adapters.telegram.model import Message as TGMessage - adapter = SupportedAdapter.telegram + adapter = SupportAdapter.telegram @event_postprocessor - async def record_recv_msg(bot: Bot, event: MessageEvent): - session = extract_session(bot, event) + async def record_recv_msg(event: MessageEvent, session: Uninfo): session_persist_id = await get_session_persist_id(session) record = MessageRecord( @@ -97,28 +103,24 @@ async def record_send_msg( message_thread_id = tg_message.message_thread_id chat_id = tg_message.chat.id - id1 = None - id2 = None - id3 = None + parent = None if message_thread_id: - id3 = str(chat_id) - id2 = str(message_thread_id) - level = SessionLevel.LEVEL3 + scene_type = SceneType.CHANNEL_TEXT + scene_id = str(message_thread_id) + parent = Scene(id=str(chat_id), type=SceneType.GUILD) elif chat.type == "private": - id1 = str(chat_id) - level = SessionLevel.LEVEL1 + scene_type = SceneType.PRIVATE + scene_id = str(chat_id) else: - id2 = str(chat_id) - level = SessionLevel.LEVEL2 + scene_type = SceneType.GROUP + scene_id = str(chat_id) session = Session( - bot_id=bot.self_id, - bot_type=bot.type, - platform=SupportedPlatform.telegram, - level=level, - id1=id1, - id2=id2, - id3=id3, + self_id=bot.self_id, + adapter=adapter, + scope=SupportScope.telegram, + scene=Scene(id=scene_id, type=scene_type, parent=parent), + user=User(id=bot.self_id), ) session_persist_id = await get_session_persist_id(session) diff --git a/nonebot_plugin_chatrecorder/consts.py b/nonebot_plugin_chatrecorder/consts.py index 484f8e1..e732aea 100644 --- a/nonebot_plugin_chatrecorder/consts.py +++ b/nonebot_plugin_chatrecorder/consts.py @@ -1,7 +1,7 @@ from nonebot_plugin_localstore import get_cache_dir -from strenum import StrEnum CACHE_DIR = get_cache_dir("nonebot_plugin_chatrecorder") + IMAGE_CACHE_DIR = CACHE_DIR / "images" RECORD_CACHE_DIR = CACHE_DIR / "records" VIDEO_CACHE_DIR = CACHE_DIR / "videos" @@ -9,28 +9,3 @@ IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) RECORD_CACHE_DIR.mkdir(parents=True, exist_ok=True) VIDEO_CACHE_DIR.mkdir(parents=True, exist_ok=True) - - -class SupportedAdapter(StrEnum): - console = "Console" - discord = "Discord" - dodo = "DoDo" - feishu = "Feishu" - kaiheila = "Kaiheila" - onebot_v11 = "OneBot V11" - onebot_v12 = "OneBot V12" - qq = "QQ" - satori = "Satori" - telegram = "Telegram" - - -class SupportedPlatform(StrEnum): - console = "console" - discord = "discord" - dodo = "dodo" - feishu = "feishu" - kaiheila = "kaiheila" - qq = "qq" - qqguild = "qqguild" - telegram = "telegram" - unknown = "unknown" diff --git a/nonebot_plugin_chatrecorder/message.py b/nonebot_plugin_chatrecorder/message.py index 4c16e43..59012f0 100644 --- a/nonebot_plugin_chatrecorder/message.py +++ b/nonebot_plugin_chatrecorder/message.py @@ -3,8 +3,8 @@ from nonebot.adapters import Bot, Message from nonebot.compat import type_validate_python +from nonebot_plugin_uninfo import SupportAdapter -from .consts import SupportedAdapter from .exception import AdapterNotInstalled, AdapterNotSupported JsonMsg = list[dict[str, Any]] @@ -28,42 +28,42 @@ def deserialize(cls, msg: JsonMsg) -> TM: return type_validate_python(cls.get_message_class(), msg) -_serializers: dict[SupportedAdapter, type[MessageSerializer]] = {} -_deserializers: dict[SupportedAdapter, type[MessageDeserializer]] = {} +_serializers: dict[SupportAdapter, type[MessageSerializer]] = {} +_deserializers: dict[SupportAdapter, type[MessageDeserializer]] = {} -def get_adapter_type(bot_type: str) -> SupportedAdapter: - for adapter in SupportedAdapter: +def get_adapter_type(bot_type: str) -> SupportAdapter: + for adapter in SupportAdapter: if bot_type == adapter.value: return adapter raise AdapterNotSupported(bot_type) -def get_serializer(adapter: SupportedAdapter) -> type[MessageSerializer]: +def get_serializer(adapter: SupportAdapter) -> type[MessageSerializer]: if adapter not in _serializers: raise AdapterNotInstalled(adapter.value) return _serializers[adapter] -def get_deserializer(adapter: SupportedAdapter) -> type[MessageDeserializer]: +def get_deserializer(adapter: SupportAdapter) -> type[MessageDeserializer]: if adapter not in _deserializers: raise AdapterNotInstalled(adapter.value) return _deserializers[adapter] -def register_serializer(adapter: SupportedAdapter, serializer: type[MessageSerializer]): +def register_serializer(adapter: SupportAdapter, serializer: type[MessageSerializer]): _serializers[adapter] = serializer def register_deserializer( - adapter: SupportedAdapter, deserializer: type[MessageDeserializer] + adapter: SupportAdapter, deserializer: type[MessageDeserializer] ): _deserializers[adapter] = deserializer def serialize_message( - bot_type: Union[Bot, SupportedAdapter, str], msg: Message + bot_type: Union[Bot, SupportAdapter, str], msg: Message ) -> JsonMsg: if isinstance(bot_type, Bot): bot_type = bot_type.type @@ -73,7 +73,7 @@ def serialize_message( def deserialize_message( - bot_type: Union[Bot, SupportedAdapter, str], msg: JsonMsg + bot_type: Union[Bot, SupportAdapter, str], msg: JsonMsg ) -> Message: if isinstance(bot_type, Bot): bot_type = bot_type.type diff --git a/nonebot_plugin_chatrecorder/migrations/0f0a7bc40f3c_message_id_length.py b/nonebot_plugin_chatrecorder/migrations/0f0a7bc40f3c_message_id_length.py deleted file mode 100644 index 6cf9a6e..0000000 --- a/nonebot_plugin_chatrecorder/migrations/0f0a7bc40f3c_message_id_length.py +++ /dev/null @@ -1,53 +0,0 @@ -"""message_id_length - -迁移 ID: 0f0a7bc40f3c -父迁移: 46327b837dd8 -创建时间: 2024-10-11 23:36:23.677012 - -""" - -from __future__ import annotations - -from collections.abc import Sequence - -import sqlalchemy as sa -from alembic import op - -revision: str = "0f0a7bc40f3c" -down_revision: str | Sequence[str] | None = "46327b837dd8" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade(name: str = "") -> None: - if name: - return - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table( - "nonebot_plugin_chatrecorder_messagerecord", schema=None - ) as batch_op: - batch_op.alter_column( - "message_id", - existing_type=sa.String(length=64), - type_=sa.String(length=255), - existing_nullable=False, - ) - - # ### end Alembic commands ### - - -def downgrade(name: str = "") -> None: - if name: - return - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table( - "nonebot_plugin_chatrecorder_messagerecord", schema=None - ) as batch_op: - batch_op.alter_column( - "message_id", - existing_type=sa.String(length=255), - type_=sa.String(length=64), - existing_nullable=False, - ) - - # ### end Alembic commands ### diff --git a/nonebot_plugin_chatrecorder/migrations/46327b837dd8_data_migrate.py b/nonebot_plugin_chatrecorder/migrations/46327b837dd8_data_migrate.py deleted file mode 100644 index cbb9900..0000000 --- a/nonebot_plugin_chatrecorder/migrations/46327b837dd8_data_migrate.py +++ /dev/null @@ -1,138 +0,0 @@ -"""data_migrate - -修订 ID: 46327b837dd8 -父修订: e6460fccaf90 -创建时间: 2023-10-12 15:32:47.496268 - -""" - -from __future__ import annotations - -import math -from collections.abc import Sequence - -from alembic import op -from alembic.op import run_async -from nonebot import logger, require -from sqlalchemy import Connection, insert, inspect, select -from sqlalchemy.ext.asyncio import AsyncConnection -from sqlalchemy.ext.automap import automap_base -from sqlalchemy.orm import Session - -revision: str = "46327b837dd8" -down_revision: str | Sequence[str] | None = "e6460fccaf90" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = "71a72119935f" - - -def _migrate_old_data(ds_conn: Connection): - insp = inspect(ds_conn) - if ( - "nonebot_plugin_chatrecorder_messagerecord" not in insp.get_table_names() - or "nonebot_plugin_chatrecorder_alembic_version" not in insp.get_table_names() - ): - return - - DsBase = automap_base() - DsBase.prepare(autoload_with=ds_conn) - DsMessageRecord = DsBase.classes.nonebot_plugin_chatrecorder_messagerecord - - Base = automap_base() - Base.prepare(autoload_with=op.get_bind()) - MessageRecord = Base.classes.nonebot_plugin_chatrecorder_messagerecord - - ds_session = Session(ds_conn) - session = Session(op.get_bind()) - - count = ds_session.query(DsMessageRecord).count() - if count == 0: - return - - AlembicVersion = DsBase.classes.nonebot_plugin_chatrecorder_alembic_version - version_num = ds_session.scalars(select(AlembicVersion.version_num)).one_or_none() - if not version_num: - return - if version_num not in ["902a51ac4032", "44cce443d2c0"]: - logger.warning( - "chatrecorder: 发现旧版本的数据,请先安装 0.4.2 版本," - "并运行 nb datastore upgrade 完成数据迁移之后再安装新版本" - ) - raise RuntimeError("chatrecorder: 请先安装 0.4.2 版本完成迁移之后再升级") - - logger.warning( - "chatrecorder: 发现来自 datastore 的数据,正在迁移,请不要关闭程序..." - ) - logger.info(f"chatrecorder: 聊天记录数据总数:{count}") - - # 每次迁移的数据量为 10000 条 - migration_limit = 10000 - last_message_id = -1 - - for i in range(math.ceil(count / migration_limit)): - statement = ( - select( - DsMessageRecord.id, - DsMessageRecord.session_id, - DsMessageRecord.time, - DsMessageRecord.type, - DsMessageRecord.message_id, - DsMessageRecord.message, - DsMessageRecord.plain_text, - ) - .order_by(DsMessageRecord.id) - .where(DsMessageRecord.id > last_message_id) - .limit(migration_limit) - ) - records = ds_session.execute(statement).all() - last_message_id = records[-1][0] - - bulk_insert_records = [] - for record in records: - bulk_insert_records.append( - { - "id": record[0], - "session_persist_id": record[1], - "time": record[2], - "type": record[3], - "message_id": record[4], - "message": record[5], - "plain_text": record[6], - } - ) - session.execute(insert(MessageRecord), bulk_insert_records) - logger.info( - f"chatrecorder: 已迁移 {i * migration_limit + len(records)}/{count}" - ) - - session.commit() - logger.warning("chatrecorder: 聊天记录数据迁移完成!") - - -async def data_migrate(conn: AsyncConnection): - from nonebot_plugin_datastore.db import get_engine - - async with get_engine().connect() as ds_conn: - await ds_conn.run_sync(_migrate_old_data) - - -def upgrade(name: str = "") -> None: - if name: - return - # ### commands auto generated by Alembic - please adjust! ### - - try: - require("nonebot_plugin_datastore") - except RuntimeError: - return - - run_async(data_migrate) - - # ### end Alembic commands ### - - -def downgrade(name: str = "") -> None: - if name: - return - # ### commands auto generated by Alembic - please adjust! ### - pass - # ### end Alembic commands ### diff --git a/nonebot_plugin_chatrecorder/migrations/bc43ce947963_data_migrate.py b/nonebot_plugin_chatrecorder/migrations/bc43ce947963_data_migrate.py new file mode 100644 index 0000000..31f0a86 --- /dev/null +++ b/nonebot_plugin_chatrecorder/migrations/bc43ce947963_data_migrate.py @@ -0,0 +1,114 @@ +"""data_migrate + +迁移 ID: bc43ce947963 +父迁移: ea78280f71da +创建时间: 2024-11-30 13:12:22.130296 + +""" + +from __future__ import annotations + +import math +from collections.abc import Sequence + +from alembic import op +from nonebot.log import logger +from sqlalchemy import insert, inspect, select +from sqlalchemy.ext.automap import automap_base +from sqlalchemy.orm import Session + +revision: str = "bc43ce947963" +down_revision: str | Sequence[str] | None = "ea78280f71da" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def data_migrate() -> None: + conn = op.get_bind() + insp = inspect(conn) + table_names = insp.get_table_names() + if "nonebot_plugin_chatrecorder_messagerecord" not in table_names: + return + + Base = automap_base() + Base.prepare(autoload_with=conn) + MessageRecord = Base.classes.nonebot_plugin_chatrecorder_messagerecord + MessageRecordV2 = Base.classes.nonebot_plugin_chatrecorder_messagerecord_v2 + + with Session(conn) as db_session: + count = db_session.query(MessageRecord).count() + if count == 0: + return + + try: + from nonebot_session_to_uninfo import check_tables, get_id_map + except ImportError: + raise ValueError("请安装 `nonebot-session-to-uninfo` 以迁移数据") + + check_tables() + + migration_limit = 10000 # 每次迁移的数据量为 10000 条 + last_message_id = -1 + id_map: dict[int, int] = {} + + logger.warning("chatrecorder: 正在迁移数据,请不要关闭程序...") + + for i in range(math.ceil(count / migration_limit)): + statement = ( + select( + MessageRecord.id, + MessageRecord.session_persist_id, + MessageRecord.time, + MessageRecord.type, + MessageRecord.message_id, + MessageRecord.message, + MessageRecord.plain_text, + ) + .order_by(MessageRecord.id) + .where(MessageRecord.id > last_message_id) + .limit(migration_limit) + ) + records = db_session.execute(statement).all() + last_message_id = records[-1][0] + + session_ids = [record[1] for record in records if record[1] not in id_map] + if session_ids: + id_map.update(get_id_map(session_ids)) + + bulk_insert_records = [] + for record in records: + bulk_insert_records.append( + { + "id": record[0], + "session_persist_id": id_map[record[1]], + "time": record[2], + "type": record[3], + "message_id": record[4], + "message": record[5], + "plain_text": record[6], + } + ) + db_session.execute(insert(MessageRecordV2), bulk_insert_records) + logger.info( + f"chatrecorder: 已迁移 {i * migration_limit + len(records)}/{count}" + ) + + db_session.commit() + + logger.warning("chatrecorder: 数据迁移完成!") + + +def upgrade(name: str = "") -> None: + if name: + return + # ### commands auto generated by Alembic - please adjust! ### + data_migrate() + # ### end Alembic commands ### + + +def downgrade(name: str = "") -> None: + if name: + return + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/nonebot_plugin_chatrecorder/migrations/e6460fccaf90_init_db.py b/nonebot_plugin_chatrecorder/migrations/ea78280f71da_init_db.py similarity index 71% rename from nonebot_plugin_chatrecorder/migrations/e6460fccaf90_init_db.py rename to nonebot_plugin_chatrecorder/migrations/ea78280f71da_init_db.py index ae0c65f..ee64545 100644 --- a/nonebot_plugin_chatrecorder/migrations/e6460fccaf90_init_db.py +++ b/nonebot_plugin_chatrecorder/migrations/ea78280f71da_init_db.py @@ -1,8 +1,8 @@ """init_db -修订 ID: e6460fccaf90 -父修订: -创建时间: 2023-10-09 21:18:49.711008 +迁移 ID: ea78280f71da +父迁移: +创建时间: 2024-11-30 13:10:23.110088 """ @@ -13,10 +13,10 @@ import sqlalchemy as sa from alembic import op -revision: str = "e6460fccaf90" +revision: str = "ea78280f71da" down_revision: str | Sequence[str] | None = None -branch_labels: str | Sequence[str] | None = ("nonebot_plugin_chatrecorder",) -depends_on: str | Sequence[str] | None = "fff55366306e" +branch_labels: str | Sequence[str] | None = ("nonebot_plugin_chatrecorder_v2",) +depends_on: str | Sequence[str] | None = "14175fde8186" def upgrade(name: str = "") -> None: @@ -24,17 +24,18 @@ def upgrade(name: str = "") -> None: return # ### commands auto generated by Alembic - please adjust! ### op.create_table( - "nonebot_plugin_chatrecorder_messagerecord", + "nonebot_plugin_chatrecorder_messagerecord_v2", sa.Column("id", sa.Integer(), nullable=False), sa.Column("session_persist_id", sa.Integer(), nullable=False), sa.Column("time", sa.DateTime(), nullable=False), sa.Column("type", sa.String(length=32), nullable=False), - sa.Column("message_id", sa.String(length=64), nullable=False), + sa.Column("message_id", sa.String(length=255), nullable=False), sa.Column("message", sa.JSON(), nullable=False), sa.Column("plain_text", sa.TEXT(), nullable=False), sa.PrimaryKeyConstraint( - "id", name=op.f("pk_nonebot_plugin_chatrecorder_messagerecord") + "id", name=op.f("pk_nonebot_plugin_chatrecorder_messagerecord_v2") ), + info={"bind_key": "nonebot_plugin_chatrecorder"}, ) # ### end Alembic commands ### @@ -43,5 +44,5 @@ def downgrade(name: str = "") -> None: if name: return # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("nonebot_plugin_chatrecorder_messagerecord") + op.drop_table("nonebot_plugin_chatrecorder_messagerecord_v2") # ### end Alembic commands ### diff --git a/nonebot_plugin_chatrecorder/model.py b/nonebot_plugin_chatrecorder/model.py index cc7f7c8..dabee05 100644 --- a/nonebot_plugin_chatrecorder/model.py +++ b/nonebot_plugin_chatrecorder/model.py @@ -10,6 +10,7 @@ class MessageRecord(Model): """消息记录""" + __tablename__ = "nonebot_plugin_chatrecorder_messagerecord_v2" __table_args__ = {"extend_existing": True} id: Mapped[int] = mapped_column(primary_key=True) diff --git a/nonebot_plugin_chatrecorder/record.py b/nonebot_plugin_chatrecorder/record.py index e03e74f..a9d7536 100644 --- a/nonebot_plugin_chatrecorder/record.py +++ b/nonebot_plugin_chatrecorder/record.py @@ -4,33 +4,36 @@ from nonebot.adapters import Message from nonebot_plugin_orm import get_session -from nonebot_plugin_session import Session, SessionIdType, SessionLevel -from nonebot_plugin_session_orm import SessionModel +from nonebot_plugin_uninfo import SceneType, Session, SupportAdapter, SupportScope +from nonebot_plugin_uninfo.orm import BotModel, SceneModel, SessionModel, UserModel from sqlalchemy import or_, select from sqlalchemy.sql import ColumnElement from .message import deserialize_message from .model import MessageRecord -from .utils import remove_timezone +from .utils import adapter_value, remove_timezone, scene_type_value, scope_value def filter_statement( *, session: Optional[Session] = None, - id_type: SessionIdType = SessionIdType.GROUP_USER, - include_platform: bool = True, - include_bot_type: bool = True, - include_bot_id: bool = True, - bot_ids: Optional[Iterable[str]] = None, - bot_types: Optional[Iterable[str]] = None, - platforms: Optional[Iterable[str]] = None, - levels: Optional[Iterable[Union[int, SessionLevel]]] = None, - id1s: Optional[Iterable[str]] = None, - id2s: Optional[Iterable[str]] = None, - id3s: Optional[Iterable[str]] = None, - exclude_id1s: Optional[Iterable[str]] = None, - exclude_id2s: Optional[Iterable[str]] = None, - exclude_id3s: Optional[Iterable[str]] = None, + filter_self_id: bool = True, + filter_adapter: bool = True, + filter_scope: bool = True, + filter_scene: bool = True, + filter_user: bool = True, + self_ids: Optional[Iterable[str]] = None, + adapters: Optional[Iterable[Union[str, SupportAdapter]]] = None, + scopes: Optional[Iterable[Union[str, SupportScope]]] = None, + scene_types: Optional[Iterable[Union[int, SceneType]]] = None, + scene_ids: Optional[Iterable[str]] = None, + user_ids: Optional[Iterable[str]] = None, + exclude_self_ids: Optional[Iterable[str]] = None, + exclude_adapters: Optional[Iterable[Union[str, SupportAdapter]]] = None, + exclude_scopes: Optional[Iterable[Union[str, SupportScope]]] = None, + exclude_scene_types: Optional[Iterable[Union[int, SceneType]]] = None, + exclude_scene_ids: Optional[Iterable[str]] = None, + exclude_user_ids: Optional[Iterable[str]] = None, time_start: Optional[datetime] = None, time_stop: Optional[datetime] = None, types: Optional[Iterable[Literal["message", "message_sent"]]] = None, @@ -40,19 +43,23 @@ def filter_statement( 参数: * ``session: Optional[Session]``: 会话模型,传入时会根据 `session` 中的字段筛选 * ``id_type: SessionIdType``: 会话 id 类型,仅在传入 `session` 时有效 - * ``include_platform: bool``: 是否限制平台类型,仅在传入 `session` 时有效 - * ``include_bot_type: bool``: 是否限制适配器类型,仅在传入 `session` 时有效 - * ``include_bot_id: bool``: 是否限制 bot id,仅在传入 `session` 时有效 - * ``bot_ids: Optional[Iterable[str]]``: bot id 列表,为空表示所有 bot id - * ``bot_types: Optional[Iterable[str]]``: 协议适配器类型列表,为空表示所有适配器 - * ``platforms: Optional[Iterable[str]]``: 平台类型列表,为空表示所有平台 - * ``levels: Optional[Iterable[Union[str, SessionLevel]]]``: 会话级别列表,为空表示所有级别 - * ``id1s: Optional[Iterable[str]]``: 会话 id1(用户级 id)列表,为空表示所有 id - * ``id2s: Optional[Iterable[str]]``: 会话 id2(群组级 id)列表,为空表示所有 id - * ``id3s: Optional[Iterable[str]]``: 会话 id3(两级群组级 id)列表,为空表示所有 id - * ``exclude_id1s: Optional[Iterable[str]]``: 不包含的会话 id1(用户级 id)列表,为空表示不限制 - * ``exclude_id2s: Optional[Iterable[str]]``: 不包含的会话 id2(群组级 id)列表,为空表示不限制 - * ``exclude_id3s: Optional[Iterable[str]]``: 不包含的会话 id3(两级群组级 id)列表,为空表示不限制 + * ``filter_self_id: bool``: 是否筛选 bot id,仅在传入 `session` 时有效 + * ``filter_adapter: bool``: 是否筛选适配器类型,仅在传入 `session` 时有效 + * ``filter_scope: bool``: 是否筛选平台类型,仅在传入 `session` 时有效 + * ``filter_scene: bool``: 是否筛选事件场景,仅在传入 `session` 时有效 + * ``filter_user: bool``: 是否筛选用户,仅在传入 `session` 时有效 + * ``self_ids: Optional[Iterable[str]]``: bot id 列表,为空表示所有 bot id + * ``adapters: Optional[Iterable[Union[str, SupportAdapter]]]``: 适配器类型列表,为空表示所有适配器 + * ``scopes: Optional[Iterable[Union[str, SupportScope]]]``: 平台类型列表,为空表示所有平台 + * ``scene_types: Optional[Iterable[Union[str, SceneType]]]``: 事件场景类型列表,为空表示所有类型 + * ``scene_ids: Optional[Iterable[str]]``: 事件场景 id 列表,为空表示所有 id + * ``user_ids: Optional[Iterable[str]]``: 用户 id 列表,为空表示所有 id + * ``exclude_self_ids: Optional[Iterable[str]]``: 不包含的 bot id 列表,为空表示不限制 + * ``exclude_adapters: Optional[Iterable[Union[str, SupportAdapter]]]``: 不包含的适配器类型列表,为空表示不限制 + * ``exclude_scopes: Optional[Iterable[Union[str, SupportScope]]]``: 不包含的平台类型列表,为空表示不限制 + * ``exclude_scene_types: Optional[Iterable[Union[str, SceneType]]]``: 不包含的事件场景类型列表,为空表示不限制 + * ``exclude_scene_ids: Optional[Iterable[str]]``: 不包含的事件场景 id 列表,为空表示不限制 + * ``exclude_user_ids: Optional[Iterable[str]]``: 不包含的用户 id 列表,为空表示不限制 * ``time_start: Optional[datetime]``: 起始时间,为空表示不限制起始时间(传入带时区的时间或 UTC 时间) * ``time_stop: Optional[datetime]``: 结束时间,为空表示不限制结束时间(传入带时区的时间或 UTC 时间) * ``types: Optional[Iterable[Literal["message", "message_sent"]]]``: 消息事件类型列表,为空表示所有类型 @@ -63,41 +70,61 @@ def filter_statement( whereclause: list[ColumnElement[bool]] = [] if session: - whereclause = SessionModel.filter_statement( - session, - id_type, - include_platform=include_platform, - include_bot_type=include_bot_type, - include_bot_id=include_bot_id, + if filter_self_id: + whereclause.append(BotModel.self_id == session.self_id) + if filter_adapter: + whereclause.append(BotModel.adapter == adapter_value(session.adapter)) + if filter_scope: + whereclause.append(BotModel.scope == scope_value(session.scope)) + if filter_scene: + whereclause.append(SceneModel.scene_id == session.scene.id) + whereclause.append(SceneModel.scene_type == session.scene.type.value) + if filter_user: + whereclause.append(UserModel.user_id == session.user.id) + + if self_ids: + whereclause.append(or_(*[BotModel.self_id == self_id for self_id in self_ids])) + if adapters: + whereclause.append( + or_(*[BotModel.adapter == adapter_value(adapter) for adapter in adapters]) ) - - if bot_types: + if scopes: + whereclause.append( + or_(*[BotModel.scope == scope_value(scope) for scope in scopes]) + ) + if scene_types: whereclause.append( - or_(*[SessionModel.bot_type == bot_type for bot_type in bot_types]) + or_( + *[ + SceneModel.scene_type == scene_type_value(scene_type) + for scene_type in scene_types + ] + ) ) - if bot_ids: - whereclause.append(or_(*[SessionModel.bot_id == bot_id for bot_id in bot_ids])) - if platforms: + if scene_ids: whereclause.append( - or_(*[SessionModel.platform == platform for platform in platforms]) + or_(*[SceneModel.scene_id == scene_id for scene_id in scene_ids]) ) - if levels: - whereclause.append(or_(*[SessionModel.level == level for level in levels])) - if id1s: - whereclause.append(or_(*[SessionModel.id1 == id1 for id1 in id1s])) - if id2s: - whereclause.append(or_(*[SessionModel.id2 == id2 for id2 in id2s])) - if id3s: - whereclause.append(or_(*[SessionModel.id3 == id3 for id3 in id3s])) - if exclude_id1s: - for id1 in exclude_id1s: - whereclause.append(SessionModel.id1 != id1) - if exclude_id2s: - for id2 in exclude_id2s: - whereclause.append(SessionModel.id2 != id2) - if exclude_id3s: - for id3 in exclude_id3s: - whereclause.append(SessionModel.id3 != id3) + if user_ids: + whereclause.append(or_(*[UserModel.user_id == user_id for user_id in user_ids])) + if exclude_self_ids: + for self_id in exclude_self_ids: + whereclause.append(BotModel.self_id != self_id) + if exclude_adapters: + for adapter in exclude_adapters: + whereclause.append(BotModel.adapter != adapter_value(adapter)) + if exclude_scopes: + for scope in exclude_scopes: + whereclause.append(BotModel.scope != scope_value(scope)) + if exclude_scene_types: + for scene_type in exclude_scene_types: + whereclause.append(SceneModel.scene_type != scene_type_value(scene_type)) + if exclude_scene_ids: + for scene_id in exclude_scene_ids: + whereclause.append(SceneModel.scene_id != scene_id) + if exclude_user_ids: + for user_id in exclude_user_ids: + whereclause.append(UserModel.user_id != user_id) if time_start: whereclause.append(MessageRecord.time >= remove_timezone(time_start)) if time_stop: @@ -121,6 +148,9 @@ async def get_message_records(**kwargs) -> Sequence[MessageRecord]: select(MessageRecord) .where(*whereclause) .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id) + .join(BotModel, BotModel.id == SessionModel.bot_persist_id) + .join(SceneModel, SceneModel.id == SessionModel.scene_persist_id) + .join(UserModel, UserModel.id == SessionModel.user_persist_id) ) async with get_session() as db_session: records = (await db_session.scalars(statement)).all() @@ -138,9 +168,12 @@ async def get_messages(**kwargs) -> list[Message]: """ whereclause = filter_statement(**kwargs) statement = ( - select(MessageRecord.message, SessionModel.bot_type) + select(MessageRecord.message, BotModel.adapter) .where(*whereclause) .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id) + .join(BotModel, BotModel.id == SessionModel.bot_persist_id) + .join(SceneModel, SceneModel.id == SessionModel.scene_persist_id) + .join(UserModel, UserModel.id == SessionModel.user_persist_id) ) async with get_session() as db_session: results = (await db_session.execute(statement)).all() @@ -161,6 +194,9 @@ async def get_messages_plain_text(**kwargs) -> Sequence[str]: select(MessageRecord.plain_text) .where(*whereclause) .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id) + .join(BotModel, BotModel.id == SessionModel.bot_persist_id) + .join(SceneModel, SceneModel.id == SessionModel.scene_persist_id) + .join(UserModel, UserModel.id == SessionModel.user_persist_id) ) async with get_session() as db_session: records = (await db_session.scalars(statement)).all() diff --git a/nonebot_plugin_chatrecorder/record.pyi b/nonebot_plugin_chatrecorder/record.pyi index ba7b7d5..df308ae 100644 --- a/nonebot_plugin_chatrecorder/record.pyi +++ b/nonebot_plugin_chatrecorder/record.pyi @@ -3,7 +3,7 @@ from datetime import datetime from typing import Literal from nonebot.adapters import Message -from nonebot_plugin_session import Session, SessionIdType, SessionLevel +from nonebot_plugin_uninfo import SceneType, Session, SupportAdapter, SupportScope from sqlalchemy.sql import ColumnElement from .model import MessageRecord @@ -11,20 +11,23 @@ from .model import MessageRecord def filter_statement( *, session: Session | None = None, - id_type: SessionIdType = ..., - include_platform: bool = True, - include_bot_type: bool = True, - include_bot_id: bool = True, - bot_ids: Iterable[str] | None = None, - bot_types: Iterable[str] | None = None, - platforms: Iterable[str] | None = None, - levels: Iterable[int | SessionLevel] | None = None, - id1s: Iterable[str] | None = None, - id2s: Iterable[str] | None = None, - id3s: Iterable[str] | None = None, - exclude_id1s: Iterable[str] | None = None, - exclude_id2s: Iterable[str] | None = None, - exclude_id3s: Iterable[str] | None = None, + filter_self_id: bool = True, + filter_adapter: bool = True, + filter_scope: bool = True, + filter_scene: bool = True, + filter_user: bool = True, + self_ids: Iterable[str] | None = None, + adapters: Iterable[str | SupportAdapter] | None = None, + scopes: Iterable[str | SupportScope] | None = None, + scene_types: Iterable[int | SceneType] | None = None, + scene_ids: Iterable[str] | None = None, + user_ids: Iterable[str] | None = None, + exclude_self_ids: Iterable[str] | None = None, + exclude_adapters: Iterable[str | SupportAdapter] | None = None, + exclude_scopes: Iterable[str | SupportScope] | None = None, + exclude_scene_types: Iterable[int | SceneType] | None = None, + exclude_scene_ids: Iterable[str] | None = None, + exclude_user_ids: Iterable[str] | None = None, time_start: datetime | None = None, time_stop: datetime | None = None, types: Iterable[Literal["message", "message_sent"]] | None = None, @@ -32,20 +35,23 @@ def filter_statement( async def get_message_records( *, session: Session | None = None, - id_type: SessionIdType = ..., - include_platform: bool = True, - include_bot_type: bool = True, - include_bot_id: bool = True, - bot_ids: Iterable[str] | None = None, - bot_types: Iterable[str] | None = None, - platforms: Iterable[str] | None = None, - levels: Iterable[int | SessionLevel] | None = None, - id1s: Iterable[str] | None = None, - id2s: Iterable[str] | None = None, - id3s: Iterable[str] | None = None, - exclude_id1s: Iterable[str] | None = None, - exclude_id2s: Iterable[str] | None = None, - exclude_id3s: Iterable[str] | None = None, + filter_self_id: bool = True, + filter_adapter: bool = True, + filter_scope: bool = True, + filter_scene: bool = True, + filter_user: bool = True, + self_ids: Iterable[str] | None = None, + adapters: Iterable[str | SupportAdapter] | None = None, + scopes: Iterable[str | SupportScope] | None = None, + scene_types: Iterable[int | SceneType] | None = None, + scene_ids: Iterable[str] | None = None, + user_ids: Iterable[str] | None = None, + exclude_self_ids: Iterable[str] | None = None, + exclude_adapters: Iterable[str | SupportAdapter] | None = None, + exclude_scopes: Iterable[str | SupportScope] | None = None, + exclude_scene_types: Iterable[int | SceneType] | None = None, + exclude_scene_ids: Iterable[str] | None = None, + exclude_user_ids: Iterable[str] | None = None, time_start: datetime | None = None, time_stop: datetime | None = None, types: Iterable[Literal["message", "message_sent"]] | None = None, @@ -53,20 +59,23 @@ async def get_message_records( async def get_messages( *, session: Session | None = None, - id_type: SessionIdType = ..., - include_platform: bool = True, - include_bot_type: bool = True, - include_bot_id: bool = True, - bot_ids: Iterable[str] | None = None, - bot_types: Iterable[str] | None = None, - platforms: Iterable[str] | None = None, - levels: Iterable[int | SessionLevel] | None = None, - id1s: Iterable[str] | None = None, - id2s: Iterable[str] | None = None, - id3s: Iterable[str] | None = None, - exclude_id1s: Iterable[str] | None = None, - exclude_id2s: Iterable[str] | None = None, - exclude_id3s: Iterable[str] | None = None, + filter_self_id: bool = True, + filter_adapter: bool = True, + filter_scope: bool = True, + filter_scene: bool = True, + filter_user: bool = True, + self_ids: Iterable[str] | None = None, + adapters: Iterable[str | SupportAdapter] | None = None, + scopes: Iterable[str | SupportScope] | None = None, + scene_types: Iterable[int | SceneType] | None = None, + scene_ids: Iterable[str] | None = None, + user_ids: Iterable[str] | None = None, + exclude_self_ids: Iterable[str] | None = None, + exclude_adapters: Iterable[str | SupportAdapter] | None = None, + exclude_scopes: Iterable[str | SupportScope] | None = None, + exclude_scene_types: Iterable[int | SceneType] | None = None, + exclude_scene_ids: Iterable[str] | None = None, + exclude_user_ids: Iterable[str] | None = None, time_start: datetime | None = None, time_stop: datetime | None = None, types: Iterable[Literal["message", "message_sent"]] | None = None, @@ -74,20 +83,23 @@ async def get_messages( async def get_messages_plain_text( *, session: Session | None = None, - id_type: SessionIdType = ..., - include_platform: bool = True, - include_bot_type: bool = True, - include_bot_id: bool = True, - bot_ids: Iterable[str] | None = None, - bot_types: Iterable[str] | None = None, - platforms: Iterable[str] | None = None, - levels: Iterable[int | SessionLevel] | None = None, - id1s: Iterable[str] | None = None, - id2s: Iterable[str] | None = None, - id3s: Iterable[str] | None = None, - exclude_id1s: Iterable[str] | None = None, - exclude_id2s: Iterable[str] | None = None, - exclude_id3s: Iterable[str] | None = None, + filter_self_id: bool = True, + filter_adapter: bool = True, + filter_scope: bool = True, + filter_scene: bool = True, + filter_user: bool = True, + self_ids: Iterable[str] | None = None, + adapters: Iterable[str | SupportAdapter] | None = None, + scopes: Iterable[str | SupportScope] | None = None, + scene_types: Iterable[int | SceneType] | None = None, + scene_ids: Iterable[str] | None = None, + user_ids: Iterable[str] | None = None, + exclude_self_ids: Iterable[str] | None = None, + exclude_adapters: Iterable[str | SupportAdapter] | None = None, + exclude_scopes: Iterable[str | SupportScope] | None = None, + exclude_scene_types: Iterable[int | SceneType] | None = None, + exclude_scene_ids: Iterable[str] | None = None, + exclude_user_ids: Iterable[str] | None = None, time_start: datetime | None = None, time_stop: datetime | None = None, types: Iterable[Literal["message", "message_sent"]] | None = None, diff --git a/nonebot_plugin_chatrecorder/utils.py b/nonebot_plugin_chatrecorder/utils.py index 224392f..51440b5 100644 --- a/nonebot_plugin_chatrecorder/utils.py +++ b/nonebot_plugin_chatrecorder/utils.py @@ -1,8 +1,8 @@ from datetime import datetime, timezone +from typing import Union from nonebot.adapters import Event - -from .consts import SupportedPlatform +from nonebot_plugin_uninfo import SceneType, SupportAdapter, SupportScope def remove_timezone(dt: datetime) -> datetime: @@ -14,19 +14,21 @@ def remove_timezone(dt: datetime) -> datetime: return dt.replace(tzinfo=None) -def format_platform(platform: str) -> str: - if platform in ("onebot", "red", "chronocat"): - return SupportedPlatform.qq - elif platform in ("kook",): - return SupportedPlatform.kaiheila - elif platform not in list(SupportedPlatform): - return SupportedPlatform.unknown - return platform - - def is_fake_event(event: Event) -> bool: return hasattr(event, "_is_fake") and event._is_fake() # type: ignore def record_type(event: Event) -> str: return "fake" if is_fake_event(event) else "message" + + +def scene_type_value(scene_type: Union[int, SceneType]) -> int: + return scene_type.value if isinstance(scene_type, SceneType) else scene_type + + +def adapter_value(adapter: Union[str, SupportAdapter]) -> str: + return adapter.value if isinstance(adapter, SupportAdapter) else adapter + + +def scope_value(scope: Union[str, SupportScope]) -> str: + return scope.value if isinstance(scope, SupportScope) else scope diff --git a/poetry.lock b/poetry.lock index 16923e9..d2c27d6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1366,35 +1366,19 @@ psycopg = ["sqlalchemy[postgresql-psycopgbinary]"] sqlite = ["sqlalchemy[aiosqlite]"] [[package]] -name = "nonebot-plugin-session" -version = "0.3.2" -description = "Nonebot2 会话信息提取与会话id定义" +name = "nonebot-plugin-uninfo" +version = "0.6.1" +description = "Universal Information Model for Nonebot2" optional = false -python-versions = "<4.0,>=3.9" -files = [ - {file = "nonebot_plugin_session-0.3.2-py3-none-any.whl", hash = "sha256:785e74ff656e46d84c4dbbac125f9a9adebcc4b9ff7db700527e6e47d47aabd6"}, - {file = "nonebot_plugin_session-0.3.2.tar.gz", hash = "sha256:da0dabe9108151052a6e83d9923068d2ec43a93810ad37b264ed65aee2c69f93"}, -] - -[package.dependencies] -nonebot2 = ">=2.3.0,<3.0.0" -strenum = ">=0.4.15,<0.5.0" - -[[package]] -name = "nonebot-plugin-session-orm" -version = "0.2.0" -description = "session 插件 orm 扩展" -optional = false -python-versions = ">=3.8,<4.0" +python-versions = ">=3.9" files = [ - {file = "nonebot_plugin_session_orm-0.2.0-py3-none-any.whl", hash = "sha256:e0a6009803feccf4e98db483e1f9072122ddbca40df0b2d1b45741c8e5332acb"}, - {file = "nonebot_plugin_session_orm-0.2.0.tar.gz", hash = "sha256:420e210898a3f348cebbb4ea0816bab66f3299c7b0c01e929a01967d60cc438c"}, + {file = "nonebot_plugin_uninfo-0.6.1-py3-none-any.whl", hash = "sha256:c6fef664af82955a50e1fd91040c2f180eaa9ee0fd09aa6cb6b385e11aede590"}, + {file = "nonebot_plugin_uninfo-0.6.1.tar.gz", hash = "sha256:f5eefa4992ab8a5a2fad9d848f1b2077303d825f61ee237f0131968506c32708"}, ] [package.dependencies] -nonebot-plugin-orm = ">=0.7.0,<1.0.0" -nonebot-plugin-session = ">=0.3.0,<0.4.0" -nonebot2 = {version = ">=2.2.0,<3.0.0", extras = ["fastapi"]} +importlib-metadata = ">=4.13.0" +nonebot2 = ">=2.3.0" [[package]] name = "nonebot2" @@ -2142,31 +2126,15 @@ typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\"" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] -[[package]] -name = "strenum" -version = "0.4.15" -description = "An Enum that inherits from str." -optional = false -python-versions = "*" -files = [ - {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, - {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, -] - -[package.extras] -docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"] -release = ["twine"] -test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"] - [[package]] name = "textual" -version = "0.87.1" +version = "0.88.0" description = "Modern Text User Interface framework" optional = false python-versions = "<4.0.0,>=3.8.1" files = [ - {file = "textual-0.87.1-py3-none-any.whl", hash = "sha256:026d1368cd10610a72a9d3de7a56692a17e7e8dffa0468147eb8e186ba0ff0c0"}, - {file = "textual-0.87.1.tar.gz", hash = "sha256:daf4e248ba3d890831ff2617099535eb835863a2e3609c8ce00af0f6d55ed123"}, + {file = "textual-0.88.0-py3-none-any.whl", hash = "sha256:87a1085a403e3a95aa4b954c530d46947d830e9ad4b8c15490104c0b4a452b6a"}, + {file = "textual-0.88.0.tar.gz", hash = "sha256:bf9cc3ec9d34957c361eabf739e59272295323478cc822633fb0a7b7cc2a0ac3"}, ] [package.dependencies] @@ -2631,4 +2599,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4b888819c48a220b6908a90b359f4ac463ce5ad34d01a0e96412026f8f656365" +content-hash = "2997deba3ecbf419a35dd7f2ced1eb2c0ae0bdd34ec8329d2a2c1cdd83d3cafa" diff --git a/pyproject.toml b/pyproject.toml index 480d6a4..b5fa916 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,8 @@ repository = "https://github.com/noneplugin/nonebot-plugin-chatrecorder" [tool.poetry.dependencies] python = "^3.9" nonebot2 = "^2.3.0" -nonebot-plugin-session = "^0.3.0" nonebot-plugin-orm = ">=0.7.0,<1.0.0" -nonebot-plugin-session-orm = "^0.2.0" +nonebot-plugin-uninfo = ">=0.6.1,<1.0.0" nonebot-plugin-localstore = ">=0.6.0,<1.0.0" [tool.poetry.group.dev.dependencies] @@ -45,6 +44,7 @@ nonebot-adapter-qq = "^1.5.0" [tool.pytest.ini_options] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "session" [tool.nonebot] plugins = ["nonebot_plugin_chatrecorder"] diff --git a/tests/conftest.py b/tests/conftest.py index bf65a80..8ba4e97 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,12 +11,13 @@ from nonebot.adapters.satori import Adapter as SatoriAdapter from nonebot.adapters.telegram import Adapter as TelegramAdapter from nonebug import NONEBOT_INIT_KWARGS, App -from sqlalchemy import delete +from sqlalchemy import StaticPool, delete def pytest_configure(config: pytest.Config) -> None: config.stash[NONEBOT_INIT_KWARGS] = { "sqlalchemy_database_url": "sqlite+aiosqlite:///:memory:", + "sqlalchemy_engine_options": {"poolclass": StaticPool}, "alembic_startup_check": False, "driver": "~fastapi+~websockets+~httpx", } @@ -27,7 +28,7 @@ async def app(): nonebot.require("nonebot_plugin_chatrecorder") from nonebot_plugin_orm import get_session, init_orm - from nonebot_plugin_session_orm import SessionModel + from nonebot_plugin_uninfo.orm import BotModel, SceneModel, SessionModel, UserModel from nonebot_plugin_chatrecorder.model import MessageRecord @@ -38,11 +39,14 @@ async def app(): async with get_session() as db_session: await db_session.execute(delete(MessageRecord)) await db_session.execute(delete(SessionModel)) + await db_session.execute(delete(UserModel)) + await db_session.execute(delete(SceneModel)) + await db_session.execute(delete(BotModel)) await db_session.commit() @pytest.fixture(scope="session", autouse=True) -def load_adapters(nonebug_init: None): +def after_nonebot_init(after_nonebot_init: None): driver = nonebot.get_driver() driver.register_adapter(ConsoleAdapter) driver.register_adapter(DiscordAdapter) diff --git a/tests/test_console.py b/tests/test_console.py index ed1d455..ce5b2ea 100644 --- a/tests/test_console.py +++ b/tests/test_console.py @@ -1,4 +1,5 @@ from datetime import datetime, timezone +from unittest.mock import patch from nonebot import get_driver from nonebot.adapters.console import Adapter, Bot, Message, MessageEvent @@ -28,31 +29,37 @@ class Config: async def test_record_recv_msg(app: App): # 测试记录收到的消息 + from nonebot_plugin_uninfo import Scene, SceneType, Session + from nonebot_plugin_uninfo import User as UninfoUser + + from nonebot_plugin_chatrecorder.adapters.console import record_recv_msg from nonebot_plugin_chatrecorder.message import serialize_message time = 1000000 user_id = "User" message = Message("test_record_recv_msg") - async with app.test_matcher() as ctx: + async with app.test_api() as ctx: adapter = get_driver()._adapters[Adapter.get_name()] bot = ctx.create_bot(base=Bot, adapter=adapter, self_id="Bot") - event = fake_message_event( - time=datetime.fromtimestamp(time, timezone.utc), - user=User(id=user_id), - message=message, - ) - ctx.receive_event(bot, event) + event = fake_message_event( + time=datetime.fromtimestamp(time, timezone.utc), + user=User(id=user_id), + message=message, + ) + session = Session( + self_id="Bot", + adapter="Console", + scope="Console", + scene=Scene(id=user_id, type=SceneType.PRIVATE), + user=UninfoUser(id=user_id), + ) + with patch("nonebot_plugin_chatrecorder.adapters.console.get_id", return_value="0"): + await record_recv_msg(event, session) await check_record( - "Bot", - "Console", - "console", - 1, - user_id, - None, - None, + session, datetime.fromtimestamp(time, timezone.utc), "message", "0", @@ -63,6 +70,9 @@ async def test_record_recv_msg(app: App): async def test_record_send_msg(app: App): # 测试记录发送的消息 + from nonebot_plugin_uninfo import Scene, SceneType, Session + from nonebot_plugin_uninfo import User as UninfoUser + from nonebot_plugin_chatrecorder.adapters.console import record_send_msg from nonebot_plugin_chatrecorder.message import serialize_message @@ -73,17 +83,19 @@ async def test_record_send_msg(app: App): user_id = "User" elements = ConsoleMessage([Text("test_record_send_msg")]) message = Message("test_record_send_msg") - await record_send_msg( - bot, None, "send_msg", {"user_id": user_id, "message": elements}, None - ) + + with patch("nonebot_plugin_chatrecorder.adapters.console.get_id", return_value="1"): + await record_send_msg( + bot, None, "send_msg", {"user_id": user_id, "message": elements}, None + ) await check_record( - "Bot", - "Console", - "console", - 1, - user_id, - None, - None, + Session( + self_id="Bot", + adapter="Console", + scope="Console", + scene=Scene(id=user_id, type=SceneType.PRIVATE), + user=UninfoUser(id="Bot"), + ), None, "message_sent", "1", diff --git a/tests/test_discord.py b/tests/test_discord.py index 1574cbc..491c42e 100644 --- a/tests/test_discord.py +++ b/tests/test_discord.py @@ -95,6 +95,10 @@ def fake_direct_message_event(content: str, msg_id: int) -> DirectMessageCreateE async def test_record_recv_msg(app: App): """测试记录收到的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session + from nonebot_plugin_uninfo import User as UninfoUser + + from nonebot_plugin_chatrecorder.adapters.discord import record_recv_msg from nonebot_plugin_chatrecorder.message import serialize_message guild_msg = "test guild message" @@ -103,26 +107,27 @@ async def test_record_recv_msg(app: App): direct_msg = "test direct message" direct_msg_id = 11235 - async with app.test_matcher() as ctx: + async with app.test_api() as ctx: adapter = get_driver()._adapters[Adapter.get_name()] bot = ctx.create_bot( base=Bot, adapter=adapter, self_id="2233", bot_info=BotInfo(token="1234") ) - event = fake_guild_message_event(guild_msg, guild_msg_id) - ctx.receive_event(bot, event) - - event = fake_direct_message_event(direct_msg, direct_msg_id) - ctx.receive_event(bot, event) - + event = fake_guild_message_event(guild_msg, guild_msg_id) + session = Session( + self_id="2233", + adapter="Discord", + scope="Discord", + scene=Scene( + id="5566", + type=SceneType.CHANNEL_TEXT, + parent=Scene(id="6677", type=SceneType.GUILD), + ), + user=UninfoUser(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "Discord", - "discord", - 3, - "3344", - "5566", - "6677", + session, datetime.fromtimestamp(123456, timezone.utc), "message", str(guild_msg_id), @@ -130,14 +135,17 @@ async def test_record_recv_msg(app: App): guild_msg, ) + event = fake_direct_message_event(direct_msg, direct_msg_id) + session = Session( + self_id="2233", + adapter="Discord", + scope="Discord", + scene=Scene(id="5566", type=SceneType.PRIVATE), + user=UninfoUser(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "Discord", - "discord", - 1, - "3344", - "5566", - None, + session, datetime.fromtimestamp(123456, timezone.utc), "message", str(direct_msg_id), @@ -148,6 +156,9 @@ async def test_record_recv_msg(app: App): async def test_record_send_msg(app: App): """测试记录发送的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session + from nonebot_plugin_uninfo import User as UninfoUser + from nonebot_plugin_chatrecorder.adapters.discord import record_send_msg from nonebot_plugin_chatrecorder.message import serialize_message @@ -216,13 +227,17 @@ async def test_record_send_msg(app: App): ), ) await check_record( - "2233", - "Discord", - "discord", - 3, - None, - "5566", - "6677", + Session( + self_id="2233", + adapter="Discord", + scope="Discord", + scene=Scene( + id="5566", + type=SceneType.CHANNEL_TEXT, + parent=Scene(id="6677", type=SceneType.GUILD), + ), + user=UninfoUser(id="2233"), + ), datetime.fromtimestamp(123456, timezone.utc), "message_sent", "11236", @@ -298,13 +313,13 @@ async def test_record_send_msg(app: App): ), ) await check_record( - "2233", - "Discord", - "discord", - 1, - "3344", - "5555", - None, + Session( + self_id="2233", + adapter="Discord", + scope="Discord", + scene=Scene(id="5555", type=SceneType.PRIVATE), + user=UninfoUser(id="2233"), + ), datetime.fromtimestamp(123456, timezone.utc), "message_sent", "11237", diff --git a/tests/test_dodo.py b/tests/test_dodo.py index 22e5524..c51fe8e 100644 --- a/tests/test_dodo.py +++ b/tests/test_dodo.py @@ -63,6 +63,9 @@ def fake_personal_message_event(content: str, message_id: str) -> PersonalMessag async def test_record_recv_msg(app: App): """测试记录收到的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session, User + + from nonebot_plugin_chatrecorder.adapters.dodo import record_recv_msg from nonebot_plugin_chatrecorder.message import serialize_message channel_msg = "test channel message" @@ -71,7 +74,7 @@ async def test_record_recv_msg(app: App): personal_msg = "test personal message" personal_msg_id = "123457" - async with app.test_matcher() as ctx: + async with app.test_api() as ctx: adapter = get_driver()._adapters[Adapter.get_name()] bot = ctx.create_bot( base=Bot, @@ -80,20 +83,21 @@ async def test_record_recv_msg(app: App): bot_config=BotConfig(client_id="1234", token="xxxx"), ) - event = fake_channel_message_event(channel_msg, channel_msg_id) - ctx.receive_event(bot, event) - - event = fake_personal_message_event(personal_msg, personal_msg_id) - ctx.receive_event(bot, event) - + event = fake_channel_message_event(channel_msg, channel_msg_id) + session = Session( + self_id="2233", + adapter="DoDo", + scope="DoDo", + scene=Scene( + id="5566", + type=SceneType.CHANNEL_TEXT, + parent=Scene(id="7788", type=SceneType.GUILD), + ), + user=User(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "DoDo", - "dodo", - 3, - "3344", - "5566", - "7788", + session, datetime.fromtimestamp(12345678, timezone.utc), "message", channel_msg_id, @@ -101,14 +105,21 @@ async def test_record_recv_msg(app: App): channel_msg, ) + event = fake_personal_message_event(personal_msg, personal_msg_id) + session = Session( + self_id="2233", + adapter="DoDo", + scope="DoDo", + scene=Scene( + id="3344", + type=SceneType.PRIVATE, + parent=Scene(id="7788", type=SceneType.GUILD), + ), + user=User(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "DoDo", - "dodo", - 1, - "3344", - None, - "7788", + session, datetime.fromtimestamp(12345678, timezone.utc), "message", personal_msg_id, @@ -119,6 +130,8 @@ async def test_record_recv_msg(app: App): async def test_record_send_msg(app: App): """测试记录发送的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session, User + from nonebot_plugin_chatrecorder.adapters.dodo import record_send_msg from nonebot_plugin_chatrecorder.message import serialize_message @@ -144,13 +157,13 @@ async def test_record_send_msg(app: App): MessageReturn(message_id="123458"), ) await check_record( - "2233", - "DoDo", - "dodo", - 3, - None, - "5566", - None, + Session( + self_id="2233", + adapter="DoDo", + scope="DoDo", + scene=Scene(id="5566", type=SceneType.CHANNEL_TEXT), + user=User(id="2233"), + ), None, "message_sent", "123458", @@ -171,13 +184,17 @@ async def test_record_send_msg(app: App): MessageReturn(message_id="123459"), ) await check_record( - "2233", - "DoDo", - "dodo", - 1, - "3344", - None, - "7788", + Session( + self_id="2233", + adapter="DoDo", + scope="DoDo", + scene=Scene( + id="3344", + type=SceneType.PRIVATE, + parent=Scene(id="7788", type=SceneType.GUILD), + ), + user=User(id="2233"), + ), None, "message_sent", "123459", diff --git a/tests/test_fake_event.py b/tests/test_fake_event.py index 1b5397b..955b8f3 100644 --- a/tests/test_fake_event.py +++ b/tests/test_fake_event.py @@ -39,6 +39,9 @@ def _is_fake(self) -> bool: async def test_record_recv_msg(app: App): """测试记录收到的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session, User + + from nonebot_plugin_chatrecorder.adapters.onebot_v11 import record_recv_msg from nonebot_plugin_chatrecorder.message import serialize_message time = 1000000 @@ -46,23 +49,23 @@ async def test_record_recv_msg(app: App): message_id = 1145141919810 message = Message("test private message") - async with app.test_matcher() as ctx: + async with app.test_api() as ctx: adapter = get_driver()._adapters[Adapter.get_name()] bot = ctx.create_bot(base=Bot, adapter=adapter, self_id="11") - event = fake_private_message_event( - time=time, user_id=user_id, message_id=message_id, message=message - ) - ctx.receive_event(bot, event) - + event = fake_private_message_event( + time=time, user_id=user_id, message_id=message_id, message=message + ) + session = Session( + self_id="11", + adapter="OneBot V11", + scope="QQClient", + scene=Scene(id=str(user_id), type=SceneType.PRIVATE), + user=User(id=str(user_id)), + ) + await record_recv_msg(event, session) await check_record( - "11", - "OneBot V11", - "qq", - 1, - str(user_id), - None, - None, + session, datetime.fromtimestamp(time, timezone.utc), "fake", str(message_id), diff --git a/tests/test_feishu.py b/tests/test_feishu.py index 23cc750..9ff5814 100644 --- a/tests/test_feishu.py +++ b/tests/test_feishu.py @@ -112,6 +112,9 @@ def fake_group_message_event( async def test_record_recv_msg(app: App): """测试记录收到的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session, User + + from nonebot_plugin_chatrecorder.adapters.feishu import record_recv_msg from nonebot_plugin_chatrecorder.message import serialize_message msg_type = "text" @@ -120,7 +123,7 @@ async def test_record_recv_msg(app: App): group_msg_id = "om_2" message = Message.deserialize(content, None, msg_type) - async with app.test_matcher() as ctx: + async with app.test_api() as ctx: adapter = get_driver()._adapters[Adapter.get_name()] bot = ctx.create_bot( base=Bot, @@ -130,20 +133,17 @@ async def test_record_recv_msg(app: App): bot_info=BOT_INFO, ) - event = fake_private_message_event(msg_type, content, private_msg_id) - ctx.receive_event(bot, event) - - event = fake_group_message_event(msg_type, content, group_msg_id) - ctx.receive_event(bot, event) - + event = fake_private_message_event(msg_type, content, private_msg_id) + session = Session( + self_id="2233", + adapter="Feishu", + scope="Feishu", + scene=Scene(id="3344", type=SceneType.PRIVATE), + user=User(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "Feishu", - "feishu", - 1, - "3344", - None, - None, + session, datetime.fromtimestamp(123456000 / 1000, timezone.utc), "message", private_msg_id, @@ -151,14 +151,17 @@ async def test_record_recv_msg(app: App): message.extract_plain_text(), ) + event = fake_group_message_event(msg_type, content, group_msg_id) + session = Session( + self_id="2233", + adapter="Feishu", + scope="Feishu", + scene=Scene(id="1122", type=SceneType.GROUP), + user=User(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "Feishu", - "feishu", - 2, - "3344", - "1122", - None, + session, datetime.fromtimestamp(123456000 / 1000, timezone.utc), "message", group_msg_id, @@ -169,6 +172,8 @@ async def test_record_recv_msg(app: App): async def test_record_send_msg(app: App): """测试记录发送的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session, User + from nonebot_plugin_chatrecorder.adapters.feishu import record_send_msg from nonebot_plugin_chatrecorder.message import serialize_message @@ -233,13 +238,13 @@ async def test_record_send_msg(app: App): }, ) await check_record( - "2233", - "Feishu", - "feishu", - 2, - None, - "oc_123", - None, + Session( + self_id="2233", + adapter="Feishu", + scope="Feishu", + scene=Scene(id="oc_123", type=SceneType.GROUP), + user=User(id="2233"), + ), datetime.fromtimestamp(123456000 / 1000, timezone.utc), "message_sent", "om_3", @@ -294,13 +299,13 @@ async def test_record_send_msg(app: App): }, ) await check_record( - "2233", - "Feishu", - "feishu", - 1, - "3344", - None, - None, + Session( + self_id="2233", + adapter="Feishu", + scope="Feishu", + scene=Scene(id="3344", type=SceneType.PRIVATE), + user=User(id="2233"), + ), datetime.fromtimestamp(123456000 / 1000, timezone.utc), "message_sent", "om_4", @@ -339,13 +344,13 @@ async def test_record_send_msg(app: App): }, ) await check_record( - "2233", - "Feishu", - "feishu", - 2, - None, - "oc_123", - None, + Session( + self_id="2233", + adapter="Feishu", + scope="Feishu", + scene=Scene(id="oc_123", type=SceneType.GROUP), + user=User(id="2233"), + ), datetime.fromtimestamp(123456000 / 1000, timezone.utc), "message_sent", "om_5", @@ -384,13 +389,13 @@ async def test_record_send_msg(app: App): }, ) await check_record( - "2233", - "Feishu", - "feishu", - 1, - "3344", - None, - None, + Session( + self_id="2233", + adapter="Feishu", + scope="Feishu", + scene=Scene(id="3344", type=SceneType.PRIVATE), + user=User(id="2233"), + ), datetime.fromtimestamp(123456000 / 1000, timezone.utc), "message_sent", "om_6", diff --git a/tests/test_get_records.py b/tests/test_get_records.py index a1b2792..62f6987 100644 --- a/tests/test_get_records.py +++ b/tests/test_get_records.py @@ -14,8 +14,15 @@ async def test_get_message_records(app: App): """测试获取消息记录""" from nonebot_plugin_orm import get_session - from nonebot_plugin_session import Session, SessionIdType, SessionLevel - from nonebot_plugin_session_orm import get_session_persist_id + from nonebot_plugin_uninfo import ( + Scene, + SceneType, + Session, + SupportAdapter, + SupportScope, + User, + ) + from nonebot_plugin_uninfo.orm import get_session_persist_id from nonebot_plugin_chatrecorder.message import serialize_message from nonebot_plugin_chatrecorder.model import MessageRecord @@ -39,49 +46,43 @@ async def test_get_message_records(app: App): sessions = [ Session( - bot_id="100", - bot_type="OneBot V11", - platform="qq", - level=SessionLevel.LEVEL1, - id1="1000", - id2=None, - id3=None, + self_id="100", + adapter=SupportAdapter.onebot11, + scope=SupportScope.qq_client, + scene=Scene(id="1000", type=SceneType.PRIVATE), + user=User(id="1000"), ), Session( - bot_id="101", - bot_type="OneBot V11", - platform="qq", - level=SessionLevel.LEVEL2, - id1="1000", - id2="10000", - id3=None, + self_id="101", + adapter=SupportAdapter.onebot11, + scope=SupportScope.qq_client, + scene=Scene(id="10000", type=SceneType.GROUP), + user=User(id="1000"), ), Session( - bot_id="100", - bot_type="OneBot V12", - platform="qq", - level=SessionLevel.LEVEL1, - id1="1001", - id2=None, - id3=None, + self_id="100", + adapter=SupportAdapter.onebot12, + scope=SupportScope.qq_client, + scene=Scene(id="1001", type=SceneType.PRIVATE), + user=User(id="1001"), ), Session( - bot_id="102", - bot_type="OneBot V12", - platform="telegram", - level=SessionLevel.LEVEL2, - id1="1002", - id2="10001", - id3=None, + self_id="102", + adapter=SupportAdapter.onebot12, + scope=SupportScope.telegram, + scene=Scene(id="10001", type=SceneType.GROUP), + user=User(id="1002"), ), Session( - bot_id="103", - bot_type="OneBot V12", - platform="kook", - level=SessionLevel.LEVEL3, - id1="1003", - id2="10002", - id3="100000", + self_id="103", + adapter=SupportAdapter.onebot12, + scope=SupportScope.kook, + scene=Scene( + id="10002", + type=SceneType.CHANNEL_TEXT, + parent=Scene(id="100000", type=SceneType.GUILD), + ), + user=User(id="1003"), ), ] session_persist_ids: list[int] = [] @@ -146,12 +147,12 @@ async def test_get_message_records(app: App): for msg in msgs: assert isinstance(msg, Message) - msgs = await get_messages(bot_types=["OneBot V11"]) + msgs = await get_messages(adapters=[SupportAdapter.onebot11]) assert len(msgs) == 2 for msg in msgs: assert isinstance(msg, V11Msg) - msgs = await get_messages(bot_types=["OneBot V12"]) + msgs = await get_messages(adapters=[SupportAdapter.onebot12]) assert len(msgs) == 3 for msg in msgs: assert isinstance(msg, V12Msg) @@ -161,19 +162,19 @@ async def test_get_message_records(app: App): for msg in msgs: assert isinstance(msg, str) - msgs = await get_message_records(bot_types=["OneBot V11"]) + msgs = await get_message_records(adapters=[SupportAdapter.onebot11]) assert len(msgs) == 2 - msgs = await get_message_records(bot_types=["OneBot V12"]) + msgs = await get_message_records(adapters=[SupportAdapter.onebot12]) assert len(msgs) == 3 - msgs = await get_message_records(bot_ids=["100"]) + msgs = await get_message_records(self_ids=["100"]) assert len(msgs) == 2 - msgs = await get_message_records(bot_ids=["101", "102", "103"]) + msgs = await get_message_records(self_ids=["101", "102", "103"]) assert len(msgs) == 3 - msgs = await get_message_records(platforms=["qq"]) + msgs = await get_message_records(scopes=[SupportScope.qq_client]) assert len(msgs) == 3 - msgs = await get_message_records(platforms=["telegram", "kook"]) + msgs = await get_message_records(scopes=[SupportScope.telegram, SupportScope.kook]) assert len(msgs) == 2 msgs = await get_message_records( @@ -199,74 +200,81 @@ async def test_get_message_records(app: App): msgs = await get_message_records(types=["message_sent"]) assert len(msgs) == 1 - msgs = await get_message_records(levels=[1]) + msgs = await get_message_records(scene_types=[SceneType.PRIVATE]) assert len(msgs) == 2 - msgs = await get_message_records(levels=[2]) + msgs = await get_message_records(scene_types=[SceneType.GROUP]) assert len(msgs) == 2 - msgs = await get_message_records(levels=[3]) + msgs = await get_message_records(scene_types=[SceneType.CHANNEL_TEXT]) assert len(msgs) == 1 - msgs = await get_message_records(id1s=["1000"]) + msgs = await get_message_records(user_ids=["1000"]) assert len(msgs) == 2 - msgs = await get_message_records(exclude_id1s=["1000"]) + msgs = await get_message_records(exclude_user_ids=["1000"]) assert len(msgs) == 3 - msgs = await get_message_records(id2s=["10000"]) + msgs = await get_message_records(scene_ids=["10000"]) assert len(msgs) == 1 - msgs = await get_message_records(exclude_id2s=["10000"]) + msgs = await get_message_records(exclude_scene_ids=["10000"]) assert len(msgs) == 4 - msgs = await get_message_records(id3s=["100000"]) - assert len(msgs) == 1 - msgs = await get_message_records(exclude_id3s=["100000"]) - assert len(msgs) == 4 + # msgs = await get_message_records(scene_ids=["100000"]) + # assert len(msgs) == 1 + # msgs = await get_message_records(exclude_scene_ids=["100000"]) + # assert len(msgs) == 4 - msgs = await get_message_records(session=sessions[1], id_type=SessionIdType.GROUP) + msgs = await get_message_records( + session=sessions[1], filter_scene=True, filter_user=False + ) assert len(msgs) == 1 msgs = await get_message_records( - session=sessions[0], include_bot_type=False, id1s=["1001"] + session=sessions[0], filter_adapter=False, user_ids=["1001"] ) assert len(msgs) == 0 msgs = await get_message_records( - session=sessions[0], include_bot_type=False, id1s=["1000", "1001"] + session=sessions[0], filter_adapter=False, user_ids=["1000", "1001"] ) assert len(msgs) == 1 msgs = await get_message_records( - session=sessions[1], id_type=SessionIdType.GROUP_USER + session=sessions[1], filter_scene=True, filter_user=True ) assert len(msgs) == 1 - msgs = await get_message_records(session=sessions[1], id_type=SessionIdType.USER) + msgs = await get_message_records( + session=sessions[1], filter_scene=False, filter_user=True + ) assert len(msgs) == 1 msgs = await get_message_records( - session=sessions[1], id_type=SessionIdType.USER, include_bot_id=False + session=sessions[1], filter_scene=False, filter_user=True, filter_self_id=False ) assert len(msgs) == 2 - msgs = await get_message_records(session=sessions[1], id_type=SessionIdType.GLOBAL) + msgs = await get_message_records( + session=sessions[1], filter_scene=False, filter_user=False + ) assert len(msgs) == 1 msgs = await get_message_records( - session=sessions[0], id_type=SessionIdType.GLOBAL, include_bot_type=False + session=sessions[0], filter_scene=False, filter_user=False, filter_adapter=False ) assert len(msgs) == 2 msgs = await get_message_records( session=sessions[0], - id_type=SessionIdType.GLOBAL, - include_bot_type=False, - exclude_id1s=["1000"], + filter_scene=False, + filter_user=False, + filter_adapter=False, + exclude_user_ids=["1000"], ) assert len(msgs) == 1 - msgs = await get_messages(session=sessions[1], id_type=SessionIdType.GROUP) + msgs = await get_messages(session=sessions[1], filter_scene=True, filter_user=False) assert len(msgs) == 1 msgs = await get_messages_plain_text( - session=sessions[1], id_type=SessionIdType.GROUP + session=sessions[1], filter_scene=True, filter_user=False ) assert len(msgs) == 1 diff --git a/tests/test_kaiheila.py b/tests/test_kaiheila.py index 3048538..5babd70 100644 --- a/tests/test_kaiheila.py +++ b/tests/test_kaiheila.py @@ -88,6 +88,10 @@ def fake_channel_message_event(content: str, msg_id: str) -> ChannelMessageEvent async def test_record_recv_msg(app: App): """测试记录收到的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session + from nonebot_plugin_uninfo import User as UninfoUser + + from nonebot_plugin_chatrecorder.adapters.kaiheila import record_recv_msg from nonebot_plugin_chatrecorder.message import serialize_message private_msg = "test private message" @@ -96,26 +100,23 @@ async def test_record_recv_msg(app: App): channel_msg = "test channel message" channel_msg_id = "4456" - async with app.test_matcher() as ctx: + async with app.test_api() as ctx: adapter = get_driver()._adapters[Adapter.get_name()] bot = ctx.create_bot( base=Bot, adapter=adapter, self_id="2233", name="Bot", token="" ) - event = fake_private_message_event(private_msg, private_msg_id) - ctx.receive_event(bot, event) - - event = fake_channel_message_event(channel_msg, channel_msg_id) - ctx.receive_event(bot, event) - + event = fake_private_message_event(private_msg, private_msg_id) + session = Session( + self_id="2233", + adapter="Kaiheila", + scope="Kaiheila", + scene=Scene(id="3344", type=SceneType.PRIVATE), + user=UninfoUser(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "Kaiheila", - "kaiheila", - 1, - "3344", - None, - None, + session, datetime.fromtimestamp(1234000 / 1000, timezone.utc), "message", private_msg_id, @@ -123,14 +124,21 @@ async def test_record_recv_msg(app: App): private_msg, ) + event = fake_channel_message_event(channel_msg, channel_msg_id) + session = Session( + self_id="2233", + adapter="Kaiheila", + scope="Kaiheila", + scene=Scene( + id="6677", + type=SceneType.CHANNEL_TEXT, + parent=Scene(id="5566", type=SceneType.GUILD), + ), + user=UninfoUser(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "Kaiheila", - "kaiheila", - 3, - "3344", - "6677", - "5566", + session, datetime.fromtimestamp(1234000 / 1000, timezone.utc), "message", channel_msg_id, @@ -141,6 +149,9 @@ async def test_record_recv_msg(app: App): async def test_record_send_msg(app: App): """测试记录发送的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session + from nonebot_plugin_uninfo import User as UninfoUser + from nonebot_plugin_chatrecorder.adapters.kaiheila import record_send_msg from nonebot_plugin_chatrecorder.message import serialize_message @@ -158,13 +169,13 @@ async def test_record_send_msg(app: App): MessageCreateReturn(msg_id="4457", msg_timestamp=1234000, nonce="xxx"), ) await check_record( - "2233", - "Kaiheila", - "kaiheila", - 3, - None, - None, - "6677", + Session( + self_id="2233", + adapter="Kaiheila", + scope="Kaiheila", + scene=Scene(id="6677", type=SceneType.CHANNEL_TEXT), + user=UninfoUser(id="2233"), + ), datetime.fromtimestamp(1234000 / 1000, timezone.utc), "message_sent", "4457", @@ -180,13 +191,13 @@ async def test_record_send_msg(app: App): MessageCreateReturn(msg_id="4458", msg_timestamp=1234000, nonce="xxx"), ) await check_record( - "2233", - "Kaiheila", - "kaiheila", - 1, - "3344", - None, - None, + Session( + self_id="2233", + adapter="Kaiheila", + scope="Kaiheila", + scene=Scene(id="3344", type=SceneType.PRIVATE), + user=UninfoUser(id="2233"), + ), datetime.fromtimestamp(1234000 / 1000, timezone.utc), "message_sent", "4458", @@ -202,13 +213,13 @@ async def test_record_send_msg(app: App): MessageCreateReturn(msg_id="4459", msg_timestamp=1234000, nonce="xxx"), ) await check_record( - "2233", - "Kaiheila", - "kaiheila", - 3, - None, - None, - "6677", + Session( + self_id="2233", + adapter="Kaiheila", + scope="Kaiheila", + scene=Scene(id="6677", type=SceneType.CHANNEL_TEXT), + user=UninfoUser(id="2233"), + ), datetime.fromtimestamp(1234000 / 1000, timezone.utc), "message_sent", "4459", @@ -251,13 +262,13 @@ async def test_record_send_msg(app: App): MessageCreateReturn(msg_id="4460", msg_timestamp=1234000, nonce="xxx"), ) await check_record( - "2233", - "Kaiheila", - "kaiheila", - 3, - None, - None, - "6677", + Session( + self_id="2233", + adapter="Kaiheila", + scope="Kaiheila", + scene=Scene(id="6677", type=SceneType.CHANNEL_TEXT), + user=UninfoUser(id="2233"), + ), datetime.fromtimestamp(1234000 / 1000, timezone.utc), "message_sent", "4460", diff --git a/tests/test_onebot_v11.py b/tests/test_onebot_v11.py index 5bec838..fc3a453 100644 --- a/tests/test_onebot_v11.py +++ b/tests/test_onebot_v11.py @@ -71,6 +71,9 @@ class Config: async def test_record_recv_msg(app: App): """测试记录收到的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session, User + + from nonebot_plugin_chatrecorder.adapters.onebot_v11 import record_recv_msg from nonebot_plugin_chatrecorder.message import serialize_message time = 1000000 @@ -82,32 +85,27 @@ async def test_record_recv_msg(app: App): private_msg = Message("test private message") private_msg_id = 11451422222 - async with app.test_matcher() as ctx: + async with app.test_api() as ctx: adapter = get_driver()._adapters[Adapter.get_name()] bot = ctx.create_bot(base=Bot, adapter=adapter, self_id="11") - event = fake_group_message_event( - time=time, - user_id=user_id, - group_id=group_id, - message_id=group_msg_id, - message=group_msg, - ) - ctx.receive_event(bot, event) - - event = fake_private_message_event( - time=time, user_id=user_id, message_id=private_msg_id, message=private_msg - ) - ctx.receive_event(bot, event) - + event = fake_group_message_event( + time=time, + user_id=user_id, + group_id=group_id, + message_id=group_msg_id, + message=group_msg, + ) + session = Session( + self_id="11", + adapter="OneBot V11", + scope="QQClient", + scene=Scene(id=str(group_id), type=SceneType.GROUP), + user=User(id=str(user_id)), + ) + await record_recv_msg(event, session) await check_record( - "11", - "OneBot V11", - "qq", - 2, - str(user_id), - str(group_id), - None, + session, datetime.fromtimestamp(time, timezone.utc), "message", str(group_msg_id), @@ -115,14 +113,19 @@ async def test_record_recv_msg(app: App): group_msg.extract_plain_text(), ) + event = fake_private_message_event( + time=time, user_id=user_id, message_id=private_msg_id, message=private_msg + ) + session = Session( + self_id="11", + adapter="OneBot V11", + scope="QQClient", + scene=Scene(id=str(user_id), type=SceneType.PRIVATE), + user=User(id=str(user_id)), + ) + await record_recv_msg(event, session) await check_record( - "11", - "OneBot V11", - "qq", - 1, - str(user_id), - None, - None, + session, datetime.fromtimestamp(time, timezone.utc), "message", str(private_msg_id), @@ -133,6 +136,8 @@ async def test_record_recv_msg(app: App): async def test_record_send_msg(app: App): """测试记录发送的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session, User + from nonebot_plugin_chatrecorder.adapters.onebot_v11 import record_send_msg from nonebot_plugin_chatrecorder.message import serialize_message @@ -153,13 +158,13 @@ async def test_record_send_msg(app: App): {"message_id": message_id}, ) await check_record( - "11", - "OneBot V11", - "qq", - 2, - None, - str(group_id), - None, + Session( + self_id="11", + adapter="OneBot V11", + scope="QQClient", + scene=Scene(id=str(group_id), type=SceneType.GROUP), + user=User(id="11"), + ), None, "message_sent", str(message_id), @@ -177,13 +182,13 @@ async def test_record_send_msg(app: App): {"message_id": message_id}, ) await check_record( - "11", - "OneBot V11", - "qq", - 2, - None, - str(group_id), - None, + Session( + self_id="11", + adapter="OneBot V11", + scope="QQClient", + scene=Scene(id=str(group_id), type=SceneType.GROUP), + user=User(id="11"), + ), None, "message_sent", str(message_id), @@ -201,13 +206,13 @@ async def test_record_send_msg(app: App): {"message_id": message_id}, ) await check_record( - "11", - "OneBot V11", - "qq", - 1, - str(user_id), - None, - None, + Session( + self_id="11", + adapter="OneBot V11", + scope="QQClient", + scene=Scene(id=str(user_id), type=SceneType.PRIVATE), + user=User(id="11"), + ), None, "message_sent", str(message_id), diff --git a/tests/test_onebot_v12.py b/tests/test_onebot_v12.py index d5d9bcf..a2297d7 100644 --- a/tests/test_onebot_v12.py +++ b/tests/test_onebot_v12.py @@ -91,6 +91,9 @@ class Config: async def test_record_recv_msg(app: App): """测试记录收到的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session, User + + from nonebot_plugin_chatrecorder.adapters.onebot_v12 import record_recv_msg from nonebot_plugin_chatrecorder.message import serialize_message time = datetime.fromtimestamp(1000000, timezone.utc) @@ -108,44 +111,29 @@ async def test_record_recv_msg(app: App): channel_msg = Message("test channel message") channel_msg_id = "11451433333" - async with app.test_matcher() as ctx: + async with app.test_api() as ctx: adapter = get_driver()._adapters[Adapter.get_name()] bot = ctx.create_bot( base=Bot, adapter=adapter, self_id="12", platform="qq", impl="walle-q" ) - event = fake_group_message_event( - time=time, - user_id=user_id, - group_id=group_id, - message_id=group_msg_id, - message=group_msg, - ) - ctx.receive_event(bot, event) - - event = fake_private_message_event( - time=time, user_id=user_id, message_id=private_msg_id, message=private_msg - ) - ctx.receive_event(bot, event) - - event = fake_channel_message_event_v12( - time=time, - user_id=user_id, - guild_id=guild_id, - channel_id=channel_id, - message_id=channel_msg_id, - message=channel_msg, - ) - ctx.receive_event(bot, event) - + event = fake_group_message_event( + time=time, + user_id=user_id, + group_id=group_id, + message_id=group_msg_id, + message=group_msg, + ) + session = Session( + self_id="12", + adapter="OneBot V12", + scope="QQClient", + scene=Scene(id=str(group_id), type=SceneType.GROUP), + user=User(id=str(user_id)), + ) + await record_recv_msg(event, session) await check_record( - "12", - "OneBot V12", - "qq", - 2, - str(user_id), - str(group_id), - None, + session, time, "message", str(group_msg_id), @@ -153,14 +141,19 @@ async def test_record_recv_msg(app: App): group_msg.extract_plain_text(), ) + event = fake_private_message_event( + time=time, user_id=user_id, message_id=private_msg_id, message=private_msg + ) + session = Session( + self_id="12", + adapter="OneBot V12", + scope="QQClient", + scene=Scene(id=str(user_id), type=SceneType.PRIVATE), + user=User(id=str(user_id)), + ) + await record_recv_msg(event, session) await check_record( - "12", - "OneBot V12", - "qq", - 1, - str(user_id), - None, - None, + session, time, "message", str(private_msg_id), @@ -168,14 +161,28 @@ async def test_record_recv_msg(app: App): private_msg.extract_plain_text(), ) + event = fake_channel_message_event_v12( + time=time, + user_id=user_id, + guild_id=guild_id, + channel_id=channel_id, + message_id=channel_msg_id, + message=channel_msg, + ) + session = Session( + self_id="12", + adapter="OneBot V12", + scope="QQClient", + scene=Scene( + id=str(channel_id), + type=SceneType.CHANNEL_TEXT, + parent=Scene(id=str(guild_id), type=SceneType.GUILD), + ), + user=User(id=str(user_id)), + ) + await record_recv_msg(event, session) await check_record( - "12", - "OneBot V12", - "qq", - 3, - str(user_id), - str(channel_id), - str(guild_id), + session, time, "message", str(channel_msg_id), @@ -186,6 +193,8 @@ async def test_record_recv_msg(app: App): async def test_record_send_msg(app: App): """测试记录发送的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session, User + from nonebot_plugin_chatrecorder.adapters.onebot_v12 import record_send_msg from nonebot_plugin_chatrecorder.message import serialize_message @@ -215,13 +224,13 @@ async def test_record_send_msg(app: App): {"message_id": message_id, "time": time}, ) await check_record( - "12", - "OneBot V12", - "qq", - 2, - None, - str(group_id), - None, + Session( + self_id="12", + adapter="OneBot V12", + scope="QQClient", + scene=Scene(id=str(group_id), type=SceneType.GROUP), + user=User(id="12"), + ), datetime.fromtimestamp(time, timezone.utc), "message_sent", str(message_id), @@ -243,13 +252,13 @@ async def test_record_send_msg(app: App): {"message_id": message_id, "time": time}, ) await check_record( - "12", - "OneBot V12", - "qq", - 1, - str(user_id), - None, - None, + Session( + self_id="12", + adapter="OneBot V12", + scope="QQClient", + scene=Scene(id=str(user_id), type=SceneType.PRIVATE), + user=User(id="12"), + ), datetime.fromtimestamp(time, timezone.utc), "message_sent", str(message_id), @@ -272,13 +281,17 @@ async def test_record_send_msg(app: App): {"message_id": message_id, "time": time}, ) await check_record( - "12", - "OneBot V12", - "qq", - 3, - None, - str(channel_id), - str(guild_id), + Session( + self_id="12", + adapter="OneBot V12", + scope="QQClient", + scene=Scene( + id=str(channel_id), + type=SceneType.CHANNEL_TEXT, + parent=Scene(id=str(guild_id), type=SceneType.GUILD), + ), + user=User(id="12"), + ), datetime.fromtimestamp(time, timezone.utc), "message_sent", str(message_id), diff --git a/tests/test_qq.py b/tests/test_qq.py index 0905488..70921c8 100644 --- a/tests/test_qq.py +++ b/tests/test_qq.py @@ -78,6 +78,10 @@ def fake_c2c_message_create_event(content: str, id: str) -> C2CMessageCreateEven async def test_record_recv_msg(app: App): """测试记录收到的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session + from nonebot_plugin_uninfo import User as UninfoUser + + from nonebot_plugin_chatrecorder.adapters.qq import record_recv_msg from nonebot_plugin_chatrecorder.message import serialize_message msg = "test message create event" @@ -92,7 +96,7 @@ async def test_record_recv_msg(app: App): c2c_msg = "test c2c message create event" c2c_msg_id = "1237" - async with app.test_matcher() as ctx: + async with app.test_api() as ctx: adapter = get_driver()._adapters[Adapter.get_name()] bot = ctx.create_bot( base=Bot, @@ -101,26 +105,21 @@ async def test_record_recv_msg(app: App): bot_info=BotInfo(id="2233", token="", secret=""), ) - event = fake_message_create_event(msg, msg_id) - ctx.receive_event(bot, event) - - event = fake_direct_message_create_event(direct_msg, direct_msg_id) - ctx.receive_event(bot, event) - - event = fake_group_at_message_create_event(group_at_msg, group_at_msg_id) - ctx.receive_event(bot, event) - - event = fake_c2c_message_create_event(c2c_msg, c2c_msg_id) - ctx.receive_event(bot, event) - + event = fake_message_create_event(msg, msg_id) + session = Session( + self_id="2233", + adapter="QQ", + scope="QQAPI", + scene=Scene( + id="6677", + type=SceneType.CHANNEL_TEXT, + parent=Scene(id="5566", type=SceneType.GUILD), + ), + user=UninfoUser(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "QQ", - "qqguild", - 3, - "3344", - "6677", - "5566", + session, datetime(2023, 7, 30, 0, 0, 0, 0, tzinfo=timezone.utc), "message", msg_id, @@ -128,14 +127,21 @@ async def test_record_recv_msg(app: App): msg, ) + event = fake_direct_message_create_event(direct_msg, direct_msg_id) + session = Session( + self_id="2233", + adapter="QQ", + scope="QQAPI", + scene=Scene( + id="3344", + type=SceneType.PRIVATE, + parent=Scene(id="5566", type=SceneType.GUILD), + ), + user=UninfoUser(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "QQ", - "qqguild", - 1, - "3344", - "6677", - "5566", + session, datetime(2023, 7, 30, 0, 0, 0, 0, tzinfo=timezone.utc), "message", direct_msg_id, @@ -143,14 +149,17 @@ async def test_record_recv_msg(app: App): direct_msg, ) + event = fake_group_at_message_create_event(group_at_msg, group_at_msg_id) + session = Session( + self_id="2233", + adapter="QQ", + scope="QQAPI", + scene=Scene(id="195747FDF0D845E98CF3886C5C7ED328", type=SceneType.GROUP), + user=UninfoUser(id="8BE608110EAA4328A1883DEF239F5580"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "QQ", - "qq", - 2, - "8BE608110EAA4328A1883DEF239F5580", - "195747FDF0D845E98CF3886C5C7ED328", - None, + session, datetime.fromisoformat("2023-11-06T13:37:18+08:00"), "message", group_at_msg_id, @@ -158,14 +167,17 @@ async def test_record_recv_msg(app: App): group_at_msg, ) + event = fake_c2c_message_create_event(c2c_msg, c2c_msg_id) + session = Session( + self_id="2233", + adapter="QQ", + scope="QQAPI", + scene=Scene(id="451368C569A1401D87172E9435EE8663", type=SceneType.PRIVATE), + user=UninfoUser(id="451368C569A1401D87172E9435EE8663"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "QQ", - "qq", - 1, - "451368C569A1401D87172E9435EE8663", - None, - None, + session, datetime.fromisoformat("2023-11-06T13:37:18+08:00"), "message", c2c_msg_id, @@ -176,6 +188,8 @@ async def test_record_recv_msg(app: App): async def test_record_send_msg(app: App): """测试记录发送的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session + from nonebot_plugin_uninfo import User as UninfoUser from nonebot_plugin_chatrecorder.adapters.qq import record_send_msg from nonebot_plugin_chatrecorder.message import serialize_message @@ -211,13 +225,17 @@ async def test_record_send_msg(app: App): ), ) await check_record( - "2233", - "QQ", - "qqguild", - 3, - None, - "6677", - "5566", + Session( + self_id="2233", + adapter="QQ", + scope="QQAPI", + scene=Scene( + id="6677", + type=SceneType.CHANNEL_TEXT, + parent=Scene(id="5566", type=SceneType.GUILD), + ), + user=UninfoUser(id="2233"), + ), datetime(2023, 7, 30, 0, 0, 0, 0, tzinfo=timezone.utc), "message_sent", "1238", @@ -247,13 +265,17 @@ async def test_record_send_msg(app: App): ), ) await check_record( - "2233", - "QQ", - "qqguild", - 1, - None, - None, - "5566", + Session( + self_id="2233", + adapter="QQ", + scope="QQAPI", + scene=Scene( + id="3344", + type=SceneType.PRIVATE, + parent=Scene(id="5566", type=SceneType.GUILD), + ), + user=UninfoUser(id="2233"), + ), datetime(2023, 7, 30, 0, 0, 0, 0, tzinfo=timezone.utc), "message_sent", "1239", @@ -281,13 +303,13 @@ async def test_record_send_msg(app: App): ), ) await check_record( - "2233", - "QQ", - "qq", - 1, - "87E469B751CD4520B0B18D826CC94B71", - None, - None, + Session( + self_id="2233", + adapter="QQ", + scope="QQAPI", + scene=Scene(id="87E469B751CD4520B0B18D826CC94B71", type=SceneType.PRIVATE), + user=UninfoUser(id="2233"), + ), datetime(2023, 7, 30, 0, 0, 0, 0, tzinfo=timezone.utc), "message_sent", "1241", @@ -315,13 +337,13 @@ async def test_record_send_msg(app: App): ), ) await check_record( - "2233", - "QQ", - "qq", - 2, - None, - "1CC5DF4814E54834B0A7F5D553BB25CC", - None, + Session( + self_id="2233", + adapter="QQ", + scope="QQAPI", + scene=Scene(id="1CC5DF4814E54834B0A7F5D553BB25CC", type=SceneType.GROUP), + user=UninfoUser(id="2233"), + ), datetime(2023, 7, 30, 0, 0, 0, 0, tzinfo=timezone.utc), "message_sent", "1243", diff --git a/tests/test_satori.py b/tests/test_satori.py index 6e8c293..336f9e5 100644 --- a/tests/test_satori.py +++ b/tests/test_satori.py @@ -22,7 +22,7 @@ from nonebot.compat import type_validate_python from nonebug.app import App -from .utils import assert_no_record, check_record +from .utils import check_record def fake_public_message_created_event( @@ -117,6 +117,10 @@ def fake_public_message_updated_event( async def test_record_recv_msg(app: App): """测试记录收到的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session + from nonebot_plugin_uninfo import User as UninfoUser + + from nonebot_plugin_chatrecorder.adapters.satori import record_recv_msg from nonebot_plugin_chatrecorder.message import serialize_message public_msg = "test public message created" @@ -125,10 +129,7 @@ async def test_record_recv_msg(app: App): private_msg = "test private message created" private_msg_id = "56163f81-de30-4c39-b4c4-3a205d0be9db" - msg_deleted_id = "56163f81-de30-4c39-b4c4-3a205d0be9dc" - msg_updated_id = "56163f81-de30-4c39-b4c4-3a205d0be9dd" - - async with app.test_matcher() as ctx: + async with app.test_api() as ctx: adapter = get_driver()._adapters[Adapter.get_name()] bot = ctx.create_bot( base=Bot, @@ -147,26 +148,21 @@ async def test_record_recv_msg(app: App): info=ClientInfo(port=5140), ) - event = fake_public_message_created_event(public_msg, public_msg_id) - ctx.receive_event(bot, event) - - event = fake_private_message_created_event(private_msg, private_msg_id) - ctx.receive_event(bot, event) - - event = fake_public_message_deleted_event("msg deleted", msg_deleted_id) - ctx.receive_event(bot, event) - - event = fake_public_message_updated_event("msg updated", msg_updated_id) - ctx.receive_event(bot, event) - + event = fake_public_message_created_event(public_msg, public_msg_id) + session = Session( + self_id="2233", + adapter="Satori", + scope="Kaiheila", + scene=Scene( + id="6677", + type=SceneType.CHANNEL_TEXT, + parent=Scene(id="5566", type=SceneType.GUILD), + ), + user=UninfoUser(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "Satori", - "kaiheila", - 3, - "3344", - "6677", - "5566", + session, datetime.fromtimestamp(17000000000 / 1000, timezone.utc), "message", public_msg_id, @@ -174,14 +170,17 @@ async def test_record_recv_msg(app: App): public_msg, ) + event = fake_private_message_created_event(private_msg, private_msg_id) + session = Session( + self_id="2233", + adapter="Satori", + scope="Kaiheila", + scene=Scene(id="3344", type=SceneType.PRIVATE), + user=UninfoUser(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "Satori", - "kaiheila", - 1, - "3344", - "6677", - None, + session, datetime.fromtimestamp(17000000000 / 1000, timezone.utc), "message", private_msg_id, @@ -189,12 +188,12 @@ async def test_record_recv_msg(app: App): private_msg, ) - await assert_no_record(msg_deleted_id) - await assert_no_record(msg_updated_id) - async def test_record_send_msg(app: App): """测试记录发送的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session + from nonebot_plugin_uninfo import User as UninfoUser + from nonebot_plugin_chatrecorder.adapters.satori import record_send_msg from nonebot_plugin_chatrecorder.message import serialize_message @@ -249,13 +248,17 @@ async def test_record_send_msg(app: App): ], ) await check_record( - "2233", - "Satori", - "kaiheila", - 3, - None, - "6677", - "5566", + Session( + self_id="2233", + adapter="Satori", + scope="Kaiheila", + scene=Scene( + id="6677", + type=SceneType.CHANNEL_TEXT, + parent=Scene(id="5566", type=SceneType.GUILD), + ), + user=UninfoUser(id="2233"), + ), None, "message_sent", "6b701984-c185-4da9-9808-549dc9947b85", diff --git a/tests/test_telegram.py b/tests/test_telegram.py index c812746..36e6ab0 100644 --- a/tests/test_telegram.py +++ b/tests/test_telegram.py @@ -72,6 +72,9 @@ def fake_forum_topic_message_event( async def test_record_recv_msg(app: App): """测试记录收到的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session, User + + from nonebot_plugin_chatrecorder.adapters.telegram import record_recv_msg from nonebot_plugin_chatrecorder.message import serialize_message private_msg = "test private message" @@ -83,7 +86,7 @@ async def test_record_recv_msg(app: App): forum_msg = "test forum topic message" forum_msg_id = "1236" - async with app.test_matcher() as ctx: + async with app.test_api() as ctx: adapter = get_driver()._adapters[Adapter.get_name()] bot = ctx.create_bot( base=Bot, @@ -92,23 +95,17 @@ async def test_record_recv_msg(app: App): config=BotConfig(token="2233:xxx"), ) - event = fake_private_message_event(private_msg, private_msg_id) - ctx.receive_event(bot, event) - - event = fake_group_message_event(group_msg, group_msg_id) - ctx.receive_event(bot, event) - - event = fake_forum_topic_message_event(forum_msg, forum_msg_id) - ctx.receive_event(bot, event) - + event = fake_private_message_event(private_msg, private_msg_id) + session = Session( + self_id="2233", + adapter="Telegram", + scope="Telegram", + scene=Scene(id="3344", type=SceneType.PRIVATE), + user=User(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "Telegram", - "telegram", - 1, - "3344", - None, - None, + session, datetime.fromtimestamp(1122, timezone.utc), "message", "3344_1234", @@ -116,14 +113,17 @@ async def test_record_recv_msg(app: App): private_msg, ) + event = fake_group_message_event(group_msg, group_msg_id) + session = Session( + self_id="2233", + adapter="Telegram", + scope="Telegram", + scene=Scene(id="5566", type=SceneType.GROUP), + user=User(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "Telegram", - "telegram", - 2, - "3344", - "5566", - None, + session, datetime.fromtimestamp(1122, timezone.utc), "message", "5566_1235", @@ -131,14 +131,21 @@ async def test_record_recv_msg(app: App): group_msg, ) + event = fake_forum_topic_message_event(forum_msg, forum_msg_id) + session = Session( + self_id="2233", + adapter="Telegram", + scope="Telegram", + scene=Scene( + id="6677", + type=SceneType.CHANNEL_TEXT, + parent=Scene(id="5566", type=SceneType.GUILD), + ), + user=User(id="3344"), + ) + await record_recv_msg(event, session) await check_record( - "2233", - "Telegram", - "telegram", - 3, - "3344", - "6677", - "5566", + session, datetime.fromtimestamp(1122, timezone.utc), "message", "5566_1236", @@ -149,6 +156,8 @@ async def test_record_recv_msg(app: App): async def test_record_send_msg(app: App): """测试记录发送的消息""" + from nonebot_plugin_uninfo import Scene, SceneType, Session, User + from nonebot_plugin_chatrecorder.adapters.telegram import record_send_msg from nonebot_plugin_chatrecorder.message import serialize_message @@ -197,13 +206,13 @@ async def test_record_send_msg(app: App): }, ) await check_record( - "2233", - "Telegram", - "telegram", - 1, - "3344", - None, - None, + Session( + self_id="2233", + adapter="Telegram", + scope="Telegram", + scene=Scene(id="3344", type=SceneType.PRIVATE), + user=User(id="2233"), + ), datetime.fromtimestamp(1122, timezone.utc), "message_sent", "3344_1237", @@ -295,13 +304,13 @@ async def test_record_send_msg(app: App): ], ) await check_record( - "2233", - "Telegram", - "telegram", - 1, - "3344", - None, - None, + Session( + self_id="2233", + adapter="Telegram", + scope="Telegram", + scene=Scene(id="3344", type=SceneType.PRIVATE), + user=User(id="2233"), + ), datetime.fromtimestamp(1122, timezone.utc), "message_sent", "3344_1238_1239", diff --git a/tests/utils.py b/tests/utils.py index dc56920..5bc9cdd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,15 +1,18 @@ from datetime import datetime -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional + +from nonebot.log import logger + +if TYPE_CHECKING: + from nonebot_plugin_uninfo import Session + + +def session_id(session: "Session") -> str: + return f"{session.self_id}_{session.adapter}_{session.scope}_{session.id}" async def check_record( - bot_id: str, - bot_type: str, - platform: str, - level: int, - id1: Optional[str], - id2: Optional[str], - id3: Optional[str], + session: Optional["Session"], time: Optional[datetime], type: str, message_id: str, @@ -17,7 +20,7 @@ async def check_record( plain_text: str, ): from nonebot_plugin_orm import get_session - from nonebot_plugin_session_orm import get_session_by_persist_id + from nonebot_plugin_uninfo.orm import get_session_model from sqlalchemy import select from nonebot_plugin_chatrecorder.model import MessageRecord @@ -27,17 +30,14 @@ async def check_record( async with get_session() as db_session: records = (await db_session.scalars(statement)).all() + logger.warning(f"len records: {len(records)}") assert len(records) == 1 record = records[0] session_persist_id = record.session_persist_id - session = await get_session_by_persist_id(session_persist_id) - assert session.bot_id == bot_id - assert session.bot_type == bot_type - assert session.platform == platform - assert session.level == level - assert session.id1 == id1 - assert session.id2 == id2 - assert session.id3 == id3 + session_model = await get_session_model(session_persist_id) + record_session = await session_model.to_session() + if session: + assert session_id(record_session) == session_id(session) assert record.type == type if time: assert record.time == remove_timezone(time)