diff --git a/.gitignore b/.gitignore index 1dfa72c..0f686a8 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,6 @@ overleaf.kubeconfig # coverage report coverage.out coverage.html + +# claude code +CLAUDE.md \ No newline at end of file diff --git a/internal/api/chat/get_citation_keys.go b/internal/api/chat/get_citation_keys.go new file mode 100644 index 0000000..d2a0fa0 --- /dev/null +++ b/internal/api/chat/get_citation_keys.go @@ -0,0 +1,43 @@ +package chat + +import ( + "context" + + "paperdebugger/internal/libs/contextutil" + "paperdebugger/internal/models" + chatv2 "paperdebugger/pkg/gen/api/chat/v2" +) + +func (s *ChatServerV2) GetCitationKeys( + ctx context.Context, + req *chatv2.GetCitationKeysRequest, +) (*chatv2.GetCitationKeysResponse, error) { + actor, err := contextutil.GetActor(ctx) + if err != nil { + return nil, err + } + + settings, err := s.userService.GetUserSettings(ctx, actor.ID) + if err != nil { + return nil, err + } + + llmProvider := &models.LLMProviderConfig{ + APIKey: settings.OpenAIAPIKey, + } + + citationKeys, err := s.aiClientV2.GetCitationKeys( + ctx, + req.GetSentence(), + actor.ID, + req.GetProjectId(), + llmProvider, + ) + if err != nil { + return nil, err + } + + return &chatv2.GetCitationKeysResponse{ + CitationKeys: citationKeys, + }, nil +} diff --git a/internal/services/toolkit/client/get_citation_keys.go b/internal/services/toolkit/client/get_citation_keys.go new file mode 100644 index 0000000..03d9a0e --- /dev/null +++ b/internal/services/toolkit/client/get_citation_keys.go @@ -0,0 +1,53 @@ +package client + +// TODO: This file should not place in the client package. +import ( + "context" + "fmt" + "paperdebugger/internal/models" + "strings" + + "github.com/openai/openai-go/v3" + "go.mongodb.org/mongo-driver/v2/bson" +) + +func (a *AIClientV2) GetCitationKeys(ctx context.Context, sentence string, userId bson.ObjectID, projectId string, llmProvider *models.LLMProviderConfig) (string, error) { + // Get bibliography from mongodb + project, err := a.projectService.GetProject(ctx, userId, projectId) + if err != nil { + return "", err + } + + var bibFiles []string + for _, doc := range project.Docs { + if doc.Filepath != "" && strings.HasSuffix(doc.Filepath, ".bib") { + bibFiles = append(bibFiles, doc.Lines...) + } + } + bibliography := strings.Join(bibFiles, "\n") + + // Get citation keys from LLM + emptyCitation := "none" + message := fmt.Sprintf("Sentence: %s\nBibliography: %s\nBased on the sentence and bibliography, suggest only the most relevant citation keys separated by commas with no spaces (e.g. key1,key2). Be selective and only include citations that are directly relevant. Avoid suggesting more than 3 citations. If no relevant citations are found, return '%s'.", sentence, bibliography, emptyCitation) + + _, resp, err := a.ChatCompletionV2(ctx, "gpt-5.2", OpenAIChatHistory{ + openai.SystemMessage("You are a helpful assistant that suggests relevant citation keys."), + openai.UserMessage(message), + }, llmProvider) + + if err != nil { + return "", err + } + + if len(resp) == 0 { + return "", nil + } + + citationKeys := strings.TrimSpace(resp[0].Payload.GetAssistant().GetContent()) + + if citationKeys == emptyCitation { + return "", nil + } + + return citationKeys, nil +} diff --git a/pkg/gen/api/chat/v2/chat.pb.go b/pkg/gen/api/chat/v2/chat.pb.go index 3ba45df..01f94f3 100644 --- a/pkg/gen/api/chat/v2/chat.pb.go +++ b/pkg/gen/api/chat/v2/chat.pb.go @@ -1897,6 +1897,113 @@ func (*CreateConversationMessageStreamResponse_StreamError) isCreateConversation func (*CreateConversationMessageStreamResponse_ReasoningChunk) isCreateConversationMessageStreamResponse_ResponsePayload() { } +// Request to suggest citation keys based on context +type GetCitationKeysRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sentence string `protobuf:"bytes,1,opt,name=sentence,proto3" json:"sentence,omitempty"` + ProjectId string `protobuf:"bytes,2,opt,name=project_id,json=projectId,proto3" json:"project_id,omitempty"` + ModelSlug *string `protobuf:"bytes,3,opt,name=model_slug,json=modelSlug,proto3,oneof" json:"model_slug,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetCitationKeysRequest) Reset() { + *x = GetCitationKeysRequest{} + mi := &file_chat_v2_chat_proto_msgTypes[30] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetCitationKeysRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetCitationKeysRequest) ProtoMessage() {} + +func (x *GetCitationKeysRequest) ProtoReflect() protoreflect.Message { + mi := &file_chat_v2_chat_proto_msgTypes[30] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetCitationKeysRequest.ProtoReflect.Descriptor instead. +func (*GetCitationKeysRequest) Descriptor() ([]byte, []int) { + return file_chat_v2_chat_proto_rawDescGZIP(), []int{30} +} + +func (x *GetCitationKeysRequest) GetSentence() string { + if x != nil { + return x.Sentence + } + return "" +} + +func (x *GetCitationKeysRequest) GetProjectId() string { + if x != nil { + return x.ProjectId + } + return "" +} + +func (x *GetCitationKeysRequest) GetModelSlug() string { + if x != nil && x.ModelSlug != nil { + return *x.ModelSlug + } + return "" +} + +// Response containing the suggested citation keys +type GetCitationKeysResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // A comma-separated string of keys, or empty if none found + CitationKeys string `protobuf:"bytes,1,opt,name=citation_keys,json=citationKeys,proto3" json:"citation_keys,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetCitationKeysResponse) Reset() { + *x = GetCitationKeysResponse{} + mi := &file_chat_v2_chat_proto_msgTypes[31] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetCitationKeysResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetCitationKeysResponse) ProtoMessage() {} + +func (x *GetCitationKeysResponse) ProtoReflect() protoreflect.Message { + mi := &file_chat_v2_chat_proto_msgTypes[31] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetCitationKeysResponse.ProtoReflect.Descriptor instead. +func (*GetCitationKeysResponse) Descriptor() ([]byte, []int) { + return file_chat_v2_chat_proto_rawDescGZIP(), []int{31} +} + +func (x *GetCitationKeysResponse) GetCitationKeys() string { + if x != nil { + return x.CitationKeys + } + return "" +} + var File_chat_v2_chat_proto protoreflect.FileDescriptor const file_chat_v2_chat_proto_rawDesc = "" + @@ -2030,17 +2137,27 @@ const file_chat_v2_chat_proto_rawDesc = "" + "\x13stream_finalization\x18\x06 \x01(\v2\x1b.chat.v2.StreamFinalizationH\x00R\x12streamFinalization\x129\n" + "\fstream_error\x18\a \x01(\v2\x14.chat.v2.StreamErrorH\x00R\vstreamError\x12B\n" + "\x0freasoning_chunk\x18\b \x01(\v2\x17.chat.v2.ReasoningChunkH\x00R\x0ereasoningChunkB\x12\n" + - "\x10response_payload*R\n" + + "\x10response_payload\"\x86\x01\n" + + "\x16GetCitationKeysRequest\x12\x1a\n" + + "\bsentence\x18\x01 \x01(\tR\bsentence\x12\x1d\n" + + "\n" + + "project_id\x18\x02 \x01(\tR\tprojectId\x12\"\n" + + "\n" + + "model_slug\x18\x03 \x01(\tH\x00R\tmodelSlug\x88\x01\x01B\r\n" + + "\v_model_slug\">\n" + + "\x17GetCitationKeysResponse\x12#\n" + + "\rcitation_keys\x18\x01 \x01(\tR\fcitationKeys*R\n" + "\x10ConversationType\x12!\n" + "\x1dCONVERSATION_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n" + - "\x17CONVERSATION_TYPE_DEBUG\x10\x012\xa8\a\n" + + "\x17CONVERSATION_TYPE_DEBUG\x10\x012\xab\b\n" + "\vChatService\x12\x83\x01\n" + "\x11ListConversations\x12!.chat.v2.ListConversationsRequest\x1a\".chat.v2.ListConversationsResponse\"'\x82\xd3\xe4\x93\x02!\x12\x1f/_pd/api/v2/chats/conversations\x12\x8f\x01\n" + "\x0fGetConversation\x12\x1f.chat.v2.GetConversationRequest\x1a .chat.v2.GetConversationResponse\"9\x82\xd3\xe4\x93\x023\x121/_pd/api/v2/chats/conversations/{conversation_id}\x12\xc2\x01\n" + "\x1fCreateConversationMessageStream\x12/.chat.v2.CreateConversationMessageStreamRequest\x1a0.chat.v2.CreateConversationMessageStreamResponse\":\x82\xd3\xe4\x93\x024:\x01*\"//_pd/api/v2/chats/conversations/messages/stream0\x01\x12\x9b\x01\n" + "\x12UpdateConversation\x12\".chat.v2.UpdateConversationRequest\x1a#.chat.v2.UpdateConversationResponse\"<\x82\xd3\xe4\x93\x026:\x01*21/_pd/api/v2/chats/conversations/{conversation_id}\x12\x98\x01\n" + "\x12DeleteConversation\x12\".chat.v2.DeleteConversationRequest\x1a#.chat.v2.DeleteConversationResponse\"9\x82\xd3\xe4\x93\x023*1/_pd/api/v2/chats/conversations/{conversation_id}\x12\x82\x01\n" + - "\x13ListSupportedModels\x12#.chat.v2.ListSupportedModelsRequest\x1a$.chat.v2.ListSupportedModelsResponse\" \x82\xd3\xe4\x93\x02\x1a\x12\x18/_pd/api/v2/chats/modelsB\x7f\n" + + "\x13ListSupportedModels\x12#.chat.v2.ListSupportedModelsRequest\x1a$.chat.v2.ListSupportedModelsResponse\" \x82\xd3\xe4\x93\x02\x1a\x12\x18/_pd/api/v2/chats/models\x12\x80\x01\n" + + "\x0fGetCitationKeys\x12\x1f.chat.v2.GetCitationKeysRequest\x1a .chat.v2.GetCitationKeysResponse\"*\x82\xd3\xe4\x93\x02$:\x01*\"\x1f/_pd/api/v2/chats/citation-keysB\x7f\n" + "\vcom.chat.v2B\tChatProtoP\x01Z(paperdebugger/pkg/gen/api/chat/v2;chatv2\xa2\x02\x03CXX\xaa\x02\aChat.V2\xca\x02\aChat\\V2\xe2\x02\x13Chat\\V2\\GPBMetadata\xea\x02\bChat::V2b\x06proto3" var ( @@ -2056,7 +2173,7 @@ func file_chat_v2_chat_proto_rawDescGZIP() []byte { } var file_chat_v2_chat_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_chat_v2_chat_proto_msgTypes = make([]protoimpl.MessageInfo, 30) +var file_chat_v2_chat_proto_msgTypes = make([]protoimpl.MessageInfo, 32) var file_chat_v2_chat_proto_goTypes = []any{ (ConversationType)(0), // 0: chat.v2.ConversationType (*MessageTypeToolCall)(nil), // 1: chat.v2.MessageTypeToolCall @@ -2089,6 +2206,8 @@ var file_chat_v2_chat_proto_goTypes = []any{ (*StreamError)(nil), // 28: chat.v2.StreamError (*CreateConversationMessageStreamRequest)(nil), // 29: chat.v2.CreateConversationMessageStreamRequest (*CreateConversationMessageStreamResponse)(nil), // 30: chat.v2.CreateConversationMessageStreamResponse + (*GetCitationKeysRequest)(nil), // 31: chat.v2.GetCitationKeysRequest + (*GetCitationKeysResponse)(nil), // 32: chat.v2.GetCitationKeysResponse } var file_chat_v2_chat_proto_depIdxs = []int32{ 3, // 0: chat.v2.MessagePayload.system:type_name -> chat.v2.MessageTypeSystem @@ -2120,14 +2239,16 @@ var file_chat_v2_chat_proto_depIdxs = []int32{ 14, // 26: chat.v2.ChatService.UpdateConversation:input_type -> chat.v2.UpdateConversationRequest 16, // 27: chat.v2.ChatService.DeleteConversation:input_type -> chat.v2.DeleteConversationRequest 19, // 28: chat.v2.ChatService.ListSupportedModels:input_type -> chat.v2.ListSupportedModelsRequest - 11, // 29: chat.v2.ChatService.ListConversations:output_type -> chat.v2.ListConversationsResponse - 13, // 30: chat.v2.ChatService.GetConversation:output_type -> chat.v2.GetConversationResponse - 30, // 31: chat.v2.ChatService.CreateConversationMessageStream:output_type -> chat.v2.CreateConversationMessageStreamResponse - 15, // 32: chat.v2.ChatService.UpdateConversation:output_type -> chat.v2.UpdateConversationResponse - 17, // 33: chat.v2.ChatService.DeleteConversation:output_type -> chat.v2.DeleteConversationResponse - 20, // 34: chat.v2.ChatService.ListSupportedModels:output_type -> chat.v2.ListSupportedModelsResponse - 29, // [29:35] is the sub-list for method output_type - 23, // [23:29] is the sub-list for method input_type + 31, // 29: chat.v2.ChatService.GetCitationKeys:input_type -> chat.v2.GetCitationKeysRequest + 11, // 30: chat.v2.ChatService.ListConversations:output_type -> chat.v2.ListConversationsResponse + 13, // 31: chat.v2.ChatService.GetConversation:output_type -> chat.v2.GetConversationResponse + 30, // 32: chat.v2.ChatService.CreateConversationMessageStream:output_type -> chat.v2.CreateConversationMessageStreamResponse + 15, // 33: chat.v2.ChatService.UpdateConversation:output_type -> chat.v2.UpdateConversationResponse + 17, // 34: chat.v2.ChatService.DeleteConversation:output_type -> chat.v2.DeleteConversationResponse + 20, // 35: chat.v2.ChatService.ListSupportedModels:output_type -> chat.v2.ListSupportedModelsResponse + 32, // 36: chat.v2.ChatService.GetCitationKeys:output_type -> chat.v2.GetCitationKeysResponse + 30, // [30:37] is the sub-list for method output_type + 23, // [23:30] is the sub-list for method input_type 23, // [23:23] is the sub-list for extension type_name 23, // [23:23] is the sub-list for extension extendee 0, // [0:23] is the sub-list for field type_name @@ -2161,13 +2282,14 @@ func file_chat_v2_chat_proto_init() { (*CreateConversationMessageStreamResponse_StreamError)(nil), (*CreateConversationMessageStreamResponse_ReasoningChunk)(nil), } + file_chat_v2_chat_proto_msgTypes[30].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_chat_v2_chat_proto_rawDesc), len(file_chat_v2_chat_proto_rawDesc)), NumEnums: 1, - NumMessages: 30, + NumMessages: 32, NumExtensions: 0, NumServices: 1, }, diff --git a/pkg/gen/api/chat/v2/chat.pb.gw.go b/pkg/gen/api/chat/v2/chat.pb.gw.go index 81f7e4e..0e7a57c 100644 --- a/pkg/gen/api/chat/v2/chat.pb.gw.go +++ b/pkg/gen/api/chat/v2/chat.pb.gw.go @@ -237,6 +237,33 @@ func local_request_ChatService_ListSupportedModels_0(ctx context.Context, marsha return msg, metadata, err } +func request_ChatService_GetCitationKeys_0(ctx context.Context, marshaler runtime.Marshaler, client ChatServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetCitationKeysRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.GetCitationKeys(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_ChatService_GetCitationKeys_0(ctx context.Context, marshaler runtime.Marshaler, server ChatServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetCitationKeysRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := server.GetCitationKeys(ctx, &protoReq) + return msg, metadata, err +} + // RegisterChatServiceHandlerServer registers the http handlers for service ChatService to "mux". // UnaryRPC :call ChatServiceServer directly. // StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. @@ -350,6 +377,26 @@ func RegisterChatServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux } forward_ChatService_ListSupportedModels_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) + mux.Handle(http.MethodPost, pattern_ChatService_GetCitationKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/chat.v2.ChatService/GetCitationKeys", runtime.WithHTTPPathPattern("/_pd/api/v2/chats/citation-keys")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_ChatService_GetCitationKeys_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_ChatService_GetCitationKeys_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) return nil } @@ -492,6 +539,23 @@ func RegisterChatServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux } forward_ChatService_ListSupportedModels_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) + mux.Handle(http.MethodPost, pattern_ChatService_GetCitationKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/chat.v2.ChatService/GetCitationKeys", runtime.WithHTTPPathPattern("/_pd/api/v2/chats/citation-keys")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_ChatService_GetCitationKeys_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_ChatService_GetCitationKeys_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) return nil } @@ -502,6 +566,7 @@ var ( pattern_ChatService_UpdateConversation_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"_pd", "api", "v2", "chats", "conversations", "conversation_id"}, "")) pattern_ChatService_DeleteConversation_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"_pd", "api", "v2", "chats", "conversations", "conversation_id"}, "")) pattern_ChatService_ListSupportedModels_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4}, []string{"_pd", "api", "v2", "chats", "models"}, "")) + pattern_ChatService_GetCitationKeys_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4}, []string{"_pd", "api", "v2", "chats", "citation-keys"}, "")) ) var ( @@ -511,4 +576,5 @@ var ( forward_ChatService_UpdateConversation_0 = runtime.ForwardResponseMessage forward_ChatService_DeleteConversation_0 = runtime.ForwardResponseMessage forward_ChatService_ListSupportedModels_0 = runtime.ForwardResponseMessage + forward_ChatService_GetCitationKeys_0 = runtime.ForwardResponseMessage ) diff --git a/pkg/gen/api/chat/v2/chat_grpc.pb.go b/pkg/gen/api/chat/v2/chat_grpc.pb.go index 8303a8a..2f1ea65 100644 --- a/pkg/gen/api/chat/v2/chat_grpc.pb.go +++ b/pkg/gen/api/chat/v2/chat_grpc.pb.go @@ -25,6 +25,7 @@ const ( ChatService_UpdateConversation_FullMethodName = "/chat.v2.ChatService/UpdateConversation" ChatService_DeleteConversation_FullMethodName = "/chat.v2.ChatService/DeleteConversation" ChatService_ListSupportedModels_FullMethodName = "/chat.v2.ChatService/ListSupportedModels" + ChatService_GetCitationKeys_FullMethodName = "/chat.v2.ChatService/GetCitationKeys" ) // ChatServiceClient is the client API for ChatService service. @@ -37,6 +38,7 @@ type ChatServiceClient interface { UpdateConversation(ctx context.Context, in *UpdateConversationRequest, opts ...grpc.CallOption) (*UpdateConversationResponse, error) DeleteConversation(ctx context.Context, in *DeleteConversationRequest, opts ...grpc.CallOption) (*DeleteConversationResponse, error) ListSupportedModels(ctx context.Context, in *ListSupportedModelsRequest, opts ...grpc.CallOption) (*ListSupportedModelsResponse, error) + GetCitationKeys(ctx context.Context, in *GetCitationKeysRequest, opts ...grpc.CallOption) (*GetCitationKeysResponse, error) } type chatServiceClient struct { @@ -116,6 +118,16 @@ func (c *chatServiceClient) ListSupportedModels(ctx context.Context, in *ListSup return out, nil } +func (c *chatServiceClient) GetCitationKeys(ctx context.Context, in *GetCitationKeysRequest, opts ...grpc.CallOption) (*GetCitationKeysResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetCitationKeysResponse) + err := c.cc.Invoke(ctx, ChatService_GetCitationKeys_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + // ChatServiceServer is the server API for ChatService service. // All implementations must embed UnimplementedChatServiceServer // for forward compatibility. @@ -126,6 +138,7 @@ type ChatServiceServer interface { UpdateConversation(context.Context, *UpdateConversationRequest) (*UpdateConversationResponse, error) DeleteConversation(context.Context, *DeleteConversationRequest) (*DeleteConversationResponse, error) ListSupportedModels(context.Context, *ListSupportedModelsRequest) (*ListSupportedModelsResponse, error) + GetCitationKeys(context.Context, *GetCitationKeysRequest) (*GetCitationKeysResponse, error) mustEmbedUnimplementedChatServiceServer() } @@ -154,6 +167,9 @@ func (UnimplementedChatServiceServer) DeleteConversation(context.Context, *Delet func (UnimplementedChatServiceServer) ListSupportedModels(context.Context, *ListSupportedModelsRequest) (*ListSupportedModelsResponse, error) { return nil, status.Error(codes.Unimplemented, "method ListSupportedModels not implemented") } +func (UnimplementedChatServiceServer) GetCitationKeys(context.Context, *GetCitationKeysRequest) (*GetCitationKeysResponse, error) { + return nil, status.Error(codes.Unimplemented, "method GetCitationKeys not implemented") +} func (UnimplementedChatServiceServer) mustEmbedUnimplementedChatServiceServer() {} func (UnimplementedChatServiceServer) testEmbeddedByValue() {} @@ -276,6 +292,24 @@ func _ChatService_ListSupportedModels_Handler(srv interface{}, ctx context.Conte return interceptor(ctx, in, info, handler) } +func _ChatService_GetCitationKeys_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetCitationKeysRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ChatServiceServer).GetCitationKeys(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ChatService_GetCitationKeys_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ChatServiceServer).GetCitationKeys(ctx, req.(*GetCitationKeysRequest)) + } + return interceptor(ctx, in, info, handler) +} + // ChatService_ServiceDesc is the grpc.ServiceDesc for ChatService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -303,6 +337,10 @@ var ChatService_ServiceDesc = grpc.ServiceDesc{ MethodName: "ListSupportedModels", Handler: _ChatService_ListSupportedModels_Handler, }, + { + MethodName: "GetCitationKeys", + Handler: _ChatService_GetCitationKeys_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/proto/chat/v2/chat.proto b/proto/chat/v2/chat.proto index 8dd650a..856c25f 100644 --- a/proto/chat/v2/chat.proto +++ b/proto/chat/v2/chat.proto @@ -31,6 +31,12 @@ service ChatService { rpc ListSupportedModels(ListSupportedModelsRequest) returns (ListSupportedModelsResponse) { option (google.api.http) = {get: "/_pd/api/v2/chats/models"}; } + rpc GetCitationKeys(GetCitationKeysRequest) returns (GetCitationKeysResponse) { + option (google.api.http) = { + post: "/_pd/api/v2/chats/citation-keys" + body: "*" + }; + } } message MessageTypeToolCall { @@ -234,3 +240,16 @@ message CreateConversationMessageStreamResponse { ReasoningChunk reasoning_chunk = 8; } } + +// Request to suggest citation keys based on context +message GetCitationKeysRequest { + string sentence = 1; + string project_id = 2; + optional string model_slug = 3; +} + +// Response containing the suggested citation keys +message GetCitationKeysResponse { + // A comma-separated string of keys, or empty if none found + string citation_keys = 1; +} \ No newline at end of file diff --git a/webapp/_webapp/src/libs/inline-suggestion.ts b/webapp/_webapp/src/libs/inline-suggestion.ts index 7c21883..1a2ef16 100644 --- a/webapp/_webapp/src/libs/inline-suggestion.ts +++ b/webapp/_webapp/src/libs/inline-suggestion.ts @@ -33,6 +33,85 @@ import { import { logDebug, logError, logInfo } from "./logger"; import { useSettingStore } from "../stores/setting-store"; +import { getCitationKeys } from "../query/api"; +import { getProjectId } from "./helpers"; + +/** A completion trigger associates a pattern with a handler function. */ +type CompletionTrigger = { + pattern: string; + handler: (state: EditorState, triggerPattern: string) => Promise; +}; + +/** Completion handler for citation keys (triggered by \cite{). */ +async function completeCitationKeys(state: EditorState, triggerPattern: string): Promise { + const cursorPos = state.selection.main.head; + const textBefore = state.doc.sliceString(0, cursorPos - triggerPattern.length); + const lastSentence = textBefore + .split(/(?<=[.!?])\s+/) + .filter((s) => s.trim().length > 0) + .slice(-1)[0]; + if (!lastSentence) { + return ""; + } + + const projectId = getProjectId(); + if (!projectId) { + return ""; + } + + try { + const response = await getCitationKeys({ + sentence: lastSentence, + projectId: projectId, + }); + return response.citationKeys || ""; + } catch (err) { + logError("inline completion: failed", err); + return ""; + } +} + +/** Registry of completion triggers. Add new triggers here to extend functionality. */ +const COMPLETION_TRIGGERS: CompletionTrigger[] = [{ pattern: "\\cite{", handler: completeCitationKeys }]; + +/** Returns the trigger that matches at cursor position, or null if none. */ +function getTriggerAtCursor(state: EditorState): CompletionTrigger | null { + const cursorPos = state.selection.main.head; + for (const trigger of COMPLETION_TRIGGERS) { + const start = Math.max(0, cursorPos - trigger.pattern.length); + if (state.doc.sliceString(start, cursorPos) === trigger.pattern) { + return trigger; + } + } + return null; +} + +/** Returns true when the cursor sits right after any registered trigger pattern. */ +function isTriggerAtCursor(state: EditorState): boolean { + return getTriggerAtCursor(state) !== null; +} + +/** Inserts a suggestion into the editor and dispatches the acceptance effect. */ +function acceptSuggestion( + view: EditorView, + suggestionText: string, + suggestionAcceptanceEffect: StateEffectType, +) { + view.dispatch({ + ...insertCompletionText( + view.state, + suggestionText, + view.state.selection.main.head, + view.state.selection.main.head, + ), + }); + + view.dispatch({ + effects: suggestionAcceptanceEffect.of({ + acceptance: SuggestionAcceptance.ACCEPTED, + }), + }); +} export enum SuggestionAcceptance { REJECTED = 0, @@ -102,13 +181,21 @@ export function debouncePromise any>( // eslint-di }; } -export async function completion(_state: EditorState): Promise { +/** Main completion function that dispatches to the appropriate handler based on trigger. */ +export async function completion(state: EditorState): Promise { + // Only trigger when enable completion setting is on const settings = useSettingStore.getState().settings; if (!settings?.enableCompletion) { return ""; } - return "Unsupported Feature"; + // Find matching trigger and call its handler + const trigger = getTriggerAtCursor(state); + if (!trigger) { + return ""; + } + + return trigger.handler(state, trigger.pattern); } /** @@ -211,28 +298,13 @@ class InlineSuggestionWidget extends WidgetType { return span; } accept(e: MouseEvent, view: EditorView) { - const suggestionText = this.suggestion; const config = view.state.field(this.configState); if (!config.acceptOnClick) return; e.stopPropagation(); e.preventDefault(); - view.dispatch({ - ...insertCompletionText( - view.state, - suggestionText, - view.state.selection.main.head, - view.state.selection.main.head, - ), - }); - - view.dispatch({ - effects: this.suggestionAcceptanceEffect.of({ - acceptance: SuggestionAcceptance.ACCEPTED, - }), - }); - + acceptSuggestion(view, this.suggestion, this.suggestionAcceptanceEffect); return true; } } @@ -300,14 +372,7 @@ export function createExtensionKeymapBinding( return false; } - view.dispatch({ - ...insertCompletionText( - view.state, - suggestionText, - view.state.selection.main.head, - view.state.selection.main.head, - ), - }); + acceptSuggestion(view, suggestionText, suggestionAcceptanceEffect); logInfo("tab handler: suggestion accepted"); return true; } catch (e) { @@ -438,15 +503,15 @@ export function createSuggestionFetchPlugin( // Check if the docChange is due to an remote collaborator // @ts-expect-error - changedRanges is only available in the Overleaf version of CodeMirror - const updatePos = update.changedRanges[0].toB; + const changedRanges = update.changedRanges; const localPos = update.view.state.selection.main.head; - if (updatePos !== localPos) { - return; - } - const isAutocompleted = update.transactions.some((t) => t.isUserEvent("input.complete")); - if (isAutocompleted) { - return; + // Local changes should have the cursor within or at the end of the changed range + if (changedRanges && changedRanges.length > 0) { + const changedRange = changedRanges[0]; + if (localPos < changedRange.fromB || localPos > changedRange.toB) { + return; + } } const config = update.state.field(suggestionConfig); @@ -530,6 +595,87 @@ export function createRenderInlineSuggestionPlugin( ); } +/** + * Creates a CodeMirror ViewPlugin that suppresses Overleaf's built-in + * autocomplete when our inline citation suggestion is active or pending + * (i.e. `\cite{` was just typed). + * + * Three mechanisms work together: + * 1. A dynamically injected `