diff --git a/gateway/internal/api/v1/models.go b/gateway/internal/api/v1/models.go index 160f991..df24e33 100644 --- a/gateway/internal/api/v1/models.go +++ b/gateway/internal/api/v1/models.go @@ -13,6 +13,11 @@ func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.M allProviderModels := map[string]*llmv1.ProviderModels{} for name := range base.ProviderRegistry { + // Check if the provider is healthy before fetching models + if !router.DefaultHealthChecker{}.IsHealthy(name) { + continue + } + provider, err := s.iProviderService.GetProvider(provider.Provider{Name: name}) if err != nil { continue diff --git a/gateway/internal/api/v1/providers.go b/gateway/internal/api/v1/providers.go index 54d5ddd..b65e4c5 100644 --- a/gateway/internal/api/v1/providers.go +++ b/gateway/internal/api/v1/providers.go @@ -20,6 +20,10 @@ func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[empt data := []*llmv1.Provider{} for _, provider := range providers { + // Check if the provider is healthy before adding to the list + if !router.DefaultHealthChecker{}.IsHealthy(provider.Info().Name) { + continue + } providerInfo := provider.Info() data = append(data, &llmv1.Provider{ Title: providerInfo.Title, @@ -34,6 +38,11 @@ func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[empt } func (s *V1Handler) GetProvider(ctx context.Context, req *connect.Request[llmv1.GetProviderRequest]) (*connect.Response[llmv1.GetProviderResponse], error) { + // First, check if the provider is healthy + if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) { + return nil, errors.NewNotFound("Provider is unhealthy") + } + provider, err := s.iProviderService.GetProvider(provider.Provider{Name: req.Msg.Name}) if err != nil { return nil, errors.NewNotFound(err.Error()) @@ -63,6 +72,11 @@ func (s *V1Handler) GetProvider(ctx context.Context, req *connect.Request[llmv1. } func (s *V1Handler) CreateProvider(ctx context.Context, req *connect.Request[llmv1.CreateProviderRequest]) (*connect.Response[llmv1.CreateProviderResponse], error) { + // First, check if the provider is healthy + if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) { + return nil, errors.NewNotFound("Provider is unhealthy") + } + provider := provider.Provider{Name: req.Msg.Name, Config: req.Msg.Config.AsMap()} p, err := s.iProviderService.GetProvider(provider) @@ -111,6 +125,11 @@ func (s *V1Handler) CreateProvider(ctx context.Context, req *connect.Request[llm } func (s *V1Handler) UpsertProvider(ctx context.Context, req *connect.Request[llmv1.UpdateProviderRequest]) (*connect.Response[llmv1.UpdateProviderResponse], error) { + // First, check if the provider is healthy + if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) { + return nil, errors.NewNotFound("Provider is unhealthy") + } + provider := provider.Provider{Name: req.Msg.Name, Config: req.Msg.Config.AsMap()} p, err := s.iProviderService.GetProvider(provider) @@ -172,6 +191,11 @@ func (s *V1Handler) UpsertProvider(ctx context.Context, req *connect.Request[llm } func (s *V1Handler) GetProviderConfig(ctx context.Context, req *connect.Request[llmv1.GetProviderConfigRequest]) (*connect.Response[llmv1.GetProviderConfigResponse], error) { + // First, check if the provider is healthy + if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) { + return nil, errors.NewNotFound("Provider is unhealthy") + } + p, err := s.iProviderService.GetProvider(provider.Provider{Name: req.Msg.Name}) if err != nil { return nil, errors.NewNotFound(err.Error()) diff --git a/gateway/internal/router/health_checker.go b/gateway/internal/router/health_checker.go new file mode 100644 index 0000000..8cb1f28 --- /dev/null +++ b/gateway/internal/router/health_checker.go @@ -0,0 +1,17 @@ +package router + +import ( + "log" +) + +type HealthChecker interface { + IsHealthy(providerName string) bool +} + +type DefaultHealthChecker struct{} + +func (d *DefaultHealthChecker) IsHealthy(providerName string) bool { + // Placeholder for actual health check logic + // Currently returns true, assuming all providers are healthy + return true +} diff --git a/gateway/internal/router/priority.go b/gateway/internal/router/priority.go index 7b0f80e..71a5136 100644 --- a/gateway/internal/router/priority.go +++ b/gateway/internal/router/priority.go @@ -2,6 +2,8 @@ package router import ( "sync/atomic" + "log" + "gateway/internal/router" // Importing to use HealthChecker ) const ( @@ -21,11 +23,20 @@ func NewPriorityRouter(providers []RouterConfig) *PriorityRouter { } func (r *PriorityRouter) Next() (*RouterConfig, error) { - idx := int(r.idx.Load()) - - // Todo: make a check for healthy provider - model := &r.providers[idx] - r.idx.Add(1) - - return model, nil + providerLen := len(r.providers) + originalIdx := r.idx.Load() + var healthyProvider *RouterConfig + for i := 0; i < providerLen; i++ { + idx := (originalIdx + uint64(i)) % uint64(providerLen) + if router.DefaultHealthChecker{}.IsHealthy(r.providers[idx].Name) { + healthyProvider = &r.providers[idx] + r.idx.Store(idx + 1) + break + } + } + if healthyProvider == nil { + log.Println("Error: No healthy providers available.") + return nil, fmt.Errorf("no healthy providers available") + } + return healthyProvider, nil } diff --git a/gateway/internal/router/round_robin.go b/gateway/internal/router/round_robin.go index 60ac24c..5d88108 100644 --- a/gateway/internal/router/round_robin.go +++ b/gateway/internal/router/round_robin.go @@ -1,7 +1,9 @@ package router import ( + "log" "sync/atomic" + "gateway/internal/router" // Importing to use HealthChecker ) const ( @@ -26,9 +28,22 @@ func (r *RoundRobinRouter) Iterator() RouterIterator { func (r *RoundRobinRouter) Next() *RouterConfig { providerLen := len(r.providers) - // Todo: make a check for healthy provider - idx := r.idx.Add(1) - 1 - model := &r.providers[idx%uint64(providerLen)] + // Iterate through providers to find a healthy one + var healthyProvider *RouterConfig + originalIdx := r.idx.Load() + for i := 0; i < providerLen; i++ { + idx := (originalIdx + uint64(i)) % uint64(providerLen) + if router.DefaultHealthChecker{}.IsHealthy(r.providers[idx].Name) { + healthyProvider = &r.providers[idx] + r.idx.Add(1) + break + } + } + + if healthyProvider == nil { + log.Println("Error: No healthy providers available.") + return nil + } - return model + return healthyProvider }