diff --git a/radis/chats/utils/chat_client.py b/radis/chats/utils/chat_client.py index b310247a..c9c1f9dd 100644 --- a/radis/chats/utils/chat_client.py +++ b/radis/chats/utils/chat_client.py @@ -16,53 +16,171 @@ def _get_base_url() -> str: return base_url -class AsyncChatClient: +def _validate_completion_response(completion) -> str: + """ + Validates that the LLM completion response contains valid content. + + Args: + completion: The completion response from the LLM + + Returns: + The message content as a string + + Raises: + ValueError: If the response is empty or invalid + """ + if not completion.choices: + logger.error("LLM returned empty choices list") + raise ValueError("LLM returned no response choices") + + answer = completion.choices[0].message.content + if answer is None: + logger.error("LLM returned None for message content") + raise ValueError("LLM returned empty response content") + + return answer + + +def _validate_parsed_response(completion) -> BaseModel: + """ + Validates that the LLM completion response contains valid parsed data. + + Args: + completion: The completion response from the LLM + + Returns: + The parsed BaseModel instance + + Raises: + ValueError: If the response is empty or invalid + """ + if not completion.choices: + logger.error("LLM returned empty choices list") + raise ValueError("LLM returned no response choices") + + parsed = completion.choices[0].message.parsed + if parsed is None: + logger.error("LLM returned None for parsed message") + raise ValueError("LLM returned empty parsed response") + + return parsed + + +def _handle_api_error(error: openai.APIError, operation: str) -> None: + """ + Logs and re-raises API errors with consistent error messages. + + Args: + error: The API error from OpenAI + operation: Description of the operation that failed (e.g., "chat", "data extraction") + + Raises: + RuntimeError: Always raises with a user-friendly error message + """ + logger.error(f"OpenAI API error during {operation}: {error}") + raise RuntimeError(f"Failed to communicate with LLM service: {error}") from error + + +class _BaseChatClient: + """Base class containing shared chat client logic.""" + def __init__(self): - base_url = _get_base_url() - api_key = settings.EXTERNAL_LLM_PROVIDER_API_KEY - self._client = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) + self._base_url = _get_base_url() + self._api_key = settings.EXTERNAL_LLM_PROVIDER_API_KEY self._model_name = settings.LLM_MODEL_NAME - async def chat( + def _build_chat_request( self, messages: Iterable[ChatCompletionMessageParam], max_completion_tokens: int | None = None, - ) -> str: - logger.debug(f"Sending messages to LLM for chat:\n{messages}") - + ) -> dict: + """Build the request dictionary for chat completion.""" request = { "model": self._model_name, "messages": messages, } if max_completion_tokens is not None: request["max_completion_tokens"] = max_completion_tokens + return request - completion = await self._client.chat.completions.create(**request) - answer = completion.choices[0].message.content - assert answer is not None + def _log_request(self, messages: Iterable[ChatCompletionMessageParam]) -> None: + """Log the outgoing request.""" + logger.debug(f"Sending messages to LLM for chat:\n{messages}") + + def _log_response(self, answer: str) -> None: + """Log the response from LLM.""" logger.debug("Received from LLM: %s", answer) + + +class AsyncChatClient(_BaseChatClient): + def __init__(self): + super().__init__() + self._client = openai.AsyncOpenAI(base_url=self._base_url, api_key=self._api_key) + + async def chat( + self, + messages: Iterable[ChatCompletionMessageParam], + max_completion_tokens: int | None = None, + ) -> str: + self._log_request(messages) + request = self._build_chat_request(messages, max_completion_tokens) + + try: + completion = await self._client.chat.completions.create(**request) + except openai.APIError as e: + _handle_api_error(e, "chat") + + answer = _validate_completion_response(completion) + self._log_response(answer) return answer -class ChatClient: +class ChatClient(_BaseChatClient): def __init__(self) -> None: - base_url = _get_base_url() - api_key = settings.EXTERNAL_LLM_PROVIDER_API_KEY + super().__init__() + self._client = openai.OpenAI(base_url=self._base_url, api_key=self._api_key) + + def chat( + self, + messages: Iterable[ChatCompletionMessageParam], + max_completion_tokens: int | None = None, + ) -> str: + """ + Send messages to LLM and return the response text. - self._client = openai.OpenAI(base_url=base_url, api_key=api_key) - self._llm_model_name = settings.LLM_MODEL_NAME + Args: + messages: List of message dictionaries with 'role' and 'content' + max_completion_tokens: Optional maximum tokens to generate + + Returns: + The LLM's response as a string + """ + self._log_request(messages) + request = self._build_chat_request(messages, max_completion_tokens) + + try: + completion = self._client.chat.completions.create(**request) + except openai.APIError as e: + _handle_api_error(e, "chat") + + answer = _validate_completion_response(completion) + self._log_response(answer) + return answer def extract_data(self, prompt: str, schema: type[BaseModel]) -> BaseModel: logger.debug("Sending prompt and schema to LLM to extract data.") logger.debug("Prompt:\n%s", prompt) logger.debug("Schema:\n%s", schema.model_json_schema()) - completion = self._client.beta.chat.completions.parse( - model=self._llm_model_name, - messages=[{"role": "system", "content": prompt}], - response_format=schema, - ) - event = completion.choices[0].message.parsed - assert event + try: + completion = self._client.beta.chat.completions.parse( + model=self._model_name, + messages=[{"role": "system", "content": prompt}], + response_format=schema, + ) + except openai.APIError as e: + _handle_api_error(e, "data extraction") + + event = _validate_parsed_response(completion) logger.debug("Received from LLM: %s", event) return event diff --git a/radis/core/constants.py b/radis/core/constants.py index db1b048b..8e1d74bd 100644 --- a/radis/core/constants.py +++ b/radis/core/constants.py @@ -2,3 +2,7 @@ "de": "German", "en": "English", } + +MIN_AGE = 0 +MAX_AGE = 120 +AGE_STEP = 10 diff --git a/radis/core/form_fields.py b/radis/core/form_fields.py new file mode 100644 index 00000000..1ce2a311 --- /dev/null +++ b/radis/core/form_fields.py @@ -0,0 +1,177 @@ +""" +Reusable form field factories for RADIS forms. + +This module provides factory functions for commonly used form fields +to reduce duplication across the codebase. +""" + +from typing import Literal, overload + +from django import forms + +from radis.core.constants import AGE_STEP, LANGUAGE_LABELS, MAX_AGE, MIN_AGE +from radis.reports.models import Language, Modality + + +@overload +def create_language_field( + required: bool = False, + empty_label: str | None = None, + use_pk: Literal[True] = True, +) -> forms.ModelChoiceField: ... + + +@overload +def create_language_field( + required: bool = False, + empty_label: str | None = None, + use_pk: Literal[False] = False, +) -> forms.ChoiceField: ... + + +def create_language_field( + required: bool = False, + empty_label: str | None = None, + use_pk: bool = True, +) -> forms.ModelChoiceField | forms.ChoiceField: + """ + Create a language choice field with consistent configuration. + + Args: + required: Whether the field is required + empty_label: Label for empty option (None = no empty option) + use_pk: If True, returns ModelChoiceField with Language objects; + if False, returns ChoiceField with code strings + + Returns: + ModelChoiceField (if use_pk=True) or ChoiceField (if use_pk=False) + + Example: + # For extraction forms (uses ModelChoiceField, returns Language objects) + self.fields["language"] = create_language_field() + + # For subscription forms (uses ModelChoiceField, allows "All") + self.fields["language"] = create_language_field(empty_label="All") + + # For search forms (uses ChoiceField with codes) + self.fields["language"] = create_language_field(use_pk=False) + """ + languages = Language.objects.order_by("code") + + if use_pk: + # Return ModelChoiceField - cleaned_data will contain Language objects + field = forms.ModelChoiceField( + queryset=languages, + required=required, + empty_label=empty_label, + ) + field.label_from_instance = lambda obj: LANGUAGE_LABELS[obj.code] + return field + else: + # Return ChoiceField - cleaned_data will contain code strings + field = forms.ChoiceField(required=required) + field.choices = [(lang.code, LANGUAGE_LABELS[lang.code]) for lang in languages] + if empty_label is not None: + field.empty_label = empty_label # type: ignore + return field + + +@overload +def create_modality_field( + required: bool = False, + widget_size: int = 6, + use_pk: Literal[True] = True, +) -> forms.ModelMultipleChoiceField: ... + + +@overload +def create_modality_field( + required: bool = False, + widget_size: int = 6, + use_pk: Literal[False] = False, +) -> forms.MultipleChoiceField: ... + + +def create_modality_field( + required: bool = False, + widget_size: int = 6, + use_pk: bool = True, +) -> forms.ModelMultipleChoiceField | forms.MultipleChoiceField: + """ + Create a modality multiple choice field with consistent configuration. + + Args: + required: Whether the field is required + widget_size: Height of the select widget + use_pk: If True, returns ModelMultipleChoiceField with Modality objects; + if False, returns MultipleChoiceField with code strings + + Returns: + ModelMultipleChoiceField (if use_pk=True) or MultipleChoiceField (if use_pk=False) + + Example: + # For extraction forms (uses ModelMultipleChoiceField, returns Modality objects) + self.fields["modalities"] = create_modality_field() + + # For search forms (uses MultipleChoiceField with codes) + self.fields["modalities"] = create_modality_field(use_pk=False) + """ + modalities = Modality.objects.filter(filterable=True).order_by("code") + + if use_pk: + # Return ModelMultipleChoiceField - cleaned_data will contain Modality objects + field = forms.ModelMultipleChoiceField( + queryset=modalities, + required=required, + ) + # Display just the code for each modality + field.label_from_instance = lambda obj: obj.code + field.widget.attrs["size"] = widget_size + return field + else: + # Return MultipleChoiceField - cleaned_data will contain code strings + field = forms.MultipleChoiceField(required=required) + field.choices = [(mod.code, mod.code) for mod in modalities] + field.widget.attrs["size"] = widget_size + return field + + +def create_age_range_fields() -> tuple[forms.IntegerField, forms.IntegerField]: + """ + Create age_from and age_till fields with consistent configuration. + + Returns: + Tuple of (age_from_field, age_till_field) + + Example: + age_from, age_till = create_age_range_fields() + self.fields["age_from"] = age_from + self.fields["age_till"] = age_till + """ + age_from = forms.IntegerField( + required=False, + min_value=MIN_AGE, + max_value=MAX_AGE, + widget=forms.NumberInput( + attrs={ + "type": "range", + "step": AGE_STEP, + "value": MIN_AGE, + } + ), + ) + + age_till = forms.IntegerField( + required=False, + min_value=MIN_AGE, + max_value=MAX_AGE, + widget=forms.NumberInput( + attrs={ + "type": "range", + "step": AGE_STEP, + "value": MAX_AGE, + } + ), + ) + + return age_from, age_till diff --git a/radis/core/static/core/core.js b/radis/core/static/core/core.js index cfff69a0..31468d10 100644 --- a/radis/core/static/core/core.js +++ b/radis/core/static/core/core.js @@ -46,6 +46,23 @@ document.addEventListener("alpine:init", () => { }); }); +document.addEventListener("DOMContentLoaded", () => { + const preventAttr = "[data-prevent-enter-submit]"; + document.querySelectorAll(preventAttr).forEach((formEl) => { + formEl.addEventListener("keydown", (event) => { + if (event.key !== "Enter") { + return; + } + const target = event.target; + const isTextInput = + target instanceof HTMLInputElement && + !["submit", "button"].includes(target.type); + if (isTextInput) { + event.preventDefault(); + } + }); + }); +}); /** * An Alpine component that controls a Django FormSet * @@ -66,15 +83,14 @@ function FormSet(rootEl) { formCount: parseInt(totalForms.value), minForms: parseInt(minForms.value), maxForms: parseInt(maxForms.value), - init() { - console.log(this.formCount); - }, + init() {}, addForm() { - const newForm = template.content.cloneNode(true); + if (!template) { + return; + } const idx = totalForms.value; - container.append(newForm); - const lastForm = container.querySelector(".formset-form:last-child"); - lastForm.innerHTML = lastForm.innerHTML.replace(/__prefix__/g, idx); + const html = template.innerHTML.replace(/__prefix__/g, idx); + container.insertAdjacentHTML("beforeend", html); totalForms.value = (parseInt(idx) + 1).toString(); this.formCount = parseInt(totalForms.value); }, diff --git a/radis/core/templates/cotton/formset.html b/radis/core/templates/cotton/formset.html index 20c6af7b..95ea87bd 100644 --- a/radis/core/templates/cotton/formset.html +++ b/radis/core/templates/cotton/formset.html @@ -7,7 +7,9 @@ {% crispy formset.empty_form %}
- {% for form in formset %}{{ form|crispy }}{% endfor %} + {% for form in formset %} + {% crispy form %} + {% endfor %}
{% if add_form_label %}
diff --git a/radis/extractions/constants.py b/radis/extractions/constants.py new file mode 100644 index 00000000..bacfd27c --- /dev/null +++ b/radis/extractions/constants.py @@ -0,0 +1 @@ +MAX_SELECTION_OPTIONS = 7 diff --git a/radis/extractions/factories.py b/radis/extractions/factories.py index 6994a8cd..1aa0c931 100644 --- a/radis/extractions/factories.py +++ b/radis/extractions/factories.py @@ -1,15 +1,16 @@ +from typing import cast + import factory +from adit_radis_shared.accounts.factories import GroupFactory, UserFactory +from adit_radis_shared.accounts.models import User +from adit_radis_shared.common.utils.testing_helpers import add_user_to_group +from django.contrib.auth.models import Group +from factory.declarations import SKIP from faker import Faker from radis.reports.factories import ModalityFactory -from .models import ( - ExtractionInstance, - ExtractionJob, - ExtractionTask, - OutputField, - OutputType, -) +from .models import ExtractionInstance, ExtractionJob, ExtractionTask, OutputField, OutputType fake = Faker() @@ -26,8 +27,9 @@ class ExtractionJobFactory(BaseDjangoModelFactory): class Meta: model = ExtractionJob + owner = factory.SubFactory(UserFactory) title = factory.Faker("sentence", nb_words=3) - group = factory.SubFactory("adit_radis_shared.accounts.factories.GroupFactory") + group = factory.SubFactory(GroupFactory) query = factory.Faker("word") language = factory.SubFactory("radis.reports.factories.LanguageFactory") study_date_from = factory.Faker("date") @@ -56,15 +58,34 @@ def modalities(self, create, extracted, **kwargs): # django_get_or_create would not be respected then self.modalities.add(ModalityFactory(code=modality)) # type: ignore + @factory.post_generation + def ensure_owner_in_group(obj, create, extracted, **kwargs): + owner = cast(User, obj.owner) + group = cast(Group, obj.group) + + if not create: + return + + add_user_to_group(owner, group) + class OutputFieldFactory(BaseDjangoModelFactory[OutputField]): class Meta: model = OutputField - job = factory.SubFactory("radis.extractions.factories.ExtractionJobFactory") + # Use factory.Maybe to conditionally create job only when subscription is None + job = factory.Maybe( + factory.SelfAttribute("subscription"), + yes_declaration=SKIP, # If subscription exists, skip job creation + no_declaration=factory.SubFactory("radis.extractions.factories.ExtractionJobFactory"), # type: ignore[arg-type] + ) + subscription = None name = factory.Sequence(lambda n: f"output_field_{n}") description = factory.Faker("sentence", nb_words=10) output_type = factory.Faker("random_element", elements=[a[0] for a in OutputType.choices]) + selection_options = factory.LazyAttribute( + lambda obj: ["Option 1", "Option 2"] if obj.output_type == OutputType.SELECTION else [] + ) class ExtractionTaskFactory(BaseDjangoModelFactory[ExtractionTask]): diff --git a/radis/extractions/forms.py b/radis/extractions/forms.py index 7a1d559d..0d9c6162 100644 --- a/radis/extractions/forms.py +++ b/radis/extractions/forms.py @@ -1,21 +1,27 @@ +import json from typing import Any, cast from adit_radis_shared.accounts.models import User from crispy_forms.helper import FormHelper -from crispy_forms.layout import Column, Layout, Row, Submit +from crispy_forms.layout import HTML, Column, Div, Field, Layout, Row, Submit from django import forms from django.conf import settings from django.db.models import QuerySet -from radis.core.constants import LANGUAGE_LABELS +from radis.core.form_fields import ( + create_age_range_fields, + create_language_field, + create_modality_field, +) from radis.core.layouts import RangeSlider from radis.reports.models import Language, Modality -from radis.search.forms import AGE_STEP, MAX_AGE, MIN_AGE from radis.search.site import Search, SearchFilters from radis.search.utils.query_parser import QueryParser -from .models import ExtractionJob, OutputField +from .constants import MAX_SELECTION_OPTIONS +from .models import ExtractionJob, OutputField, OutputType from .site import extraction_retrieval_provider +from .utils.validation import validate_selection_options class SearchForm(forms.ModelForm): @@ -35,7 +41,14 @@ class Meta: ] help_texts = { "title": "Title of the extraction job", - "query": "A query to find reports for further analysis", + "query": ( + "Search query to filter reports. " + "This query was auto-generated from your extraction fields" + " - you can edit or refine it." + ), + } + widgets = { + "query": forms.TextInput(attrs={"placeholder": "Auto-generated query (editable)"}), } def __init__(self, *args, **kwargs): @@ -43,41 +56,14 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fields["language"].choices = [ # type: ignore - (language.pk, LANGUAGE_LABELS[language.code]) - for language in Language.objects.order_by("code") - ] - self.fields["modalities"].choices = [ # type: ignore - (modality.pk, modality.code) - for modality in Modality.objects.filter(filterable=True).order_by("code") - ] - self.fields["modalities"].widget.attrs["size"] = 6 + self.fields["query"].required = True + self.fields["language"] = create_language_field() + self.fields["modalities"] = create_modality_field() self.fields["study_date_from"].widget = forms.DateInput(attrs={"type": "date"}) self.fields["study_date_till"].widget = forms.DateInput(attrs={"type": "date"}) - self.fields["age_from"] = forms.IntegerField( - required=False, - min_value=MIN_AGE, - max_value=MAX_AGE, - widget=forms.NumberInput( - attrs={ - "type": "range", - "step": AGE_STEP, - "value": MIN_AGE, - } - ), - ) - self.fields["age_till"] = forms.IntegerField( - required=False, - min_value=MIN_AGE, - max_value=MAX_AGE, - widget=forms.NumberInput( - attrs={ - "type": "range", - "step": AGE_STEP, - "value": MAX_AGE, - } - ), - ) + age_from, age_till = create_age_range_fields() + self.fields["age_from"] = age_from + self.fields["age_till"] = age_till self.helper = FormHelper() self.helper.form_tag = False @@ -89,8 +75,12 @@ def build_layout(self): Row( Column( "title", + # Query generation section (async HTMX) + HTML('{% include "extractions/_query_generation_section.html" %}'), "query", - Submit("next", "Next Step (Output Fields)", css_class="btn-primary"), + # Preview div from template include + HTML('{% include "extractions/_search_preview_form_section.html" %}'), + Submit("next", "Next Step (Summary)", css_class="btn-primary"), ), Column( "language", @@ -107,29 +97,37 @@ def build_layout(self): ) def clean_query(self) -> str: - query = self.cleaned_data["query"] - query_node, _ = QueryParser().parse(query) + query = self.cleaned_data["query"].strip() + if not query: + raise forms.ValidationError( + "A search query is required. " + "Please enter a query or go back to regenerate from fields." + ) + query_node, fixes = QueryParser().parse(query) if query_node is None: - raise forms.ValidationError("Invalid empty query") + raise forms.ValidationError("Invalid query syntax") + else: + self.cleaned_data["query_node"] = query_node + if len(fixes) > 0: + query = QueryParser.unparse(query_node) return query def clean(self) -> dict[str, Any] | None: cleaned_data = super().clean() assert cleaned_data + # If query validation failed, query_node won't exist - exit early + if "query_node" not in cleaned_data: + return cleaned_data + active_group = self.user.active_group language = cast(Language, cleaned_data["language"]) modalities = cast(QuerySet[Modality], cleaned_data["modalities"]) - query_node, fixes = QueryParser().parse(cleaned_data["query"]) - assert query_node - - if len(fixes) > 0: - cleaned_data["fixed_query"] = QueryParser.unparse(query_node) - + # Calculate retrieval count with inline Search construction search = Search( - query=query_node, + query=cleaned_data["query_node"], offset=0, limit=0, filters=SearchFilters( @@ -147,14 +145,16 @@ def clean(self) -> dict[str, Any] | None: if extraction_retrieval_provider is None: raise forms.ValidationError("Extraction retrieval provider is not configured.") + retrieval_count = extraction_retrieval_provider.count(search) cleaned_data["retrieval_count"] = retrieval_count + # Validate against limits if retrieval_count > settings.EXTRACTION_MAXIMUM_REPORTS_COUNT: raise forms.ValidationError( f"Your search returned more results ({retrieval_count}) than the extraction " f"pipeline can handle (max. {settings.EXTRACTION_MAXIMUM_REPORTS_COUNT}). " - "Please refine your search." + "Please refine your search query." ) if ( @@ -170,14 +170,135 @@ def clean(self) -> dict[str, Any] | None: class OutputFieldForm(forms.ModelForm): + """Hidden field to store selection options and array flag as JSON string. + This is done because the selection options are dynamic and the array toggle + is an alpine component that needs to be re-rendered on every change.""" + + selection_options = forms.CharField( + required=False, + widget=forms.HiddenInput(), + ) + is_array = forms.CharField( + required=False, + widget=forms.HiddenInput(), + ) + class Meta: model = OutputField fields = [ "name", "description", "output_type", + "selection_options", + "is_array", ] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.fields["name"].required = True + self.fields["description"].required = True + self.fields["description"].widget = forms.Textarea(attrs={"rows": 3}) + self.fields["selection_options"].widget.attrs.update( + { + "data-selection-input": "true", + "data-max-selection-options": str(MAX_SELECTION_OPTIONS), + } + ) + self.fields["is_array"].widget.attrs.update( + { + "data-array-input": "true", + } + ) + + initial_options = self.instance.selection_options if self.instance.pk else [] + self.initial["selection_options"] = json.dumps(initial_options) + self.initial["is_array"] = "true" if self.instance.is_array else "false" + + self.helper = FormHelper() + self.helper.form_tag = False + self.helper.disable_csrf = True + + # Build the layout for selection options and array toggle button using crispy. + fields = [ + Field("id", type="hidden"), + Row( + Column("name", css_class="col-md-7 col-12"), + Column("output_type", css_class="col-md-4 col-10"), + Column( + HTML( + ( + '' + ) + ), + css_class=( + "col-md-1 col-2 d-flex align-items-center " + "justify-content-end array-toggle-field" + ), + ), + css_class="g-3 align-items-center", + ), + "description", + # Include the selection options widget partial template here. + Div( + HTML('{% include "extractions/_selection_options_field.html" %}'), + css_class="selection-options-wrapper", + ), + ] + + if "DELETE" in self.fields: + fields.insert(1, Field("DELETE", type="hidden")) + + self.helper.layout = Layout(Div(*fields)) + + def clean_selection_options(self) -> list[str]: + raw_value = self.cleaned_data.get("selection_options") or "" + raw_value = raw_value.strip() + if raw_value == "": + return [] + + try: + parsed = json.loads(raw_value) + except json.JSONDecodeError as exc: + raise forms.ValidationError("Invalid selection data.") from exc + + return validate_selection_options(parsed) + + def clean_is_array(self) -> bool: + raw_value = (self.cleaned_data.get("is_array") or "").strip().lower() + if raw_value in {"1", "true", "on"}: + return True + return False + + def clean(self): + cleaned_data = super().clean() + if not cleaned_data: + return cleaned_data + + output_type = cleaned_data.get("output_type") + selection_options: list[str] = cleaned_data.get("selection_options") or [] + + if output_type == OutputType.SELECTION: + if not selection_options: + self.add_error( + "selection_options", + "Add at least one selection to use the Selection type.", + ) + else: + if selection_options: + self.add_error( + "selection_options", + "Selections are only allowed when Output Type is Selection.", + ) + cleaned_data["selection_options"] = [] + + return cleaned_data + OutputFieldFormSet = forms.inlineformset_factory( ExtractionJob, diff --git a/radis/extractions/migrations/0004_outputfield_selection_options.py b/radis/extractions/migrations/0004_outputfield_selection_options.py new file mode 100644 index 00000000..f01877a8 --- /dev/null +++ b/radis/extractions/migrations/0004_outputfield_selection_options.py @@ -0,0 +1,29 @@ +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("extractions", "0003_alter_extractionjob_options_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="outputfield", + name="selection_options", + field=models.JSONField(blank=True, default=list), + ), + migrations.AlterField( + model_name="outputfield", + name="output_type", + field=models.CharField( + choices=[ + ("T", "Text"), + ("N", "Numeric"), + ("B", "Boolean"), + ("S", "Selection"), + ], + default="T", + max_length=1, + ), + ), + ] diff --git a/radis/extractions/migrations/0005_outputfield_is_array.py b/radis/extractions/migrations/0005_outputfield_is_array.py new file mode 100644 index 00000000..568a02c3 --- /dev/null +++ b/radis/extractions/migrations/0005_outputfield_is_array.py @@ -0,0 +1,15 @@ +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("extractions", "0004_outputfield_selection_options"), + ] + + operations = [ + migrations.AddField( + model_name="outputfield", + name="is_array", + field=models.BooleanField(default=False), + ), + ] diff --git a/radis/extractions/migrations/0004_remove_extractionjob_provider.py b/radis/extractions/migrations/0006_remove_extractionjob_provider.py similarity index 80% rename from radis/extractions/migrations/0004_remove_extractionjob_provider.py rename to radis/extractions/migrations/0006_remove_extractionjob_provider.py index 24640eaa..84c59f4c 100644 --- a/radis/extractions/migrations/0004_remove_extractionjob_provider.py +++ b/radis/extractions/migrations/0006_remove_extractionjob_provider.py @@ -6,7 +6,7 @@ class Migration(migrations.Migration): dependencies = [ - ("extractions", "0003_alter_extractionjob_options_and_more"), + ("extractions", "0005_outputfield_is_array"), ] operations = [ diff --git a/radis/extractions/migrations/0007_remove_outputfield_unique.py b/radis/extractions/migrations/0007_remove_outputfield_unique.py new file mode 100644 index 00000000..bc12db2d --- /dev/null +++ b/radis/extractions/migrations/0007_remove_outputfield_unique.py @@ -0,0 +1,33 @@ +# Generated by Django 5.2.8 on 2025-11-17 23:19 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("extractions", "0006_remove_extractionjob_provider"), + ] + + operations = [ + migrations.RemoveConstraint( + model_name="outputfield", + name="unique_output_field_name_per_job", + ), + migrations.RemoveField( + model_name="outputfield", + name="optional", + ), + migrations.AlterField( + model_name="outputfield", + name="job", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="output_fields", + to="extractions.extractionjob", + ), + ), + ] diff --git a/radis/extractions/migrations/0008_outputfield_subscription_and_more.py b/radis/extractions/migrations/0008_outputfield_subscription_and_more.py new file mode 100644 index 00000000..11027f91 --- /dev/null +++ b/radis/extractions/migrations/0008_outputfield_subscription_and_more.py @@ -0,0 +1,53 @@ +# Generated by Django 5.2.8 on 2025-11-17 23:39 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("extractions", "0007_remove_outputfield_unique"), + ("subscriptions", "0011_rename_answers_subscribeditem_filter_results_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="outputfield", + name="subscription", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="output_fields", + to="subscriptions.subscription", + ), + ), + migrations.AddConstraint( + model_name="outputfield", + constraint=models.UniqueConstraint( + condition=models.Q(("job__isnull", False)), + fields=("name", "job_id"), + name="unique_output_field_name_per_job", + ), + ), + migrations.AddConstraint( + model_name="outputfield", + constraint=models.UniqueConstraint( + condition=models.Q(("subscription__isnull", False)), + fields=("name", "subscription_id"), + name="unique_output_field_name_per_subscription", + ), + ), + migrations.AddConstraint( + model_name="outputfield", + constraint=models.CheckConstraint( + condition=models.Q( + models.Q(("job__isnull", False), ("subscription__isnull", True)), + models.Q(("job__isnull", True), ("subscription__isnull", False)), + _connector="OR", + ), + name="output_field_exactly_one_parent", + ), + ), + ] diff --git a/radis/extractions/models.py b/radis/extractions/models.py index 97e6fa98..dc0f972e 100644 --- a/radis/extractions/models.py +++ b/radis/extractions/models.py @@ -4,6 +4,7 @@ from django.conf import settings from django.contrib.auth.models import Group from django.db import models +from django.db.models import Q from django.urls import reverse from procrastinate.contrib.django import app from procrastinate.contrib.django.models import ProcrastinateJob @@ -75,6 +76,7 @@ class OutputType(models.TextChoices): TEXT = "T", "Text" NUMERIC = "N", "Numeric" BOOLEAN = "B", "Boolean" + SELECTION = "S", "Selection" class OutputField(models.Model): @@ -84,22 +86,64 @@ class OutputField(models.Model): max_length=1, choices=OutputType.choices, default=OutputType.TEXT ) get_output_type_display: Callable[[], str] - optional = models.BooleanField(default=False) + selection_options = models.JSONField(default=list, blank=True) + is_array = models.BooleanField(default=False) job = models.ForeignKey[ExtractionJob]( - ExtractionJob, on_delete=models.CASCADE, related_name="output_fields" + ExtractionJob, null=True, blank=True, on_delete=models.CASCADE, related_name="output_fields" + ) + subscription = models.ForeignKey( + "subscriptions.Subscription", + null=True, + blank=True, + on_delete=models.CASCADE, + related_name="output_fields", ) class Meta: constraints = [ models.UniqueConstraint( fields=["name", "job_id"], + condition=Q(job__isnull=False), name="unique_output_field_name_per_job", - ) + ), + models.UniqueConstraint( + fields=["name", "subscription_id"], + condition=Q(subscription__isnull=False), + name="unique_output_field_name_per_subscription", + ), + models.CheckConstraint( + condition=( + Q(job__isnull=False, subscription__isnull=True) + | Q(job__isnull=True, subscription__isnull=False) + ), + name="output_field_exactly_one_parent", + ), ] def __str__(self) -> str: return f'Output Field "{self.name}" [{self.pk}]' + def clean(self) -> None: + from django.core.exceptions import ValidationError + + from radis.extractions.utils.validation import validate_selection_options + + super().clean() + + if self.output_type == OutputType.SELECTION: + if not self.selection_options: + raise ValidationError({"selection_options": "Add at least one selection option."}) + try: + self.selection_options = validate_selection_options(self.selection_options) + except ValidationError as e: + raise ValidationError({"selection_options": e.message}) + else: + if self.selection_options: + raise ValidationError( + {"selection_options": "Selections are only allowed for the Selection type."} + ) + self.selection_options = [] + class ExtractionTask(AnalysisTask): job = models.ForeignKey[ExtractionJob]( diff --git a/radis/extractions/processors.py b/radis/extractions/processors.py index e4c7d886..ee2e7a31 100644 --- a/radis/extractions/processors.py +++ b/radis/extractions/processors.py @@ -46,11 +46,12 @@ def process_instance(self, instance: ExtractionInstance) -> None: def process_output_fields(self, instance: ExtractionInstance) -> None: job = instance.task.job - Schema = generate_output_fields_schema(job.output_fields) + output_fields = list(job.output_fields.order_by("pk")) + Schema = generate_output_fields_schema(output_fields) prompt = Template(settings.OUTPUT_FIELDS_SYSTEM_PROMPT).substitute( { "report": instance.text, - "fields": generate_output_fields_prompt(job.output_fields), + "fields": generate_output_fields_prompt(output_fields), } ) result = self.client.extract_data(prompt.strip(), Schema) diff --git a/radis/extractions/static/extractions/extractions.css b/radis/extractions/static/extractions/extractions.css index 17c43b36..c2164fdc 100644 --- a/radis/extractions/static/extractions/extractions.css +++ b/radis/extractions/static/extractions/extractions.css @@ -1,3 +1,38 @@ #filters .asteriskField { display: none; } + +/* + * Selection Options Component + */ + +.selection-options-controls, +.selection-options-actions { + gap: 0.5rem; +} + +.array-toggle-field { + display: flex; + align-items: center; + justify-content: flex-end; + padding-top: 14px; +} + +.array-toggle-btn { + width: calc(2.5rem - 6px); + height: calc(2.5rem - 6px); + border-radius: 50%; + font-weight: 600; + display: inline-flex; + align-items: center; + justify-content: center; + padding: 0; + line-height: 1; + font-size: 1rem; +} + +.array-toggle-btn.active { + background-color: var(--bs-primary); + border-color: var(--bs-primary); + color: #fff; +} diff --git a/radis/extractions/static/extractions/extractions.js b/radis/extractions/static/extractions/extractions.js index e69de29b..0d400738 100644 --- a/radis/extractions/static/extractions/extractions.js +++ b/radis/extractions/static/extractions/extractions.js @@ -0,0 +1,141 @@ +/** + * Manages the dynamic selection options input for extraction output fields. + * + * @param {HTMLElement} rootEl + * @returns {Object} + */ +function SelectionOptions(rootEl) { + const hiddenInput = rootEl.querySelector("[data-selection-input]"); + const arrayInput = rootEl.querySelector("[data-array-input]"); + const formContainer = + rootEl.closest(".formset-form") ?? rootEl.closest("form") ?? rootEl; + const outputTypeField = + formContainer.querySelector('select[name$="-output_type"]') ?? + formContainer.querySelector('select[name="output_type"]'); + const arrayToggleButton = + formContainer.querySelector("[data-array-toggle]") ?? null; + const parseArrayValue = (value) => { + if (!value) { + return false; + } + const normalized = value.trim().toLowerCase(); + return normalized === "true" || normalized === "1" || normalized === "on"; + }; + const parseMaxOptions = () => { + const datasetValue = + hiddenInput?.dataset.maxSelectionOptions ?? + rootEl.dataset.maxSelectionOptions ?? + ""; + const parsed = Number.parseInt(datasetValue, 10); + return Number.isNaN(parsed) ? 0 : parsed; + }; + const initialMaxOptions = parseMaxOptions(); + + return { + options: [], + maxOptions: initialMaxOptions, + supportsSelection: false, + isArray: parseArrayValue(arrayInput?.value), + lastSelectionOptions: [], + init() { + this.options = this.parseOptions(hiddenInput?.value); + this.isArray = parseArrayValue(arrayInput?.value); + this.lastSelectionOptions = [...this.options]; + this.updateSupports(); + if (arrayToggleButton) { + arrayToggleButton.addEventListener("click", (event) => { + event.preventDefault(); + this.toggleArray(); + }); + this.updateArrayButton(); + } + if (outputTypeField) { + outputTypeField.addEventListener("change", () => { + const wasSelection = this.supportsSelection; + this.updateSupports(); + if (!this.supportsSelection) { + this.lastSelectionOptions = [...this.options]; + this.options = []; + } else if (!wasSelection && this.options.length === 0) { + if (this.lastSelectionOptions.length > 0) { + this.options = [...this.lastSelectionOptions]; + } else { + this.options = this.parseOptions(hiddenInput?.value); + } + } + }); + } + }, + parseOptions(value) { + if (!value) { + return []; + } + try { + const parsed = JSON.parse(value); + if (Array.isArray(parsed)) { + return parsed + .map((opt) => (typeof opt === "string" ? opt : "")) + .filter((opt) => opt !== ""); + } + } catch (err) { + console.warn("Invalid selection options payload", err); + } + return []; + }, + updateSupports() { + this.supportsSelection = outputTypeField + ? outputTypeField.value === "S" + : false; + }, + syncOptions() { + if (!hiddenInput) { + return; + } + const sanitized = this.options + .map((opt) => (typeof opt === "string" ? opt.trim() : "")) + .filter((opt) => opt !== ""); + hiddenInput.value = JSON.stringify(sanitized); + this.lastSelectionOptions = [...sanitized]; + }, + syncArray() { + if (!arrayInput) { + return; + } + arrayInput.value = this.isArray ? "true" : "false"; + }, + syncState() { + this.syncOptions(); + this.syncArray(); + this.updateArrayButton(); + }, + addOption() { + if (!this.supportsSelection || this.options.length >= this.maxOptions) { + return; + } + this.options.push(""); + this.$nextTick(() => { + const inputs = rootEl.querySelectorAll("[data-selection-option-input]"); + const lastInput = inputs[inputs.length - 1]; + if (lastInput instanceof HTMLInputElement) { + lastInput.focus(); + } + }); + }, + removeOption(index) { + this.options.splice(index, 1); + }, + toggleArray() { + this.isArray = !this.isArray; + }, + updateArrayButton() { + if (!arrayToggleButton) { + return; + } + arrayToggleButton.classList.toggle("active", this.isArray); + arrayToggleButton.setAttribute( + "aria-pressed", + this.isArray ? "true" : "false" + ); + }, + }; +} diff --git a/radis/extractions/templates/extractions/_query_generation_result.html b/radis/extractions/templates/extractions/_query_generation_result.html new file mode 100644 index 00000000..6205dca0 --- /dev/null +++ b/radis/extractions/templates/extractions/_query_generation_result.html @@ -0,0 +1,44 @@ +{% load bootstrap_icon from common_extras %} +{% if error %} + +{% elif generated_query %} + + {# Update the query input field using out-of-band swap #} + +{% else %} + +{% endif %} diff --git a/radis/extractions/templates/extractions/_query_generation_section.html b/radis/extractions/templates/extractions/_query_generation_section.html new file mode 100644 index 00000000..1a73b3b2 --- /dev/null +++ b/radis/extractions/templates/extractions/_query_generation_section.html @@ -0,0 +1,21 @@ +{% load bootstrap_icon from common_extras %} +
+ +
+
+ +
+
{% bootstrap_icon "magic" %} Generating Search Query...
+

