Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,7 @@ gen-protoc:
protoc --go_out=. --go_opt=paths=source_relative \
--go-grpc_out=. --go-grpc_opt=paths=source_relative \
gravity.proto
@cd dns/proto && \
protoc --go_out=. --go_opt=paths=source_relative \
--go-grpc_out=. --go-grpc_opt=paths=source_relative \
dns.proto
132 changes: 44 additions & 88 deletions dns/aether.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"strings"
"time"

cstr "github.com/agentuity/go-common/string"
pb "github.com/agentuity/go-common/dns/proto"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
Expand All @@ -19,6 +19,8 @@ type DNSBaseAction struct {
Reply string `json:"reply,omitempty"`
}

// DNSAddAction represents a DNS record addition request.
// Deprecated: This type will be replaced by the protobuf-generated types in dns/proto.
type DNSAddAction struct {
DNSBaseAction
Name string `json:"name"`
Expand Down Expand Up @@ -74,6 +76,8 @@ func (a *DNSAddAction) WithPort(port int) *DNSAddAction {
return a
}

// DNSDeleteAction represents a DNS record deletion request.
// Deprecated: This type will be replaced by the protobuf-generated types in dns/proto.
type DNSDeleteAction struct {
DNSBaseAction
// Name is the name of the DNS record to delete.
Expand All @@ -83,6 +87,8 @@ type DNSDeleteAction struct {
IDs []string `json:"ids,omitempty"`
}

// DNSCertAction represents a certificate request for a domain.
// Deprecated: This type will be replaced by the protobuf-generated types in dns/proto.
type DNSCertAction struct {
DNSBaseAction
Name string `json:"name"`
Expand Down Expand Up @@ -205,93 +211,59 @@ func (a DNSBaseAction) GetAction() string {
return a.Action
}

// Transport is an interface for a transport layer for the DNS server
type Transport interface {
Subscribe(ctx context.Context, channel string) Subscriber
Publish(ctx context.Context, channel string, payload []byte) error
}

// Message is a message from the transport layer
type Message struct {
Payload []byte
}

// Subscriber is an interface for a subscriber to the transport layer
type Subscriber interface {
// Close closes the subscriber
Close() error
// Channel returns a channel of messages
Channel() <-chan *Message
// transport is an interface for a transport layer for the DNS server
type transport interface {
// Publish sends a DNS action and waits for a response
Publish(ctx context.Context, action DNSAction) ([]byte, error)
// PublishAsync sends a DNS action without waiting for a response
PublishAsync(ctx context.Context, action DNSAction) error
}

type option struct {
transport Transport
transport transport
timeout time.Duration
reply bool
}

type optionHandler func(*option)
type OptionHandler func(*option)

// WithReply sets whether the DNS action should wait for a reply from the DNS server
func WithReply(reply bool) optionHandler {
func WithReply(reply bool) OptionHandler {
return func(o *option) {
o.reply = reply
}
}

// WithTransport sets a custom transport for the DNS action
func WithTransport(transport Transport) optionHandler {
//
// for Testing
func withTransport(t transport) OptionHandler {
return func(o *option) {
o.transport = transport
o.transport = t
}
}

// WithTimeout sets a custom timeout for the DNS action
func WithTimeout(timeout time.Duration) optionHandler {
func WithTimeout(timeout time.Duration) OptionHandler {
return func(o *option) {
o.timeout = timeout
}
}

// WithRedis uses a redis client as the transport for the DNS action
func WithRedis(redis *redis.Client) optionHandler {
func WithRedis(redis *redis.Client) OptionHandler {
return func(o *option) {
o.transport = &redisTransport{redis: redis}
}
}

type redisSubscriber struct {
sub *redis.PubSub
}

var _ Subscriber = (*redisSubscriber)(nil)

func (s *redisSubscriber) Close() error {
return s.sub.Close()
}

func (s *redisSubscriber) Channel() <-chan *Message {
ch := make(chan *Message)
go func() {
for msg := range s.sub.Channel() {
ch <- &Message{Payload: []byte(msg.Payload)}
// WithGRPC uses a gRPC client as the transport for the DNS action
func WithGRPC(client pb.DNSServiceClient) OptionHandler {
return func(o *option) {
o.transport = &grpcTransport{
client: client,
}
}()
return ch
}

type redisTransport struct {
redis *redis.Client
}

var _ Transport = (*redisTransport)(nil)

func (t *redisTransport) Subscribe(ctx context.Context, channel string) Subscriber {
return &redisSubscriber{sub: t.redis.Subscribe(ctx, channel)}
}

func (t *redisTransport) Publish(ctx context.Context, channel string, payload []byte) error {
return t.redis.Publish(ctx, channel, payload).Err()
}
}

// ActionFromChannel returns the action from the channel string
Expand Down Expand Up @@ -332,7 +304,7 @@ func NewDNSResponse[R any, T TypedDNSAction[R]](action T, data *R, err error) *D
}

// SendDNSAction sends a DNS action to the DNS server with a timeout. If the timeout is 0, the default timeout will be used.
func SendDNSAction[R any, T TypedDNSAction[R]](ctx context.Context, action T, opts ...optionHandler) (*R, error) {
func SendDNSAction[R any, T TypedDNSAction[R]](ctx context.Context, action T, opts ...OptionHandler) (*R, error) {
var o option
o.timeout = DefaultDNSTimeout
o.reply = true
Expand All @@ -345,43 +317,27 @@ func SendDNSAction[R any, T TypedDNSAction[R]](ctx context.Context, action T, op
return nil, ErrTransportRequired
}

id := action.GetID()
if id == "" {
return nil, errors.New("message ID not found")
if o.timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, o.timeout)
defer cancel()
}

var sub Subscriber

if o.reply {
action.SetReply("aether:response:" + action.GetAction() + ":" + id)
sub = o.transport.Subscribe(ctx, action.GetReply())
defer sub.Close()
if !o.reply {
return nil, o.transport.PublishAsync(ctx, action)
}

if err := o.transport.Publish(ctx, "aether:request:"+action.GetAction()+":"+id, []byte(cstr.JSONStringify(action))); err != nil {
responseBytes, err := o.transport.Publish(ctx, action)
if err != nil {
return nil, err
}

if o.reply {
select {
case <-ctx.Done():
return nil, ctx.Err()
case msg := <-sub.Channel():
if msg == nil {
return nil, ErrClosed
}
var response DNSResponse[R]
if err := json.Unmarshal([]byte(msg.Payload), &response); err != nil {
return nil, fmt.Errorf("failed to unmarshal dns action response: %w", err)
}
if !response.Success {
return nil, errors.New(response.Error)
}
return response.Data, nil
case <-time.After(o.timeout):
return nil, ErrTimeout
}
var response DNSResponse[R]
if err := json.Unmarshal(responseBytes, &response); err != nil {
return nil, fmt.Errorf("failed to unmarshal dns action response: %w", err)
}

return nil, nil
if !response.Success {
return nil, errors.New(response.Error)
}
return response.Data, nil
}
71 changes: 30 additions & 41 deletions dns/aether_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,53 +11,42 @@ import (
)

type testTransport struct {
publish chan *Message
subscribe chan *Message
lastAction DNSAction
}

var _ Transport = (*testTransport)(nil)

func (t *testTransport) Subscribe(ctx context.Context, channel string) Subscriber {
t.subscribe = make(chan *Message, 1)
t.publish = make(chan *Message, 1)
return &testSubscriber{channel: channel, messages: t.subscribe}
}

func (t *testTransport) Publish(ctx context.Context, channel string, payload []byte) error {
t.publish <- &Message{Payload: payload}
var cert DNSCert
cert.Certificate = []byte("cert")
cert.Expires = time.Now().Add(time.Hour * 24 * 365 * 2)
cert.PrivateKey = []byte("private")
var response DNSResponse[DNSCert]
response.Success = true
response.Data = &cert
response.MsgID = uuid.New().String()
response.Error = ""
response.Data = &cert
responseBytes, err := json.Marshal(response)
if err != nil {
return err
var _ transport = (*testTransport)(nil)

func (t *testTransport) Publish(ctx context.Context, action DNSAction) ([]byte, error) {
t.lastAction = action

// Return different responses based on action type
switch action.(type) {
case *DNSCertAction:
var cert DNSCert
cert.Certificate = []byte("cert")
cert.Expires = time.Now().Add(time.Hour * 24 * 365 * 2)
cert.PrivateKey = []byte("private")
var response DNSResponse[DNSCert]
response.Success = true
response.Data = &cert
return json.Marshal(response)
case *DNSAddAction:
var record DNSRecord
record.IDs = []string{uuid.New().String()}
var response DNSResponse[DNSRecord]
response.Success = true
response.Data = &record
return json.Marshal(response)
default:
return nil, nil
}
t.subscribe <- &Message{Payload: responseBytes}
return nil
}

type testSubscriber struct {
channel string
messages chan *Message
}

var _ Subscriber = (*testSubscriber)(nil)

func (s *testSubscriber) Close() error {
func (t *testTransport) PublishAsync(ctx context.Context, action DNSAction) error {
t.lastAction = action
return nil
}

func (s *testSubscriber) Channel() <-chan *Message {
return s.messages
}

func TestDNSAction(t *testing.T) {
var transport testTransport

Expand All @@ -69,7 +58,7 @@ func TestDNSAction(t *testing.T) {
Name: "test",
}

reply, err := SendDNSAction(context.Background(), action, WithTransport(&transport), WithTimeout(time.Second))
reply, err := SendDNSAction(context.Background(), action, withTransport(&transport), WithTimeout(time.Second))
if err != nil {
t.Fatalf("failed to send dns action: %v", err)
}
Expand All @@ -89,7 +78,7 @@ func TestDNSCertAction(t *testing.T) {
Name: "test",
}

reply, err := SendDNSAction(context.Background(), action, WithTransport(&transport), WithTimeout(time.Second))
reply, err := SendDNSAction(context.Background(), action, withTransport(&transport), WithTimeout(time.Second))
if err != nil {
t.Fatalf("failed to send dns cert action: %v", err)
}
Expand Down
Loading
Loading