diff --git a/internal/fusefs/fuse_unix.go b/internal/fusefs/fuse_unix.go index f830cba..fef603e 100644 --- a/internal/fusefs/fuse_unix.go +++ b/internal/fusefs/fuse_unix.go @@ -6,7 +6,6 @@ import ( "context" "errors" "fmt" - "io" iofs "io/fs" "os" "path/filepath" @@ -475,26 +474,6 @@ func (fs *ArtifactFuse) Rename(ctx context.Context, op *fuseops.RenameOp) error // symlink can point at anyway. const maxSymlinkTargetBytes = 4096 -// readSymlinkTarget reads a cached Git blob and returns its contents as a -// symlink target. Git stores the target as the raw blob body, so nothing -// stops a repo from parking a very large payload behind a mode 120000 entry. -// We read with a bounded LimitReader and reject anything past PATH_MAX. -func readSymlinkTarget(cachePath string) (string, error) { - f, err := os.Open(cachePath) - if err != nil { - return "", err - } - defer f.Close() - data, err := io.ReadAll(io.LimitReader(f, maxSymlinkTargetBytes+1)) - if err != nil { - return "", err - } - if len(data) > maxSymlinkTargetBytes { - return "", syscall.ENAMETOOLONG - } - return string(data), nil -} - func (fs *ArtifactFuse) ReadSymlink(ctx context.Context, op *fuseops.ReadSymlinkOp) error { ref, err := fs.requireInode(op.Inode, syscall.ESTALE) if err != nil { @@ -505,26 +484,35 @@ func (fs *ArtifactFuse) ReadSymlink(ctx context.Context, op *fuseops.ReadSymlink return syscall.ENOENT } if n.Base.ObjectOID != "" { - if n.Base.SizeState == "known" && n.Base.SizeBytes > maxSymlinkTargetBytes { - return syscall.ENAMETOOLONG - } - cachePath, _, err := fs.engine.Hydrator.EnsureHydrated(ctx, fs.repo, n.Base) - if err != nil { - return syscall.EIO + if err := validateKnownSymlinkTargetSize(n.Base); err != nil { + return err } - target, err := readSymlinkTarget(cachePath) + data, err := fs.engine.Hydrator.ReadBlob(ctx, fs.repo, n.Base, maxSymlinkTargetBytes) if err != nil { - if errors.Is(err, syscall.ENAMETOOLONG) { + if errors.Is(err, model.ErrBlobTooLarge) { return syscall.ENAMETOOLONG } return syscall.EIO } - op.Target = target + op.Target = string(data) return nil } return syscall.ENOENT } +func validateKnownSymlinkTargetSize(node model.BaseNode) error { + if node.SizeState != "known" { + return nil + } + if node.SizeBytes < 0 { + return syscall.EIO + } + if node.SizeBytes > maxSymlinkTargetBytes { + return syscall.ENAMETOOLONG + } + return nil +} + func (fs *ArtifactFuse) FlushFile(_ context.Context, _ *fuseops.FlushFileOp) error { return nil } diff --git a/internal/fusefs/readsymlink_unix_test.go b/internal/fusefs/readsymlink_unix_test.go index 5f5deb6..23e9580 100644 --- a/internal/fusefs/readsymlink_unix_test.go +++ b/internal/fusefs/readsymlink_unix_test.go @@ -5,9 +5,6 @@ package fusefs import ( "context" "errors" - "os" - "path/filepath" - "strings" "syscall" "testing" @@ -16,10 +13,13 @@ import ( ) type fakeSymlinkHydrator struct { - calls int - cachePath string - size int64 - err error + calls int + readBlobCalls int + cachePath string + size int64 + err error + readBlobData []byte + readBlobErr error } func (f *fakeSymlinkHydrator) Enqueue(model.HydrationTask) {} @@ -29,91 +29,93 @@ func (f *fakeSymlinkHydrator) EnsureHydrated(_ context.Context, _ model.RepoConf return f.cachePath, f.size, f.err } +func (f *fakeSymlinkHydrator) ReadBlob(_ context.Context, _ model.RepoConfig, _ model.BaseNode, _ int64) ([]byte, error) { + f.readBlobCalls++ + return f.readBlobData, f.readBlobErr +} + func (f *fakeSymlinkHydrator) QueueDepth(model.RepoID) int { return 0 } -func writeBlob(t *testing.T, dir, name string, data []byte) string { - t.Helper() - p := filepath.Join(dir, name) - if err := os.WriteFile(p, data, 0o644); err != nil { - t.Fatalf("write %s: %v", p, err) - } - return p -} +func TestReadSymlinkRejectsKnownOversizedBlobBeforeHydration(t *testing.T) { + hydrator := &fakeSymlinkHydrator{} + repoID := model.RepoID("repo") + resolver := newResolver( + &fakeSnapshot{nodes: map[string]model.BaseNode{ + "link": { + RepoID: repoID, + Path: "link", + Type: "symlink", + Mode: 0o120000, + ObjectOID: "blob", + SizeState: "known", + SizeBytes: int64(maxSymlinkTargetBytes + 1), + }, + }}, + &fakeOverlay{entries: map[string]model.OverlayEntry{}}, + ) + fs := NewArtifactFuse(model.RepoConfig{ID: repoID}, resolver, &Engine{Hydrator: hydrator}) -func TestReadSymlinkTarget_EmptyTarget(t *testing.T) { - dir := t.TempDir() - p := writeBlob(t, dir, "empty", nil) - got, err := readSymlinkTarget(p) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != "" { - t.Fatalf("target = %q, want empty string", got) - } -} + fs.mu.Lock() + ref := fs.allocInode("link", "symlink", 0o120000) + fs.mu.Unlock() -func TestReadSymlinkTarget_ShortTarget(t *testing.T) { - dir := t.TempDir() - p := writeBlob(t, dir, "short", []byte("../relative/path")) - got, err := readSymlinkTarget(p) - if err != nil { - t.Fatalf("unexpected error: %v", err) + op := &fuseops.ReadSymlinkOp{Inode: ref.ID} + err := fs.ReadSymlink(context.Background(), op) + if !errors.Is(err, syscall.ENAMETOOLONG) { + t.Fatalf("err = %v, want ENAMETOOLONG", err) } - if got != "../relative/path" { - t.Fatalf("target = %q, want %q", got, "../relative/path") + if hydrator.calls != 0 { + t.Fatalf("EnsureHydrated calls = %d, want 0", hydrator.calls) } -} - -func TestReadSymlinkTarget_AtLimit(t *testing.T) { - dir := t.TempDir() - data := []byte(strings.Repeat("a", maxSymlinkTargetBytes)) - p := writeBlob(t, dir, "at-limit", data) - got, err := readSymlinkTarget(p) - if err != nil { - t.Fatalf("unexpected error at %d bytes: %v", maxSymlinkTargetBytes, err) + if hydrator.readBlobCalls != 0 { + t.Fatalf("ReadBlob calls = %d, want 0", hydrator.readBlobCalls) } - if len(got) != maxSymlinkTargetBytes { - t.Fatalf("target length = %d, want %d", len(got), maxSymlinkTargetBytes) + if op.Target != "" { + t.Fatalf("target = %q, want empty", op.Target) } } -func TestReadSymlinkTarget_OverLimit(t *testing.T) { - dir := t.TempDir() - data := []byte(strings.Repeat("a", maxSymlinkTargetBytes+1)) - p := writeBlob(t, dir, "over-limit", data) - _, err := readSymlinkTarget(p) - if !errors.Is(err, syscall.ENAMETOOLONG) { - t.Fatalf("err = %v, want ENAMETOOLONG", err) - } -} +func TestReadSymlinkRejectsNegativeKnownBlobBeforeHydration(t *testing.T) { + hydrator := &fakeSymlinkHydrator{} + repoID := model.RepoID("repo") + resolver := newResolver( + &fakeSnapshot{nodes: map[string]model.BaseNode{ + "link": { + RepoID: repoID, + Path: "link", + Type: "symlink", + Mode: 0o120000, + ObjectOID: "blob", + SizeState: "known", + SizeBytes: -1, + }, + }}, + &fakeOverlay{entries: map[string]model.OverlayEntry{}}, + ) + fs := NewArtifactFuse(model.RepoConfig{ID: repoID}, resolver, &Engine{Hydrator: hydrator}) + + fs.mu.Lock() + ref := fs.allocInode("link", "symlink", 0o120000) + fs.mu.Unlock() -func TestReadSymlinkTarget_FarOverLimit(t *testing.T) { - // A blob that's orders of magnitude past PATH_MAX should still be read - // into a bounded slice and rejected, not slurped whole. - dir := t.TempDir() - data := make([]byte, 1<<20) // 1 MiB - for i := range data { - data[i] = 'x' + op := &fuseops.ReadSymlinkOp{Inode: ref.ID} + err := fs.ReadSymlink(context.Background(), op) + if !errors.Is(err, syscall.EIO) { + t.Fatalf("err = %v, want EIO", err) } - p := writeBlob(t, dir, "huge", data) - _, err := readSymlinkTarget(p) - if !errors.Is(err, syscall.ENAMETOOLONG) { - t.Fatalf("err = %v, want ENAMETOOLONG", err) + if hydrator.calls != 0 { + t.Fatalf("EnsureHydrated calls = %d, want 0", hydrator.calls) } -} - -func TestReadSymlinkTarget_MissingFile(t *testing.T) { - _, err := readSymlinkTarget(filepath.Join(t.TempDir(), "does-not-exist")) - if err == nil { - t.Fatal("expected error for missing cache file, got nil") + if hydrator.readBlobCalls != 0 { + t.Fatalf("ReadBlob calls = %d, want 0", hydrator.readBlobCalls) } - if errors.Is(err, syscall.ENAMETOOLONG) { - t.Fatalf("err = %v, want non-ENAMETOOLONG for missing file", err) + if op.Target != "" { + t.Fatalf("target = %q, want empty", op.Target) } } -func TestReadSymlinkRejectsKnownOversizedBlobBeforeHydration(t *testing.T) { - hydrator := &fakeSymlinkHydrator{} +func TestReadSymlinkRejectsUnknownOversizedBlobWithoutHydration(t *testing.T) { + hydrator := &fakeSymlinkHydrator{readBlobErr: model.ErrBlobTooLarge} repoID := model.RepoID("repo") resolver := newResolver( &fakeSnapshot{nodes: map[string]model.BaseNode{ @@ -123,8 +125,7 @@ func TestReadSymlinkRejectsKnownOversizedBlobBeforeHydration(t *testing.T) { Type: "symlink", Mode: 0o120000, ObjectOID: "blob", - SizeState: "known", - SizeBytes: int64(maxSymlinkTargetBytes + 1), + SizeState: "unknown", }, }}, &fakeOverlay{entries: map[string]model.OverlayEntry{}}, @@ -143,7 +144,48 @@ func TestReadSymlinkRejectsKnownOversizedBlobBeforeHydration(t *testing.T) { if hydrator.calls != 0 { t.Fatalf("EnsureHydrated calls = %d, want 0", hydrator.calls) } + if hydrator.readBlobCalls != 1 { + t.Fatalf("ReadBlob calls = %d, want 1", hydrator.readBlobCalls) + } if op.Target != "" { t.Fatalf("target = %q, want empty", op.Target) } } + +func TestReadSymlinkReadsUnknownBlobThroughBoundedRead(t *testing.T) { + hydrator := &fakeSymlinkHydrator{readBlobData: []byte("../target")} + repoID := model.RepoID("repo") + resolver := newResolver( + &fakeSnapshot{nodes: map[string]model.BaseNode{ + "link": { + RepoID: repoID, + Path: "link", + Type: "symlink", + Mode: 0o120000, + ObjectOID: "blob", + SizeState: "unknown", + }, + }}, + &fakeOverlay{entries: map[string]model.OverlayEntry{}}, + ) + fs := NewArtifactFuse(model.RepoConfig{ID: repoID}, resolver, &Engine{Hydrator: hydrator}) + + fs.mu.Lock() + ref := fs.allocInode("link", "symlink", 0o120000) + fs.mu.Unlock() + + op := &fuseops.ReadSymlinkOp{Inode: ref.ID} + err := fs.ReadSymlink(context.Background(), op) + if err != nil { + t.Fatalf("ReadSymlink: %v", err) + } + if hydrator.calls != 0 { + t.Fatalf("EnsureHydrated calls = %d, want 0", hydrator.calls) + } + if hydrator.readBlobCalls != 1 { + t.Fatalf("ReadBlob calls = %d, want 1", hydrator.readBlobCalls) + } + if op.Target != "../target" { + t.Fatalf("target = %q, want ../target", op.Target) + } +} diff --git a/internal/gitstore/gitstore.go b/internal/gitstore/gitstore.go index c8054ad..5c4418f 100644 --- a/internal/gitstore/gitstore.go +++ b/internal/gitstore/gitstore.go @@ -27,6 +27,13 @@ type Store struct { pools map[string]*batchPool // gitDir -> pool } +type readBlobResult struct { + data []byte + err error +} + +const maxReadBlobBytes int64 = 1<<31 - 1 + func New(logger *slog.Logger) *Store { if logger == nil { logger = slog.Default() @@ -251,6 +258,58 @@ func (s *Store) BlobToCache(ctx context.Context, repo model.RepoConfig, objectOI return size, err } +func (s *Store) ReadBlob(ctx context.Context, repo model.RepoConfig, objectOID string, maxBytes int64) ([]byte, error) { + if maxBytes < 0 { + return nil, fmt.Errorf("negative max bytes: %d", maxBytes) + } + pool := s.getPool(repo.GitDir) + batch, err := pool.acquire() + if err != nil { + return nil, err + } + data, err := readBatchBlob(ctx, batch, objectOID, maxBytes) + if err == nil { + pool.release(batch) + return data, nil + } + if errors.Is(err, model.ErrBlobTooLarge) { + batch.kill() + return nil, err + } + batch.close() + + batch, err = pool.acquire() + if err != nil { + return nil, err + } + data, err = readBatchBlob(ctx, batch, objectOID, maxBytes) + if err != nil { + if errors.Is(err, model.ErrBlobTooLarge) { + batch.kill() + return nil, err + } + batch.close() + return nil, err + } + pool.release(batch) + return data, nil +} + +func readBatchBlob(ctx context.Context, batch *batchCatFile, objectOID string, maxBytes int64) ([]byte, error) { + ch := make(chan readBlobResult, 1) + go func() { + data, err := batch.readBlob(objectOID, maxBytes) + ch <- readBlobResult{data: data, err: err} + }() + select { + case r := <-ch: + return r.data, r.err + case <-ctx.Done(): + batch.kill() + return nil, ctx.Err() + } +} + func (s *Store) VerifyBlob(ctx context.Context, repo model.RepoConfig, objectOID string, cachePath string) (bool, error) { out, err := runGit(ctx, repo.GitDir, "hash-object", "--no-filters", cachePath) if err != nil { @@ -382,6 +441,13 @@ func (b *batchCatFile) close() { } } +func (b *batchCatFile) kill() { + if b.cmd != nil && b.cmd.Process != nil { + _ = b.cmd.Process.Kill() + } + b.close() +} + // fetchToFile writes oid to the batch process stdin, reads the response header // and streams the blob content directly to dstPath. Binary-safe (no string // conversion of blob content). @@ -395,25 +461,9 @@ func (b *batchCatFile) fetchToFile(oid string, dstPath string) (int64, error) { return 0, fmt.Errorf("batch write: %w", err) } - // Read response header: " SP SP LF" or " SP missing LF" - header, err := b.stdout.ReadString('\n') + size, err := b.readObjectSize(oid) if err != nil { - return 0, fmt.Errorf("batch read header: %w", err) - } - header = strings.TrimRight(header, "\n") - fields := strings.Fields(header) - if len(fields) < 2 { - return 0, fmt.Errorf("unexpected batch header: %q", header) - } - if fields[1] == "missing" { - return 0, fmt.Errorf("object %s missing", oid) - } - if len(fields) < 3 { - return 0, fmt.Errorf("unexpected batch header: %q", header) - } - size, err := strconv.ParseInt(fields[2], 10, 64) - if err != nil { - return 0, fmt.Errorf("parse size %q: %w", fields[2], err) + return 0, err } // Stream blob content to a temp file, then atomic rename. The blob cache is @@ -453,6 +503,60 @@ func (b *batchCatFile) fetchToFile(oid string, dstPath string) (int64, error) { return size, nil } +func (b *batchCatFile) readBlob(oid string, maxBytes int64) ([]byte, error) { + if b.cmd == nil || b.stdin == nil { + return nil, errors.New("batch cat-file process not running") + } + if _, err := fmt.Fprintf(b.stdin, "%s\n", oid); err != nil { + return nil, fmt.Errorf("batch write: %w", err) + } + size, err := b.readObjectSize(oid) + if err != nil { + return nil, err + } + if size < 0 { + return nil, fmt.Errorf("negative blob size: %d", size) + } + if size > maxBytes { + return nil, model.ErrBlobTooLarge + } + if size > maxReadBlobBytes { + return nil, model.ErrBlobTooLarge + } + data := make([]byte, int(size)) + if _, err := io.ReadFull(b.stdout, data); err != nil { + return nil, fmt.Errorf("batch read content: %w", err) + } + if _, err := b.stdout.ReadByte(); err != nil { + return nil, fmt.Errorf("batch read trailing LF: %w", err) + } + return data, nil +} + +func (b *batchCatFile) readObjectSize(oid string) (int64, error) { + // Read response header: " SP SP LF" or " SP missing LF" + header, err := b.stdout.ReadString('\n') + if err != nil { + return 0, fmt.Errorf("batch read header: %w", err) + } + header = strings.TrimRight(header, "\n") + fields := strings.Fields(header) + if len(fields) < 2 { + return 0, fmt.Errorf("unexpected batch header: %q", header) + } + if fields[1] == "missing" { + return 0, fmt.Errorf("object %s missing", oid) + } + if len(fields) < 3 { + return 0, fmt.Errorf("unexpected batch header: %q", header) + } + size, err := strconv.ParseInt(fields[2], 10, 64) + if err != nil { + return 0, fmt.Errorf("parse size %q: %w", fields[2], err) + } + return size, nil +} + // CommitTimestamp returns the committer timestamp of the given commit OID. func (s *Store) CommitTimestamp(ctx context.Context, repo model.RepoConfig, oid string) (int64, error) { out, err := runGit(ctx, repo.GitDir, "show", "-s", "--format=%ct", oid) diff --git a/internal/gitstore/gitstore_test.go b/internal/gitstore/gitstore_test.go index f779064..b6e7d22 100644 --- a/internal/gitstore/gitstore_test.go +++ b/internal/gitstore/gitstore_test.go @@ -2,6 +2,7 @@ package gitstore import ( "context" + "errors" "os" "os/exec" "path/filepath" @@ -88,6 +89,50 @@ func TestBlobToCacheBinarySafe(t *testing.T) { } } +func TestReadBlobRespectsMaxBytes(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + repo := filepath.Join(tmp, "repo") + run(t, "git", "init", repo) + os.WriteFile(filepath.Join(repo, "file.txt"), []byte("line\n"), 0o644) + run(t, "git", "-C", repo, "add", "file.txt") + run(t, "git", "-C", repo, "-c", "user.name=test", "-c", "user.email=test@example.com", "commit", "-m", "init") + + cfg := model.RepoConfig{ID: "x", GitDir: filepath.Join(repo, ".git")} + store := New(nil) + ctx := context.Background() + oid, _, _ := store.ResolveHEAD(ctx, cfg) + nodes, _ := store.BuildTreeIndex(ctx, cfg, oid) + var blobOID string + for _, n := range nodes { + if n.Path == "file.txt" { + blobOID = n.ObjectOID + } + } + if blobOID == "" { + t.Fatal("no blob OID found") + } + + data, err := store.ReadBlob(ctx, cfg, blobOID, 5) + if err != nil { + t.Fatalf("ReadBlob at limit: %v", err) + } + if string(data) != "line\n" { + t.Fatalf("data = %q, want line\\n", data) + } + _, err = store.ReadBlob(ctx, cfg, blobOID, 4) + if !errors.Is(err, model.ErrBlobTooLarge) { + t.Fatalf("err = %v, want ErrBlobTooLarge", err) + } + data, err = store.ReadBlob(ctx, cfg, blobOID, 5) + if err != nil { + t.Fatalf("ReadBlob after oversized read: %v", err) + } + if string(data) != "line\n" { + t.Fatalf("data after oversized read = %q, want line\\n", data) + } +} + func TestBuildTreeIndexNonASCIIPaths(t *testing.T) { t.Parallel() tmp := t.TempDir() diff --git a/internal/hydrator/hydrator.go b/internal/hydrator/hydrator.go index 22a3187..e36343c 100644 --- a/internal/hydrator/hydrator.go +++ b/internal/hydrator/hydrator.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" "os" "path/filepath" "sync" @@ -15,6 +16,7 @@ import ( type BlobFetcher interface { BlobToCache(ctx context.Context, repo model.RepoConfig, objectOID string, dstPath string) (size int64, err error) + ReadBlob(ctx context.Context, repo model.RepoConfig, objectOID string, maxBytes int64) ([]byte, error) VerifyBlob(ctx context.Context, repo model.RepoConfig, objectOID string, cachePath string) (ok bool, err error) } @@ -135,6 +137,48 @@ func (s *Service) EnsureHydrated(ctx context.Context, repo model.RepoConfig, nod } } +func (s *Service) ReadBlob(ctx context.Context, repo model.RepoConfig, node model.BaseNode, maxBytes int64) ([]byte, error) { + if maxBytes < 0 { + return nil, fmt.Errorf("negative max bytes: %d", maxBytes) + } + if node.SizeState == "known" && node.SizeBytes > maxBytes { + return nil, model.ErrBlobTooLarge + } + cachePath := cachePathFor(repo, node.ObjectOID) + if st, err := os.Stat(cachePath); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return nil, err + } + } else if st.Size() > maxBytes { + return s.fetcher.ReadBlob(ctx, repo, node.ObjectOID, maxBytes) + } + if size, ok, err := s.validateCachedBlob(ctx, repo, cachePath, node); err != nil { + return nil, err + } else if ok { + return readCachedBlob(cachePath, size, maxBytes) + } + return s.fetcher.ReadBlob(ctx, repo, node.ObjectOID, maxBytes) +} + +func readCachedBlob(cachePath string, size int64, maxBytes int64) ([]byte, error) { + if size > maxBytes { + return nil, model.ErrBlobTooLarge + } + f, err := os.Open(cachePath) + if err != nil { + return nil, err + } + defer f.Close() + data, err := io.ReadAll(io.LimitReader(f, maxBytes+1)) + if err != nil { + return nil, err + } + if int64(len(data)) > maxBytes { + return nil, model.ErrBlobTooLarge + } + return data, nil +} + func (s *Service) validateCachedBlob(ctx context.Context, repo model.RepoConfig, cachePath string, node model.BaseNode) (size int64, ok bool, err error) { st, err := os.Stat(cachePath) if err != nil { diff --git a/internal/hydrator/hydrator_test.go b/internal/hydrator/hydrator_test.go index b44b6e0..31cbd99 100644 --- a/internal/hydrator/hydrator_test.go +++ b/internal/hydrator/hydrator_test.go @@ -292,11 +292,83 @@ func TestValidateCachedBlobKeepsFileOnVerifyError(t *testing.T) { } } +func TestReadBlobRejectsKnownOversizedWithoutFetch(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + cfg := model.RepoConfig{ID: "repo", BlobCacheDir: tmp} + node := model.BaseNode{RepoID: cfg.ID, Path: "link", ObjectOID: "blob", SizeState: "known", SizeBytes: 6} + fetcher := &fakeBlobFetcher{payload: []byte("target")} + h := New(fetcher) + + _, err := h.ReadBlob(context.Background(), cfg, node, 5) + if !errors.Is(err, model.ErrBlobTooLarge) { + t.Fatalf("err = %v, want ErrBlobTooLarge", err) + } + if fetcher.Calls() != 0 { + t.Fatalf("BlobToCache calls = %d, want 0", fetcher.Calls()) + } + if fetcher.ReadBlobCalls() != 0 { + t.Fatalf("ReadBlob calls = %d, want 0", fetcher.ReadBlobCalls()) + } +} + +func TestReadBlobUsesBoundedFetcherForUnknownSize(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + cfg := model.RepoConfig{ID: "repo", BlobCacheDir: tmp} + node := model.BaseNode{RepoID: cfg.ID, Path: "link", ObjectOID: "blob", SizeState: "unknown"} + fetcher := &fakeBlobFetcher{payload: []byte("target")} + h := New(fetcher) + + _, err := h.ReadBlob(context.Background(), cfg, node, 5) + if !errors.Is(err, model.ErrBlobTooLarge) { + t.Fatalf("err = %v, want ErrBlobTooLarge", err) + } + if fetcher.Calls() != 0 { + t.Fatalf("BlobToCache calls = %d, want 0", fetcher.Calls()) + } + if fetcher.ReadBlobCalls() != 1 { + t.Fatalf("ReadBlob calls = %d, want 1", fetcher.ReadBlobCalls()) + } +} + +func TestReadBlobSkipsVerificationForOversizedCache(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + cfg := model.RepoConfig{ID: "repo", BlobCacheDir: tmp} + node := model.BaseNode{RepoID: cfg.ID, Path: "link", ObjectOID: "blob", SizeState: "unknown"} + cachePath := filepath.Join(tmp, node.ObjectOID) + if err := os.WriteFile(cachePath, []byte("oversized"), 0o644); err != nil { + t.Fatal(err) + } + fetcher := &fakeBlobFetcher{payload: []byte("ok"), verifyOK: true} + h := New(fetcher) + + data, err := h.ReadBlob(context.Background(), cfg, node, 5) + if err != nil { + t.Fatalf("ReadBlob: %v", err) + } + if string(data) != "ok" { + t.Fatalf("data = %q, want ok", data) + } + if fetcher.VerifyCalls() != 0 { + t.Fatalf("VerifyBlob calls = %d, want 0", fetcher.VerifyCalls()) + } + if fetcher.ReadBlobCalls() != 1 { + t.Fatalf("ReadBlob calls = %d, want 1", fetcher.ReadBlobCalls()) + } + if _, err := os.Stat(cachePath); err != nil { + t.Fatalf("cache file should be left alone: %v", err) + } +} + type fakeBlobFetcher struct { mu sync.Mutex calls int + readBlobCalls int verifyCalls int payload []byte + readBlobErr error verifyOK bool verifyErr error verifyStarted chan struct{} @@ -317,6 +389,19 @@ func (f *fakeBlobFetcher) BlobToCache(_ context.Context, _ model.RepoConfig, _ s return int64(len(f.payload)), nil } +func (f *fakeBlobFetcher) ReadBlob(_ context.Context, _ model.RepoConfig, _ string, maxBytes int64) ([]byte, error) { + f.mu.Lock() + f.readBlobCalls++ + f.mu.Unlock() + if f.readBlobErr != nil { + return nil, f.readBlobErr + } + if int64(len(f.payload)) > maxBytes { + return nil, model.ErrBlobTooLarge + } + return f.payload, nil +} + func (f *fakeBlobFetcher) VerifyBlob(_ context.Context, _ model.RepoConfig, _ string, _ string) (bool, error) { f.mu.Lock() f.verifyCalls++ @@ -339,6 +424,12 @@ func (f *fakeBlobFetcher) Calls() int { return f.calls } +func (f *fakeBlobFetcher) ReadBlobCalls() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.readBlobCalls +} + func (f *fakeBlobFetcher) VerifyCalls() int { f.mu.Lock() defer f.mu.Unlock() diff --git a/internal/model/types.go b/internal/model/types.go index 9e7372d..c49f1e5 100644 --- a/internal/model/types.go +++ b/internal/model/types.go @@ -2,12 +2,15 @@ package model import ( "context" + "errors" "fmt" "path/filepath" "strings" "time" ) +var ErrBlobTooLarge = errors.New("blob too large") + type RepoID string type RepoConfig struct { @@ -143,6 +146,7 @@ type GitStore interface { ResolveHEAD(ctx context.Context, repo RepoConfig) (oid string, ref string, err error) BuildTreeIndex(ctx context.Context, repo RepoConfig, headOID string) ([]BaseNode, error) BlobToCache(ctx context.Context, repo RepoConfig, objectOID string, dstPath string) (size int64, err error) + ReadBlob(ctx context.Context, repo RepoConfig, objectOID string, maxBytes int64) ([]byte, error) ComputeAheadBehind(ctx context.Context, repo RepoConfig) (ahead int, behind int, diverged bool, err error) CommitTimestamp(ctx context.Context, repo RepoConfig, oid string) (int64, error) ReadTreeHEAD(ctx context.Context, repo RepoConfig) error @@ -171,5 +175,6 @@ type OverlayStore interface { type Hydrator interface { Enqueue(task HydrationTask) EnsureHydrated(ctx context.Context, repo RepoConfig, node BaseNode) (cachePath string, size int64, err error) + ReadBlob(ctx context.Context, repo RepoConfig, node BaseNode, maxBytes int64) ([]byte, error) QueueDepth(repoID RepoID) int }