diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000..2087cec138 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,464 @@ +# EditorConfig to support per-solution formatting. +# Use the EditorConfig VS add-in to make this work. +# http://editorconfig.org/ +# +# Resources for what's supported for .NET/C# +# https://kent-boogaart.com/blog/editorconfig-reference-for-c-developers +# https://learn.microsoft.com/visualstudio/ide/editorconfig-code-style-settings-reference + +root = true + +[*] +indent_style = space +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.cs] +indent_size = 4 +dotnet_sort_system_directives_first = true + +# Don't use this. qualifier +dotnet_style_qualification_for_field = false:warning +dotnet_style_qualification_for_property = false:warning + +# use int x = .. over Int32 +dotnet_style_predefined_type_for_locals_parameters_members = true:warning + +# use int.MaxValue over Int32.MaxValue +dotnet_style_predefined_type_for_member_access = true:warning + +# Require var all the time (modern C#) +csharp_style_var_for_built_in_types = true:warning +csharp_style_var_when_type_is_apparent = true:warning +csharp_style_var_elsewhere = true:warning + +# Disallow throw expressions in most cases +csharp_style_throw_expression = false:suggestion + +# Newline settings - Allman style (braces on new line) +csharp_new_line_before_open_brace = all +csharp_new_line_before_else = true +csharp_new_line_before_catch = true +csharp_new_line_before_finally = true +csharp_new_line_before_members_in_object_initializers = true +csharp_new_line_before_members_in_anonymous_types = true + +# Indentation settings +csharp_indent_case_contents = true +csharp_indent_switch_labels = true +csharp_indent_labels = no_change +csharp_indent_block_contents = true +csharp_indent_braces = false +csharp_indent_case_contents_when_block = false + +# Spacing settings +csharp_space_after_cast = false +csharp_space_after_keywords_in_control_flow_statements = true +csharp_space_between_method_call_parameter_list_parentheses = false +csharp_space_between_method_declaration_parameter_list_parentheses = false +csharp_space_between_parentheses = false +csharp_space_before_colon_in_inheritance_clause = true +csharp_space_after_colon_in_inheritance_clause = true +csharp_space_around_binary_operators = before_and_after +csharp_space_between_method_declaration_empty_parameter_list_parentheses = false +csharp_space_between_method_call_name_and_opening_parenthesis = false +csharp_space_between_method_call_empty_parameter_list_parentheses = false +csharp_space_after_comma = true +csharp_space_after_dot = false +csharp_space_after_semicolon_in_for_statement = true +csharp_space_before_semicolon_in_for_statement = false +csharp_space_around_declaration_statements = false +csharp_space_before_open_square_brackets = false +csharp_space_between_empty_square_brackets = false +csharp_space_between_square_brackets = false + +# Wrapping settings +csharp_preserve_single_line_statements = false +csharp_preserve_single_line_blocks = true + +# Namespace settings - C# 10+ file-scoped namespaces +csharp_style_namespace_declarations = file_scoped:warning + +# Brace settings - ALWAYS use braces, even for single-line blocks +csharp_prefer_braces = true:warning + +# Expression-bodied members (modern C#) +csharp_style_expression_bodied_methods = when_on_single_line:suggestion +csharp_style_expression_bodied_constructors = false:suggestion +csharp_style_expression_bodied_operators = when_on_single_line:suggestion +csharp_style_expression_bodied_properties = true:suggestion +csharp_style_expression_bodied_indexers = true:suggestion +csharp_style_expression_bodied_accessors = true:suggestion +csharp_style_expression_bodied_lambdas = true:suggestion +csharp_style_expression_bodied_local_functions = when_on_single_line:suggestion + +# Pattern matching (C# 7+) +csharp_style_pattern_matching_over_is_with_cast_check = true:warning +csharp_style_pattern_matching_over_as_with_null_check = true:warning +csharp_style_prefer_switch_expression = true:suggestion +csharp_style_prefer_pattern_matching = true:suggestion +csharp_style_prefer_not_pattern = true:warning +csharp_style_prefer_extended_property_pattern = true:suggestion + +# Null checking (C# 6+) +csharp_style_conditional_delegate_call = true:warning +dotnet_style_coalesce_expression = true:warning +dotnet_style_null_propagation = true:warning +dotnet_style_prefer_is_null_check_over_reference_equality_method = true:warning + +# Modern C# features +csharp_prefer_simple_using_statement = true:warning +csharp_style_prefer_method_group_conversion = true:suggestion +csharp_style_prefer_top_level_statements = false:suggestion +csharp_style_prefer_primary_constructors = true:warning +csharp_style_prefer_local_over_anonymous_function = true:suggestion +csharp_style_prefer_tuple_swap = true:warning +csharp_style_implicit_object_creation_when_type_is_apparent = true:warning +csharp_style_prefer_utf8_string_literals = true:suggestion +csharp_style_prefer_readonly_struct = true:warning +csharp_style_prefer_readonly_struct_member = true:warning + +# Use collection expressions (C# 12+) +dotnet_style_prefer_collection_expression = when_types_loosely_match:suggestion + +# Code block preferences +csharp_prefer_simple_default_expression = true:suggestion +dotnet_style_prefer_compound_assignment = true:warning +dotnet_style_prefer_simplified_boolean_expressions = true:warning +dotnet_style_prefer_conditional_expression_over_assignment = false:suggestion +dotnet_style_prefer_conditional_expression_over_return = false:suggestion +dotnet_style_prefer_inferred_tuple_names = true:suggestion +dotnet_style_prefer_inferred_anonymous_type_member_names = true:suggestion +dotnet_style_prefer_auto_properties = true:warning +dotnet_style_prefer_simplified_interpolation = true:suggestion + +# Object/collection initializers +dotnet_style_object_initializer = true:warning +dotnet_style_collection_initializer = true:warning +dotnet_style_explicit_tuple_names = true:warning + +# Parameter preferences +dotnet_code_quality_unused_parameters = all:warning +csharp_style_unused_value_expression_statement_preference = discard_variable:suggestion +csharp_style_unused_value_assignment_preference = discard_variable:suggestion + +# this. and Me. preferences - don't use this. unless required +dotnet_style_qualification_for_event = false:warning +dotnet_style_qualification_for_method = false:warning + +# Modifier preferences +dotnet_style_require_accessibility_modifiers = for_non_interface_members:warning +dotnet_style_readonly_field = true:warning +csharp_preferred_modifier_order = public,private,protected,internal,static,extern,new,virtual,abstract,sealed,override,readonly,unsafe,volatile,async:warning + +# Parentheses preferences - be explicit for clarity +dotnet_style_parentheses_in_arithmetic_binary_operators = always_for_clarity:suggestion +dotnet_style_parentheses_in_relational_binary_operators = always_for_clarity:suggestion +dotnet_style_parentheses_in_other_binary_operators = always_for_clarity:suggestion +dotnet_style_parentheses_in_other_operators = never_if_unnecessary:suggestion + +# Naming conventions - name all constant fields using PascalCase +dotnet_naming_rule.constant_fields_should_be_pascal_case.severity = warning +dotnet_naming_rule.constant_fields_should_be_pascal_case.symbols = constant_fields +dotnet_naming_rule.constant_fields_should_be_pascal_case.style = pascal_case_style +dotnet_naming_symbols.constant_fields.applicable_kinds = field +dotnet_naming_symbols.constant_fields.required_modifiers = const +dotnet_naming_style.pascal_case_style.capitalization = pascal_case + +# internal and private fields should be _camelCase +dotnet_naming_rule.camel_case_for_private_internal_fields.severity = warning +dotnet_naming_rule.camel_case_for_private_internal_fields.symbols = private_internal_fields +dotnet_naming_rule.camel_case_for_private_internal_fields.style = camel_case_underscore_style +dotnet_naming_symbols.private_internal_fields.applicable_kinds = field +dotnet_naming_symbols.private_internal_fields.applicable_accessibilities = private, internal +dotnet_naming_style.camel_case_underscore_style.required_prefix = _ +dotnet_naming_style.camel_case_underscore_style.capitalization = camel_case + +# Async methods should end with Async +dotnet_naming_rule.async_methods_should_end_with_async.severity = warning +dotnet_naming_rule.async_methods_should_end_with_async.symbols = async_methods +dotnet_naming_rule.async_methods_should_end_with_async.style = end_in_async_style +dotnet_naming_symbols.async_methods.applicable_kinds = method +dotnet_naming_symbols.async_methods.required_modifiers = async +dotnet_naming_style.end_in_async_style.required_suffix = Async +dotnet_naming_style.end_in_async_style.capitalization = pascal_case + +# Interfaces should start with I +dotnet_naming_rule.interfaces_should_start_with_i.severity = warning +dotnet_naming_rule.interfaces_should_start_with_i.symbols = interfaces +dotnet_naming_rule.interfaces_should_start_with_i.style = i_prefix_style +dotnet_naming_symbols.interfaces.applicable_kinds = interface +dotnet_naming_style.i_prefix_style.required_prefix = I +dotnet_naming_style.i_prefix_style.capitalization = pascal_case + +[*.{xml,config,*proj,nuspec,props,resx,targets,yml,tasks}] +indent_size = 2 + +[*.{props,targets,ruleset,config,nuspec,resx,vsixmanifest,vsct}] +indent_size = 2 + +[*.json] +indent_size = 2 + +[*.{ps1,psm1}] +indent_size = 4 + +[*.sh] +indent_size = 4 +end_of_line = lf + +[*.{razor,cshtml}] +charset = utf-8-bom + +[*.{cs,vb}] + +# SYSLIB1054: Use 'LibraryImportAttribute' instead of 'DllImportAttribute' to generate P/Invoke marshalling code at compile time +dotnet_diagnostic.SYSLIB1054.severity = warning + +# CA1018: Mark attributes with AttributeUsageAttribute +dotnet_diagnostic.CA1018.severity = warning + +# CA1047: Do not declare protected member in sealed type +dotnet_diagnostic.CA1047.severity = warning + +# CA1305: Specify IFormatProvider +dotnet_diagnostic.CA1305.severity = warning + +# CA1507: Use nameof to express symbol names (critical for AGENTS.md) +dotnet_diagnostic.CA1507.severity = warning + +# CA1510: Use ArgumentNullException throw helper +dotnet_diagnostic.CA1510.severity = warning + +# CA1511: Use ArgumentException throw helper +dotnet_diagnostic.CA1511.severity = warning + +# CA1512: Use ArgumentOutOfRangeException throw helper +dotnet_diagnostic.CA1512.severity = warning + +# CA1513: Use ObjectDisposedException throw helper +dotnet_diagnostic.CA1513.severity = warning + +# CA1725: Parameter names should match base declaration +dotnet_diagnostic.CA1725.severity = suggestion + +# CA1802: Use literals where appropriate +dotnet_diagnostic.CA1802.severity = warning + +# CA1805: Do not initialize unnecessarily +dotnet_diagnostic.CA1805.severity = warning + +# CA1810: Do not initialize static fields unnecessarily +dotnet_diagnostic.CA1810.severity = warning + +# CA1821: Remove empty Finalizers +dotnet_diagnostic.CA1821.severity = warning + +# CA1822: Make member static +dotnet_diagnostic.CA1822.severity = warning +dotnet_code_quality.CA1822.api_surface = private, internal + +# CA1823: Avoid unused private fields +dotnet_diagnostic.CA1823.severity = warning + +# CA1825: Avoid zero-length array allocations +dotnet_diagnostic.CA1825.severity = warning + +# CA1826: Do not use Enumerable methods on indexable collections +dotnet_diagnostic.CA1826.severity = warning + +# CA1827: Do not use Count() or LongCount() when Any() can be used +dotnet_diagnostic.CA1827.severity = warning + +# CA1828: Do not use CountAsync() or LongCountAsync() when AnyAsync() can be used +dotnet_diagnostic.CA1828.severity = warning + +# CA1829: Use Length/Count property instead of Count() when available +dotnet_diagnostic.CA1829.severity = warning + +# CA1830: Prefer strongly-typed Append and Insert method overloads on StringBuilder +dotnet_diagnostic.CA1830.severity = warning + +# CA1831-CA1833: Use AsSpan or AsMemory instead of Range-based indexers when appropriate +dotnet_diagnostic.CA1831.severity = warning +dotnet_diagnostic.CA1832.severity = warning +dotnet_diagnostic.CA1833.severity = warning + +# CA1834: Consider using 'StringBuilder.Append(char)' when applicable +dotnet_diagnostic.CA1834.severity = warning + +# CA1835: Prefer the 'Memory'-based overloads for 'ReadAsync' and 'WriteAsync' +dotnet_diagnostic.CA1835.severity = warning + +# CA1836: Prefer IsEmpty over Count +dotnet_diagnostic.CA1836.severity = warning + +# CA1837: Use 'Environment.ProcessId' +dotnet_diagnostic.CA1837.severity = warning + +# CA1838: Avoid 'StringBuilder' parameters for P/Invokes +dotnet_diagnostic.CA1838.severity = warning + +# CA1839: Use 'Environment.ProcessPath' +dotnet_diagnostic.CA1839.severity = warning + +# CA1840: Use 'Environment.CurrentManagedThreadId' +dotnet_diagnostic.CA1840.severity = warning + +# CA1841: Prefer Dictionary.Contains methods +dotnet_diagnostic.CA1841.severity = warning + +# CA1842: Do not use 'WhenAll' with a single task +dotnet_diagnostic.CA1842.severity = warning + +# CA1843: Do not use 'WaitAll' with a single task +dotnet_diagnostic.CA1843.severity = warning + +# CA1844: Provide memory-based overrides of async methods when subclassing 'Stream' +dotnet_diagnostic.CA1844.severity = warning + +# CA1845: Use span-based 'string.Concat' +dotnet_diagnostic.CA1845.severity = warning + +# CA1846: Prefer AsSpan over Substring +dotnet_diagnostic.CA1846.severity = warning + +# CA1847: Use string.Contains(char) instead of string.Contains(string) with single characters +dotnet_diagnostic.CA1847.severity = warning + +# CA1852: Seal internal types +dotnet_diagnostic.CA1852.severity = warning + +# CA1854: Prefer the IDictionary.TryGetValue(TKey, out TValue) method +dotnet_diagnostic.CA1854.severity = warning + +# CA1855: Prefer 'Clear' over 'Fill' +dotnet_diagnostic.CA1855.severity = warning + +# CA1856: Incorrect usage of ConstantExpected attribute +dotnet_diagnostic.CA1856.severity = error + +# CA1857: A constant is expected for the parameter +dotnet_diagnostic.CA1857.severity = warning + +# CA1858: Use 'StartsWith' instead of 'IndexOf' +dotnet_diagnostic.CA1858.severity = warning + +# CA2007: DISABLED - Never use ConfigureAwait(false) per AGENTS.md +dotnet_diagnostic.CA2007.severity = none + +# CA2008: Do not create tasks without passing a TaskScheduler +dotnet_diagnostic.CA2008.severity = warning + +# CA2009: Do not call ToImmutableCollection on an ImmutableCollection value +dotnet_diagnostic.CA2009.severity = warning + +# CA2011: Avoid infinite recursion +dotnet_diagnostic.CA2011.severity = warning + +# CA2012: Use ValueTask correctly +dotnet_diagnostic.CA2012.severity = warning + +# CA2013: Do not use ReferenceEquals with value types +dotnet_diagnostic.CA2013.severity = warning + +# CA2014: Do not use stackalloc in loops +dotnet_diagnostic.CA2014.severity = warning + +# CA2016: Forward the 'CancellationToken' parameter to methods that take one +dotnet_diagnostic.CA2016.severity = warning + +# CA2022: Avoid inexact read with `Stream.Read` +dotnet_diagnostic.CA2022.severity = warning + +# CA2200: Rethrow to preserve stack details +dotnet_diagnostic.CA2200.severity = warning + +# CA2201: Do not raise reserved exception types +dotnet_diagnostic.CA2201.severity = warning + +# CA2208: Instantiate argument exceptions correctly +dotnet_diagnostic.CA2208.severity = warning + +# CA2245: Do not assign a property to itself +dotnet_diagnostic.CA2245.severity = warning + +# CA2246: Assigning symbol and its member in the same statement +dotnet_diagnostic.CA2246.severity = warning + +# CA2249: Use string.Contains instead of string.IndexOf to improve readability +dotnet_diagnostic.CA2249.severity = warning + +# IDE0005: Remove unnecessary usings +dotnet_diagnostic.IDE0005.severity = warning + +# IDE0011: Curly braces to surround blocks of code +dotnet_diagnostic.IDE0011.severity = warning + +# IDE0020: Use pattern matching to avoid is check followed by a cast (with variable) +dotnet_diagnostic.IDE0020.severity = warning + +# IDE0029: Use coalesce expression (non-nullable types) +dotnet_diagnostic.IDE0029.severity = warning + +# IDE0030: Use coalesce expression (nullable types) +dotnet_diagnostic.IDE0030.severity = warning + +# IDE0031: Use null propagation +dotnet_diagnostic.IDE0031.severity = warning + +# IDE0035: Remove unreachable code +dotnet_diagnostic.IDE0035.severity = warning + +# IDE0036: Order modifiers +csharp_preferred_modifier_order = public,private,protected,internal,static,extern,new,virtual,abstract,sealed,override,readonly,unsafe,volatile,async:warning +dotnet_diagnostic.IDE0036.severity = warning + +# IDE0038: Use pattern matching to avoid is check followed by a cast (without variable) +dotnet_diagnostic.IDE0038.severity = warning + +# IDE0043: Format string contains invalid placeholder +dotnet_diagnostic.IDE0043.severity = warning + +# IDE0044: Make field readonly +dotnet_diagnostic.IDE0044.severity = warning + +# IDE0051: Remove unused private members +dotnet_diagnostic.IDE0051.severity = warning + +# IDE0055: All formatting rules +dotnet_diagnostic.IDE0055.severity = suggestion + +# IDE0059: Unnecessary assignment to a value +dotnet_diagnostic.IDE0059.severity = warning + +# IDE0060: Remove unused parameter +dotnet_code_quality_unused_parameters = non_public +dotnet_diagnostic.IDE0060.severity = warning + +# IDE0062: Make local function static +dotnet_diagnostic.IDE0062.severity = warning + +# IDE0073: File header - disabled +dotnet_diagnostic.IDE0073.severity = none + +# IDE0161: Convert to file-scoped namespace +dotnet_diagnostic.IDE0161.severity = warning + +# IDE0200: Lambda expression can be removed +dotnet_diagnostic.IDE0200.severity = warning + +# IDE2000: Disallow multiple blank lines +dotnet_style_allow_multiple_blank_lines_experimental = false +dotnet_diagnostic.IDE2000.severity = warning + +# Verify settings for test files +[*.{received,verified}.{txt,xml,json}] +charset = utf-8-bom +end_of_line = lf +indent_size = unset +indent_style = unset +insert_final_newline = false +tab_width = unset +trim_trailing_whitespace = false diff --git a/AGENTS.md b/AGENTS.md index 3c49ac5052..3ed5071350 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,9 +1,12 @@ # Repository Guidelines -# Rules to follow +## Rules to follow - Always run `dotnet build GraphRag.slnx` (or the relevant project) before executing any `dotnet test` command. - Default to the latest available versions (e.g., Apache AGE `latest`) when selecting dependencies, per user request ("тобі треба latest"). - Do not create or rely on fake database stores (e.g., `FakePostgresGraphStore`); all tests must use real connectors/backing services. +- Keep default prompts in static C# classes; do not rely on prompt files under `prompts/` for built-in templates. +- Register language models through Microsoft.Extensions.AI keyed services; avoid bespoke `LanguageModelConfig` providers. +- Always run `dotnet format GraphRag.slnx` before finishing work. # Conversations any resulting updates to agents.md should go under the section "## Rules to follow" diff --git a/Directory.Build.props b/Directory.Build.props index ab2b8f05fe..58057d740e 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -25,8 +25,8 @@ https://github.com/managedcode/graphrag https://github.com/managedcode/graphrag Managed Code GraphRag - 0.0.2 - 0.0.2 + 0.0.3 + 0.0.3 @@ -42,7 +42,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - + diff --git a/Directory.Packages.props b/Directory.Packages.props index dcd78bb678..b214525f9f 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -3,6 +3,7 @@ + @@ -20,4 +21,4 @@ - \ No newline at end of file + diff --git a/README.md b/README.md index 721cd314ac..e0afbcbeab 100644 --- a/README.md +++ b/README.md @@ -86,10 +86,72 @@ graphrag/ ## Integration Testing Strategy -- **No fakes.** We removed the legacy fake Postgres store. Every graph operation in tests uses real services orchestrated by Testcontainers. -- **Security coverage.** `Integration/PostgresGraphStoreIntegrationTests.cs` includes payloads that mimic SQL/Cypher injection attempts to ensure values remain literals and labels/types are strictly validated. -- **Cross-backend validation.** `Integration/GraphStoreIntegrationTests.cs` exercises Postgres, Neo4j, and Cosmos (when available) through the shared `IGraphStore` abstraction. +- **No fakes.** We removed the legacy fake Postgres store. Every graph operation in tests uses real services orchestrated by Testcontainers. +- **Security coverage.** `Integration/PostgresGraphStoreIntegrationTests.cs` includes payloads that mimic SQL/Cypher injection attempts to ensure values remain literals and labels/types are strictly validated. +- **Cross-backend validation.** `Integration/GraphStoreIntegrationTests.cs` exercises Postgres, Neo4j, and Cosmos (when available) through the shared `IGraphStore` abstraction. - **Workflow smoke tests.** Pipelines (e.g., `IndexingPipelineRunnerTests`) and finalization steps run end-to-end with the fixture-provisioned infrastructure. +- **Prompt precedence.** `Integration/CommunitySummariesIntegrationTests.cs` proves manual prompt overrides win over auto-tuned assets while still falling back to auto templates when manual text is absent. +- **Callback and stats instrumentation.** `Runtime/PipelineExecutorTests.cs` now asserts that pipeline callbacks fire and runtime statistics are captured even when workflows fail early, so custom telemetry remains reliable. + +--- + +## Pipeline Cache + +Pipelines exchange state through the `IPipelineCache` abstraction. Every workflow step receives the same cache instance via `PipelineRunContext`, so it can reuse expensive results (LLM calls, chunk expansions, graph lookups) that were produced earlier in the run instead of recomputing them. The cache also keeps optional debug payloads per entry so you can persist trace metadata alongside the main value. + +To use the built-in in-memory cache, register it alongside the standard ASP.NET Core services: + +```csharp +using GraphRag.Cache; + +builder.Services.AddMemoryCache(); +builder.Services.AddSingleton(); +``` + +Prefer a different backend? Implement `IPipelineCache` yourself and register it through DI—the pipeline will pick up your custom cache automatically. + +- **Per-scope isolation.** `MemoryPipelineCache.CreateChild("stage")` scopes keys by prefix (`parent:stage:key`). Calling `ClearAsync` on the parent removes every nested key, so multi-step workflows do not leak data between stages. +- **Debug traces.** The cache stores optional debug payloads per entry; `DeleteAsync` and `ClearAsync` always clear these traces, preventing the diagnostic dictionary from growing unbounded. +- **Lifecycle guidance.** Create the root cache once per pipeline run (the default context factory does this for you) and spawn children inside individual workflows when you need an isolated namespace. + +--- + +## Language Model Registration + +GraphRAG delegates language-model configuration to [Microsoft.Extensions.AI](https://learn.microsoft.com/dotnet/ai/overview). Register keyed clients for every `ModelId` you reference in configuration—pick any string key that matches your config: + +```csharp +using Azure; +using Azure.AI.OpenAI; +using GraphRag.Config; +using Microsoft.Extensions.AI; + +var openAi = new OpenAIClient(new Uri(endpoint), new AzureKeyCredential(key)); +const string chatModelId = "chat_model"; +const string embeddingModelId = "embedding_model"; + +builder.Services.AddKeyedSingleton( + chatModelId, + _ => openAi.GetChatClient(chatDeployment)); + +builder.Services.AddKeyedSingleton>( + embeddingModelId, + _ => openAi.GetEmbeddingClient(embeddingDeployment)); +``` + +Rate limits, retries, and other policies should be configured when you create these clients (for example by wrapping them with `Polly` handlers). `GraphRagConfig.Models` simply tracks the set of model keys that have been registered so overrides can validate references. + +--- + +## Indexing, Querying, and Prompt Tuning Alignment + +The .NET port mirrors the [GraphRAG indexing architecture](https://microsoft.github.io/graphrag/index/overview/) and its query workflows so downstream applications retain parity with the Python reference implementation. + +- **Indexing overview.** Workflows such as `extract_graph`, `create_communities`, and `community_summaries` map 1:1 to the [default data flow](https://microsoft.github.io/graphrag/index/default_dataflow/) and persist the same tables (`text_units`, `entities`, `relationships`, `communities`, `community_reports`, `covariates`). The new prompt template loader honours manual or auto-tuned prompts before falling back to the stock templates in `prompts/`. +- **Query capabilities.** The query pipeline retains global search, local search, drift search, and question generation semantics described in the [GraphRAG query overview](https://microsoft.github.io/graphrag/query/overview/). Each orchestrator continues to assemble context from the indexed tables so you can reference [global](https://microsoft.github.io/graphrag/query/global_search/) or [local](https://microsoft.github.io/graphrag/query/local_search/) narratives interchangeably. +- **Prompt tuning.** GraphRAG’s [manual](https://microsoft.github.io/graphrag/prompt_tuning/manual_prompt_tuning/) and [auto](https://microsoft.github.io/graphrag/prompt_tuning/auto_prompt_tuning/) strategies are surfaced through `GraphRagConfig.PromptTuning`. Store custom templates under `prompts/` or point `PromptTuning.Manual.Directory`/`PromptTuning.Auto.Directory` at your tuning outputs. You can also skip files entirely by assigning inline text (multi-line or prefixed with `inline:`) to workflow prompt properties. Stage keys and placeholders are documented in `docs/indexing-and-query.md`. + +See [`docs/indexing-and-query.md`](docs/indexing-and-query.md) for a deeper mapping between the .NET workflows and the research publications underpinning GraphRAG. --- diff --git a/docs/indexing-and-query.md b/docs/indexing-and-query.md new file mode 100644 index 0000000000..6e7df9e8d5 --- /dev/null +++ b/docs/indexing-and-query.md @@ -0,0 +1,99 @@ +# Indexing, Querying, and Prompt Tuning in GraphRAG for .NET + +GraphRAG for .NET keeps feature parity with the Python reference project described in the [Microsoft Research blog](https://www.microsoft.com/en-us/research/blog/graphrag-unlocking-llm-discovery-on-narrative-private-data/) and the [GraphRAG paper](https://arxiv.org/pdf/2404.16130). This document explains how the .NET workflows map to the concepts documented on [microsoft.github.io/graphrag](https://microsoft.github.io/graphrag/), highlights the supported query modes, and shows how to customise prompts via manual or auto tuning outputs. + +## Indexing Architecture + +- **Workflow parity.** Each indexing stage matches the Python pipeline and the [default data flow](https://microsoft.github.io/graphrag/index/default_dataflow/): + - `load_input_documents` → `create_base_text_units` → `summarize_descriptions` + - `extract_graph` persists `entities` and `relationships` + - `create_communities` produces `communities` + - `community_summaries` writes `community_reports` + - `extract_covariates` stores `covariates` +- **Storage schema.** Tables share the column layout described under [index outputs](https://microsoft.github.io/graphrag/index/outputs/). The new strongly-typed records (`CommunityRecord`, `CovariateRecord`, etc.) mirror the JSON representation used by the Python implementation. +- **Cluster configuration.** `GraphRagConfig.ClusterGraph` exposes the same knobs as the Python `cluster_graph` settings, enabling largest-component filtering and deterministic seeding. + +## Language Model Registration + +Workflows resolve language models from the DI container via [Microsoft.Extensions.AI](https://learn.microsoft.com/dotnet/ai/overview). Register keyed services for every `ModelId` you plan to reference: + +```csharp +using Azure; +using Azure.AI.OpenAI; +using GraphRag.Config; +using Microsoft.Extensions.AI; + +var openAi = new OpenAIClient(new Uri(endpoint), new AzureKeyCredential(key)); +const string chatModelId = "chat_model"; +const string embeddingModelId = "embedding_model"; + +services.AddKeyedSingleton(chatModelId, _ => openAi.GetChatClient(chatDeployment)); +services.AddKeyedSingleton>(embeddingModelId, _ => openAi.GetEmbeddingClient(embeddingDeployment)); +``` + +Configure retries, rate limits, and logging when you construct the concrete clients. `GraphRagConfig.Models` simply records the set of registered keys so configuration overrides can validate references. + +## Pipeline Cache + +`IPipelineCache` is intentionally infrastructure-neutral. To mirror ASP.NET Core's in-memory behaviour, register the built-in cache services alongside the provided adapter: + +```csharp +services.AddMemoryCache(); +services.AddSingleton(); +``` + +Need Redis or something else? Implement `IPipelineCache` yourself and register it through DI; the pipeline will automatically consume your custom cache. + +## Query Capabilities + +The query layer ports the orchestrators documented in the [GraphRAG query overview](https://microsoft.github.io/graphrag/query/overview/): + +- **Global search** ([docs](https://microsoft.github.io/graphrag/query/global_search/)) traverses community summaries and graph context to craft answers spanning the corpus. +- **Local search** ([docs](https://microsoft.github.io/graphrag/query/local_search/)) anchors on a document neighbourhood when you need focused context. +- **Drift search** ([docs](https://microsoft.github.io/graphrag/query/drift_search/)) monitors narrative changes across time slices. +- **Question generation** ([docs](https://microsoft.github.io/graphrag/query/question_generation/)) produces follow-up questions to extend an investigation. + +Every orchestrator consumes the same indexed tables as the Python project, so the .NET stack interoperates with BYOG scenarios described in the [index architecture guide](https://microsoft.github.io/graphrag/index/architecture/). + +## Prompt Tuning + +Manual and auto prompt tuning are both available without code changes: + +1. **Manual overrides** follow the rules from [manual prompt tuning](https://microsoft.github.io/graphrag/prompt_tuning/manual_prompt_tuning/). + - Place custom templates under a directory referenced by `GraphRagConfig.PromptTuning.Manual.Directory` and set `Enabled = true`. + - Filenames follow the stage key pattern `section/workflow/kind.txt` (see table below). +2. **Auto tuning** integrates the outputs documented in [auto prompt tuning](https://microsoft.github.io/graphrag/prompt_tuning/auto_prompt_tuning/). + - Point `GraphRagConfig.PromptTuning.Auto.Directory` at the folder containing the generated prompts and set `Enabled = true`. + - The runtime prefers explicit paths from workflow configs, then manual overrides, then auto-tuned files, and finally the built-in defaults in `prompts/`. +3. **Inline overrides** can be injected directly from code: set `ExtractGraphConfig.SystemPrompt`, `ExtractGraphConfig.Prompt`, or the equivalent properties to either a multi-line string or a value prefixed with `inline:`. Inline values bypass template file lookups and are used as-is. + +### Stage Keys and Placeholders + +| Workflow | Stage key | Purpose | Supported placeholders | +|----------|-----------|---------|------------------------| +| `extract_graph` (system) | `index/extract_graph/system.txt` | System prompt that instructs the extractor. | _N/A_ | +| `extract_graph` (user) | `index/extract_graph/user.txt` | User prompt template for individual text units. | `{{max_entities}}`, `{{text}}` | +| `community_summaries` (system) | `index/community_reports/system.txt` | System guidance for cluster summarisation. | _N/A_ | +| `community_summaries` (user) | `index/community_reports/user.txt` | User prompt template for entity lists. | `{{max_length}}`, `{{entities}}` | + +Placeholders are replaced at runtime with values drawn from workflow configuration: + +- `{{max_entities}}` → `ExtractGraphConfig.EntityTypes.Count + 5` (minimum 1) +- `{{text}}` → the original text unit content +- `{{max_length}}` → `CommunityReportsConfig.MaxLength` +- `{{entities}}` → bullet list of entity titles and descriptions + +If a template is omitted, the runtime falls back to the built-in prompts defined in `GraphRagPromptLibrary`. + +## Integration Tests + +`tests/ManagedCode.GraphRag.Tests/Integration/CommunitySummariesIntegrationTests.cs` exercises the new prompt loader end-to-end using the file-backed pipeline storage. Combined with the existing Aspire-powered suites, the tests demonstrate how indexing, community detection, and summarisation behave with tuned prompts while remaining faithful to the [GraphRAG BYOG guidance](https://microsoft.github.io/graphrag/index/byog/). + +## Further Reading + +- [GraphRAG prompt tuning overview](https://microsoft.github.io/graphrag/prompt_tuning/overview/) +- [GraphRAG index methods](https://microsoft.github.io/graphrag/index/methods/) +- [GraphRAG query overview](https://microsoft.github.io/graphrag/query/overview/) +- [GraphRAG default dataflow](https://microsoft.github.io/graphrag/index/default_dataflow/) + +These resources underpin the .NET implementation and provide broader context for customising or extending the library. diff --git a/prompts/community_graph.txt b/prompts/community_graph.txt new file mode 100644 index 0000000000..db1370edae --- /dev/null +++ b/prompts/community_graph.txt @@ -0,0 +1,2 @@ +You are an investigative analyst. Produce concise, neutral summaries that describe the shared theme binding the supplied entities. +Highlight how they relate, why the cluster matters, and any notable signals the reader should know. Do not invent facts. diff --git a/prompts/community_text.txt b/prompts/community_text.txt new file mode 100644 index 0000000000..4080cdac7b --- /dev/null +++ b/prompts/community_text.txt @@ -0,0 +1,6 @@ +Summarise the key theme that connects the following entities in no more than {{max_length}} characters. Focus on what unites them and why the group matters. Avoid bullet lists. + +Entities: +{{entities}} + +Provide a single paragraph answer. diff --git a/prompts/index/extract_graph.system.txt b/prompts/index/extract_graph.system.txt new file mode 100644 index 0000000000..bdc048e535 --- /dev/null +++ b/prompts/index/extract_graph.system.txt @@ -0,0 +1,9 @@ +You are a precise information extraction engine. Analyse the supplied text and return structured JSON describing: +- distinct entities (people, organisations, locations, products, events, concepts, technologies, dates, other) +- relationships between those entities + +Rules: +- Only use information explicitly stated or implied in the text. +- Prefer short, human-readable titles. +- Use snake_case relationship types (e.g., "works_with", "located_in"). +- Always return valid JSON adhering to the response schema. diff --git a/prompts/index/extract_graph.user.txt b/prompts/index/extract_graph.user.txt new file mode 100644 index 0000000000..506c606aac --- /dev/null +++ b/prompts/index/extract_graph.user.txt @@ -0,0 +1,28 @@ +Extract up to {{max_entities}} of the most important entities and their relationships from the following text. + +Text (between and markers): + +{{text}} + + +Respond with JSON matching this schema: +{ + "entities": [ + { + "title": "string", + "type": "person | organization | location | product | event | concept | technology | date | other", + "description": "short description", + "confidence": 0.0 - 1.0 + } + ], + "relationships": [ + { + "source": "entity title", + "target": "entity title", + "type": "relationship_type", + "description": "short description", + "weight": 0.0 - 1.0, + "bidirectional": true | false + } + ] +} diff --git a/src/ManagedCode.GraphRag.CosmosDb/CosmosGraphStore.cs b/src/ManagedCode.GraphRag.CosmosDb/CosmosGraphStore.cs index 0914edea61..95955af584 100644 --- a/src/ManagedCode.GraphRag.CosmosDb/CosmosGraphStore.cs +++ b/src/ManagedCode.GraphRag.CosmosDb/CosmosGraphStore.cs @@ -1,9 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; using GraphRag.Graphs; using Microsoft.Azure.Cosmos; using Microsoft.Azure.Cosmos.Linq; @@ -11,22 +6,13 @@ namespace GraphRag.Storage.Cosmos; -public sealed class CosmosGraphStore : IGraphStore +public sealed class CosmosGraphStore(CosmosClient client, string databaseId, string nodesContainerId, string edgesContainerId, ILogger logger) : IGraphStore { - private readonly CosmosClient _client; - private readonly string _databaseId; - private readonly string _nodesContainerId; - private readonly string _edgesContainerId; - private readonly ILogger _logger; - - public CosmosGraphStore(CosmosClient client, string databaseId, string nodesContainerId, string edgesContainerId, ILogger logger) - { - _client = client ?? throw new ArgumentNullException(nameof(client)); - _databaseId = databaseId ?? throw new ArgumentNullException(nameof(databaseId)); - _nodesContainerId = nodesContainerId ?? throw new ArgumentNullException(nameof(nodesContainerId)); - _edgesContainerId = edgesContainerId ?? throw new ArgumentNullException(nameof(edgesContainerId)); - _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - } + private readonly CosmosClient _client = client ?? throw new ArgumentNullException(nameof(client)); + private readonly string _databaseId = databaseId ?? throw new ArgumentNullException(nameof(databaseId)); + private readonly string _nodesContainerId = nodesContainerId ?? throw new ArgumentNullException(nameof(nodesContainerId)); + private readonly string _edgesContainerId = edgesContainerId ?? throw new ArgumentNullException(nameof(edgesContainerId)); + private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger)); public async Task InitializeAsync(CancellationToken cancellationToken = default) { @@ -60,7 +46,7 @@ public async Task UpsertRelationshipAsync(string sourceId, string targetId, stri public IAsyncEnumerable GetOutgoingRelationshipsAsync(string sourceId, CancellationToken cancellationToken = default) { ArgumentException.ThrowIfNullOrWhiteSpace(sourceId); - return Fetch(); + return Fetch(cancellationToken); async IAsyncEnumerable Fetch([EnumeratorCancellation] CancellationToken token = default) { diff --git a/src/ManagedCode.GraphRag.CosmosDb/ServiceCollectionExtensions.cs b/src/ManagedCode.GraphRag.CosmosDb/ServiceCollectionExtensions.cs index fa8b907b2f..3559459ca7 100644 --- a/src/ManagedCode.GraphRag.CosmosDb/ServiceCollectionExtensions.cs +++ b/src/ManagedCode.GraphRag.CosmosDb/ServiceCollectionExtensions.cs @@ -1,4 +1,3 @@ -using System; using System.Text.Json; using GraphRag.Graphs; using Microsoft.Azure.Cosmos; diff --git a/src/ManagedCode.GraphRag.CosmosDb/SystemTextJsonCosmosSerializer.cs b/src/ManagedCode.GraphRag.CosmosDb/SystemTextJsonCosmosSerializer.cs index 0f57472fba..623304b9d6 100644 --- a/src/ManagedCode.GraphRag.CosmosDb/SystemTextJsonCosmosSerializer.cs +++ b/src/ManagedCode.GraphRag.CosmosDb/SystemTextJsonCosmosSerializer.cs @@ -1,11 +1,9 @@ -using System.IO; -using System.Text; using System.Text.Json; using Microsoft.Azure.Cosmos; namespace GraphRag.Storage.Cosmos; -internal sealed class SystemTextJsonCosmosSerializer : CosmosSerializer +internal sealed class SystemTextJsonCosmosSerializer(JsonSerializerOptions? options = null) : CosmosSerializer { private static readonly JsonSerializerOptions DefaultOptions = new(JsonSerializerDefaults.Web) { @@ -13,19 +11,11 @@ internal sealed class SystemTextJsonCosmosSerializer : CosmosSerializer WriteIndented = false }; - private readonly JsonSerializerOptions _options; - - public SystemTextJsonCosmosSerializer(JsonSerializerOptions? options = null) - { - _options = options ?? DefaultOptions; - } + private readonly JsonSerializerOptions _options = options ?? DefaultOptions; public override T FromStream(Stream stream) { - if (stream is null) - { - throw new ArgumentNullException(nameof(stream)); - } + ArgumentNullException.ThrowIfNull(stream); if (typeof(T) == typeof(Stream)) { diff --git a/src/ManagedCode.GraphRag.Neo4j/Neo4jGraphStore.cs b/src/ManagedCode.GraphRag.Neo4j/Neo4jGraphStore.cs index a28c3b380e..7de7fa03bb 100644 --- a/src/ManagedCode.GraphRag.Neo4j/Neo4jGraphStore.cs +++ b/src/ManagedCode.GraphRag.Neo4j/Neo4jGraphStore.cs @@ -1,31 +1,20 @@ -using System; -using System.Collections.Generic; -using System.Linq; using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; using GraphRag.Graphs; using Microsoft.Extensions.Logging; using Neo4j.Driver; namespace GraphRag.Storage.Neo4j; -public sealed class Neo4jGraphStore : IGraphStore, IAsyncDisposable +public sealed class Neo4jGraphStore(IDriver driver, ILogger logger) : IGraphStore, IAsyncDisposable { - private readonly IDriver _driver; - private readonly ILogger _logger; + private readonly IDriver _driver = driver ?? throw new ArgumentNullException(nameof(driver)); + private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger)); public Neo4jGraphStore(string uri, string username, string password, ILogger logger) : this(GraphDatabase.Driver(uri, AuthTokens.Basic(username, password)), logger) { } - public Neo4jGraphStore(IDriver driver, ILogger logger) - { - _driver = driver ?? throw new ArgumentNullException(nameof(driver)); - _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - } - public async Task InitializeAsync(CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); @@ -64,7 +53,7 @@ public IAsyncEnumerable GetOutgoingRelationshipsAsync(string { ArgumentException.ThrowIfNullOrWhiteSpace(sourceId); cancellationToken.ThrowIfCancellationRequested(); - return Fetch(); + return Fetch(cancellationToken); async IAsyncEnumerable Fetch([EnumeratorCancellation] CancellationToken token = default) { diff --git a/src/ManagedCode.GraphRag.Neo4j/ServiceCollectionExtensions.cs b/src/ManagedCode.GraphRag.Neo4j/ServiceCollectionExtensions.cs index 72fca2b88b..8941cb421a 100644 --- a/src/ManagedCode.GraphRag.Neo4j/ServiceCollectionExtensions.cs +++ b/src/ManagedCode.GraphRag.Neo4j/ServiceCollectionExtensions.cs @@ -1,4 +1,3 @@ -using System; using GraphRag.Graphs; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; diff --git a/src/ManagedCode.GraphRag.Postgres/PostgresGraphStore.cs b/src/ManagedCode.GraphRag.Postgres/PostgresGraphStore.cs index 7050fbee51..21b771ecd6 100644 --- a/src/ManagedCode.GraphRag.Postgres/PostgresGraphStore.cs +++ b/src/ManagedCode.GraphRag.Postgres/PostgresGraphStore.cs @@ -37,8 +37,20 @@ public PostgresGraphStore(PostgresGraphStoreOptions options, ILogger 0) { @@ -121,9 +137,15 @@ public async Task UpsertRelationshipAsync(string sourceId, string targetId, stri var propertyAssignments = BuildPropertyAssignments("rel", ConvertProperties(properties), parameters, "rel_prop"); var queryBuilder = new StringBuilder(); - queryBuilder.Append($"MATCH (source {{ id: ${CypherParameterNames.SourceId} }}), (target {{ id: ${CypherParameterNames.TargetId} }})"); + queryBuilder.Append("MATCH (source { id: $"); + queryBuilder.Append(CypherParameterNames.SourceId); + queryBuilder.Append(" }), (target { id: $"); + queryBuilder.Append(CypherParameterNames.TargetId); + queryBuilder.Append(" })"); queryBuilder.AppendLine(); - queryBuilder.Append($"MERGE (source)-[rel:{EscapeLabel(type)}]->(target)"); + queryBuilder.Append("MERGE (source)-[rel:"); + queryBuilder.Append(EscapeLabel(type)); + queryBuilder.Append("]->(target)"); if (propertyAssignments.Count > 0) { diff --git a/src/ManagedCode.GraphRag/Cache/MemoryPipelineCache.cs b/src/ManagedCode.GraphRag/Cache/MemoryPipelineCache.cs new file mode 100644 index 0000000000..31b97d56ab --- /dev/null +++ b/src/ManagedCode.GraphRag/Cache/MemoryPipelineCache.cs @@ -0,0 +1,103 @@ +using System.Collections.Concurrent; +using Microsoft.Extensions.Caching.Memory; + +namespace GraphRag.Cache; + +/// +/// implementation backed by . +/// +public sealed class MemoryPipelineCache : IPipelineCache +{ + private readonly IMemoryCache _memoryCache; + private readonly string _scope; + private readonly ConcurrentDictionary _keys; + + public MemoryPipelineCache(IMemoryCache memoryCache) + : this(memoryCache, Guid.NewGuid().ToString("N"), new ConcurrentDictionary()) + { + } + + private MemoryPipelineCache(IMemoryCache memoryCache, string scope, ConcurrentDictionary keys) + { + _memoryCache = memoryCache ?? throw new ArgumentNullException(nameof(memoryCache)); + _scope = scope; + _keys = keys; + } + + public Task GetAsync(string key, CancellationToken cancellationToken = default) + { + ArgumentException.ThrowIfNullOrWhiteSpace(key); + cancellationToken.ThrowIfCancellationRequested(); + + if (_memoryCache.TryGetValue(GetCacheKey(key), out var value) && value is CacheEntry entry) + { + return Task.FromResult(entry.Value); + } + + return Task.FromResult(null); + } + + public Task SetAsync(string key, object? value, IReadOnlyDictionary? debugData = null, CancellationToken cancellationToken = default) + { + ArgumentException.ThrowIfNullOrWhiteSpace(key); + cancellationToken.ThrowIfCancellationRequested(); + + var cacheKey = GetCacheKey(key); + _memoryCache.Set(cacheKey, new CacheEntry(value, debugData)); + _keys[cacheKey] = 0; + return Task.CompletedTask; + } + + public Task HasAsync(string key, CancellationToken cancellationToken = default) + { + ArgumentException.ThrowIfNullOrWhiteSpace(key); + cancellationToken.ThrowIfCancellationRequested(); + + return Task.FromResult(_memoryCache.TryGetValue(GetCacheKey(key), out _)); + } + + public Task DeleteAsync(string key, CancellationToken cancellationToken = default) + { + ArgumentException.ThrowIfNullOrWhiteSpace(key); + cancellationToken.ThrowIfCancellationRequested(); + + var cacheKey = GetCacheKey(key); + _memoryCache.Remove(cacheKey); + _keys.TryRemove(cacheKey, out _); + return Task.CompletedTask; + } + + public Task ClearAsync(CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + var scopePrefix = string.Concat(_scope, ":"); + + foreach (var cacheKey in _keys.Keys) + { + if (!cacheKey.StartsWith(scopePrefix, StringComparison.Ordinal)) + { + continue; + } + + _memoryCache.Remove(cacheKey); + _keys.TryRemove(cacheKey, out _); + } + + return Task.CompletedTask; + } + + public IPipelineCache CreateChild(string name) + { + ArgumentException.ThrowIfNullOrWhiteSpace(name); + var childScope = string.Concat(_scope, ":", name); + return new MemoryPipelineCache(_memoryCache, childScope, _keys); + } + + private string GetCacheKey(string key) + { + return string.Concat(_scope, ":", key); + } + + private sealed record CacheEntry(object? Value, IReadOnlyDictionary? DebugData); +} diff --git a/src/ManagedCode.GraphRag/Community/CommunityBuilder.cs b/src/ManagedCode.GraphRag/Community/CommunityBuilder.cs new file mode 100644 index 0000000000..9cf716b611 --- /dev/null +++ b/src/ManagedCode.GraphRag/Community/CommunityBuilder.cs @@ -0,0 +1,237 @@ +using System.Collections.Immutable; +using System.Globalization; +using GraphRag.Config; +using GraphRag.Entities; +using GraphRag.Relationships; + +namespace GraphRag.Community; + +internal static class CommunityBuilder +{ + public static IReadOnlyList Build( + IReadOnlyList entities, + IReadOnlyList relationships, + ClusterGraphConfig? config) + { + ArgumentNullException.ThrowIfNull(entities); + ArgumentNullException.ThrowIfNull(relationships); + + config ??= new ClusterGraphConfig(); + + if (entities.Count == 0) + { + return Array.Empty(); + } + + var adjacency = BuildAdjacency(entities, relationships); + var titleLookup = entities.ToDictionary(entity => entity.Title, StringComparer.OrdinalIgnoreCase); + var random = new Random(config.Seed); + + var orderedTitles = titleLookup.Keys + .OrderBy(_ => random.Next()) + .ToList(); + + var visited = new HashSet(StringComparer.OrdinalIgnoreCase); + var components = new List>(); + + foreach (var title in orderedTitles) + { + if (!visited.Add(title)) + { + continue; + } + + var component = new List(); + var queue = new Queue(); + queue.Enqueue(title); + + while (queue.Count > 0) + { + var current = queue.Dequeue(); + component.Add(current); + + if (!adjacency.TryGetValue(current, out var neighbors) || neighbors.Count == 0) + { + continue; + } + + var orderedNeighbors = neighbors + .OrderBy(_ => random.Next()) + .ToList(); + + foreach (var neighbor in orderedNeighbors) + { + if (visited.Add(neighbor)) + { + queue.Enqueue(neighbor); + } + } + } + + components.Add(component); + } + + if (config.UseLargestConnectedComponent && components.Count > 0) + { + var largestSize = components.Max(component => component.Count); + components = components + .Where(component => component.Count == largestSize) + .Take(1) + .ToList(); + } + + var clusters = components + .SelectMany(component => SplitComponent(component, config.MaxClusterSize)) + .ToList(); + + if (clusters.Count == 0) + { + return Array.Empty(); + } + + var period = DateTime.UtcNow.ToString("yyyy-MM-dd", CultureInfo.InvariantCulture); + var communityRecords = new List(clusters.Count); + var relationshipLookup = relationships.ToList(); + + var communityIndex = 0; + foreach (var cluster in clusters) + { + var memberTitles = cluster + .Distinct(StringComparer.OrdinalIgnoreCase) + .Where(titleLookup.ContainsKey) + .ToList(); + + if (memberTitles.Count == 0) + { + continue; + } + + var members = memberTitles + .Select(title => titleLookup[title]) + .OrderBy(entity => entity.HumanReadableId) + .ToList(); + + if (members.Count == 0) + { + continue; + } + + communityIndex++; + var communityId = communityIndex; + + var entityIds = members + .Select(member => member.Id) + .ToImmutableArray(); + + var membership = new HashSet(memberTitles, StringComparer.OrdinalIgnoreCase); + var relationshipIds = new HashSet(StringComparer.OrdinalIgnoreCase); + var textUnitIds = new HashSet(StringComparer.OrdinalIgnoreCase); + + foreach (var relationship in relationshipLookup) + { + if (!membership.Contains(relationship.Source) || !membership.Contains(relationship.Target)) + { + continue; + } + + relationshipIds.Add(relationship.Id); + + foreach (var textUnitId in relationship.TextUnitIds) + { + if (!string.IsNullOrWhiteSpace(textUnitId)) + { + textUnitIds.Add(textUnitId); + } + } + } + + if (textUnitIds.Count == 0) + { + foreach (var member in members) + { + foreach (var textUnitId in member.TextUnitIds) + { + if (!string.IsNullOrWhiteSpace(textUnitId)) + { + textUnitIds.Add(textUnitId); + } + } + } + } + + var record = new CommunityRecord( + Id: Guid.NewGuid().ToString(), + HumanReadableId: communityId, + CommunityId: communityId, + Level: 0, + ParentId: -1, + Children: ImmutableArray.Empty, + Title: $"Community {communityId}", + EntityIds: entityIds, + RelationshipIds: relationshipIds + .OrderBy(id => id, StringComparer.Ordinal) + .ToImmutableArray(), + TextUnitIds: textUnitIds + .OrderBy(id => id, StringComparer.Ordinal) + .ToImmutableArray(), + Period: period, + Size: members.Count); + + communityRecords.Add(record); + } + + return communityRecords; + } + + private static Dictionary> BuildAdjacency( + IReadOnlyList entities, + IReadOnlyList relationships) + { + var adjacency = new Dictionary>(StringComparer.OrdinalIgnoreCase); + + foreach (var entity in entities) + { + adjacency.TryAdd(entity.Title, new HashSet(StringComparer.OrdinalIgnoreCase)); + } + + foreach (var relationship in relationships) + { + if (!adjacency.TryGetValue(relationship.Source, out var sourceNeighbors)) + { + sourceNeighbors = new HashSet(StringComparer.OrdinalIgnoreCase); + adjacency[relationship.Source] = sourceNeighbors; + } + + if (!adjacency.TryGetValue(relationship.Target, out var targetNeighbors)) + { + targetNeighbors = new HashSet(StringComparer.OrdinalIgnoreCase); + adjacency[relationship.Target] = targetNeighbors; + } + + sourceNeighbors.Add(relationship.Target); + targetNeighbors.Add(relationship.Source); + } + + return adjacency; + } + + private static IEnumerable> SplitComponent(List component, int maxClusterSize) + { + if (component.Count == 0) + { + yield break; + } + + if (maxClusterSize <= 0 || component.Count <= maxClusterSize) + { + yield return component; + yield break; + } + + for (var index = 0; index < component.Count; index += maxClusterSize) + { + var length = Math.Min(maxClusterSize, component.Count - index); + yield return component.GetRange(index, length); + } + } +} diff --git a/src/ManagedCode.GraphRag/Community/CommunityRecord.cs b/src/ManagedCode.GraphRag/Community/CommunityRecord.cs new file mode 100644 index 0000000000..79385842d8 --- /dev/null +++ b/src/ManagedCode.GraphRag/Community/CommunityRecord.cs @@ -0,0 +1,20 @@ +using System.Collections.Immutable; + +namespace GraphRag.Community; + +/// +/// Represents a finalized community row emitted by the GraphRAG pipeline. +/// +public sealed record CommunityRecord( + string Id, + int HumanReadableId, + int CommunityId, + int Level, + int? ParentId, + ImmutableArray Children, + string Title, + ImmutableArray EntityIds, + ImmutableArray RelationshipIds, + ImmutableArray TextUnitIds, + string? Period, + int Size); diff --git a/src/ManagedCode.GraphRag/Config/CacheConfig.cs b/src/ManagedCode.GraphRag/Config/CacheConfig.cs deleted file mode 100644 index 5c456e032a..0000000000 --- a/src/ManagedCode.GraphRag/Config/CacheConfig.cs +++ /dev/null @@ -1,16 +0,0 @@ -namespace GraphRag.Config; - -public sealed class CacheConfig -{ - public CacheType Type { get; set; } = CacheType.File; - - public string BaseDir { get; set; } = "cache"; - - public string? ConnectionString { get; set; } - - public string? ContainerName { get; set; } - - public string? StorageAccountBlobUrl { get; set; } - - public string? CosmosDbAccountUrl { get; set; } -} diff --git a/src/ManagedCode.GraphRag/Config/ChunkingConfig.cs b/src/ManagedCode.GraphRag/Config/ChunkingConfig.cs index bacd30a168..5bc7f1308a 100644 --- a/src/ManagedCode.GraphRag/Config/ChunkingConfig.cs +++ b/src/ManagedCode.GraphRag/Config/ChunkingConfig.cs @@ -10,7 +10,7 @@ public sealed class ChunkingConfig public ChunkStrategyType Strategy { get; set; } = ChunkStrategyType.Tokens; - public string EncodingModel { get; set; } = "cl100k_base"; + public string EncodingModel { get; set; } = GraphRag.Constants.TokenizerDefaults.DefaultEncoding; public bool PrependMetadata { get; set; } diff --git a/src/ManagedCode.GraphRag/Config/ClusterGraphConfig.cs b/src/ManagedCode.GraphRag/Config/ClusterGraphConfig.cs new file mode 100644 index 0000000000..16a3b43eaa --- /dev/null +++ b/src/ManagedCode.GraphRag/Config/ClusterGraphConfig.cs @@ -0,0 +1,25 @@ +namespace GraphRag.Config; + +/// +/// Configuration settings for graph community clustering. +/// +public sealed class ClusterGraphConfig +{ + /// + /// Gets or sets the maximum number of entities allowed in a single community cluster. + /// A value less than or equal to zero disables the limit. + /// + public int MaxClusterSize { get; set; } = 10; + + /// + /// Gets or sets a value indicating whether the largest connected component + /// should be used when clustering. + /// + public bool UseLargestConnectedComponent { get; set; } = true; + + /// + /// Gets or sets the seed used when ordering traversal operations to keep + /// results deterministic across runs. + /// + public int Seed { get; set; } = unchecked((int)0xDEADBEEF); +} diff --git a/src/ManagedCode.GraphRag/Config/CommunityReportsConfig.cs b/src/ManagedCode.GraphRag/Config/CommunityReportsConfig.cs index 0ca14f0f84..974f81f3cf 100644 --- a/src/ManagedCode.GraphRag/Config/CommunityReportsConfig.cs +++ b/src/ManagedCode.GraphRag/Config/CommunityReportsConfig.cs @@ -2,13 +2,11 @@ namespace GraphRag.Config; public sealed class CommunityReportsConfig { - public string ModelId { get; set; } = "default_chat_model"; + public string? ModelId { get; set; } public string? GraphPrompt { get; set; } - = "prompts/community_graph.txt"; public string? TextPrompt { get; set; } - = "prompts/community_text.txt"; public int MaxLength { get; set; } = 2000; diff --git a/src/ManagedCode.GraphRag/Config/Enums.cs b/src/ManagedCode.GraphRag/Config/Enums.cs index 42b9170dce..d81d3ff080 100644 --- a/src/ManagedCode.GraphRag/Config/Enums.cs +++ b/src/ManagedCode.GraphRag/Config/Enums.cs @@ -1,14 +1,5 @@ namespace GraphRag.Config; -public enum CacheType -{ - File, - Memory, - None, - Blob, - CosmosDb -} - public enum InputFileType { Csv, @@ -37,30 +28,6 @@ public enum ReportingType Blob } -public enum ModelType -{ - OpenAiEmbedding, - AzureOpenAiEmbedding, - Embedding, - OpenAiChat, - AzureOpenAiChat, - Chat, - MockChat, - MockEmbedding -} - -public enum AuthType -{ - ApiKey, - AzureManagedIdentity -} - -public enum AsyncType -{ - AsyncIo, - Threaded -} - public enum ChunkStrategyType { Tokens, diff --git a/src/ManagedCode.GraphRag/Config/ExtractGraphConfig.cs b/src/ManagedCode.GraphRag/Config/ExtractGraphConfig.cs index af8232ad63..d05e80cfc4 100644 --- a/src/ManagedCode.GraphRag/Config/ExtractGraphConfig.cs +++ b/src/ManagedCode.GraphRag/Config/ExtractGraphConfig.cs @@ -2,10 +2,11 @@ namespace GraphRag.Config; public sealed class ExtractGraphConfig { - public string ModelId { get; set; } = "default_chat_model"; + public string? ModelId { get; set; } + + public string? SystemPrompt { get; set; } public string? Prompt { get; set; } - = "prompts/index/extract_graph.txt"; public List EntityTypes { get; set; } = new() { "person", "organization", "location" }; diff --git a/src/ManagedCode.GraphRag/Config/GraphRagConfig.cs b/src/ManagedCode.GraphRag/Config/GraphRagConfig.cs index 0dceaa7bdb..bfdd062b64 100644 --- a/src/ManagedCode.GraphRag/Config/GraphRagConfig.cs +++ b/src/ManagedCode.GraphRag/Config/GraphRagConfig.cs @@ -7,15 +7,7 @@ public sealed class GraphRagConfig { public string RootDir { get; set; } = Directory.GetCurrentDirectory(); - public Dictionary Models { get; set; } = new(StringComparer.OrdinalIgnoreCase) - { - ["default_chat_model"] = new LanguageModelConfig(), - ["default_embedding_model"] = new LanguageModelConfig - { - Type = ModelType.Embedding, - Model = "text-embedding-3-small" - } - }; + public HashSet Models { get; set; } = new(StringComparer.OrdinalIgnoreCase); public InputConfig Input { get; set; } = new(); @@ -32,8 +24,6 @@ public sealed class GraphRagConfig Type = StorageType.File }; - public CacheConfig Cache { get; set; } = new(); - public ReportingConfig Reporting { get; set; } = new(); public Dictionary VectorStore { get; set; } = new(StringComparer.OrdinalIgnoreCase) @@ -50,23 +40,16 @@ public sealed class GraphRagConfig public SummarizeDescriptionsConfig SummarizeDescriptions { get; set; } = new(); + public ClusterGraphConfig ClusterGraph { get; set; } = new(); + public CommunityReportsConfig CommunityReports { get; set; } = new(); + public PromptTuningConfig PromptTuning { get; set; } = new(); + public SnapshotsConfig Snapshots { get; set; } = new(); public Dictionary Extensions { get; set; } = new(StringComparer.OrdinalIgnoreCase); - public LanguageModelConfig GetLanguageModelConfig(string modelId) - { - ArgumentException.ThrowIfNullOrWhiteSpace(modelId); - if (!Models.TryGetValue(modelId, out var config)) - { - throw new KeyNotFoundException($"Model ID '{modelId}' not found in configuration."); - } - - return config; - } - public VectorStoreConfig GetVectorStoreConfig(string vectorStoreId) { ArgumentException.ThrowIfNullOrWhiteSpace(vectorStoreId); diff --git a/src/ManagedCode.GraphRag/Config/LanguageModelConfig.cs b/src/ManagedCode.GraphRag/Config/LanguageModelConfig.cs deleted file mode 100644 index 98b06cad0c..0000000000 --- a/src/ManagedCode.GraphRag/Config/LanguageModelConfig.cs +++ /dev/null @@ -1,66 +0,0 @@ -namespace GraphRag.Config; - -public sealed class LanguageModelConfig -{ - public string? ApiKey { get; set; } - - public AuthType AuthType { get; set; } = AuthType.ApiKey; - - public ModelType Type { get; set; } = ModelType.Chat; - - public string? ModelProvider { get; set; } = "openai"; - - public string Model { get; set; } = "gpt-4-turbo-preview"; - - public string EncodingModel { get; set; } = "cl100k_base"; - - public string? ApiBase { get; set; } - - public string? ApiVersion { get; set; } - - public string? DeploymentName { get; set; } - - public string? Organization { get; set; } - - public string? Proxy { get; set; } - - public string? Audience { get; set; } - - public bool? ModelSupportsJson { get; set; } - - public double RequestTimeout { get; set; } = 120; - - public int? TokensPerMinute { get; set; } - - public int? RequestsPerMinute { get; set; } - - public string? RateLimitStrategy { get; set; } - - public string RetryStrategy { get; set; } = "exponential_backoff"; - - public int MaxRetries { get; set; } = 6; - - public double MaxRetryWait { get; set; } = 60; - - public int ConcurrentRequests { get; set; } = 4; - - public AsyncType AsyncMode { get; set; } = AsyncType.AsyncIo; - - public IList? Responses { get; set; } - - public int? MaxTokens { get; set; } - - public double Temperature { get; set; } - - public int? MaxCompletionTokens { get; set; } - - public string? ReasoningEffort { get; set; } - - public double TopP { get; set; } = 1; - - public int N { get; set; } = 1; - - public double FrequencyPenalty { get; set; } - - public double PresencePenalty { get; set; } -} diff --git a/src/ManagedCode.GraphRag/Config/PromptTuningConfig.cs b/src/ManagedCode.GraphRag/Config/PromptTuningConfig.cs new file mode 100644 index 0000000000..5e8b7a4f9b --- /dev/null +++ b/src/ManagedCode.GraphRag/Config/PromptTuningConfig.cs @@ -0,0 +1,24 @@ +namespace GraphRag.Config; + +public sealed class PromptTuningConfig +{ + public ManualPromptTuningConfig Manual { get; set; } = new(); + + public AutoPromptTuningConfig Auto { get; set; } = new(); +} + +public sealed class ManualPromptTuningConfig +{ + public bool Enabled { get; set; } + + public string? Directory { get; set; } +} + +public sealed class AutoPromptTuningConfig +{ + public bool Enabled { get; set; } + + public string? Directory { get; set; } + + public string? Strategy { get; set; } +} diff --git a/src/ManagedCode.GraphRag/Config/SummarizeDescriptionsConfig.cs b/src/ManagedCode.GraphRag/Config/SummarizeDescriptionsConfig.cs index db1b467acc..5a95f48855 100644 --- a/src/ManagedCode.GraphRag/Config/SummarizeDescriptionsConfig.cs +++ b/src/ManagedCode.GraphRag/Config/SummarizeDescriptionsConfig.cs @@ -2,11 +2,11 @@ namespace GraphRag.Config; public sealed class SummarizeDescriptionsConfig { - public string ModelId { get; set; } = "default_chat_model"; + public string? ModelId { get; set; } - public string? Prompt { get; set; } = "prompts/index/summarize_entities.txt"; + public string? Prompt { get; set; } - public string? RelationshipPrompt { get; set; } = "prompts/index/summarize_relationships.txt"; + public string? RelationshipPrompt { get; set; } public int MaxLength { get; set; } = 400; diff --git a/src/ManagedCode.GraphRag/Config/TextEmbeddingConfig.cs b/src/ManagedCode.GraphRag/Config/TextEmbeddingConfig.cs index 07b4ad0062..9ba21a3df3 100644 --- a/src/ManagedCode.GraphRag/Config/TextEmbeddingConfig.cs +++ b/src/ManagedCode.GraphRag/Config/TextEmbeddingConfig.cs @@ -2,7 +2,7 @@ namespace GraphRag.Config; public sealed class TextEmbeddingConfig { - public string ModelId { get; set; } = "default_embedding_model"; + public string? ModelId { get; set; } public string VectorStoreId { get; set; } = "default_vector_store"; diff --git a/src/ManagedCode.GraphRag/Constants/PipelineTableNames.cs b/src/ManagedCode.GraphRag/Constants/PipelineTableNames.cs index 2f4778426a..ece46f62f5 100644 --- a/src/ManagedCode.GraphRag/Constants/PipelineTableNames.cs +++ b/src/ManagedCode.GraphRag/Constants/PipelineTableNames.cs @@ -6,5 +6,7 @@ public static class PipelineTableNames public const string TextUnits = "text_units"; public const string Entities = "entities"; public const string Relationships = "relationships"; + public const string Communities = "communities"; public const string CommunityReports = "community_reports"; + public const string Covariates = "covariates"; } diff --git a/src/ManagedCode.GraphRag/Constants/PromptTemplateKeys.cs b/src/ManagedCode.GraphRag/Constants/PromptTemplateKeys.cs new file mode 100644 index 0000000000..67399a8368 --- /dev/null +++ b/src/ManagedCode.GraphRag/Constants/PromptTemplateKeys.cs @@ -0,0 +1,12 @@ +namespace GraphRag.Constants; + +internal static class PromptTemplateKeys +{ + public const string ExtractGraphSystem = "index/extract_graph/system"; + + public const string ExtractGraphUser = "index/extract_graph/user"; + + public const string CommunitySummarySystem = "index/community_reports/system"; + + public const string CommunitySummaryUser = "index/community_reports/user"; +} diff --git a/src/ManagedCode.GraphRag/Covariates/CovariateRecord.cs b/src/ManagedCode.GraphRag/Covariates/CovariateRecord.cs new file mode 100644 index 0000000000..16b6f86a9b --- /dev/null +++ b/src/ManagedCode.GraphRag/Covariates/CovariateRecord.cs @@ -0,0 +1,18 @@ +namespace GraphRag.Covariates; + +/// +/// Represents a finalized covariate (claim) row emitted by the GraphRAG pipeline. +/// +public sealed record CovariateRecord( + string Id, + int HumanReadableId, + string CovariateType, + string? Type, + string? Description, + string SubjectId, + string? ObjectId, + string? Status, + string? StartDate, + string? EndDate, + string? SourceText, + string TextUnitId); diff --git a/src/ManagedCode.GraphRag/Covariates/TextUnitCovariateJoiner.cs b/src/ManagedCode.GraphRag/Covariates/TextUnitCovariateJoiner.cs new file mode 100644 index 0000000000..59cba9cc19 --- /dev/null +++ b/src/ManagedCode.GraphRag/Covariates/TextUnitCovariateJoiner.cs @@ -0,0 +1,72 @@ +using GraphRag.Data; + +namespace GraphRag.Covariates; + +/// +/// Provides helpers for attaching extracted covariates back onto text unit records. +/// +public static class TextUnitCovariateJoiner +{ + public static IReadOnlyList Attach( + IReadOnlyList textUnits, + IReadOnlyList covariates) + { + ArgumentNullException.ThrowIfNull(textUnits); + ArgumentNullException.ThrowIfNull(covariates); + + if (textUnits.Count == 0 || covariates.Count == 0) + { + return textUnits; + } + + var lookup = new Dictionary>(StringComparer.OrdinalIgnoreCase); + foreach (var covariate in covariates) + { + if (string.IsNullOrWhiteSpace(covariate.TextUnitId)) + { + continue; + } + + if (!lookup.TryGetValue(covariate.TextUnitId, out var ids)) + { + ids = new HashSet(StringComparer.OrdinalIgnoreCase); + lookup[covariate.TextUnitId] = ids; + } + + if (!string.IsNullOrWhiteSpace(covariate.Id)) + { + ids.Add(covariate.Id); + } + } + + if (lookup.Count == 0) + { + return textUnits; + } + + var results = new List(textUnits.Count); + foreach (var unit in textUnits) + { + if (!lookup.TryGetValue(unit.Id, out var ids)) + { + results.Add(unit); + continue; + } + + var existing = unit.CovariateIds ?? Array.Empty(); + var combined = new HashSet(existing, StringComparer.OrdinalIgnoreCase); + foreach (var id in ids) + { + combined.Add(id); + } + + var ordered = combined + .OrderBy(value => value, StringComparer.Ordinal) + .ToArray(); + + results.Add(unit with { CovariateIds = ordered }); + } + + return results; + } +} diff --git a/src/ManagedCode.GraphRag/Indexing/IndexingPipelineRunner.cs b/src/ManagedCode.GraphRag/Indexing/IndexingPipelineRunner.cs index 4f7d9ba4c0..a9ddaef42f 100644 --- a/src/ManagedCode.GraphRag/Indexing/IndexingPipelineRunner.cs +++ b/src/ManagedCode.GraphRag/Indexing/IndexingPipelineRunner.cs @@ -3,6 +3,7 @@ using GraphRag.Config; using GraphRag.Indexing.Runtime; using GraphRag.Storage; +using Microsoft.Extensions.DependencyInjection; namespace GraphRag.Indexing; @@ -22,7 +23,7 @@ public async Task> RunAsync(GraphRagConfig conf var inputStorage = PipelineStorageFactory.Create(config.Input.Storage); var outputStorage = PipelineStorageFactory.Create(config.Output); var previousStorage = PipelineStorageFactory.Create(config.UpdateIndexOutput); - var cache = new InMemoryPipelineCache(); + var cache = _services.GetService(); var callbacks = NoopWorkflowCallbacks.Instance; var stats = new PipelineRunStats(); var state = new PipelineState(); diff --git a/src/ManagedCode.GraphRag/Indexing/Runtime/IndexingPipelineDefinitions.cs b/src/ManagedCode.GraphRag/Indexing/Runtime/IndexingPipelineDefinitions.cs index c8e34c8cf0..2bad212dc9 100644 --- a/src/ManagedCode.GraphRag/Indexing/Runtime/IndexingPipelineDefinitions.cs +++ b/src/ManagedCode.GraphRag/Indexing/Runtime/IndexingPipelineDefinitions.cs @@ -9,6 +9,7 @@ public static class IndexingPipelineDefinitions LoadInputDocumentsWorkflow.Name, CreateBaseTextUnitsWorkflow.Name, ExtractGraphWorkflow.Name, + CreateCommunitiesWorkflow.Name, CommunitySummariesWorkflow.Name, CreateFinalDocumentsWorkflow.Name }); diff --git a/src/ManagedCode.GraphRag/Indexing/Runtime/PipelineContextFactory.cs b/src/ManagedCode.GraphRag/Indexing/Runtime/PipelineContextFactory.cs index 2b1a988e7e..474ecee9be 100644 --- a/src/ManagedCode.GraphRag/Indexing/Runtime/PipelineContextFactory.cs +++ b/src/ManagedCode.GraphRag/Indexing/Runtime/PipelineContextFactory.cs @@ -22,7 +22,7 @@ public static PipelineRunContext Create( inputStorage ?? new MemoryPipelineStorage(), outputStorage ?? new MemoryPipelineStorage(), previousStorage ?? new MemoryPipelineStorage(), - cache ?? new InMemoryPipelineCache(), + cache, callbacks ?? NoopWorkflowCallbacks.Instance, stats ?? new PipelineRunStats(), state ?? new PipelineState(), diff --git a/src/ManagedCode.GraphRag/Indexing/Runtime/PipelineRunContext.cs b/src/ManagedCode.GraphRag/Indexing/Runtime/PipelineRunContext.cs index 76e12f088f..427bc0f22c 100644 --- a/src/ManagedCode.GraphRag/Indexing/Runtime/PipelineRunContext.cs +++ b/src/ManagedCode.GraphRag/Indexing/Runtime/PipelineRunContext.cs @@ -11,7 +11,7 @@ public sealed class PipelineRunContext( IPipelineStorage inputStorage, IPipelineStorage outputStorage, IPipelineStorage previousStorage, - IPipelineCache cache, + IPipelineCache? cache, IWorkflowCallbacks callbacks, PipelineRunStats stats, PipelineState state, @@ -24,7 +24,7 @@ public sealed class PipelineRunContext( public IPipelineStorage PreviousStorage { get; } = previousStorage ?? throw new ArgumentNullException(nameof(previousStorage)); - public IPipelineCache Cache { get; } = cache ?? throw new ArgumentNullException(nameof(cache)); + public IPipelineCache? Cache { get; } = cache; public IWorkflowCallbacks Callbacks { get; } = callbacks ?? throw new ArgumentNullException(nameof(callbacks)); diff --git a/src/ManagedCode.GraphRag/Indexing/Workflows/CommunitySummariesWorkflow.cs b/src/ManagedCode.GraphRag/Indexing/Workflows/CommunitySummariesWorkflow.cs index 4046863cc5..14952aa45c 100644 --- a/src/ManagedCode.GraphRag/Indexing/Workflows/CommunitySummariesWorkflow.cs +++ b/src/ManagedCode.GraphRag/Indexing/Workflows/CommunitySummariesWorkflow.cs @@ -16,6 +16,7 @@ namespace GraphRag.Indexing.Workflows; internal static class CommunitySummariesWorkflow { public const string Name = "community_summaries"; + private const string CommunityReportsCountKey = "community_reports:count"; public static WorkflowDelegate Create() { @@ -45,10 +46,38 @@ await context.OutputStorage relationships = Array.Empty(); } + IReadOnlyList communities; + if (await context.OutputStorage.TableExistsAsync(PipelineTableNames.Communities, cancellationToken).ConfigureAwait(false)) + { + communities = await context.OutputStorage + .LoadTableAsync(PipelineTableNames.Communities, cancellationToken) + .ConfigureAwait(false); + } + else + { + communities = DetectCommunities(entities, relationships, config.ClusterGraph); + } + + if (communities.Count == 0) + { + await context.OutputStorage + .WriteTableAsync(PipelineTableNames.CommunityReports, Array.Empty(), cancellationToken) + .ConfigureAwait(false); + return new WorkflowResult(Array.Empty()); + } + + var entityLookup = entities.ToDictionary(entity => entity.Id, StringComparer.OrdinalIgnoreCase); var logger = context.Services.GetService()?.CreateLogger(typeof(CommunitySummariesWorkflow)); var reportsConfig = config.CommunityReports ?? new CommunityReportsConfig(); var chatClient = ResolveChatClient(context.Services, reportsConfig.ModelId, logger); - var communities = DetectCommunities(entities, relationships); + var promptLoader = PromptTemplateLoader.Create(config); + var systemPrompt = promptLoader.ResolveOrDefault( + PromptTemplateKeys.CommunitySummarySystem, + reportsConfig.GraphPrompt, + GraphRagPromptLibrary.CommunitySummarySystemPrompt); + var userTemplate = promptLoader.ResolveOptional( + PromptTemplateKeys.CommunitySummaryUser, + reportsConfig.TextPrompt); var reports = new List(communities.Count); for (var index = 0; index < communities.Count; index++) @@ -56,14 +85,28 @@ await context.OutputStorage cancellationToken.ThrowIfCancellationRequested(); var community = communities[index]; + var members = community.EntityIds + .Select(id => entityLookup.TryGetValue(id, out var entity) ? entity : null) + .Where(static entity => entity is not null) + .Cast() + .ToArray(); + + if (members.Length == 0) + { + continue; + } + var summary = string.Empty; try { summary = await GenerateCommunitySummaryAsync( chatClient, - GraphRagPromptLibrary.CommunitySummarySystemPrompt, - GraphRagPromptLibrary.BuildCommunitySummaryUserPrompt(community, reportsConfig.MaxLength), + systemPrompt, + GraphRagPromptLibrary.BuildCommunitySummaryUserPrompt( + members, + reportsConfig.MaxLength, + userTemplate), cancellationToken).ConfigureAwait(false); } catch (Exception ex) @@ -73,14 +116,14 @@ await context.OutputStorage if (string.IsNullOrWhiteSpace(summary)) { - summary = BuildFallbackSummary(community); + summary = BuildFallbackSummary(members); } var keywords = ExtractKeywords(summary); reports.Add(new CommunityReportRecord( - CommunityId: $"community_{index + 1}", + CommunityId: $"community_{community.CommunityId}", Level: 0, - EntityTitles: community.Select(static e => e.Title).ToArray(), + EntityTitles: members.Select(static e => e.Title).ToArray(), Summary: summary.Trim(), Keywords: keywords)); } @@ -89,83 +132,17 @@ await context.OutputStorage .WriteTableAsync(PipelineTableNames.CommunityReports, reports, cancellationToken) .ConfigureAwait(false); - context.Items["community_reports:count"] = reports.Count; + context.Items[CommunityReportsCountKey] = reports.Count; return new WorkflowResult(reports); }; } - private static List> DetectCommunities( + private static IReadOnlyList DetectCommunities( IReadOnlyList entities, - IReadOnlyList relationships) + IReadOnlyList relationships, + ClusterGraphConfig? clusterConfig) { - var adjacency = new Dictionary>(StringComparer.OrdinalIgnoreCase); - foreach (var entity in entities) - { - adjacency.TryAdd(entity.Title, new HashSet(StringComparer.OrdinalIgnoreCase)); - } - - foreach (var relationship in relationships) - { - if (!adjacency.TryGetValue(relationship.Source, out var sourceNeighbors)) - { - sourceNeighbors = new HashSet(StringComparer.OrdinalIgnoreCase); - adjacency[relationship.Source] = sourceNeighbors; - } - - if (!adjacency.TryGetValue(relationship.Target, out var targetNeighbors)) - { - targetNeighbors = new HashSet(StringComparer.OrdinalIgnoreCase); - adjacency[relationship.Target] = targetNeighbors; - } - - sourceNeighbors.Add(relationship.Target); - targetNeighbors.Add(relationship.Source); - } - - var entityLookup = entities.ToDictionary(entity => entity.Title, StringComparer.OrdinalIgnoreCase); - var visited = new HashSet(StringComparer.OrdinalIgnoreCase); - var communities = new List>(); - - foreach (var entity in entities) - { - if (!visited.Add(entity.Title)) - { - continue; - } - - var queue = new Queue(); - queue.Enqueue(entity.Title); - - var members = new List(); - - while (queue.Count > 0) - { - var current = queue.Dequeue(); - if (!entityLookup.TryGetValue(current, out var record)) - { - continue; - } - - members.Add(record); - - if (!adjacency.TryGetValue(current, out var neighbors)) - { - continue; - } - - foreach (var neighbor in neighbors) - { - if (visited.Add(neighbor)) - { - queue.Enqueue(neighbor); - } - } - } - - communities.Add(members); - } - - return communities; + return CommunityBuilder.Build(entities, relationships, clusterConfig); } private static IReadOnlyList ExtractKeywords(string summary) diff --git a/src/ManagedCode.GraphRag/Indexing/Workflows/CreateCommunitiesWorkflow.cs b/src/ManagedCode.GraphRag/Indexing/Workflows/CreateCommunitiesWorkflow.cs new file mode 100644 index 0000000000..72e1b375fd --- /dev/null +++ b/src/ManagedCode.GraphRag/Indexing/Workflows/CreateCommunitiesWorkflow.cs @@ -0,0 +1,55 @@ +using GraphRag.Community; +using GraphRag.Config; +using GraphRag.Constants; +using GraphRag.Entities; +using GraphRag.Indexing.Runtime; +using GraphRag.Relationships; +using GraphRag.Storage; + +namespace GraphRag.Indexing.Workflows; + +internal static class CreateCommunitiesWorkflow +{ + public const string Name = "create_communities"; + private const string CommunityCountKey = "create_communities:count"; + + public static WorkflowDelegate Create() + { + return async (config, context, cancellationToken) => + { + var entities = await context.OutputStorage + .LoadTableAsync(PipelineTableNames.Entities, cancellationToken) + .ConfigureAwait(false); + + if (entities.Count == 0) + { + await context.OutputStorage + .WriteTableAsync(PipelineTableNames.Communities, Array.Empty(), cancellationToken) + .ConfigureAwait(false); + return new WorkflowResult(Array.Empty()); + } + + IReadOnlyList relationships; + if (await context.OutputStorage.TableExistsAsync(PipelineTableNames.Relationships, cancellationToken).ConfigureAwait(false)) + { + relationships = await context.OutputStorage + .LoadTableAsync(PipelineTableNames.Relationships, cancellationToken) + .ConfigureAwait(false); + } + else + { + relationships = Array.Empty(); + } + + var clusterConfig = config.ClusterGraph ?? new ClusterGraphConfig(); + var communities = CommunityBuilder.Build(entities, relationships, clusterConfig); + + await context.OutputStorage + .WriteTableAsync(PipelineTableNames.Communities, communities, cancellationToken) + .ConfigureAwait(false); + + context.Items[CommunityCountKey] = communities.Count; + return new WorkflowResult(communities); + }; + } +} diff --git a/src/ManagedCode.GraphRag/Indexing/Workflows/ExtractGraphWorkflow.cs b/src/ManagedCode.GraphRag/Indexing/Workflows/ExtractGraphWorkflow.cs index 0b3fa7e69a..aff1191aff 100644 --- a/src/ManagedCode.GraphRag/Indexing/Workflows/ExtractGraphWorkflow.cs +++ b/src/ManagedCode.GraphRag/Indexing/Workflows/ExtractGraphWorkflow.cs @@ -20,6 +20,9 @@ internal static class ExtractGraphWorkflow private static readonly JsonSerializerOptions SerializerOptions = new(JsonSerializerDefaults.Web); + private const string EntityCountKey = "extract_graph:entity_count"; + private const string RelationshipCountKey = "extract_graph:relationship_count"; + public static WorkflowDelegate Create() { return async (config, context, cancellationToken) => @@ -42,6 +45,14 @@ public static WorkflowDelegate Create() var relationshipAggregator = new RelationshipAggregator(); var chatClient = ResolveChatClient(context.Services, extractionConfig.ModelId, logger); + var promptLoader = PromptTemplateLoader.Create(config); + var systemPrompt = promptLoader.ResolveOrDefault( + PromptTemplateKeys.ExtractGraphSystem, + extractionConfig.SystemPrompt, + GraphRagPromptLibrary.ExtractGraphSystemPrompt); + var userPromptTemplate = promptLoader.ResolveOptional( + PromptTemplateKeys.ExtractGraphUser, + extractionConfig.Prompt); foreach (var unit in textUnits) { @@ -56,8 +67,11 @@ public static WorkflowDelegate Create() { var extraction = await GenerateExtractionAsync( chatClient, - GraphRagPromptLibrary.ExtractGraphSystemPrompt, - GraphRagPromptLibrary.BuildExtractGraphUserPrompt(unit.Text, Math.Max(1, allowedTypes.Count + 5)), + systemPrompt, + GraphRagPromptLibrary.BuildExtractGraphUserPrompt( + unit.Text, + Math.Max(1, allowedTypes.Count + 5), + userPromptTemplate), logger, cancellationToken).ConfigureAwait(false); @@ -94,8 +108,8 @@ await context.OutputStorage .WriteTableAsync(PipelineTableNames.Relationships, finalization.Relationships, cancellationToken) .ConfigureAwait(false); - context.Items["extract_graph:entity_count"] = finalization.Entities.Count; - context.Items["extract_graph:relationship_count"] = finalization.Relationships.Count; + context.Items[EntityCountKey] = finalization.Entities.Count; + context.Items[RelationshipCountKey] = finalization.Relationships.Count; return new WorkflowResult(finalization.Entities); }; diff --git a/src/ManagedCode.GraphRag/LanguageModels/GraphRagPromptLibrary.cs b/src/ManagedCode.GraphRag/LanguageModels/GraphRagPromptLibrary.cs index 00fae7b622..70f7be2206 100644 --- a/src/ManagedCode.GraphRag/LanguageModels/GraphRagPromptLibrary.cs +++ b/src/ManagedCode.GraphRag/LanguageModels/GraphRagPromptLibrary.cs @@ -1,3 +1,4 @@ +using System.Globalization; using System.Text; using GraphRag.Entities; @@ -22,8 +23,15 @@ internal static class GraphRagPromptLibrary Highlight how they relate, why the cluster matters, and any notable signals the reader should know. Do not invent facts. """; - internal static string BuildExtractGraphUserPrompt(string textUnit, int maxEntities) + internal static string BuildExtractGraphUserPrompt(string textUnit, int maxEntities, string? template = null) { + if (!string.IsNullOrWhiteSpace(template)) + { + return template + .Replace("{{max_entities}}", maxEntities.ToString(CultureInfo.InvariantCulture), StringComparison.Ordinal) + .Replace("{{text}}", textUnit, StringComparison.Ordinal); + } + return @$"Extract up to {maxEntities} of the most important entities and their relationships from the following text. Text (between and markers): @@ -54,8 +62,17 @@ internal static string BuildExtractGraphUserPrompt(string textUnit, int maxEntit }}"; } - internal static string BuildCommunitySummaryUserPrompt(IReadOnlyList community, int maxLength) + internal static string BuildCommunitySummaryUserPrompt(IReadOnlyList community, int maxLength, string? template = null) { + var entityLines = BuildEntityLines(community); + + if (!string.IsNullOrWhiteSpace(template)) + { + return template + .Replace("{{max_length}}", maxLength.ToString(CultureInfo.InvariantCulture), StringComparison.Ordinal) + .Replace("{{entities}}", entityLines, StringComparison.Ordinal); + } + var builder = new StringBuilder(); builder.Append("Summarise the key theme that connects the following entities in no more than "); builder.Append(maxLength); @@ -63,6 +80,15 @@ internal static string BuildCommunitySummaryUserPrompt(IReadOnlyList community) + { + var builder = new StringBuilder(); foreach (var entity in community) { builder.Append("- "); @@ -76,8 +102,6 @@ internal static string BuildCommunitySummaryUserPrompt(IReadOnlyList _cache = new(StringComparer.OrdinalIgnoreCase); + + private PromptTemplateLoader(GraphRagConfig config) + { + _config = config; + } + + public static PromptTemplateLoader Create(GraphRagConfig config) + { + ArgumentNullException.ThrowIfNull(config); + return new PromptTemplateLoader(config); + } + + public string ResolveOrDefault(string stageKey, string? explicitPath, string defaultValue) + { + ArgumentException.ThrowIfNullOrWhiteSpace(stageKey); + ArgumentNullException.ThrowIfNull(defaultValue); + + var resolved = ResolveInternal(stageKey, explicitPath); + return string.IsNullOrWhiteSpace(resolved) ? defaultValue : resolved!; + } + + public string? ResolveOptional(string stageKey, string? explicitPath) + { + ArgumentException.ThrowIfNullOrWhiteSpace(stageKey); + return ResolveInternal(stageKey, explicitPath); + } + + private string? ResolveInternal(string stageKey, string? explicitPath) + { + var cacheKey = $"{stageKey}::{explicitPath}"; + if (_cache.TryGetValue(cacheKey, out var cached)) + { + return cached; + } + + var result = LoadPrompt(stageKey, explicitPath); + _cache[cacheKey] = result; + return result; + } + + private string? LoadPrompt(string stageKey, string? explicitPath) + { + if (TryReadFile(explicitPath, out var value)) + { + return value; + } + + if (TryReadFromDirectory(_config.PromptTuning?.Manual, stageKey, out value)) + { + return value; + } + + if (TryReadFromDirectory(_config.PromptTuning?.Auto, stageKey, out value)) + { + return value; + } + + if (!string.IsNullOrWhiteSpace(explicitPath)) + { + var inline = ExtractInlinePrompt(explicitPath); + if (!string.IsNullOrWhiteSpace(inline)) + { + return inline; + } + } + + return null; + } + + private bool TryReadFromDirectory(ManualPromptTuningConfig? tuning, string stageKey, out string? value) + { + value = null; + if (tuning is null || !tuning.Enabled || string.IsNullOrWhiteSpace(tuning.Directory)) + { + return false; + } + + var directory = ResolveDirectory(tuning.Directory!); + var candidate = BuildPath(directory, stageKey); + return TryReadFile(candidate, out value); + } + + private bool TryReadFromDirectory(AutoPromptTuningConfig? tuning, string stageKey, out string? value) + { + value = null; + if (tuning is null || !tuning.Enabled || string.IsNullOrWhiteSpace(tuning.Directory)) + { + return false; + } + + var directory = ResolveDirectory(tuning.Directory!); + var candidate = BuildPath(directory, stageKey); + return TryReadFile(candidate, out value); + } + + private bool TryReadFile(string? path, out string? value) + { + value = null; + if (string.IsNullOrWhiteSpace(path)) + { + return false; + } + + var resolved = ResolvePath(path); + if (!File.Exists(resolved)) + { + return false; + } + + value = File.ReadAllText(resolved); + return true; + } + + private string ResolveDirectory(string directory) + { + if (Path.IsPathRooted(directory)) + { + return Path.GetFullPath(directory); + } + + var root = string.IsNullOrWhiteSpace(_config.RootDir) + ? Directory.GetCurrentDirectory() + : _config.RootDir; + + return Path.GetFullPath(Path.Combine(root, directory)); + } + + private string ResolvePath(string path) + { + if (Path.IsPathRooted(path)) + { + return Path.GetFullPath(path); + } + + var root = string.IsNullOrWhiteSpace(_config.RootDir) + ? Directory.GetCurrentDirectory() + : _config.RootDir; + + return Path.GetFullPath(Path.Combine(root, path)); + } + + private static string BuildPath(string directory, string stageKey) + { + var relative = stageKey.Replace('/', Path.DirectorySeparatorChar); + var candidate = Path.Combine(directory, relative); + return Path.HasExtension(candidate) ? candidate : candidate + ".txt"; + } + + private static readonly char[] InlineSeparators = new[] { '\r', '\n' }; + + private static string? ExtractInlinePrompt(string candidate) + { + if (candidate.StartsWith("inline:", StringComparison.OrdinalIgnoreCase)) + { + return candidate[7..].TrimStart(); + } + + return candidate.IndexOfAny(InlineSeparators) >= 0 + ? candidate + : null; + } +} diff --git a/src/ManagedCode.GraphRag/ManagedCode.GraphRag.csproj b/src/ManagedCode.GraphRag/ManagedCode.GraphRag.csproj index 4e6e7f8ccc..e7b550b75a 100644 --- a/src/ManagedCode.GraphRag/ManagedCode.GraphRag.csproj +++ b/src/ManagedCode.GraphRag/ManagedCode.GraphRag.csproj @@ -14,6 +14,7 @@ + diff --git a/src/ManagedCode.GraphRag/ServiceCollectionExtensions.cs b/src/ManagedCode.GraphRag/ServiceCollectionExtensions.cs index ea8d8ee72d..33a9937331 100644 --- a/src/ManagedCode.GraphRag/ServiceCollectionExtensions.cs +++ b/src/ManagedCode.GraphRag/ServiceCollectionExtensions.cs @@ -19,6 +19,7 @@ public static IServiceCollection AddGraphRag(this IServiceCollection services) services.AddKeyedSingleton(Indexing.Workflows.LoadInputDocumentsWorkflow.Name, static (_, _) => Indexing.Workflows.LoadInputDocumentsWorkflow.Create()); services.AddKeyedSingleton(Indexing.Workflows.CreateBaseTextUnitsWorkflow.Name, static (_, _) => Indexing.Workflows.CreateBaseTextUnitsWorkflow.Create()); services.AddKeyedSingleton(Indexing.Workflows.ExtractGraphWorkflow.Name, static (_, _) => Indexing.Workflows.ExtractGraphWorkflow.Create()); + services.AddKeyedSingleton(Indexing.Workflows.CreateCommunitiesWorkflow.Name, static (_, _) => Indexing.Workflows.CreateCommunitiesWorkflow.Create()); services.AddKeyedSingleton(Indexing.Workflows.CommunitySummariesWorkflow.Name, static (_, _) => Indexing.Workflows.CommunitySummariesWorkflow.Create()); services.AddKeyedSingleton(Indexing.Workflows.CreateFinalDocumentsWorkflow.Name, static (_, _) => Indexing.Workflows.CreateFinalDocumentsWorkflow.Create()); diff --git a/src/ManagedCode.GraphRag/Vectors/IVectorStore.cs b/src/ManagedCode.GraphRag/Vectors/IVectorStore.cs index 30bae2aac3..dcf8b48ab1 100644 --- a/src/ManagedCode.GraphRag/Vectors/IVectorStore.cs +++ b/src/ManagedCode.GraphRag/Vectors/IVectorStore.cs @@ -2,9 +2,15 @@ namespace GraphRag.Vectors; public interface IVectorStore { - Task UpsertAsync(string collection, ReadOnlyMemory embedding, IReadOnlyDictionary metadata, CancellationToken cancellationToken = default); + Task UpsertAsync( + string collection, + ReadOnlyMemory embedding, + IReadOnlyDictionary metadata, + CancellationToken cancellationToken = default); - IAsyncEnumerable SearchAsync(string collection, ReadOnlyMemory embedding, int limit, CancellationToken cancellationToken = default); + IAsyncEnumerable SearchAsync( + string collection, + ReadOnlyMemory embedding, + int limit, + CancellationToken cancellationToken = default); } - -public sealed record VectorSearchResult(string Id, double Score, IReadOnlyDictionary Metadata); diff --git a/src/ManagedCode.GraphRag/Vectors/VectorSearchResult.cs b/src/ManagedCode.GraphRag/Vectors/VectorSearchResult.cs new file mode 100644 index 0000000000..51c5c3091d --- /dev/null +++ b/src/ManagedCode.GraphRag/Vectors/VectorSearchResult.cs @@ -0,0 +1,77 @@ +using System.Collections.ObjectModel; + +namespace GraphRag.Vectors; + +/// +/// Represents a vector store match with associated metadata. +/// +public sealed class VectorSearchResult(string id, double score, IReadOnlyDictionary? metadata = null) +{ + private static readonly IReadOnlyDictionary EmptyMetadata = + new ReadOnlyDictionary(new Dictionary()); + + public string Id { get; } = id ?? throw new ArgumentNullException(nameof(id)); + + public double Score { get; } = score; + + public IReadOnlyDictionary Metadata { get; } = metadata ?? EmptyMetadata; + + public override bool Equals(object? obj) + { + if (ReferenceEquals(this, obj)) + { + return true; + } + + if (obj is not VectorSearchResult other) + { + return false; + } + + return string.Equals(Id, other.Id, StringComparison.Ordinal) && + Score.Equals(other.Score) && + DictionaryEquals(Metadata, other.Metadata); + } + + public override int GetHashCode() + { + var hash = new HashCode(); + hash.Add(Id, StringComparer.Ordinal); + hash.Add(Score); + foreach (var pair in Metadata.OrderBy(static kvp => kvp.Key, StringComparer.Ordinal)) + { + hash.Add(pair.Key, StringComparer.Ordinal); + hash.Add(pair.Value); + } + + return hash.ToHashCode(); + } + + private static bool DictionaryEquals(IReadOnlyDictionary first, IReadOnlyDictionary second) + { + if (ReferenceEquals(first, second)) + { + return true; + } + + if (first.Count != second.Count) + { + return false; + } + + foreach (var pair in first) + { + if (!second.TryGetValue(pair.Key, out var value)) + { + return false; + } + + if (!Equals(pair.Value, value)) + { + return false; + } + } + + return true; + } +} diff --git a/tests/ManagedCode.GraphRag.Tests/Cache/InMemoryPipelineCacheTests.cs b/tests/ManagedCode.GraphRag.Tests/Cache/InMemoryPipelineCacheTests.cs deleted file mode 100644 index b647d6fd6c..0000000000 --- a/tests/ManagedCode.GraphRag.Tests/Cache/InMemoryPipelineCacheTests.cs +++ /dev/null @@ -1,33 +0,0 @@ -using GraphRag.Cache; -using Xunit; - -namespace ManagedCode.GraphRag.Tests.Cache; - -public sealed class InMemoryPipelineCacheTests -{ - [Fact] - public async Task CacheStoresAndRetrievesValues() - { - var cache = new InMemoryPipelineCache(); - await cache.SetAsync("key", "value"); - - Assert.True(await cache.HasAsync("key")); - Assert.Equal("value", await cache.GetAsync("key")); - - await cache.DeleteAsync("key"); - Assert.False(await cache.HasAsync("key")); - } - - [Fact] - public async Task CreateChild_SharesUnderlyingEntries() - { - var cache = new InMemoryPipelineCache(); - var child = cache.CreateChild("child"); - - await cache.SetAsync("shared", 42); - Assert.Equal(42, await child.GetAsync("shared")); - - await child.ClearAsync(); - Assert.False(await cache.HasAsync("shared")); - } -} diff --git a/tests/ManagedCode.GraphRag.Tests/Cache/MemoryPipelineCacheTests.cs b/tests/ManagedCode.GraphRag.Tests/Cache/MemoryPipelineCacheTests.cs new file mode 100644 index 0000000000..358c69c46d --- /dev/null +++ b/tests/ManagedCode.GraphRag.Tests/Cache/MemoryPipelineCacheTests.cs @@ -0,0 +1,104 @@ +using System.Collections.Concurrent; +using System.Reflection; +using GraphRag.Cache; +using Microsoft.Extensions.Caching.Memory; + +namespace ManagedCode.GraphRag.Tests.Cache; + +public sealed class MemoryPipelineCacheTests +{ + [Fact] + public async Task SetAndGet_ReturnsStoredValue() + { + var memoryCache = new MemoryCache(new MemoryCacheOptions()); + var cache = new MemoryPipelineCache(memoryCache); + + await cache.SetAsync("foo", 42); + var value = await cache.GetAsync("foo"); + + Assert.Equal(42, value); + Assert.True(await cache.HasAsync("foo")); + } + + [Fact] + public async Task ClearAsync_RemovesEntries() + { + var memoryCache = new MemoryCache(new MemoryCacheOptions()); + var cache = new MemoryPipelineCache(memoryCache); + + await cache.SetAsync("foo", "bar"); + await cache.ClearAsync(); + + Assert.False(await cache.HasAsync("foo")); + } + + [Fact] + public async Task ChildCache_IsolatedFromParent() + { + var memoryCache = new MemoryCache(new MemoryCacheOptions()); + var parent = new MemoryPipelineCache(memoryCache); + var child = parent.CreateChild("child"); + + await child.SetAsync("value", "child"); + + Assert.False(await parent.HasAsync("value")); + Assert.Equal("child", await child.GetAsync("value")); + } + + [Fact] + public async Task ClearAsync_RemovesChildEntries() + { + var memoryCache = new MemoryCache(new MemoryCacheOptions()); + var parent = new MemoryPipelineCache(memoryCache); + var child = parent.CreateChild("child"); + + await parent.SetAsync("parentValue", "parent"); + await child.SetAsync("childValue", "child"); + + await parent.ClearAsync(); + + Assert.False(await parent.HasAsync("parentValue")); + Assert.False(await child.HasAsync("childValue")); + } + + [Fact] + public async Task DeleteAsync_RemovesTrackedKeyEvenWithDebugData() + { + var memoryCache = new MemoryCache(new MemoryCacheOptions()); + var cache = new MemoryPipelineCache(memoryCache); + + await cache.SetAsync("debug", 123, new Dictionary { ["token"] = "value" }); + var keys = GetTrackedKeys(cache); + Assert.Contains(keys.Keys, key => key.EndsWith(":debug", StringComparison.Ordinal)); + + await cache.DeleteAsync("debug"); + + Assert.DoesNotContain(GetTrackedKeys(cache).Keys, key => key.EndsWith(":debug", StringComparison.Ordinal)); + Assert.False(await cache.HasAsync("debug")); + } + + [Fact] + public async Task CreateChild_AfterParentWrites_StillClearsChildEntries() + { + var memoryCache = new MemoryCache(new MemoryCacheOptions()); + var parent = new MemoryPipelineCache(memoryCache); + + await parent.SetAsync("root", "root"); + var child = parent.CreateChild("later-child"); + await child.SetAsync("inner", "child"); + + await parent.ClearAsync(); + + Assert.False(await parent.HasAsync("root")); + Assert.False(await child.HasAsync("inner")); + } + + private static ConcurrentDictionary GetTrackedKeys(MemoryPipelineCache cache) + { + var field = typeof(MemoryPipelineCache) + .GetField("_keys", BindingFlags.NonPublic | BindingFlags.Instance) + ?? throw new InvalidOperationException("Could not access keys field."); + + return (ConcurrentDictionary)field.GetValue(cache)!; + } +} diff --git a/tests/ManagedCode.GraphRag.Tests/Callbacks/WorkflowCallbacksManagerTests.cs b/tests/ManagedCode.GraphRag.Tests/Callbacks/WorkflowCallbacksManagerTests.cs index da3b7dff68..a153529b0a 100644 --- a/tests/ManagedCode.GraphRag.Tests/Callbacks/WorkflowCallbacksManagerTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Callbacks/WorkflowCallbacksManagerTests.cs @@ -1,9 +1,6 @@ -using System; -using System.Collections.Generic; using GraphRag.Callbacks; using GraphRag.Indexing.Runtime; using GraphRag.Logging; -using Xunit; namespace ManagedCode.GraphRag.Tests.Callbacks; diff --git a/tests/ManagedCode.GraphRag.Tests/Chunking/MarkdownTextChunkerTests.cs b/tests/ManagedCode.GraphRag.Tests/Chunking/MarkdownTextChunkerTests.cs index 64a7c5c978..750562bb13 100644 --- a/tests/ManagedCode.GraphRag.Tests/Chunking/MarkdownTextChunkerTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Chunking/MarkdownTextChunkerTests.cs @@ -1,9 +1,7 @@ -using System.Linq; using GraphRag.Chunking; using GraphRag.Config; using GraphRag.Constants; using GraphRag.Tokenization; -using Xunit; namespace ManagedCode.GraphRag.Tests.Chunking; @@ -73,7 +71,8 @@ public void Chunk_RespectsOverlapBetweenChunks() var tokenizer = TokenizerRegistry.GetTokenizer(config.EncodingModel); var firstTokens = tokenizer.EncodeToIds(chunks[0].Text); - var secondTokens = tokenizer.EncodeToIds(chunks[1].Text); + + _ = tokenizer.EncodeToIds(chunks[1].Text); var overlapTokens = firstTokens.Skip(Math.Max(0, firstTokens.Count - config.Overlap)).ToArray(); Assert.True(overlapTokens.Length > 0); var overlapText = tokenizer.Decode(overlapTokens).TrimStart(); diff --git a/tests/ManagedCode.GraphRag.Tests/Chunking/TokenTextChunkerTests.cs b/tests/ManagedCode.GraphRag.Tests/Chunking/TokenTextChunkerTests.cs index f3cfc9ab81..28d8bf0a37 100644 --- a/tests/ManagedCode.GraphRag.Tests/Chunking/TokenTextChunkerTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Chunking/TokenTextChunkerTests.cs @@ -1,9 +1,7 @@ -using System.Linq; using GraphRag.Chunking; using GraphRag.Config; using GraphRag.Constants; using GraphRag.Tokenization; -using Xunit; namespace ManagedCode.GraphRag.Tests.Chunking; diff --git a/tests/ManagedCode.GraphRag.Tests/Config/ConfigTests.cs b/tests/ManagedCode.GraphRag.Tests/Config/ConfigTests.cs index dfcdf19d35..a7c74d3297 100644 --- a/tests/ManagedCode.GraphRag.Tests/Config/ConfigTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Config/ConfigTests.cs @@ -1,5 +1,4 @@ using GraphRag.Config; -using Xunit; namespace ManagedCode.GraphRag.Tests.Config; @@ -26,24 +25,6 @@ public void StorageConfig_AllowsCustomValues() Assert.Equal("https://cosmos.com", config.CosmosDbAccountUrl); } - [Fact] - public void CacheConfig_AllowsCustomValues() - { - var config = new CacheConfig - { - Type = CacheType.Memory, - BaseDir = "cache", - ConnectionString = "conn", - ContainerName = "container", - StorageAccountBlobUrl = "https://blob", - CosmosDbAccountUrl = "https://cosmos" - }; - - Assert.Equal(CacheType.Memory, config.Type); - Assert.Equal("cache", config.BaseDir); - Assert.Equal("conn", config.ConnectionString); - } - [Fact] public void ReportingConfig_AllowsCustomValues() { @@ -94,4 +75,12 @@ public void VectorStoreSchemaConfig_AllowsCustomization() Assert.Equal(42, config.VectorSize); Assert.Equal("index", config.IndexName); } + + [Fact] + public void GraphRagConfig_InitializesEmptyModelSet() + { + var config = new GraphRagConfig(); + + Assert.Empty(config.Models); + } } diff --git a/tests/ManagedCode.GraphRag.Tests/Covariates/TextUnitCovariateJoinerTests.cs b/tests/ManagedCode.GraphRag.Tests/Covariates/TextUnitCovariateJoinerTests.cs new file mode 100644 index 0000000000..ad450337ad --- /dev/null +++ b/tests/ManagedCode.GraphRag.Tests/Covariates/TextUnitCovariateJoinerTests.cs @@ -0,0 +1,56 @@ +using GraphRag.Covariates; +using GraphRag.Data; + +namespace ManagedCode.GraphRag.Tests.Covariates; + +public sealed class TextUnitCovariateJoinerTests +{ + [Fact] + public void Attach_MergesCovariateIdentifiersOntoTextUnits() + { + var textUnits = new[] + { + new TextUnitRecord + { + Id = "unit-1", + Text = "Alpha", + DocumentIds = new[] { "doc-1" }, + TokenCount = 12, + CovariateIds = Array.Empty() + }, + new TextUnitRecord + { + Id = "unit-2", + Text = "Beta", + DocumentIds = new[] { "doc-2" }, + TokenCount = 15, + CovariateIds = new[] { "existing" } + }, + new TextUnitRecord + { + Id = "unit-3", + Text = "Gamma", + DocumentIds = new[] { "doc-3" }, + TokenCount = 9 + } + }; + + var covariates = new[] + { + new CovariateRecord("cov-1", 0, "claim", "fraud", "", "entity-1", null, "OPEN", null, null, null, "unit-1"), + new CovariateRecord("cov-2", 1, "claim", "fraud", "", "entity-1", null, "OPEN", null, null, null, "unit-1"), + new CovariateRecord("cov-3", 2, "claim", "audit", "", "entity-2", null, "OPEN", null, null, null, "unit-2") + }; + + var updated = TextUnitCovariateJoiner.Attach(textUnits, covariates); + + var first = Assert.Single(updated, unit => unit.Id == "unit-1"); + Assert.Equal(new[] { "cov-1", "cov-2" }, first.CovariateIds); + + var second = Assert.Single(updated, unit => unit.Id == "unit-2"); + Assert.Equal(new[] { "cov-3", "existing" }, second.CovariateIds); + + var third = Assert.Single(updated, unit => unit.Id == "unit-3"); + Assert.Empty(third.CovariateIds); + } +} diff --git a/tests/ManagedCode.GraphRag.Tests/Finalization/GraphFinalizerTests.cs b/tests/ManagedCode.GraphRag.Tests/Finalization/GraphFinalizerTests.cs index 71e1931b77..38bb1f0266 100644 --- a/tests/ManagedCode.GraphRag.Tests/Finalization/GraphFinalizerTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Finalization/GraphFinalizerTests.cs @@ -1,9 +1,6 @@ -using System; -using System.Linq; using GraphRag.Entities; using GraphRag.Finalization; using GraphRag.Relationships; -using Xunit; namespace ManagedCode.GraphRag.Tests.Finalization; diff --git a/tests/ManagedCode.GraphRag.Tests/Graphs/GraphRelationshipTests.cs b/tests/ManagedCode.GraphRag.Tests/Graphs/GraphRelationshipTests.cs index 87443d63c9..4f4eec7efa 100644 --- a/tests/ManagedCode.GraphRag.Tests/Graphs/GraphRelationshipTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Graphs/GraphRelationshipTests.cs @@ -1,6 +1,4 @@ -using System.Collections.Generic; using GraphRag.Graphs; -using Xunit; namespace ManagedCode.GraphRag.Tests.Graphs; diff --git a/src/ManagedCode.GraphRag/Cache/InMemoryPipelineCache.cs b/tests/ManagedCode.GraphRag.Tests/Infrastructure/StubPipelineCache.cs similarity index 54% rename from src/ManagedCode.GraphRag/Cache/InMemoryPipelineCache.cs rename to tests/ManagedCode.GraphRag.Tests/Infrastructure/StubPipelineCache.cs index 7595b8b63e..52a5faccc8 100644 --- a/src/ManagedCode.GraphRag/Cache/InMemoryPipelineCache.cs +++ b/tests/ManagedCode.GraphRag.Tests/Infrastructure/StubPipelineCache.cs @@ -1,59 +1,42 @@ -using System.Collections.Concurrent; +using GraphRag.Cache; -namespace GraphRag.Cache; +namespace ManagedCode.GraphRag.Tests.Infrastructure; -public sealed class InMemoryPipelineCache : IPipelineCache +internal sealed class StubPipelineCache : IPipelineCache { - private readonly ConcurrentDictionary _entries; - - public InMemoryPipelineCache() - : this(new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase)) - { - } - - private InMemoryPipelineCache(ConcurrentDictionary entries) - { - _entries = entries; - } - public Task GetAsync(string key, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); - return Task.FromResult(_entries.TryGetValue(key, out var value) ? value.Value : null); + return Task.FromResult(null); } public Task SetAsync(string key, object? value, IReadOnlyDictionary? debugData = null, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); - _entries[key] = new CacheEntry(value, debugData); return Task.CompletedTask; } public Task HasAsync(string key, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); - return Task.FromResult(_entries.ContainsKey(key)); + return Task.FromResult(false); } public Task DeleteAsync(string key, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); - _entries.TryRemove(key, out _); return Task.CompletedTask; } public Task ClearAsync(CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); - _entries.Clear(); return Task.CompletedTask; } public IPipelineCache CreateChild(string name) { ArgumentException.ThrowIfNullOrWhiteSpace(name); - return new InMemoryPipelineCache(_entries); + return this; } - - private sealed record CacheEntry(object? Value, IReadOnlyDictionary? DebugData); } diff --git a/tests/ManagedCode.GraphRag.Tests/Infrastructure/TestChatClientFactory.cs b/tests/ManagedCode.GraphRag.Tests/Infrastructure/TestChatClientFactory.cs index c958f642f4..ff1af8664f 100644 --- a/tests/ManagedCode.GraphRag.Tests/Infrastructure/TestChatClientFactory.cs +++ b/tests/ManagedCode.GraphRag.Tests/Infrastructure/TestChatClientFactory.cs @@ -1,32 +1,17 @@ -using System; -using System.Collections.Generic; -using System.Linq; using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; using Microsoft.Extensions.AI; namespace ManagedCode.GraphRag.Tests.Infrastructure; -internal sealed class TestChatClientFactory +internal sealed class TestChatClientFactory(Func, ChatResponse>? responseFactory = null) { - private readonly Func, ChatResponse> _responseFactory; - - public TestChatClientFactory(Func, ChatResponse>? responseFactory = null) - { - _responseFactory = responseFactory ?? (messages => new ChatResponse(new ChatMessage(ChatRole.Assistant, "{}"))); - } + private readonly Func, ChatResponse> _responseFactory = responseFactory ?? (messages => new ChatResponse(new ChatMessage(ChatRole.Assistant, "{}"))); public IChatClient CreateClient() => new TestChatClient(_responseFactory); - private sealed class TestChatClient : IChatClient + private sealed class TestChatClient(Func, ChatResponse> responseFactory) : IChatClient { - private readonly Func, ChatResponse> _responseFactory; - - public TestChatClient(Func, ChatResponse> responseFactory) - { - _responseFactory = responseFactory; - } + private readonly Func, ChatResponse> _responseFactory = responseFactory; public void Dispose() { diff --git a/tests/ManagedCode.GraphRag.Tests/Integration/CommunitySummariesIntegrationTests.cs b/tests/ManagedCode.GraphRag.Tests/Integration/CommunitySummariesIntegrationTests.cs new file mode 100644 index 0000000000..97dd8e6f9d --- /dev/null +++ b/tests/ManagedCode.GraphRag.Tests/Integration/CommunitySummariesIntegrationTests.cs @@ -0,0 +1,235 @@ +using System.Collections.Immutable; +using GraphRag; +using GraphRag.Callbacks; +using GraphRag.Community; +using GraphRag.Config; +using GraphRag.Constants; +using GraphRag.Entities; +using GraphRag.Indexing.Runtime; +using GraphRag.Indexing.Workflows; +using GraphRag.Relationships; +using GraphRag.Storage; +using ManagedCode.GraphRag.Tests.Infrastructure; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; + +namespace ManagedCode.GraphRag.Tests.Integration; + +public sealed class CommunitySummariesIntegrationTests : IDisposable +{ + private readonly string _rootDir; + + public CommunitySummariesIntegrationTests() + { + _rootDir = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(_rootDir); + } + + [Fact] + public async Task CommunitySummariesWorkflow_UsesManualPromptOverrides() + { + var outputDir = Path.Combine(_rootDir, "output"); + var inputDir = Path.Combine(_rootDir, "input"); + var previousDir = Path.Combine(_rootDir, "previous"); + Directory.CreateDirectory(outputDir); + Directory.CreateDirectory(inputDir); + Directory.CreateDirectory(previousDir); + + var manualDirectory = Path.Combine(_rootDir, "prompt_overrides"); + var systemOverride = Path.Combine(manualDirectory, "index", "community_reports", "system.txt"); + var userOverride = Path.Combine(manualDirectory, "index", "community_reports", "user.txt"); + Directory.CreateDirectory(Path.GetDirectoryName(systemOverride)!); + Directory.CreateDirectory(Path.GetDirectoryName(userOverride)!); + + const string systemTemplate = "Manual system guidance"; + const string userTemplate = "Manual template for {{entities}} within {{max_length}} characters."; + File.WriteAllText(systemOverride, systemTemplate); + File.WriteAllText(userOverride, userTemplate); + + var outputStorage = new FilePipelineStorage(outputDir); + await outputStorage.WriteTableAsync(PipelineTableNames.Entities, new[] + { + new EntityRecord("entity-1", 0, "Alice", "Person", "Researcher", new[] { "unit-1" }.ToImmutableArray(), 2, 1, 0, 0), + new EntityRecord("entity-2", 1, "Bob", "Person", "Policy expert", new[] { "unit-2" }.ToImmutableArray(), 1, 1, 0, 0) + }); + + await outputStorage.WriteTableAsync(PipelineTableNames.Relationships, new[] + { + new RelationshipRecord("rel-1", 0, "Alice", "Bob", "collaborates_with", "Joint work", 0.8, 2, new[] { "unit-1" }.ToImmutableArray(), true) + }); + + var capturedSystem = string.Empty; + var capturedUser = string.Empty; + var services = new ServiceCollection() + .AddSingleton(new TestChatClientFactory(messages => + { + var system = messages.First(m => m.Role == ChatRole.System); + var user = messages.First(m => m.Role == ChatRole.User); + capturedSystem = system.Text ?? string.Empty; + capturedUser = user.Text ?? string.Empty; + return new ChatResponse(new ChatMessage(ChatRole.Assistant, "Manual summary output")); + }).CreateClient()) + .AddGraphRag() + .BuildServiceProvider(); + + var config = new GraphRagConfig + { + RootDir = _rootDir, + PromptTuning = new PromptTuningConfig + { + Manual = new ManualPromptTuningConfig + { + Enabled = true, + Directory = "prompt_overrides" + } + }, + CommunityReports = new CommunityReportsConfig + { + GraphPrompt = null, + TextPrompt = null, + MaxLength = 512 + } + }; + + var context = new PipelineRunContext( + inputStorage: new FilePipelineStorage(inputDir), + outputStorage: outputStorage, + previousStorage: new FilePipelineStorage(previousDir), + cache: new StubPipelineCache(), + callbacks: NoopWorkflowCallbacks.Instance, + stats: new PipelineRunStats(), + state: new PipelineState(), + services: services); + + var createCommunities = CreateCommunitiesWorkflow.Create(); + await createCommunities(config, context, CancellationToken.None); + + var summaries = CommunitySummariesWorkflow.Create(); + await summaries(config, context, CancellationToken.None); + + Assert.Equal(systemTemplate, capturedSystem); + Assert.DoesNotContain("{{", capturedUser, StringComparison.Ordinal); + Assert.Contains("Alice", capturedUser, StringComparison.Ordinal); + Assert.Contains("Bob", capturedUser, StringComparison.Ordinal); + + var reports = await outputStorage.LoadTableAsync(PipelineTableNames.CommunityReports); + var report = Assert.Single(reports); + Assert.Equal("Manual summary output", report.Summary); + Assert.Equal(2, report.EntityTitles.Count); + Assert.Equal(1, context.Items["community_reports:count"]); + Assert.True(File.Exists(Path.Combine(outputDir, $"{PipelineTableNames.CommunityReports}.json"))); + } + + [Fact] + public async Task CommunitySummariesWorkflow_PrefersManualOverAutoPrompts() + { + var outputDir = Path.Combine(_rootDir, "output-auto"); + var inputDir = Path.Combine(_rootDir, "input-auto"); + var previousDir = Path.Combine(_rootDir, "previous-auto"); + Directory.CreateDirectory(outputDir); + Directory.CreateDirectory(inputDir); + Directory.CreateDirectory(previousDir); + + var manualDirectory = Path.Combine(_rootDir, "prompt_manual"); + var autoDirectory = Path.Combine(_rootDir, "prompt_auto"); + + var manualSystem = Path.Combine(manualDirectory, "index", "community_reports", "system.txt"); + Directory.CreateDirectory(Path.GetDirectoryName(manualSystem)!); + File.WriteAllText(manualSystem, "Manual system override"); + + var autoSystem = Path.Combine(autoDirectory, "index", "community_reports", "system.txt"); + var autoUser = Path.Combine(autoDirectory, "index", "community_reports", "user.txt"); + Directory.CreateDirectory(Path.GetDirectoryName(autoSystem)!); + File.WriteAllText(autoSystem, "Auto system value"); + File.WriteAllText(autoUser, "Auto template for {{entities}} within {{max_length}} characters."); + + var outputStorage = new FilePipelineStorage(outputDir); + await outputStorage.WriteTableAsync(PipelineTableNames.Entities, new[] + { + new EntityRecord("entity-1", 0, "Alice", "Person", "Investigator", new[] { "unit-1" }.ToImmutableArray(), 2, 1, 0, 0), + new EntityRecord("entity-2", 1, "Eve", "Person", "Analyst", new[] { "unit-2" }.ToImmutableArray(), 1, 1, 0, 0) + }); + + await outputStorage.WriteTableAsync(PipelineTableNames.Relationships, new[] + { + new RelationshipRecord("rel-1", 0, "Alice", "Eve", "collaborates_with", "Joint research", 0.7, 2, new[] { "unit-1" }.ToImmutableArray(), true) + }); + + var capturedSystem = string.Empty; + var capturedUser = string.Empty; + var services = new ServiceCollection() + .AddSingleton(new TestChatClientFactory(messages => + { + var system = messages.First(m => m.Role == ChatRole.System); + var user = messages.First(m => m.Role == ChatRole.User); + capturedSystem = system.Text ?? string.Empty; + capturedUser = user.Text ?? string.Empty; + return new ChatResponse(new ChatMessage(ChatRole.Assistant, "Combined summary")); + }).CreateClient()) + .AddGraphRag() + .BuildServiceProvider(); + + var config = new GraphRagConfig + { + RootDir = _rootDir, + PromptTuning = new PromptTuningConfig + { + Manual = new ManualPromptTuningConfig + { + Enabled = true, + Directory = "prompt_manual" + }, + Auto = new AutoPromptTuningConfig + { + Enabled = true, + Directory = "prompt_auto" + } + }, + CommunityReports = new CommunityReportsConfig + { + GraphPrompt = null, + TextPrompt = null, + MaxLength = 256 + } + }; + + var context = new PipelineRunContext( + inputStorage: new FilePipelineStorage(inputDir), + outputStorage: outputStorage, + previousStorage: new FilePipelineStorage(previousDir), + cache: new StubPipelineCache(), + callbacks: NoopWorkflowCallbacks.Instance, + stats: new PipelineRunStats(), + state: new PipelineState(), + services: services); + + var createCommunities = CreateCommunitiesWorkflow.Create(); + await createCommunities(config, context, CancellationToken.None); + + var summaries = CommunitySummariesWorkflow.Create(); + await summaries(config, context, CancellationToken.None); + + Assert.Equal("Manual system override", capturedSystem); + Assert.Contains("Auto template", capturedUser, StringComparison.Ordinal); + Assert.DoesNotContain("{{", capturedUser, StringComparison.Ordinal); + + var reports = await outputStorage.LoadTableAsync(PipelineTableNames.CommunityReports); + var report = Assert.Single(reports); + Assert.Equal("Combined summary", report.Summary); + Assert.Equal(2, report.EntityTitles.Count); + } + + public void Dispose() + { + try + { + if (Directory.Exists(_rootDir)) + { + Directory.Delete(_rootDir, recursive: true); + } + } + catch + { + } + } +} diff --git a/tests/ManagedCode.GraphRag.Tests/Integration/Finalization/GraphFinalizerTests.cs b/tests/ManagedCode.GraphRag.Tests/Integration/Finalization/GraphFinalizerTests.cs index 3c36ea2e7f..8d6c4ef1cc 100644 --- a/tests/ManagedCode.GraphRag.Tests/Integration/Finalization/GraphFinalizerTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Integration/Finalization/GraphFinalizerTests.cs @@ -1,9 +1,6 @@ -using System; -using System.Linq; using GraphRag.Entities; using GraphRag.Finalization; using GraphRag.Relationships; -using Xunit; namespace ManagedCode.GraphRag.Tests.Integration.Finalization; diff --git a/tests/ManagedCode.GraphRag.Tests/Integration/GraphRagApplicationCollection.cs b/tests/ManagedCode.GraphRag.Tests/Integration/GraphRagApplicationCollection.cs index ae86fc080a..fd32032f74 100644 --- a/tests/ManagedCode.GraphRag.Tests/Integration/GraphRagApplicationCollection.cs +++ b/tests/ManagedCode.GraphRag.Tests/Integration/GraphRagApplicationCollection.cs @@ -1,5 +1,3 @@ -using Xunit; - namespace ManagedCode.GraphRag.Tests.Integration; [CollectionDefinition(nameof(GraphRagApplicationCollection))] diff --git a/tests/ManagedCode.GraphRag.Tests/Integration/GraphRagApplicationFixture.cs b/tests/ManagedCode.GraphRag.Tests/Integration/GraphRagApplicationFixture.cs index 29e8ef4195..c4897a52e8 100644 --- a/tests/ManagedCode.GraphRag.Tests/Integration/GraphRagApplicationFixture.cs +++ b/tests/ManagedCode.GraphRag.Tests/Integration/GraphRagApplicationFixture.cs @@ -1,22 +1,17 @@ -using System; -using System.Collections.Generic; -using System.Threading; +using DotNet.Testcontainers.Builders; using GraphRag; using GraphRag.Graphs; using GraphRag.Indexing.Runtime; using GraphRag.Storage.Cosmos; using GraphRag.Storage.Neo4j; using GraphRag.Storage.Postgres; +using ManagedCode.GraphRag.Tests.Infrastructure; using Microsoft.Extensions.AI; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using DotNet.Testcontainers.Builders; using Npgsql; using Testcontainers.Neo4j; using Testcontainers.PostgreSql; -using Xunit; -using ManagedCode.GraphRag.Tests.Infrastructure; namespace ManagedCode.GraphRag.Tests.Integration; @@ -36,31 +31,42 @@ public sealed class GraphRagApplicationFixture : IAsyncLifetime public async Task InitializeAsync() { - _neo4jContainer = new Neo4jBuilder() - .WithImage("neo4j:5.23.0-community") - .WithEnvironment("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes") - .WithEnvironment("NEO4J_PLUGINS", "[\"apoc\"]") - .WithEnvironment("NEO4J_dbms_default__listen__address", "0.0.0.0") - .WithEnvironment("NEO4J_dbms_default__advertised__address", "localhost") - .WithEnvironment("NEO4J_AUTH", $"neo4j/{Neo4jPassword}") - .WithWaitStrategy(Wait.ForUnixContainer().UntilInternalTcpPortIsAvailable(7687)) - .Build(); - - _postgresContainer = new PostgreSqlBuilder() - .WithImage("apache/age:latest") - .WithDatabase(PostgresDatabase) - .WithUsername("postgres") - .WithPassword(PostgresPassword) - .WithCleanUp(true) - .WithWaitStrategy(Wait.ForUnixContainer().UntilInternalTcpPortIsAvailable(5432)) - .Build(); - - await Task.WhenAll(_neo4jContainer.StartAsync(), _postgresContainer.StartAsync()).ConfigureAwait(false); - - await EnsurePostgresDatabaseAsync().ConfigureAwait(false); - - var boltEndpoint = new Uri(_neo4jContainer.GetConnectionString(), UriKind.Absolute); - var postgresConnection = _postgresContainer.GetConnectionString(); + var skipContainers = string.Equals( + Environment.GetEnvironmentVariable("GRAPHRAG_SKIP_TESTCONTAINERS"), + "1", + StringComparison.OrdinalIgnoreCase); + + Uri? boltEndpoint = null; + string? postgresConnection = null; + + if (!skipContainers) + { + _neo4jContainer = new Neo4jBuilder() + .WithImage("neo4j:5.23.0-community") + .WithEnvironment("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes") + .WithEnvironment("NEO4J_PLUGINS", "[\"apoc\"]") + .WithEnvironment("NEO4J_dbms_default__listen__address", "0.0.0.0") + .WithEnvironment("NEO4J_dbms_default__advertised__address", "localhost") + .WithEnvironment("NEO4J_AUTH", $"neo4j/{Neo4jPassword}") + .WithWaitStrategy(Wait.ForUnixContainer().UntilInternalTcpPortIsAvailable(7687)) + .Build(); + + _postgresContainer = new PostgreSqlBuilder() + .WithImage("apache/age:latest") + .WithDatabase(PostgresDatabase) + .WithUsername("postgres") + .WithPassword(PostgresPassword) + .WithCleanUp(true) + .WithWaitStrategy(Wait.ForUnixContainer().UntilInternalTcpPortIsAvailable(5432)) + .Build(); + + await Task.WhenAll(_neo4jContainer.StartAsync(), _postgresContainer.StartAsync()).ConfigureAwait(false); + await EnsurePostgresDatabaseAsync().ConfigureAwait(false); + + boltEndpoint = new Uri(_neo4jContainer.GetConnectionString(), UriKind.Absolute); + postgresConnection = _postgresContainer.GetConnectionString(); + } + var cosmosConnectionString = Environment.GetEnvironmentVariable("COSMOS_EMULATOR_CONNECTION_STRING"); var includeCosmos = !string.IsNullOrWhiteSpace(cosmosConnectionString); @@ -74,39 +80,55 @@ public async Task InitializeAsync() services.AddGraphRag(); - services.AddKeyedSingleton("neo4j-seed", static (_, _) => async (config, context, token) => + if (!skipContainers && boltEndpoint is not null && postgresConnection is not null) { - var graph = context.Services.GetRequiredKeyedService("neo4j"); - await graph.InitializeAsync(token).ConfigureAwait(false); - await graph.UpsertNodeAsync("alice", "Person", new Dictionary { ["name"] = "Alice" }, token).ConfigureAwait(false); - await graph.UpsertNodeAsync("bob", "Person", new Dictionary { ["name"] = "Bob" }, token).ConfigureAwait(false); - await graph.UpsertRelationshipAsync("alice", "bob", "KNOWS", new Dictionary { ["since"] = 2024 }, token).ConfigureAwait(false); - var relationships = new List(); - await foreach (var relationship in graph.GetOutgoingRelationshipsAsync("alice", token).ConfigureAwait(false)) + services.AddKeyedSingleton("neo4j-seed", static (_, _) => async (config, context, token) => { - relationships.Add(relationship); - } + var graph = context.Services.GetRequiredKeyedService("neo4j"); + await graph.InitializeAsync(token).ConfigureAwait(false); + await graph.UpsertNodeAsync("alice", "Person", new Dictionary { ["name"] = "Alice" }, token).ConfigureAwait(false); + await graph.UpsertNodeAsync("bob", "Person", new Dictionary { ["name"] = "Bob" }, token).ConfigureAwait(false); + await graph.UpsertRelationshipAsync("alice", "bob", "KNOWS", new Dictionary { ["since"] = 2024 }, token).ConfigureAwait(false); + var relationships = new List(); + await foreach (var relationship in graph.GetOutgoingRelationshipsAsync("alice", token).ConfigureAwait(false)) + { + relationships.Add(relationship); + } - context.Items["neo4j:relationship-count"] = relationships.Count; - return new WorkflowResult(null); - }); + context.Items["neo4j:relationship-count"] = relationships.Count; + return new WorkflowResult(null); + }); - services.AddKeyedSingleton("postgres-seed", static (_, _) => async (config, context, token) => - { - var graph = context.Services.GetRequiredKeyedService("postgres"); - await graph.InitializeAsync(token).ConfigureAwait(false); - await graph.UpsertNodeAsync("chapter-1", "Chapter", new Dictionary { ["title"] = "Origins" }, token).ConfigureAwait(false); - await graph.UpsertNodeAsync("chapter-2", "Chapter", new Dictionary { ["title"] = "Discovery" }, token).ConfigureAwait(false); - await graph.UpsertRelationshipAsync("chapter-1", "chapter-2", "LEADS_TO", new Dictionary { ["weight"] = 0.9 }, token).ConfigureAwait(false); - var relationships = new List(); - await foreach (var relationship in graph.GetOutgoingRelationshipsAsync("chapter-1", token).ConfigureAwait(false)) + services.AddKeyedSingleton("postgres-seed", static (_, _) => async (config, context, token) => { - relationships.Add(relationship); - } + var graph = context.Services.GetRequiredKeyedService("postgres"); + await graph.InitializeAsync(token).ConfigureAwait(false); + await graph.UpsertNodeAsync("chapter-1", "Chapter", new Dictionary { ["title"] = "Origins" }, token).ConfigureAwait(false); + await graph.UpsertNodeAsync("chapter-2", "Chapter", new Dictionary { ["title"] = "Discovery" }, token).ConfigureAwait(false); + await graph.UpsertRelationshipAsync("chapter-1", "chapter-2", "LEADS_TO", new Dictionary { ["weight"] = 0.9 }, token).ConfigureAwait(false); + var relationships = new List(); + await foreach (var relationship in graph.GetOutgoingRelationshipsAsync("chapter-1", token).ConfigureAwait(false)) + { + relationships.Add(relationship); + } + + context.Items["postgres:relationship-count"] = relationships.Count; + return new WorkflowResult(null); + }); - context.Items["postgres:relationship-count"] = relationships.Count; - return new WorkflowResult(null); - }); + services.AddNeo4jGraphStore("neo4j", options => + { + options.Uri = boltEndpoint.ToString(); + options.Username = "neo4j"; + options.Password = Neo4jPassword; + }, makeDefault: true); + + services.AddPostgresGraphStore("postgres", options => + { + options.ConnectionString = postgresConnection!; + options.GraphName = "graphrag"; + }); + } if (includeCosmos) { @@ -128,19 +150,6 @@ public async Task InitializeAsync() }); } - services.AddNeo4jGraphStore("neo4j", options => - { - options.Uri = boltEndpoint.ToString(); - options.Username = "neo4j"; - options.Password = Neo4jPassword; - }, makeDefault: true); - - services.AddPostgresGraphStore("postgres", options => - { - options.ConnectionString = postgresConnection; - options.GraphName = "graphrag"; - }); - if (includeCosmos) { services.AddCosmosGraphStore("cosmos", options => diff --git a/tests/ManagedCode.GraphRag.Tests/Integration/GraphStoreIntegrationTests.cs b/tests/ManagedCode.GraphRag.Tests/Integration/GraphStoreIntegrationTests.cs index 01ab76a9bd..a8040098f5 100644 --- a/tests/ManagedCode.GraphRag.Tests/Integration/GraphStoreIntegrationTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Integration/GraphStoreIntegrationTests.cs @@ -1,7 +1,5 @@ -using System.Linq; using GraphRag.Graphs; using Microsoft.Extensions.DependencyInjection; -using Xunit; namespace ManagedCode.GraphRag.Tests.Integration; @@ -58,6 +56,7 @@ public async Task GraphStores_RoundTripRelationshipsAsync(string providerKey) } [Fact] + [Trait("Category", "Cosmos")] public async Task CosmosGraphStore_RoundTrips_WhenEmulatorAvailable() { var cosmosStore = fixture.Services.GetKeyedService("cosmos"); diff --git a/tests/ManagedCode.GraphRag.Tests/Integration/IndexingPipelineRunnerTests.cs b/tests/ManagedCode.GraphRag.Tests/Integration/IndexingPipelineRunnerTests.cs index b9620bb31f..8a32d6a2e1 100644 --- a/tests/ManagedCode.GraphRag.Tests/Integration/IndexingPipelineRunnerTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Integration/IndexingPipelineRunnerTests.cs @@ -1,17 +1,12 @@ -using System.IO; using System.Text.Json; -using System.Threading.Tasks; using GraphRag; using GraphRag.Config; -using GraphRag.Indexing; -using GraphRag.Storage; -using Microsoft.Extensions.AI; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Xunit; using GraphRag.Constants; +using GraphRag.Indexing; using GraphRag.Indexing.Workflows; using ManagedCode.GraphRag.Tests.Infrastructure; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; namespace ManagedCode.GraphRag.Tests.Integration; diff --git a/tests/ManagedCode.GraphRag.Tests/Integration/PostgresGraphStoreIntegrationTests.cs b/tests/ManagedCode.GraphRag.Tests/Integration/PostgresGraphStoreIntegrationTests.cs index 1e7488326b..d512f6aaed 100644 --- a/tests/ManagedCode.GraphRag.Tests/Integration/PostgresGraphStoreIntegrationTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Integration/PostgresGraphStoreIntegrationTests.cs @@ -1,30 +1,19 @@ -using System; -using System.Collections.Generic; using System.Globalization; -using System.IO; -using System.Linq; using System.Text; using System.Text.Json; -using System.Threading.Tasks; using GraphRag.Graphs; using GraphRag.Storage.Postgres; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging.Abstractions; using Npgsql; using NpgsqlTypes; -using Xunit; namespace ManagedCode.GraphRag.Tests.Integration; [Collection(nameof(GraphRagApplicationCollection))] -public sealed class PostgresGraphStoreIntegrationTests +public sealed class PostgresGraphStoreIntegrationTests(GraphRagApplicationFixture fixture) { - private readonly GraphRagApplicationFixture _fixture; - - public PostgresGraphStoreIntegrationTests(GraphRagApplicationFixture fixture) - { - _fixture = fixture; - } + private readonly GraphRagApplicationFixture _fixture = fixture; [Fact] public async Task UpsertNode_NormalizesNestedProperties() diff --git a/tests/ManagedCode.GraphRag.Tests/LanguageModels/PromptTemplateLoaderTests.cs b/tests/ManagedCode.GraphRag.Tests/LanguageModels/PromptTemplateLoaderTests.cs new file mode 100644 index 0000000000..89e13b759c --- /dev/null +++ b/tests/ManagedCode.GraphRag.Tests/LanguageModels/PromptTemplateLoaderTests.cs @@ -0,0 +1,121 @@ +using GraphRag.Config; +using GraphRag.Constants; +using GraphRag.LanguageModels; + +namespace ManagedCode.GraphRag.Tests.LanguageModels; + +public sealed class PromptTemplateLoaderTests : IDisposable +{ + private readonly string _rootDir; + + public PromptTemplateLoaderTests() + { + _rootDir = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(_rootDir); + } + + [Fact] + public void ResolveOrDefault_UsesExplicitPath() + { + var templatePath = Path.Combine(_rootDir, "custom.txt"); + File.WriteAllText(templatePath, "explicit prompt"); + + var config = new GraphRagConfig { RootDir = _rootDir }; + var loader = PromptTemplateLoader.Create(config); + + var prompt = loader.ResolveOrDefault(PromptTemplateKeys.ExtractGraphUser, "custom.txt", "fallback"); + + Assert.Equal("explicit prompt", prompt); + } + + [Fact] + public void ResolveOptional_ReadsManualDirectory() + { + var manualDirectory = Path.Combine(_rootDir, "manual"); + Directory.CreateDirectory(Path.Combine(manualDirectory, "index", "community_reports")); + var manualPath = Path.Combine(manualDirectory, "index", "community_reports", "user.txt"); + File.WriteAllText(manualPath, "manual prompt"); + + var config = new GraphRagConfig + { + RootDir = _rootDir, + PromptTuning = new PromptTuningConfig + { + Manual = new ManualPromptTuningConfig + { + Enabled = true, + Directory = "manual" + } + } + }; + + var loader = PromptTemplateLoader.Create(config); + var prompt = loader.ResolveOptional(PromptTemplateKeys.CommunitySummaryUser, null); + + Assert.Equal("manual prompt", prompt); + } + + [Fact] + public void ResolveOrDefault_FallsBackToAutoDirectory() + { + var autoDirectory = Path.Combine(_rootDir, "auto"); + Directory.CreateDirectory(Path.Combine(autoDirectory, "index", "extract_graph")); + var autoPath = Path.Combine(autoDirectory, "index", "extract_graph", "system.txt"); + File.WriteAllText(autoPath, "auto system"); + + var config = new GraphRagConfig + { + RootDir = _rootDir, + PromptTuning = new PromptTuningConfig + { + Auto = new AutoPromptTuningConfig + { + Enabled = true, + Directory = "auto" + } + } + }; + + var loader = PromptTemplateLoader.Create(config); + var prompt = loader.ResolveOrDefault(PromptTemplateKeys.ExtractGraphSystem, null, "fallback"); + + Assert.Equal("auto system", prompt); + } + + [Fact] + public void ResolveOrDefault_AllowsInlinePromptWithPrefix() + { + var config = new GraphRagConfig(); + var loader = PromptTemplateLoader.Create(config); + + var prompt = loader.ResolveOrDefault(PromptTemplateKeys.ExtractGraphSystem, "inline:custom text", "fallback"); + + Assert.Equal("custom text", prompt); + } + + [Fact] + public void ResolveOrDefault_AllowsInlinePromptWithNewlines() + { + var config = new GraphRagConfig(); + var loader = PromptTemplateLoader.Create(config); + + var inline = "line1\nline2"; + var prompt = loader.ResolveOrDefault(PromptTemplateKeys.ExtractGraphUser, inline, "fallback"); + + Assert.Equal(inline, prompt); + } + + public void Dispose() + { + try + { + if (Directory.Exists(_rootDir)) + { + Directory.Delete(_rootDir, recursive: true); + } + } + catch + { + } + } +} diff --git a/tests/ManagedCode.GraphRag.Tests/Relationships/RelationshipRecordTests.cs b/tests/ManagedCode.GraphRag.Tests/Relationships/RelationshipRecordTests.cs index fc2a3b9060..2b11ee0ac1 100644 --- a/tests/ManagedCode.GraphRag.Tests/Relationships/RelationshipRecordTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Relationships/RelationshipRecordTests.cs @@ -1,7 +1,5 @@ -using System.Collections.Generic; using System.Collections.Immutable; using GraphRag.Relationships; -using Xunit; namespace ManagedCode.GraphRag.Tests.Relationships; diff --git a/tests/ManagedCode.GraphRag.Tests/Runtime/DefaultPipelineFactoryTests.cs b/tests/ManagedCode.GraphRag.Tests/Runtime/DefaultPipelineFactoryTests.cs index b8bcb3b5dd..5f26ff1ada 100644 --- a/tests/ManagedCode.GraphRag.Tests/Runtime/DefaultPipelineFactoryTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Runtime/DefaultPipelineFactoryTests.cs @@ -1,7 +1,5 @@ -using System.Threading.Tasks; using GraphRag.Indexing.Runtime; using Microsoft.Extensions.DependencyInjection; -using Xunit; namespace ManagedCode.GraphRag.Tests.Runtime; diff --git a/tests/ManagedCode.GraphRag.Tests/Runtime/IndexingPipelineRunnerTests.cs b/tests/ManagedCode.GraphRag.Tests/Runtime/IndexingPipelineRunnerTests.cs index d01139ec11..7c9cbf5f2f 100644 --- a/tests/ManagedCode.GraphRag.Tests/Runtime/IndexingPipelineRunnerTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Runtime/IndexingPipelineRunnerTests.cs @@ -1,12 +1,9 @@ -using System.Collections.Generic; -using System.Threading.Tasks; using GraphRag.Config; using GraphRag.Indexing; using GraphRag.Indexing.Runtime; using GraphRag.Storage; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging.Abstractions; -using Xunit; namespace ManagedCode.GraphRag.Tests.Runtime; @@ -48,14 +45,9 @@ public async Task RunAsync_ExecutesPipelineAndReturnsResults() Assert.IsType(capturedContexts[0].InputStorage); } - private sealed class StubPipelineFactory : IPipelineFactory + private sealed class StubPipelineFactory(WorkflowDelegate step) : IPipelineFactory { - private readonly WorkflowDelegate _step; - - public StubPipelineFactory(WorkflowDelegate step) - { - _step = step; - } + private readonly WorkflowDelegate _step = step; public WorkflowPipeline BuildIndexingPipeline(IndexingPipelineDescriptor descriptor) { diff --git a/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineBuilderTests.cs b/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineBuilderTests.cs index cda76cb786..4b7cf0980d 100644 --- a/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineBuilderTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineBuilderTests.cs @@ -1,6 +1,4 @@ -using System.Threading.Tasks; using GraphRag.Indexing.Runtime; -using Xunit; namespace ManagedCode.GraphRag.Tests.Runtime; diff --git a/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineContextFactoryTests.cs b/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineContextFactoryTests.cs index 18873606db..3ad3a96cf0 100644 --- a/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineContextFactoryTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineContextFactoryTests.cs @@ -1,10 +1,8 @@ -using System.Collections.Generic; -using GraphRag.Cache; using GraphRag.Callbacks; using GraphRag.Indexing.Runtime; using GraphRag.Storage; +using ManagedCode.GraphRag.Tests.Infrastructure; using Microsoft.Extensions.DependencyInjection; -using Xunit; namespace ManagedCode.GraphRag.Tests.Runtime; @@ -16,7 +14,7 @@ public void Create_UsesProvidedComponents() var input = new MemoryPipelineStorage(); var output = new MemoryPipelineStorage(); var previous = new MemoryPipelineStorage(); - var cache = new InMemoryPipelineCache(); + var cache = new StubPipelineCache(); var callbacks = WorkflowCallbacksManagerFactory(); var stats = new PipelineRunStats(); var state = new PipelineState(); @@ -34,7 +32,7 @@ public void Create_ProvidesDefaultsWhenNull() var context = PipelineContextFactory.Create(); Assert.IsType(context.InputStorage); - Assert.IsType(context.Cache); + Assert.Null(context.Cache); Assert.NotNull(context.Services); } diff --git a/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineExecutorTests.cs b/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineExecutorTests.cs index 7d206fad3d..775f2fafb1 100644 --- a/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineExecutorTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Runtime/PipelineExecutorTests.cs @@ -1,16 +1,11 @@ -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using GraphRag; -using GraphRag.Cache; using GraphRag.Callbacks; using GraphRag.Config; using GraphRag.Indexing.Runtime; +using GraphRag.Logging; using GraphRag.Storage; +using ManagedCode.GraphRag.Tests.Infrastructure; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging.Abstractions; -using Xunit; namespace ManagedCode.GraphRag.Tests.Runtime; @@ -24,7 +19,7 @@ public async Task ExecuteAsync_StopsOnException() new MemoryPipelineStorage(), new MemoryPipelineStorage(), new MemoryPipelineStorage(), - new InMemoryPipelineCache(), + new StubPipelineCache(), NoopWorkflowCallbacks.Instance, new PipelineRunStats(), new PipelineState(), @@ -59,7 +54,7 @@ public async Task ExecuteAsync_HonoursStopSignal() new MemoryPipelineStorage(), new MemoryPipelineStorage(), new MemoryPipelineStorage(), - new InMemoryPipelineCache(), + new StubPipelineCache(), NoopWorkflowCallbacks.Instance, new PipelineRunStats(), new PipelineState(), @@ -81,4 +76,136 @@ public async Task ExecuteAsync_HonoursStopSignal() Assert.Single(outputs); Assert.Equal("first", outputs[0].Workflow); } + + [Fact] + public async Task ExecuteAsync_InvokesCallbacksAndUpdatesStats() + { + var services = new ServiceCollection().BuildServiceProvider(); + var callbacks = new RecordingCallbacks(); + var stats = new PipelineRunStats(); + var context = new PipelineRunContext( + new MemoryPipelineStorage(), + new MemoryPipelineStorage(), + new MemoryPipelineStorage(), + new StubPipelineCache(), + callbacks, + stats, + new PipelineState(), + services); + + var pipeline = new WorkflowPipeline("stats", new[] + { + new WorkflowStep("first", async (cfg, ctx, token) => + { + await Task.Delay(5, token); + return new WorkflowResult("ok"); + }), + new WorkflowStep("second", (cfg, ctx, token) => ValueTask.FromResult(new WorkflowResult("done"))) + }); + + var executor = new PipelineExecutor(new NullLogger()); + var results = new List(); + + await foreach (var result in executor.ExecuteAsync(pipeline, new GraphRagConfig(), context)) + { + results.Add(result); + } + + Assert.Equal(new[] { "first", "second" }, callbacks.WorkflowStarts); + Assert.Equal(callbacks.WorkflowStarts, callbacks.WorkflowEnds); + Assert.Equal(2, callbacks.PipelineEndResults?.Count); + Assert.True(callbacks.PipelineStartedWith?.SequenceEqual(pipeline.Names)); + + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.Null(r.Errors)); + + Assert.True(stats.TotalRuntime >= 0); + Assert.True(stats.Workflows.ContainsKey("first")); + Assert.True(stats.Workflows["first"].ContainsKey("overall")); + Assert.True(stats.Workflows.ContainsKey("second")); + Assert.True(stats.Workflows["second"].ContainsKey("overall")); + } + + [Fact] + public async Task ExecuteAsync_RecordsExceptionInResultsAndStats() + { + var services = new ServiceCollection().BuildServiceProvider(); + var stats = new PipelineRunStats(); + var callbacks = new RecordingCallbacks(); + var context = new PipelineRunContext( + new MemoryPipelineStorage(), + new MemoryPipelineStorage(), + new MemoryPipelineStorage(), + new StubPipelineCache(), + callbacks, + stats, + new PipelineState(), + services); + + var failure = new InvalidOperationException("fail"); + var pipeline = new WorkflowPipeline("failing", new[] + { + new WorkflowStep("good", (cfg, ctx, token) => ValueTask.FromResult(new WorkflowResult("done"))), + new WorkflowStep("bad", (cfg, ctx, token) => throw failure), + new WorkflowStep("skipped", (cfg, ctx, token) => ValueTask.FromResult(new WorkflowResult("nope"))) + }); + + var executor = new PipelineExecutor(new NullLogger()); + var results = new List(); + + await foreach (var result in executor.ExecuteAsync(pipeline, new GraphRagConfig(), context)) + { + results.Add(result); + } + + Assert.Equal(2, results.Count); + Assert.Null(results[0].Errors); + var errorResult = results[1]; + Assert.NotNull(errorResult.Errors); + var captured = Assert.Single(errorResult.Errors!); + Assert.Same(failure, captured); + + Assert.Equal(new[] { "good", "bad" }, callbacks.WorkflowStarts); + Assert.Equal(callbacks.WorkflowStarts, callbacks.WorkflowEnds); + Assert.Equal(2, callbacks.PipelineEndResults?.Count); + + Assert.True(stats.Workflows.ContainsKey("good")); + Assert.True(stats.Workflows.ContainsKey("bad")); + Assert.False(stats.Workflows.ContainsKey("skipped")); + Assert.True(stats.TotalRuntime >= 0); + } + + private sealed class RecordingCallbacks : IWorkflowCallbacks + { + public IReadOnlyList? PipelineStartedWith { get; private set; } + public List WorkflowStarts { get; } = new(); + public List WorkflowEnds { get; } = new(); + public IReadOnlyList? PipelineEndResults { get; private set; } + public List ProgressUpdates { get; } = new(); + + public void PipelineStart(IReadOnlyList names) + { + PipelineStartedWith = names.ToArray(); + } + + public void PipelineEnd(IReadOnlyList results) + { + PipelineEndResults = results.ToArray(); + } + + public void WorkflowStart(string name, object? instance) + { + WorkflowStarts.Add(name); + } + + public void WorkflowEnd(string name, object? instance) + { + WorkflowEnds.Add(name); + } + + public void ReportProgress(ProgressSnapshot progress) + { + ProgressUpdates.Add(progress); + } + } } diff --git a/tests/ManagedCode.GraphRag.Tests/Runtime/ServiceCollectionExtensionsTests.cs b/tests/ManagedCode.GraphRag.Tests/Runtime/ServiceCollectionExtensionsTests.cs index e29e4ce6a0..6ca9479f1c 100644 --- a/tests/ManagedCode.GraphRag.Tests/Runtime/ServiceCollectionExtensionsTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Runtime/ServiceCollectionExtensionsTests.cs @@ -1,13 +1,11 @@ -using System.Threading.Tasks; using GraphRag; using GraphRag.Chunking; using GraphRag.Indexing.Runtime; +using ManagedCode.GraphRag.Tests.Infrastructure; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using Xunit; -using ManagedCode.GraphRag.Tests.Infrastructure; namespace ManagedCode.GraphRag.Tests.Runtime; diff --git a/tests/ManagedCode.GraphRag.Tests/Storage/FilePipelineStorageTests.cs b/tests/ManagedCode.GraphRag.Tests/Storage/FilePipelineStorageTests.cs index 560ff33ec8..843a8fe66c 100644 --- a/tests/ManagedCode.GraphRag.Tests/Storage/FilePipelineStorageTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Storage/FilePipelineStorageTests.cs @@ -1,10 +1,6 @@ -using System.Collections.Generic; -using System.IO; using System.Text; using System.Text.RegularExpressions; -using System.Threading.Tasks; using GraphRag.Storage; -using Xunit; namespace ManagedCode.GraphRag.Tests.Storage; diff --git a/tests/ManagedCode.GraphRag.Tests/Storage/MemoryPipelineStorageTests.cs b/tests/ManagedCode.GraphRag.Tests/Storage/MemoryPipelineStorageTests.cs index 73f3a83fa8..1d569c7705 100644 --- a/tests/ManagedCode.GraphRag.Tests/Storage/MemoryPipelineStorageTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Storage/MemoryPipelineStorageTests.cs @@ -1,10 +1,6 @@ -using System.Collections.Generic; -using System.IO; using System.Text; using System.Text.RegularExpressions; -using System.Threading.Tasks; using GraphRag.Storage; -using Xunit; namespace ManagedCode.GraphRag.Tests.Storage; diff --git a/tests/ManagedCode.GraphRag.Tests/Storage/PipelineStorageExtensionsTests.cs b/tests/ManagedCode.GraphRag.Tests/Storage/PipelineStorageExtensionsTests.cs index 81d19adfd9..4df718c526 100644 --- a/tests/ManagedCode.GraphRag.Tests/Storage/PipelineStorageExtensionsTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Storage/PipelineStorageExtensionsTests.cs @@ -1,8 +1,4 @@ -using System.Collections.Generic; -using System.IO; -using System.Threading.Tasks; using GraphRag.Storage; -using Xunit; namespace ManagedCode.GraphRag.Tests.Storage; diff --git a/tests/ManagedCode.GraphRag.Tests/Storage/PipelineStorageFactoryTests.cs b/tests/ManagedCode.GraphRag.Tests/Storage/PipelineStorageFactoryTests.cs index 1f1723fbba..baf5178142 100644 --- a/tests/ManagedCode.GraphRag.Tests/Storage/PipelineStorageFactoryTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Storage/PipelineStorageFactoryTests.cs @@ -1,8 +1,5 @@ -using System; -using System.IO; using GraphRag.Config; using GraphRag.Storage; -using Xunit; namespace ManagedCode.GraphRag.Tests.Storage; diff --git a/tests/ManagedCode.GraphRag.Tests/Storage/Postgres/ServiceCollectionExtensionsTests.cs b/tests/ManagedCode.GraphRag.Tests/Storage/Postgres/ServiceCollectionExtensionsTests.cs index 677636edf9..85387c8fdf 100644 --- a/tests/ManagedCode.GraphRag.Tests/Storage/Postgres/ServiceCollectionExtensionsTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Storage/Postgres/ServiceCollectionExtensionsTests.cs @@ -1,12 +1,9 @@ -using System.Collections.Generic; -using System.Threading.Tasks; using GraphRag.Config; using GraphRag.Graphs; using GraphRag.Storage.Postgres; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using Xunit; namespace ManagedCode.GraphRag.Tests.Storage.Postgres; diff --git a/tests/ManagedCode.GraphRag.Tests/Tokenization/TokenizerRegistryTests.cs b/tests/ManagedCode.GraphRag.Tests/Tokenization/TokenizerRegistryTests.cs index 8a859059d4..bedfc492ba 100644 --- a/tests/ManagedCode.GraphRag.Tests/Tokenization/TokenizerRegistryTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Tokenization/TokenizerRegistryTests.cs @@ -1,6 +1,5 @@ using GraphRag.Constants; using GraphRag.Tokenization; -using Xunit; namespace ManagedCode.GraphRag.Tests.Tokenization; diff --git a/tests/ManagedCode.GraphRag.Tests/Vectors/VectorSearchResultTests.cs b/tests/ManagedCode.GraphRag.Tests/Vectors/VectorSearchResultTests.cs index 97e4ea5aa6..3de1b000fd 100644 --- a/tests/ManagedCode.GraphRag.Tests/Vectors/VectorSearchResultTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Vectors/VectorSearchResultTests.cs @@ -1,6 +1,4 @@ -using System.Collections.Generic; using GraphRag.Vectors; -using Xunit; namespace ManagedCode.GraphRag.Tests.Vectors; diff --git a/tests/ManagedCode.GraphRag.Tests/Workflows/CommunitySummariesWorkflowTests.cs b/tests/ManagedCode.GraphRag.Tests/Workflows/CommunitySummariesWorkflowTests.cs index 9a891fb241..f05a00691c 100644 --- a/tests/ManagedCode.GraphRag.Tests/Workflows/CommunitySummariesWorkflowTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Workflows/CommunitySummariesWorkflowTests.cs @@ -1,9 +1,5 @@ -using System.Collections.Generic; using System.Collections.Immutable; -using System.Threading; -using System.Threading.Tasks; using GraphRag; -using GraphRag.Cache; using GraphRag.Callbacks; using GraphRag.Community; using GraphRag.Config; @@ -13,10 +9,9 @@ using GraphRag.Indexing.Workflows; using GraphRag.Relationships; using GraphRag.Storage; +using ManagedCode.GraphRag.Tests.Infrastructure; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Xunit; -using ManagedCode.GraphRag.Tests.Infrastructure; namespace ManagedCode.GraphRag.Tests.Workflows; @@ -48,14 +43,19 @@ await outputStorage.WriteTableAsync(PipelineTableNames.Relationships, new[] inputStorage: new MemoryPipelineStorage(), outputStorage: outputStorage, previousStorage: new MemoryPipelineStorage(), - cache: new InMemoryPipelineCache(), + cache: new StubPipelineCache(), callbacks: NoopWorkflowCallbacks.Instance, stats: new PipelineRunStats(), state: new PipelineState(), services: services); + var config = new GraphRagConfig(); + + var createCommunities = CreateCommunitiesWorkflow.Create(); + await createCommunities(config, context, CancellationToken.None); + var workflow = CommunitySummariesWorkflow.Create(); - await workflow(new GraphRagConfig(), context, CancellationToken.None); + await workflow(config, context, CancellationToken.None); var reports = await outputStorage.LoadTableAsync(PipelineTableNames.CommunityReports); var report = Assert.Single(reports); diff --git a/tests/ManagedCode.GraphRag.Tests/Workflows/CreateBaseTextUnitsWorkflowTests.cs b/tests/ManagedCode.GraphRag.Tests/Workflows/CreateBaseTextUnitsWorkflowTests.cs index affd5457ab..578ce45726 100644 --- a/tests/ManagedCode.GraphRag.Tests/Workflows/CreateBaseTextUnitsWorkflowTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Workflows/CreateBaseTextUnitsWorkflowTests.cs @@ -1,22 +1,14 @@ -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; using GraphRag; -using GraphRag.Cache; using GraphRag.Callbacks; -using GraphRag.Chunking; using GraphRag.Config; using GraphRag.Constants; using GraphRag.Data; using GraphRag.Indexing.Runtime; using GraphRag.Indexing.Workflows; -using GraphRag.Logging; using GraphRag.Storage; +using ManagedCode.GraphRag.Tests.Infrastructure; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Xunit; -using ManagedCode.GraphRag.Tests.Infrastructure; namespace ManagedCode.GraphRag.Tests.Workflows; @@ -49,7 +41,7 @@ await outputStorage.WriteTableAsync(PipelineTableNames.Documents, new[] inputStorage: new MemoryPipelineStorage(), outputStorage: outputStorage, previousStorage: new MemoryPipelineStorage(), - cache: new InMemoryPipelineCache(), + cache: new StubPipelineCache(), callbacks: NoopWorkflowCallbacks.Instance, stats: new PipelineRunStats(), state: new PipelineState(), @@ -103,7 +95,7 @@ await outputStorage.WriteTableAsync(PipelineTableNames.Documents, new[] inputStorage: new MemoryPipelineStorage(), outputStorage: outputStorage, previousStorage: new MemoryPipelineStorage(), - cache: new InMemoryPipelineCache(), + cache: new StubPipelineCache(), callbacks: NoopWorkflowCallbacks.Instance, stats: new PipelineRunStats(), state: new PipelineState(), @@ -151,7 +143,7 @@ await outputStorage.WriteTableAsync(PipelineTableNames.Documents, new[] inputStorage: new MemoryPipelineStorage(), outputStorage: outputStorage, previousStorage: new MemoryPipelineStorage(), - cache: new InMemoryPipelineCache(), + cache: new StubPipelineCache(), callbacks: NoopWorkflowCallbacks.Instance, stats: new PipelineRunStats(), state: new PipelineState(), diff --git a/tests/ManagedCode.GraphRag.Tests/Workflows/CreateCommunitiesWorkflowTests.cs b/tests/ManagedCode.GraphRag.Tests/Workflows/CreateCommunitiesWorkflowTests.cs new file mode 100644 index 0000000000..5204fe6ca9 --- /dev/null +++ b/tests/ManagedCode.GraphRag.Tests/Workflows/CreateCommunitiesWorkflowTests.cs @@ -0,0 +1,123 @@ +using System.Collections.Immutable; +using GraphRag; +using GraphRag.Callbacks; +using GraphRag.Community; +using GraphRag.Config; +using GraphRag.Constants; +using GraphRag.Entities; +using GraphRag.Indexing.Runtime; +using GraphRag.Indexing.Workflows; +using GraphRag.Relationships; +using GraphRag.Storage; +using ManagedCode.GraphRag.Tests.Infrastructure; +using Microsoft.Extensions.DependencyInjection; + +namespace ManagedCode.GraphRag.Tests.Workflows; + +public sealed class CreateCommunitiesWorkflowTests +{ + [Fact] + public async Task RunWorkflow_GroupsEntitiesAndPersistsCommunities() + { + var outputStorage = new MemoryPipelineStorage(); + await outputStorage.WriteTableAsync(PipelineTableNames.Entities, new[] + { + new EntityRecord("entity-alice", 0, "Alice", "person", "Researcher", new[] { "unit-1" }.ToImmutableArray(), 3, 2, 0, 0), + new EntityRecord("entity-bob", 1, "Bob", "person", "Policy expert", new[] { "unit-2" }.ToImmutableArray(), 2, 2, 0, 0), + new EntityRecord("entity-carol", 2, "Carol", "person", "Analyst", new[] { "unit-3" }.ToImmutableArray(), 1, 1, 0, 0), + new EntityRecord("entity-dave", 3, "Dave", "person", "Observer", new[] { "unit-4" }.ToImmutableArray(), 1, 0, 0, 0) + }); + + await outputStorage.WriteTableAsync(PipelineTableNames.Relationships, new[] + { + new RelationshipRecord("rel-1", 0, "Alice", "Bob", "collaborates_with", null, 0.8, 3, new[] { "unit-1", "unit-2" }.ToImmutableArray(), true), + new RelationshipRecord("rel-2", 1, "Bob", "Carol", "mentors", null, 0.6, 3, new[] { "unit-3" }.ToImmutableArray(), false) + }); + + var context = new PipelineRunContext( + inputStorage: new MemoryPipelineStorage(), + outputStorage: outputStorage, + previousStorage: new MemoryPipelineStorage(), + cache: new StubPipelineCache(), + callbacks: NoopWorkflowCallbacks.Instance, + stats: new PipelineRunStats(), + state: new PipelineState(), + services: new ServiceCollection().AddGraphRag().BuildServiceProvider()); + + var workflow = CreateCommunitiesWorkflow.Create(); + var config = new GraphRagConfig + { + ClusterGraph = new ClusterGraphConfig + { + MaxClusterSize = 2, + UseLargestConnectedComponent = false, + Seed = 1337 + } + }; + + await workflow(config, context, CancellationToken.None); + + var communities = await outputStorage.LoadTableAsync(PipelineTableNames.Communities); + Assert.Equal(3, communities.Count); + Assert.True(context.Items.TryGetValue("create_communities:count", out var countValue)); + Assert.Equal(3, Assert.IsType(countValue)); + + var communityByMembers = communities.ToDictionary( + community => community.EntityIds.OrderBy(id => id).ToArray(), + community => community, + new SequenceComparer()); + + var aliceBob = communityByMembers[new[] { "entity-alice", "entity-bob" }]; + Assert.Equal(2, aliceBob.Size); + Assert.Equal(aliceBob.CommunityId, aliceBob.HumanReadableId); + Assert.Contains("rel-1", aliceBob.RelationshipIds); + Assert.Contains("unit-1", aliceBob.TextUnitIds); + Assert.Contains("unit-2", aliceBob.TextUnitIds); + Assert.Equal(-1, aliceBob.ParentId); + + var carol = communityByMembers[new[] { "entity-carol" }]; + Assert.Empty(carol.RelationshipIds); + Assert.Contains("unit-3", carol.TextUnitIds); + + var dave = communityByMembers[new[] { "entity-dave" }]; + Assert.Empty(dave.RelationshipIds); + Assert.Contains("unit-4", dave.TextUnitIds); + } + + private sealed class SequenceComparer : IEqualityComparer> where T : notnull + { + public bool Equals(IReadOnlyList? x, IReadOnlyList? y) + { + if (x is null || y is null) + { + return x is null && y is null; + } + + if (x.Count != y.Count) + { + return false; + } + + for (var index = 0; index < x.Count; index++) + { + if (!EqualityComparer.Default.Equals(x[index], y[index])) + { + return false; + } + } + + return true; + } + + public int GetHashCode(IReadOnlyList obj) + { + var hash = new HashCode(); + foreach (var item in obj) + { + hash.Add(item); + } + + return hash.ToHashCode(); + } + } +} diff --git a/tests/ManagedCode.GraphRag.Tests/Workflows/CreateFinalDocumentsWorkflowTests.cs b/tests/ManagedCode.GraphRag.Tests/Workflows/CreateFinalDocumentsWorkflowTests.cs index e87345400c..d9598116f3 100644 --- a/tests/ManagedCode.GraphRag.Tests/Workflows/CreateFinalDocumentsWorkflowTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Workflows/CreateFinalDocumentsWorkflowTests.cs @@ -1,20 +1,14 @@ -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; using GraphRag; -using GraphRag.Cache; using GraphRag.Callbacks; using GraphRag.Config; +using GraphRag.Constants; using GraphRag.Data; using GraphRag.Indexing.Runtime; using GraphRag.Indexing.Workflows; using GraphRag.Storage; +using ManagedCode.GraphRag.Tests.Infrastructure; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Xunit; -using GraphRag.Constants; -using ManagedCode.GraphRag.Tests.Infrastructure; namespace ManagedCode.GraphRag.Tests.Workflows; @@ -58,7 +52,7 @@ await outputStorage.WriteTableAsync(PipelineTableNames.TextUnits, new[] inputStorage: new MemoryPipelineStorage(), outputStorage: outputStorage, previousStorage: new MemoryPipelineStorage(), - cache: new InMemoryPipelineCache(), + cache: new StubPipelineCache(), callbacks: NoopWorkflowCallbacks.Instance, stats: new PipelineRunStats(), state: new PipelineState(), diff --git a/tests/ManagedCode.GraphRag.Tests/Workflows/ExtractGraphWorkflowTests.cs b/tests/ManagedCode.GraphRag.Tests/Workflows/ExtractGraphWorkflowTests.cs index f3754b169a..4bde2a879e 100644 --- a/tests/ManagedCode.GraphRag.Tests/Workflows/ExtractGraphWorkflowTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Workflows/ExtractGraphWorkflowTests.cs @@ -1,10 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; using GraphRag; -using GraphRag.Cache; using GraphRag.Callbacks; using GraphRag.Config; using GraphRag.Constants; @@ -14,10 +8,9 @@ using GraphRag.Indexing.Workflows; using GraphRag.Relationships; using GraphRag.Storage; +using ManagedCode.GraphRag.Tests.Infrastructure; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Xunit; -using ManagedCode.GraphRag.Tests.Infrastructure; namespace ManagedCode.GraphRag.Tests.Workflows; @@ -61,7 +54,7 @@ await outputStorage.WriteTableAsync(PipelineTableNames.TextUnits, new[] inputStorage: new MemoryPipelineStorage(), outputStorage: outputStorage, previousStorage: new MemoryPipelineStorage(), - cache: new InMemoryPipelineCache(), + cache: new StubPipelineCache(), callbacks: NoopWorkflowCallbacks.Instance, stats: new PipelineRunStats(), state: new PipelineState(), diff --git a/tests/ManagedCode.GraphRag.Tests/Workflows/LoadInputDocumentsWorkflowTests.cs b/tests/ManagedCode.GraphRag.Tests/Workflows/LoadInputDocumentsWorkflowTests.cs index 4cb3248f68..b340b4cb88 100644 --- a/tests/ManagedCode.GraphRag.Tests/Workflows/LoadInputDocumentsWorkflowTests.cs +++ b/tests/ManagedCode.GraphRag.Tests/Workflows/LoadInputDocumentsWorkflowTests.cs @@ -1,23 +1,16 @@ -using System.Collections.Generic; -using System.IO; -using System.Linq; using System.Text; using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; using GraphRag; -using GraphRag.Cache; using GraphRag.Callbacks; using GraphRag.Config; +using GraphRag.Constants; using GraphRag.Data; using GraphRag.Indexing.Runtime; using GraphRag.Indexing.Workflows; using GraphRag.Storage; +using ManagedCode.GraphRag.Tests.Infrastructure; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Xunit; -using GraphRag.Constants; -using ManagedCode.GraphRag.Tests.Infrastructure; namespace ManagedCode.GraphRag.Tests.Workflows; @@ -39,7 +32,7 @@ public async Task RunWorkflow_LoadsTextFiles() inputStorage, outputStorage, previousStorage: new MemoryPipelineStorage(), - cache: new InMemoryPipelineCache(), + cache: new StubPipelineCache(), callbacks: NoopWorkflowCallbacks.Instance, stats: new PipelineRunStats(), state: new PipelineState(), @@ -80,7 +73,7 @@ public async Task RunWorkflow_LoadsCsvFiles() inputStorage, outputStorage, new MemoryPipelineStorage(), - new InMemoryPipelineCache(), + new StubPipelineCache(), NoopWorkflowCallbacks.Instance, new PipelineRunStats(), new PipelineState(), @@ -127,7 +120,7 @@ public async Task RunWorkflow_LoadsJsonFiles() inputStorage, outputStorage, new MemoryPipelineStorage(), - new InMemoryPipelineCache(), + new StubPipelineCache(), NoopWorkflowCallbacks.Instance, new PipelineRunStats(), new PipelineState(), @@ -170,7 +163,7 @@ public async Task RunWorkflow_ParsesJsonLinesFallback() inputStorage, outputStorage, new MemoryPipelineStorage(), - new InMemoryPipelineCache(), + new StubPipelineCache(), NoopWorkflowCallbacks.Instance, new PipelineRunStats(), new PipelineState(),