diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 17cb068..cda89ff 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -1786,6 +1786,7 @@ func (wkd *wellKnownDomainsTracker) sendUpdate(ipBytes []byte, msg *dns.Msg, daw func (wkd *wellKnownDomainsTracker) rotateTracker(edm *dnstapMinimiser, dawgFile string, rotationTime time.Time) (*wellKnownDomainsData, error) { dawgFileChanged := false var dawgFinder dawg.Finder + var dawgModTime time.Time fileInfo, err := os.Stat(dawgFile) if err != nil { @@ -1793,9 +1794,9 @@ func (wkd *wellKnownDomainsTracker) rotateTracker(edm *dnstapMinimiser, dawgFile } if fileInfo.ModTime() != wkd.dawgModTime { - dawgFinder, err = dawg.Load(dawgFile) + dawgFinder, dawgModTime, err = loadDawgFile(dawgFile) if err != nil { - return nil, fmt.Errorf("rotateTracker: dawg.Load(): %w", err) + return nil, fmt.Errorf("rotateTracker: loadDawgFile(): %w", err) } dawgFileChanged = true edm.log.Info("dawg file modification changed, will reload file", "prev_time", wkd.dawgModTime, "cur_time", fileInfo.ModTime()) @@ -1810,7 +1811,7 @@ func (wkd *wellKnownDomainsTracker) rotateTracker(edm *dnstapMinimiser, dawgFile wkd.m = map[int]*histogramData{} if dawgFileChanged { wkd.dawgFinder = dawgFinder - wkd.dawgModTime = fileInfo.ModTime() + wkd.dawgModTime = dawgModTime prevWKD.dawgIsRotated = true } wkd.mutex.Unlock() diff --git a/pkg/runner/runner_test.go b/pkg/runner/runner_test.go index 2bbe84b..2df15c0 100644 --- a/pkg/runner/runner_test.go +++ b/pkg/runner/runner_test.go @@ -3,6 +3,7 @@ package runner import ( "bytes" "encoding/binary" + "errors" "flag" "io" "log/slog" @@ -215,6 +216,41 @@ func TestWKD(t *testing.T) { } } +func TestRotateTrackerUsesSafeDawgLoader(t *testing.T) { + dBuilder := dawg.New() + dBuilder.Add("example.com.") + dFinder := dBuilder.Finish() + + dawgFile := t.TempDir() + "/well-known-domains.dawg" + if _, err := dFinder.Save(dawgFile); err != nil { + t.Fatalf("Save: %s", err) + } + fileInfo, err := os.Stat(dawgFile) + if err != nil { + t.Fatalf("Stat: %s", err) + } + + wkd, err := newWellKnownDomainsTracker(dFinder, fileInfo.ModTime()) + if err != nil { + t.Fatalf("newWellKnownDomainsTracker: %s", err) + } + edm := &dnstapMinimiser{ + log: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + + if err := os.WriteFile(dawgFile, nil, 0o644); err != nil { + t.Fatalf("WriteFile: %s", err) + } + nextModTime := fileInfo.ModTime().Add(time.Second) + if err := os.Chtimes(dawgFile, nextModTime, nextModTime); err != nil { + t.Fatalf("Chtimes: %s", err) + } + + if _, err := wkd.rotateTracker(edm, dawgFile, time.Now()); !errors.Is(err, errEmptyDawgFile) { + t.Fatalf("rotateTracker error have: %v, want: %v", err, errEmptyDawgFile) + } +} + func TestIgnoredClientIPsValid(t *testing.T) { discardLogger := slog.NewTextHandler(io.Discard, nil) logger := slog.New(discardLogger)