diff --git a/.gitignore b/.gitignore index a4d63804..11f7554a 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,7 @@ models/lightfm/user_embeddings.dill models/popular_in_category/ models/dssm/ models/bert4rec/ +models/als # Unit test / coverage reports htmlcov/ diff --git a/load_models_from_google_drive.sh b/load_models_from_google_drive.sh index b9ab58d9..c361d88a 100755 --- a/load_models_from_google_drive.sh +++ b/load_models_from_google_drive.sh @@ -28,3 +28,10 @@ download "models/dssm" "iid_to_item_id.json" "1-TrGCS_YmRWQkIeuhSXKsN7Xg3Nk_pEn" download "models/dssm" "uid_to_watched_iids.json" "1-QtArop7useHil5pIeAM2t-d9J1nKOhS" download "models/bert4rec" "user_id_to_bert4rec_recs.pickle" "15o1tIcsFlkQdmbw2Z9kucMzv724H-Qi8" + +download "models/als" "ui_csr.pickle" "1u73wI918JbDnRRGZ19hDVNePoX-E75Vk" +download "models/als" "user_ext_to_int.pickle" "1-2AN1y039gPAl0oKUIN6KIuT-MkUDqtC" +download "models/als" "item_int_to_ext.pickle" "1--1QkLMZgTzSiQSrDdLZcKvhAQgqONNa" +download "models/als" "item_ext_to_int.pickle" "1d_ecNjQfpxNDwJSh6fP1f2HWrDiSy3ao" +download "models/als" "als_model.pickle" "1-8gJuelBZwJFhq7IDN1U2N2zKjPW1-kw" +download "models/als" "item_to_title.pickle" "1voF8cCfifroAsbJzFzc5DXC2j0MEGYIh" diff --git "a/notebooks/2023-spring. hw-4. ALS_\320\270\320\275\321\202\320\265\321\200\320\277\321\200\320\265\321\202\320\260\321\206\320\270\321\217.ipynb" "b/notebooks/2023-spring. hw-4. ALS_\320\270\320\275\321\202\320\265\321\200\320\277\321\200\320\265\321\202\320\260\321\206\320\270\321\217.ipynb" new file mode 100644 index 00000000..2bc36ab3 --- /dev/null +++ "b/notebooks/2023-spring. hw-4. ALS_\320\270\320\275\321\202\320\265\321\200\320\277\321\200\320\265\321\202\320\260\321\206\320\270\321\217.ipynb" @@ -0,0 +1,1108 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4", + "collapsed_sections": [ + "e1uQOgjY75wn", + "ekn9E-ih79T0" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "gpuClass": "standard", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "b50b7514d4924d1eb2e3ce4b3d4a964a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_cfb9d9570b4a49a9bf782607056da77a", + "IPY_MODEL_2d811c1270224482930b0a5ac4889852", + "IPY_MODEL_a4ba3e0e0cda49f1855b18659ad52f73" + ], + "layout": "IPY_MODEL_076aaa32625f4aca9c01cab7d5c0ad36" + } + }, + "cfb9d9570b4a49a9bf782607056da77a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c2c32f0d15614ede8702508c3ede6a9a", + "placeholder": "​", + "style": "IPY_MODEL_125d0a4accb2478184df475818540350", + "value": "100%" + } + }, + "2d811c1270224482930b0a5ac4889852": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_92962bbd778049d5b512a1f440becdce", + "max": 15, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_e2e4182854884016a9ed4f645a59ab13", + "value": 15 + } + }, + "a4ba3e0e0cda49f1855b18659ad52f73": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9a64281c06124ed18386e78310d529ad", + "placeholder": "​", + "style": "IPY_MODEL_c96fb3b88dec43349a9eaf0a87d98d93", + "value": " 15/15 [00:04<00:00, 3.30it/s]" + } + }, + "076aaa32625f4aca9c01cab7d5c0ad36": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c2c32f0d15614ede8702508c3ede6a9a": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "125d0a4accb2478184df475818540350": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "92962bbd778049d5b512a1f440becdce": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e2e4182854884016a9ed4f645a59ab13": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "9a64281c06124ed18386e78310d529ad": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c96fb3b88dec43349a9eaf0a87d98d93": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "from rectools import Columns\n", + "from rectools.models import ImplicitALSWrapperModel\n", + "from rectools.dataset import Dataset\n", + "from rectools.models.utils import recommend_from_scores\n", + "\n", + "from implicit.als import AlternatingLeastSquares\n", + "\n", + "import pandas as pd\n", + "from collections import Counter" + ], + "metadata": { + "id": "YEkHPJRZA3h8" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Загрузка данных" + ], + "metadata": { + "id": "e1uQOgjY75wn" + } + }, + { + "cell_type": "code", + "source": [ + "interactions_df = pd.read_csv(\"/content/drive/MyDrive/RecSys MTC/kion/interactions_processed.csv\")\n", + "interactions_df.rename(columns={\"last_watch_dt\": Columns.Datetime}, inplace=True)" + ], + "metadata": { + "id": "aoR1FBRMnrS8" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "interactions_df.head(3)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 0 + }, + "id": "lCgzfC00GF4o", + "outputId": "e02cc58a-45e8-48d3-8180-8210e7d93fb0" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " user_id item_id datetime total_dur watched_pct\n", + "0 176549 9506 2021-05-11 4250 72\n", + "1 699317 1659 2021-05-29 8317 100\n", + "2 656683 7107 2021-05-09 10 0" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_iddatetimetotal_durwatched_pct
017654995062021-05-11425072
169931716592021-05-298317100
265668371072021-05-09100
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ] + }, + "metadata": {}, + "execution_count": 4 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Предобработка" + ], + "metadata": { + "id": "ekn9E-ih79T0" + } + }, + { + "cell_type": "code", + "source": [ + "interactions_df = interactions_df[interactions_df.watched_pct >= 5]" + ], + "metadata": { + "id": "QANVdjKNvdQO" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# отбросим пользователей с малым числом просмотров\n", + "active_users = []\n", + "inactive_users = []\n", + "c = Counter(interactions_df.user_id)\n", + "for user_id, entries in c.items():\n", + " if entries >= 4:\n", + " active_users.append(user_id)\n", + " else:\n", + " inactive_users.append(user_id)\n", + "\n", + "interactions_df = interactions_df[interactions_df.user_id.isin(active_users)]\n", + "\n", + "len(active_users), len(inactive_users)" + ], + "metadata": { + "id": "vr3mW58NE8TL", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "825a7193-4b28-43f5-9958-f33c3898d071" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(286206, 507802)" + ] + }, + "metadata": {}, + "execution_count": 7 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Дадим эмпирический рейтинг взаимодействия с айтемом на основании процента просмотра:\n", + "\n", + "- 1 - просмотрено 0-10%\n", + "- 2 - просмотрено 10-30%\n", + "- 3 - просмотрено 30-60%\n", + "- 4 - просмотрено 60-85%\n", + "- 5 - просмотрено 85-100%" + ], + "metadata": { + "id": "9MZPXOGTJ4u4" + } + }, + { + "cell_type": "code", + "source": [ + "def watched_pct_to_score(pct: float) -> int:\n", + " if 85 <= pct <= 100:\n", + " return 5\n", + " elif 60 <= pct < 85:\n", + " return 4\n", + " elif 30 <= pct < 60:\n", + " return 3\n", + " elif 10 <= pct < 30:\n", + " return 2\n", + " else:\n", + " return 1" + ], + "metadata": { + "id": "aofKPB3VKioS" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "interactions_df[Columns.Weight] = interactions_df[\"watched_pct\"].apply(lambda pct: watched_pct_to_score(pct))" + ], + "metadata": { + "id": "KzzGdIzA7cUh" + }, + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "interactions_df.tail(3)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 0 + }, + "id": "Pp2o6tG6BG6H", + "outputId": "b9f580a2-2995-44a1-b232-542cf1ffd8e5" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " user_id item_id datetime total_dur watched_pct weight\n", + "5476247 546862 9673 2021-04-13 2308 49 3\n", + "5476249 384202 16197 2021-04-19 6203 100 5\n", + "5476250 319709 4436 2021-08-15 3921 45 3" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_iddatetimetotal_durwatched_pctweight
547624754686296732021-04-132308493
5476249384202161972021-04-1962031005
547625031970944362021-08-153921453
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ] + }, + "metadata": {}, + "execution_count": 10 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## ALS рекомендации и их объяснение для активных пользователей" + ], + "metadata": { + "id": "NKkzlx7oJB40" + } + }, + { + "cell_type": "code", + "source": [ + "dataset = Dataset.construct(\n", + " interactions_df[Columns.Interactions]\n", + ")" + ], + "metadata": { + "id": "rPp4kjyMJukY" + }, + "execution_count": 11, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "k_recs = 10\n", + "\n", + "N_FACTORS = 64\n", + "RANDOM_STATE = 2023\n", + "NUM_THREADS = 16\n", + "\n", + "als_model = ImplicitALSWrapperModel(\n", + " model=AlternatingLeastSquares(\n", + " factors=N_FACTORS,\n", + " random_state=RANDOM_STATE,\n", + " num_threads=NUM_THREADS\n", + " ), verbose=1,\n", + ")" + ], + "metadata": { + "id": "2kK8vVtGnrPe" + }, + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "als_model.fit(dataset)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 66, + "referenced_widgets": [ + "b50b7514d4924d1eb2e3ce4b3d4a964a", + "cfb9d9570b4a49a9bf782607056da77a", + "2d811c1270224482930b0a5ac4889852", + "a4ba3e0e0cda49f1855b18659ad52f73", + "076aaa32625f4aca9c01cab7d5c0ad36", + "c2c32f0d15614ede8702508c3ede6a9a", + "125d0a4accb2478184df475818540350", + "92962bbd778049d5b512a1f440becdce", + "e2e4182854884016a9ed4f645a59ab13", + "9a64281c06124ed18386e78310d529ad", + "c96fb3b88dec43349a9eaf0a87d98d93" + ] + }, + "id": "ogKYQ0GYx7vF", + "outputId": "ef86c801-d417-4bcf-8d20-a23288a6744a" + }, + "execution_count": 13, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + " 0%| | 0/15 [00:00" + ] + }, + "metadata": {}, + "execution_count": 13 + } + ] + }, + { + "cell_type": "code", + "source": [ + "ui_csr = dataset.get_user_item_matrix()\n", + "\n", + "user_ext_to_int = dataset.user_id_map.to_internal.to_dict()\n", + "item_int_to_ext = dataset.item_id_map.to_external.to_dict()\n", + "item_ext_to_int = {v: k for k, v in item_int_to_ext.items()}" + ], + "metadata": { + "id": "jmdqW12L1NzB" + }, + "execution_count": 14, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model = als_model.model" + ], + "metadata": { + "id": "IK8egJB73L3m" + }, + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "items_df = pd.read_csv(\"/content/drive/MyDrive/RecSys MTC/kion/items.csv\")\n", + "item_to_title = items_df[[\"item_id\", \"title\"]].set_index(\"item_id\").to_dict()[\"title\"]" + ], + "metadata": { + "id": "9T4jaz5gPREz" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import pickle\n", + "\n", + "\n", + "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/ui_csr.pickle\", \"wb\") as f:\n", + " pickle.dump(ui_csr, f)\n", + "\n", + "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/user_ext_to_int.pickle\", \"wb\") as f:\n", + " pickle.dump(user_ext_to_int, f)\n", + "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/item_int_to_ext.pickle\", \"wb\") as f:\n", + " pickle.dump(item_int_to_ext, f)\n", + "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/item_ext_to_int.pickle\", \"wb\") as f:\n", + " pickle.dump(item_ext_to_int, f)\n", + "\n", + "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/als_model.pickle\", \"wb\") as f:\n", + " pickle.dump(model, f)\n", + "\n", + "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/item_to_title.pickle\", \"wb\") as f:\n", + " pickle.dump(item_to_title, f)" + ], + "metadata": { + "id": "Ie7NDks53KuU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from math import exp\n", + "\n", + "\n", + "def explain(user_id, item_id, threshold=0.05):\n", + " if user_id not in user_ext_to_int or \\\n", + " item_id not in item_ext_to_int:\n", + " return None, None\n", + "\n", + " internal_userid = user_ext_to_int[user_id]\n", + " internal_itemid = item_ext_to_int[item_id]\n", + "\n", + " total_score, top_contributions, _ = model.explain(\n", + " userid=internal_userid,\n", + " user_items=ui_csr,\n", + " itemid=internal_itemid,\n", + " N=2\n", + " )\n", + " if total_score < threshold:\n", + " return None, None\n", + "\n", + " p = int((0.5 / (1 + exp(-(total_score * 5 - 1))) + 0.5) * 100)\n", + "\n", + " title_1 = item_to_title[\n", + " item_int_to_ext[top_contributions[0][0]]\n", + " ]\n", + " explanation = f\"Рекомендуем тем, кому нравится «{title_1}»\"\n", + "\n", + " if top_contributions[1][1] >= threshold:\n", + " title_2 = item_to_title[\n", + " item_int_to_ext[top_contributions[1][0]]\n", + " ]\n", + " explanation += f\" и «{title_2}»\"\n", + "\n", + " return p, explanation" + ], + "metadata": { + "id": "47kLK90x2nq2" + }, + "execution_count": 90, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "test_users = list(user_ext_to_int.keys())[1000:1005]\n", + "test_items = list(item_ext_to_int.keys())[1000:1005]\n", + "for user_id in test_users:\n", + " for item_id in test_items:\n", + " print(explain(user_id, item_id))\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8F7jaznRyrT6", + "outputId": "616d005f-67ae-4421-f836-4c2d4025f379" + }, + "execution_count": 91, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(None, None)\n", + "(66, 'Рекомендуем тем, кому нравится «Медиатор»')\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(81, 'Рекомендуем тем, кому нравится «Балканский рубеж» и «Застава»')\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(84, 'Рекомендуем тем, кому нравится «Коридор бессмертия» и «Легенда № 17»')\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n", + "(None, None)\n" + ] + } + ] + } + ] +} diff --git a/service/api/views.py b/service/api/views.py index 840fa194..e3e5a83d 100644 --- a/service/api/views.py +++ b/service/api/views.py @@ -1,3 +1,4 @@ +import random from typing import Any, Dict, List from fastapi import APIRouter, Depends, FastAPI, Request @@ -16,7 +17,7 @@ NotFoundError, ) from service.log import app_logger -from service.reco_models import models +from service.reco_models import als_model, models class RecoResponse(BaseModel): @@ -24,6 +25,11 @@ class RecoResponse(BaseModel): items: List[int] +class ExplainResponse(BaseModel): + p: int + explanation: str + + bearer_scheme = HTTPBearer() router = APIRouter() @@ -69,19 +75,35 @@ async def get_reco( else: raise ModelNotFoundError(error_message=f"Model {model_name} not found") - # if model_name in ("light_fm_1", "light_fm_2"): - # reco = ( - # online_fm_all_popular.predict(user_id, k_recs) - # if model_name == "light_fm_1" - # else online_fm_part_popular.predict(user_id, k_recs) - # ) - # if model_name == "ann_lightfm": - # reco = ann_lightfm.predict(user_id) - if not reco: reco = models["baseline"].predict(user_id, k_recs=k_recs) return RecoResponse(user_id=user_id, items=reco) +@router.get( + path="/explain/{model_name}/{user_id}/{item_id}", + tags=["Explanations"], + response_model=ExplainResponse, +) +async def explain(request: Request, model_name: str, user_id: int, item_id: int) -> ExplainResponse: + """ + Пользователь переходит на карточку контента, на которой нужно показать + процент релевантности этого контента зашедшему пользователю, + а также текстовое объяснение почему ему может понравится этот контент. + + :param request: запрос. + :param model_name: название модели, для которой нужно получить объяснения. + :param user_id: id пользователя, для которого нужны объяснения. + :param item_id: id контента, для которого нужны объяснения. + :return: Response со значением процента релевантности и текстовым объяснением, понятным пользователю. + - "p": "процент релевантности контента item_id для пользователя user_id" + - "explanation": "текстовое объяснение почему рекомендован item_id" + """ + p, explanation = als_model.explain(user_id, item_id) + if p is None: + return ExplainResponse(p=random.randint(50, 80), explanation="Вам может понравится") + return ExplainResponse(p=p, explanation=explanation) + + def add_views(app: FastAPI) -> None: app.include_router(router) diff --git a/service/reco_models/__init__.py b/service/reco_models/__init__.py index dca1579e..4b1eb672 100644 --- a/service/reco_models/__init__.py +++ b/service/reco_models/__init__.py @@ -1,7 +1,9 @@ from typing import Dict +from .als_model import ALS from .bert4rec_model import BERT4Rec from .configuration import ( + ALS_PATHS, OFFLINE_KNN_MODEL_PATH, ONLINE_KNN_MODEL_PATH, POPULAR_IN_CATEGORY, @@ -22,7 +24,7 @@ # from .lightfm_models import ANNLightFM, OnlineFM from .popular_models import PopularInCategory, SimplePopularModel, TestModel -__all__ = ("models",) +__all__ = ("models", "als_model") test_model = TestModel() @@ -48,6 +50,8 @@ bert4rec_model = BERT4Rec(model_path=BERT4Rec_model_path) +als_model = ALS(ALS_PATHS) + models: Dict[str, RecommendationModel] = { "test_model": test_model, "baseline": baseline_model, @@ -59,6 +63,7 @@ # "ann_lightfm" "dssm": dssm_model, "bert4rec": bert4rec_model, + "als": als_model, } # ---------------------------------------------------------- diff --git a/service/reco_models/als_model.py b/service/reco_models/als_model.py new file mode 100644 index 00000000..7c5ff614 --- /dev/null +++ b/service/reco_models/als_model.py @@ -0,0 +1,53 @@ +import pickle +from math import exp +from typing import Dict, List, Optional, Tuple + +from service.reco_models.model import RecommendationModel + + +class ALS(RecommendationModel): + def __init__(self, paths: Dict[str, str]) -> None: + with open(paths["ui_csr"], "rb") as f: + self.ui_csr = pickle.load(f) + + with open(paths["user_ext_to_int"], "rb") as f: + self.user_ext_to_int = pickle.load(f) + with open(paths["item_int_to_ext"], "rb") as f: + self.item_int_to_ext = pickle.load(f) + with open(paths["item_ext_to_int"], "rb") as f: + self.item_ext_to_int = pickle.load(f) + + with open(paths["als_model"], "rb") as f: + self.model = pickle.load(f) + + with open(paths["item_to_title"], "rb") as f: + self.item_to_title = pickle.load(f) + + def predict(self, user_id: int, k_recs: int = 10) -> Optional[List[int]]: + int_user_id = self.user_ext_to_int[user_id] + rec = self.model.recommend(int_user_id, user_items=self.ui_csr, N=k_recs, filter_already_liked_items=True) + return [self.item_int_to_ext[item_int_id] for (item_int_id, _) in rec] + + def explain(self, user_id: int, item_id: int, threshold: float = 0.05) -> Tuple[Optional[int], Optional[str]]: + if user_id not in self.user_ext_to_int or item_id not in self.item_ext_to_int: + return None, None + + internal_userid = self.user_ext_to_int[user_id] + internal_itemid = self.item_ext_to_int[item_id] + + total_score, top_contributions, _ = self.model.explain( + userid=internal_userid, user_items=self.ui_csr, itemid=internal_itemid, N=2 + ) + if total_score < threshold: + return None, None + + p = int((0.5 / (1 + exp(-(total_score * 5 - 1))) + 0.5) * 100) + + title_1 = self.item_to_title[self.item_int_to_ext[top_contributions[0][0]]] + explanation = f"Рекомендуем тем, кому нравится «{title_1}»" + + if top_contributions[1][1] >= threshold: + title_2 = self.item_to_title[self.item_int_to_ext[top_contributions[1][0]]] + explanation += f" и «{title_2}»" + + return p, explanation diff --git a/service/reco_models/configuration.py b/service/reco_models/configuration.py index 633a3f7e..53eb5fc5 100644 --- a/service/reco_models/configuration.py +++ b/service/reco_models/configuration.py @@ -40,3 +40,12 @@ DSSM_ef_s = 50 BERT4Rec_model_path = "models/bert4rec/user_id_to_bert4rec_recs.pickle" + +ALS_PATHS = { + "ui_csr": "models/als/ui_csr.pickle", + "user_ext_to_int": "models/als/user_ext_to_int.pickle", + "item_int_to_ext": "models/als/item_int_to_ext.pickle", + "item_ext_to_int": "models/als/item_ext_to_int.pickle", + "als_model": "models/als/als_model.pickle", + "item_to_title": "models/als/item_to_title.pickle", +} diff --git a/tests/api/test_views.py b/tests/api/test_views.py index 4cd4189b..9e33e851 100644 --- a/tests/api/test_views.py +++ b/tests/api/test_views.py @@ -5,6 +5,7 @@ from service.settings import ServiceConfig GET_RECO_PATH = "/reco/{model_name}/{user_id}" +GET_EXPLAIN_PATH = "/explain/{model_name}/{user_id}/{item_id}" def test_health( @@ -63,3 +64,18 @@ def test_bearer_failed( response = client.get(path, headers={"Authorization": f"Bearer {incorrect_bearer}"}) assert response.status_code == HTTPStatus.UNAUTHORIZED assert response.json()["errors"][0]["error_key"] == "incorrect_bearer_key" + + +def test_als_explain( + client: TestClient, + service_config: ServiceConfig, +) -> None: + user_id = 662395 + item_id = 4633 + path = GET_EXPLAIN_PATH.format(model_name="als", user_id=user_id, item_id=item_id) + with client: + response = client.get(path, headers={"Authorization": "Bearer Team_5"}) + assert response.status_code == HTTPStatus.OK + response_json = response.json() + assert response_json["p"] == 66 + assert response_json["explanation"] == "Рекомендуем тем, кому нравится «Медиатор»"