diff --git a/pkg/api/prometheus/prometheus.go b/pkg/api/prometheus/prometheus.go index 051d122f6..6bd4ee51f 100644 --- a/pkg/api/prometheus/prometheus.go +++ b/pkg/api/prometheus/prometheus.go @@ -17,6 +17,11 @@ const ( RequestRevalidationCounter = "souin_request_revalidation_counter" NoCachedResponseCounter = "souin_no_cached_response_counter" CachedResponseCounter = "souin_cached_response_counter" + SoftPurgeHitCounter = "souin_soft_purge_hit_counter" + SoftPurgeRefreshCounter = "souin_soft_purge_refresh_counter" + SoftPurgeRefreshSuccess = "souin_soft_purge_refresh_success_counter" + SoftPurgeRefreshFailure = "souin_soft_purge_refresh_failure_counter" + SoftPurgeRefreshDeduped = "souin_soft_purge_refresh_deduped_counter" AvgResponseTime = "souin_avg_response_time" ) @@ -103,5 +108,10 @@ func run() { push(counter, RequestRevalidationCounter, "Total revalidation request revalidation counter") push(counter, NoCachedResponseCounter, "No cached response counter") push(counter, CachedResponseCounter, "Cached response counter") + push(counter, SoftPurgeHitCounter, "Soft purge stale hit counter") + push(counter, SoftPurgeRefreshCounter, "Soft purge background refresh counter") + push(counter, SoftPurgeRefreshSuccess, "Soft purge background refresh success counter") + push(counter, SoftPurgeRefreshFailure, "Soft purge background refresh failure counter") + push(counter, SoftPurgeRefreshDeduped, "Soft purge background refresh deduplicated counter") push(average, AvgResponseTime, "Average response time") } diff --git a/pkg/api/prometheus/prometheus_test.go b/pkg/api/prometheus/prometheus_test.go index 197cbb072..a4903e9a3 100644 --- a/pkg/api/prometheus/prometheus_test.go +++ b/pkg/api/prometheus/prometheus_test.go @@ -13,8 +13,8 @@ func Test_Run(t *testing.T) { } run() - if len(registered) != 5 { - t.Error("The registered additional metrics array must have 5 items.") + if len(registered) != 10 { + t.Error("The registered additional metrics array must have 10 items.") } i, ok := registered[RequestCounter] @@ -53,6 +53,17 @@ func Test_Run(t *testing.T) { t.Errorf("The souin_cached_response_counter element must be a *prometheus.Counter object, %T given.", i) } + for _, key := range []string{SoftPurgeHitCounter, SoftPurgeRefreshCounter, SoftPurgeRefreshSuccess, SoftPurgeRefreshFailure, SoftPurgeRefreshDeduped} { + i, ok = registered[key] + if !ok { + t.Errorf("The registered array must have the %s key", key) + continue + } + if _, counterOK := i.(*prometheus.Counter); counterOK { + t.Errorf("The %s element must be a *prometheus.Counter object, %T given.", key, i) + } + } + i, ok = registered[AvgResponseTime] if !ok { t.Error("The registered array must have the souin_avg_response_time key") diff --git a/pkg/api/souin.go b/pkg/api/souin.go index d5888b50d..1cc337111 100644 --- a/pkg/api/souin.go +++ b/pkg/api/souin.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "regexp" "strings" "time" @@ -23,6 +24,7 @@ type SouinAPI struct { storers []types.Storer surrogateStorage providers.SurrogateInterface allowedMethods []string + logger core.Logger } type invalidationType string @@ -32,6 +34,10 @@ const ( uriPrefixInvalidationType invalidationType = "uri-prefix" originInvalidationType invalidationType = "origin" groupInvalidationType invalidationType = "group" + + SoftPurgeModeHeader = "Souin-Purge-Mode" + softPurgeModeValue = "soft" + softPurgeKeyPrefix = "SOFTPURGE_" ) type invalidation struct { @@ -62,46 +68,80 @@ func initializeSouin( storers, surrogateStorage, allowedMethods, + configuration.GetLogger(), + } +} + +func SoftPurgeMarkerKey(key string) string { + return softPurgeKeyPrefix + key +} + +func (s *SouinAPI) logInfof(template string, args ...any) { + if s.logger != nil { + s.logger.Infof(template, args...) + } +} + +func (s *SouinAPI) logWarnf(template string, args ...any) { + if s.logger != nil { + s.logger.Warnf(template, args...) } } +func IsSoftPurgeRequest(r *http.Request) bool { + return strings.EqualFold(r.Header.Get(SoftPurgeModeHeader), softPurgeModeValue) +} + // BulkDelete allow user to delete multiple items with regexp func (s *SouinAPI) BulkDelete(key string, purge bool) { key, _ = strings.CutPrefix(key, core.MappingKeyPrefix) for _, current := range s.storers { + infiniteStoreDuration := storageToInfiniteTTLMap[current.Name()] + decodedKey, _ := url.QueryUnescape(key) + + if purge { + current.Delete(SoftPurgeMarkerKey(key)) + } else if err := current.Set(SoftPurgeMarkerKey(key), []byte(time.Now().UTC().Format(time.RFC3339Nano)), infiniteStoreDuration); err != nil { + s.logWarnf("Unable to soft-purge cache key %s in %s: %v", decodedKey, current.Name(), err) + } else { + s.logInfof("Soft-purged cache key %s in %s", decodedKey, current.Name()) + } + if b := current.Get(core.MappingKeyPrefix + key); len(b) > 0 { var mapping core.StorageMapper if e := proto.Unmarshal(b, &mapping); e == nil { - for k := range mapping.GetMapping() { - current.Delete(k) - } - } - - if !purge { - newFreshTime := time.Now() - for k, v := range mapping.Mapping { - v.FreshTime = timestamppb.New(newFreshTime) - mapping.Mapping[k] = v - } + if purge { + for k := range mapping.GetMapping() { + current.Delete(k) + } + } else { + newFreshTime := time.Now().Add(-time.Second) + for k, v := range mapping.Mapping { + v.FreshTime = timestamppb.New(newFreshTime) + mapping.Mapping[k] = v + } - v, e := proto.Marshal(&mapping) - if e != nil { - fmt.Println("Impossible to re-encode the mapping", core.MappingKeyPrefix+key) - current.Delete(core.MappingKeyPrefix + key) + v, e := proto.Marshal(&mapping) + if e != nil { + fmt.Println("Impossible to re-encode the mapping", core.MappingKeyPrefix+key) + current.Delete(core.MappingKeyPrefix + key) + } else { + _ = current.Set(core.MappingKeyPrefix+key, v, infiniteStoreDuration) + } } - _ = current.Set(core.MappingKeyPrefix+key, v, storageToInfiniteTTLMap[current.Name()]) } } if purge { current.Delete(core.MappingKeyPrefix + key) + current.Delete(key) } - - current.Delete(key) } - s.Delete(key) + if purge { + s.Delete(key) + } } // Delete will delete a record into the provider cache system and will update the Souin API if enabled @@ -212,13 +252,14 @@ func (s *SouinAPI) purgeMapping() { // HandleRequest will handle the request func (s *SouinAPI) HandleRequest(w http.ResponseWriter, r *http.Request) { res := []byte{} - compile := regexp.MustCompile(s.GetBasePath()+"/.+").FindString(r.RequestURI) != "" + requestPath := r.URL.Path + compile := regexp.MustCompile(s.GetBasePath()+"/.+").FindString(requestPath) != "" switch r.Method { case http.MethodGet: - if regexp.MustCompile(s.GetBasePath()+"/surrogate_keys").FindString(r.RequestURI) != "" { + if regexp.MustCompile(s.GetBasePath()+"/surrogate_keys").FindString(requestPath) != "" { res, _ = json.Marshal(s.surrogateStorage.List()) } else if compile { - search := regexp.MustCompile(s.GetBasePath()+"/(.+)").FindAllStringSubmatch(r.RequestURI, -1)[0][1] + search := regexp.MustCompile(s.GetBasePath()+"/(.+)").FindAllStringSubmatch(requestPath, -1)[0][1] res, _ = json.Marshal(s.listKeys(search)) if len(res) == 2 { w.WriteHeader(http.StatusNotFound) @@ -303,35 +344,57 @@ func (s *SouinAPI) HandleRequest(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusOK) case "PURGE": + softPurge := IsSoftPurgeRequest(r) if compile { keysRg := regexp.MustCompile(s.GetBasePath() + "/(.+)") flushRg := regexp.MustCompile(s.GetBasePath() + "/flush$") mappingRg := regexp.MustCompile(s.GetBasePath() + "/mapping$") - if flushRg.FindString(r.RequestURI) != "" { + if flushRg.FindString(requestPath) != "" { + if softPurge { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("soft purge is not supported for flush")) + return + } for _, current := range s.storers { current.DeleteMany(".+") } - e := s.surrogateStorage.Destruct() - if e != nil { - fmt.Printf("Error while purging the surrogate keys: %+v.", e) + if s.surrogateStorage != nil { + s.surrogateStorage.Clear() } fmt.Println("Successfully clear the cache and the surrogate keys storage.") - } else if mappingRg.FindString(r.RequestURI) != "" { + } else if mappingRg.FindString(requestPath) != "" { s.purgeMapping() } else { - submatch := keysRg.FindAllStringSubmatch(r.RequestURI, -1)[0][1] - for _, current := range s.storers { - current.DeleteMany(submatch) + submatch := keysRg.FindAllStringSubmatch(requestPath, -1)[0][1] + if softPurge { + s.BulkDelete(submatch, false) + } else { + for _, current := range s.storers { + current.DeleteMany(submatch) + } } } } else { - ck, surrogateKeys := s.surrogateStorage.Purge(r.Header) - for _, k := range ck { - s.BulkDelete(k, true) + if s.surrogateStorage == nil { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte("surrogate storage is not initialized")) + return } - for _, k := range surrogateKeys { - s.BulkDelete("SURROGATE_"+k, true) + ck, surrogateKeys := s.surrogateStorage.Purge(r.Header) + if softPurge { + s.logInfof("Soft purge requested for surrogate keys: %s", strings.Join(surrogateKeys, ", ")) + for _, k := range ck { + s.BulkDelete(k, false) + } + } else { + s.logInfof("Hard purge requested for surrogate keys: %s", strings.Join(surrogateKeys, ", ")) + for _, k := range ck { + s.BulkDelete(k, true) + } + for _, k := range surrogateKeys { + s.BulkDelete("SURROGATE_"+k, true) + } } } w.WriteHeader(http.StatusNoContent) diff --git a/pkg/api/souin_test.go b/pkg/api/souin_test.go new file mode 100644 index 000000000..a9150bf54 --- /dev/null +++ b/pkg/api/souin_test.go @@ -0,0 +1,183 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/darkweak/souin/pkg/storage" + "github.com/darkweak/souin/pkg/storage/types" + "github.com/darkweak/souin/pkg/surrogate/providers" + "github.com/darkweak/souin/tests" + "github.com/darkweak/storages/core" + "google.golang.org/protobuf/proto" +) + +func newTestSouinAPI(t *testing.T) (*SouinAPI, types.Storer) { + t.Helper() + + memoryStorer, _ := storage.Factory(tests.MockConfiguration(tests.BaseConfiguration)) + core.RegisterStorage(memoryStorer) + + cfg := tests.MockConfiguration(tests.BaseConfiguration) + + return initializeSouin(cfg, []types.Storer{memoryStorer}, nil), memoryStorer +} + +func TestIsSoftPurgeRequest(t *testing.T) { + reqWithHeader, _ := http.NewRequest("PURGE", "http://example.com/souin-api/souin", nil) + reqWithHeader.Header.Set(SoftPurgeModeHeader, "soft") + if !IsSoftPurgeRequest(reqWithHeader) { + t.Fatal("request with soft purge header should be detected as soft purge") + } + + reqWithQuery, _ := http.NewRequest("PURGE", "http://example.com/souin-api/souin?mode=soft", nil) + if IsSoftPurgeRequest(reqWithQuery) { + t.Fatal("request with only soft purge query parameter should not be detected as soft purge") + } + + reqHard, _ := http.NewRequest("PURGE", "http://example.com/souin-api/souin", nil) + reqHard.Header.Set(SoftPurgeModeHeader, "hard") + if IsSoftPurgeRequest(reqHard) { + t.Fatal("request with hard purge mode should not be detected as soft purge") + } +} + +func TestBulkDeleteSoftPurgePreservesEntryAndMarksMappingStale(t *testing.T) { + api, storer := newTestSouinAPI(t) + cacheKey := "GET-http-example.com-/resource" + mappingKey := core.MappingKeyPrefix + cacheKey + + if err := storer.Set(cacheKey, []byte("payload"), types.OneYearDuration); err != nil { + t.Fatalf("unable to store cache entry: %v", err) + } + + mapping, err := core.MappingUpdater( + cacheKey, + nil, + api.logger, + time.Now(), + time.Now().Add(time.Minute), + time.Now().Add(2*time.Minute), + nil, + "", + cacheKey, + ) + if err != nil { + t.Fatalf("unable to create mapping: %v", err) + } + + if err = storer.Set(mappingKey, mapping, types.OneYearDuration); err != nil { + t.Fatalf("unable to store mapping: %v", err) + } + + api.BulkDelete(cacheKey, false) + + if got := storer.Get(cacheKey); len(got) == 0 { + t.Fatal("soft purge should preserve the stored cache entry") + } + + if got := storer.Get(SoftPurgeMarkerKey(cacheKey)); len(got) == 0 { + t.Fatal("soft purge should create a soft purge marker") + } + + rawMapping := storer.Get(mappingKey) + if len(rawMapping) == 0 { + t.Fatal("soft purge should preserve the mapping entry") + } + + decoded := &core.StorageMapper{} + if err = proto.Unmarshal(rawMapping, decoded); err != nil { + t.Fatalf("unable to decode mapping: %v", err) + } + + item := decoded.GetMapping()[cacheKey] + if item == nil { + t.Fatal("soft purge should preserve the cache key mapping") + } + + if item.GetFreshTime().AsTime().After(time.Now()) { + t.Fatal("soft purge should mark the mapping as no longer fresh") + } +} + +func TestBulkDeleteHardPurgeRemovesMarkerMappingAndEntry(t *testing.T) { + api, storer := newTestSouinAPI(t) + cacheKey := "GET-http-example.com-/resource" + + _ = storer.Set(cacheKey, []byte("payload"), types.OneYearDuration) + _ = storer.Set(core.MappingKeyPrefix+cacheKey, []byte("mapping"), types.OneYearDuration) + _ = storer.Set(SoftPurgeMarkerKey(cacheKey), []byte("marker"), types.OneYearDuration) + + api.BulkDelete(cacheKey, true) + + if got := storer.Get(cacheKey); len(got) != 0 { + t.Fatal("hard purge should remove the stored cache entry") + } + + if got := storer.Get(core.MappingKeyPrefix + cacheKey); len(got) != 0 { + t.Fatal("hard purge should remove the mapping entry") + } + + if got := storer.Get(SoftPurgeMarkerKey(cacheKey)); len(got) != 0 { + t.Fatal("hard purge should remove the soft purge marker") + } +} + +func TestSoftPurgeFlushIsRejected(t *testing.T) { + api, _ := newTestSouinAPI(t) + api.basePath = "/souin" + + req := httptest.NewRequest("PURGE", "/souin/flush", nil) + req.Header.Set(SoftPurgeModeHeader, "soft") + res := httptest.NewRecorder() + + api.HandleRequest(res, req) + + if res.Code != http.StatusBadRequest { + t.Fatalf("expected soft purge flush to return %d, got %d", http.StatusBadRequest, res.Code) + } + + if body := res.Body.String(); body != "soft purge is not supported for flush" { + t.Fatalf("unexpected soft purge flush response body %q", body) + } +} + +func TestFlushClearsSharedSurrogateStorageWithoutResettingStorer(t *testing.T) { + cfg := tests.MockConfiguration(tests.BaseConfiguration) + storer, _ := storage.Factory(cfg) + core.RegisterStorage(storer) + + surrogateStorage := providers.SurrogateFactory(cfg, storer.Name()) + api := initializeSouin(cfg, []types.Storer{storer}, surrogateStorage) + api.basePath = "/souin" + + cacheKey := "GET-http-example.com-/" + if err := storer.Set(cacheKey, []byte("payload"), types.OneYearDuration); err != nil { + t.Fatalf("unable to store cache entry: %v", err) + } + if err := storer.Set("SURROGATE_blog-1-home", []byte(","+cacheKey), types.OneYearDuration); err != nil { + t.Fatalf("unable to store surrogate entry: %v", err) + } + + req := httptest.NewRequest("PURGE", "/souin/flush", nil) + res := httptest.NewRecorder() + api.HandleRequest(res, req) + + if res.Code != http.StatusNoContent { + t.Fatalf("expected flush to return %d, got %d", http.StatusNoContent, res.Code) + } + + if got := storer.Get(cacheKey); len(got) != 0 { + t.Fatal("flush should remove cached entries") + } + + if got := storer.Get("SURROGATE_blog-1-home"); len(got) != 0 { + t.Fatal("flush should remove surrogate entries") + } + + if err := storer.Set("still-open", []byte("ok"), time.Minute); err != nil { + t.Fatalf("flush should not reset the shared storage: %v", err) + } +} diff --git a/pkg/middleware/README.md b/pkg/middleware/README.md new file mode 100644 index 000000000..49b889e20 --- /dev/null +++ b/pkg/middleware/README.md @@ -0,0 +1,46 @@ +# HTTP cache middleware + +## What is the middleware? +The middleware is the HTTP cache entrypoint of Souin. It sits in front of the upstream application, computes cache keys, serves cached responses when possible, revalidates stale content, stores fresh upstream responses, and exposes the cache status to the client. + +## How to deal with it? + +### In a regular HTTP request +The client sends a cacheable request to the server. The middleware computes the cache key, checks the configured storages, and either: +* serves a fresh cached response +* serves a stale response according to the cache directives +* forwards the request upstream and stores the response when it is cacheable + +The middleware also sets the `Cache-Status` response header so the client can understand if the response was a cache hit, a stale hit, a revalidation, or a miss. + +### In a soft purge request +The client sends a `PURGE` request to the API endpoint, either for surrogate keys or for a direct cache key pattern, and sets: +* `Souin-Purge-Mode: soft` + +The middleware-related behavior is different from a hard purge: +* the cached object is kept in storage +* the associated mapping is marked stale +* a soft purge marker is attached to the stored cache entry + +The next matching request is then served as stale immediately. If the cached response has validators or `stale-while-revalidate`, the middleware also starts a detached background refresh. + +### During the background refresh +When a soft-purged response is served, the middleware only refreshes it in background when the cached response can be revalidated or refreshed safely: +* if the cached response has `ETag` or `Last-Modified`, it prefers a conditional revalidation +* if the cached response exposes `stale-while-revalidate`, it can do a background fetch +* if neither validators nor `stale-while-revalidate` are present, the stale response is still served but no refresh is started + +Concurrent refreshes are deduplicated per stored key so only one background refresh runs for the same soft-purged object at a time. + +### In the client response +When the response comes from a soft-purged entry, the middleware returns it as stale and adds a dedicated `Cache-Status` detail such as: +* `SOFT-PURGE-REVALIDATE` +* `SOFT-PURGE-SWR` +* `SOFT-PURGE-SIE` +* `SOFT-PURGED` + +Once the background refresh succeeds, the middleware stores the refreshed response and clears the soft purge marker so future requests behave like normal cache hits again. + +## Hard purge vs soft purge +A hard purge removes the cached response and the related mapping immediately. +A soft purge keeps the cached response available for stale serving and marks the mapping stale. The next request serves that stale response immediately, and only triggers a background refresh when validators or `stale-while-revalidate` allow it. diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go index c9a02d07b..e5aad19b7 100644 --- a/pkg/middleware/middleware.go +++ b/pkg/middleware/middleware.go @@ -230,6 +230,7 @@ type SouinBaseHandler struct { context *context.Context singleflightPool singleflight.Group bufPool *sync.Pool + backgroundRefreshes sync.Map storersLen int } @@ -399,8 +400,7 @@ func (s *SouinBaseHandler) Store( return nil } res.Header.Set(rfc.StoredLengthHeader, res.Header.Get("Content-Length")) - response, err := httputil.DumpResponse(&res, true) - if err == nil && (bLen > 0 || rq.Method == http.MethodHead || canStatusCodeEmptyContent(statusCode) || s.hasAllowedAdditionalStatusCodesToCache(statusCode)) { + if bLen > 0 || rq.Method == http.MethodHead || canStatusCodeEmptyContent(statusCode) || s.hasAllowedAdditionalStatusCodesToCache(statusCode) { variedHeaders, isVaryStar := rfc.VariedHeaderAllCommaSepValues(res.Header) if isVaryStar { // "Implies that the response is uncacheable" @@ -412,6 +412,14 @@ func (s *SouinBaseHandler) Store( variedKey = fmt.Sprint(xxhash.Sum64String(variedKey)) } s.Configuration.GetLogger().Debugf("Store the response for %s with duration %v", variedKey, ma) + res.Header.Set(rfc.StoredKeyHeader, variedKey) + response, err := httputil.DumpResponse(&res, true) + if err != nil { + status += "; detail=UPSTREAM-ERROR-OR-EMPTY-RESPONSE" + customWriter.Header().Set("Cache-Status", status+"; key="+rfc.GetCacheKeyFromCtx(rq.Context())) + + return nil + } var wg sync.WaitGroup mu := sync.Mutex{} @@ -443,6 +451,7 @@ func (s *SouinBaseHandler) Store( res.Header.Get("Etag"), ma, variedKey, ) == nil { + s.clearSoftPurgeMarker(overridedStorer, variedKey) s.Configuration.GetLogger().Debugf("Stored the key %s in the %s provider", variedKey, overridedStorer.Name()) res.Request = rq } else { @@ -461,6 +470,7 @@ func (s *SouinBaseHandler) Store( currentRes.Header.Get("Etag"), ma, variedKey, ) == nil { + s.clearSoftPurgeMarker(currentStorer, variedKey) s.Configuration.GetLogger().Debugf("Stored the key %s in the %s provider", variedKey, currentStorer.Name()) currentRes.Request = rq } else { @@ -488,7 +498,6 @@ func (s *SouinBaseHandler) Store( } } } - } else { status += "; detail=UPSTREAM-ERROR-OR-EMPTY-RESPONSE" } @@ -878,6 +887,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n validator := rfc.ParseRequest(req) var fresh, stale *http.Response var storerName string + var matchedStorer types.Storer finalKey := cachedKey if req.Context().Value(context.Hashed).(bool) { finalKey = fmt.Sprint(xxhash.Sum64String(finalKey)) @@ -887,6 +897,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n if fresh != nil || stale != nil { storerName = currentStorer.Name() + matchedStorer = currentStorer s.Configuration.GetLogger().Debugf("Found at least one valid response in the %s storage", storerName) break } @@ -901,6 +912,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n }() response := fresh + if s.isSoftPurgedResponse(matchedStorer, response) { + return s.serveSoftPurgedResponse(customWriter, response, storerName, req, next, requestCc, cachedKey, uri) + } if validator.ResponseETag != "" && validator.Matched { rfc.SetCacheStatusHeader(response, storerName) @@ -954,7 +968,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n return err } - } else if !requestCc.OnlyIfCached && (requestCc.MaxStaleSet || requestCc.MaxStale > -1) { + } else if s.isSoftPurgedResponse(matchedStorer, stale) { + return s.serveSoftPurgedResponse(customWriter, stale, storerName, req, next, requestCc, cachedKey, uri) + } else if stale != nil { response := stale if nil != response && (!modeContext.Strict || rfc.ValidateCacheControl(response, requestCc)) { @@ -963,27 +979,12 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n responseCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, "Cache-Control")) if responseCc.StaleWhileRevalidate > 0 { - for h, v := range response.Header { - customWriter.Header()[h] = v + if !modeContext.Strict || rfc.ValidateMaxAgeCachedStaleResponse(requestCc, responseCc, response, int(addTime.Seconds())) != nil { + return s.serveStaleWhileRevalidateResponse(customWriter, response, req, next, validator, requestCc, cachedKey, uri) } - customWriter.WriteHeader(response.StatusCode) - rfc.HitStaleCache(&response.Header) - customWriter.handleBuffer(func(b *bytes.Buffer) { - _, _ = io.Copy(b, response.Body) - }) - _, err := customWriter.Send() - customWriter = NewCustomWriter(req, rw, bufPool) - go func(v *core.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) { - _ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk, goUri) - }(validator, customWriter, req, next, requestCc, cachedKey, uri) - buf := s.bufPool.Get().(*bytes.Buffer) - buf.Reset() - defer s.bufPool.Put(buf) - - return err } - if modeContext.Bypass_response || responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation { + if !requestCc.OnlyIfCached && (requestCc.MaxStaleSet || requestCc.MaxStale > -1) && (modeContext.Bypass_response || responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation) { req.Header["If-None-Match"] = append(req.Header["If-None-Match"], validator.ResponseETag) err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri) statusCode := customWriter.GetStatusCode() diff --git a/pkg/middleware/soft_purge.go b/pkg/middleware/soft_purge.go new file mode 100644 index 000000000..65436d749 --- /dev/null +++ b/pkg/middleware/soft_purge.go @@ -0,0 +1,209 @@ +package middleware + +import ( + "bytes" + baseCtx "context" + "io" + "maps" + "net/http" + "net/http/httptest" + + "github.com/darkweak/souin/pkg/api" + "github.com/darkweak/souin/pkg/api/prometheus" + "github.com/darkweak/souin/pkg/rfc" + "github.com/darkweak/souin/pkg/storage/types" + "github.com/pquerna/cachecontrol/cacheobject" +) + +func (s *SouinBaseHandler) clearSoftPurgeMarker(storer types.Storer, storedKey string) { + if storedKey == "" { + return + } + + storer.Delete(api.SoftPurgeMarkerKey(storedKey)) +} + +func (s *SouinBaseHandler) isSoftPurgedResponse(storer types.Storer, response *http.Response) bool { + if storer == nil || response == nil { + return false + } + + storedKey := response.Header.Get(rfc.StoredKeyHeader) + if storedKey == "" { + return false + } + + return len(storer.Get(api.SoftPurgeMarkerKey(storedKey))) > 0 +} + +func hasSoftPurgeValidators(response *http.Response) bool { + if response == nil { + return false + } + + return response.Header.Get("Etag") != "" || response.Header.Get("Last-Modified") != "" +} + +func (s *SouinBaseHandler) getSoftPurgeDetail( + response *http.Response, + requestCc *cacheobject.RequestCacheDirectives, +) (string, bool) { + if response == nil { + return "SOFT-PURGED", false + } + + if hasSoftPurgeValidators(response) { + return "SOFT-PURGE-REVALIDATE", true + } + + responseCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, "Cache-Control")) + if responseCc != nil && responseCc.StaleWhileRevalidate > 0 { + return "SOFT-PURGE-SWR", true + } + + if (responseCc != nil && responseCc.StaleIfError > -1) || requestCc.StaleIfError > 0 { + return "SOFT-PURGE-SIE", false + } + + return "SOFT-PURGED", false +} + +func cloneBodyForSoftPurge(response *http.Response) []byte { + if response == nil || response.Body == nil { + return nil + } + + body, _ := io.ReadAll(response.Body) + response.Body = io.NopCloser(bytes.NewReader(body)) + + return body +} + +func mergeRevalidatedHeaders(staleHeaders, revalidatedHeaders http.Header) http.Header { + merged := staleHeaders.Clone() + maps.Copy(merged, revalidatedHeaders) + + return merged +} + +func (s *SouinBaseHandler) storeRevalidatedStaleResponse( + req *http.Request, + requestCc *cacheobject.RequestCacheDirectives, + cachedKey string, + uri string, + statusCode int, + headers http.Header, + body []byte, +) error { + recorder := httptest.NewRecorder() + customWriter := NewCustomWriter(req, recorder, new(bytes.Buffer)) + maps.Copy(customWriter.Header(), headers) + customWriter.WriteHeader(statusCode) + _, _ = customWriter.Write(body) + + return s.Store(customWriter, req, requestCc, cachedKey, uri) +} + +func (s *SouinBaseHandler) triggerSoftPurgeBackgroundRefresh( + storedKey string, + req *http.Request, + next handlerFunc, + requestCc *cacheobject.RequestCacheDirectives, + cachedKey string, + uri string, + response *http.Response, + body []byte, +) { + if storedKey == "" { + return + } + + if _, loaded := s.backgroundRefreshes.LoadOrStore(storedKey, struct{}{}); loaded { + s.Configuration.GetLogger().Infof("Skipping duplicate background refresh for soft-purged key %s", storedKey) + prometheus.Increment(prometheus.SoftPurgeRefreshDeduped) + return + } + + prometheus.Increment(prometheus.SoftPurgeRefreshCounter) + backgroundReq := req.Clone(baseCtx.WithoutCancel(req.Context())) + backgroundReq.Header = req.Header.Clone() + backgroundReq.Header.Del("Cache-Control") + + if etag := response.Header.Get("Etag"); etag != "" { + backgroundReq.Header.Set("If-None-Match", etag) + } + if lastModified := response.Header.Get("Last-Modified"); lastModified != "" { + backgroundReq.Header.Set("If-Modified-Since", lastModified) + } + + go func() { + defer s.backgroundRefreshes.Delete(storedKey) + + s.Configuration.GetLogger().Infof("Starting background refresh for soft-purged key %s", storedKey) + recorder := httptest.NewRecorder() + backgroundWriter := NewCustomWriter(backgroundReq, recorder, new(bytes.Buffer)) + + refreshErr := s.Upstream(backgroundWriter, backgroundReq, next, requestCc, cachedKey, uri, false) + statusCode := backgroundWriter.GetStatusCode() + + if refreshErr != nil { + s.Configuration.GetLogger().Warnf("Background refresh failed for soft-purged key %s: %v", storedKey, refreshErr) + prometheus.Increment(prometheus.SoftPurgeRefreshFailure) + return + } + + if statusCode == http.StatusNotModified { + mergedHeaders := mergeRevalidatedHeaders(response.Header, backgroundWriter.Header()) + if err := s.storeRevalidatedStaleResponse(backgroundReq, requestCc, cachedKey, uri, response.StatusCode, mergedHeaders, body); err != nil { + s.Configuration.GetLogger().Warnf("Background 304 revalidation failed for soft-purged key %s: %v", storedKey, err) + prometheus.Increment(prometheus.SoftPurgeRefreshFailure) + return + } + + s.Configuration.GetLogger().Infof("Background refresh revalidated soft-purged key %s with 304", storedKey) + prometheus.Increment(prometheus.SoftPurgeRefreshSuccess) + return + } + + s.Configuration.GetLogger().Infof("Background refresh completed for soft-purged key %s with status %d", storedKey, statusCode) + prometheus.Increment(prometheus.SoftPurgeRefreshSuccess) + }() +} + +func (s *SouinBaseHandler) serveSoftPurgedResponse( + customWriter *CustomWriter, + response *http.Response, + storerName string, + req *http.Request, + next handlerFunc, + requestCc *cacheobject.RequestCacheDirectives, + cachedKey string, + uri string, +) error { + storedKey := response.Header.Get(rfc.StoredKeyHeader) + body := cloneBodyForSoftPurge(response) + detail, shouldRefresh := s.getSoftPurgeDetail(response, requestCc) + rfc.SetCacheStatusHeader(response, storerName) + rfc.HitStaleCache(&response.Header) + response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+"; detail="+detail) + maps.Copy(customWriter.Header(), response.Header) + customWriter.WriteHeader(response.StatusCode) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = b.Write(body) + }) + + _, err := customWriter.Send() + prometheus.Increment(prometheus.CachedResponseCounter) + prometheus.Increment(prometheus.SoftPurgeHitCounter) + if err != nil { + return err + } + + if shouldRefresh { + s.triggerSoftPurgeBackgroundRefresh(storedKey, req, next, requestCc, cachedKey, uri, response, body) + } else { + s.Configuration.GetLogger().Infof("Soft-purged key %s served stale without background refresh because no validators or stale-while-revalidate directives were present", storedKey) + } + + return nil +} diff --git a/pkg/middleware/stale_while_revalidate.go b/pkg/middleware/stale_while_revalidate.go new file mode 100644 index 000000000..15753dd96 --- /dev/null +++ b/pkg/middleware/stale_while_revalidate.go @@ -0,0 +1,55 @@ +package middleware + +import ( + "bytes" + baseCtx "context" + "io" + "maps" + "net/http" + "net/http/httptest" + + "github.com/pquerna/cachecontrol/cacheobject" + + "github.com/darkweak/souin/pkg/rfc" + "github.com/darkweak/storages/core" +) + +func (s *SouinBaseHandler) triggerBackgroundRevalidation( + validator *core.Revalidator, + req *http.Request, + next handlerFunc, + requestCc *cacheobject.RequestCacheDirectives, + cachedKey string, + uri string, +) { + backgroundReq := req.Clone(baseCtx.WithoutCancel(req.Context())) + backgroundReq.Header = req.Header.Clone() + + go func() { + recorder := httptest.NewRecorder() + backgroundWriter := NewCustomWriter(backgroundReq, recorder, new(bytes.Buffer)) + _ = s.Revalidate(validator, next, backgroundWriter, backgroundReq, requestCc, cachedKey, uri) + }() +} + +func (s *SouinBaseHandler) serveStaleWhileRevalidateResponse( + customWriter *CustomWriter, + response *http.Response, + req *http.Request, + next handlerFunc, + validator *core.Revalidator, + requestCc *cacheobject.RequestCacheDirectives, + cachedKey string, + uri string, +) error { + customWriter.WriteHeader(response.StatusCode) + rfc.HitStaleCache(&response.Header) + maps.Copy(customWriter.Header(), response.Header) + customWriter.handleBuffer(func(b *bytes.Buffer) { + _, _ = io.Copy(b, response.Body) + }) + _, err := customWriter.Send() + s.triggerBackgroundRevalidation(validator, req, next, requestCc, cachedKey, uri) + + return err +} diff --git a/pkg/middleware/writer.go b/pkg/middleware/writer.go index c4e836861..4e64bdaba 100644 --- a/pkg/middleware/writer.go +++ b/pkg/middleware/writer.go @@ -217,6 +217,7 @@ Content-Range: bytes %d-%d/%d r.Header().Del(rfc.StoredLengthHeader) r.Header().Del(rfc.StoredTTLHeader) + r.Header().Del(rfc.StoredKeyHeader) if !r.headersSent { r.Rw.WriteHeader(r.GetStatusCode()) diff --git a/pkg/rfc/age.go b/pkg/rfc/age.go index 03f12acbc..1693f54ea 100644 --- a/pkg/rfc/age.go +++ b/pkg/rfc/age.go @@ -36,6 +36,12 @@ func ValidateMaxAgeCachedStaleResponse(co *cacheobject.RequestCacheDirectives, r return res } + if resCo != nil && resCo.StaleWhileRevalidate > -1 { + if response := validateMaxAgeCachedResponse(res, int(resCo.StaleWhileRevalidate), addTime); response != nil { + return response + } + } + if resCo != nil && (resCo.StaleIfError > -1 || co.StaleIfError > 0) { if resCo.StaleIfError > -1 { if response := validateMaxAgeCachedResponse(res, int(resCo.StaleIfError), addTime); response != nil { diff --git a/pkg/rfc/age_test.go b/pkg/rfc/age_test.go index b1fc5fe3d..2df7cab19 100644 --- a/pkg/rfc/age_test.go +++ b/pkg/rfc/age_test.go @@ -80,3 +80,30 @@ func Test_ValidateMaxStaleCachedResponse(t *testing.T) { t.Errorf("The max-stale validation should return the response instead of nil with the given parameters:\nRequestCacheDirectives: %+v\nResponse: %+v\n", coWithMaxStale, expiredMaxAge) } } + +func Test_ValidateMaxStaleCachedResponseWithStaleWhileRevalidate(t *testing.T) { + coWithoutMaxStale := cacheobject.RequestCacheDirectives{ + MaxStale: -1, + } + responseCc := cacheobject.ResponseCacheDirectives{ + StaleWhileRevalidate: 30, + } + + validStaleWhileRevalidate := http.Response{ + Header: http.Header{ + "Age": []string{"20"}, + }, + } + expiredStaleWhileRevalidate := http.Response{ + Header: http.Header{ + "Age": []string{"50"}, + }, + } + + if ValidateMaxAgeCachedStaleResponse(&coWithoutMaxStale, &responseCc, &validStaleWhileRevalidate, 1) == nil { + t.Errorf("The stale-while-revalidate validation should return the response instead of nil with the given parameters:\nRequestCacheDirectives: %+v\nResponseCacheDirectives: %+v\nResponse: %+v\n", coWithoutMaxStale, responseCc, validStaleWhileRevalidate) + } + if ValidateMaxAgeCachedStaleResponse(&coWithoutMaxStale, &responseCc, &expiredStaleWhileRevalidate, 1) != nil { + t.Errorf("The stale-while-revalidate validation should return nil instead of the response with the given parameters:\nRequestCacheDirectives: %+v\nResponseCacheDirectives: %+v\nResponse: %+v\n", coWithoutMaxStale, responseCc, expiredStaleWhileRevalidate) + } +} diff --git a/pkg/rfc/cache_status.go b/pkg/rfc/cache_status.go index 804f8da16..c21155d26 100644 --- a/pkg/rfc/cache_status.go +++ b/pkg/rfc/cache_status.go @@ -15,6 +15,7 @@ import ( const ( StoredTTLHeader = "X-Souin-Stored-TTL" StoredLengthHeader = "X-Souin-Stored-Length" + StoredKeyHeader = "X-Souin-Stored-Key" ) var emptyHeaders = []string{"Expires", "Last-Modified"} diff --git a/pkg/surrogate/providers/common.go b/pkg/surrogate/providers/common.go index 802156920..40ef35a49 100644 --- a/pkg/surrogate/providers/common.go +++ b/pkg/surrogate/providers/common.go @@ -304,6 +304,13 @@ func (s *baseStorage) List() map[string]string { return s.Storage.MapKeys(surrogatePrefix) } +// Clear removes only surrogate index entries without resetting the underlying storage. +func (s *baseStorage) Clear() { + for key := range s.List() { + s.Storage.Delete(surrogatePrefix + key) + } +} + // Destruct method will shutdown properly the provider func (s *baseStorage) Destruct() error { return s.Storage.Reset() diff --git a/pkg/surrogate/providers/types.go b/pkg/surrogate/providers/types.go index bbd299199..03098ace6 100644 --- a/pkg/surrogate/providers/types.go +++ b/pkg/surrogate/providers/types.go @@ -20,5 +20,6 @@ type SurrogateInterface interface { ParseHeaders(string) []string List() map[string]string candidateStore(string) bool + Clear() Destruct() error } diff --git a/plugins/caddy/admin.go b/plugins/caddy/admin.go index a95b373f8..456b94164 100644 --- a/plugins/caddy/admin.go +++ b/plugins/caddy/admin.go @@ -35,8 +35,9 @@ func (adminAPI) CaddyModule() caddy.ModuleInfo { func (a *adminAPI) handleAPIEndpoints(writer http.ResponseWriter, request *http.Request) error { if a.InternalEndpointHandlers != nil { + requestPath := request.URL.Path for k, handler := range *a.InternalEndpointHandlers.Handlers { - if strings.Contains(request.RequestURI, k) { + if strings.Contains(requestPath, k) { handler(writer, request) return nil } @@ -63,6 +64,10 @@ func (a *adminAPI) Provision(ctx caddy.Context) error { currentApp := app.(*SouinApp) item := <-currentApp.onMiddlewareLoaded() + surrogateStorage := item.SurrogateStorage + if surrogateStorage == nil { + surrogateStorage = currentApp.SurrogateStorage + } config := Configuration{ API: item.API, @@ -75,7 +80,8 @@ func (a *adminAPI) Provision(ctx caddy.Context) error { }, }, } - a.InternalEndpointHandlers = api.GenerateHandlerMap(&config, currentApp.Storers, item.SurrogateStorage) + config.SetLogger(a.logger) + a.InternalEndpointHandlers = api.GenerateHandlerMap(&config, currentApp.Storers, surrogateStorage) }() return nil diff --git a/plugins/caddy/httpcache.go b/plugins/caddy/httpcache.go index 41e0226d1..9cc2fd74a 100644 --- a/plugins/caddy/httpcache.go +++ b/plugins/caddy/httpcache.go @@ -90,8 +90,8 @@ func (SouinCaddyMiddleware) CaddyModule() caddy.ModuleInfo { // ServeHTTP implements caddyhttp.MiddlewareHandler. func (s *SouinCaddyMiddleware) ServeHTTP(rw http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { - return s.SouinBaseHandler.ServeHTTP(rw, r, func(w http.ResponseWriter, _ *http.Request) error { - return next.ServeHTTP(w, r) + return s.SouinBaseHandler.ServeHTTP(rw, r, func(w http.ResponseWriter, req *http.Request) error { + return next.ServeHTTP(w, req) }) } diff --git a/plugins/caddy/httpcache_softpurge_test.go b/plugins/caddy/httpcache_softpurge_test.go new file mode 100644 index 000000000..bdf35fa0d --- /dev/null +++ b/plugins/caddy/httpcache_softpurge_test.go @@ -0,0 +1,273 @@ +package httpcache + +import ( + "io" + "net/http" + "strings" + "sync" + "testing" + "time" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +type softPurgeOrigin struct { + mu sync.Mutex + version string + hits int + etag string + conditionalHits int +} + +func (s *softPurgeOrigin) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + + s.hits++ + if s.etag != "" && r.Header.Get("If-None-Match") == s.etag { + s.conditionalHits++ + w.Header().Set("Cache-Control", "max-age=300, stale-while-revalidate=30") + w.Header().Set("Surrogate-Key", "post-1") + w.Header().Set("Etag", s.etag) + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("Cache-Control", "max-age=300, stale-while-revalidate=30") + w.Header().Set("Surrogate-Key", "post-1") + if s.etag != "" { + w.Header().Set("Etag", s.etag) + } + _, _ = w.Write([]byte(s.version)) +} + +func (s *softPurgeOrigin) setVersion(version string) { + s.mu.Lock() + defer s.mu.Unlock() + s.version = version + s.etag = `"` + version + `"` +} + +func (s *softPurgeOrigin) conditionalHitCount() int { + s.mu.Lock() + defer s.mu.Unlock() + + return s.conditionalHits +} + +func TestSoftPurgeServesStaleThenRefreshesInBackground(t *testing.T) { + tester := caddytest.NewTester(t) + tester.InitServer(` + { + admin localhost:2999 + http_port 9080 + cache { + api { + souin + } + stale 1m + } + } + localhost:9080 { + route /soft-purge { + cache + reverse_proxy localhost:9088 + } + }`, "caddyfile") + + origin := &softPurgeOrigin{version: "version-1"} + go func() { + _ = http.ListenAndServe(":9088", origin) + }() + + time.Sleep(time.Second) + + resp1, _ := tester.AssertGetResponse("http://localhost:9080/soft-purge", http.StatusOK, "version-1") + if resp1.Header.Get("Cache-Status") != "Souin; fwd=uri-miss; stored; key=GET-http-localhost:9080-/soft-purge" { + t.Fatalf("unexpected initial Cache-Status header %q", resp1.Header.Get("Cache-Status")) + } + + origin.setVersion("version-2") + + purgeReq, _ := http.NewRequest("PURGE", "http://localhost:2999/souin-api/souin", nil) + purgeReq.Header.Set("Surrogate-Key", "post-1") + purgeReq.Header.Set("Souin-Purge-Mode", "soft") + _, _ = tester.AssertResponse(purgeReq, http.StatusNoContent, "") + + resp2, _ := tester.AssertGetResponse("http://localhost:9080/soft-purge", http.StatusOK, "version-1") + cacheStatus := resp2.Header.Get("Cache-Status") + if cacheStatus == "" || !containsAll(cacheStatus, "Souin; hit;", "; fwd=stale", "; detail=SOFT-PURGE-SWR") { + t.Fatalf("unexpected soft purge Cache-Status header %q", cacheStatus) + } + + deadline := time.Now().Add(5 * time.Second) + for { + resp, err := http.Get("http://localhost:9080/soft-purge") + if err != nil { + t.Fatalf("unable to fetch refreshed response: %v", err) + } + + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + + if string(body) == "version-2" && containsAll(resp.Header.Get("Cache-Status"), "Souin; hit;", "key=GET-http-localhost:9080-/soft-purge") && !strings.Contains(resp.Header.Get("Cache-Status"), "SOFT-PURGED") { + break + } + + if time.Now().After(deadline) { + t.Fatalf("background refresh did not replace the soft-purged object in time, last body %q last Cache-Status %q", string(body), resp.Header.Get("Cache-Status")) + } + + time.Sleep(100 * time.Millisecond) + } +} + +func containsAll(value string, parts ...string) bool { + for _, part := range parts { + if !strings.Contains(value, part) { + return false + } + } + + return true +} + +func TestSoftPurgeConditionalRevalidationWithNotModified(t *testing.T) { + tester := caddytest.NewTester(t) + tester.InitServer(` + { + admin localhost:2999 + http_port 9080 + cache { + api { + souin + } + stale 1m + } + } + localhost:9080 { + route /soft-purge-conditional { + cache + reverse_proxy localhost:9089 + } + }`, "caddyfile") + + origin := &softPurgeOrigin{} + origin.setVersion("version-1") + go func() { + _ = http.ListenAndServe(":9089", origin) + }() + + time.Sleep(time.Second) + + _, _ = tester.AssertGetResponse("http://localhost:9080/soft-purge-conditional", http.StatusOK, "version-1") + + purgeReq, _ := http.NewRequest("PURGE", "http://localhost:2999/souin-api/souin", nil) + purgeReq.Header.Set("Surrogate-Key", "post-1") + purgeReq.Header.Set("Souin-Purge-Mode", "soft") + _, _ = tester.AssertResponse(purgeReq, http.StatusNoContent, "") + + resp2, _ := tester.AssertGetResponse("http://localhost:9080/soft-purge-conditional", http.StatusOK, "version-1") + cacheStatus := resp2.Header.Get("Cache-Status") + if !containsAll(cacheStatus, "Souin; hit;", "; fwd=stale", "; detail=SOFT-PURGE-REVALIDATE") { + t.Fatalf("unexpected conditional soft purge Cache-Status header %q", cacheStatus) + } + + deadline := time.Now().Add(5 * time.Second) + for { + resp, err := http.Get("http://localhost:9080/soft-purge-conditional") + if err != nil { + t.Fatalf("unable to fetch revalidated response: %v", err) + } + + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + + if string(body) == "version-1" && containsAll(resp.Header.Get("Cache-Status"), "Souin; hit;", "key=GET-http-localhost:9080-/soft-purge-conditional") && !strings.Contains(resp.Header.Get("Cache-Status"), "SOFT-PURGE") { + break + } + + if time.Now().After(deadline) { + t.Fatalf("conditional background refresh did not clear the soft purge marker in time, last body %q last Cache-Status %q", string(body), resp.Header.Get("Cache-Status")) + } + + time.Sleep(100 * time.Millisecond) + } + + if origin.conditionalHitCount() == 0 { + t.Fatal("expected background refresh to use conditional revalidation") + } +} + +func TestSoftPurgeWithBypassRequestModeStillServesStaleAndRefreshes(t *testing.T) { + tester := caddytest.NewTester(t) + tester.InitServer(` + { + admin localhost:2999 + http_port 9080 + cache { + api { + souin + } + mode bypass_request + stale 1m + } + } + localhost:9080 { + route /soft-purge-bypass-request { + cache + reverse_proxy localhost:9090 + } + }`, "caddyfile") + + origin := &softPurgeOrigin{version: "version-1"} + go func() { + _ = http.ListenAndServe(":9090", origin) + }() + + time.Sleep(time.Second) + + primeReq, _ := http.NewRequest(http.MethodGet, "http://localhost:9080/soft-purge-bypass-request", nil) + primeReq.Header.Set("Cache-Control", "no-cache") + resp1, _ := tester.AssertResponse(primeReq, http.StatusOK, "version-1") + if resp1.Header.Get("Cache-Status") != "Souin; fwd=uri-miss; stored; key=GET-http-localhost:9080-/soft-purge-bypass-request" { + t.Fatalf("unexpected initial Cache-Status header %q", resp1.Header.Get("Cache-Status")) + } + + origin.setVersion("version-2") + + purgeReq, _ := http.NewRequest("PURGE", "http://localhost:2999/souin-api/souin", nil) + purgeReq.Header.Set("Surrogate-Key", "post-1") + purgeReq.Header.Set("Souin-Purge-Mode", "soft") + _, _ = tester.AssertResponse(purgeReq, http.StatusNoContent, "") + + staleReq, _ := http.NewRequest(http.MethodGet, "http://localhost:9080/soft-purge-bypass-request", nil) + staleReq.Header.Set("Cache-Control", "no-cache") + resp2, _ := tester.AssertResponse(staleReq, http.StatusOK, "version-1") + cacheStatus := resp2.Header.Get("Cache-Status") + if !containsAll(cacheStatus, "Souin; hit;", "; fwd=stale", "; detail=SOFT-PURGE-SWR") { + t.Fatalf("unexpected soft purge Cache-Status header with bypass_request %q", cacheStatus) + } + + deadline := time.Now().Add(5 * time.Second) + for { + req, _ := http.NewRequest(http.MethodGet, "http://localhost:9080/soft-purge-bypass-request", nil) + req.Header.Set("Cache-Control", "no-cache") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to fetch refreshed response: %v", err) + } + + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + + if string(body) == "version-2" && containsAll(resp.Header.Get("Cache-Status"), "Souin; hit;", "key=GET-http-localhost:9080-/soft-purge-bypass-request") && !strings.Contains(resp.Header.Get("Cache-Status"), "SOFT-PURGED") { + break + } + + if time.Now().After(deadline) { + t.Fatalf("background refresh did not replace the soft-purged object in bypass_request mode, last body %q last Cache-Status %q", string(body), resp.Header.Get("Cache-Status")) + } + + time.Sleep(100 * time.Millisecond) + } +} diff --git a/plugins/caddy/httpcache_swr_test.go b/plugins/caddy/httpcache_swr_test.go new file mode 100644 index 000000000..8b96e854e --- /dev/null +++ b/plugins/caddy/httpcache_swr_test.go @@ -0,0 +1,102 @@ +package httpcache + +import ( + "io" + "net/http" + "sync" + "testing" + "time" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +type staleWhileRevalidateOrigin struct { + mu sync.Mutex + version string + hits int +} + +func (s *staleWhileRevalidateOrigin) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + + s.hits++ + w.Header().Set("Cache-Control", "max-age=1, stale-while-revalidate=30") + _, _ = w.Write([]byte(s.version)) +} + +func (s *staleWhileRevalidateOrigin) setVersion(version string) { + s.mu.Lock() + defer s.mu.Unlock() + s.version = version +} + +func testStaleWhileRevalidateByMode(t *testing.T, mode string, path string, upstreamPort string) { + t.Helper() + + tester := caddytest.NewTester(t) + tester.InitServer(` + { + admin localhost:2999 + http_port 9080 + cache { + mode `+mode+` + stale 1m + } + } + localhost:9080 { + route `+path+` { + cache + reverse_proxy localhost:`+upstreamPort+` + } + }`, "caddyfile") + + origin := &staleWhileRevalidateOrigin{version: "version-1"} + go func() { + _ = http.ListenAndServe(":"+upstreamPort, origin) + }() + + time.Sleep(time.Second) + + resp1, _ := tester.AssertGetResponse("http://localhost:9080"+path, http.StatusOK, "version-1") + if resp1.Header.Get("Cache-Status") != "Souin; fwd=uri-miss; stored; key=GET-http-localhost:9080-"+path { + t.Fatalf("unexpected initial Cache-Status header %q", resp1.Header.Get("Cache-Status")) + } + + time.Sleep(2 * time.Second) + origin.setVersion("version-2") + + resp2, _ := tester.AssertGetResponse("http://localhost:9080"+path, http.StatusOK, "version-1") + if !containsAll(resp2.Header.Get("Cache-Status"), "Souin; hit;", "; fwd=stale") { + t.Fatalf("expected stale response in %s mode, got Cache-Status %q", mode, resp2.Header.Get("Cache-Status")) + } + + deadline := time.Now().Add(5 * time.Second) + for { + resp, err := http.Get("http://localhost:9080" + path) + if err != nil { + t.Fatalf("unable to fetch refreshed response in %s mode: %v", mode, err) + } + + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + + if string(body) == "version-2" && containsAll(resp.Header.Get("Cache-Status"), "Souin; hit;", "key=GET-http-localhost:9080-"+path) && !containsAll(resp.Header.Get("Cache-Status"), "; fwd=stale") { + break + } + + if time.Now().After(deadline) { + t.Fatalf("background revalidation did not refresh the stale response in %s mode, last body %q last Cache-Status %q", mode, string(body), resp.Header.Get("Cache-Status")) + } + + time.Sleep(100 * time.Millisecond) + } +} + +func TestStaleWhileRevalidateInBypassRequestMode(t *testing.T) { + testStaleWhileRevalidateByMode(t, "bypass_request", "/stale-while-revalidate-bypass-request", "9091") +} + +func TestStaleWhileRevalidateInStrictMode(t *testing.T) { + testStaleWhileRevalidateByMode(t, "strict", "/stale-while-revalidate-strict", "9092") +}