diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index 0587c156f9a5..ef907861a6d3 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -76,5 +76,6 @@ export 'src/tool.dart' FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration, + RAGEngineGrounding, Tool, ToolConfig; diff --git a/packages/firebase_ai/firebase_ai/lib/src/tool.dart b/packages/firebase_ai/firebase_ai/lib/src/tool.dart index 394cb555e7af..fb934d5d8439 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/tool.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/tool.dart @@ -21,12 +21,13 @@ import 'schema.dart'; /// knowledge and scope of the model. final class Tool { // ignore: public_member_api_docs - Tool._(this._functionDeclarations, this._googleSearch); + Tool._( + this._functionDeclarations, this._googleSearch, this._ragEngineGrounding); /// Returns a [Tool] instance with list of [FunctionDeclaration]. static Tool functionDeclarations( List functionDeclarations) { - return Tool._(functionDeclarations, null); + return Tool._(functionDeclarations, null, null); } /// Creates a tool that allows the model to use Grounding with Google Search. @@ -47,7 +48,18 @@ final class Tool { /// /// Returns a `Tool` configured for Google Search. static Tool googleSearch({GoogleSearch googleSearch = const GoogleSearch()}) { - return Tool._(null, googleSearch); + return Tool._(null, googleSearch, null); + } + + /// Creates a tool that allows the model to use RAG Engine Grounding. + /// + /// RAG Engine Grounding can be used to allow the model to connect to Vertex AI + /// RAG Engine for retrieving and incorporating external knowledge into its + /// responses. + /// + /// Only available in Vertex AI + static Tool ragEngine(RAGEngineGrounding ragEngineGrounding) { + return Tool._(null, null, ragEngineGrounding); } /// A list of `FunctionDeclarations` available to the model that can be used @@ -65,13 +77,74 @@ final class Tool { /// responses. final GoogleSearch? _googleSearch; + /// A tool that allows the generative model to connect to RAG Engine for + /// retrieving and incorporating external knowledge into its responses. + final RAGEngineGrounding? _ragEngineGrounding; + /// Convert to json object. Map toJson() => { if (_functionDeclarations case final _functionDeclarations?) 'functionDeclarations': _functionDeclarations.map((f) => f.toJson()).toList(), if (_googleSearch case final _googleSearch?) - 'googleSearch': _googleSearch.toJson() + 'googleSearch': _googleSearch.toJson(), + if (_ragEngineGrounding case final _ragEngineGrounding?) + 'retrieval': _ragEngineGrounding.toJson() + }; +} + +/// Configuration for RAG (Retrieval-Augmented Generation) Engine Grounding. +/// +/// This tool allows grounding a model's response in a specific corpus of data +/// stored in Vertex AI. It helps the model generate more accurate and +/// contextually relevant answers by retrieving information from your own data sources. +/// +/// Use this to configure the grounding settings for a generative model tool. +final class RAGEngineGrounding { + /// Creates a new instance of [RAGEngineGrounding]. + /// + /// [projectId] is the ID of the Vertex AI project. + /// [location] is the location of the corpus (e.g., 'us-central1'). + /// [corpusId] is the specific ID of the corpus to use for grounding. + /// [topK] specifies the number of top contexts to retrieve, defaulting to 20. + const RAGEngineGrounding({ + required this.projectId, + required this.location, + required this.corpusId, + this.topK = 20, + }); + + /// The project ID of the Vertex AI project that contains the corpus. + final String projectId; + + /// The location of the corpus. + final String location; + + /// The ID of the corpus to use for RAG Engine Grounding. + final String corpusId; + + /// The number of top contexts to retrieve. + /// + /// Must be between 1 and 20. + final int topK; + + /// The path to the corpus in the format: + /// `projects/{projectId}/locations/{location}/ragCorpora/{corpusId}` + String get corpusPath => + 'projects/$projectId/locations/$location/ragCorpora/$corpusId'; + + /// Converts this [RAGEngineGrounding] object into a JSON-compatible Map. + Map toJson() => { + 'vertexRagStore': { + 'ragResources': [ + { + 'ragCorpus': corpusPath, + }, + ], + 'ragRetrievalConfig': { + 'topK': topK, + }, + }, }; } diff --git a/packages/firebase_ai/firebase_ai/test/model_test.dart b/packages/firebase_ai/firebase_ai/test/model_test.dart index 860b8e19ba7c..634051e5c70f 100644 --- a/packages/firebase_ai/firebase_ai/test/model_test.dart +++ b/packages/firebase_ai/firebase_ai/test/model_test.dart @@ -268,6 +268,47 @@ void main() { ); }); + test('can pass rag engine grounding tool', () async { + final (client, model) = createModel( + tools: [ + Tool.ragEngine( + const RAGEngineGrounding( + corpusId: 'corpusId', + projectId: 'project-12345', + location: 'global', + topK: 15, + ), + ), + ], + ); + + const prompt = 'Some prompt'; + + await client.checkRequest( + () => model.generateContent([Content.text(prompt)]), + verifyRequest: (_, request) { + expect(request['tools'], [ + { + 'retrieval': { + 'vertexRagStore': { + 'ragResources': [ + { + 'ragCorpus': + 'projects/project-12345/locations/global/ragCorpora/corpusId', + } + ], + 'ragRetrievalConfig': { + 'topK': 15, + } + } + } + }, + ]); + }, + response: arbitraryGenerateContentResponse, + ); + }); + test('can pass a google search tool', () async { final (client, model) = createModel( tools: [Tool.googleSearch()],