diff --git a/internal/client/client.go b/internal/client/client.go index 38e3fc09..f1ebc04c 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -44,6 +44,7 @@ const ( defaultRetryJitterFraction = 0.5 importBulkRoute = "/authzed.api.v1.PermissionsService/ImportBulkRelationships" exportBulkRoute = "/authzed.api.v1.PermissionsService/ExportBulkRelationships" + watchRoute = "/authzed.api.v1.WatchService/Watch" ) // NewClient defines an (overridable) means of creating a new client. @@ -235,7 +236,7 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti // retrying the bulk import in backup/restore logic is handled manually. // retrying bulk export is also handled manually, because the default behavior is // to start at the beginning of the stream, which produces duplicate values. - selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute))), + selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute, watchRoute))), } if !cobrautil.MustGetBool(cmd, "skip-version-check") { diff --git a/internal/client/client_test.go b/internal/client/client_test.go index b33b392a..66790da4 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -133,30 +133,36 @@ func TestGetCurrentTokenWithCLIOverrideWithoutSecretFile(t *testing.T) { require.Equal(&bTrue, token.Insecure) } -type fakeSchemaServer struct { +type fakeServer struct { v1.UnimplementedSchemaServiceServer v1.UnimplementedExperimentalServiceServer + v1.UnimplementedWatchServiceServer v1.UnimplementedPermissionsServiceServer testFunc func() } -func (fss *fakeSchemaServer) ReadSchema(_ context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) { +func (fss *fakeServer) ReadSchema(_ context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) { fss.testFunc() return nil, status.Error(codes.Unavailable, "") } -func (fss *fakeSchemaServer) ImportBulkRelationships(grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error { +func (fss *fakeServer) ImportBulkRelationships(grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error { fss.testFunc() return status.Errorf(codes.Aborted, "") } +func (fss *fakeServer) Watch(*v1.WatchRequest, grpc.ServerStreamingServer[v1.WatchResponse]) error { + fss.testFunc() + return status.Errorf(codes.Unavailable, "") +} + func TestRetries(t *testing.T) { ctx := t.Context() var callCount uint lis := bufconn.Listen(1024 * 1024) s := grpc.NewServer() - fakeServer := &fakeSchemaServer{testFunc: func() { + fakeServer := &fakeServer{testFunc: func() { callCount++ }} v1.RegisterSchemaServiceServer(s, fakeServer) @@ -185,22 +191,25 @@ func TestRetries(t *testing.T) { c, err := authzed.NewClient("passthrough://bufnet", dialOpts...) require.NoError(t, err) - _, err = c.ReadSchema(ctx, &v1.ReadSchemaRequest{}) - grpcutil.RequireStatus(t, codes.Unavailable, err) - require.Equal(t, retries, callCount) + t.Run("read_schema", func(t *testing.T) { + _, err = c.ReadSchema(ctx, &v1.ReadSchemaRequest{}) + grpcutil.RequireStatus(t, codes.Unavailable, err) + require.Equal(t, retries, callCount) + }) } -func TestDoesNotRetryBackupRestore(t *testing.T) { +func TestDoesNotRetry(t *testing.T) { ctx := t.Context() var callCount uint lis := bufconn.Listen(1024 * 1024) s := grpc.NewServer() - fakeServer := &fakeSchemaServer{testFunc: func() { + fakeServer := &fakeServer{testFunc: func() { callCount++ }} v1.RegisterPermissionsServiceServer(s, fakeServer) v1.RegisterExperimentalServiceServer(s, fakeServer) + v1.RegisterWatchServiceServer(s, fakeServer) go func() { _ = s.Serve(lis) @@ -226,20 +235,23 @@ func TestDoesNotRetryBackupRestore(t *testing.T) { c, err := authzed.NewClientWithExperimentalAPIs("passthrough://bufnet", dialOpts...) require.NoError(t, err) - ibc, err := c.ImportBulkRelationships(ctx) - require.NoError(t, err) - err = ibc.SendMsg(&v1.ImportBulkRelationshipsRequest{}) - require.NoError(t, err) - _, err = ibc.CloseAndRecv() - grpcutil.RequireStatus(t, codes.Aborted, err) - require.Equal(t, uint(1), callCount) + t.Run("import_bulk", func(t *testing.T) { + ibc, err := c.ImportBulkRelationships(ctx) + require.NoError(t, err) + err = ibc.SendMsg(&v1.ImportBulkRelationshipsRequest{}) + require.NoError(t, err) + _, err = ibc.CloseAndRecv() + grpcutil.RequireStatus(t, codes.Aborted, err) + require.Equal(t, uint(1), callCount) + }) - callCount = 0 - bic, err := c.ImportBulkRelationships(ctx) - require.NoError(t, err) - err = bic.SendMsg(&v1.ImportBulkRelationshipsRequest{}) - require.NoError(t, err) - _, err = bic.CloseAndRecv() - grpcutil.RequireStatus(t, codes.Aborted, err) - require.Equal(t, uint(1), callCount) + t.Run("watch", func(t *testing.T) { + callCount = 0 + watchReq, err := c.Watch(ctx, &v1.WatchRequest{}) + require.NoError(t, err) + resp, err := watchReq.Recv() + require.Nil(t, resp) + grpcutil.RequireStatus(t, codes.Unavailable, err) + require.Equal(t, uint(1), callCount) + }) } diff --git a/internal/commands/watch.go b/internal/commands/watch.go index fa5cf390..1e0313e9 100644 --- a/internal/commands/watch.go +++ b/internal/commands/watch.go @@ -9,7 +9,10 @@ import ( "syscall" "time" + "github.com/rs/zerolog/log" "github.com/spf13/cobra" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" @@ -44,7 +47,7 @@ func RegisterWatchRelationshipCmd(parentCmd *cobra.Command) *cobra.Command { var watchCmd = &cobra.Command{ Use: "watch [object_types, ...] [start_cursor]", - Short: "Watches the stream of relationship updates from the server", + Short: "Watches the stream of relationship updates and schema updates from the server", Args: ValidationWrapper(cobra.RangeArgs(0, 2)), RunE: watchCmdFunc, Deprecated: "please use `zed relationships watch` instead", @@ -52,18 +55,21 @@ var watchCmd = &cobra.Command{ var watchRelationshipsCmd = &cobra.Command{ Use: "watch [object_types, ...] [start_cursor]", - Short: "Watches the stream of relationship updates from the server", + Short: "Watches the stream of relationship updates and schema updates from the server", Args: ValidationWrapper(cobra.RangeArgs(0, 2)), RunE: watchCmdFunc, } func watchCmdFunc(cmd *cobra.Command, _ []string) error { - console.Printf("starting watch stream over types %v and revision %v\n", watchObjectTypes, watchRevision) - - cli, err := client.NewClient(cmd) + client, err := client.NewClient(cmd) if err != nil { return err } + return watchCmdFuncImpl(cmd, client, processResponse) +} + +func watchCmdFuncImpl(cmd *cobra.Command, watchClient v1.WatchServiceClient, processResponse func(resp *v1.WatchResponse)) error { + console.Printf("starting watch stream over types %v and revision %v\n", watchObjectTypes, watchRevision) relFilters := make([]*v1.RelationshipFilter, 0, len(watchRelationshipFilters)) for _, filter := range watchRelationshipFilters { @@ -74,21 +80,26 @@ func watchCmdFunc(cmd *cobra.Command, _ []string) error { relFilters = append(relFilters, relFilter) } + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + signalctx, interruptCancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) + defer interruptCancel() + req := &v1.WatchRequest{ OptionalObjectTypes: watchObjectTypes, OptionalRelationshipFilters: relFilters, + OptionalUpdateKinds: []v1.WatchKind{ + v1.WatchKind_WATCH_KIND_INCLUDE_CHECKPOINTS, // keeps connection open during quiet periods + v1.WatchKind_WATCH_KIND_INCLUDE_SCHEMA_UPDATES, + }, } + if watchRevision != "" { req.OptionalStartCursor = &v1.ZedToken{Token: watchRevision} } - ctx, cancel := context.WithCancel(cmd.Context()) - defer cancel() - - signalctx, interruptCancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) - defer interruptCancel() - - watchStream, err := cli.Watch(ctx, req) + watchStream, err := watchClient.Watch(ctx, req) if err != nil { return err } @@ -104,40 +115,74 @@ func watchCmdFunc(cmd *cobra.Command, _ []string) error { default: resp, err := watchStream.Recv() if err != nil { - return err - } + ok, err := isRetryable(err) + if !ok { + return err + } - for _, update := range resp.Updates { - if watchTimestamps { - console.Printf("%v: ", time.Now()) + log.Trace().Err(err).Msg("will retry from the last known revision " + watchRevision) + req.OptionalStartCursor = &v1.ZedToken{Token: watchRevision} + watchStream, err = watchClient.Watch(ctx, req) + if err != nil { + return err } + continue + } - switch update.Operation { - case v1.RelationshipUpdate_OPERATION_CREATE: - console.Printf("CREATED ") + processResponse(resp) + } + } +} - case v1.RelationshipUpdate_OPERATION_DELETE: - console.Printf("DELETED ") +func isRetryable(err error) (bool, error) { + statusErr, ok := status.FromError(err) + if !ok || (statusErr.Code() != codes.Unavailable) { + return false, err + } + return true, nil +} - case v1.RelationshipUpdate_OPERATION_TOUCH: - console.Printf("TOUCHED ") - } +func processResponse(resp *v1.WatchResponse) { + if resp.ChangesThrough != nil { + watchRevision = resp.ChangesThrough.Token + } - subjectRelation := "" - if update.Relationship.Subject.OptionalRelation != "" { - subjectRelation = " " + update.Relationship.Subject.OptionalRelation - } + if resp.SchemaUpdated { + if watchTimestamps { + console.Printf("%v: ", time.Now()) + } + console.Println("SCHEMA UPDATED") + } - console.Printf("%s:%s %s %s:%s%s\n", - update.Relationship.Resource.ObjectType, - update.Relationship.Resource.ObjectId, - update.Relationship.Relation, - update.Relationship.Subject.Object.ObjectType, - update.Relationship.Subject.Object.ObjectId, - subjectRelation, - ) - } + for _, update := range resp.Updates { + if watchTimestamps { + console.Printf("%v: ", time.Now()) } + + switch update.Operation { + case v1.RelationshipUpdate_OPERATION_CREATE: + console.Printf("CREATED ") + + case v1.RelationshipUpdate_OPERATION_DELETE: + console.Printf("DELETED ") + + case v1.RelationshipUpdate_OPERATION_TOUCH: + console.Printf("TOUCHED ") + } + + subjectRelation := "" + if update.Relationship.Subject.OptionalRelation != "" { + subjectRelation = " " + update.Relationship.Subject.OptionalRelation + } + + console.Printf("%s:%s %s %s:%s%s\n", + update.Relationship.Resource.ObjectType, + update.Relationship.Resource.ObjectId, + update.Relationship.Relation, + update.Relationship.Subject.Object.ObjectType, + update.Relationship.Subject.Object.ObjectId, + subjectRelation, + ) } } diff --git a/internal/commands/watch_test.go b/internal/commands/watch_test.go index 5f6e96da..7fc9c2e8 100644 --- a/internal/commands/watch_test.go +++ b/internal/commands/watch_test.go @@ -1,10 +1,22 @@ package commands import ( + "context" + "io" "reflect" + "sync" "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/zed/internal/client" + zedtesting "github.com/authzed/zed/internal/testing" ) func TestParseRelationshipFilter(t *testing.T) { @@ -108,3 +120,109 @@ func TestParseRelationshipFilter(t *testing.T) { } } } + +type mockWatchClient struct { + client.Client + grpc.ServerStreamingClient[v1.WatchResponse] + callCounter int +} + +var _ v1.WatchServiceClient = (*mockWatchClient)(nil) + +func (m *mockWatchClient) Recv() (*v1.WatchResponse, error) { + update1 := &v1.RelationshipUpdate{ + Operation: v1.RelationshipUpdate_OPERATION_CREATE, + Relationship: &v1.Relationship{ + Resource: &v1.ObjectReference{ + ObjectType: "document", + ObjectId: "object1", + }, + Relation: "viewer", + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: "user", + ObjectId: "alice", + }, + }, + }, + } + update2 := &v1.RelationshipUpdate{ + Operation: v1.RelationshipUpdate_OPERATION_CREATE, + Relationship: &v1.Relationship{ + Resource: &v1.ObjectReference{ + ObjectType: "document", + ObjectId: "object2", + }, + Relation: "viewer", + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: "user", + ObjectId: "alice", + }, + }, + }, + } + + response1 := &v1.WatchResponse{ + Updates: []*v1.RelationshipUpdate{update1}, + ChangesThrough: &v1.ZedToken{Token: "revision1"}, + } + response2 := &v1.WatchResponse{ + Updates: []*v1.RelationshipUpdate{update2}, + ChangesThrough: &v1.ZedToken{Token: "revision2"}, + } + + switch m.callCounter { + case 0: + m.callCounter++ + return response1, nil + case 1: + m.callCounter++ + return nil, status.Error(codes.Unavailable, "simulated error") + case 2: + m.callCounter++ + return response2, nil + default: + return nil, io.EOF + } +} + +func (m *mockWatchClient) Watch(_ context.Context, _ *v1.WatchRequest, _ ...grpc.CallOption) (grpc.ServerStreamingClient[v1.WatchResponse], error) { + return m, nil +} + +func TestWatchCmdFunc(t *testing.T) { + cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cmd.SetContext(ctx) + + watchErr := make(chan error, 1) + + receivedResponses := make([]*v1.WatchResponse, 0) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + watchErr <- watchCmdFuncImpl(cmd, &mockWatchClient{}, func(resp *v1.WatchResponse) { + receivedResponses = append(receivedResponses, resp) + }) + }() + + time.Sleep(1 * time.Second) + + cancel() + + wg.Wait() + + err := <-watchErr + require.ErrorIs(t, err, io.EOF) + + require.Len(t, receivedResponses, 2) + require.Equal(t, "object1", receivedResponses[0].Updates[0].Relationship.Resource.ObjectId) + require.Equal(t, `token:"revision1"`, receivedResponses[0].ChangesThrough.String()) + require.Equal(t, "object2", receivedResponses[1].Updates[0].Relationship.Resource.ObjectId) + require.Equal(t, `token:"revision2"`, receivedResponses[1].ChangesThrough.String()) +}