+ Using AI to create a search query from your extraction fields... +

+
+
+
+
diff --git a/radis/extractions/templates/extractions/_search_preview.html b/radis/extractions/templates/extractions/_search_preview.html new file mode 100644 index 00000000..d5b4354e --- /dev/null +++ b/radis/extractions/templates/extractions/_search_preview.html @@ -0,0 +1,36 @@ +{% load bootstrap_icon from common_extras %} + +{% if error %} + +{% elif count is None %} + +{% else %} + +{% endif %} diff --git a/radis/extractions/templates/extractions/_search_preview_form_section.html b/radis/extractions/templates/extractions/_search_preview_form_section.html new file mode 100644 index 00000000..5b73b2a0 --- /dev/null +++ b/radis/extractions/templates/extractions/_search_preview_form_section.html @@ -0,0 +1,15 @@ +{% load bootstrap_icon from common_extras %} + +
+ +
diff --git a/radis/extractions/templates/extractions/_selection_options_field.html b/radis/extractions/templates/extractions/_selection_options_field.html new file mode 100644 index 00000000..e1c81a9b --- /dev/null +++ b/radis/extractions/templates/extractions/_selection_options_field.html @@ -0,0 +1,46 @@ +{% load bootstrap_icon from common_extras %} +
+ {{ form.selection_options }} + {{ form.is_array }} +
+
+ + +
+
+
+ +

No selections defined yet.

+
+

+ Choose the “Selection” output type to define enumerated values. +

+ {% if form.selection_options.errors %} +
+ {% for error in form.selection_options.errors %}{{ error }}{% endfor %} +
+ {% endif %} +
diff --git a/radis/extractions/templates/extractions/extraction_job_detail.html b/radis/extractions/templates/extractions/extraction_job_detail.html index 942a2fc8..3a083766 100644 --- a/radis/extractions/templates/extractions/extraction_job_detail.html +++ b/radis/extractions/templates/extractions/extraction_job_detail.html @@ -114,6 +114,20 @@
Output Fields
{{ field.output_type|human_readable_output_type }}
+ {% if field.is_array %} +
Array Output
+
+ Yes — return multiple {{ field.output_type|human_readable_output_type|lower }} values +
+ {% endif %} + {% if field.selection_options %} +
Selections
+
+ +
+ {% endif %} {% endfor %} @@ -121,11 +135,18 @@
Output Fields
{% if not job.is_preparing %} - - {% bootstrap_icon "eye" %} - View Results - + {% crispy filter.form %} diff --git a/radis/extractions/templates/extractions/extraction_job_output_fields_form.html b/radis/extractions/templates/extractions/extraction_job_output_fields_form.html index 4ab54deb..83fbca03 100644 --- a/radis/extractions/templates/extractions/extraction_job_output_fields_form.html +++ b/radis/extractions/templates/extractions/extraction_job_output_fields_form.html @@ -5,7 +5,7 @@ New Extraction Job {% endblock title %} {% block heading %} - + - + class="btn btn-secondary" + disabled>Previous Step +
- {% if fixed_query %} - - {% endif %}

- Your search will hit approximately {{ retrieval_count }} report{{ retrieval_count|pluralize }}. - The more reports you analyze the longer it will take. If you like to refine your search you can go back to - the previous step and adjust the query and/or filters. + Step 1: Define the fields you want to extract from medical reports.

- If the result count sounds reasonable then provide the fields to extract. Choose a short and concise name and more detailed description. + Choose a short and concise name and more detailed description for each field. You can add up to 5 fields.

+

+ In the next step, we will automatically generate a search query based on these fields, + which you can then review and refine. +

Example for an output field:
diff --git a/radis/extractions/templates/extractions/extraction_job_search_form.html b/radis/extractions/templates/extractions/extraction_job_search_form.html index dc59fd95..999f752f 100644 --- a/radis/extractions/templates/extractions/extraction_job_search_form.html +++ b/radis/extractions/templates/extractions/extraction_job_search_form.html @@ -5,7 +5,7 @@ New Extraction Job {% endblock title %} {% block heading %} - + + class="btn btn-secondary">Previous Step (Search Query)

