diff --git a/security/interceptors/connect/authentication.go b/security/interceptors/connect/authentication.go index 17da092..190d9b9 100644 --- a/security/interceptors/connect/authentication.go +++ b/security/interceptors/connect/authentication.go @@ -26,7 +26,7 @@ var ( ErrInvalidToken = errors.New("invalid authorization token") ) -// AuthInterceptor implements connect.Interceptor for JWT authentication. +// AuthInterceptor implements connect.ValidationInterceptor for JWT authentication. type AuthInterceptor struct { authenticator security.Authenticator } diff --git a/security/interceptors/connect/pbvalidation.go b/security/interceptors/connect/pbvalidation.go index fc990ef..2b5cb4f 100644 --- a/security/interceptors/connect/pbvalidation.go +++ b/security/interceptors/connect/pbvalidation.go @@ -19,22 +19,22 @@ const ( maxStructFields = 200 ) -// An Option configures an [Interceptor]. +// An Option configures an [ValidationInterceptor]. type Option interface { - apply(*Interceptor) + apply(*ValidationInterceptor) } -// WithValidator configures the [Interceptor] to use a customized +// WithValidator configures the [ValidationInterceptor] to use a customized // [protovalidate.Validator]. By default, [protovalidate.GlobalInterceptor] // is used See [protovalidate.ValidatorOption] for the range of available // customizations. func WithValidator(validator protovalidate.Validator) Option { - return optionFunc(func(i *Interceptor) { + return optionFunc(func(i *ValidationInterceptor) { i.validator = validator }) } -// WithValidateResponses configures the [Interceptor] to also validate reponses +// WithValidateResponses configures the [ValidationInterceptor] to also validate reponses // in addition to validating requests. // // By default: @@ -45,21 +45,21 @@ func WithValidator(validator protovalidate.Validator) Option { // // However, these messages are all validated if this option is set. func WithValidateResponses() Option { - return optionFunc(func(i *Interceptor) { + return optionFunc(func(i *ValidationInterceptor) { i.validateResponses = true }) } -// WithoutErrorDetails configures the [Interceptor] to elide error details from +// WithoutErrorDetails configures the [ValidationInterceptor] to elide error details from // validation errors. By default, a [protovalidate.ValidationError] is added // as a detail when validation errors are returned. func WithoutErrorDetails() Option { - return optionFunc(func(i *Interceptor) { + return optionFunc(func(i *ValidationInterceptor) { i.noErrorDetails = true }) } -// Interceptor is a [connect.Interceptor] that ensures that RPC request +// ValidationInterceptor is a [connect.Interceptor] that ensures that RPC request // messages match the constraints expressed in their Protobuf schemas. It does // not validate response messages unless the [WithValidateResponses] option // is specified. @@ -79,16 +79,16 @@ func WithoutErrorDetails() Option { // schema. // // [detailed representation of the error]: https://pkg.go.dev/buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate#Violations -type Interceptor struct { +type ValidationInterceptor struct { validator protovalidate.Validator validateResponses bool noErrorDetails bool } -// NewInterceptor builds an Interceptor. The default configuration is +// NewValidationInterceptor builds an ValidationInterceptor. The default configuration is // appropriate for most use cases. -func NewInterceptor(opts ...Option) *Interceptor { - var interceptor Interceptor +func NewValidationInterceptor(opts ...Option) *ValidationInterceptor { + var interceptor ValidationInterceptor for _, opt := range opts { opt.apply(&interceptor) } @@ -100,8 +100,8 @@ func NewInterceptor(opts ...Option) *Interceptor { return &interceptor } -// WrapUnary implements connect.Interceptor. -func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { +// WrapUnary implements connect.ValidationInterceptor. +func (i *ValidationInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { if err := i.validateRequest(req.Any()); err != nil { return nil, err @@ -117,8 +117,8 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { } } -// WrapStreamingClient implements connect.Interceptor. -func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { +// WrapStreamingClient implements connect.ValidationInterceptor. +func (i *ValidationInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { return &streamingClientInterceptor{ StreamingClientConn: next(ctx, spec), @@ -127,8 +127,8 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn } } -// WrapStreamingHandler implements connect.Interceptor. -func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +// WrapStreamingHandler implements connect.ValidationInterceptor. +func (i *ValidationInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { return next(ctx, &streamingHandlerInterceptor{ StreamingHandlerConn: conn, @@ -137,18 +137,18 @@ func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) co } } -func (i *Interceptor) validateRequest(msg any) error { +func (i *ValidationInterceptor) validateRequest(msg any) error { return i.validate(msg, connect.CodeInvalidArgument) } -func (i *Interceptor) validateResponse(msg any) error { +func (i *ValidationInterceptor) validateResponse(msg any) error { if !i.validateResponses { return nil } return i.validate(msg, connect.CodeInternal) } -func (i *Interceptor) validate(msg any, code connect.Code) error { +func (i *ValidationInterceptor) validate(msg any, code connect.Code) error { if msg == nil { return nil } @@ -169,7 +169,7 @@ func (i *Interceptor) validate(msg any, code connect.Code) error { return nil } -func (i *Interceptor) wrapValidationError(originalErr error, code connect.Code) error { +func (i *ValidationInterceptor) wrapValidationError(originalErr error, code connect.Code) error { connectErr := connect.NewError(code, originalErr) if i.noErrorDetails { return connectErr @@ -272,7 +272,7 @@ func validateSingleStruct(s *structpb.Struct) error { type streamingClientInterceptor struct { connect.StreamingClientConn - interceptor *Interceptor + interceptor *ValidationInterceptor } func (s *streamingClientInterceptor) Send(msg any) error { @@ -292,7 +292,7 @@ func (s *streamingClientInterceptor) Receive(msg any) error { type streamingHandlerInterceptor struct { connect.StreamingHandlerConn - interceptor *Interceptor + interceptor *ValidationInterceptor } func (s *streamingHandlerInterceptor) Send(msg any) error { @@ -309,6 +309,6 @@ func (s *streamingHandlerInterceptor) Receive(msg any) error { return s.interceptor.validateRequest(msg) } -type optionFunc func(*Interceptor) +type optionFunc func(*ValidationInterceptor) -func (f optionFunc) apply(i *Interceptor) { f(i) } +func (f optionFunc) apply(i *ValidationInterceptor) { f(i) } diff --git a/security/interceptors/connect/pbvalidation_test.go b/security/interceptors/connect/pbvalidation_test.go index a8b104d..3f74b49 100644 --- a/security/interceptors/connect/pbvalidation_test.go +++ b/security/interceptors/connect/pbvalidation_test.go @@ -18,7 +18,7 @@ import ( func TestNewInterceptor(t *testing.T) { t.Run("default configuration", func(t *testing.T) { - interceptor := NewInterceptor() + interceptor := NewValidationInterceptor() assert.NotNil(t, interceptor.validator) assert.False(t, interceptor.validateResponses) assert.False(t, interceptor.noErrorDetails) @@ -26,23 +26,23 @@ func TestNewInterceptor(t *testing.T) { t.Run("with custom validator", func(t *testing.T) { customValidator := protovalidate.GlobalValidator - interceptor := NewInterceptor(WithValidator(customValidator)) + interceptor := NewValidationInterceptor(WithValidator(customValidator)) assert.Equal(t, customValidator, interceptor.validator) }) t.Run("with validate responses", func(t *testing.T) { - interceptor := NewInterceptor(WithValidateResponses()) + interceptor := NewValidationInterceptor(WithValidateResponses()) assert.True(t, interceptor.validateResponses) }) t.Run("with no error details", func(t *testing.T) { - interceptor := NewInterceptor(WithoutErrorDetails()) + interceptor := NewValidationInterceptor(WithoutErrorDetails()) assert.True(t, interceptor.noErrorDetails) }) t.Run("multiple options", func(t *testing.T) { customValidator := protovalidate.GlobalValidator - interceptor := NewInterceptor( + interceptor := NewValidationInterceptor( WithValidator(customValidator), WithValidateResponses(), WithoutErrorDetails(), @@ -55,7 +55,7 @@ func TestNewInterceptor(t *testing.T) { } func TestValidateRequest(t *testing.T) { - interceptor := NewInterceptor() + interceptor := NewValidationInterceptor() t.Run("nil message", func(t *testing.T) { err := interceptor.validateRequest(nil) @@ -77,14 +77,14 @@ func TestValidateRequest(t *testing.T) { func TestValidateResponse(t *testing.T) { t.Run("without validate responses option", func(t *testing.T) { - interceptor := NewInterceptor() + interceptor := NewValidationInterceptor() msg := &wrapperspb.StringValue{Value: "test"} err := interceptor.validateResponse(msg) require.NoError(t, err) // Should not validate responses by default }) t.Run("with validate responses option", func(t *testing.T) { - interceptor := NewInterceptor(WithValidateResponses()) + interceptor := NewValidationInterceptor(WithValidateResponses()) msg := &wrapperspb.StringValue{Value: "test"} err := interceptor.validateResponse(msg) require.NoError(t, err) @@ -224,7 +224,7 @@ func TestValidateAllStructs(t *testing.T) { func TestWrapValidationError(t *testing.T) { t.Run("with error details", func(t *testing.T) { - interceptor := NewInterceptor() + interceptor := NewValidationInterceptor() originalErr := &protovalidate.ValidationError{} wrappedErr := interceptor.wrapValidationError(originalErr, connect.CodeInvalidArgument) @@ -234,7 +234,7 @@ func TestWrapValidationError(t *testing.T) { }) t.Run("without error details", func(t *testing.T) { - interceptor := NewInterceptor(WithoutErrorDetails()) + interceptor := NewValidationInterceptor(WithoutErrorDetails()) originalErr := &protovalidate.ValidationError{} wrappedErr := interceptor.wrapValidationError(originalErr, connect.CodeInvalidArgument) @@ -245,7 +245,7 @@ func TestWrapValidationError(t *testing.T) { } func TestUnaryInterceptor(t *testing.T) { - interceptor := NewInterceptor() + interceptor := NewValidationInterceptor() t.Run("valid request", func(t *testing.T) { callCount := 0 @@ -297,7 +297,7 @@ func TestUnaryInterceptor(t *testing.T) { } func TestStreamingInterceptors(t *testing.T) { - interceptor := NewInterceptor() + interceptor := NewValidationInterceptor() t.Run("streaming client interceptor", func(t *testing.T) { wrapped := interceptor.WrapStreamingClient( @@ -378,7 +378,7 @@ func BenchmarkValidateAllStructs(b *testing.B) { } func BenchmarkInterceptorWrapUnary(b *testing.B) { - interceptor := NewInterceptor() + interceptor := NewValidationInterceptor() next := func(_ context.Context, _ connect.AnyRequest) (connect.AnyResponse, error) { return connect.NewResponse(&wrapperspb.StringValue{Value: "response"}), nil }