From 637bfb8fa27d118c6310a17f1412d51780fc5ee6 Mon Sep 17 00:00:00 2001 From: Farzin Kazemzadeh Date: Tue, 27 May 2025 13:24:51 +0330 Subject: [PATCH 1/4] refactor(core): move plugin related methods to its own class Signed-off-by: Farzin Kazemzadeh --- bot/bot.py | 198 ++----------------------------------- bot/core/__init__.py | 8 ++ bot/core/plugin_manager.py | 198 +++++++++++++++++++++++++++++++++++++ 3 files changed, 215 insertions(+), 189 deletions(-) create mode 100644 bot/core/__init__.py create mode 100644 bot/core/plugin_manager.py diff --git a/bot/bot.py b/bot/bot.py index 493ffe2..ae7252f 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,13 +1,6 @@ -import os -import logging -import importlib -from pathlib import Path +from .core import Core from pyrogram import Client -from sqlalchemy import select -from sqlalchemy.orm import Session -from typing import Generator, Optional -from pyrogram.handlers.handler import Handler -from config import Config, DataBase, PluginDatabase +from config import Config, DataBase class BotMeta(type): @@ -17,188 +10,15 @@ def __call__(cls, *args, **kwargs): return instance -class Bot(Client, metaclass=BotMeta): +class Bot(Core, Client, metaclass=BotMeta): def _post_init(self): - self.builtin_plugin = "bot/builtin_plugins" DataBase.metadata.create_all(Config.engine) - self.load_plugins(folder=self.builtin_plugin) - DataBase.metadata.create_all(Config.engine) - - def modules_list( - self, folder: Optional[str | list[str]] = None - ) -> list[Path]: - modules = [] - - folders = ( - folder - if isinstance(folder, list) - else [folder or self.plugins["root"]] - ) - - for f in folders: - for root, _, files in os.walk( - f.replace(".", "/"), followlinks=True - ): - for file in files: - if not file.endswith(".py"): - continue - path = Path(root) / file - modules.append(path) - - return sorted(modules) - - def get_plugins( - self, folder: Optional[str | list[str]] = None - ) -> list[str] | list[Path]: - plugins = [] - - folders = ( - folder - if isinstance(folder, list) - else [folder or self.plugins["root"]] - ) - for path in self.modules_list(folders): - module_path = ".".join(path.with_suffix("").parts) - module = importlib.import_module(module_path) - if getattr(module, "__plugin__", False): - plugins.append(path.stem) - - return sorted(plugins) - - def get_handlers( - self, - plugins: Optional[str | list[str]] = None, - folder: Optional[str | list[str]] = None, - ) -> Generator[tuple[str, str] | tuple[Handler, int], None, None]: - if isinstance(plugins, str): - plugins = plugins.split(",") - - group_offset = 0 if folder == self.builtin_plugin else 1 - _plugins = self.get_plugins(folder=folder) - - if plugins: - for plugin in plugins: - if plugin not in _plugins: - yield (plugin, "Plugin not found") - - for path in self.modules_list(folder=folder): - if plugins and path.stem not in plugins: + for base_class in reversed(self.__class__.__mro__): + if base_class is self.__class__ or base_class is object: continue - module_path = ".".join(path.parent.parts + (path.stem,)) - module = importlib.import_module(module_path) - # TODO: reload the module after import - - for name in vars(module).keys(): - target_attr = getattr(module, name) - if hasattr(target_attr, "handlers"): - for handler, group in target_attr.handlers: - if isinstance(handler, Handler) and isinstance( - group, int - ): - if group < 0 and group_offset != 0: - yield (handler, 0) - else: - yield (handler, group + group_offset) - - def handler_is_loaded(self, handler: Handler, group: int = 0) -> bool: - if group not in self.dispatcher.groups: - return False - return handler in self.dispatcher.groups[group] - - def set_plugin_status(self, plugin: str, enabled: bool = True): - with Session(Config.engine) as session: - session.merge(PluginDatabase(name=plugin, enabled=enabled)) - session.commit() - - def get_plugin_status(self, plugin: str) -> bool: - with Session(Config.engine) as session: - enabled = session.execute( - select(PluginDatabase.enabled).where( - PluginDatabase.name == plugin - ) - ).scalar() - return enabled or False - - def load_plugins( - self, - plugins: Optional[str | list[str]] = None, - folder: Optional[str | list[str]] = None, - force_load: bool = False, - ) -> dict[str, str]: - result = {} - if isinstance(plugins, str): - plugins = plugins.split(",") - - _plugins = self.get_plugins(folder=folder) - plugins = plugins or _plugins - - for plugin in plugins: - if plugin in _plugins: - with Session(Config.engine) as session: - if ( - session.execute( - select(PluginDatabase.enabled).where( - PluginDatabase.name == plugin - ) - ).scalar() - is False - and not force_load - ): - plugins.remove(plugin) - else: - self.set_plugin_status(plugin, True) - - for handler in self.get_handlers(plugins, folder=folder): - if isinstance(handler[0], str): - result[handler[0]] = handler[1] - logging.warning(handler[1]) - else: - callback_name = handler[0].callback.__name__ - if not self.handler_is_loaded(*handler): - self.add_handler(*handler) - result[callback_name] = "Handler loaded" - logging.info(f"{callback_name} handler has been loaded") - else: - result[callback_name] = "Failed to load handler" - logging.warning( - f"Failed to load {callback_name} handler, " - "because it is already loaded" - ) - DataBase.metadata.create_all(Config.engine) - return result - - def unload_plugins( - self, - plugins: Optional[str | list[str]] = None, - folder: Optional[str | list[str]] = None, - ): - result = {} - if isinstance(plugins, str): - plugins = plugins.split(",") - - _plugins = self.get_plugins(folder=folder) - plugins = plugins or _plugins - - for plugin in plugins: - if plugin in _plugins: - self.set_plugin_status(plugin, False) - - for handler in self.get_handlers(plugins, folder=folder): - if isinstance(handler[0], str): - result[handler[0]] = handler[1] - logging.warning(handler[1]) - else: - callback_name = handler[0].callback.__name__ - if self.handler_is_loaded(*handler): - self.remove_handler(*handler) - result[callback_name] = "Handler unloaded" - logging.info(f"{callback_name} handler has been unloaded") - else: - result[callback_name] = "Failed to unload handler" - logging.warning( - f"Failed to unload {callback_name} handler, " - "it is not loaded already." - ) - return result + if hasattr(base_class, "_post_init"): + _post_init = getattr(base_class, "_post_init") + if callable(_post_init): + _post_init(self) diff --git a/bot/core/__init__.py b/bot/core/__init__.py new file mode 100644 index 0000000..a590b79 --- /dev/null +++ b/bot/core/__init__.py @@ -0,0 +1,8 @@ +from .plugin_manager import PluginManager + + +class Core(PluginManager): + pass + + +__all__ = ["Core"] diff --git a/bot/core/plugin_manager.py b/bot/core/plugin_manager.py new file mode 100644 index 0000000..25982c3 --- /dev/null +++ b/bot/core/plugin_manager.py @@ -0,0 +1,198 @@ +import os +import logging +import importlib +from pathlib import Path +from pyrogram import Client +from sqlalchemy import select +from sqlalchemy.orm import Session +from typing import Generator, Optional +from pyrogram.handlers.handler import Handler +from config import Config, DataBase, PluginDatabase + + +class PluginManager(Client): + plugins: dict + builtin_plugin: str = "bot/builtin_plugins" + + def _post_init(self): + self.load_plugins(folder=self.builtin_plugin) + DataBase.metadata.create_all(Config.engine) + + def modules_list( + self, folder: Optional[str | list[str]] = None + ) -> list[Path]: + modules = [] + + folders = ( + folder + if isinstance(folder, list) + else [folder or self.plugins["root"]] + ) + + for f in folders: + for root, _, files in os.walk( + f.replace(".", "/"), followlinks=True + ): + for file in files: + if not file.endswith(".py"): + continue + path = Path(root) / file + modules.append(path) + + return sorted(modules) + + def get_plugins( + self, folder: Optional[str | list[str]] = None + ) -> list[str] | list[Path]: + plugins = [] + + folders = ( + folder + if isinstance(folder, list) + else [folder or self.plugins["root"]] + ) + + for path in self.modules_list(folders): + module_path = ".".join(path.with_suffix("").parts) + module = importlib.import_module(module_path) + if getattr(module, "__plugin__", False): + plugins.append(path.stem) + + return sorted(plugins) + + def get_handlers( + self, + plugins: Optional[str | list[str]] = None, + folder: Optional[str | list[str]] = None, + ) -> Generator[tuple[str, str] | tuple[Handler, int], None, None]: + if isinstance(plugins, str): + plugins = plugins.split(",") + + group_offset = 0 if folder == self.builtin_plugin else 1 + _plugins = self.get_plugins(folder=folder) + + if plugins: + for plugin in plugins: + if plugin not in _plugins: + yield (plugin, "Plugin not found") + + for path in self.modules_list(folder=folder): + if plugins and path.stem not in plugins: + continue + + module_path = ".".join(path.parent.parts + (path.stem,)) + module = importlib.import_module(module_path) + # TODO: reload the module after import + + for name in vars(module).keys(): + target_attr = getattr(module, name) + if hasattr(target_attr, "handlers"): + for handler, group in target_attr.handlers: + if isinstance(handler, Handler) and isinstance( + group, int + ): + if group < 0 and group_offset != 0: + yield (handler, 0) + else: + yield (handler, group + group_offset) + + def handler_is_loaded(self, handler: Handler, group: int = 0) -> bool: + if group not in self.dispatcher.groups: + return False + return handler in self.dispatcher.groups[group] + + def set_plugin_status(self, plugin: str, enabled: bool = True): + with Session(Config.engine) as session: + session.merge(PluginDatabase(name=plugin, enabled=enabled)) + session.commit() + + def get_plugin_status(self, plugin: str) -> bool: + with Session(Config.engine) as session: + enabled = session.execute( + select(PluginDatabase.enabled).where( + PluginDatabase.name == plugin + ) + ).scalar() + return enabled or False + + def load_plugins( + self, + plugins: Optional[str | list[str]] = None, + folder: Optional[str | list[str]] = None, + force_load: bool = False, + ) -> dict[str, str]: + result = {} + if isinstance(plugins, str): + plugins = plugins.split(",") + + _plugins = self.get_plugins(folder=folder) + plugins = plugins or _plugins + + for plugin in plugins: + if plugin in _plugins: + with Session(Config.engine) as session: + if ( + session.execute( + select(PluginDatabase.enabled).where( + PluginDatabase.name == plugin + ) + ).scalar() + is False + and not force_load + ): + plugins.remove(plugin) + else: + self.set_plugin_status(plugin, True) + + for handler in self.get_handlers(plugins, folder=folder): + if isinstance(handler[0], str): + result[handler[0]] = handler[1] + logging.warning(handler[1]) + else: + callback_name = handler[0].callback.__name__ + if not self.handler_is_loaded(*handler): + self.add_handler(*handler) + result[callback_name] = "Handler loaded" + logging.info(f"{callback_name} handler has been loaded") + else: + result[callback_name] = "Failed to load handler" + logging.warning( + f"Failed to load {callback_name} handler, " + "because it is already loaded" + ) + DataBase.metadata.create_all(Config.engine) + return result + + def unload_plugins( + self, + plugins: Optional[str | list[str]] = None, + folder: Optional[str | list[str]] = None, + ): + result = {} + if isinstance(plugins, str): + plugins = plugins.split(",") + + _plugins = self.get_plugins(folder=folder) + plugins = plugins or _plugins + + for plugin in plugins: + if plugin in _plugins: + self.set_plugin_status(plugin, False) + + for handler in self.get_handlers(plugins, folder=folder): + if isinstance(handler[0], str): + result[handler[0]] = handler[1] + logging.warning(handler[1]) + else: + callback_name = handler[0].callback.__name__ + if self.handler_is_loaded(*handler): + self.remove_handler(*handler) + result[callback_name] = "Handler unloaded" + logging.info(f"{callback_name} handler has been unloaded") + else: + result[callback_name] = "Failed to unload handler" + logging.warning( + f"Failed to unload {callback_name} handler, " + "it is not loaded already." + ) + return result From 8bd51ac84a5423b350e0e4a610b1dad3a75eeae2 Mon Sep 17 00:00:00 2001 From: Farzin Kazemzadeh Date: Tue, 27 May 2025 15:18:58 +0330 Subject: [PATCH 2/4] feat(core): add `is_public_use` decorator [1/2] With this, handlers can be enabled/disabled for general public use. Signed-off-by: Farzin Kazemzadeh --- bot/core/__init__.py | 3 ++- bot/core/is_public_use.py | 27 +++++++++++++++++++++++++++ settings.py | 1 + 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 bot/core/is_public_use.py diff --git a/bot/core/__init__.py b/bot/core/__init__.py index a590b79..7627b88 100644 --- a/bot/core/__init__.py +++ b/bot/core/__init__.py @@ -1,7 +1,8 @@ +from .is_public_use import IsPublicUse from .plugin_manager import PluginManager -class Core(PluginManager): +class Core(PluginManager, IsPublicUse): pass diff --git a/bot/core/is_public_use.py b/bot/core/is_public_use.py new file mode 100644 index 0000000..1b1cf49 --- /dev/null +++ b/bot/core/is_public_use.py @@ -0,0 +1,27 @@ +import bot +import functools +from typing import Callable +from sqlalchemy import select +from sqlalchemy.orm import Session +from config import Config, PluginDatabase +from pyrogram.types import Message, InlineQuery + + +class IsPublicUse: + def is_public_use(func: Callable): + @functools.wraps(func) + async def decorator(client: "bot.Bot", update: Message | InlineQuery): + if await Config.IS_ADMIN(client, update): + return await func(client, update) + + with Session(Config.engine) as session: + if session.execute( + select(PluginDatabase.is_public_use).where( + PluginDatabase.name == func.__module__.split(".")[-1] + ) + ).scalar(): + return await func(client, update) + + return None + + return decorator diff --git a/settings.py b/settings.py index 6d12389..dc989e9 100644 --- a/settings.py +++ b/settings.py @@ -175,3 +175,4 @@ class PluginDatabase(DataBase): name: Mapped[str] = mapped_column(String(40), primary_key=True) enabled: Mapped[bool] = mapped_column(Boolean()) custom_data: Mapped[JSON] = mapped_column(JSON(), default=dict()) + is_public_use: Mapped[bool] = mapped_column(Boolean(), default=False) From db831dd0962f49789a45bf6ea4b122b50788172c Mon Sep 17 00:00:00 2001 From: Farzin Kazemzadeh Date: Tue, 27 May 2025 16:08:00 +0330 Subject: [PATCH 3/4] feat(manager): add `is_public_use` decorator [2/2] Signed-off-by: Farzin Kazemzadeh --- bot/builtin_plugins/manager.py | 118 ++++++++++++++++++++++++++------- bot/core/plugin_manager.py | 16 +++-- 2 files changed, 106 insertions(+), 28 deletions(-) diff --git a/bot/builtin_plugins/manager.py b/bot/builtin_plugins/manager.py index 68fa3f0..1000e9c 100644 --- a/bot/builtin_plugins/manager.py +++ b/bot/builtin_plugins/manager.py @@ -1,4 +1,5 @@ from bot import Bot +from typing import Optional from pyrogram import filters from pyrogram.types import ( Message, @@ -6,27 +7,79 @@ InlineKeyboardButton, CallbackQuery, ) +from sqlalchemy import select +from sqlalchemy.orm import Session -from config import Config +from config import Config, PluginDatabase -def plugins_keyboard(app: Bot): - plugins = app.get_plugins() - keyboard = [ - [ - InlineKeyboardButton( - plugin.replace("-", " ").replace("_", " "), f"plugins {plugin}" - ), - InlineKeyboardButton( - {True: "✅", False: "❌"}[app.get_plugin_status(plugin)], - f"plugins {plugin}", - ), +def plugins_keyboard(client: Bot, plugin: Optional[str] = None): + + with Session(Config.engine) as session: + if plugin: + enabled, public = session.execute( + select( + PluginDatabase.enabled, + PluginDatabase.is_public_use, + ).where(PluginDatabase.name == plugin) + ).one() + + return [ + [ + InlineKeyboardButton( + "Status", f"plugins {plugin} status2" + ), + InlineKeyboardButton( + {True: "✅", False: "❌"}[enabled], + f"plugins {plugin} status2", + ), + ], + [ + InlineKeyboardButton( + "Public Use", f"plugins {plugin} public" + ), + InlineKeyboardButton( + {True: "✅", False: "❌"}[public], + f"plugins {plugin} public", + ), + ], + [ + InlineKeyboardButton("Back", "plugins"), + ], + ] + + keyboard = [[InlineKeyboardButton("No were plugin found.", "None")]] + _plugins = client.get_plugins() + plugins = session.execute( + select( + PluginDatabase.name, + PluginDatabase.enabled, + PluginDatabase.is_public_use, + ) + ).all() + plugins: dict[str, list[bool]] = { + plugin[0]: plugin[1] for plugin in plugins + } + + if len(plugins) == 0: + return keyboard + + keyboard = [ + [ + InlineKeyboardButton( + plugin.replace("-", " ").replace("_", " "), + f"plugins {plugin}", + ), + InlineKeyboardButton( + {True: "✅", False: "❌"}[plugins[plugin]], + f"plugins {plugin} status1", + ), + ] + for plugin in plugins + if plugin in _plugins ] - for plugin in plugins - ] - return keyboard or [ - [InlineKeyboardButton("No were plugin found.", "None")] - ] + + return keyboard @Bot.on_message( @@ -40,17 +93,36 @@ async def plugins(app: Bot, message: Message): @Bot.on_callback_query( - Config.IS_ADMIN & filters.regex(r"^plugins (?P[\w\-]+)$") + Config.IS_ADMIN + & filters.regex(r"^plugins(?: (?P[\w\-]+))?(?: (?P\w+))?$") ) async def plugins_callback(app: Bot, query: CallbackQuery): plugin: str = query.matches[0].group("plugin") - if app.get_plugin_status(plugin): - app.unload_plugins(plugin) + mode: str = query.matches[0].group("mode") + text = "**Plugins**:" + + if not plugin: + keyboard = plugins_keyboard(app) + elif mode == "status1": + if app.get_plugin_status(plugin): + app.unload_plugins(plugin) + else: + app.load_plugins(plugin, force_load=True) + keyboard = plugins_keyboard(app) else: - app.load_plugins(plugin, force_load=True) + text = f"Plugin **{plugin}**:" + if mode == "status2": + if app.get_plugin_status(plugin): + app.unload_plugins(plugin) + else: + app.load_plugins(plugin, force_load=True) + elif mode == "public": + is_public_use = app.get_plugin_data(plugin, "is_public_use") + app.set_plugin_data(plugin, "is_public_use", not is_public_use) + keyboard = plugins_keyboard(app, plugin=plugin) + await query.edit_message_text( - "**Plugins**:", - reply_markup=InlineKeyboardMarkup(plugins_keyboard(app)), + text, reply_markup=InlineKeyboardMarkup(keyboard) ) diff --git a/bot/core/plugin_manager.py b/bot/core/plugin_manager.py index 25982c3..ea52c46 100644 --- a/bot/core/plugin_manager.py +++ b/bot/core/plugin_manager.py @@ -5,7 +5,7 @@ from pyrogram import Client from sqlalchemy import select from sqlalchemy.orm import Session -from typing import Generator, Optional +from typing import Any, Generator, Optional from pyrogram.handlers.handler import Handler from config import Config, DataBase, PluginDatabase @@ -101,20 +101,26 @@ def handler_is_loaded(self, handler: Handler, group: int = 0) -> bool: return False return handler in self.dispatcher.groups[group] - def set_plugin_status(self, plugin: str, enabled: bool = True): + def set_plugin_data(self, plugin: str, key: str, value: Any): with Session(Config.engine) as session: - session.merge(PluginDatabase(name=plugin, enabled=enabled)) + session.merge(PluginDatabase(**{"name": plugin, key: value})) session.commit() - def get_plugin_status(self, plugin: str) -> bool: + def get_plugin_data(self, plugin: str, key: str) -> bool: with Session(Config.engine) as session: enabled = session.execute( - select(PluginDatabase.enabled).where( + select(getattr(PluginDatabase, key)).where( PluginDatabase.name == plugin ) ).scalar() return enabled or False + def set_plugin_status(self, plugin: str, enabled: bool = True): + self.set_plugin_data(plugin, "enabled", enabled) + + def get_plugin_status(self, plugin: str) -> bool: + return self.get_plugin_data(plugin, "enabled") + def load_plugins( self, plugins: Optional[str | list[str]] = None, From 487c54c5dc0d46acd252cc39451057592d939f79 Mon Sep 17 00:00:00 2001 From: Farzin Kazemzadeh Date: Tue, 27 May 2025 16:21:36 +0330 Subject: [PATCH 4/4] fix(settings): default `is_public_use` to True By having it default as False, plugins will not work out-of-the-box Signed-off-by: Farzin Kazemzadeh --- settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/settings.py b/settings.py index dc989e9..abbd36d 100644 --- a/settings.py +++ b/settings.py @@ -175,4 +175,4 @@ class PluginDatabase(DataBase): name: Mapped[str] = mapped_column(String(40), primary_key=True) enabled: Mapped[bool] = mapped_column(Boolean()) custom_data: Mapped[JSON] = mapped_column(JSON(), default=dict()) - is_public_use: Mapped[bool] = mapped_column(Boolean(), default=False) + is_public_use: Mapped[bool] = mapped_column(Boolean(), default=True)