diff --git a/radis/extractions/templates/extractions/extraction_result_list.html b/radis/extractions/templates/extractions/extraction_result_list.html index a0abdb4e..b44fd9c5 100644 --- a/radis/extractions/templates/extractions/extraction_result_list.html +++ b/radis/extractions/templates/extractions/extraction_result_list.html @@ -7,11 +7,18 @@ {% block heading %} - - {% bootstrap_icon "arrow-return-left" %} - View Job - + {% endblock heading %} diff --git a/radis/extractions/tests/test_forms.py b/radis/extractions/tests/test_forms.py new file mode 100644 index 00000000..d91fc909 --- /dev/null +++ b/radis/extractions/tests/test_forms.py @@ -0,0 +1,186 @@ +import json + +import pytest +from django.core.exceptions import ValidationError + +from radis.extractions.factories import ExtractionJobFactory +from radis.extractions.forms import OutputFieldForm +from radis.extractions.models import OutputField, OutputType + + +@pytest.mark.django_db +def test_output_field_form_accepts_selection_options(): + job = ExtractionJobFactory.create() + form = OutputFieldForm( + data={ + "name": "tumor_grade", + "description": "Classified tumor grade.", + "output_type": OutputType.SELECTION, + "selection_options": json.dumps(["Grade 1", "Grade 2"]), + "is_array": "false", + }, + instance=OutputField(job=job), + ) + + assert form.is_valid() + instance = form.save(commit=False) + + assert instance.selection_options == ["Grade 1", "Grade 2"] + + +@pytest.mark.django_db +def test_output_field_form_requires_options_for_selection(): + job = ExtractionJobFactory.create() + form = OutputFieldForm( + data={ + "name": "tumor_grade", + "description": "Classified tumor grade.", + "output_type": OutputType.SELECTION, + "selection_options": json.dumps([]), + "is_array": "false", + }, + instance=OutputField(job=job), + ) + + assert not form.is_valid() + assert "selection_options" in form.errors + + +@pytest.mark.django_db +def test_output_field_form_rejects_options_for_non_selection(): + job = ExtractionJobFactory.create() + form = OutputFieldForm( + data={ + "name": "tumor_grade", + "description": "Classified tumor grade.", + "output_type": OutputType.TEXT, + "selection_options": json.dumps(["Grade 1"]), + "is_array": "false", + }, + instance=OutputField(job=job), + ) + + assert not form.is_valid() + assert "selection_options" in form.errors + + +@pytest.mark.django_db +def test_output_field_clean_trims_selection_options(): + job = ExtractionJobFactory.create() + field = OutputField( + job=job, + name="tumor_grade", + description="Classified tumor grade.", + output_type=OutputType.SELECTION, + selection_options=[" Grade 1 ", "Grade 2 "], + ) + + field.full_clean() + + assert field.selection_options == ["Grade 1", "Grade 2"] + + +@pytest.mark.django_db +def test_output_field_clean_rejects_selection_options_for_other_types(): + job = ExtractionJobFactory.create() + field = OutputField( + job=job, + name="tumor_grade", + description="Classified tumor grade.", + output_type=OutputType.TEXT, + selection_options=["Grade 1"], + ) + + with pytest.raises(ValidationError): + field.full_clean() + + +@pytest.mark.django_db +def test_output_field_form_handles_array_toggle(): + job = ExtractionJobFactory.create() + form = OutputFieldForm( + data={ + "name": "measurements", + "description": "Multiple numeric values.", + "output_type": OutputType.NUMERIC, + "selection_options": json.dumps([]), + "is_array": "true", + }, + instance=OutputField(job=job), + ) + + assert form.is_valid() + instance = form.save(commit=False) + assert instance.is_array is True + + +@pytest.mark.django_db +def test_output_field_form_rejects_duplicate_selection_options(): + job = ExtractionJobFactory.create() + form = OutputFieldForm( + data={ + "name": "tumor_grade", + "description": "Classified tumor grade.", + "output_type": OutputType.SELECTION, + "selection_options": json.dumps(["Grade 1", "Grade 1"]), + "is_array": "false", + }, + instance=OutputField(job=job), + ) + + assert not form.is_valid() + assert "selection_options" in form.errors + + +@pytest.mark.django_db +def test_output_field_clean_rejects_duplicate_selection_options(): + job = ExtractionJobFactory.create() + field = OutputField( + job=job, + name="tumor_grade", + description="Classified tumor grade.", + output_type=OutputType.SELECTION, + selection_options=["Grade 1", "Grade 1"], + ) + + with pytest.raises(ValidationError): + field.full_clean() + + +@pytest.mark.django_db +def test_output_field_form_rejects_whitespace_only_selection_option(): + job = ExtractionJobFactory.create() + form = OutputFieldForm( + data={ + "name": "tumor_grade", + "description": "Classified tumor grade.", + "output_type": OutputType.SELECTION, + "selection_options": json.dumps(["Grade 1", " "]), + "is_array": "false", + }, + instance=OutputField(job=job), + ) + + assert not form.is_valid() + assert "selection_options" in form.errors + + +@pytest.mark.django_db +def test_output_field_form_accepts_unicode_and_long_selection_options(): + job = ExtractionJobFactory.create() + unicode_option = "Grädé αβγ測試" + long_option = "Grade " + ("X" * 150) + form = OutputFieldForm( + data={ + "name": "tumor_grade", + "description": "Classified tumor grade.", + "output_type": OutputType.SELECTION, + "selection_options": json.dumps([unicode_option, long_option]), + "is_array": "false", + }, + instance=OutputField(job=job), + ) + + assert form.is_valid() + instance = form.save(commit=False) + assert instance.selection_options == [unicode_option.strip(), long_option.strip()] diff --git a/radis/extractions/tests/test_query_generator.py b/radis/extractions/tests/test_query_generator.py new file mode 100644 index 00000000..7a7808a2 --- /dev/null +++ b/radis/extractions/tests/test_query_generator.py @@ -0,0 +1,122 @@ +"""Unit tests for the AsyncQueryGenerator class.""" + +from unittest.mock import AsyncMock, patch + +from django.test import TestCase, override_settings + +from radis.extractions.models import OutputField, OutputType +from radis.extractions.utils.query_generator import AsyncQueryGenerator + + +class AsyncQueryGeneratorTest(TestCase): + """Test cases for AsyncQueryGenerator.""" + + def setUp(self): + """Set up test fixtures.""" + self.generator = AsyncQueryGenerator() + + async def test_generate_from_empty_fields(self): + """Test query generation with no fields returns None.""" + fields = [] + query, metadata = await self.generator.generate_from_fields(fields) + + assert query is None + assert metadata["generation_method"] is None + assert metadata["field_count"] == 0 + assert metadata["success"] is False + assert metadata["error"] == "No fields provided" + + @patch("radis.chats.utils.chat_client.AsyncChatClient") + async def test_llm_generation_success(self, mock_chat_client_class): + """Test successful LLM query generation.""" + # Mock the LLM response using AsyncMock for async methods + mock_client = AsyncMock() + mock_client.chat.return_value = '("lung nodule" OR "pulmonary nodule") AND size' + mock_chat_client_class.return_value = mock_client + + fields = [ + OutputField( + name="nodule_size", + description="size of lung nodule in millimeters", + output_type=OutputType.NUMERIC, + ) + ] + + with override_settings(ENABLE_AUTO_QUERY_GENERATION=True): + generator = AsyncQueryGenerator() # Create new instance with mocked client + query, metadata = await generator.generate_from_fields(fields) + + assert query != "" + assert metadata["success"] is True + assert query is not None + assert "nodule" in query.lower() + + def test_validate_and_fix_valid_query(self): + """Test validation of a valid query.""" + query = "lung AND nodule" + fixed_query, fixes = self.generator.validate_and_fix_query(query) + + assert fixed_query == query + assert len(fixes) == 0 + + def test_validate_and_fix_empty_query(self): + """Test validation of an empty query.""" + query = "" + fixed_query, fixes = self.generator.validate_and_fix_query(query) + + assert fixed_query == "" + assert len(fixes) == 0 + + def test_validate_and_fix_query_with_quotes(self): + """Test validation of a query with quoted phrases.""" + query = '"lung nodule" AND size' + fixed_query, fixes = self.generator.validate_and_fix_query(query) + + assert fixed_query != "" + assert "lung nodule" in fixed_query or "lung" in fixed_query + + def test_format_fields_for_prompt(self): + """Test formatting of fields for LLM prompt.""" + fields = [ + OutputField( + name="test_field", + description="test description", + output_type=OutputType.TEXT, + ) + ] + + formatted = self.generator._format_fields_for_prompt(fields) + + assert "test_field" in formatted + assert "test description" in formatted + assert "Text" in formatted + + def test_extract_query_from_response_simple(self): + """Test extraction of query from simple LLM response.""" + response = "lung AND nodule" + query = self.generator._extract_query_from_response(response) + + assert query == "lung AND nodule" + + def test_extract_query_from_response_with_prefix(self): + """Test extraction when LLM response has a prefix.""" + response = "Query: lung AND nodule" + query = self.generator._extract_query_from_response(response) + + assert query == "lung AND nodule" + + def test_extract_query_from_response_with_quotes(self): + """Test extraction when response is wrapped in quotes.""" + response = '"lung AND nodule"' + query = self.generator._extract_query_from_response(response) + + assert query == "lung AND nodule" + + def test_extract_query_from_response_multiline(self): + """Test extraction when LLM adds explanation on additional lines.""" + response = "lung AND nodule\n\nThis query will find..." + query = self.generator._extract_query_from_response(response) + + # Should only take the first line + assert query == "lung AND nodule" + assert "This query" not in query diff --git a/radis/extractions/tests/test_views.py b/radis/extractions/tests/test_views.py index bb297864..af099215 100644 --- a/radis/extractions/tests/test_views.py +++ b/radis/extractions/tests/test_views.py @@ -1,13 +1,14 @@ import pytest from adit_radis_shared.accounts.factories import GroupFactory, UserFactory from django.contrib.auth.models import Permission -from django.test import Client +from django.test import Client, override_settings from radis.core.models import AnalysisTask from radis.extractions.factories import ( ExtractionInstanceFactory, ExtractionJobFactory, ExtractionTaskFactory, + OutputFieldFactory, ) from radis.extractions.models import ExtractionJob from radis.reports.factories import LanguageFactory, ReportFactory @@ -29,6 +30,21 @@ def create_test_extraction_task(job=None): return ExtractionTaskFactory.create(job=job) +def _hide_toolbar(_request): + return False + + +def _collect_csv(response) -> str: + chunks: list[bytes] = [] + for chunk in response.streaming_content: + if isinstance(chunk, bytes): + chunks.append(chunk) + else: + chunks.append(chunk.encode("utf-8")) + csv_bytes = b"".join(chunks) + return csv_bytes.decode("utf-8-sig") + + @pytest.mark.django_db def test_extraction_job_list_view(client: Client): user = UserFactory.create(is_active=True) @@ -202,6 +218,89 @@ def test_extraction_result_list_view(client: Client): assert response.status_code == 200 +@override_settings(DEBUG_TOOLBAR_CONFIG={"SHOW_TOOLBAR_CALLBACK": _hide_toolbar}) +@pytest.mark.django_db +def test_extraction_result_download_view(client: Client): + user = UserFactory.create(is_active=True) + job = create_test_extraction_job(owner=user) + + OutputFieldFactory.create(job=job, name="field_one") + OutputFieldFactory.create(job=job, name="field_two") + OutputFieldFactory.create(job=job, name="field_bool") + + task = create_test_extraction_task(job=job) + language = LanguageFactory.create(code="en") + report = ReportFactory.create(language=language) + instance = ExtractionInstanceFactory.create( + task=task, + report=report, + is_processed=True, + output={"field_one": "value", "field_two": 42, "field_bool": False}, + ) + + client.force_login(user) + response = client.get(f"/extractions/jobs/{job.pk}/results/download/") + assert response.status_code == 200 + assert response["Content-Type"].startswith("text/csv") + assert f"extraction_job_{job.pk}" in response["Content-Disposition"] + + csv_text = _collect_csv(response) + lines = [line.strip() for line in csv_text.strip().splitlines()] + assert lines[0] == "instance_id,report_id,is_processed,field_one,field_two,field_bool" + assert lines[1] == f"{instance.pk},{instance.report.pk},yes,value,42,no" + + +@override_settings(DEBUG_TOOLBAR_CONFIG={"SHOW_TOOLBAR_CALLBACK": _hide_toolbar}) +@pytest.mark.django_db +def test_extraction_result_download_view_unauthorized(client: Client): + owner = UserFactory.create(is_active=True) + other_user = UserFactory.create(is_active=True) + job = create_test_extraction_job(owner=owner) + client.force_login(other_user) + response = client.get(f"/extractions/jobs/{job.pk}/results/download/") + assert response.status_code == 404 + + +@override_settings(DEBUG_TOOLBAR_CONFIG={"SHOW_TOOLBAR_CALLBACK": _hide_toolbar}) +@pytest.mark.django_db +def test_extraction_result_download_view_no_instances(client: Client): + user = UserFactory.create(is_active=True) + job = create_test_extraction_job(owner=user) + OutputFieldFactory.create(job=job, name="field_one") + + client.force_login(user) + response = client.get(f"/extractions/jobs/{job.pk}/results/download/") + assert response.status_code == 200 + + csv_text = _collect_csv(response) + assert csv_text.strip() == "instance_id,report_id,is_processed,field_one" + + +@override_settings(DEBUG_TOOLBAR_CONFIG={"SHOW_TOOLBAR_CALLBACK": _hide_toolbar}) +@pytest.mark.django_db +def test_extraction_result_download_view_no_output_fields(client: Client): + user = UserFactory.create(is_active=True) + job = create_test_extraction_job(owner=user) + task = create_test_extraction_task(job=job) + language = LanguageFactory.create(code="en") + report = ReportFactory.create(language=language) + instance = ExtractionInstanceFactory.create( + task=task, + report=report, + is_processed=False, + output={}, + ) + + client.force_login(user) + response = client.get(f"/extractions/jobs/{job.pk}/results/download/") + assert response.status_code == 200 + + csv_text = _collect_csv(response) + lines = csv_text.strip().splitlines() + assert lines[0] == "instance_id,report_id,is_processed" + assert lines[1] == f"{instance.pk},{instance.report.pk},no" + + @pytest.mark.django_db def test_extraction_task_detail_view(client: Client): user = UserFactory.create(is_active=True) diff --git a/radis/extractions/tests/unit/test_processor_utils.py b/radis/extractions/tests/unit/test_processor_utils.py new file mode 100644 index 00000000..45fbfbb1 --- /dev/null +++ b/radis/extractions/tests/unit/test_processor_utils.py @@ -0,0 +1,48 @@ +from typing import Literal, get_args, get_origin + +import pytest + +from radis.extractions.factories import ExtractionJobFactory, OutputFieldFactory +from radis.extractions.models import OutputType +from radis.extractions.utils.processor_utils import generate_output_fields_schema + + +@pytest.mark.django_db +def test_generate_output_fields_schema_uses_literal_for_selection_fields(): + job = ExtractionJobFactory.create() + field = OutputFieldFactory( + job=job, + name="grade", + output_type=OutputType.SELECTION, + ) + field.selection_options = ["Grade 1", "Grade 2"] + field.save() + + schema = generate_output_fields_schema(job.output_fields.all()) + + grade_field = schema.model_fields["grade"] + annotation = grade_field.annotation + assert get_origin(annotation) is Literal + assert set(get_args(annotation)) == {"Grade 1", "Grade 2"} + + +@pytest.mark.django_db +def test_generate_output_fields_schema_wraps_literal_in_list_for_array_selections(): + job = ExtractionJobFactory.create() + field = OutputFieldFactory( + job=job, + name="grade_multi", + output_type=OutputType.SELECTION, + ) + field.selection_options = ["High", "Low"] + field.is_array = True + field.save() + + schema = generate_output_fields_schema(job.output_fields.all()) + + grade_field = schema.model_fields["grade_multi"] + annotation = grade_field.annotation + assert get_origin(annotation) is list + (inner_annotation,) = get_args(annotation) + assert get_origin(inner_annotation) is Literal + assert set(get_args(inner_annotation)) == {"High", "Low"} diff --git a/radis/extractions/urls.py b/radis/extractions/urls.py index df459e25..6b9ef946 100644 --- a/radis/extractions/urls.py +++ b/radis/extractions/urls.py @@ -12,7 +12,10 @@ ExtractionJobRetryView, ExtractionJobVerifyView, ExtractionJobWizardView, + ExtractionQueryGeneratorView, + ExtractionResultDownloadView, ExtractionResultListView, + ExtractionSearchPreviewView, ExtractionTaskDeleteView, ExtractionTaskDetailView, ExtractionTaskResetView, @@ -39,6 +42,16 @@ ExtractionJobWizardView.as_view(), name="extraction_job_create", ), + path( + "jobs/new/search-preview/", + ExtractionSearchPreviewView.as_view(), + name="extraction_search_preview", + ), + path( + "jobs/new/generate-query/", + ExtractionQueryGeneratorView.as_view(), + name="extraction_generate_query", + ), path( "jobs//", ExtractionJobDetailView.as_view(), @@ -99,6 +112,11 @@ ExtractionResultListView.as_view(), name="extraction_result_list", ), + path( + "jobs//results/download/", + ExtractionResultDownloadView.as_view(), + name="extraction_result_download", + ), path( "instances//", ExtractionInstanceDetailView.as_view(), diff --git a/radis/extractions/utils/__init__.py b/radis/extractions/utils/__init__.py index e69de29b..c32e9b87 100644 --- a/radis/extractions/utils/__init__.py +++ b/radis/extractions/utils/__init__.py @@ -0,0 +1,5 @@ +"""Utility helpers for extraction workflows.""" + +from .csv_export import iter_extraction_result_rows + +__all__ = ["iter_extraction_result_rows"] diff --git a/radis/extractions/utils/csv_export.py b/radis/extractions/utils/csv_export.py new file mode 100644 index 00000000..e51da187 --- /dev/null +++ b/radis/extractions/utils/csv_export.py @@ -0,0 +1,54 @@ +"""Helpers for exporting extraction results in CSV format.""" + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from typing import Any + +from radis.extractions.models import ExtractionInstance, ExtractionJob + + +def _format_cell(value: Any) -> str: + """Format a single output value for CSV export.""" + if value is None: + return "" + if isinstance(value, bool): + return "yes" if value else "no" + return str(value) + + +def iter_extraction_result_rows(job: ExtractionJob) -> Iterable[Sequence[str]]: + """Yield rows for the extraction results CSV. + + Args: + job: The extraction job whose results should be exported. + + Yields: + Sequences of stringified cell values suitable for csv.writer. + """ + + field_names: list[str] = list(job.output_fields.order_by("pk").values_list("name", flat=True)) + + header = ["instance_id", "report_id", "is_processed"] + header.extend(field_names) + yield header + + instances = ( + ExtractionInstance.objects.filter(task__job=job) + .order_by("pk") + .values_list("pk", "report_id", "is_processed", "output") + ) + + for instance_id, report_id, is_processed, output in instances.iterator(): + row: list[str] = [ + str(instance_id), + str(report_id) if report_id else "", + "yes" if is_processed else "no", + ] + + output_dict: dict[str, Any] = output or {} + for field_name in field_names: + value = output_dict.get(field_name) + row.append(_format_cell(value)) + + yield row diff --git a/radis/extractions/utils/processor_utils.py b/radis/extractions/utils/processor_utils.py index e7f12bca..3e338351 100644 --- a/radis/extractions/utils/processor_utils.py +++ b/radis/extractions/utils/processor_utils.py @@ -1,33 +1,55 @@ -from typing import Any +from collections.abc import Iterable +from typing import Any, Literal -from django.db.models import QuerySet from pydantic import BaseModel, create_model from ..models import OutputField, OutputType type Numeric = float | int +"""Build a Pydantic model that describes the structure the extractor must output""" -def generate_output_fields_schema(fields: QuerySet[OutputField]) -> type[BaseModel]: + +def generate_output_fields_schema(fields: Iterable[OutputField]) -> type[BaseModel]: field_definitions: dict[str, Any] = {} - for field in fields.all(): + for field in fields: if field.output_type == OutputType.TEXT: output_type = str elif field.output_type == OutputType.NUMERIC: output_type = Numeric elif field.output_type == OutputType.BOOLEAN: output_type = bool + elif field.output_type == OutputType.SELECTION: + options = tuple(field.selection_options) + if not options: + raise ValueError("Selection output requires at least one option.") + output_type = Literal[*options] else: raise ValueError(f"Unknown data type: {field.output_type}") + if field.is_array: + # If the field stores multiple values, use a list[...] of the base type above. + output_type = list[output_type] + field_definitions[field.name] = (output_type, ...) return create_model("OutputFieldsModel", **field_definitions) -def generate_output_fields_prompt(fields: QuerySet[OutputField]) -> str: +def generate_output_fields_prompt(fields: Iterable[OutputField]) -> str: + # Build a human-readable prompt that mirrors the same selection/array rules. prompt = "" - for field in fields.all(): - prompt += f"{field.name}: {field.description}\n" + for field in fields: + description = field.description + if OutputType(field.output_type) == OutputType.SELECTION and field.selection_options: + description = ( + f"{description} (Allowed selections: {', '.join(field.selection_options)})" + ) + if field.is_array: + description = ( + f"{description} (Return an array of " + f"{field.get_output_type_display().lower()} values.)" + ) + prompt += f"{field.name}: {description}\n" return prompt diff --git a/radis/extractions/utils/query_generator.py b/radis/extractions/utils/query_generator.py new file mode 100644 index 00000000..95615fcb --- /dev/null +++ b/radis/extractions/utils/query_generator.py @@ -0,0 +1,190 @@ +""" +Query Generator for Automated Query Creation from Extraction Fields + +This module provides functionality to automatically generate search queries +from user-defined extraction fields using LLM with fallback strategies. +""" + +import logging +import re +from string import Template +from typing import Any, Iterable + +from django.conf import settings + +from radis.extractions.models import OutputField +from radis.search.utils.query_parser import QueryParser + +logger = logging.getLogger(__name__) + + +class AsyncQueryGenerator: + """Async version of QueryGenerator for use in async views.""" + + def __init__(self): + """Initialize the async query generator with an async LLM client.""" + from radis.chats.utils.chat_client import AsyncChatClient + + self.client = AsyncChatClient() + self.parser = QueryParser() + + async def generate_from_fields( + self, fields: Iterable[OutputField] + ) -> tuple[str | None, dict[str, Any]]: + """ + Async version of generate_from_fields. + + Args: + fields: Iterable of OutputField objects to generate query from + + Returns: + Tuple of (query_string, metadata_dict) + Same structure as synchronous version + """ + fields_list = list(fields) + field_count = len(fields_list) + + metadata = { + "field_count": field_count, + "success": False, + "generation_method": None, + "error": None, + } + + if field_count == 0: + logger.warning("No fields provided for async query generation") + metadata["error"] = "No fields provided" + return None, metadata + + if settings.ENABLE_AUTO_QUERY_GENERATION: + try: + query = await self._call_llm(fields_list) + if query: + validated_query, fixes = self.validate_and_fix_query(query) + if validated_query: + logger.info( + f"Successfully generated query from {field_count} fields using LLM" + ) + metadata["generation_method"] = "llm" + metadata["success"] = True + metadata["fixes_applied"] = len(fixes) > 0 + return validated_query, metadata + else: + logger.warning("LLM generated invalid query") + metadata["error"] = "LLM generated invalid query" + except Exception as e: + logger.error(f"Error during async LLM query generation: {e}", exc_info=True) + metadata["error"] = str(e) + + logger.warning(f"Async query generation failed for {field_count} fields") + metadata["error"] = metadata.get("error") or "All generation methods failed" + metadata["success"] = False + return None, metadata + + async def _call_llm(self, fields: list[OutputField]) -> str | None: + """ + Async version of _call_llm. + + Args: + fields: List of OutputField objects + + Returns: + Generated query string, or None if failed + """ + fields_formatted = self._format_fields_for_prompt(fields) + + prompt = Template(settings.QUERY_GENERATION_SYSTEM_PROMPT).substitute( + fields=fields_formatted + ) + + try: + response = await self.client.chat([{"role": "user", "content": prompt}]) + + if not response: + logger.warning("LLM returned empty response") + return None + + query = self._extract_query_from_response(response) + logger.debug(f"Async LLM generated query: {query}") + return query + + except Exception as e: + logger.error(f"Async LLM call failed: {e}") + return None + + def _format_fields_for_prompt(self, fields: list[OutputField]) -> str: + """ + Format extraction fields for inclusion in LLM prompt. + Reuses synchronous version logic. + + Args: + fields: List of OutputField objects + + Returns: + Formatted string representation of fields + """ + formatted_fields = [] + for field in fields: + field_dict = { + "name": field.name, + "description": field.description, + "type": field.get_output_type_display(), + } + formatted_fields.append(str(field_dict)) + + return "\n".join(formatted_fields) + + def _extract_query_from_response(self, response: str) -> str: + """ + Extract query from LLM response. + Reuses synchronous version logic. + + Args: + response: Raw LLM response + + Returns: + Cleaned query string + """ + cleaned = re.sub( + r"^(query|search|generated query|result):\s*", "", response.strip(), flags=re.IGNORECASE + ) + + if cleaned.startswith('"') and cleaned.endswith('"'): + cleaned = cleaned[1:-1] + elif cleaned.startswith("'") and cleaned.endswith("'"): + cleaned = cleaned[1:-1] + + cleaned = cleaned.split("\n")[0].strip() + return cleaned + + def validate_and_fix_query(self, query: str) -> tuple[str, list[str]]: + """ + Validate and fix a query using QueryParser. + Reuses synchronous version logic. + + Args: + query: Query string to validate + + Returns: + Tuple of (fixed_query, list_of_fixes_applied) + """ + if not query or not query.strip(): + return "", [] + + try: + query_node, fixes = self.parser.parse(query) + + if query_node is None: + logger.warning(f"Query validation failed for: {query}") + return "", [] + + if len(fixes) > 0: + fixed_query = QueryParser.unparse(query_node) + logger.debug(f"Applied {len(fixes)} fixes to query: {fixes}") + return fixed_query, fixes + + return query, [] + + except Exception as e: + logger.error(f"Error validating query '{query}': {e}") + return "", [] diff --git a/radis/extractions/utils/testing_helpers.py b/radis/extractions/utils/testing_helpers.py index 5e864c3b..443af8b9 100644 --- a/radis/extractions/utils/testing_helpers.py +++ b/radis/extractions/utils/testing_helpers.py @@ -25,7 +25,8 @@ def create_extraction_task( add_user_to_group(user, group) job = ExtractionJobFactory.create( status=ExtractionJob.Status.PENDING, - owner_id=user.id, + owner=user, + group=group, language=language, ) diff --git a/radis/extractions/utils/validation.py b/radis/extractions/utils/validation.py new file mode 100644 index 00000000..c7c8bfbc --- /dev/null +++ b/radis/extractions/utils/validation.py @@ -0,0 +1,48 @@ +"""Shared validation utilities for the extractions app.""" + +from django.core.exceptions import ValidationError + +from radis.extractions.constants import MAX_SELECTION_OPTIONS + + +def validate_selection_options(options: list) -> list[str]: + """ + Validates selection options for output fields. + + Args: + options: A list of selection options to validate + + Returns: + A list of cleaned (stripped) selection option strings + + Raises: + ValidationError: If validation fails for any of these reasons: + - Options is not a list + - Any option is not a string + - Any option is empty after stripping + - Too many options (exceeds MAX_SELECTION_OPTIONS) + - Options are not unique + """ + if not isinstance(options, list): + raise ValidationError("Selection options must be a list.") + + cleaned_options = [] + for option in options: + if not isinstance(option, str): + raise ValidationError("All selection options must be text.") + + stripped = option.strip() + if not stripped: + raise ValidationError("Selection options cannot be empty strings.") + + cleaned_options.append(stripped) + + if len(cleaned_options) > MAX_SELECTION_OPTIONS: + raise ValidationError( + f"Provide at most {MAX_SELECTION_OPTIONS} selection options." + ) + + if len(set(cleaned_options)) != len(cleaned_options): + raise ValidationError("Selection options must be unique.") + + return cleaned_options diff --git a/radis/extractions/views.py b/radis/extractions/views.py index 78f4e3c2..a0f2222a 100644 --- a/radis/extractions/views.py +++ b/radis/extractions/views.py @@ -1,5 +1,9 @@ -from typing import Any, Type, cast +import csv +from collections.abc import Generator +from typing import Any, Type, Union, cast +from urllib.parse import urlencode +from adit_radis_shared.accounts.models import User from adit_radis_shared.common.mixins import ( PageSizeSelectMixin, ) @@ -14,9 +18,11 @@ from django.db import transaction from django.db.models import QuerySet from django.forms import BaseInlineFormSet -from django.shortcuts import redirect -from django.urls import reverse_lazy -from django.views.generic import DetailView +from django.http import QueryDict, StreamingHttpResponse +from django.shortcuts import redirect, render +from django.urls import reverse, reverse_lazy +from django.utils.text import slugify +from django.views.generic import DetailView, View from django_tables2 import SingleTableMixin, tables from formtools.wizard.views import SessionWizardView @@ -34,6 +40,7 @@ AnalysisTaskResetView, BaseUpdatePreferencesView, ) +from radis.search.site import Search, SearchFilters from radis.search.utils.query_parser import QueryParser from .filters import ExtractionInstanceFilter, ExtractionJobFilter, ExtractionTaskFilter @@ -44,12 +51,14 @@ ) from .mixins import ExtractionsLockedMixin from .models import ExtractionInstance, ExtractionJob, ExtractionTask +from .site import extraction_retrieval_provider from .tables import ( ExtractionInstanceTable, ExtractionJobTable, ExtractionResultsTable, ExtractionTaskTable, ) +from .utils.csv_export import iter_extraction_result_rows EXTRACTIONS_SEARCH_PROVIDER = "extractions_search_provider" @@ -70,11 +79,11 @@ class ExtractionJobListView(ExtractionsLockedMixin, AnalysisJobListView): class ExtractionJobWizardView( LoginRequiredMixin, PermissionRequiredMixin, UserPassesTestMixin, SessionWizardView ): - SEARCH_STEP = "0" - OUTPUT_FIELDS_STEP = "1" + OUTPUT_FIELDS_STEP = "0" + SEARCH_STEP = "1" SUMMARY_STEP = "2" - form_list = [SearchForm, OutputFieldFormSet, SummaryForm] + form_list = [OutputFieldFormSet, SearchForm, SummaryForm] permission_required = "extractions.add_extractionjob" permission_denied_message = "You must be logged in and have an active group" request: AuthenticatedHttpRequest @@ -88,40 +97,140 @@ def get_form_kwargs(self, step=None): kwargs["user"] = self.request.user return kwargs + def process_step(self, form): + """Process validated form data and trigger query generation after output fields step.""" + step_data = self.get_form_step_data(form) + + # After output fields are submitted, store field data for async query generation + if self.steps.current == ExtractionJobWizardView.OUTPUT_FIELDS_STEP: + # Extract and serialize output fields data for async generation + formset_data = [] + if hasattr(form, "cleaned_data"): + formset_data = form.cleaned_data + + output_fields_data = [ + { + "name": field_data["name"], + "description": field_data["description"], + "output_type": field_data["output_type"], + } + for field_data in formset_data + if not field_data.get("DELETE", False) + ] + + # Store serialized data for async generation - query will be generated via HTMX + self.storage.extra_data["output_fields_data"] = output_fields_data + self.storage.extra_data["query_generation_attempted"] = False + # Clear any previous query to ensure fresh generation + self.storage.extra_data["generated_query"] = "" + self.storage.extra_data["query_metadata"] = {} + + return step_data + + # After search step is submitted, store retrieval count for summary + elif self.steps.current == ExtractionJobWizardView.SEARCH_STEP: + if hasattr(form, "cleaned_data"): + retrieval_count = form.cleaned_data.get("retrieval_count") + if retrieval_count is not None: + self.storage.extra_data["retrieval_count"] = retrieval_count + + return step_data + + # For any other steps (SUMMARY_STEP or future additions) + return step_data + + def get_form_initial(self, step=None): + """Provide initial data for forms, including generated query for search step.""" + initial = super().get_form_initial(step) + + if step == ExtractionJobWizardView.SEARCH_STEP: + # Pre-populate query field with generated query + generated_query = self.storage.extra_data.get("generated_query", "") + if generated_query: + initial["query"] = generated_query + + return initial + + def render(self, form=None, **kwargs): + """Override to validate wizard data integrity before rendering summary step.""" + # Only validate if we're on the summary step + if self.steps.current == ExtractionJobWizardView.SUMMARY_STEP: + output_fields_data = self.get_cleaned_data_for_step( + ExtractionJobWizardView.OUTPUT_FIELDS_STEP + ) + search_data = self.get_cleaned_data_for_step(ExtractionJobWizardView.SEARCH_STEP) + + # If output fields data is missing or invalid, restart wizard + if not output_fields_data or not isinstance(output_fields_data, list): + from django.contrib import messages + + self.storage.reset() + self.storage.current_step = self.steps.first + messages.error( + self.request, + "Wizard data was lost or corrupted. Please start over from step 1.", + ) + return redirect(reverse("extraction_job_create")) + + # If search data is missing, go back to search step + if not search_data or not isinstance(search_data, dict): + from django.contrib import messages + + self.storage.current_step = ExtractionJobWizardView.SEARCH_STEP + messages.error( + self.request, + "Search data is missing. Please complete step 2.", + ) + return redirect(reverse("extraction_job_create")) + + return super().render(form, **kwargs) + def get_context_data(self, form, **kwargs): context = super().get_context_data(form, **kwargs) - if self.steps.current == ExtractionJobWizardView.SEARCH_STEP: - pass - - elif self.steps.current == ExtractionJobWizardView.OUTPUT_FIELDS_STEP: + if self.steps.current == ExtractionJobWizardView.OUTPUT_FIELDS_STEP: + # First step - just show the formset context["formset"] = form - data = self.get_cleaned_data_for_step(ExtractionJobWizardView.SEARCH_STEP) - assert data and isinstance(data, dict) + elif self.steps.current == ExtractionJobWizardView.SEARCH_STEP: + # Second step - show generated query info and retrieval count + context["generated_query"] = self.storage.extra_data.get("generated_query", "") + context["query_metadata"] = self.storage.extra_data.get("query_metadata", {}) - context["fixed_query"] = data.get("fixed_query") - context["retrieval_count"] = data["retrieval_count"] - self.storage.extra_data["retrieval_count"] = data["retrieval_count"] + # Get output fields data to show context + output_fields_data = self.get_cleaned_data_for_step( + ExtractionJobWizardView.OUTPUT_FIELDS_STEP + ) + if output_fields_data: + context["output_fields_count"] = len( + [ + f + for f in output_fields_data + if not cast(dict[str, Any], f).get("DELETE", False) + ] + ) elif self.steps.current == ExtractionJobWizardView.SUMMARY_STEP: - search_data = self.get_cleaned_data_for_step(ExtractionJobWizardView.SEARCH_STEP) - output_fields = self.get_cleaned_data_for_step( + # Final step - show everything for review + # Data integrity validated in render() method + output_fields_data = self.get_cleaned_data_for_step( ExtractionJobWizardView.OUTPUT_FIELDS_STEP ) + search_data = self.get_cleaned_data_for_step(ExtractionJobWizardView.SEARCH_STEP) + context["output_fields"] = output_fields_data context["search"] = search_data - context["output_fields"] = output_fields - context["retrieval_count"] = self.storage.extra_data["retrieval_count"] + context["retrieval_count"] = self.storage.extra_data.get("retrieval_count", 0) + context["query_metadata"] = self.storage.extra_data.get("query_metadata", {}) return context def get_template_names(self) -> list[str]: step = self.steps.current - if step == ExtractionJobWizardView.SEARCH_STEP: - return ["extractions/extraction_job_search_form.html"] - elif step == ExtractionJobWizardView.OUTPUT_FIELDS_STEP: + if step == ExtractionJobWizardView.OUTPUT_FIELDS_STEP: return ["extractions/extraction_job_output_fields_form.html"] + elif step == ExtractionJobWizardView.SEARCH_STEP: + return ["extractions/extraction_job_search_form.html"] elif step == ExtractionJobWizardView.SUMMARY_STEP: return ["extractions/extraction_job_wizard_summary.html"] else: @@ -129,23 +238,26 @@ def get_template_names(self) -> list[str]: def done( self, - form_objs: tuple[SearchForm, BaseInlineFormSet, SummaryForm], + form_objs: tuple[BaseInlineFormSet, SearchForm, SummaryForm], **kwargs, ): user = self.request.user with transaction.atomic(): + output_fields_formset = form_objs[0] + search_form = form_objs[1] summary_form = form_objs[2] - job_form = form_objs[0] + + # The query is always in search_form now (no conditional logic needed) if summary_form.cleaned_data["send_finished_mail"]: - job_form.cleaned_data["send_finished_mail"] = True + search_form.cleaned_data["send_finished_mail"] = True - job: ExtractionJob = job_form.save(commit=False) + job: ExtractionJob = search_form.save(commit=False) + # Parse and normalize the query query = job.query query_node, fixes = QueryParser().parse(query) if len(fixes) > 0: - # The query was already validated that it is not empty by the form assert query_node job.query = QueryParser.unparse(query_node) @@ -157,8 +269,8 @@ def done( job.save() # Save output fields - form_objs[1].instance = job - form_objs[1].save() + output_fields_formset.instance = job + output_fields_formset.save() if user.is_staff or settings.START_EXTRACTION_JOB_UNVERIFIED: job.status = ExtractionJob.Status.PENDING @@ -169,6 +281,149 @@ def done( return redirect(job) +class ExtractionSearchPreviewView(LoginRequiredMixin, View): + """HTMX endpoint for live search preview: count and link.""" + + request: AuthenticatedHttpRequest + + def get(self, request: AuthenticatedHttpRequest): + # Wizard step prefix for form field names + WIZARD_STEP_PREFIX = "1-" + + # Extract wizard data and strip "1-" prefix + wizard_data = { + key.replace(WIZARD_STEP_PREFIX, "", 1): value + for key, value in request.GET.items() + if key.startswith(WIZARD_STEP_PREFIX) + } + + # Build QueryDict for SearchForm (supports getlist for modalities) + query_dict = QueryDict(mutable=True) + for key, value in wizard_data.items(): + if key == "modalities": + # modalities are sent as multiple values + query_dict.setlist(key, request.GET.getlist(f"{WIZARD_STEP_PREFIX}{key}")) + else: + query_dict[key] = str(value) + + # Add dummy title (required field but not used in preview) + query_dict["title"] = "preview" + + # Validate with SearchForm (pass user kwarg) + user = cast("User", request.user) + form = SearchForm(query_dict, user=user) + + if not form.is_valid(): + # Extract error messages + error_messages = [] + for field, errors in form.errors.items(): + # Skip title field errors (dummy field) + if field == "title": + continue + for error in errors: + if field == "__all__": + error_messages.append(str(error)) + else: + error_messages.append(f"{field}: {error}") + + error = "; ".join(error_messages) if error_messages else "Invalid search parameters" + context = { + "count": None, + "search_url": None, + "error": error, + "max_reports_limit": settings.EXTRACTION_MAXIMUM_REPORTS_COUNT, + } + return render(request, "extractions/_search_preview.html", context) + + # Extract validated data from form + # Note: query already validated by clean_query() including QueryParser! + query_str = form.cleaned_data["query"] + query_node = form.cleaned_data["query_node"] # QueryNode object + language = form.cleaned_data["language"] # Language object or None + modalities = form.cleaned_data["modalities"] # QuerySet[Modality] + study_date_from = form.cleaned_data["study_date_from"] # date object or None + study_date_till = form.cleaned_data["study_date_till"] # date object or None + study_description = form.cleaned_data["study_description"] + patient_sex = form.cleaned_data["patient_sex"] + age_from = form.cleaned_data["age_from"] # int or None + age_till = form.cleaned_data["age_till"] # int or None + + # Convert objects to codes for Search object + language_code = language.code if language else "" + modality_codes = list(modalities.values_list("code", flat=True)) + + # Get active group (user already cast above) + active_group = user.active_group + + # Build Search object + search = Search( + query=query_node, + offset=0, + limit=0, + filters=SearchFilters( + group=active_group.pk if active_group else None, + language=language_code, # Converted from Language object + modalities=modality_codes, # Converted from QuerySet + study_date_from=study_date_from, # Already date object from form + study_date_till=study_date_till, # Already date object from form + study_description=study_description, + patient_sex=patient_sex, + patient_age_from=age_from, # Already int from form + patient_age_till=age_till, # Already int from form + ), + ) + + # Calculate count + if extraction_retrieval_provider is None: + context = { + "count": None, + "search_url": None, + "error": "Extraction retrieval provider is not configured", + "max_reports_limit": settings.EXTRACTION_MAXIMUM_REPORTS_COUNT, + } + return render(request, "extractions/_search_preview.html", context) + + retrieval_count = extraction_retrieval_provider.count(search) + + # Generate search URL with codes (FIX: use codes not PKs!) + search_params: dict[str, Union[str, list[str]]] = {"query": query_str} + + if language_code: + search_params["language"] = language_code + + if modality_codes: + search_params["modalities"] = modality_codes # FIX: Use codes not PKs! + + if study_date_from: + search_params["study_date_from"] = study_date_from.strftime("%Y-%m-%d") + + if study_date_till: + search_params["study_date_till"] = study_date_till.strftime("%Y-%m-%d") + + if study_description: + search_params["study_description"] = study_description + + if patient_sex: + search_params["patient_sex"] = patient_sex + + if age_from is not None: + search_params["age_from"] = str(age_from) + + if age_till is not None: + search_params["age_till"] = str(age_till) + + search_url = reverse("search") + "?" + urlencode(search_params, doseq=True) + + # Return context + context = { + "count": retrieval_count, + "search_url": search_url, + "error": None, + "max_reports_limit": settings.EXTRACTION_MAXIMUM_REPORTS_COUNT, + } + return render(request, "extractions/_search_preview.html", context) + + class ExtractionJobDetailView(AnalysisJobDetailView): model = ExtractionJob table_class = ExtractionTaskTable @@ -194,6 +449,101 @@ class ExtractionJobResumeView(ExtractionsLockedMixin, AnalysisJobResumeView): model = ExtractionJob +class ExtractionQueryGeneratorView(LoginRequiredMixin, View): + """HTMX endpoint for async query generation from output fields.""" + + request: AuthenticatedHttpRequest + + async def post(self, request: AuthenticatedHttpRequest): + """Generate query asynchronously and save to wizard session.""" + import logging + + logger = logging.getLogger(__name__) + logger.info("Query generation endpoint called") + + # Access wizard session storage + # Django-formtools stores wizard data in a nested structure: + # session['wizard_extraction_job_wizard_view'] = { + # 'step': '1', + # 'step_data': {...}, + # 'extra_data': {'output_fields_data': [...], ...} + # } + wizard_session_key = "wizard_extraction_job_wizard_view" + wizard_data = request.session.get(wizard_session_key, {}) + + if not wizard_data: + logger.error(f"No wizard session data found for key: {wizard_session_key}") + logger.error(f"Available session keys: {list(request.session.keys())}") + + # Get extra_data from within the wizard data + extra_data = wizard_data.get("extra_data", {}) + output_fields_data = extra_data.get("output_fields_data", []) + + logger.info(f"Found wizard_data: {bool(wizard_data)}, extra_data: {bool(extra_data)}") + logger.info(f"Found {len(output_fields_data)} output fields in session") + + if not output_fields_data: + context = { + "error": "No output fields found. Please go back to step 1.", + "generated_query": "", + "query_metadata": {}, + } + return render(request, "extractions/_query_generation_result.html", context) + + # Reconstruct OutputField objects from stored data + from .models import OutputField + + temp_fields = [ + OutputField( + name=field_data["name"], + description=field_data["description"], + output_type=field_data["output_type"], + ) + for field_data in output_fields_data + ] + + # Generate query using async query generator + from .utils.query_generator import AsyncQueryGenerator + + try: + generator = AsyncQueryGenerator() + generated_query, metadata = await generator.generate_from_fields(temp_fields) + + # Store in wizard session + extra_data["generated_query"] = generated_query or "" + extra_data["query_metadata"] = metadata + extra_data["query_generation_attempted"] = True + + # Update wizard data and save back to session + wizard_data["extra_data"] = extra_data + request.session[wizard_session_key] = wizard_data + request.session.modified = True + logger.info("Saved query to wizard session") + + context = { + "generated_query": generated_query, + "query_metadata": metadata, + "output_fields_count": len(temp_fields), + "error": None if metadata.get("success") else "Query generation failed", + } + + except Exception as e: + logger.error(f"Error during async query generation: {e}", exc_info=True) + context = { + "error": f"Error generating query: {str(e)}", + "generated_query": "", + "query_metadata": {"success": False, "error": str(e)}, + } + extra_data["generated_query"] = "" + extra_data["query_metadata"] = context["query_metadata"] + extra_data["query_generation_attempted"] = True + wizard_data["extra_data"] = extra_data + request.session[wizard_session_key] = wizard_data + request.session.modified = True + + return render(request, "extractions/_query_generation_result.html", context) + + class ExtractionJobRetryView(ExtractionsLockedMixin, AnalysisJobRetryView): model = ExtractionJob @@ -272,3 +622,50 @@ def get_table(self, **kwargs): def get_table_data(self): job = cast(ExtractionJob, self.get_object()) return ExtractionInstance.objects.filter(task__job=job) + + +class _Echo: + """Lightweight write-only buffer for csv.writer.""" + + def write(self, value: str) -> str: + return value + + +class ExtractionResultDownloadView(ExtractionsLockedMixin, LoginRequiredMixin, DetailView): + """Stream extraction results as a CSV download.""" + + model = ExtractionJob + request: AuthenticatedHttpRequest + + def get_queryset(self) -> QuerySet[ExtractionJob]: + """Return the accessible extraction jobs for the current user.""" + assert self.model + model = cast(Type[ExtractionJob], self.model) + if self.request.user.is_staff: + return model.objects.all() + return model.objects.filter(owner=self.request.user) + + def get(self, request: AuthenticatedHttpRequest, *args, **kwargs) -> StreamingHttpResponse: + """Stream the CSV file response.""" + job = cast(ExtractionJob, self.get_object()) + filename = self._build_filename(job) + + response = StreamingHttpResponse( + self._stream_rows(job), + content_type="text/csv", + ) + response["Content-Disposition"] = f'attachment; filename="{filename}"' + return response + + def _stream_rows(self, job: ExtractionJob) -> Generator[str, None, None]: + """Yield serialized CSV rows for the response.""" + pseudo_buffer = _Echo() + writer = csv.writer(pseudo_buffer) + yield "\ufeff" + for row in iter_extraction_result_rows(job): + yield writer.writerow(row) + + def _build_filename(self, job: ExtractionJob) -> str: + """Generate a descriptive CSV filename for the extraction job.""" + slug = slugify(job.title) or "results" + return f"extraction_job_{job.pk}_{slug}.csv" diff --git a/radis/search/forms.py b/radis/search/forms.py index 6bef0c2d..3c83f3de 100644 --- a/radis/search/forms.py +++ b/radis/search/forms.py @@ -4,23 +4,21 @@ from crispy_forms.layout import Button, Div, Field from django import forms -from radis.core.constants import LANGUAGE_LABELS +from radis.core.constants import AGE_STEP +from radis.core.form_fields import ( + create_age_range_fields, + create_language_field, + create_modality_field, +) from radis.core.layouts import RangeSlider -from radis.reports.models import Language, Modality from .layouts import QueryInput -MIN_AGE = 0 -MAX_AGE = 120 -AGE_STEP = 10 - class SearchForm(forms.Form): # Query fields query = forms.CharField(required=False, label=False) # type: ignore - # Filter fields - language = forms.ChoiceField(required=False, choices=[]) - modalities = forms.MultipleChoiceField(required=False, choices=[]) + # Filter fields - language, modalities, and age fields created in __init__ study_date_from = forms.DateField( required=False, widget=forms.DateInput(attrs={"type": "date"}) ) @@ -31,43 +29,16 @@ class SearchForm(forms.Form): patient_sex = forms.ChoiceField( required=False, choices=[("", "All"), ("M", "Male"), ("F", "Female")] ) - age_from = forms.IntegerField( - required=False, - min_value=MIN_AGE, - max_value=MAX_AGE, - widget=forms.NumberInput( - attrs={ - "type": "range", - "step": AGE_STEP, - "value": MIN_AGE, - } - ), - ) - age_till = forms.IntegerField( - required=False, - min_value=MIN_AGE, - max_value=MAX_AGE, - widget=forms.NumberInput( - attrs={ - "type": "range", - "step": AGE_STEP, - "value": MAX_AGE, - } - ), - ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fields["language"].choices = [ # type: ignore - (language.code, LANGUAGE_LABELS[language.code]) - for language in Language.objects.order_by("code") - ] - self.fields["modalities"].choices = [ # type: ignore - (modality.code, modality.code) - for modality in Modality.objects.filter(filterable=True).order_by("code") - ] - self.fields["modalities"].widget.attrs["size"] = 6 + # Create fields using factory functions (use codes, not PKs) + self.fields["language"] = create_language_field(use_pk=False) + self.fields["modalities"] = create_modality_field(use_pk=False) + age_from, age_till = create_age_range_fields() + self.fields["age_from"] = age_from + self.fields["age_till"] = age_till self.query_helper = FormHelper() self.query_helper.template = "search/form_elements/form_part.html" # type: ignore diff --git a/radis/search/site.py b/radis/search/site.py index 5e0ac4b6..10552260 100644 --- a/radis/search/site.py +++ b/radis/search/site.py @@ -45,7 +45,7 @@ class SearchFilters: - patient_age_till: Filter only reports where the patient is at most this age """ - group: int # TODO: Rename to group_id + group: int | None = None # TODO: Rename to group_id language: str = "" # TODO: Rename to language_code modalities: list[str] = field(default_factory=list) study_date_from: date | None = None diff --git a/radis/settings/base.py b/radis/settings/base.py index 1f9b1c6b..627dac44 100644 --- a/radis/settings/base.py +++ b/radis/settings/base.py @@ -357,20 +357,34 @@ """ # Subscription -QUESTIONS_SYSTEM_PROMPT = """ +SUBSCRIPTION_FILTER_PROMPT = """ You are an AI medical assistant with extensive knowledge in radiology and general medicine. You have been trained on a wide range of medical literature, including the latest research and guidelines in radiological practices. -Answer the following questions from the given radiology report. The report and questions can -be given in any language. -Base your answers only on the information provided in the report. Don't hallucinate. -Return the answer in JSON format. Answer with 'true' for 'yes' and 'false' for 'no'. +Answer the following filter questions about the radiology report. The questions can be in any +language. Base your answers strictly on the contents of the report. Return the answers in JSON +format using the provided field identifiers. Answer with `true` for "yes" and `false` for "no". Radiology Report: $report Questions: $questions + +""" + +SUBSCRIPTION_EXTRACTION_PROMPT = """ +You are an AI medical assistant with extensive knowledge in radiology and general medicine. +Extract the requested information from the radiology report. Only provide data that is explicitly +mentioned in the report and respect the expected data type. If the report does not contain the +requested information, respond with null. Return the extracted information in JSON format using +the provided field identifiers. + +Radiology Report: +$report + +Fields to extract: +$fields """ # Extraction @@ -390,6 +404,42 @@ $fields """ +# Query Generation from Extraction Fields +ENABLE_AUTO_QUERY_GENERATION = env.bool("ENABLE_AUTO_QUERY_GENERATION", default=True) +QUERY_GENERATION_TIMEOUT = env.int("QUERY_GENERATION_TIMEOUT", default=10) +QUERY_GENERATION_MAX_RETRIES = env.int("QUERY_GENERATION_MAX_RETRIES", default=2) + +QUERY_GENERATION_SYSTEM_PROMPT = """You are an AI assistant specialized in medical informatics and +radiology report retrieval. Your task is to generate an effective search query to find radiology +reports that would contain information relevant to the specified data extraction fields. +Given extraction fields with their descriptions and types, generate a boolean search query that +will retrieve reports likely to contain the requested information. + +Guidelines: +1. Use medical terminology and common synonyms +2. Prefer broader terms to avoid missing relevant reports +3. Use boolean operators: AND, OR, NOT +4. Use quotes for exact phrases: "lung nodule" +5. Consider anatomical variations and medical abbreviations +6. Keep the query concise but comprehensive (max 150 characters recommended) +7. Focus on key concepts that would appear in reports containing this data + +Output format: Return ONLY the search query as a single line, without explanation. + +Examples: +Fields: [{"name": "nodule_size", "description": "size of lung nodule in millimeters", +"type": "NUMERIC"}] +Query: ("lung nodule" OR "pulmonary nodule") AND (size OR measurement OR diameter) + +Fields: [{"name": "fracture_type", "description": "type of bone fracture", "type": "TEXT"}, +{"name": "bone", "description": "which bone is fractured", "type": "TEXT"}] +Query: fracture AND bone + +Extraction Fields: +$fields + +Generate the search query:""" + # The maximum number of reports that can be extracted by one extraction job. EXTRACTION_MAXIMUM_REPORTS_COUNT = 25000 @@ -412,7 +462,7 @@ # Subscription SUBSCRIPTION_DEFAULT_PRIORITY = 3 SUBSCRIPTION_URGENT_PRIORITY = 4 -SUBSCRIPTION_CRON = "* * * * *" +SUBSCRIPTION_CRON = "* * * * *" # Disabled - using handler-based approach SUBSCRIPTION_REFRESH_TASK_BATCH_SIZE = 100 # The priority for stalled jobs that are retried. diff --git a/radis/subscriptions/apps.py b/radis/subscriptions/apps.py index dee690a6..28680c4d 100644 --- a/radis/subscriptions/apps.py +++ b/radis/subscriptions/apps.py @@ -7,6 +7,7 @@ class SubscriptionsConfig(AppConfig): def ready(self): register_app() + register_reports_handler() # Put calls to db stuff in this signal handler post_migrate.connect(init_db, sender=self) @@ -28,3 +29,17 @@ def init_db(**kwargs): if not SubscriptionsAppSettings.objects.exists(): SubscriptionsAppSettings.objects.create() + + +def register_reports_handler(): + """Register handler to trigger subscriptions when reports are created.""" + from radis.reports.site import ReportsCreatedHandler, reports_created_handlers + + from .handlers import handle_reports_created + + reports_created_handlers.append( + ReportsCreatedHandler( + name="subscription_launcher", + handle=handle_reports_created, + ) + ) diff --git a/radis/subscriptions/factories.py b/radis/subscriptions/factories.py index a26b23a5..da57a0c7 100644 --- a/radis/subscriptions/factories.py +++ b/radis/subscriptions/factories.py @@ -4,7 +4,7 @@ from radis.reports.factories import LanguageFactory, ReportFactory -from .models import Question, SubscribedItem, Subscription, SubscriptionJob, SubscriptionTask +from .models import FilterQuestion, SubscribedItem, Subscription, SubscriptionJob, SubscriptionTask class SubscriptionFactory(BaseDjangoModelFactory[Subscription]): @@ -15,7 +15,6 @@ class Meta: owner = factory.SubFactory(UserFactory) group = factory.SubFactory(GroupFactory) patient_id = factory.Faker("numerify", text="##########") - query = factory.Faker("sentence", nb_words=3) language = factory.SubFactory(LanguageFactory, code="en") study_description = factory.Faker("sentence", nb_words=4) patient_sex = factory.Faker("random_element", elements=["M", "F", ""]) @@ -24,9 +23,9 @@ class Meta: send_finished_mail = factory.Faker("boolean") -class QuestionFactory(BaseDjangoModelFactory[Question]): +class FilterQuestionFactory(BaseDjangoModelFactory[FilterQuestion]): class Meta: - model = Question + model = FilterQuestion subscription = factory.SubFactory(SubscriptionFactory) question = factory.Faker("sentence", nb_words=6, variable_nb_words=True) @@ -56,4 +55,3 @@ class Meta: SubscriptionJobFactory, subscription=factory.SelfAttribute("..subscription") ) report = factory.SubFactory(ReportFactory) - answers = factory.LazyFunction(lambda: {"question_1": "answer_1", "question_2": "answer_2"}) diff --git a/radis/subscriptions/filters.py b/radis/subscriptions/filters.py index 0ef57e17..6e3cddb4 100644 --- a/radis/subscriptions/filters.py +++ b/radis/subscriptions/filters.py @@ -1,9 +1,14 @@ import django_filters from adit_radis_shared.common.forms import SingleFilterFieldFormHelper from adit_radis_shared.common.types import with_form_helper +from crispy_forms.helper import FormHelper +from crispy_forms.layout import HTML, Div, Field, Layout, Submit +from django import forms from django.http import HttpRequest -from .models import Subscription +from radis.reports.models import Modality + +from .models import SubscribedItem, Subscription class SubscriptionFilter(django_filters.FilterSet): @@ -18,3 +23,58 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) with_form_helper(self.form).helper = SingleFilterFieldFormHelper(self.request.GET, "name") + + +class SubscribedItemFilter(django_filters.FilterSet): + patient_id = django_filters.CharFilter( + label="Patient ID", + field_name="report__patient_id", + lookup_expr="icontains", + ) + study_description = django_filters.CharFilter( + label="Study Description", + field_name="report__study_description", + lookup_expr="icontains", + ) + study_date_from = django_filters.DateFilter( + label="Study Date From", + field_name="report__study_datetime", + lookup_expr="date__gte", + widget=forms.DateInput(attrs={"type": "date"}), + ) + study_date_till = django_filters.DateFilter( + label="Study Date Till", + field_name="report__study_datetime", + lookup_expr="date__lte", + widget=forms.DateInput(attrs={"type": "date"}), + ) + modalities = django_filters.ModelMultipleChoiceFilter( + queryset=Modality.objects.order_by("code"), + field_name="report__modalities__code", + to_field_name="code", + ) + request: HttpRequest + + class Meta: + model = SubscribedItem + fields = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + form_helper = FormHelper() + form_helper.form_tag = False + form_helper.disable_csrf = True + form_helper.layout = Layout( + Field("patient_id", css_class="form-control-sm"), + Field("study_description", css_class="form-control-sm"), + Field("study_date_from", css_class="form-control-sm"), + Field("study_date_till", css_class="form-control-sm"), + Field("modalities", css_class="form-select-sm"), + Div( + Submit("submit", "Apply Filters", css_class="btn btn-sm btn-primary"), + HTML("Reset"), + css_class="d-flex justify-content-center gap-2", + ), + ) + with_form_helper(self.form).helper = form_helper diff --git a/radis/subscriptions/forms.py b/radis/subscriptions/forms.py index 90376fb5..d272df12 100644 --- a/radis/subscriptions/forms.py +++ b/radis/subscriptions/forms.py @@ -4,12 +4,17 @@ from crispy_forms.layout import Column, Div, Field, Layout, Row from django import forms -from radis.core.constants import LANGUAGE_LABELS +from radis.core.constants import AGE_STEP +from radis.core.form_fields import ( + create_age_range_fields, + create_language_field, + create_modality_field, +) from radis.core.layouts import Formset, RangeSlider -from radis.reports.models import Language, Modality -from radis.search.forms import AGE_STEP, MAX_AGE, MIN_AGE +from radis.extractions.forms import OutputFieldForm +from radis.extractions.models import OutputField -from .models import Question, Subscription +from .models import FilterQuestion, Subscription class SubscriptionForm(forms.ModelForm): @@ -17,7 +22,6 @@ class Meta: model = Subscription fields = [ "name", - "query", "language", "modalities", "study_description", @@ -28,48 +32,15 @@ class Meta: "send_finished_mail", ] labels = {"patient_id": "Patient ID"} - help_texts = { - "name": "Name of the Subscription", - "query": "A query to filter reports", - } def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.fields["language"].choices = [ # type: ignore - (language.pk, LANGUAGE_LABELS[language.code]) - for language in Language.objects.order_by("code") - ] - self.fields["language"].empty_label = "All" # type: ignore - self.fields["modalities"].choices = [ # type: ignore - (modality.pk, modality.code) - for modality in Modality.objects.filter(filterable=True).order_by("code") - ] - self.fields["modalities"].widget.attrs["size"] = 6 - self.fields["age_from"] = forms.IntegerField( - required=False, - min_value=MIN_AGE, - max_value=MAX_AGE, - widget=forms.NumberInput( - attrs={ - "type": "range", - "step": AGE_STEP, - "value": MIN_AGE, - } - ), - ) - self.fields["age_till"] = forms.IntegerField( - required=False, - min_value=MIN_AGE, - max_value=MAX_AGE, - widget=forms.NumberInput( - attrs={ - "type": "range", - "step": AGE_STEP, - "value": MAX_AGE, - } - ), - ) + self.fields["language"] = create_language_field(empty_label="All") + self.fields["modalities"] = create_modality_field() + age_from, age_till = create_age_range_fields() + self.fields["age_from"] = age_from + self.fields["age_till"] = age_till self.fields["send_finished_mail"].label = "Notify me via mail of new reports" self.helper = FormHelper() @@ -81,9 +52,17 @@ def build_layout(self): Row( Column( "name", - "query", "send_finished_mail", - Formset("formset", legend="Questions", add_form_label="Add Question"), + Formset( + "filter_formset", + legend="Filter Questions", + add_form_label="Add Filter Question", + ), + Formset( + "output_formset", + legend="Extraction Fields", + add_form_label="Add Extraction Field", + ), ), Column( "patient_id", @@ -120,35 +99,74 @@ def clean(self) -> dict[str, Any] | None: return super().clean() -class QuestionForm(forms.ModelForm): +class FilterQuestionForm(forms.ModelForm): class Meta: - model = Question - fields = ["question"] + model = FilterQuestion + fields = ["question", "expected_answer"] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.fields["question"].required = False + self.fields["expected_answer"].required = False + self.fields["expected_answer"].choices = [ # type: ignore[attr-defined] + ("", "Select the expected answer"), + *FilterQuestion.ExpectedAnswer.choices, + ] + self.fields["expected_answer"].label = "Accept when answer is" self.helper = FormHelper() self.helper.form_tag = False self.helper.disable_csrf = True - self.helper.layout = Layout( - Div( - Field("id", type="hidden"), - Field("DELETE", type="hidden"), - "question", - ), - ) + fields = [Field("id", type="hidden"), "question", "expected_answer"] + if "DELETE" in self.fields: + fields.insert(1, Field("DELETE", type="hidden")) + self.helper.layout = Layout(Div(*fields)) + + def has_changed(self) -> bool: + if not self.is_bound: + return super().has_changed() + + question = (self.data.get(self.add_prefix("question")) or "").strip() + expected_answer = self.data.get(self.add_prefix("expected_answer")) or "" + + if not question and not expected_answer: + return False + + return super().has_changed() + def clean(self) -> dict[str, Any]: + cleaned_data = super().clean() + assert cleaned_data -QuestionFormSet = forms.inlineformset_factory( + question = cleaned_data.get("question") + expected_answer = cleaned_data.get("expected_answer") + + if (question and not expected_answer) or (expected_answer and not question): + raise forms.ValidationError("You must provide both a question and an expected answer.") + + return cleaned_data + + +FilterQuestionFormSet = forms.inlineformset_factory( Subscription, - Question, - form=QuestionForm, + FilterQuestion, + form=FilterQuestionForm, extra=1, min_num=0, max_num=3, validate_max=True, can_delete=False, ) + +OutputFieldFormSet = forms.inlineformset_factory( + Subscription, + OutputField, + form=OutputFieldForm, + fk_name="subscription", + extra=1, + min_num=0, + max_num=10, + validate_max=True, + can_delete=False, +) diff --git a/radis/subscriptions/handlers.py b/radis/subscriptions/handlers.py new file mode 100644 index 00000000..c0bd70b1 --- /dev/null +++ b/radis/subscriptions/handlers.py @@ -0,0 +1,21 @@ +import logging +import time + +from radis.reports.models import Report +from radis.subscriptions.tasks import subscription_launcher + +logger = logging.getLogger(__name__) + + +def handle_reports_created(reports: list[Report]) -> None: + """ + Handler called when reports are created via API. + Triggers subscription processing for bulk imports. + """ + if not reports: + return + + # Trigger subscriptions for any report creation + # The subscription_launcher will filter by last_refreshed timestamp + logger.info(f"Triggering subscription processing for {len(reports)} new report(s)") + subscription_launcher.defer(timestamp=int(time.time())) diff --git a/radis/subscriptions/migrations/0011_rename_answers_subscribeditem_filter_results_and_more.py b/radis/subscriptions/migrations/0011_rename_answers_subscribeditem_filter_results_and_more.py new file mode 100644 index 00000000..f7b3e518 --- /dev/null +++ b/radis/subscriptions/migrations/0011_rename_answers_subscribeditem_filter_results_and_more.py @@ -0,0 +1,56 @@ +# Generated by Django 5.2.8 on 2025-11-17 23:36 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("subscriptions", "0010_remove_subscription_provider"), + ] + + operations = [ + migrations.RenameField( + model_name="subscribeditem", + old_name="answers", + new_name="filter_results", + ), + migrations.AddField( + model_name="subscribeditem", + name="extraction_results", + field=models.JSONField(blank=True, null=True), + ), + migrations.CreateModel( + name="FilterQuestion", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("question", models.CharField(max_length=300)), + ( + "expected_answer", + models.CharField( + choices=[("Y", "Yes"), ("N", "No")], default="Y", max_length=1 + ), + ), + ( + "subscription", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="filter_questions", + to="subscriptions.subscription", + ), + ), + ], + ), + migrations.DeleteModel( + name="Question", + ), + ] diff --git a/radis/subscriptions/migrations/0012_procrastinate_on_delete.py b/radis/subscriptions/migrations/0012_procrastinate_on_delete.py new file mode 100644 index 00000000..198142f2 --- /dev/null +++ b/radis/subscriptions/migrations/0012_procrastinate_on_delete.py @@ -0,0 +1,23 @@ +# Generated by Django 5.1.6 on 2025-11-18 08:51 + +from django.db import migrations + +from adit_radis_shared.common.utils.migration_utils import procrastinate_on_delete_sql + +class Migration(migrations.Migration): + + dependencies = [ + ("subscriptions", "0011_rename_answers_subscribeditem_filter_results_and_more"), + ("procrastinate", "0028_add_cancel_states"), + ] + + operations = [ + migrations.RunSQL( + sql=procrastinate_on_delete_sql("subscriptions", "subscriptionjob"), + reverse_sql=procrastinate_on_delete_sql("subscriptions", "subscriptionjob", reverse=True), + ), + migrations.RunSQL( + sql=procrastinate_on_delete_sql("subscriptions", "subscriptiontask"), + reverse_sql=procrastinate_on_delete_sql("subscriptions", "subscriptiontask", reverse=True), + ), + ] diff --git a/radis/subscriptions/migrations/0013_subscription_last_viewed_at.py b/radis/subscriptions/migrations/0013_subscription_last_viewed_at.py new file mode 100644 index 00000000..cda71caa --- /dev/null +++ b/radis/subscriptions/migrations/0013_subscription_last_viewed_at.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2.8 on 2025-11-28 10:07 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("subscriptions", "0012_procrastinate_on_delete"), + ] + + operations = [ + migrations.AddField( + model_name="subscription", + name="last_viewed_at", + field=models.DateTimeField(blank=True, null=True), + ), + ] diff --git a/radis/subscriptions/migrations/0014_remove_subscription_query.py b/radis/subscriptions/migrations/0014_remove_subscription_query.py new file mode 100644 index 00000000..9bea7412 --- /dev/null +++ b/radis/subscriptions/migrations/0014_remove_subscription_query.py @@ -0,0 +1,17 @@ +# Generated by Django 5.2.8 on 2025-12-12 01:49 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("subscriptions", "0013_subscription_last_viewed_at"), + ] + + operations = [ + migrations.RemoveField( + model_name="subscription", + name="query", + ), + ] diff --git a/radis/subscriptions/models.py b/radis/subscriptions/models.py index ae38c095..abdad5fd 100644 --- a/radis/subscriptions/models.py +++ b/radis/subscriptions/models.py @@ -8,6 +8,7 @@ from procrastinate.contrib.django.models import ProcrastinateJob from radis.core.models import AnalysisJob, AnalysisTask +from radis.extractions.models import OutputField from radis.reports.models import Language, Modality, Report @@ -31,7 +32,6 @@ class Subscription(models.Model): group = models.ForeignKey[Group](Group, on_delete=models.CASCADE, related_name="+") patient_id = models.CharField(max_length=100, blank=True) - query = models.CharField(max_length=200, blank=True) language = models.ForeignKey[Language]( Language, on_delete=models.SET_NULL, blank=True, null=True, related_name="+" ) @@ -48,8 +48,10 @@ class Subscription(models.Model): created_at = models.DateTimeField(auto_now_add=True) last_refreshed = models.DateTimeField(auto_now_add=True) + last_viewed_at = models.DateTimeField(null=True, blank=True) - questions: models.QuerySet["Question"] + filter_questions: models.QuerySet["FilterQuestion"] + output_fields: models.QuerySet[OutputField] items: models.QuerySet["SubscribedItem"] send_finished_mail = models.BooleanField(default=False) @@ -67,17 +69,30 @@ def __str__(self): return f"Subscription {self.name} [{self.pk}]" -class Question(models.Model): +class FilterQuestion(models.Model): + class ExpectedAnswer(models.TextChoices): + YES = "Y", "Yes" + NO = "N", "No" + subscription = models.ForeignKey[Subscription]( - Subscription, on_delete=models.CASCADE, related_name="questions" + Subscription, on_delete=models.CASCADE, related_name="filter_questions" ) question = models.CharField(max_length=300) + expected_answer = models.CharField( + max_length=1, + choices=ExpectedAnswer.choices, + default=ExpectedAnswer.YES, + ) def __str__(self) -> str: max_length = 30 - if len(self.question) > max_length: - return f'Question "{self.question[:max_length]}..." [{self.pk}]' - return f'Question "{self.question}" [{self.pk}]' + truncated = self.question[:max_length] + suffix = "..." if len(self.question) > max_length else "" + return f'Filter Question "{truncated}{suffix}" [{self.pk}]' + + @property + def expected_answer_bool(self) -> bool: + return self.expected_answer == self.ExpectedAnswer.YES class SubscribedItem(models.Model): @@ -88,7 +103,8 @@ class SubscribedItem(models.Model): "SubscriptionJob", null=True, on_delete=models.SET_NULL, related_name="items" ) report = models.ForeignKey[Report](Report, on_delete=models.CASCADE, related_name="+") - answers = models.JSONField(null=True, blank=True) + filter_results = models.JSONField(null=True, blank=True) + extraction_results = models.JSONField(null=True, blank=True) created_at = models.DateTimeField(auto_now_add=True) class Meta: diff --git a/radis/subscriptions/processors.py b/radis/subscriptions/processors.py index 842d0e91..5c76a517 100644 --- a/radis/subscriptions/processors.py +++ b/radis/subscriptions/processors.py @@ -1,13 +1,19 @@ import logging from concurrent.futures import Future, ThreadPoolExecutor from string import Template +from typing import Any from adit_radis_shared.common.types import User from django import db from django.conf import settings +from pydantic import ValidationError from radis.chats.utils.chat_client import ChatClient from radis.core.processors import AnalysisTaskProcessor +from radis.extractions.utils.processor_utils import ( + generate_output_fields_prompt, + generate_output_fields_schema, +) from radis.reports.models import Report from .models import ( @@ -16,8 +22,10 @@ SubscriptionTask, ) from .utils.processor_utils import ( - generate_questions_for_prompt, - generate_questions_schema, + generate_filter_questions_prompt, + generate_filter_questions_schema, + get_filter_question_field_name, + get_output_field_name, ) logger = logging.getLogger(__name__) @@ -46,26 +54,102 @@ def process_task(self, task: SubscriptionTask) -> None: db.close_old_connections() def process_report(self, report: Report, task: SubscriptionTask) -> None: - subscription: Subscription = task.job.subscription - Schema = generate_questions_schema(subscription.questions) - prompt = Template(settings.QUESTIONS_SYSTEM_PROMPT).substitute( - { - "report": report.body, - "questions": generate_questions_for_prompt(subscription.questions), - } - ) - result = self.client.extract_data(prompt, Schema) - - is_accepted = all( - [getattr(result, field_name) for field_name in result.__pydantic_fields__] - ) - if is_accepted: + try: + subscription: Subscription = task.job.subscription + + filter_results: dict[str, bool] = {} + is_accepted = True + + filter_questions = list(subscription.filter_questions.order_by("pk")) + + if filter_questions: + filter_prompt = Template(settings.SUBSCRIPTION_FILTER_PROMPT).substitute( + { + "report": report.body, + "questions": generate_filter_questions_prompt(filter_questions), + } + ) + filter_schema = generate_filter_questions_schema(filter_questions) + + try: + filter_response = self.client.extract_data(filter_prompt, filter_schema) + + for question in filter_questions: + field_name = get_filter_question_field_name(question) + answer = getattr(filter_response, field_name, None) + if answer is None: + logger.debug( + f"LLM returned None for question {question.pk} ", + f"on report {report.pk}", + ) + is_accepted = False + break + else: + answer_bool = bool(answer) + filter_results[str(question.pk)] = answer_bool + if answer_bool != question.expected_answer_bool: + is_accepted = False + break + except RuntimeError as e: + logger.error(f"LLM API error filtering report {report.pk}: {e}") + return + except ValidationError as e: + logger.error(f"Response validation failed filtering report {report.pk}: {e}") + return + except ValueError: + logger.error(f"No parsed response received filtering report {report.pk}") + return + else: + logger.debug( + "Subscription %s has no filter questions; accepting report %s by default", + subscription.pk, + report.pk, + ) + + if not is_accepted: + logger.debug(f"Report {report.pk} was rejected by subscription {subscription.pk}") + return + + extraction_results: dict[str, Any] = {} + output_fields = list(subscription.output_fields.order_by("pk")) + + if output_fields: + extraction_prompt = Template(settings.SUBSCRIPTION_EXTRACTION_PROMPT).substitute( + { + "report": report.body, + "fields": generate_output_fields_prompt(output_fields), + } + ) + extraction_schema = generate_output_fields_schema(output_fields) + + try: + extraction_response = self.client.extract_data( + extraction_prompt, extraction_schema + ) + + for field in output_fields: + extraction_results[str(field.pk)] = getattr( + extraction_response, get_output_field_name(field), None + ) + except RuntimeError as e: + logger.error(f"LLM API error extracting from report {report.pk}: {e}") + return + except ValidationError as e: + logger.error( + f"Response validation failed extracting from report {report.pk}: {e}" + ) + return + except ValueError: + logger.error(f"No parsed response received extracting from report {report.pk}") + return + SubscribedItem.objects.create( subscription=task.job.subscription, job=task.job, report=report, - filter_fields_results=result.model_dump(), + filter_results=filter_results or None, + extraction_results=extraction_results or None, ) logger.debug(f"Report {report.pk} was accepted by subscription {subscription.pk}") - else: - logger.debug(f"Report {report.pk} was rejected by subscription {subscription.pk}") + finally: + db.close_old_connections() diff --git a/radis/subscriptions/static/subscriptions/subscriptions.css b/radis/subscriptions/static/subscriptions/subscriptions.css index e69de29b..016f8965 100644 --- a/radis/subscriptions/static/subscriptions/subscriptions.css +++ b/radis/subscriptions/static/subscriptions/subscriptions.css @@ -0,0 +1,6 @@ +fieldset > legend { + font-size: 1rem !important; + font-weight: 500 !important; + line-height: 1.5; + margin-bottom: 0.5rem; +} diff --git a/radis/subscriptions/tables.py b/radis/subscriptions/tables.py index 81ebacd3..8a0aa58d 100644 --- a/radis/subscriptions/tables.py +++ b/radis/subscriptions/tables.py @@ -1,4 +1,6 @@ import django_tables2 as tables +from django.urls import reverse +from django.utils.html import format_html from django_tables2.utils import A from .models import Subscription @@ -18,6 +20,20 @@ class SubscriptionTable(tables.Table): }, ) + def render_num_reports(self, value, record): + """Render the num_reports column with a notification badge for new reports.""" + num_new = getattr(record, "num_new_reports", 0) + url = reverse("subscription_inbox", args=[record.pk]) + + if num_new > 0: + return format_html( + '{}{} new', + url, + value, + num_new, + ) + return format_html('{}', url, value) + class Meta: model = Subscription fields = ("name", "num_reports") diff --git a/radis/subscriptions/tasks.py b/radis/subscriptions/tasks.py index c6248ca5..505b3c6a 100644 --- a/radis/subscriptions/tasks.py +++ b/radis/subscriptions/tasks.py @@ -9,8 +9,7 @@ from procrastinate.contrib.django import app from radis.reports.models import Report -from radis.search.site import Search, SearchFilters -from radis.search.utils.query_parser import QueryParser +from radis.search.site import SearchFilters from . import site from .models import Subscription, SubscriptionJob, SubscriptionTask @@ -40,7 +39,7 @@ def process_subscription_job(job_id: int) -> None: logger.debug("Collecting tasks for job %s", job) language_code = "" - if job.subscription.language and job.subscription.query != "": + if job.subscription.language: language_code = job.subscription.language.code filters = SearchFilters( @@ -54,38 +53,13 @@ def process_subscription_job(job_id: int) -> None: created_after=job.subscription.last_refreshed, ) - if job.subscription.query != "": - logger.debug("Searching new reports with query and filters for job %s", job) + logger.debug("Searching new reports with filters for job %s", job) - if site.subscription_retrieval_provider is None: - logger.error("Subscription retrieval provider is not configured for job %s", job) - raise ImproperlyConfigured("Subscription retrieval provider is not configured.") - retrieval_provider = site.subscription_retrieval_provider - - query_node, fixes = QueryParser().parse(job.subscription.query) - - if query_node is None: - raise ValueError(f"Not a valid query (evaluated as empty): {job.subscription.query}") - - if len(fixes) > 0: - logger.info(f"The following fixes were applied to the query:\n{'\n - '.join(fixes)}") - - search = Search( - query=query_node, - offset=0, - filters=filters, - ) - - new_document_ids = retrieval_provider.retrieve(search) - - else: - logger.debug("Searching new reports with filters for job %s", job) - - if site.subscription_filter_provider is None: - logger.error("Subscription filter provider is not configured for job %s", job) - raise ImproperlyConfigured("Subscription filter provider is not configured.") - filter_provider = site.subscription_filter_provider - new_document_ids = filter_provider.filter(filters) + if site.subscription_filter_provider is None: + logger.error("Subscription filter provider is not configured for job %s", job) + raise ImproperlyConfigured("Subscription filter provider is not configured.") + filter_provider = site.subscription_filter_provider + new_document_ids = filter_provider.filter(filters) for document_ids in batched(new_document_ids, settings.SUBSCRIPTION_REFRESH_TASK_BATCH_SIZE): logger.debug("Creating SubscriptionTask for document IDs: %s", document_ids) @@ -105,7 +79,7 @@ def process_subscription_job(job_id: int) -> None: job.save() -@app.periodic(cron=settings.SUBSCRIPTION_CRON) +# @app.periodic(cron=settings.SUBSCRIPTION_CRON) @app.task() def subscription_launcher(timestamp: int): logger.info("Launching SubscriptionJobs (Timestamp %s)", datetime.fromtimestamp(timestamp)) diff --git a/radis/subscriptions/templates/subscriptions/_subscribed_item_preview.html b/radis/subscriptions/templates/subscriptions/_subscribed_item_preview.html index 432891f5..894786f2 100644 --- a/radis/subscriptions/templates/subscriptions/_subscribed_item_preview.html +++ b/radis/subscriptions/templates/subscriptions/_subscribed_item_preview.html @@ -1,3 +1,4 @@ +{% load subscriptions_extras %}
@@ -17,6 +18,24 @@ x-cloak x-show="full">[Show summary]
+ {% if subscribed_item.extraction_results %} +
+
Extracted Fields
+
+ {# LLM-generated content - ensure HTML escaping for XSS protection #} + {% autoescape on %} + {% for field in subscribed_item.subscription.output_fields.all %} + {% with field_key=field.pk|stringformat:"s" %} +
{{ field.name }}
+
+ {{ subscribed_item.extraction_results|get_item:field_key|default_if_none:"—" }} +
+ {% endwith %} + {% endfor %} + {% endautoescape %} +
+
+ {% endif %}
{% include "reports/_report_buttons_panel.html" with report=subscribed_item.report %}
diff --git a/radis/subscriptions/templates/subscriptions/subscription_detail.html b/radis/subscriptions/templates/subscriptions/subscription_detail.html index 7002a2b1..d26dabe3 100644 --- a/radis/subscriptions/templates/subscriptions/subscription_detail.html +++ b/radis/subscriptions/templates/subscriptions/subscription_detail.html @@ -44,24 +44,8 @@

General

{{ subscription.last_refreshed }} -

Search Details

+

Filter Details

-
Search Provider
-
- {% if subscription.provider %} - {{ subscription.provider }} - {% else %} - – - {% endif %} -
-
Search Query
-
- {% if subscription.query %} - {{ subscription.query }} - {% else %} - – - {% endif %} -
Language
{% if subscription.language %} @@ -70,9 +54,6 @@

Search Details

– {% endif %}
-
-

Filter Details

-
Modalities
{{ subscription.modalities.all|join:", "|default:"–" }} @@ -118,8 +99,60 @@

Filter Details

{% endif %}
-

Questions

-
    - {% for question in subscription.questions.all %}
  • {{ question.question }}
  • {% endfor %} -
+

Filter Questions

+ {% with questions=subscription.filter_questions.all %} + {% if questions %} +
    + {% for question in questions %} +
  • + {{ question.question }} — accepts reports when answer is + {{ question.get_expected_answer_display|lower }} +
  • + {% endfor %} +
+ {% else %} +

No filter questions defined.

+ {% endif %} + {% endwith %} +

Extraction Fields

+ {% with fields=subscription.output_fields.all %} + {% if fields %} +
    + {% for field in fields %} +
  • +
    +
    Name
    +
    + {{ field.name }} +
    +
    Description
    +
    + {{ field.description }} +
    +
    Data Type
    +
    + {{ field.get_output_type_display }} +
    + {% if field.is_array %} +
    Array Output
    +
    + Yes — return multiple {{ field.get_output_type_display|lower }} values +
    + {% endif %} + {% if field.selection_options %} +
    Selections
    +
    +
      + {% for option in field.selection_options %}
    • {{ option }}
    • {% endfor %} +
    +
    + {% endif %} +
    +
  • + {% endfor %} +
+ {% else %} +

No extraction fields configured.

+ {% endif %} + {% endwith %} {% endblock content %} diff --git a/radis/subscriptions/templates/subscriptions/subscription_inbox.html b/radis/subscriptions/templates/subscriptions/subscription_inbox.html index 0e1eec5c..35a2f5ee 100644 --- a/radis/subscriptions/templates/subscriptions/subscription_inbox.html +++ b/radis/subscriptions/templates/subscriptions/subscription_inbox.html @@ -1,4 +1,6 @@ {% extends 'subscriptions/subscription_layout.html' %} +{% load crispy from crispy_forms_tags %} +{% load bootstrap_icon from common_extras %} {% block title %} Subscription Inbox {% endblock title %} @@ -6,7 +8,70 @@ {% endblock heading %} {% block content %} - {% for subscribed_item in object_list %} - {% include "subscriptions/_subscribed_item_preview.html" %} - {% endfor %} +
+
+ {# Sorting Controls and Download Button #} + + {# Card List #} + {% for subscribed_item in object_list %} + {% include "subscriptions/_subscribed_item_preview.html" %} + {% empty %} + + {% endfor %} + {# Pagination Controls #} +
+
+ +
+
+
+ {# Filter Sidebar #} +
+
+
+
+
Filters
+ {# Hidden inputs to preserve sort and pagination state #} + + + + {% crispy filter.form %} +
+
+
+
+
{% endblock content %} diff --git a/radis/subscriptions/templates/subscriptions/subscription_layout.html b/radis/subscriptions/templates/subscriptions/subscription_layout.html index 0b1b86c0..a0fdad2b 100644 --- a/radis/subscriptions/templates/subscriptions/subscription_layout.html +++ b/radis/subscriptions/templates/subscriptions/subscription_layout.html @@ -5,6 +5,9 @@ + @@ -12,5 +15,6 @@ {% block script %} {{ block.super }} + {% endblock script %} diff --git a/radis/subscriptions/templatetags/__init__.py b/radis/subscriptions/templatetags/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/radis/subscriptions/templatetags/__init__.py @@ -0,0 +1 @@ + diff --git a/radis/subscriptions/templatetags/subscriptions_extras.py b/radis/subscriptions/templatetags/subscriptions_extras.py new file mode 100644 index 00000000..ee0f3d7a --- /dev/null +++ b/radis/subscriptions/templatetags/subscriptions_extras.py @@ -0,0 +1,19 @@ +from collections.abc import Mapping +from typing import Any + +from django import template + +register = template.Library() + + +@register.filter +def get_item(mapping: Mapping[str, Any] | None, key: Any) -> Any: + """ + Safely retrieve a value from a mapping for template usage. + """ + if mapping is None: + return None + try: + return mapping.get(str(key)) + except AttributeError: + return None diff --git a/radis/subscriptions/tests/test_views.py b/radis/subscriptions/tests/test_views.py index 07fd10ef..f84e8163 100644 --- a/radis/subscriptions/tests/test_views.py +++ b/radis/subscriptions/tests/test_views.py @@ -1,11 +1,15 @@ +from datetime import datetime, timedelta, timezone + import pytest from adit_radis_shared.accounts.factories import GroupFactory, UserFactory from django.test import Client -from radis.reports.factories import LanguageFactory +from radis.extractions.factories import OutputFieldFactory +from radis.reports.factories import LanguageFactory, ReportFactory from radis.reports.models import Modality from radis.subscriptions.factories import ( - QuestionFactory, + FilterQuestionFactory, + SubscribedItemFactory, SubscriptionFactory, ) from radis.subscriptions.models import Subscription @@ -90,7 +94,6 @@ def test_subscription_create_view_post_valid(client: Client): data = { "name": "Test Subscription", - "query": "test query", "language": language.pk, "modalities": [modality.pk], "study_description": "Test study", @@ -99,11 +102,19 @@ def test_subscription_create_view_post_valid(client: Client): "age_till": 60, "patient_id": "12345", "send_finished_mail": True, - "questions-TOTAL_FORMS": "1", - "questions-INITIAL_FORMS": "0", - "questions-MIN_NUM_FORMS": "0", - "questions-MAX_NUM_FORMS": "3", - "questions-0-question": "What is the diagnosis?", + "filter_questions-TOTAL_FORMS": "1", + "filter_questions-INITIAL_FORMS": "0", + "filter_questions-MIN_NUM_FORMS": "0", + "filter_questions-MAX_NUM_FORMS": "3", + "filter_questions-0-question": "Does the report contain pneumothorax?", + "filter_questions-0-expected_answer": "Y", + "output_fields-TOTAL_FORMS": "1", + "output_fields-INITIAL_FORMS": "0", + "output_fields-MIN_NUM_FORMS": "0", + "output_fields-MAX_NUM_FORMS": "10", + "output_fields-0-name": "Pneumothorax status", + "output_fields-0-description": "Extract pneumothorax related findings", + "output_fields-0-output_type": "T", } response = client.post("/subscriptions/create/", data) @@ -112,6 +123,41 @@ def test_subscription_create_view_post_valid(client: Client): assert Subscription.objects.filter(name="Test Subscription").exists() +@pytest.mark.django_db +def test_subscription_create_view_ignores_empty_filter_question(client: Client): + user = UserFactory.create(is_active=True) + group = GroupFactory.create() + user.groups.add(group) + user.active_group = group + user.save() + + language = LanguageFactory.create(code="en") + + client.force_login(user) + + data = { + "name": "Subscription Without Filter", + "provider": "test_provider", + "language": language.pk, + "filter_questions-TOTAL_FORMS": "1", + "filter_questions-INITIAL_FORMS": "0", + "filter_questions-MIN_NUM_FORMS": "0", + "filter_questions-MAX_NUM_FORMS": "3", + "filter_questions-0-question": "", + "filter_questions-0-expected_answer": "", + "output_fields-TOTAL_FORMS": "0", + "output_fields-INITIAL_FORMS": "0", + "output_fields-MIN_NUM_FORMS": "0", + "output_fields-MAX_NUM_FORMS": "10", + } + + response = client.post("/subscriptions/create/", data) + assert response.status_code == 302 + + subscription = Subscription.objects.get(name="Subscription Without Filter") + assert subscription.filter_questions.count() == 0 + + @pytest.mark.django_db def test_subscription_create_view_post_duplicate_name(client: Client): user = UserFactory.create(is_active=True) @@ -124,12 +170,20 @@ def test_subscription_create_view_post_duplicate_name(client: Client): client.force_login(user) + language = LanguageFactory.create(code="en") + data = { "name": "Duplicate Name", - "questions-TOTAL_FORMS": "0", - "questions-INITIAL_FORMS": "0", - "questions-MIN_NUM_FORMS": "0", - "questions-MAX_NUM_FORMS": "3", + "provider": "test_provider", + "language": language.pk, + "filter_questions-TOTAL_FORMS": "0", + "filter_questions-INITIAL_FORMS": "0", + "filter_questions-MIN_NUM_FORMS": "0", + "filter_questions-MAX_NUM_FORMS": "3", + "output_fields-TOTAL_FORMS": "0", + "output_fields-INITIAL_FORMS": "0", + "output_fields-MIN_NUM_FORMS": "0", + "output_fields-MAX_NUM_FORMS": "10", } response = client.post("/subscriptions/create/", data) @@ -188,22 +242,33 @@ def test_subscription_update_view_unauthorized(client: Client): def test_subscription_update_view_post_valid(client: Client): user = UserFactory.create(is_active=True) subscription = create_test_subscription(owner=user, name="Original Name") - question = QuestionFactory.create(subscription=subscription) + question = FilterQuestionFactory.create(subscription=subscription) + output_field = OutputFieldFactory.create(subscription=subscription, job=None) client.force_login(user) data = { "name": "Updated Name", - "query": "updated query", "study_description": "Updated study", "patient_sex": "F", "send_finished_mail": False, - "questions-TOTAL_FORMS": "1", - "questions-INITIAL_FORMS": "1", - "questions-MIN_NUM_FORMS": "0", - "questions-MAX_NUM_FORMS": "3", - "questions-0-id": question.pk, - "questions-0-question": "Updated question?", + "filter_questions-TOTAL_FORMS": "1", + "filter_questions-INITIAL_FORMS": "1", + "filter_questions-MIN_NUM_FORMS": "0", + "filter_questions-MAX_NUM_FORMS": "3", + "filter_questions-0-id": question.pk, + "filter_questions-0-question": "Updated question?", + "filter_questions-0-expected_answer": "N", + "output_fields-TOTAL_FORMS": "1", + "output_fields-INITIAL_FORMS": "1", + "output_fields-MIN_NUM_FORMS": "0", + "output_fields-MAX_NUM_FORMS": "10", + "output_fields-0-id": output_field.pk, + "output_fields-0-name": "Volume", + "output_fields-0-description": "Volume description", + "output_fields-0-output_type": "N", + "output_fields-0-selection_options": "", + "output_fields-0-is_array": "false", } response = client.post(f"/subscriptions/{subscription.pk}/update/", data) @@ -300,3 +365,252 @@ def test_unauthenticated_access_redirects_to_login(client: Client): response = client.get(endpoint) assert response.status_code == 302 assert "/accounts/login/" in response["Location"] + + +# Subscription Inbox Pagination and Sorting Tests + + +@pytest.mark.django_db +def test_subscription_inbox_pagination(client: Client): + """Test that pagination works correctly in subscription inbox.""" + + user = UserFactory.create(is_active=True) + subscription = create_test_subscription(owner=user) + language = LanguageFactory.create(code="en") + + # Create 25 subscribed items + base_time = datetime.now(timezone.utc) + for i in range(25): + item = SubscribedItemFactory.create( + subscription=subscription, report=ReportFactory.create(language=language) + ) + # Set created_at to ensure consistent ordering + item.created_at = base_time - timedelta(hours=i) + item.save() + + client.force_login(user) + + # Test first page with default page size (10) + response = client.get(f"/subscriptions/{subscription.pk}/inbox/") + assert response.status_code == 200 + assert len(response.context["object_list"]) == 10 + assert response.context["page_obj"].number == 1 + assert response.context["page_obj"].paginator.num_pages == 3 + + # Test second page + response = client.get(f"/subscriptions/{subscription.pk}/inbox/?page=2") + assert response.status_code == 200 + assert response.context["page_obj"].number == 2 + + # Test custom page size + response = client.get(f"/subscriptions/{subscription.pk}/inbox/?per_page=25") + assert len(response.context["object_list"]) == 25 + assert response.context["page_obj"].paginator.num_pages == 1 + + +@pytest.mark.django_db +def test_subscription_inbox_sorting_by_created_date(client: Client): + """Test sorting by created_at date.""" + + user = UserFactory.create(is_active=True) + subscription = create_test_subscription(owner=user) + language = LanguageFactory.create(code="en") + + # Create items with different created_at times + item1 = SubscribedItemFactory.create( + subscription=subscription, report=ReportFactory.create(language=language) + ) + item2 = SubscribedItemFactory.create( + subscription=subscription, report=ReportFactory.create(language=language) + ) + item3 = SubscribedItemFactory.create( + subscription=subscription, report=ReportFactory.create(language=language) + ) + + # Manually set created_at to ensure order + base_time = datetime.now(timezone.utc) + item1.created_at = base_time - timedelta(days=2) + item1.save() + item2.created_at = base_time - timedelta(days=1) + item2.save() + item3.created_at = base_time + item3.save() + + client.force_login(user) + + # Test descending (newest first - default) + response = client.get(f"/subscriptions/{subscription.pk}/inbox/?sort_by=created_at&order=desc") + assert response.status_code == 200 + items = list(response.context["object_list"]) + assert items[0].pk == item3.pk + assert items[1].pk == item2.pk + assert items[2].pk == item1.pk + + # Test ascending (oldest first) + response = client.get(f"/subscriptions/{subscription.pk}/inbox/?sort_by=created_at&order=asc") + assert response.status_code == 200 + items = list(response.context["object_list"]) + assert items[0].pk == item1.pk + assert items[1].pk == item2.pk + assert items[2].pk == item3.pk + + +@pytest.mark.django_db +def test_subscription_inbox_sorting_by_study_date(client: Client): + """Test sorting by study date.""" + + user = UserFactory.create(is_active=True) + subscription = create_test_subscription(owner=user) + + # Create reports with different study dates + base_time = datetime.now(timezone.utc) + language = LanguageFactory.create(code="en") + report1 = ReportFactory.create(language=language, study_datetime=base_time - timedelta(days=10)) + report2 = ReportFactory.create(language=language, study_datetime=base_time - timedelta(days=5)) + report3 = ReportFactory.create(language=language, study_datetime=base_time - timedelta(days=1)) + + item1 = SubscribedItemFactory.create(subscription=subscription, report=report1) + item2 = SubscribedItemFactory.create(subscription=subscription, report=report2) + item3 = SubscribedItemFactory.create(subscription=subscription, report=report3) + + client.force_login(user) + + # Test descending (newest study date first) + response = client.get(f"/subscriptions/{subscription.pk}/inbox/?sort_by=study_date&order=desc") + assert response.status_code == 200 + items = list(response.context["object_list"]) + assert items[0].pk == item3.pk + assert items[1].pk == item2.pk + assert items[2].pk == item1.pk + + # Test ascending (oldest study date first) + response = client.get(f"/subscriptions/{subscription.pk}/inbox/?sort_by=study_date&order=asc") + assert response.status_code == 200 + items = list(response.context["object_list"]) + assert items[0].pk == item1.pk + assert items[1].pk == item2.pk + assert items[2].pk == item3.pk + + +@pytest.mark.django_db +def test_subscription_inbox_invalid_sort_parameters(client: Client): + """Test that invalid sort parameters fall back to defaults.""" + + user = UserFactory.create(is_active=True) + subscription = create_test_subscription(owner=user) + language = LanguageFactory.create(code="en") + + SubscribedItemFactory.create( + subscription=subscription, report=ReportFactory.create(language=language) + ) + + client.force_login(user) + + # Invalid sort_by should default to created_at + response = client.get( + f"/subscriptions/{subscription.pk}/inbox/?sort_by=invalid_field&order=desc" + ) + assert response.status_code == 200 + assert response.context["current_sort_by"] == "created_at" + + # Invalid order should default to desc + response = client.get( + f"/subscriptions/{subscription.pk}/inbox/?sort_by=created_at&order=invalid" + ) + assert response.status_code == 200 + assert response.context["current_order"] == "desc" + + +@pytest.mark.django_db +def test_subscription_inbox_filtering_by_patient_id(client: Client): + """Test filtering by patient ID.""" + + user = UserFactory.create(is_active=True) + subscription = create_test_subscription(owner=user) + language = LanguageFactory.create(code="en") + report1 = ReportFactory.create(language=language, patient_id="12345") + report2 = ReportFactory.create(language=language, patient_id="67890") + report3 = ReportFactory.create(language=language, patient_id="54321") + + item1 = SubscribedItemFactory.create(subscription=subscription, report=report1) + item2 = SubscribedItemFactory.create(subscription=subscription, report=report2) + item3 = SubscribedItemFactory.create(subscription=subscription, report=report3) + + client.force_login(user) + + # Filter by patient_id (partial match) + response = client.get(f"/subscriptions/{subscription.pk}/inbox/?patient_id=123") + assert response.status_code == 200 + items = list(response.context["object_list"]) + assert len(items) == 1 # Should match "12345" + assert item1 in items + assert item2 and item3 not in items + + +@pytest.mark.django_db +def test_subscription_inbox_filtering_by_date_range(client: Client): + """Test filtering by study date range.""" + + user = UserFactory.create(is_active=True) + subscription = create_test_subscription(owner=user) + + base_date = datetime.now(timezone.utc) + language = LanguageFactory.create(code="en") + report1 = ReportFactory.create(language=language, study_datetime=base_date - timedelta(days=10)) + report2 = ReportFactory.create(language=language, study_datetime=base_date - timedelta(days=5)) + report3 = ReportFactory.create(language=language, study_datetime=base_date - timedelta(days=1)) + + item1 = SubscribedItemFactory.create(subscription=subscription, report=report1) + item2 = SubscribedItemFactory.create(subscription=subscription, report=report2) + item3 = SubscribedItemFactory.create(subscription=subscription, report=report3) + + client.force_login(user) + + # Filter by date range + date_from = (base_date - timedelta(days=7)).strftime("%Y-%m-%d") + date_till = (base_date - timedelta(days=2)).strftime("%Y-%m-%d") + + response = client.get( + f"/subscriptions/{subscription.pk}/inbox/?study_date_from={date_from}&study_date_till={date_till}" + ) + assert response.status_code == 200 + items = list(response.context["object_list"]) + assert len(items) == 1 + assert item2 in items + assert item1 and item3 not in items + + +@pytest.mark.django_db +def test_subscription_inbox_combined_filter_and_sort(client: Client): + """Test that filtering and sorting work together.""" + + user = UserFactory.create(is_active=True) + subscription = create_test_subscription(owner=user) + + # Create items with same patient but different study dates + base_date = datetime.now(timezone.utc) + language = LanguageFactory.create(code="en") + report1 = ReportFactory.create( + language=language, patient_id="12345", study_datetime=base_date - timedelta(days=5) + ) + report2 = ReportFactory.create( + language=language, patient_id="12345", study_datetime=base_date - timedelta(days=1) + ) + report3 = ReportFactory.create(language=language, patient_id="67890", study_datetime=base_date) + + item1 = SubscribedItemFactory.create(subscription=subscription, report=report1) + item2 = SubscribedItemFactory.create(subscription=subscription, report=report2) + item3 = SubscribedItemFactory.create(subscription=subscription, report=report3) + + client.force_login(user) + + # Filter by patient_id and sort by study_date ascending + response = client.get( + f"/subscriptions/{subscription.pk}/inbox/?patient_id=12345&sort_by=study_date&order=asc" + ) + assert response.status_code == 200 + items = list(response.context["object_list"]) + assert len(items) == 2 + assert items[0].pk == item1.pk # Older study date + assert items[1].pk == item2.pk # Newer study date + assert item3 not in items # Different patient diff --git a/radis/subscriptions/tests/unit/test_processors.py b/radis/subscriptions/tests/unit/test_processors.py new file mode 100644 index 00000000..bb7416bb --- /dev/null +++ b/radis/subscriptions/tests/unit/test_processors.py @@ -0,0 +1,125 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import create_model + +from radis.chats.utils.testing_helpers import create_openai_client_mock +from radis.subscriptions.models import SubscribedItem +from radis.subscriptions.processors import SubscriptionTaskProcessor +from radis.subscriptions.utils.processor_utils import ( + get_filter_question_field_name, + get_output_field_name, +) +from radis.subscriptions.utils.testing_helpers import create_subscription_task + + +@pytest.mark.django_db(transaction=True) +def test_subscription_task_processor_filters_and_extracts(): + task, filter_question, output_field, report = create_subscription_task() + + filter_field_name = get_filter_question_field_name(filter_question) + extraction_field_name = get_output_field_name(output_field) + filter_field_definitions = {} + filter_field_definitions[filter_field_name] = (bool, ...) + + extraction_field_definitions = {} + extraction_field_definitions[extraction_field_name] = (str, ...) + + FilterOutput = create_model("FilterOutput", **filter_field_definitions) + ExtractionOutput = create_model("ExtractionOutput", **extraction_field_definitions) + + filter_output = FilterOutput(**{filter_field_name: True}) + extraction_output = ExtractionOutput(**{extraction_field_name: "Pneumothorax status confirmed"}) + + filter_response = MagicMock(choices=[MagicMock(message=MagicMock(parsed=filter_output))]) + extraction_response = MagicMock( + choices=[MagicMock(message=MagicMock(parsed=extraction_output))] + ) + + openai_mock = create_openai_client_mock(extraction_output) + openai_mock.beta.chat.completions.parse = MagicMock( + side_effect=[filter_response, extraction_response] + ) + with patch("openai.OpenAI", return_value=openai_mock): + SubscriptionTaskProcessor(task).start() + + subscribed_item = SubscribedItem.objects.get(subscription=task.job.subscription, report=report) + assert subscribed_item.filter_results == {str(filter_question.pk): True} + assert subscribed_item.extraction_results == { + str(output_field.pk): "Pneumothorax status confirmed" + } + + +@pytest.mark.django_db(transaction=True) +def test_subscription_task_processor_handles_llm_null_response(): + task, _, _, report = create_subscription_task() + + processor = SubscriptionTaskProcessor(task) + processor.client.extract_data = MagicMock(return_value=None) + + processor.start() + + assert not SubscribedItem.objects.filter( + subscription=task.job.subscription, report=report + ).exists() + processor.client.extract_data.assert_called_once() + + +@pytest.mark.django_db(transaction=True) +def test_subscription_task_processor_with_no_expected_answer(): + task, filter_question, _, report = create_subscription_task() + + filter_field_name = get_filter_question_field_name(filter_question) + filter_response = MagicMock() + setattr(filter_response, filter_field_name, None) + + processor = SubscriptionTaskProcessor(task) + processor.client.extract_data = MagicMock(return_value=filter_response) + + processor.start() + + assert not SubscribedItem.objects.filter( + subscription=task.job.subscription, report=report + ).exists() + processor.client.extract_data.assert_called_once() + + +@pytest.mark.django_db(transaction=True) +def test_subscription_task_processor_extraction_only(): + task, _, output_field, report = create_subscription_task() + task.job.subscription.filter_questions.all().delete() + + extraction_field_name = get_output_field_name(output_field) + extraction_field_definitions = {} + extraction_field_definitions[extraction_field_name] = (str, ...) + + ExtractionOutput = create_model("ExtractionOnlyOutput", **extraction_field_definitions) + extraction_output = ExtractionOutput(**{extraction_field_name: "Only extraction response"}) + + openai_mock = create_openai_client_mock(extraction_output) + with patch("openai.OpenAI", return_value=openai_mock): + SubscriptionTaskProcessor(task).start() + + subscribed_item = SubscribedItem.objects.get(subscription=task.job.subscription, report=report) + assert subscribed_item.filter_results is None + assert subscribed_item.extraction_results == {str(output_field.pk): "Only extraction response"} + + +@pytest.mark.django_db(transaction=True) +def test_subscription_task_processor_filter_only(): + task, filter_question, output_field, report = create_subscription_task() + output_field.delete() + + filter_field_name = get_filter_question_field_name(filter_question) + filter_field_definitions = {} + filter_field_definitions[filter_field_name] = (bool, ...) + FilterOutput = create_model("FilterOnlyOutput", **filter_field_definitions) + filter_output = FilterOutput(**{filter_field_name: True}) + + openai_mock = create_openai_client_mock(filter_output) + with patch("openai.OpenAI", return_value=openai_mock): + SubscriptionTaskProcessor(task).start() + + subscribed_item = SubscribedItem.objects.get(subscription=task.job.subscription, report=report) + assert subscribed_item.filter_results == {str(filter_question.pk): True} + assert subscribed_item.extraction_results is None diff --git a/radis/subscriptions/urls.py b/radis/subscriptions/urls.py index 48631378..419269a8 100644 --- a/radis/subscriptions/urls.py +++ b/radis/subscriptions/urls.py @@ -5,6 +5,7 @@ SubscriptionCreateView, SubscriptionDeleteView, SubscriptionDetailView, + SubscriptionInboxDownloadView, SubscriptionInboxView, SubscriptionListView, SubscriptionUpdateView, @@ -22,4 +23,9 @@ name="subscription_help", ), path("/inbox/", SubscriptionInboxView.as_view(), name="subscription_inbox"), + path( + "/inbox/download/", + SubscriptionInboxDownloadView.as_view(), + name="subscription_inbox_download", + ), ] diff --git a/radis/subscriptions/utils/csv_export.py b/radis/subscriptions/utils/csv_export.py new file mode 100644 index 00000000..a5e84279 --- /dev/null +++ b/radis/subscriptions/utils/csv_export.py @@ -0,0 +1,87 @@ +"""Helpers for exporting subscription inbox items in CSV format.""" + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from typing import Any + +from django.db.models import QuerySet + +from radis.subscriptions.models import SubscribedItem, Subscription + + +def _format_cell(value: Any) -> str: + """Format a single output value for CSV export.""" + if value is None: + return "" + if isinstance(value, bool): + return "yes" if value else "no" + return str(value) + + +def iter_subscribed_item_rows( + subscription: Subscription, queryset: QuerySet[SubscribedItem] +) -> Iterable[Sequence[str]]: + """Yield rows for the subscription inbox CSV. + + Args: + subscription: The subscription whose items should be exported. + queryset: Pre-filtered queryset of SubscribedItems to export. + + Yields: + Sequences of stringified cell values suitable for csv.writer. + """ + # Get output field names in PK order (to match dict keys) + field_names: list[str] = list( + subscription.output_fields.order_by("pk").values_list("name", flat=True) + ) + + # Pre-fetch field PKs to avoid N+1 query in the loop below + field_pks: list[int] = list( + subscription.output_fields.order_by("pk").values_list("pk", flat=True) + ) + + # Header row + header = [ + "subscribed_item_id", + "report_id", + "patient_id", + "study_date", + "study_description", + "modalities", + ] + header.extend(field_names) + yield header + + # Data rows - prefetch related fields for efficiency + items = queryset.select_related("report").prefetch_related( + "report__modalities", "subscription__output_fields" + ) + + for item in items.iterator(chunk_size=1000): + # Format modalities as comma-separated codes + modality_codes = ",".join( + item.report.modalities.order_by("code").values_list("code", flat=True) + ) + + # Format study date + study_date = "" + if item.report.study_datetime: + study_date = item.report.study_datetime.strftime("%Y-%m-%d") + + row = [ + str(item.pk), + str(item.report.pk), + item.report.patient_id or "", + study_date, + item.report.study_description or "", + modality_codes, + ] + + # Add extraction results (keyed by field PK as string) + extraction_results: dict[str, Any] = item.extraction_results or {} + for field_pk in field_pks: + value = extraction_results.get(str(field_pk)) + row.append(_format_cell(value)) + + yield row diff --git a/radis/subscriptions/utils/processor_utils.py b/radis/subscriptions/utils/processor_utils.py index c1bd39d3..3d9ee654 100644 --- a/radis/subscriptions/utils/processor_utils.py +++ b/radis/subscriptions/utils/processor_utils.py @@ -1,22 +1,35 @@ -from typing import Any +from __future__ import annotations + +from typing import Any, Iterable -from django.db.models import QuerySet from pydantic import BaseModel, create_model -from ..models import Question +from radis.extractions.models import OutputField + +from ..models import FilterQuestion + + +def get_filter_question_field_name(question: FilterQuestion) -> str: + return f"question_{question.pk}" -def generate_questions_schema(questions: QuerySet[Question]) -> type[BaseModel]: +def get_output_field_name(field: OutputField) -> str: + return field.name + + +def generate_filter_questions_schema(questions: Iterable[FilterQuestion]) -> type[BaseModel]: field_definitions: dict[str, Any] = {} - for index, _ in enumerate(questions.all()): - field_definitions[f"question_{index}"] = (bool, ...) - return create_model("QuestionsModel", **field_definitions) + for question in questions: + field_name = get_filter_question_field_name(question) + field_definitions[field_name] = (bool, ...) + model_name = "SubscriptionFilterResultsModel" + return create_model(model_name, **field_definitions) -def generate_questions_for_prompt(fields: QuerySet[Question]) -> str: - prompt = "" - for index, question in enumerate(fields.all()): - prompt += f"question_{index}: {question.question}\n" +def generate_filter_questions_prompt(questions: Iterable[FilterQuestion]) -> str: + prompt = "" + for question in questions: + prompt += f"{get_filter_question_field_name(question)}: {question.question}\n" return prompt diff --git a/radis/subscriptions/utils/testing_helpers.py b/radis/subscriptions/utils/testing_helpers.py new file mode 100644 index 00000000..9ec03e91 --- /dev/null +++ b/radis/subscriptions/utils/testing_helpers.py @@ -0,0 +1,43 @@ +from adit_radis_shared.accounts.factories import GroupFactory, UserFactory +from adit_radis_shared.common.utils.testing_helpers import add_user_to_group + +from radis.extractions.factories import OutputFieldFactory +from radis.extractions.models import OutputType +from radis.reports.factories import LanguageFactory, ReportFactory +from radis.subscriptions.factories import FilterQuestionFactory, SubscriptionFactory +from radis.subscriptions.models import FilterQuestion, SubscriptionJob, SubscriptionTask + + +def create_subscription_task(): + language = LanguageFactory.create(code="en") + + user = UserFactory(is_active=True) + group = GroupFactory() + add_user_to_group(user, group) + user.active_group = group + user.save() + + subscription = SubscriptionFactory.create(owner=user, group=group, language=language) + + filter_question = FilterQuestionFactory.create( + subscription=subscription, expected_answer=FilterQuestion.ExpectedAnswer.YES + ) + output_field = OutputFieldFactory.create( + subscription=subscription, + job=None, + output_type=OutputType.TEXT, + ) + + job = SubscriptionJob.objects.create( + subscription=subscription, + owner=user, + owner_id=user.id, + status=SubscriptionJob.Status.PENDING, + ) + task = SubscriptionTask.objects.create(job=job, status=SubscriptionTask.Status.PENDING) + + report = ReportFactory.create(language=language, body="Pneumothorax observed.") + report.groups.add(group) + task.reports.add(report) + + return task, filter_question, output_field, report diff --git a/radis/subscriptions/views.py b/radis/subscriptions/views.py index 852bdb8e..afe4529e 100644 --- a/radis/subscriptions/views.py +++ b/radis/subscriptions/views.py @@ -1,6 +1,9 @@ +import csv +from collections.abc import Generator from logging import getLogger from typing import Any, Type, cast +from adit_radis_shared.accounts.models import User from adit_radis_shared.common.mixins import ( PageSizeSelectMixin, RelatedFilterMixin, @@ -9,19 +12,26 @@ from adit_radis_shared.common.types import AuthenticatedHttpRequest from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.messages.views import SuccessMessageMixin -from django.db import IntegrityError -from django.db.models import Count, QuerySet +from django.db import IntegrityError, transaction +from django.db.models import Count, F, Q, QuerySet from django.forms.models import BaseInlineFormSet -from django.http import HttpResponse, HttpResponseRedirect +from django.http import HttpResponse, HttpResponseRedirect, StreamingHttpResponse from django.urls import reverse, reverse_lazy +from django.utils import timezone +from django.utils.text import slugify from django.views.generic import CreateView, DeleteView, DetailView, UpdateView from django_tables2 import SingleTableView -from radis.subscriptions.filters import SubscriptionFilter +from radis.subscriptions.filters import SubscribedItemFilter, SubscriptionFilter from radis.subscriptions.tables import SubscriptionTable -from .forms import QuestionForm, QuestionFormSet, SubscriptionForm -from .models import Question, SubscribedItem, Subscription +from .forms import ( + FilterQuestionFormSet, + OutputFieldFormSet, + SubscriptionForm, +) +from .models import SubscribedItem, Subscription +from .utils.csv_export import iter_subscribed_item_rows logger = getLogger(__name__) @@ -41,6 +51,13 @@ def get_queryset(self) -> QuerySet[Subscription]: return ( Subscription.objects.filter(owner=self.request.user) .annotate(num_reports=Count("items")) + .annotate( + num_new_reports=Count( + "items", + filter=Q(items__created_at__gt=F("last_viewed_at")) + | Q(last_viewed_at__isnull=True), + ) + ) .order_by("-created_at") ) @@ -50,7 +67,12 @@ class SubscriptionDetailView(LoginRequiredMixin, DetailView): template_name = "subscriptions/subscription_detail.html" def get_queryset(self): - return super().get_queryset().filter(owner=self.request.user).prefetch_related("questions") + return ( + super() + .get_queryset() + .filter(owner=self.request.user) + .prefetch_related("filter_questions", "output_fields") + ) class SubscriptionCreateView(LoginRequiredMixin, CreateView): # TODO: Add PermissionRequiredMixin @@ -62,15 +84,19 @@ class SubscriptionCreateView(LoginRequiredMixin, CreateView): # TODO: Add Permi def get_context_data(self, **kwargs: Any) -> dict[str, Any]: ctx = super().get_context_data(**kwargs) if self.request.POST: - ctx["formset"] = QuestionFormSet(self.request.POST) + ctx["filter_formset"] = FilterQuestionFormSet(self.request.POST) + ctx["output_formset"] = OutputFieldFormSet(self.request.POST) else: - ctx["formset"] = QuestionFormSet() + ctx["filter_formset"] = FilterQuestionFormSet() + ctx["output_formset"] = OutputFieldFormSet() return ctx + @transaction.atomic() def form_valid(self, form) -> HttpResponse: ctx = self.get_context_data() - formset: BaseInlineFormSet[Question, Subscription, QuestionForm] = ctx["formset"] - if formset.is_valid(): + filter_formset: BaseInlineFormSet = ctx["filter_formset"] + output_formset: BaseInlineFormSet = ctx["output_formset"] + if filter_formset.is_valid() and output_formset.is_valid(): user = self.request.user form.instance.owner = user active_group = user.active_group @@ -80,12 +106,15 @@ def form_valid(self, form) -> HttpResponse: self.object: Subscription = form.save() except IntegrityError as e: if "unique_subscription_name_per_user" in str(e): - form.add_error("name", "An subscription with this name already exists.") + form.add_error("name", "A subscription with this name already exists.") return self.form_invalid(form) raise e - formset.instance = self.object - formset.save() + filter_formset.instance = self.object + filter_formset.save() + + output_formset.instance = self.object + output_formset.save() return HttpResponseRedirect(self.get_success_url()) else: return self.form_invalid(form) @@ -101,31 +130,44 @@ def get_success_url(self): return reverse("subscription_detail", kwargs={"pk": self.object.pk}) def get_queryset(self) -> QuerySet[Subscription]: - return super().get_queryset().filter(owner=self.request.user).prefetch_related("questions") + return ( + super() + .get_queryset() + .filter(owner=self.request.user) + .prefetch_related("filter_questions", "output_fields") + ) def get_context_data(self, **kwargs: Any) -> dict[str, Any]: ctx = super().get_context_data(**kwargs) if self.request.POST: - ctx["formset"] = QuestionFormSet(self.request.POST, instance=self.object) + ctx["filter_formset"] = FilterQuestionFormSet(self.request.POST, instance=self.object) + ctx["output_formset"] = OutputFieldFormSet(self.request.POST, instance=self.object) else: - ctx["formset"] = QuestionFormSet(instance=self.object) - ctx["formset"].extra = 0 # no additional empty form when editing + ctx["filter_formset"] = FilterQuestionFormSet(instance=self.object) + ctx["output_formset"] = OutputFieldFormSet(instance=self.object) + ctx["filter_formset"].extra = 0 # no additional empty form when editing + ctx["output_formset"].extra = 0 return ctx + @transaction.atomic() def form_valid(self, form) -> HttpResponse: ctx = self.get_context_data() - formset: BaseInlineFormSet[Question, Subscription, QuestionForm] = ctx["formset"] - if formset.is_valid(): + filter_formset = ctx["filter_formset"] + output_formset = ctx["output_formset"] + if filter_formset.is_valid() and output_formset.is_valid(): try: self.object = form.save() except IntegrityError as e: if "unique_subscription_name_per_user" in str(e): - form.add_error("name", "An subscription with this name already exists.") + form.add_error("name", "A subscription with this name already exists.") return self.form_invalid(form) raise e - formset.instance = self.object - formset.save() + filter_formset.instance = self.object + filter_formset.save() + + output_formset.instance = self.object + output_formset.save() return super().form_valid(form) else: @@ -150,20 +192,191 @@ class SubscriptionInboxView( ): model = Subscription template_name = "subscriptions/subscription_inbox.html" + filterset_class = SubscribedItemFilter + paginate_by = 10 + page_sizes = [10, 25, 50] + + def get_queryset(self) -> QuerySet[Subscription]: + assert self.model + model = cast(Type[Subscription], self.model) + user = cast(User, self.request.user) + if user.is_staff: + return model.objects.all() + return model.objects.filter(owner=self.request.user) + + def get_ordering(self) -> str: + """Get the ordering from query parameters, defaulting to -created_at.""" + sort_by = self.request.GET.get("sort_by", "created_at") + order = self.request.GET.get("order", "desc") + + # Define allowed sort fields to prevent injection + allowed_fields = { + "created_at": "created_at", + "study_date": "report__study_datetime", + } + + # Validate sort_by parameter + if sort_by not in allowed_fields: + sort_by = "created_at" + + # Validate order parameter + if order not in ["asc", "desc"]: + order = "desc" + + field = allowed_fields[sort_by] + return field if order == "asc" else f"-{field}" + + def get_related_queryset(self) -> QuerySet[SubscribedItem]: + subscription = cast(Subscription, self.get_object()) + ordering = self.get_ordering() + return ( + SubscribedItem.objects.filter(subscription_id=subscription.pk) + .select_related("subscription") + .prefetch_related( + "report", + "subscription__output_fields", + ) + .order_by(ordering) + ) + + def get_filter_queryset(self) -> QuerySet[SubscribedItem]: + return self.get_related_queryset() + + def get_context_data(self, **kwargs: Any) -> dict[str, Any]: + context = super().get_context_data(**kwargs) + + # Get validated sort parameters + sort_by = self.request.GET.get("sort_by", "created_at") + order = self.request.GET.get("order", "desc") + + # Define allowed sort fields (same as in get_ordering) + allowed_fields = { + "created_at": "created_at", + "study_date": "report__study_datetime", + } + + # Validate and default if invalid + if sort_by not in allowed_fields: + sort_by = "created_at" + if order not in ["asc", "desc"]: + order = "desc" + + # Add validated sort parameters to context for template rendering + context["current_sort_by"] = sort_by + context["current_order"] = order + + # Update last_viewed_at to mark all current reports as seen + subscription = cast(Subscription, self.object) + subscription.last_viewed_at = timezone.now() + subscription.save(update_fields=["last_viewed_at"]) + + return context + + +class _Echo: + """Lightweight write-only buffer for csv.writer.""" + + def write(self, value: str) -> str: + return value + + +class SubscriptionInboxDownloadView(LoginRequiredMixin, RelatedFilterMixin, DetailView): + """Stream subscription inbox items as a CSV download. + + Applies the same filters as SubscriptionInboxView to ensure users + download exactly what they see (respecting filters but ignoring pagination). + """ + + model = Subscription + filterset_class = SubscribedItemFilter request: AuthenticatedHttpRequest def get_queryset(self) -> QuerySet[Subscription]: + """Return only subscriptions owned by the current user.""" assert self.model model = cast(Type[Subscription], self.model) - if self.request.user.is_staff: + user = cast(User, self.request.user) + if user.is_staff: return model.objects.all() return model.objects.filter(owner=self.request.user) + def get_ordering(self) -> str: + """Get the ordering from query parameters (same logic as SubscriptionInboxView).""" + sort_by = self.request.GET.get("sort_by", "created_at") + order = self.request.GET.get("order", "desc") + + # Define allowed sort fields to prevent injection + allowed_fields = { + "created_at": "created_at", + "study_date": "report__study_datetime", + } + + # Validate sort_by parameter + if sort_by not in allowed_fields: + sort_by = "created_at" + + # Validate order parameter + if order not in ["asc", "desc"]: + order = "desc" + + field = allowed_fields[sort_by] + return field if order == "asc" else f"-{field}" + def get_related_queryset(self) -> QuerySet[SubscribedItem]: + """Build queryset matching the inbox view (for filtering).""" subscription = cast(Subscription, self.get_object()) - return SubscribedItem.objects.filter(subscription_id=subscription.pk).prefetch_related( - "report" + ordering = self.get_ordering() + return ( + SubscribedItem.objects.filter(subscription_id=subscription.pk) + .exclude(extraction_results__isnull=True) # Only items with results + .exclude(extraction_results={}) # Only items with non-empty results + .select_related("subscription") + .prefetch_related( + "report", + "report__modalities", + "subscription__output_fields", + ) + .order_by(ordering) ) def get_filter_queryset(self) -> QuerySet[SubscribedItem]: + """Required by RelatedFilterMixin.""" return self.get_related_queryset() + + def get(self, request: AuthenticatedHttpRequest, *args, **kwargs) -> StreamingHttpResponse: + """Stream the CSV file response.""" + subscription = cast(Subscription, self.get_object()) + + # Manually instantiate the filterset to apply filters + # (RelatedFilterMixin doesn't provide get_filtered_queryset()) + filterset_class = self.get_filterset_class() + filterset_kwargs = self.get_filterset_kwargs(filterset_class) + assert filterset_class is not None + filterset = filterset_class(**filterset_kwargs) + + # Get the filtered queryset from filterset.qs + filtered_items = filterset.qs + + filename = self._build_filename(subscription) + + response = StreamingHttpResponse( + self._stream_rows(subscription, filtered_items), + content_type="text/csv", + ) + response["Content-Disposition"] = f'attachment; filename="{filename}"' + return response + + def _stream_rows( + self, subscription: Subscription, items: QuerySet[SubscribedItem] + ) -> Generator[str, None, None]: + """Yield serialized CSV rows for the response.""" + pseudo_buffer = _Echo() + writer = csv.writer(pseudo_buffer) + yield "\ufeff" # UTF-8 BOM for Excel compatibility + for row in iter_subscribed_item_rows(subscription, items): + yield writer.writerow(row) + + def _build_filename(self, subscription: Subscription) -> str: + """Generate a descriptive CSV filename for the subscription.""" + slug = slugify(subscription.name) or "inbox" + return f"subscription_{subscription.pk}_{slug}.csv" diff --git a/radis/urls.py b/radis/urls.py index a131cee5..12984602 100644 --- a/radis/urls.py +++ b/radis/urls.py @@ -14,7 +14,7 @@ 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ -from django.conf import settings +from django.apps import apps from django.contrib import admin from django.urls import include, path @@ -35,11 +35,15 @@ path("subscriptions/", include("radis.subscriptions.urls")), ] -# Debug Toolbar in Debug mode only -if settings.DEBUG: - import debug_toolbar - +# Some Django test runners force `DEBUG=False` even if the settings module enables it. +# If these apps/middlewares are installed, we must still include their URLs so +# templates can reverse them without raising `NoReverseMatch`. +if apps.is_installed("django_browser_reload"): urlpatterns = [ path("__reload__/", include("django_browser_reload.urls")), - path("__debug__/", include(debug_toolbar.urls)), + ] + urlpatterns + +if apps.is_installed("debug_toolbar"): + urlpatterns = [ + path("__debug__/", include("debug_toolbar.urls")), ] + urlpatterns