diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 7f271a40..b624a678 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -37,7 +37,7 @@ jobs: credentials_json: '${{ secrets.GCP_SA_KEY }}' - name: Set up Cloud SDK - uses: google-github-actions/setup-gcloud@v0 + uses: google-github-actions/setup-gcloud@v2 - name: Set up Go 1.25 uses: actions/setup-go@v5 diff --git a/Makefile b/Makefile index 7b48ed56..92d03df8 100644 --- a/Makefile +++ b/Makefile @@ -66,8 +66,8 @@ vet: # Generate code generate: controller-gen mockgen manifests - go generate ./... $(CONTROLLER_GEN) object paths="./..." + go generate ./... .PHONY: controller-gen controller-gen: $(CONTROLLER_GEN) diff --git a/api/v1/clusterwidenetworkpolicy_types.go b/api/v1/clusterwidenetworkpolicy_types.go index 7a7b32c5..e8b9d3fa 100644 --- a/api/v1/clusterwidenetworkpolicy_types.go +++ b/api/v1/clusterwidenetworkpolicy_types.go @@ -153,11 +153,18 @@ type FQDNSelector struct { // IPSet stores set name association to IP addresses type IPSet struct { - FQDN string `json:"fqdn,omitempty"` - SetName string `json:"setName,omitempty"` - IPs []string `json:"ips,omitempty"` + // FQDN which this IP set is for. + FQDN string `json:"fqdn,omitempty"` + // A hash value merely used for reference. + SetName string `json:"setName,omitempty"` + // Deprecated: use `IPExpirationTimes` instead. + IPs []string `json:"ips,omitempty"` + // Deprecated: use `IPExpirationTimes` instead. ExpirationTime metav1.Time `json:"expirationTime,omitempty"` - Version IPVersion `json:"version,omitempty"` + // Maps IP addresses to their expiration times. + IPExpirationTimes map[string]metav1.Time `json:"ipExpirationTimes,omitempty"` + // Whether this is a IPv4 or a IPv6 set. + Version IPVersion `json:"version,omitempty"` } func (l *ClusterwideNetworkPolicyList) GetFQDNs() []FQDNSelector { diff --git a/api/v1/zz_generated.deepcopy.go b/api/v1/zz_generated.deepcopy.go index c9863969..24698576 100644 --- a/api/v1/zz_generated.deepcopy.go +++ b/api/v1/zz_generated.deepcopy.go @@ -6,6 +6,7 @@ package v1 import ( networkingv1 "k8s.io/api/networking/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" runtime "k8s.io/apimachinery/pkg/runtime" ) @@ -158,6 +159,13 @@ func (in *IPSet) DeepCopyInto(out *IPSet) { copy(*out, *in) } in.ExpirationTime.DeepCopyInto(&out.ExpirationTime) + if in.IPExpirationTimes != nil { + in, out := &in.IPExpirationTimes, &out.IPExpirationTimes + *out = make(map[string]metav1.Time, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new IPSet. diff --git a/config/crd/bases/metal-stack.io_clusterwidenetworkpolicies.yaml b/config/crd/bases/metal-stack.io_clusterwidenetworkpolicies.yaml index 9d91f1ad..f9d60992 100644 --- a/config/crd/bases/metal-stack.io_clusterwidenetworkpolicies.yaml +++ b/config/crd/bases/metal-stack.io_clusterwidenetworkpolicies.yaml @@ -244,17 +244,28 @@ spec: description: IPSet stores set name association to IP addresses properties: expirationTime: + description: 'Deprecated: use `IPExpirationTimes` instead.' format: date-time type: string fqdn: + description: FQDN which this IP set is for. type: string + ipExpirationTimes: + additionalProperties: + format: date-time + type: string + description: Maps IP addresses to their expiration times. + type: object ips: + description: 'Deprecated: use `IPExpirationTimes` instead.' items: type: string type: array setName: + description: ' A hash value merely used for reference.' type: string version: + description: Whether this is a IPv4 or a IPv6 set. type: string type: object type: array diff --git a/controllers/clusterwidenetworkpolicy_controller.go b/controllers/clusterwidenetworkpolicy_controller.go index 9a169693..f8392fa8 100644 --- a/controllers/clusterwidenetworkpolicy_controller.go +++ b/controllers/clusterwidenetworkpolicy_controller.go @@ -18,10 +18,12 @@ import ( "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/manager" + "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" firewallv2 "github.com/metal-stack/firewall-controller-manager/api/v2" @@ -38,6 +40,7 @@ type ClusterwideNetworkPolicyReconciler struct { SeedNamespace string Log logr.Logger + Ctx context.Context Recorder record.EventRecorder Interval time.Duration @@ -57,7 +60,7 @@ func (r *ClusterwideNetworkPolicyReconciler) SetupWithManager(mgr ctrl.Manager) } return ctrl.NewControllerManagedBy(mgr). - For(&firewallv1.ClusterwideNetworkPolicy{}). + For(&firewallv1.ClusterwideNetworkPolicy{}, builder.WithPredicates(predicate.GenerationChangedPredicate{})). Watches(&corev1.Service{}, &handler.EnqueueRequestForObject{}). WatchesRawSource(&source.Channel{Source: scheduleChan}, &handler.EnqueueRequestForObject{}). Complete(r) @@ -104,7 +107,7 @@ func (r *ClusterwideNetworkPolicyReconciler) Reconcile(ctx context.Context, _ ct cwnps.Items = validCwnps nftablesFirewall := nftables.NewFirewall(f, &cwnps, &services, r.DnsProxy, r.Log, r.Recorder) - if err := r.manageDNSProxy(ctx, f, cwnps, nftablesFirewall); err != nil { + if err := r.manageDNSProxy(f, cwnps, nftablesFirewall); err != nil { return ctrl.Result{}, err } updated, err := nftablesFirewall.Reconcile() @@ -127,7 +130,7 @@ func (r *ClusterwideNetworkPolicyReconciler) Reconcile(ctx context.Context, _ ct // manageDNSProxy start DNS proxy if toFQDN rules are present // if rules were deleted it will stop running DNS proxy func (r *ClusterwideNetworkPolicyReconciler) manageDNSProxy( - ctx context.Context, f *firewallv2.Firewall, cwnps firewallv1.ClusterwideNetworkPolicyList, nftablesFirewall *nftables.Firewall, + f *firewallv2.Firewall, cwnps firewallv1.ClusterwideNetworkPolicyList, nftablesFirewall *nftables.Firewall, ) (err error) { // Skipping is needed for testing if r.SkipDNS { @@ -142,10 +145,10 @@ func (r *ClusterwideNetworkPolicyReconciler) manageDNSProxy( if enableDNS && r.DnsProxy == nil { r.Log.Info("DNS Proxy is initialized") - if r.DnsProxy, err = dns.NewDNSProxy(f.Spec.DNSServerAddress, f.Spec.DNSPort, ctrl.Log.WithName("DNS proxy")); err != nil { + if r.DnsProxy, err = dns.NewDNSProxy(r.Ctx, f.Spec.DNSServerAddress, f.Spec.DNSPort, r.ShootClient, ctrl.Log.WithName("DNS proxy")); err != nil { return fmt.Errorf("failed to init DNS proxy: %w", err) } - go r.DnsProxy.Run(ctx) + go r.DnsProxy.Run() } else if !enableDNS && r.DnsProxy != nil { r.Log.Info("DNS Proxy is stopped") r.DnsProxy.Stop() @@ -217,7 +220,6 @@ func (r *ClusterwideNetworkPolicyReconciler) allowedCWNPs(ctx context.Context, c } for _, cwnp := range cwnps { - cwnp := cwnp oke, err := r.validateCWNPEgressTargetPrefix(cwnp, egressSet) if err != nil { return nil, err diff --git a/controllers/firewall_controller.go b/controllers/firewall_controller.go index b56caccb..96b5ae79 100644 --- a/controllers/firewall_controller.go +++ b/controllers/firewall_controller.go @@ -39,6 +39,7 @@ type FirewallReconciler struct { Recorder record.EventRecorder Log logr.Logger + Ctx context.Context Scheme *runtime.Scheme Updater *updater.Updater diff --git a/main.go b/main.go index 4029390d..43152de7 100644 --- a/main.go +++ b/main.go @@ -91,7 +91,13 @@ func main() { return } - jsonHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{}) + var sll slog.Level + err := sll.UnmarshalText([]byte(logLevel)) + if err != nil { + setupLog.Error(err, "failed to unmarshal log level") + os.Exit(1) + } + jsonHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: sll}) l := slog.New(jsonHandler) ctrl.SetLogger(logr.FromSlogHandler(jsonHandler)) @@ -105,7 +111,6 @@ func main() { // FIXME validation and controller start should be refactored into own func which returns error // instead Fatalw or Error and panic here. - var err error if firewallName == "" { firewallName, err = os.Hostname() if err != nil { @@ -263,6 +268,7 @@ func main() { SeedClient: seedMgr.GetClient(), ShootClient: shootMgr.GetClient(), Log: ctrl.Log.WithName("controllers").WithName("ClusterwideNetworkPolicy"), + Ctx: ctx, Recorder: shootMgr.GetEventRecorderFor("FirewallController"), FirewallName: firewallName, SeedNamespace: seedNamespace, diff --git a/pkg/dns/dnscache.go b/pkg/dns/dnscache.go index aa5be334..4e922f0d 100644 --- a/pkg/dns/dnscache.go +++ b/pkg/dns/dnscache.go @@ -1,12 +1,13 @@ package dns import ( + "context" "crypto/md5" //nolint:gosec "encoding/hex" "fmt" - "math" - "net" + "reflect" "regexp" + "slices" "sort" "strings" "sync" @@ -16,7 +17,12 @@ import ( "github.com/go-logr/logr" "github.com/google/nftables" dnsgo "github.com/miekg/dns" + v1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/yaml" firewallv1 "github.com/metal-stack/firewall-controller/v2/api/v1" ) @@ -32,6 +38,11 @@ const ( // How many DNS redirections (CNAME/DNAME) are followed, to break up redirection loops. maxDNSRedirects = 10 + + // Configmap that holds the FQDN state + fqdnStateConfigmapName = "fqdnstate" + fqdnStateNamespace = firewallv1.ClusterwideNetworkPolicyNamespace + fqdnStateConfigmapKey = "state" ) // RenderIPSet stores set info for rendering @@ -41,32 +52,24 @@ type RenderIPSet struct { Version IPVersion `json:"version,omitempty"` } -type ipEntry struct { - ips []string - expirationTime time.Time - setName string +type iPEntry struct { + // ips is a map of the ip address and its expiration time which is the time of the DNS lookup + the TTL + IPs map[string]time.Time `json:"ips,omitempty"` + SetName string `json:"setName,omitempty"` } -func newIPEntry(setName string, expirationTime time.Time) *ipEntry { - return &ipEntry{ - expirationTime: expirationTime, - setName: setName, +func newIPEntry(setName string) *iPEntry { + return &iPEntry{ + SetName: setName, + IPs: map[string]time.Time{}, } } -func (e *ipEntry) update(setName string, ips []net.IP, expirationTime time.Time, dtype nftables.SetDatatype) error { - newIPs, deletedIPs := e.getNewAndDeletedIPs(ips) - if !e.expirationTime.After(time.Now()) { - e.expirationTime = expirationTime - } +func (e *iPEntry) update(log logr.Logger, setName string, rrs []dnsgo.RR, lookupTime time.Time, dtype nftables.SetDatatype) error { + deletedIPs := e.expireIPs() + newIPs := e.addAndUpdateIPs(log, rrs, lookupTime) if newIPs != nil || deletedIPs != nil { - e.ips = make([]string, len(ips)) - for i, ip := range ips { - e.ips[i] = ip.String() - } - sort.Strings(e.ips) - if err := updateNftSet(newIPs, deletedIPs, setName, dtype); err != nil { return fmt.Errorf("failed to update nft set: %w", err) } @@ -75,33 +78,38 @@ func (e *ipEntry) update(setName string, ips []net.IP, expirationTime time.Time, return nil } -func (e *ipEntry) getNewAndDeletedIPs(ips []net.IP) (newIPs, deletedIPs []nftables.SetElement) { - currentIps := make(map[string]bool, len(e.ips)) - for _, ip := range e.ips { - currentIps[ip] = false - } - - for _, ip := range ips { - s := ip.String() - if _, ok := currentIps[s]; ok { - currentIps[s] = true - } else { - newIPs = append(newIPs, nftables.SetElement{Key: ip}) +func (e *iPEntry) expireIPs() (deletedIPs []nftables.SetElement) { + for ip, expirationTime := range e.IPs { + if expirationTime.Before(time.Now()) { + deletedIPs = append(deletedIPs, nftables.SetElement{Key: []byte(ip)}) + delete(e.IPs, ip) } } + return +} - for ip, exists := range currentIps { - if !exists { - deletedIPs = append(deletedIPs, nftables.SetElement{Key: net.ParseIP(ip)}) +func (e *iPEntry) addAndUpdateIPs(log logr.Logger, rrs []dnsgo.RR, lookupTime time.Time) (newIPs []nftables.SetElement) { + for _, rr := range rrs { + var s string + switch r := rr.(type) { + case *dnsgo.A: + s = r.A.String() + case *dnsgo.AAAA: + s = r.AAAA.String() } - } + if _, ok := e.IPs[s]; !ok { + newIPs = append(newIPs, nftables.SetElement{Key: []byte(s)}) + } + log.WithValues("ip", s, "rr header ttl", rr.Header().Ttl, "expiration time", lookupTime.Add(time.Duration(rr.Header().Ttl)*time.Second)) + e.IPs[s] = lookupTime.Add(time.Duration(rr.Header().Ttl) * time.Second) + } return } type cacheEntry struct { - ipv4 *ipEntry - ipv6 *ipEntry + IPv4 *iPEntry `json:"ipv4,omitempty"` + IPv6 *iPEntry `json:"ipv6,omitempty"` } type DNSCache struct { @@ -111,25 +119,123 @@ type DNSCache struct { fqdnToEntry map[string]cacheEntry setNames map[string]struct{} dnsServerAddr string + shootClient client.Client + ctx context.Context ipv4Enabled bool ipv6Enabled bool } -func newDNSCache(dns string, ipv4Enabled, ipv6Enabled bool, log logr.Logger) *DNSCache { - return &DNSCache{ +func newDNSCache(ctx context.Context, dns string, ipv4Enabled, ipv6Enabled bool, shootClient client.Client, log logr.Logger) (*DNSCache, error) { + c := DNSCache{ log: log, fqdnToEntry: map[string]cacheEntry{}, setNames: map[string]struct{}{}, dnsServerAddr: dns, + shootClient: shootClient, + ctx: ctx, ipv4Enabled: ipv4Enabled, ipv6Enabled: ipv6Enabled, } + + nn := types.NamespacedName{Name: fqdnStateConfigmapName, Namespace: fqdnStateNamespace} + scm := &v1.ConfigMap{} + + ctxWithTimeout, cancel := context.WithTimeout(c.ctx, 5*time.Second) + defer func() { + cancel() + }() + + err := shootClient.Get(ctxWithTimeout, nn, scm) + if err != nil && !apierrors.IsNotFound(err) { + c.log.Error(err, "error reading fqndstate configmap") + return nil, err + } + if scm.Data == nil { + c.log.V(4).Info("DEBUG fqdnstate cm not found or contains no data", "cm", scm) + return &c, nil + + } + if scm.Data[fqdnStateConfigmapKey] == "" { + c.log.Error(fmt.Errorf("error reading fqdnstate configmap, ignoring content"), "fqdnstate configmap does not contain the right key", "configmap", scm, "key", fqdnStateConfigmapKey) + return &c, nil + } + c.log.V(4).Info("DEBUG state stored in fqdnstate cm, trying to unmarshal", fqdnStateConfigmapKey, scm.Data[fqdnStateConfigmapKey]) + err = yaml.UnmarshalStrict([]byte(scm.Data[fqdnStateConfigmapKey]), &c.fqdnToEntry) + if err != nil { + c.log.Info("could not unmarshal state from fqdnstate configmap, ignoring content.", "error", err) + } + return &c, nil } -// getSetsForFQDN returns sets for FQDN selector -func (c *DNSCache) getSetsForFQDN(fqdn firewallv1.FQDNSelector, fqdnSets []firewallv1.IPSet) (result []firewallv1.IPSet) { - c.restoreSets(fqdnSets) +// writeStateToConfigmap writes the whole DNS cache to the state configmap +func (c *DNSCache) writeStateToConfigmap() error { + s, err := yaml.Marshal(c.fqdnToEntry) + if err != nil { + return err + } + + var ( + cm = &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: fqdnStateConfigmapName, + Namespace: fqdnStateNamespace, + }, + } + + data = map[string]string{ + fqdnStateConfigmapKey: string(s), + } + + debugLog = c.log.V(4).WithValues("namespace", cm.Namespace, "name", cm.Name) + ) + + debugLog.Info("DEBUG writing cache to configmap", "fqdnToEntry", string(s)) + + debugLog.Info("DEBUG looking for configmap") + + ctxWithTimeout, cancel := context.WithTimeout(c.ctx, 5*time.Second) + defer func() { + cancel() + }() + + err = c.shootClient.Get(ctxWithTimeout, client.ObjectKeyFromObject(cm), cm) + if err != nil && !apierrors.IsNotFound(err) { + return err + } + + if apierrors.IsNotFound(err) { + debugLog.Info("DEBUG configmap not found, trying to create configmap") + cm.Data = data + + err = c.shootClient.Create(ctxWithTimeout, cm) + if err != nil { + return err + } + + return nil + } + + if !reflect.DeepEqual(cm.Data, data) { + cm.Data = data + + err = c.shootClient.Update(ctxWithTimeout, cm) + if err != nil { + return err + } + + debugLog.Info("DEBUG updated cm", "old data", cm.Data, "new data", data) + + return nil + } + + debugLog.Info("DEBUG no need to update cm, already up to date") + + return nil +} + +// getSetsForFQDN returns sets for FQDN selector +func (c *DNSCache) getSetsForFQDN(fqdn firewallv1.FQDNSelector) (result []firewallv1.IPSet) { sets := map[string]firewallv1.IPSet{} if fqdn.MatchName != "" { for _, s := range c.getSetNameForFQDN(fqdn.GetMatchName()) { @@ -146,6 +252,9 @@ func (c *DNSCache) getSetsForFQDN(fqdn firewallv1.FQDNSelector, fqdnSets []firew for _, s := range sets { result = append(result, s) } + slices.SortStableFunc(result, func(a, b firewallv1.IPSet) int { + return strings.Compare(a.SetName, b.SetName) + }) c.log.WithValues("fqdn", fqdn, "sets", result).Info("sets for FQDN") return @@ -168,14 +277,17 @@ func (c *DNSCache) getSetsForRendering(fqdns []firewallv1.FQDNSelector) (result } } if matched { - if e.ipv4 != nil { - result = append(result, createRenderIPSetFromIPEntry(IPv4, e.ipv4)) + if e.IPv4 != nil { + result = append(result, createRenderIPSetFromIPEntry(IPv4, e.IPv4)) } - if e.ipv6 != nil { - result = append(result, createRenderIPSetFromIPEntry(IPv6, e.ipv6)) + if e.IPv6 != nil { + result = append(result, createRenderIPSetFromIPEntry(IPv6, e.IPv6)) } } } + slices.SortStableFunc(result, func(a, b RenderIPSet) int { + return strings.Compare(a.SetName, b.SetName) + }) return } @@ -184,36 +296,6 @@ func (c *DNSCache) updateDNSServerAddr(addr string) { c.dnsServerAddr = addr } -// restoreSets add missing sets from FQDNSelector.Sets -func (c *DNSCache) restoreSets(fqdnSets []firewallv1.IPSet) { - for _, s := range fqdnSets { - // Add cache entries from fqdn.Sets if missing - c.Lock() - if _, ok := c.setNames[s.SetName]; !ok { - c.setNames[s.SetName] = struct{}{} - entry, exists := c.fqdnToEntry[s.FQDN] - if !exists { - entry = cacheEntry{} - } - - ipe := &ipEntry{ - ips: s.IPs, - expirationTime: s.ExpirationTime.Time, - setName: s.SetName, - } - switch s.Version { - case firewallv1.IPv4: - entry.ipv4 = ipe - case firewallv1.IPv6: - entry.ipv6 = ipe - } - - c.fqdnToEntry[s.FQDN] = entry - } - c.Unlock() - } -} - // getSetNameForFQDN returns FQDN set data func (c *DNSCache) getSetNameForFQDN(fqdn string) (result []firewallv1.IPSet) { c.RLock() @@ -235,11 +317,11 @@ func (c *DNSCache) getSetNameForFQDN(fqdn string) (result []firewallv1.IPSet) { } defer c.RUnlock() - if entry.ipv4 != nil { - result = append(result, createIPSetFromIPEntry(fqdn, firewallv1.IPv4, entry.ipv4)) + if entry.IPv4 != nil { + result = append(result, createIPSetFromIPEntry(fqdn, firewallv1.IPv4, entry.IPv4)) } - if entry.ipv6 != nil { - result = append(result, createIPSetFromIPEntry(fqdn, firewallv1.IPv6, entry.ipv6)) + if entry.IPv6 != nil { + result = append(result, createIPSetFromIPEntry(fqdn, firewallv1.IPv6, entry.IPv6)) } return } @@ -264,7 +346,7 @@ func (c *DNSCache) loadDataFromDNSServer(fqdns []string) error { return fmt.Errorf("failed to get DNS data about fqdn %s: %w", fqdns[0], err) } c.log.V(4).Info("DEBUG dnscache loadDataFromDNSServer function calling Update function", "answer", in, "fqdns", fqdns) - if _, err = c.Update(time.Now(), qname, in, fqdns); err != nil { + if _, err = c.Update(time.Now().UTC(), qname, in, fqdns); err != nil { return fmt.Errorf("failed to update DNS data for fqdn %s: %w", fqdns[0], err) } } @@ -282,11 +364,11 @@ func (c *DNSCache) getSetNameForRegex(regex string) (sets []firewallv1.IPSet) { continue } - if e.ipv4 != nil { - sets = append(sets, createIPSetFromIPEntry(n, firewallv1.IPv4, e.ipv4)) + if e.IPv4 != nil { + sets = append(sets, createIPSetFromIPEntry(n, firewallv1.IPv4, e.IPv4)) } - if e.ipv6 != nil { - sets = append(sets, createIPSetFromIPEntry(n, firewallv1.IPv6, e.ipv6)) + if e.IPv6 != nil { + sets = append(sets, createIPSetFromIPEntry(n, firewallv1.IPv6, e.IPv6)) } } @@ -311,10 +393,8 @@ func (c *DNSCache) Update(lookupTime time.Time, qname string, msg *dnsgo.Msg, fq return true, fmt.Errorf("too many hops, fqdn chain: %s", strings.Join(fqdns, ",")) } - ipv4 := []net.IP{} - ipv6 := []net.IP{} - minIPv4TTL := uint32(math.MaxUint32) - minIPv6TTL := uint32(math.MaxUint32) + ipv4 := []dnsgo.RR{} + ipv6 := []dnsgo.RR{} found := false for _, ans := range msg.Answer { @@ -326,17 +406,11 @@ func (c *DNSCache) Update(lookupTime time.Time, qname string, msg *dnsgo.Msg, fq switch rr := ans.(type) { case *dnsgo.A: - ipv4 = append(ipv4, rr.A) - if minIPv4TTL > rr.Hdr.Ttl { - minIPv4TTL = rr.Hdr.Ttl - } + ipv4 = append(ipv4, rr) found = true c.log.V(4).Info("DEBUG dnscache Update function A record found", "IPs", ipv4) case *dnsgo.AAAA: - ipv6 = append(ipv6, rr.AAAA) - if minIPv6TTL > rr.Hdr.Ttl { - minIPv6TTL = rr.Hdr.Ttl - } + ipv6 = append(ipv6, rr) found = true c.log.V(4).Info("DEBUG dnscache Update function AAAA record found", "IPs", ipv6) case *dnsgo.CNAME: @@ -359,27 +433,37 @@ func (c *DNSCache) Update(lookupTime time.Time, qname string, msg *dnsgo.Msg, fq } } + ipEntriesUpdated := false + for _, fqdn := range fqdns { c.log.V(4).Info("DEBUG dnscache Update function Updating DNS cache for", "fqdn", fqdn, "ipv4", ipv4, "ipv6", ipv6) if c.ipv4Enabled && len(ipv4) > 0 { - if err := c.updateIPEntry(fqdn, ipv4, lookupTime.Add(time.Duration(minIPv4TTL)), nftables.TypeIPAddr); err != nil { + if err := c.updateIPEntry(fqdn, ipv4, lookupTime, nftables.TypeIPAddr); err != nil { return false, fmt.Errorf("failed to update IPv4 addresses: %w", err) } + ipEntriesUpdated = true } if c.ipv6Enabled && len(ipv6) > 0 { - if err := c.updateIPEntry(fqdn, ipv6, lookupTime.Add(time.Duration(minIPv6TTL)), nftables.TypeIP6Addr); err != nil { + if err := c.updateIPEntry(fqdn, ipv6, lookupTime, nftables.TypeIP6Addr); err != nil { return false, fmt.Errorf("failed to update IPv6 addresses: %w", err) } + ipEntriesUpdated = true + } + } + + if ipEntriesUpdated { + if err := c.writeStateToConfigmap(); err != nil { + c.log.V(4).Info("DEBUG could not write updated DNS cache to state configmap", "configmap", fqdnStateConfigmapName, "namespace", fqdnStateNamespace, "error", err) } } return found, nil } -func (c *DNSCache) updateIPEntry(qname string, ips []net.IP, expirationTime time.Time, dtype nftables.SetDatatype) error { +func (c *DNSCache) updateIPEntry(qname string, rrs []dnsgo.RR, lookupTime time.Time, dtype nftables.SetDatatype) error { scopedLog := c.log.WithValues( "fqdn", qname, - "ip_len", len(ips), + "ip_len", len(rrs), "dtype", dtype.Name, ) @@ -391,27 +475,28 @@ func (c *DNSCache) updateIPEntry(qname string, ips []net.IP, expirationTime time entry = cacheEntry{} } - var ipe *ipEntry + var ipe *iPEntry switch dtype { case nftables.TypeIPAddr: - if entry.ipv4 == nil { + if entry.IPv4 == nil { setName := c.createSetName(qname, dtype.Name, 0) - ipe = newIPEntry(setName, expirationTime) - entry.ipv4 = ipe + ipe = newIPEntry(setName) + entry.IPv4 = ipe } - ipe = entry.ipv4 + ipe = entry.IPv4 case nftables.TypeIP6Addr: - if entry.ipv6 == nil { + if entry.IPv6 == nil { setName := c.createSetName(qname, dtype.Name, 0) - ipe = newIPEntry(setName, expirationTime) - entry.ipv6 = ipe + ipe = newIPEntry(setName) + entry.IPv6 = ipe } - ipe = entry.ipv6 + ipe = entry.IPv6 } - setName := ipe.setName - if err := ipe.update(setName, ips, expirationTime, dtype); err != nil { - return fmt.Errorf("failed to update ipEntry: %w", err) + setName := ipe.SetName + scopedLog.WithValues("set", setName, "lookupTime", lookupTime, "rrs", rrs).Info("updating ip entry") + if err := ipe.update(scopedLog, setName, rrs, lookupTime, dtype); err != nil { + return fmt.Errorf("failed to update IPEntry: %w", err) } c.fqdnToEntry[qname] = entry @@ -477,20 +562,28 @@ func updateNftSet( return nil } -func createIPSetFromIPEntry(fqdn string, version firewallv1.IPVersion, entry *ipEntry) firewallv1.IPSet { - return firewallv1.IPSet{ - FQDN: fqdn, - SetName: entry.setName, - IPs: entry.ips, - ExpirationTime: metav1.Time{Time: entry.expirationTime}, - Version: version, +func createIPSetFromIPEntry(fqdn string, version firewallv1.IPVersion, entry *iPEntry) firewallv1.IPSet { + ips := firewallv1.IPSet{ + FQDN: fqdn, + SetName: entry.SetName, + IPExpirationTimes: make(map[string]metav1.Time), + Version: version, } + for ip, expirationTime := range entry.IPs { + ips.IPExpirationTimes[ip] = metav1.NewTime(expirationTime) + } + return ips } -func createRenderIPSetFromIPEntry(version IPVersion, entry *ipEntry) RenderIPSet { +func createRenderIPSetFromIPEntry(version IPVersion, entry *iPEntry) RenderIPSet { + var ips []string + for ip := range entry.IPs { + ips = append(ips, ip) + } + sort.Strings(ips) return RenderIPSet{ - SetName: entry.setName, - IPs: entry.ips, + SetName: entry.SetName, + IPs: ips, Version: version, } } diff --git a/pkg/dns/dnscache_test.go b/pkg/dns/dnscache_test.go index ade37e6c..92c2b1fb 100644 --- a/pkg/dns/dnscache_test.go +++ b/pkg/dns/dnscache_test.go @@ -2,34 +2,48 @@ package dns import ( "testing" + "time" "github.com/go-logr/logr" - + "github.com/google/go-cmp/cmp" firewallv1 "github.com/metal-stack/firewall-controller/v2/api/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) func Test_GetSetsForFQDN(t *testing.T) { tests := []struct { - name string - fqdnToEntry map[string]cacheEntry - expectedSets []string - fqdnSelector firewallv1.FQDNSelector - cachedSets []firewallv1.IPSet + name string + fqdnToEntry map[string]cacheEntry + want []firewallv1.IPSet + fqdn firewallv1.FQDNSelector }{ { name: "get result for matchName", fqdnToEntry: map[string]cacheEntry{ "test.com.": { - ipv4: &ipEntry{ - setName: "testv4", + IPv4: &iPEntry{ + SetName: "testv4", }, - ipv6: &ipEntry{ - setName: "testv6", + IPv6: &iPEntry{ + SetName: "testv6", }, }, }, - expectedSets: []string{"testv4", "testv6"}, - fqdnSelector: firewallv1.FQDNSelector{ + want: []firewallv1.IPSet{ + { + SetName: "testv4", + FQDN: "test.com.", + Version: "ip", + IPExpirationTimes: map[string]v1.Time{}, + }, + { + SetName: "testv6", + FQDN: "test.com.", + Version: "ip6", + IPExpirationTimes: map[string]v1.Time{}, + }, + }, + fqdn: firewallv1.FQDNSelector{ MatchName: "test.com", }, }, @@ -37,40 +51,77 @@ func Test_GetSetsForFQDN(t *testing.T) { name: "get result for matchPattern", fqdnToEntry: map[string]cacheEntry{ "test.com.": { - ipv4: &ipEntry{ - setName: "testv4", + IPv4: &iPEntry{ + SetName: "testv4", }, - ipv6: &ipEntry{ - setName: "testv6", + IPv6: &iPEntry{ + SetName: "testv6", }, }, "test.io.": { - ipv4: &ipEntry{ - setName: "testiov4", + IPv4: &iPEntry{ + SetName: "testiov4", }, - ipv6: &ipEntry{ - setName: "testiov6", + IPv6: &iPEntry{ + SetName: "testiov6", }, }, "example.com.": { - ipv4: &ipEntry{ - setName: "examplev4", + IPv4: &iPEntry{ + SetName: "examplev4", }, - ipv6: &ipEntry{ - setName: "examplev6", + IPv6: &iPEntry{ + SetName: "examplev6", }, }, "second.example.com.": { - ipv4: &ipEntry{ - setName: "2examplev4", + IPv4: &iPEntry{ + SetName: "2examplev4", }, - ipv6: &ipEntry{ - setName: "2examplev6", + IPv6: &iPEntry{ + SetName: "2examplev6", }, }, }, - expectedSets: []string{"testv4", "testv6", "examplev4", "examplev6", "2examplev4", "2examplev6"}, - fqdnSelector: firewallv1.FQDNSelector{ + want: []firewallv1.IPSet{ + { + SetName: "2examplev4", + FQDN: "second.example.com.", + Version: "ip", + IPExpirationTimes: map[string]v1.Time{}, + }, + { + SetName: "2examplev6", + FQDN: "second.example.com.", + IPExpirationTimes: map[string]v1.Time{}, + Version: "ip6", + }, + { + SetName: "examplev4", + FQDN: "example.com.", + IPExpirationTimes: map[string]v1.Time{}, + Version: "ip", + }, + { + SetName: "examplev6", + FQDN: "example.com.", + IPExpirationTimes: map[string]v1.Time{}, + Version: "ip6", + }, + { + SetName: "testv4", + FQDN: "test.com.", + IPExpirationTimes: map[string]v1.Time{}, + Version: "ip", + }, + { + SetName: "testv6", + FQDN: "test.com.", + IPExpirationTimes: map[string]v1.Time{}, + Version: "ip6", + }, + }, + fqdn: firewallv1.FQDNSelector{ MatchPattern: "*.com", }, }, @@ -78,66 +129,99 @@ func Test_GetSetsForFQDN(t *testing.T) { name: "pattern from integration testing", fqdnToEntry: map[string]cacheEntry{ "www.freechess.org.": { - ipv4: &ipEntry{ - setName: "testv4", + IPv4: &iPEntry{ + SetName: "testv4", }, - ipv6: &ipEntry{ - setName: "testv6", + IPv6: &iPEntry{ + SetName: "testv6", }, }, }, - expectedSets: []string{"testv4", "testv6"}, - fqdnSelector: firewallv1.FQDNSelector{ + want: []firewallv1.IPSet{ + { + SetName: "testv4", + FQDN: "www.freechess.org.", + IPExpirationTimes: map[string]v1.Time{}, + Version: "ip", + }, + { + SetName: "testv6", + FQDN: "www.freechess.org.", + IPExpirationTimes: map[string]v1.Time{}, + Version: "ip6", + }, + }, + fqdn: firewallv1.FQDNSelector{ MatchPattern: "ww*.freechess.org", }, }, - { - name: "restore sets", - fqdnToEntry: map[string]cacheEntry{}, - fqdnSelector: firewallv1.FQDNSelector{}, - cachedSets: []firewallv1.IPSet{{ - FQDN: "test-fqdn", - SetName: "test-set", - }}, - }, } for _, tt := range tests { - tc := tt - t.Run(tc.name, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { cache := DNSCache{ log: logr.Discard(), - fqdnToEntry: tc.fqdnToEntry, + fqdnToEntry: tt.fqdnToEntry, setNames: make(map[string]struct{}), ipv4Enabled: true, ipv6Enabled: true, } - result := cache.getSetsForFQDN(tc.fqdnSelector, tc.cachedSets) - set := make(map[string]bool, len(tc.expectedSets)) - for _, s := range tc.expectedSets { - set[s] = false - } - for _, r := range result { - if _, found := set[r.SetName]; !found { - t.Errorf("set name %s wasn't expected", r.SetName) - } - set[r.SetName] = true - } - for s, b := range set { - if !b { - t.Errorf("set name %s didn't occurred in result", s) - } + got := cache.getSetsForFQDN(tt.fqdn) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("DNSCache.getSetsForFQDN diff = %s", diff) } + }) + } +} - // Check if cache was updated - for _, s := range tc.cachedSets { - if _, ok := cache.setNames[s.SetName]; !ok { - t.Errorf("set name %s wasn't added to cache", s.SetName) - } - if _, ok := cache.fqdnToEntry[s.FQDN]; !ok { - t.Errorf("FQDN %s wasn't added to cache", s.FQDN) - } +func Test_createIPSetFromIPEntry(t *testing.T) { + tests := []struct { + name string + fqdn string + version firewallv1.IPVersion + entry *iPEntry + want firewallv1.IPSet + }{ + { + name: "empty ip entry", + fqdn: "www.freechess.org", + version: "ip", + entry: &iPEntry{ + SetName: "test", + }, + want: firewallv1.IPSet{ + FQDN: "www.freechess.org", + SetName: "test", + IPExpirationTimes: map[string]v1.Time{}, + Version: "ip", + }, + }, + { + name: "entry contains ips", + fqdn: "www.freechess.org", + version: "ip", + entry: &iPEntry{ + SetName: "test", + IPs: map[string]time.Time{ + "1.2.3.4": time.Date(2100, time.January, 1, 0, 0, 0, 0, time.UTC), + }, + }, + want: firewallv1.IPSet{ + FQDN: "www.freechess.org", + SetName: "test", + IPExpirationTimes: map[string]v1.Time{ + "1.2.3.4": v1.NewTime(time.Date(2100, time.January, 1, 0, 0, 0, 0, time.UTC)), + }, + Version: "ip", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := createIPSetFromIPEntry(tt.fqdn, tt.version, tt.entry) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("createIPSetFromIPEntry() diff = %s", diff) } }) } diff --git a/pkg/dns/dnsproxy.go b/pkg/dns/dnsproxy.go index a867afbb..b9a16212 100644 --- a/pkg/dns/dnsproxy.go +++ b/pkg/dns/dnsproxy.go @@ -5,8 +5,10 @@ import ( "fmt" "net" "strconv" + "time" "github.com/metal-stack/metal-networker/pkg/netconf" + "sigs.k8s.io/controller-runtime/pkg/client" firewallv1 "github.com/metal-stack/firewall-controller/v2/api/v1" @@ -26,9 +28,10 @@ type DNSHandler interface { } type DNSProxy struct { - log logr.Logger - cache *DNSCache - stopCh chan struct{} + log logr.Logger + ctx context.Context + cancelFunc context.CancelFunc + cache *DNSCache udpServer *dnsgo.Server tcpServer *dnsgo.Server @@ -36,12 +39,10 @@ type DNSProxy struct { handler DNSHandler } -func NewDNSProxy(dns string, port *uint, log logr.Logger) (*DNSProxy, error) { +func NewDNSProxy(ctx context.Context, dns string, port *uint, shootClient client.Client, log logr.Logger) (*DNSProxy, error) { if dns == "" { dns = defaultDNSServerAddr } - cache := newDNSCache(dns, true, false, log.WithName("DNS cache")) - handler := NewDNSProxyHandler(log, cache) host, err := getHost() if err != nil { @@ -57,13 +58,22 @@ func NewDNSProxy(dns string, port *uint, log logr.Logger) (*DNSProxy, error) { return nil, fmt.Errorf("failed to bind to port: %w", err) } + backgroundCtx, cancel := context.WithCancel(ctx) + cache, err := newDNSCache(backgroundCtx, dns, true, false, shootClient, log.WithName("DNS cache")) + if err != nil { + cancel() + return nil, err + } + handler := NewDNSProxyHandler(log, cache) + udpServer := &dnsgo.Server{PacketConn: udpConn, Addr: udpConn.LocalAddr().String(), Net: "udp", Handler: handler} tcpServer := &dnsgo.Server{Listener: tcpListener, Addr: udpConn.LocalAddr().String(), Net: "tcp", Handler: handler} return &DNSProxy{ - log: log, - cache: cache, - stopCh: make(chan struct{}), + log: log, + ctx: backgroundCtx, + cancelFunc: cancel, + cache: cache, udpServer: udpServer, tcpServer: tcpServer, @@ -73,7 +83,7 @@ func NewDNSProxy(dns string, port *uint, log logr.Logger) (*DNSProxy, error) { } // Run starts TCP/UDP servers -func (p *DNSProxy) Run(ctx context.Context) { +func (p *DNSProxy) Run() { go func() { p.log.Info("starting UDP server") if err := p.udpServer.ActivateAndServe(); err != nil { @@ -88,7 +98,9 @@ func (p *DNSProxy) Run(ctx context.Context) { } }() - <-p.stopCh + <-p.ctx.Done() + ctx, cancel := context.WithTimeout(p.ctx, time.Second*5) + defer cancel() if err := p.udpServer.ShutdownContext(ctx); err != nil { p.log.Error(err, "failed to shut down UDP server") @@ -100,7 +112,7 @@ func (p *DNSProxy) Run(ctx context.Context) { // Stop starts TCP/UDP servers func (p *DNSProxy) Stop() { - close(p.stopCh) + p.cancelFunc() } func (p *DNSProxy) UpdateDNSServerAddr(addr string) error { @@ -116,8 +128,8 @@ func (p *DNSProxy) GetSetsForRendering(fqdns []firewallv1.FQDNSelector) (result return p.cache.getSetsForRendering(fqdns) } -func (p *DNSProxy) GetSetsForFQDN(fqdn firewallv1.FQDNSelector, fqdnSets []firewallv1.IPSet) (result []firewallv1.IPSet) { - return p.cache.getSetsForFQDN(fqdn, fqdnSets) +func (p *DNSProxy) GetSetsForFQDN(fqdn firewallv1.FQDNSelector) (result []firewallv1.IPSet) { + return p.cache.getSetsForFQDN(fqdn) } func (p *DNSProxy) IsInitialized() bool { diff --git a/pkg/nftables/firewall.go b/pkg/nftables/firewall.go index 7c938caa..463a616f 100644 --- a/pkg/nftables/firewall.go +++ b/pkg/nftables/firewall.go @@ -41,7 +41,7 @@ var templates embed.FS //go:generate ../../bin/mockgen -destination=./mocks/mock_fqdncache.go -package=mocks . FQDNCache type FQDNCache interface { GetSetsForRendering(fqdns []firewallv1.FQDNSelector) (result []dns.RenderIPSet) - GetSetsForFQDN(fqdn firewallv1.FQDNSelector, fqdnSets []firewallv1.IPSet) (result []firewallv1.IPSet) + GetSetsForFQDN(fqdn firewallv1.FQDNSelector) (result []firewallv1.IPSet) IsInitialized() bool CacheAddr() (string, error) } @@ -199,14 +199,10 @@ func getConfiguredIPs(networkID string) []string { } var ips []string for _, nw := range c.Networks { - nw := nw if nw.Networkid == nil || *nw.Networkid != networkID { continue } - for _, ip := range nw.Ips { - ip := ip - ips = append(ips, ip) - } + ips = append(ips, nw.Ips...) } return ips } diff --git a/pkg/nftables/mocks/mock_fqdncache.go b/pkg/nftables/mocks/mock_fqdncache.go index fbd5bc3a..d1f26694 100644 --- a/pkg/nftables/mocks/mock_fqdncache.go +++ b/pkg/nftables/mocks/mock_fqdncache.go @@ -57,17 +57,17 @@ func (mr *MockFQDNCacheMockRecorder) CacheAddr() *gomock.Call { } // GetSetsForFQDN mocks base method. -func (m *MockFQDNCache) GetSetsForFQDN(fqdn v1.FQDNSelector, fqdnSets []v1.IPSet) []v1.IPSet { +func (m *MockFQDNCache) GetSetsForFQDN(fqdn v1.FQDNSelector) []v1.IPSet { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSetsForFQDN", fqdn, fqdnSets) + ret := m.ctrl.Call(m, "GetSetsForFQDN", fqdn) ret0, _ := ret[0].([]v1.IPSet) return ret0 } // GetSetsForFQDN indicates an expected call of GetSetsForFQDN. -func (mr *MockFQDNCacheMockRecorder) GetSetsForFQDN(fqdn, fqdnSets any) *gomock.Call { +func (mr *MockFQDNCacheMockRecorder) GetSetsForFQDN(fqdn any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSetsForFQDN", reflect.TypeOf((*MockFQDNCache)(nil).GetSetsForFQDN), fqdn, fqdnSets) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSetsForFQDN", reflect.TypeOf((*MockFQDNCache)(nil).GetSetsForFQDN), fqdn) } // GetSetsForRendering mocks base method. diff --git a/pkg/nftables/networkpolicy.go b/pkg/nftables/networkpolicy.go index de472f03..85b56812 100644 --- a/pkg/nftables/networkpolicy.go +++ b/pkg/nftables/networkpolicy.go @@ -78,6 +78,7 @@ func clusterwideNetworkPolicyEgressRules( np firewallv1.ClusterwideNetworkPolicy, logAcceptedConnections bool, ) (rules nftablesRules, updated firewallv1.ClusterwideNetworkPolicy) { + var fqdnState firewallv1.FQDNState for _, e := range np.Spec.Egress { tcpPorts, udpPorts := calculatePorts(e.Ports) ruleBases := []ruleBase{} @@ -94,10 +95,9 @@ func clusterwideNetworkPolicyEgressRules( } ruleBases = append(ruleBases, ruleBase{base: rb}) } else if len(e.ToFQDNs) > 0 && cache.IsInitialized() { - // Generate allow rules based on DNS selectors - rbs, u := clusterwideNetworkPolicyEgressToFQDNRules(cache, e) - np.Status.FQDNState = u + rbs, u := clusterwideNetworkPolicyEgressToFQDNRules(cache, fqdnState, e) ruleBases = append(ruleBases, rbs...) + fqdnState = u } comment := fmt.Sprintf("accept traffic for np %s", np.Name) @@ -111,6 +111,7 @@ func clusterwideNetworkPolicyEgressRules( } } + np.Status.FQDNState = fqdnState return uniqueSorted(rules), np } @@ -125,9 +126,12 @@ func clusterwideNetworkPolicyEgressToRules(e firewallv1.EgressRule) (allow, exce func clusterwideNetworkPolicyEgressToFQDNRules( cache FQDNCache, + fqdnState firewallv1.FQDNState, e firewallv1.EgressRule, ) (rules []ruleBase, updatedState firewallv1.FQDNState) { - fqdnState := firewallv1.FQDNState{} + if fqdnState == nil { + fqdnState = firewallv1.FQDNState{} + } for _, fqdn := range e.ToFQDNs { fqdnName := fqdn.MatchName @@ -135,7 +139,7 @@ func clusterwideNetworkPolicyEgressToFQDNRules( fqdnName = fqdn.MatchPattern } - fqdnState[fqdnName] = cache.GetSetsForFQDN(fqdn, fqdnState[fqdnName]) + fqdnState[fqdnName] = cache.GetSetsForFQDN(fqdn) for _, set := range fqdnState[fqdnName] { rb := []string{"ip saddr == @cluster_prefixes"} rb = append(rb, fmt.Sprintf(string(set.Version)+" daddr @%s", set.SetName)) diff --git a/pkg/nftables/networkpolicy_test.go b/pkg/nftables/networkpolicy_test.go index 295cfd6e..6580d7b5 100644 --- a/pkg/nftables/networkpolicy_test.go +++ b/pkg/nftables/networkpolicy_test.go @@ -99,14 +99,11 @@ func TestClusterwideNetworkPolicyRules(t *testing.T) { `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } udp dport { 53 } counter accept comment "accept traffic for np udp"`, }, ingressAL: nftablesRules{ - `ip saddr != { 1.1.0.1 } ip saddr { 1.1.0.0/24 } tcp dport { 80, 443-448 } log prefix "nftables-firewall-accepted: " limit rate 10/second`, - `ip saddr != { 1.1.0.1 } ip saddr { 1.1.0.0/24 } tcp dport { 80, 443-448 } counter accept comment "accept traffic for k8s network policy tcp"`, + `ip saddr != { 1.1.0.1 } ip saddr { 1.1.0.0/24 } tcp dport { 80, 443-448 } log prefix "nftables-firewall-accepted: " limit rate 10/second` + "\n" + `ip saddr != { 1.1.0.1 } ip saddr { 1.1.0.0/24 } tcp dport { 80, 443-448 } counter accept comment "accept traffic for k8s network policy tcp"`, }, egressAL: nftablesRules{ - `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } tcp dport { 53, 443-448 } log prefix "nftables-firewall-accepted: " limit rate 10/second`, - `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } tcp dport { 53, 443-448 } counter accept comment "accept traffic for np tcp"`, - `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } udp dport { 53 } log prefix "nftables-firewall-accepted: " limit rate 10/second`, - `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } udp dport { 53 } counter accept comment "accept traffic for np udp"`, + `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } tcp dport { 53, 443-448 } log prefix "nftables-firewall-accepted: " limit rate 10/second` + "\n" + `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } tcp dport { 53, 443-448 } counter accept comment "accept traffic for np tcp"`, + `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } udp dport { 53 } log prefix "nftables-firewall-accepted: " limit rate 10/second` + "\n" + `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } udp dport { 53 } counter accept comment "accept traffic for np udp"`, }, }, }, @@ -184,10 +181,8 @@ func TestClusterwideNetworkPolicyEgressRules(t *testing.T) { `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } udp dport { 53 } counter accept comment "accept traffic for np udp"`, }, egressAL: nftablesRules{ - `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } tcp dport { 53 } log prefix "nftables-firewall-accepted: " limit rate 10/second`, - `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } tcp dport { 53 } counter accept comment "accept traffic for np tcp"`, - `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } udp dport { 53 } log prefix "nftables-firewall-accepted: " limit rate 10/second`, - `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } udp dport { 53 } counter accept comment "accept traffic for np udp"`, + `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } tcp dport { 53 } log prefix "nftables-firewall-accepted: " limit rate 10/second` + "\n" + `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } tcp dport { 53 } counter accept comment "accept traffic for np tcp"`, + `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } udp dport { 53 } log prefix "nftables-firewall-accepted: " limit rate 10/second` + "\n" + `ip saddr == @cluster_prefixes ip daddr != { 1.1.0.1 } ip daddr { 1.1.0.0/24, 1.1.1.0/24 } udp dport { 53 } counter accept comment "accept traffic for np udp"`, }, }, }, @@ -226,11 +221,11 @@ func TestClusterwideNetworkPolicyEgressRules(t *testing.T) { Return(true) cache. EXPECT(). - GetSetsForFQDN(gomock.Any(), gomock.Any()). + GetSetsForFQDN(gomock.Any()). Return([]firewallv1.IPSet{{SetName: "test", Version: firewallv1.IPv4}}) cache. EXPECT(). - GetSetsForFQDN(gomock.Any(), gomock.Any()). + GetSetsForFQDN(gomock.Any()). Return([]firewallv1.IPSet{{SetName: "test2", Version: firewallv1.IPv6}}) }, want: want{ diff --git a/pkg/nftables/rendering.go b/pkg/nftables/rendering.go index e8d78380..e69b30de 100644 --- a/pkg/nftables/rendering.go +++ b/pkg/nftables/rendering.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "os" + "sort" "strings" "text/template" @@ -37,6 +38,8 @@ func newFirewallRenderingData(f *Firewall) (*firewallRenderingData, error) { egress = append(egress, e...) f.clusterwideNetworkPolicies.Items[ind] = u } + sort.Strings(ingress) + sort.Strings(egress) var serviceAllowedSet *netipx.IPSet if len(f.firewall.Spec.AllowedNetworks.Ingress) > 0 { @@ -72,6 +75,10 @@ func newFirewallRenderingData(f *Firewall) (*firewallRenderingData, error) { } egress = append(egress, rules...) } + + ingress = splitRules(ingress) + egress = splitRules(egress) + return &firewallRenderingData{ AdditionalDNSAddrs: dnsAddrs, PrivateVrfID: uint(*f.primaryPrivateNet.Vrf), // nolint:gosec diff --git a/pkg/nftables/service.go b/pkg/nftables/service.go index 8eb59cbf..efb00cb8 100644 --- a/pkg/nftables/service.go +++ b/pkg/nftables/service.go @@ -44,7 +44,6 @@ func serviceRules(svc corev1.Service, allowed *netipx.IPSet, logAcceptedConnecti tcpPorts := []string{} udpPorts := []string{} for _, p := range svc.Spec.Ports { - p := p proto := proto(&p.Protocol) switch proto { case "tcp": diff --git a/pkg/nftables/service_test.go b/pkg/nftables/service_test.go index 9dcae4e5..466eccbe 100644 --- a/pkg/nftables/service_test.go +++ b/pkg/nftables/service_test.go @@ -60,8 +60,7 @@ func TestServiceRules(t *testing.T) { `ip saddr { 185.0.0.0/16, 185.1.0.0/16 } ip daddr { 185.0.0.1 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, }, ingressAL: nftablesRules{ - `ip saddr { 185.0.0.0/16, 185.1.0.0/16 } ip daddr { 185.0.0.1 } tcp dport { 443 } log prefix "nftables-firewall-accepted: " limit rate 10/second`, - `ip saddr { 185.0.0.0/16, 185.1.0.0/16 } ip daddr { 185.0.0.1 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, + `ip saddr { 185.0.0.0/16, 185.1.0.0/16 } ip daddr { 185.0.0.1 } tcp dport { 443 } log prefix "nftables-firewall-accepted: " limit rate 10/second` + "\n" + `ip saddr { 185.0.0.0/16, 185.1.0.0/16 } ip daddr { 185.0.0.1 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, }, }, }, @@ -131,8 +130,7 @@ func TestServiceRules(t *testing.T) { `ip saddr { 185.0.0.0/16, 185.1.0.0/16 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, }, ingressAL: nftablesRules{ - `ip saddr { 185.0.0.0/16, 185.1.0.0/16 } tcp dport { 443 } log prefix "nftables-firewall-accepted: " limit rate 10/second`, - `ip saddr { 185.0.0.0/16, 185.1.0.0/16 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, + `ip saddr { 185.0.0.0/16, 185.1.0.0/16 } tcp dport { 443 } log prefix "nftables-firewall-accepted: " limit rate 10/second` + "\n" + `ip saddr { 185.0.0.0/16, 185.1.0.0/16 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, }, }, }, @@ -170,8 +168,7 @@ func TestServiceRules(t *testing.T) { `ip daddr { 185.0.1.2, 185.0.1.1 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, }, ingressAL: nftablesRules{ - `ip daddr { 185.0.1.2, 185.0.1.1 } tcp dport { 443 } log prefix "nftables-firewall-accepted: " limit rate 10/second`, - `ip daddr { 185.0.1.2, 185.0.1.1 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, + `ip daddr { 185.0.1.2, 185.0.1.1 } tcp dport { 443 } log prefix "nftables-firewall-accepted: " limit rate 10/second` + "\n" + `ip daddr { 185.0.1.2, 185.0.1.1 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, }, }, }, @@ -209,8 +206,7 @@ func TestServiceRules(t *testing.T) { `ip daddr { 185.0.1.1 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, }, ingressAL: nftablesRules{ - `ip daddr { 185.0.1.1 } tcp dport { 443 } log prefix "nftables-firewall-accepted: " limit rate 10/second`, - `ip daddr { 185.0.1.1 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, + `ip daddr { 185.0.1.1 } tcp dport { 443 } log prefix "nftables-firewall-accepted: " limit rate 10/second` + "\n" + `ip daddr { 185.0.1.1 } tcp dport { 443 } counter accept comment "accept traffic for k8s service test/svc"`, }, }, }, diff --git a/pkg/nftables/util.go b/pkg/nftables/util.go index 1e8e6f15..4bbb1de5 100644 --- a/pkg/nftables/util.go +++ b/pkg/nftables/util.go @@ -16,11 +16,15 @@ func uniqueSorted(elements []string) []string { for _, e := range elements { t[e] = true } - rawRules := []string{} + rules := []string{} for k := range t { - rawRules = append(rawRules, k) + rules = append(rules, k) } - sort.Strings(rawRules) + sort.Strings(rules) + return rules +} + +func splitRules(rawRules []string) []string { rules := []string{} for _, r := range rawRules { // split multiline log\naccept rules for pretty nftables file formatting rules = append(rules, strings.Split(r, "\n")...) diff --git a/pkg/nftables/util_test.go b/pkg/nftables/util_test.go index ba88186b..995ba146 100644 --- a/pkg/nftables/util_test.go +++ b/pkg/nftables/util_test.go @@ -3,6 +3,8 @@ package nftables import ( "os" "testing" + + "github.com/google/go-cmp/cmp" ) func Test_equal(t *testing.T) { @@ -148,7 +150,6 @@ table ip firewall { }, } for _, tt := range tests { - tt := tt s, err := os.CreateTemp("/tmp", "source") if err != nil { t.Fail() @@ -173,3 +174,30 @@ table ip firewall { }) } } + +func Test_splitRules(t *testing.T) { + tests := []struct { + name string + rawRules []string + want []string + }{ + { + name: "split multiline string into separate strings", + rawRules: []string{ + "ip saddr != { 1.1.0.1 } ip saddr { 1.1.0.0/24 } tcp dport { 80, 443-448 } log prefix \"nftables-firewall-accepted: \" limit rate 10/second\nip saddr != { 1.1.0.1 } ip saddr { 1.1.0.0/24 } tcp dport { 80, 443-448 } counter accept comment \"accept traffic for k8s network policy tcp\"", + }, + want: []string{ + "ip saddr != { 1.1.0.1 } ip saddr { 1.1.0.0/24 } tcp dport { 80, 443-448 } log prefix \"nftables-firewall-accepted: \" limit rate 10/second", + "ip saddr != { 1.1.0.1 } ip saddr { 1.1.0.0/24 } tcp dport { 80, 443-448 } counter accept comment \"accept traffic for k8s network policy tcp\"", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := splitRules(tt.rawRules) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("splitRules() diff = %s", diff) + } + }) + } +}