diff --git a/.gitignore b/.gitignore
index a4d63804..11f7554a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -41,6 +41,7 @@ models/lightfm/user_embeddings.dill
models/popular_in_category/
models/dssm/
models/bert4rec/
+models/als
# Unit test / coverage reports
htmlcov/
diff --git a/load_models_from_google_drive.sh b/load_models_from_google_drive.sh
index b9ab58d9..c361d88a 100755
--- a/load_models_from_google_drive.sh
+++ b/load_models_from_google_drive.sh
@@ -28,3 +28,10 @@ download "models/dssm" "iid_to_item_id.json" "1-TrGCS_YmRWQkIeuhSXKsN7Xg3Nk_pEn"
download "models/dssm" "uid_to_watched_iids.json" "1-QtArop7useHil5pIeAM2t-d9J1nKOhS"
download "models/bert4rec" "user_id_to_bert4rec_recs.pickle" "15o1tIcsFlkQdmbw2Z9kucMzv724H-Qi8"
+
+download "models/als" "ui_csr.pickle" "1u73wI918JbDnRRGZ19hDVNePoX-E75Vk"
+download "models/als" "user_ext_to_int.pickle" "1-2AN1y039gPAl0oKUIN6KIuT-MkUDqtC"
+download "models/als" "item_int_to_ext.pickle" "1--1QkLMZgTzSiQSrDdLZcKvhAQgqONNa"
+download "models/als" "item_ext_to_int.pickle" "1d_ecNjQfpxNDwJSh6fP1f2HWrDiSy3ao"
+download "models/als" "als_model.pickle" "1-8gJuelBZwJFhq7IDN1U2N2zKjPW1-kw"
+download "models/als" "item_to_title.pickle" "1voF8cCfifroAsbJzFzc5DXC2j0MEGYIh"
diff --git "a/notebooks/2023-spring. hw-4. ALS_\320\270\320\275\321\202\320\265\321\200\320\277\321\200\320\265\321\202\320\260\321\206\320\270\321\217.ipynb" "b/notebooks/2023-spring. hw-4. ALS_\320\270\320\275\321\202\320\265\321\200\320\277\321\200\320\265\321\202\320\260\321\206\320\270\321\217.ipynb"
new file mode 100644
index 00000000..2bc36ab3
--- /dev/null
+++ "b/notebooks/2023-spring. hw-4. ALS_\320\270\320\275\321\202\320\265\321\200\320\277\321\200\320\265\321\202\320\260\321\206\320\270\321\217.ipynb"
@@ -0,0 +1,1108 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4",
+ "collapsed_sections": [
+ "e1uQOgjY75wn",
+ "ekn9E-ih79T0"
+ ]
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "gpuClass": "standard",
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "b50b7514d4924d1eb2e3ce4b3d4a964a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_cfb9d9570b4a49a9bf782607056da77a",
+ "IPY_MODEL_2d811c1270224482930b0a5ac4889852",
+ "IPY_MODEL_a4ba3e0e0cda49f1855b18659ad52f73"
+ ],
+ "layout": "IPY_MODEL_076aaa32625f4aca9c01cab7d5c0ad36"
+ }
+ },
+ "cfb9d9570b4a49a9bf782607056da77a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c2c32f0d15614ede8702508c3ede6a9a",
+ "placeholder": "",
+ "style": "IPY_MODEL_125d0a4accb2478184df475818540350",
+ "value": "100%"
+ }
+ },
+ "2d811c1270224482930b0a5ac4889852": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_92962bbd778049d5b512a1f440becdce",
+ "max": 15,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_e2e4182854884016a9ed4f645a59ab13",
+ "value": 15
+ }
+ },
+ "a4ba3e0e0cda49f1855b18659ad52f73": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9a64281c06124ed18386e78310d529ad",
+ "placeholder": "",
+ "style": "IPY_MODEL_c96fb3b88dec43349a9eaf0a87d98d93",
+ "value": " 15/15 [00:04<00:00, 3.30it/s]"
+ }
+ },
+ "076aaa32625f4aca9c01cab7d5c0ad36": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c2c32f0d15614ede8702508c3ede6a9a": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "125d0a4accb2478184df475818540350": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "92962bbd778049d5b512a1f440becdce": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e2e4182854884016a9ed4f645a59ab13": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "9a64281c06124ed18386e78310d529ad": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c96fb3b88dec43349a9eaf0a87d98d93": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ }
+ }
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "source": [
+ "from rectools import Columns\n",
+ "from rectools.models import ImplicitALSWrapperModel\n",
+ "from rectools.dataset import Dataset\n",
+ "from rectools.models.utils import recommend_from_scores\n",
+ "\n",
+ "from implicit.als import AlternatingLeastSquares\n",
+ "\n",
+ "import pandas as pd\n",
+ "from collections import Counter"
+ ],
+ "metadata": {
+ "id": "YEkHPJRZA3h8"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Загрузка данных"
+ ],
+ "metadata": {
+ "id": "e1uQOgjY75wn"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "interactions_df = pd.read_csv(\"/content/drive/MyDrive/RecSys MTC/kion/interactions_processed.csv\")\n",
+ "interactions_df.rename(columns={\"last_watch_dt\": Columns.Datetime}, inplace=True)"
+ ],
+ "metadata": {
+ "id": "aoR1FBRMnrS8"
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "interactions_df.head(3)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 0
+ },
+ "id": "lCgzfC00GF4o",
+ "outputId": "e02cc58a-45e8-48d3-8180-8210e7d93fb0"
+ },
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " user_id item_id datetime total_dur watched_pct\n",
+ "0 176549 9506 2021-05-11 4250 72\n",
+ "1 699317 1659 2021-05-29 8317 100\n",
+ "2 656683 7107 2021-05-09 10 0"
+ ],
+ "text/html": [
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " datetime | \n",
+ " total_dur | \n",
+ " watched_pct | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 176549 | \n",
+ " 9506 | \n",
+ " 2021-05-11 | \n",
+ " 4250 | \n",
+ " 72 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 699317 | \n",
+ " 1659 | \n",
+ " 2021-05-29 | \n",
+ " 8317 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 656683 | \n",
+ " 7107 | \n",
+ " 2021-05-09 | \n",
+ " 10 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "execution_count": 4
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Предобработка"
+ ],
+ "metadata": {
+ "id": "ekn9E-ih79T0"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "interactions_df = interactions_df[interactions_df.watched_pct >= 5]"
+ ],
+ "metadata": {
+ "id": "QANVdjKNvdQO"
+ },
+ "execution_count": 6,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# отбросим пользователей с малым числом просмотров\n",
+ "active_users = []\n",
+ "inactive_users = []\n",
+ "c = Counter(interactions_df.user_id)\n",
+ "for user_id, entries in c.items():\n",
+ " if entries >= 4:\n",
+ " active_users.append(user_id)\n",
+ " else:\n",
+ " inactive_users.append(user_id)\n",
+ "\n",
+ "interactions_df = interactions_df[interactions_df.user_id.isin(active_users)]\n",
+ "\n",
+ "len(active_users), len(inactive_users)"
+ ],
+ "metadata": {
+ "id": "vr3mW58NE8TL",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "825a7193-4b28-43f5-9958-f33c3898d071"
+ },
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "(286206, 507802)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 7
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Дадим эмпирический рейтинг взаимодействия с айтемом на основании процента просмотра:\n",
+ "\n",
+ "- 1 - просмотрено 0-10%\n",
+ "- 2 - просмотрено 10-30%\n",
+ "- 3 - просмотрено 30-60%\n",
+ "- 4 - просмотрено 60-85%\n",
+ "- 5 - просмотрено 85-100%"
+ ],
+ "metadata": {
+ "id": "9MZPXOGTJ4u4"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def watched_pct_to_score(pct: float) -> int:\n",
+ " if 85 <= pct <= 100:\n",
+ " return 5\n",
+ " elif 60 <= pct < 85:\n",
+ " return 4\n",
+ " elif 30 <= pct < 60:\n",
+ " return 3\n",
+ " elif 10 <= pct < 30:\n",
+ " return 2\n",
+ " else:\n",
+ " return 1"
+ ],
+ "metadata": {
+ "id": "aofKPB3VKioS"
+ },
+ "execution_count": 8,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "interactions_df[Columns.Weight] = interactions_df[\"watched_pct\"].apply(lambda pct: watched_pct_to_score(pct))"
+ ],
+ "metadata": {
+ "id": "KzzGdIzA7cUh"
+ },
+ "execution_count": 9,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "interactions_df.tail(3)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 0
+ },
+ "id": "Pp2o6tG6BG6H",
+ "outputId": "b9f580a2-2995-44a1-b232-542cf1ffd8e5"
+ },
+ "execution_count": 10,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " user_id item_id datetime total_dur watched_pct weight\n",
+ "5476247 546862 9673 2021-04-13 2308 49 3\n",
+ "5476249 384202 16197 2021-04-19 6203 100 5\n",
+ "5476250 319709 4436 2021-08-15 3921 45 3"
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " datetime | \n",
+ " total_dur | \n",
+ " watched_pct | \n",
+ " weight | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 5476247 | \n",
+ " 546862 | \n",
+ " 9673 | \n",
+ " 2021-04-13 | \n",
+ " 2308 | \n",
+ " 49 | \n",
+ " 3 | \n",
+ "
\n",
+ " \n",
+ " | 5476249 | \n",
+ " 384202 | \n",
+ " 16197 | \n",
+ " 2021-04-19 | \n",
+ " 6203 | \n",
+ " 100 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ " | 5476250 | \n",
+ " 319709 | \n",
+ " 4436 | \n",
+ " 2021-08-15 | \n",
+ " 3921 | \n",
+ " 45 | \n",
+ " 3 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "execution_count": 10
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## ALS рекомендации и их объяснение для активных пользователей"
+ ],
+ "metadata": {
+ "id": "NKkzlx7oJB40"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "dataset = Dataset.construct(\n",
+ " interactions_df[Columns.Interactions]\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "rPp4kjyMJukY"
+ },
+ "execution_count": 11,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "k_recs = 10\n",
+ "\n",
+ "N_FACTORS = 64\n",
+ "RANDOM_STATE = 2023\n",
+ "NUM_THREADS = 16\n",
+ "\n",
+ "als_model = ImplicitALSWrapperModel(\n",
+ " model=AlternatingLeastSquares(\n",
+ " factors=N_FACTORS,\n",
+ " random_state=RANDOM_STATE,\n",
+ " num_threads=NUM_THREADS\n",
+ " ), verbose=1,\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "2kK8vVtGnrPe"
+ },
+ "execution_count": 12,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "als_model.fit(dataset)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 66,
+ "referenced_widgets": [
+ "b50b7514d4924d1eb2e3ce4b3d4a964a",
+ "cfb9d9570b4a49a9bf782607056da77a",
+ "2d811c1270224482930b0a5ac4889852",
+ "a4ba3e0e0cda49f1855b18659ad52f73",
+ "076aaa32625f4aca9c01cab7d5c0ad36",
+ "c2c32f0d15614ede8702508c3ede6a9a",
+ "125d0a4accb2478184df475818540350",
+ "92962bbd778049d5b512a1f440becdce",
+ "e2e4182854884016a9ed4f645a59ab13",
+ "9a64281c06124ed18386e78310d529ad",
+ "c96fb3b88dec43349a9eaf0a87d98d93"
+ ]
+ },
+ "id": "ogKYQ0GYx7vF",
+ "outputId": "ef86c801-d417-4bcf-8d20-a23288a6744a"
+ },
+ "execution_count": 13,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ " 0%| | 0/15 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "b50b7514d4924d1eb2e3ce4b3d4a964a"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 13
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "ui_csr = dataset.get_user_item_matrix()\n",
+ "\n",
+ "user_ext_to_int = dataset.user_id_map.to_internal.to_dict()\n",
+ "item_int_to_ext = dataset.item_id_map.to_external.to_dict()\n",
+ "item_ext_to_int = {v: k for k, v in item_int_to_ext.items()}"
+ ],
+ "metadata": {
+ "id": "jmdqW12L1NzB"
+ },
+ "execution_count": 14,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model = als_model.model"
+ ],
+ "metadata": {
+ "id": "IK8egJB73L3m"
+ },
+ "execution_count": 15,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "items_df = pd.read_csv(\"/content/drive/MyDrive/RecSys MTC/kion/items.csv\")\n",
+ "item_to_title = items_df[[\"item_id\", \"title\"]].set_index(\"item_id\").to_dict()[\"title\"]"
+ ],
+ "metadata": {
+ "id": "9T4jaz5gPREz"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import pickle\n",
+ "\n",
+ "\n",
+ "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/ui_csr.pickle\", \"wb\") as f:\n",
+ " pickle.dump(ui_csr, f)\n",
+ "\n",
+ "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/user_ext_to_int.pickle\", \"wb\") as f:\n",
+ " pickle.dump(user_ext_to_int, f)\n",
+ "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/item_int_to_ext.pickle\", \"wb\") as f:\n",
+ " pickle.dump(item_int_to_ext, f)\n",
+ "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/item_ext_to_int.pickle\", \"wb\") as f:\n",
+ " pickle.dump(item_ext_to_int, f)\n",
+ "\n",
+ "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/als_model.pickle\", \"wb\") as f:\n",
+ " pickle.dump(model, f)\n",
+ "\n",
+ "with open(\"/content/drive/MyDrive/RecSys MTC/prod/prod_models/als/item_to_title.pickle\", \"wb\") as f:\n",
+ " pickle.dump(item_to_title, f)"
+ ],
+ "metadata": {
+ "id": "Ie7NDks53KuU"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from math import exp\n",
+ "\n",
+ "\n",
+ "def explain(user_id, item_id, threshold=0.05):\n",
+ " if user_id not in user_ext_to_int or \\\n",
+ " item_id not in item_ext_to_int:\n",
+ " return None, None\n",
+ "\n",
+ " internal_userid = user_ext_to_int[user_id]\n",
+ " internal_itemid = item_ext_to_int[item_id]\n",
+ "\n",
+ " total_score, top_contributions, _ = model.explain(\n",
+ " userid=internal_userid,\n",
+ " user_items=ui_csr,\n",
+ " itemid=internal_itemid,\n",
+ " N=2\n",
+ " )\n",
+ " if total_score < threshold:\n",
+ " return None, None\n",
+ "\n",
+ " p = int((0.5 / (1 + exp(-(total_score * 5 - 1))) + 0.5) * 100)\n",
+ "\n",
+ " title_1 = item_to_title[\n",
+ " item_int_to_ext[top_contributions[0][0]]\n",
+ " ]\n",
+ " explanation = f\"Рекомендуем тем, кому нравится «{title_1}»\"\n",
+ "\n",
+ " if top_contributions[1][1] >= threshold:\n",
+ " title_2 = item_to_title[\n",
+ " item_int_to_ext[top_contributions[1][0]]\n",
+ " ]\n",
+ " explanation += f\" и «{title_2}»\"\n",
+ "\n",
+ " return p, explanation"
+ ],
+ "metadata": {
+ "id": "47kLK90x2nq2"
+ },
+ "execution_count": 90,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "test_users = list(user_ext_to_int.keys())[1000:1005]\n",
+ "test_items = list(item_ext_to_int.keys())[1000:1005]\n",
+ "for user_id in test_users:\n",
+ " for item_id in test_items:\n",
+ " print(explain(user_id, item_id))\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "8F7jaznRyrT6",
+ "outputId": "616d005f-67ae-4421-f836-4c2d4025f379"
+ },
+ "execution_count": 91,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "(None, None)\n",
+ "(66, 'Рекомендуем тем, кому нравится «Медиатор»')\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(81, 'Рекомендуем тем, кому нравится «Балканский рубеж» и «Застава»')\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(84, 'Рекомендуем тем, кому нравится «Коридор бессмертия» и «Легенда № 17»')\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n",
+ "(None, None)\n"
+ ]
+ }
+ ]
+ }
+ ]
+}
diff --git a/service/api/views.py b/service/api/views.py
index 840fa194..e3e5a83d 100644
--- a/service/api/views.py
+++ b/service/api/views.py
@@ -1,3 +1,4 @@
+import random
from typing import Any, Dict, List
from fastapi import APIRouter, Depends, FastAPI, Request
@@ -16,7 +17,7 @@
NotFoundError,
)
from service.log import app_logger
-from service.reco_models import models
+from service.reco_models import als_model, models
class RecoResponse(BaseModel):
@@ -24,6 +25,11 @@ class RecoResponse(BaseModel):
items: List[int]
+class ExplainResponse(BaseModel):
+ p: int
+ explanation: str
+
+
bearer_scheme = HTTPBearer()
router = APIRouter()
@@ -69,19 +75,35 @@ async def get_reco(
else:
raise ModelNotFoundError(error_message=f"Model {model_name} not found")
- # if model_name in ("light_fm_1", "light_fm_2"):
- # reco = (
- # online_fm_all_popular.predict(user_id, k_recs)
- # if model_name == "light_fm_1"
- # else online_fm_part_popular.predict(user_id, k_recs)
- # )
- # if model_name == "ann_lightfm":
- # reco = ann_lightfm.predict(user_id)
-
if not reco:
reco = models["baseline"].predict(user_id, k_recs=k_recs)
return RecoResponse(user_id=user_id, items=reco)
+@router.get(
+ path="/explain/{model_name}/{user_id}/{item_id}",
+ tags=["Explanations"],
+ response_model=ExplainResponse,
+)
+async def explain(request: Request, model_name: str, user_id: int, item_id: int) -> ExplainResponse:
+ """
+ Пользователь переходит на карточку контента, на которой нужно показать
+ процент релевантности этого контента зашедшему пользователю,
+ а также текстовое объяснение почему ему может понравится этот контент.
+
+ :param request: запрос.
+ :param model_name: название модели, для которой нужно получить объяснения.
+ :param user_id: id пользователя, для которого нужны объяснения.
+ :param item_id: id контента, для которого нужны объяснения.
+ :return: Response со значением процента релевантности и текстовым объяснением, понятным пользователю.
+ - "p": "процент релевантности контента item_id для пользователя user_id"
+ - "explanation": "текстовое объяснение почему рекомендован item_id"
+ """
+ p, explanation = als_model.explain(user_id, item_id)
+ if p is None:
+ return ExplainResponse(p=random.randint(50, 80), explanation="Вам может понравится")
+ return ExplainResponse(p=p, explanation=explanation)
+
+
def add_views(app: FastAPI) -> None:
app.include_router(router)
diff --git a/service/reco_models/__init__.py b/service/reco_models/__init__.py
index dca1579e..4b1eb672 100644
--- a/service/reco_models/__init__.py
+++ b/service/reco_models/__init__.py
@@ -1,7 +1,9 @@
from typing import Dict
+from .als_model import ALS
from .bert4rec_model import BERT4Rec
from .configuration import (
+ ALS_PATHS,
OFFLINE_KNN_MODEL_PATH,
ONLINE_KNN_MODEL_PATH,
POPULAR_IN_CATEGORY,
@@ -22,7 +24,7 @@
# from .lightfm_models import ANNLightFM, OnlineFM
from .popular_models import PopularInCategory, SimplePopularModel, TestModel
-__all__ = ("models",)
+__all__ = ("models", "als_model")
test_model = TestModel()
@@ -48,6 +50,8 @@
bert4rec_model = BERT4Rec(model_path=BERT4Rec_model_path)
+als_model = ALS(ALS_PATHS)
+
models: Dict[str, RecommendationModel] = {
"test_model": test_model,
"baseline": baseline_model,
@@ -59,6 +63,7 @@
# "ann_lightfm"
"dssm": dssm_model,
"bert4rec": bert4rec_model,
+ "als": als_model,
}
# ----------------------------------------------------------
diff --git a/service/reco_models/als_model.py b/service/reco_models/als_model.py
new file mode 100644
index 00000000..7c5ff614
--- /dev/null
+++ b/service/reco_models/als_model.py
@@ -0,0 +1,53 @@
+import pickle
+from math import exp
+from typing import Dict, List, Optional, Tuple
+
+from service.reco_models.model import RecommendationModel
+
+
+class ALS(RecommendationModel):
+ def __init__(self, paths: Dict[str, str]) -> None:
+ with open(paths["ui_csr"], "rb") as f:
+ self.ui_csr = pickle.load(f)
+
+ with open(paths["user_ext_to_int"], "rb") as f:
+ self.user_ext_to_int = pickle.load(f)
+ with open(paths["item_int_to_ext"], "rb") as f:
+ self.item_int_to_ext = pickle.load(f)
+ with open(paths["item_ext_to_int"], "rb") as f:
+ self.item_ext_to_int = pickle.load(f)
+
+ with open(paths["als_model"], "rb") as f:
+ self.model = pickle.load(f)
+
+ with open(paths["item_to_title"], "rb") as f:
+ self.item_to_title = pickle.load(f)
+
+ def predict(self, user_id: int, k_recs: int = 10) -> Optional[List[int]]:
+ int_user_id = self.user_ext_to_int[user_id]
+ rec = self.model.recommend(int_user_id, user_items=self.ui_csr, N=k_recs, filter_already_liked_items=True)
+ return [self.item_int_to_ext[item_int_id] for (item_int_id, _) in rec]
+
+ def explain(self, user_id: int, item_id: int, threshold: float = 0.05) -> Tuple[Optional[int], Optional[str]]:
+ if user_id not in self.user_ext_to_int or item_id not in self.item_ext_to_int:
+ return None, None
+
+ internal_userid = self.user_ext_to_int[user_id]
+ internal_itemid = self.item_ext_to_int[item_id]
+
+ total_score, top_contributions, _ = self.model.explain(
+ userid=internal_userid, user_items=self.ui_csr, itemid=internal_itemid, N=2
+ )
+ if total_score < threshold:
+ return None, None
+
+ p = int((0.5 / (1 + exp(-(total_score * 5 - 1))) + 0.5) * 100)
+
+ title_1 = self.item_to_title[self.item_int_to_ext[top_contributions[0][0]]]
+ explanation = f"Рекомендуем тем, кому нравится «{title_1}»"
+
+ if top_contributions[1][1] >= threshold:
+ title_2 = self.item_to_title[self.item_int_to_ext[top_contributions[1][0]]]
+ explanation += f" и «{title_2}»"
+
+ return p, explanation
diff --git a/service/reco_models/configuration.py b/service/reco_models/configuration.py
index 633a3f7e..53eb5fc5 100644
--- a/service/reco_models/configuration.py
+++ b/service/reco_models/configuration.py
@@ -40,3 +40,12 @@
DSSM_ef_s = 50
BERT4Rec_model_path = "models/bert4rec/user_id_to_bert4rec_recs.pickle"
+
+ALS_PATHS = {
+ "ui_csr": "models/als/ui_csr.pickle",
+ "user_ext_to_int": "models/als/user_ext_to_int.pickle",
+ "item_int_to_ext": "models/als/item_int_to_ext.pickle",
+ "item_ext_to_int": "models/als/item_ext_to_int.pickle",
+ "als_model": "models/als/als_model.pickle",
+ "item_to_title": "models/als/item_to_title.pickle",
+}
diff --git a/tests/api/test_views.py b/tests/api/test_views.py
index 4cd4189b..9e33e851 100644
--- a/tests/api/test_views.py
+++ b/tests/api/test_views.py
@@ -5,6 +5,7 @@
from service.settings import ServiceConfig
GET_RECO_PATH = "/reco/{model_name}/{user_id}"
+GET_EXPLAIN_PATH = "/explain/{model_name}/{user_id}/{item_id}"
def test_health(
@@ -63,3 +64,18 @@ def test_bearer_failed(
response = client.get(path, headers={"Authorization": f"Bearer {incorrect_bearer}"})
assert response.status_code == HTTPStatus.UNAUTHORIZED
assert response.json()["errors"][0]["error_key"] == "incorrect_bearer_key"
+
+
+def test_als_explain(
+ client: TestClient,
+ service_config: ServiceConfig,
+) -> None:
+ user_id = 662395
+ item_id = 4633
+ path = GET_EXPLAIN_PATH.format(model_name="als", user_id=user_id, item_id=item_id)
+ with client:
+ response = client.get(path, headers={"Authorization": "Bearer Team_5"})
+ assert response.status_code == HTTPStatus.OK
+ response_json = response.json()
+ assert response_json["p"] == 66
+ assert response_json["explanation"] == "Рекомендуем тем, кому нравится «Медиатор»"