diff --git a/.gitignore b/.gitignore index 03b85ba193..be0e7c92d7 100644 --- a/.gitignore +++ b/.gitignore @@ -23,10 +23,6 @@ flag_file bin/* .beads/ -# Test directory for weights testing -/test-weights/ -weights.lock - # Auto-:d version files from setuptools-scm python/cog/_version.py coglet/python/coglet/_version.py diff --git a/architecture/06-cli.md b/architecture/06-cli.md index e009163274..64550bd44a 100644 --- a/architecture/06-cli.md +++ b/architecture/06-cli.md @@ -168,7 +168,6 @@ Stores credentials for `cog push`. These commands exist but are hidden from `cog --help`: - **`cog debug`** -- Generates the Dockerfile from cog.yaml without building (useful for debugging build issues) -- **`cog inspect`** -- Inspects model images and OCI indices - **`cog weights`** -- Parent command for `weights build`, `weights push`, `weights inspect` There's also a separate `base-image` binary (`cmd/base-image/`) with subcommands for managing Cog base images (`dockerfile`, `build`, `generate-matrix`). This isn't a `cog` subcommand. diff --git a/examples/managed-weights/.dockerignore b/examples/managed-weights/.dockerignore new file mode 100644 index 0000000000..9ed59a3002 --- /dev/null +++ b/examples/managed-weights/.dockerignore @@ -0,0 +1,17 @@ +# Keep the weights/ directory out of the docker build context. +# +# With v1 managed weights, `cog.yaml`'s weights: entries are packed into +# separate OCI layers and land at their `target` paths at runtime via the +# image index — they must NOT be baked into the model image by `cog build`. +# Without this exclude, buildkit ships the full (multi-GB) weights/ directory +# to the docker daemon on every build. +weights/ + +# Packed layer cache written by `cog weights build`. Do NOT exclude all of +# .cog/ — `cog build` stages the SDK + coglet wheels and CA cert under +# .cog/tmp/ and references them from its generated Dockerfile, so excluding +# the whole directory breaks the image build. +.cog/weights-cache/ + +# Git metadata doesn't belong in the image. +.git/ diff --git a/examples/managed-weights/.gitignore b/examples/managed-weights/.gitignore new file mode 100644 index 0000000000..f6b9c5ed94 --- /dev/null +++ b/examples/managed-weights/.gitignore @@ -0,0 +1,5 @@ +# Weight files (multi-GB, populated manually — see README.md) +weights/ + +# Cog build artifacts (packed layers, cached wheels, etc.) +.cog/ diff --git a/examples/managed-weights/README.md b/examples/managed-weights/README.md new file mode 100644 index 0000000000..14fe78abee --- /dev/null +++ b/examples/managed-weights/README.md @@ -0,0 +1,92 @@ +# examples/managed-weights + +A minimal cog model used to exercise the v1 managed-weights OCI pipeline +end-to-end. It produces an OCI image index carrying a model image manifest +and per-weight manifests. + +The predictor validates weight files on disk against `weights.lock` +(generated by `cog weights import`), errors on any missing files, and +returns a per-weight status summary from predict(). + +## Populating `weights/` + +The weight directory is git-ignored because it's ~5 GB. Populate it by +cloning the HuggingFace repo and copying everything except `.git/`: + +```bash +# One-time: clone the weights somewhere outside this repo +git clone https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 ~/hf/parakeet + +# Copy everything except .git into examples/managed-weights/weights/ +mkdir -p examples/managed-weights/weights +rsync -a --exclude=.git/ ~/hf/parakeet/ examples/managed-weights/weights/ +``` + +You can substitute any directory of model files; the pipeline is +content-agnostic. + +## Importing weights + +After populating (or changing) `weights/`, regenerate the lockfile: + +```bash +cd examples/managed-weights +cog weights import +``` + +This writes `weights.lock`. The predictor's `setup()` reads this file and +validates that all expected files exist at the declared targets. + +## Running the pipeline + +Start a local registry (or point at any registry you control): + +```bash +docker run -d --rm -p 5000:5000 --name cog-test-registry registry:3 +``` + +Build and push the full bundle. Presence of `weights:` in `cog.yaml` +triggers the OCI bundle format automatically. + +```bash +cd examples/managed-weights +cog push localhost:5000/managed-weights +``` + +Or run the weight pipeline in isolation (no model image): + +```bash +cd examples/managed-weights +cog weights build +cog weights push localhost:5000/managed-weights +``` + +## Testing locally + +Build the image and run it with weights bind-mounted: + +```bash +cd examples/managed-weights +cog build -t managed-weights-local +docker run --rm -p 5050:5000 \ + -v $(pwd)/weights:/src/weights/parakeet:ro \ + managed-weights-local +``` + +Then hit predict: + +```bash +curl -s -X POST http://localhost:5050/predictions \ + -H 'Content-Type: application/json' \ + -d '{"input":{}}' | jq '.output | fromjson' +``` + +## Inspecting the output + +```bash +crane manifest localhost:5000/managed-weights:latest | jq . +crane ls localhost:5000/managed-weights +``` + +Weight manifests are pushed under tags of the shape +`weights--<12-hex-digest>` (see `pkg/model/weight_pusher.go`). diff --git a/examples/managed-weights/cog.yaml b/examples/managed-weights/cog.yaml new file mode 100644 index 0000000000..b3601bbc0f --- /dev/null +++ b/examples/managed-weights/cog.yaml @@ -0,0 +1,48 @@ +# Example model for testing the v1 managed-weights OCI artifact format. +# +# The weights/ directory is populated by a human (see README.md) with +# nvidia/parakeet-tdt-0.6b-v3 from HuggingFace. It's listed in .gitignore +# so the ~5GB payload never hits git. +# +# Import weights and generate the lockfile: +# cog weights import +# +# Build and push the full bundle: +# cog push localhost:5000/managed-weights + +image: registry.cloudflare.com/3515b24d58ec616d11f4ce4290a02ac4/md/examples/managed-weights +# image: localhost:5000/md/examples/managed-weights + +build: + gpu: false + python_version: "3.12" + python_requirements: requirements.txt + +predict: "predict.py:Predictor" + +weights: + - name: parakeet + source: + uri: weights + include: + - "*.safetensors" # HF-format weights (skip the .nemo bundle) + - "*.json" # model + tokenizer configs + target: /src/weights/parakeet + - name: minilm + source: + uri: hf://sentence-transformers/all-MiniLM-L6-v2 + exclude: + - "onnx/" # ONNX runtime variants (~474 MB) + - "openvino/" # OpenVINO runtime variants (~113 MB) + - "pytorch_model.bin" # legacy format, redundant with model.safetensors + - "tf_model.h5" # TensorFlow weights + - "rust_model.ot" # Rust (tch-rs) weights + - "train_script.py" # training artifact + - "data_config.json" # training data config + - "README.md" + - ".gitattributes" + target: /src/weights/minilm + # - name: qwen3.6-27b-fp8 + # source: + # uri: hf://Qwen/Qwen3.6-27B-FP8 + # target: /src/weights/qwen diff --git a/examples/managed-weights/predict.py b/examples/managed-weights/predict.py new file mode 100644 index 0000000000..fc9a6d407f --- /dev/null +++ b/examples/managed-weights/predict.py @@ -0,0 +1,168 @@ +# Infra verification predictor for the v1 managed-weights OCI pipeline. +# Validates weight files on disk against weights.lock at setup; predict() +# returns a per-weight status summary. + +import hashlib +import json +import sys +from pathlib import Path +from typing import Any + +from cog import BasePredictor + +LOCK_PATH = Path("/src/weights.lock") + + +def _file_sha256(path: Path) -> str: + h = hashlib.sha256() + with open(path, "rb") as f: + while chunk := f.read(8 * 1024 * 1024): + h.update(chunk) + return f"sha256:{h.hexdigest()}" + + +def _validate_weight( + name: str, target: str, expected_files: list[dict[str, Any]] +) -> dict[str, Any]: + """Validate a single weight entry from the lockfile. + + Checks presence and size first (cheap), then hashes only files whose + size matches (expensive). This way missing or truncated files fail fast + without reading gigabytes of data. + """ + target_dir = Path(target) + + if not target_dir.is_dir(): + return { + "name": name, + "target": target, + "errors": [f"weight directory {target} does not exist"], + "warnings": [], + "ok": [], + "missing": [f["path"] for f in expected_files], + "extra": [], + "mismatch": [], + } + + # Walk the directory once — just stat, no hashing yet. + actual_by_path: dict[str, Path] = {} + actual_sizes: dict[str, int] = {} + for p in sorted(target_dir.rglob("*")): + if not p.is_file(): + continue + rel = str(p.relative_to(target_dir)) + actual_by_path[rel] = p + actual_sizes[rel] = p.stat().st_size + + ok: list[str] = [] + missing: list[str] = [] + mismatch: list[str] = [] + errors: list[str] = [] + + for entry in expected_files: + path = entry["path"] + + if path not in actual_by_path: + missing.append(path) + errors.append(f"missing: {path}") + continue + + disk_size = actual_sizes[path] + if disk_size != entry["size"]: + mismatch.append(path) + errors.append( + f"size mismatch: {path} (expected {entry['size']}, got {disk_size})" + ) + actual_by_path.pop(path) + continue + + # Size matches — hash to confirm content. + digest = _file_sha256(actual_by_path.pop(path)) + if digest != entry["digest"]: + mismatch.append(path) + errors.append(f"digest mismatch: {path}") + else: + ok.append(path) + + extra = sorted(actual_by_path.keys()) + warnings = [f"extra file: {p}" for p in extra] + + return { + "name": name, + "target": target, + "errors": errors, + "warnings": warnings, + "ok": ok, + "missing": missing, + "extra": extra, + "mismatch": mismatch, + } + + +class Predictor(BasePredictor): + def setup(self) -> None: + if not LOCK_PATH.exists(): + raise RuntimeError(f"{LOCK_PATH} not found — cannot validate weights") + + lock = json.loads(LOCK_PATH.read_text()) + + self.results: list[dict[str, Any]] = [] + all_errors: list[str] = [] + + for entry in lock["weights"]: + name = entry["name"] + target = entry["target"] + expected_files = [ + {"path": f["path"], "size": f["size"], "digest": f["digest"]} + for f in entry["files"] + ] + + # Dump directory contents before validation for debugging. + target_dir = Path(target) + if target_dir.is_dir(): + print(f"--- find {target} ---", file=sys.stderr) + for p in sorted(target_dir.rglob("*")): + suffix = "/" if p.is_dir() else f" ({p.stat().st_size})" + print(f" {p.relative_to(target_dir)}{suffix}", file=sys.stderr) + print("---", file=sys.stderr) + else: + print(f"--- {target}: does not exist ---", file=sys.stderr) + + print( + f"validating weight '{name}' at {target} ({len(expected_files)} files)", + file=sys.stderr, + ) + result = _validate_weight(name, target, expected_files) + self.results.append(result) + + for w in result["warnings"]: + print(f" WARNING: {w}", file=sys.stderr) + + if result["errors"]: + for e in result["errors"]: + all_errors.append(f"[{name}] {e}") + else: + print(f" OK ({len(result['ok'])} files)", file=sys.stderr) + + if all_errors: + msg = "weight validation failed:\n" + "\n".join( + f" - {e}" for e in all_errors + ) + raise RuntimeError(msg) + + print("all weights validated", file=sys.stderr) + + def predict(self) -> str: + summary = [] + for r in self.results: + entry: dict[str, Any] = { + "name": r["name"], + "target": r["target"], + "status": "ok" if not r["errors"] else "error", + "ok": len(r["ok"]), + "missing": r["missing"], + "extra": r["extra"], + "mismatch": r["mismatch"], + } + summary.append(entry) + return json.dumps(summary) diff --git a/examples/managed-weights/requirements.txt b/examples/managed-weights/requirements.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/managed-weights/weights.lock b/examples/managed-weights/weights.lock new file mode 100644 index 0000000000..a61dda3245 --- /dev/null +++ b/examples/managed-weights/weights.lock @@ -0,0 +1,177 @@ +{ + "version": 1, + "envelopeFormat": "sha256:ce2d53f8dd962ace393450e0abadbe227304897be87753a503b61f9c8525726e", + "weights": [ + { + "name": "parakeet", + "target": "/src/weights/parakeet", + "source": { + "uri": "file://./weights", + "fingerprint": "sha256:7ebdcfebeca9a959621601570f1beec20729f23ec721694f2432939789cfad86", + "include": [ + "*.json", + "*.safetensors" + ], + "exclude": [], + "importedAt": "2026-04-29T17:23:00.414483Z" + }, + "digest": "sha256:f2c88df678c449d8fca7def7afd5d9c17bb8963096108a51fd8fb9a92bf6fc8a", + "setDigest": "sha256:05933849eadc5067660cf0d164375fc9a80cf11a2be19a663017af62f5284bf2", + "size": 2509473151, + "sizeCompressed": 2508595098, + "files": [ + { + "path": "config.json", + "size": 1153, + "digest": "sha256:e747b85e1bdfd300c8b8ac63bac8dd5221f8fe9bc275b48d06c735fcd6971b6e", + "layer": "sha256:5f33cdfadfaf568afd41ce78ac566bda102564454e8211f11370281e9ee9619a" + }, + { + "path": "generation_config.json", + "size": 268, + "digest": "sha256:fc78f636b071231420356dbf70140a81a389132eb49bc81bcb2efdbe8293e7ad", + "layer": "sha256:5f33cdfadfaf568afd41ce78ac566bda102564454e8211f11370281e9ee9619a" + }, + { + "path": "model.safetensors", + "size": 2508311120, + "digest": "sha256:3a2026366188c8c68598edbbff92f8d11590a08e0ae2e6775544e7b07d6a5e11", + "layer": "sha256:65f11b0713429c604eb159d02a1c805815d599414e10c1c295e1937472db3a2f" + }, + { + "path": "processor_config.json", + "size": 392, + "digest": "sha256:8346a93a3b987fa1dec57a78f045cd0817d21786589a5a096b41a57a446fd1d7", + "layer": "sha256:5f33cdfadfaf568afd41ce78ac566bda102564454e8211f11370281e9ee9619a" + }, + { + "path": "tokenizer.json", + "size": 1159960, + "digest": "sha256:bd321b096832a3f270bd3b2a88823957920f1a5c5ada71114a26ea729d0cbe91", + "layer": "sha256:5f33cdfadfaf568afd41ce78ac566bda102564454e8211f11370281e9ee9619a" + }, + { + "path": "tokenizer_config.json", + "size": 258, + "digest": "sha256:5e04ae3487a5533224c295622cb206cc6e53914be527503978ce81f2cc75c559", + "layer": "sha256:5f33cdfadfaf568afd41ce78ac566bda102564454e8211f11370281e9ee9619a" + } + ], + "layers": [ + { + "digest": "sha256:5f33cdfadfaf568afd41ce78ac566bda102564454e8211f11370281e9ee9619a", + "mediaType": "application/vnd.oci.image.layer.v1.tar+gzip", + "size": 280986, + "sizeUncompressed": 1162031 + }, + { + "digest": "sha256:65f11b0713429c604eb159d02a1c805815d599414e10c1c295e1937472db3a2f", + "mediaType": "application/vnd.oci.image.layer.v1.tar", + "size": 2508314112, + "sizeUncompressed": 2508311120 + } + ] + }, + { + "name": "minilm", + "target": "/src/weights/minilm", + "source": { + "uri": "hf://sentence-transformers/all-MiniLM-L6-v2", + "fingerprint": "commit:c9745ed1d9f207416be6d2e6f8de32d1f16199bf", + "include": [], + "exclude": [ + ".gitattributes", + "README.md", + "data_config.json", + "onnx/", + "openvino/", + "pytorch_model.bin", + "rust_model.ot", + "tf_model.h5", + "train_script.py" + ], + "importedAt": "2026-04-29T22:43:36.709798Z" + }, + "digest": "sha256:6b15a4ac4f7e4dec39939043d32e7c51238516356f810b496735c50d0d7310be", + "setDigest": "sha256:18074874da7fd77e7cf1fefcbb2fc7edee0c38417f2255962e8a4fff5567b1d2", + "size": 91567913, + "sizeCompressed": 91174861, + "files": [ + { + "path": "1_Pooling/config.json", + "size": 190, + "digest": "sha256:4be450dde3b0273bb9787637cfbd28fe04a7ba6ab9d36ac48e92b11e350ffc23", + "layer": "sha256:7ff2b72da0831598b639d05ed0e1ea64f2623cc5055245501b2c50f63a28cd0a" + }, + { + "path": "config.json", + "size": 612, + "digest": "sha256:953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41", + "layer": "sha256:7ff2b72da0831598b639d05ed0e1ea64f2623cc5055245501b2c50f63a28cd0a" + }, + { + "path": "config_sentence_transformers.json", + "size": 116, + "digest": "sha256:061ca9d39661d6c6d6de5ba27f79a1cd5770ea247f8d46412a68a498dc5ac9f3", + "layer": "sha256:7ff2b72da0831598b639d05ed0e1ea64f2623cc5055245501b2c50f63a28cd0a" + }, + { + "path": "model.safetensors", + "size": 90868376, + "digest": "sha256:53aa51172d142c89d9012cce15ae4d6cc0ca6895895114379cacb4fab128d9db", + "layer": "sha256:69482828980bf1c74078f5a7bf0ba03bf719f525a60e2e32a5b369417d19bcef" + }, + { + "path": "modules.json", + "size": 349, + "digest": "sha256:84e40c8e006c9b1d6c122e02cba9b02458120b5fb0c87b746c41e0207cf642cf", + "layer": "sha256:7ff2b72da0831598b639d05ed0e1ea64f2623cc5055245501b2c50f63a28cd0a" + }, + { + "path": "sentence_bert_config.json", + "size": 53, + "digest": "sha256:fc1993fde0a95c24ec6c022539d41cf6e2f7c9721e5415d6fb6897472a9cd4b7", + "layer": "sha256:7ff2b72da0831598b639d05ed0e1ea64f2623cc5055245501b2c50f63a28cd0a" + }, + { + "path": "special_tokens_map.json", + "size": 112, + "digest": "sha256:303df45a03609e4ead04bc3dc1536d0ab19b5358db685b6f3da123d05ec200e3", + "layer": "sha256:7ff2b72da0831598b639d05ed0e1ea64f2623cc5055245501b2c50f63a28cd0a" + }, + { + "path": "tokenizer.json", + "size": 466247, + "digest": "sha256:be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037", + "layer": "sha256:7ff2b72da0831598b639d05ed0e1ea64f2623cc5055245501b2c50f63a28cd0a" + }, + { + "path": "tokenizer_config.json", + "size": 350, + "digest": "sha256:acb92769e8195aabd29b7b2137a9e6d6e25c476a4f15aa4355c233426c61576b", + "layer": "sha256:7ff2b72da0831598b639d05ed0e1ea64f2623cc5055245501b2c50f63a28cd0a" + }, + { + "path": "vocab.txt", + "size": 231508, + "digest": "sha256:07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3", + "layer": "sha256:7ff2b72da0831598b639d05ed0e1ea64f2623cc5055245501b2c50f63a28cd0a" + } + ], + "layers": [ + { + "digest": "sha256:69482828980bf1c74078f5a7bf0ba03bf719f525a60e2e32a5b369417d19bcef", + "mediaType": "application/vnd.oci.image.layer.v1.tar", + "size": 90871296, + "sizeUncompressed": 90868376 + }, + { + "digest": "sha256:7ff2b72da0831598b639d05ed0e1ea64f2623cc5055245501b2c50f63a28cd0a", + "mediaType": "application/vnd.oci.image.layer.v1.tar+gzip", + "size": 303565, + "sizeUncompressed": 699537 + } + ] + } + ] +} \ No newline at end of file diff --git a/examples/resnet/.dockerignore b/examples/resnet/.dockerignore new file mode 100644 index 0000000000..2548cd88ca --- /dev/null +++ b/examples/resnet/.dockerignore @@ -0,0 +1,12 @@ +# Keep the weights/ directory out of the docker build context. +# +# With v1 managed weights, cog.yaml's weights: entries are packed into +# separate OCI layers and land at their target paths at runtime via the +# image index — they must NOT be baked into the model image by `cog build`. +weights/ + +# Packed layer cache written by `cog weights build`. +.cog/weights-cache/ + +# Git metadata doesn't belong in the image. +.git/ diff --git a/examples/resnet/.gitignore b/examples/resnet/.gitignore new file mode 100644 index 0000000000..6b8dd0b60d --- /dev/null +++ b/examples/resnet/.gitignore @@ -0,0 +1,5 @@ +# Weight files (downloaded by `cog weights import` — see README.md) +weights/ + +# Cog build artifacts (packed layers, cached wheels, etc.) +.cog/ diff --git a/examples/resnet/README.md b/examples/resnet/README.md new file mode 100644 index 0000000000..e3701f64a5 --- /dev/null +++ b/examples/resnet/README.md @@ -0,0 +1,25 @@ +# examples/resnet + +ResNet50 image classifier (microsoft/resnet-50 from HuggingFace) packaged +with v1 managed weights. Takes an image, returns top-3 ImageNet classes. + +## Usage + +Import weights from HuggingFace and generate the lockfile: + +```sh +cd examples/resnet +cog weights import +``` + +Run a prediction locally (weights are bind-mounted): + +```sh +cog predict -i image=@hotdog.png +``` + +Build and push to a registry: + +```sh +cog push /resnet +``` diff --git a/examples/resnet/cat.png b/examples/resnet/cat.png new file mode 100644 index 0000000000..15296784ac Binary files /dev/null and b/examples/resnet/cat.png differ diff --git a/examples/resnet/cog.yaml b/examples/resnet/cog.yaml new file mode 100644 index 0000000000..413d3ee998 --- /dev/null +++ b/examples/resnet/cog.yaml @@ -0,0 +1,28 @@ +# ResNet50 image classifier using v1 managed weights. +# +# Weights are pulled from HuggingFace at import time: +# cog weights import +# +# Build and push: +# cog push + +image: /resnet + +build: + gpu: true + python_version: "3.13" + python_requirements: requirements.txt + +predict: "predict.py:Predictor" + +weights: + - name: resnet50 + source: + uri: hf://microsoft/resnet-50 + exclude: + - "pytorch_model.bin" # legacy format, redundant with model.safetensors + - "flax_model.msgpack" # Flax/JAX weights + - "tf_model.h5" # TensorFlow weights + - "README.md" + - ".gitattributes" + target: /src/weights/resnet50 diff --git a/examples/resnet/hotdog.png b/examples/resnet/hotdog.png new file mode 100644 index 0000000000..16a2719693 Binary files /dev/null and b/examples/resnet/hotdog.png differ diff --git a/examples/resnet/predict.py b/examples/resnet/predict.py new file mode 100644 index 0000000000..767b1cc77e --- /dev/null +++ b/examples/resnet/predict.py @@ -0,0 +1,26 @@ +import torch +from cog import BasePredictor, Input, Path +from PIL import Image +from transformers import AutoImageProcessor, ResNetForImageClassification + +WEIGHTS_DIR = "/src/weights/resnet50" + + +class Predictor(BasePredictor): + def setup(self): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.processor = AutoImageProcessor.from_pretrained(WEIGHTS_DIR) + self.model = ResNetForImageClassification.from_pretrained(WEIGHTS_DIR) + self.model = self.model.to(self.device) + self.model.eval() + + def predict(self, image: Path = Input(description="Image to classify")) -> dict: + img = Image.open(image).convert("RGB") + inputs = self.processor(img, return_tensors="pt").to(self.device) + + with torch.no_grad(): + logits = self.model(**inputs).logits + + top3 = logits[0].softmax(0).topk(3) + labels = self.model.config.id2label + return {labels[i.item()]: p.item() for p, i in zip(*top3)} diff --git a/examples/resnet/requirements.txt b/examples/resnet/requirements.txt new file mode 100644 index 0000000000..72eea68af8 --- /dev/null +++ b/examples/resnet/requirements.txt @@ -0,0 +1,3 @@ +pillow==12.1.1 +torch==2.8.0 +transformers==4.52.3 diff --git a/examples/resnet/weights.lock b/examples/resnet/weights.lock new file mode 100644 index 0000000000..3c5f63730c --- /dev/null +++ b/examples/resnet/weights.lock @@ -0,0 +1,61 @@ +{ + "version": 1, + "envelopeFormat": "sha256:ce2d53f8dd962ace393450e0abadbe227304897be87753a503b61f9c8525726e", + "weights": [ + { + "name": "resnet50", + "target": "/src/weights/resnet50", + "source": { + "uri": "hf://microsoft/resnet-50", + "fingerprint": "commit:34c2154c194f829b11125337b98c8f5f9965ff19", + "include": [], + "exclude": [ + ".gitattributes", + "README.md", + "flax_model.msgpack", + "pytorch_model.bin", + "tf_model.h5" + ], + "importedAt": "2026-04-30T21:49:23.515142Z" + }, + "digest": "sha256:d2daafad96409df82d69df3c92192d2e651344f579a12683a59e4a6140a5abf5", + "setDigest": "sha256:52924993c7eff45d5d1deaecf1f375d774c30faa1b4ce61379f5d552fd376744", + "size": 102552676, + "sizeCompressed": 102509231, + "files": [ + { + "path": "config.json", + "size": 69556, + "digest": "sha256:a1be0f56d516d0b55f9844ce0fbfe5fb359195cba6c6c102d41ae6c202c84d5e", + "layer": "sha256:d3458267e3c3179135ab97c11fb1a0da870c44be7c32d65c5d2b35bc91c46bf8" + }, + { + "path": "model.safetensors", + "size": 102482854, + "digest": "sha256:9c6061af1f450bb0847e529fd742aa5066017be379c71bbf5546b198e5b13a1e", + "layer": "sha256:dec23f069d100ce2979619413133cd208a35069daed2b13e77879ac2404994c5" + }, + { + "path": "preprocessor_config.json", + "size": 266, + "digest": "sha256:fd575b890da5a949493e1d1e7a70bfcb9e4b99fe444004d2dbfa253add254741", + "layer": "sha256:d3458267e3c3179135ab97c11fb1a0da870c44be7c32d65c5d2b35bc91c46bf8" + } + ], + "layers": [ + { + "digest": "sha256:d3458267e3c3179135ab97c11fb1a0da870c44be7c32d65c5d2b35bc91c46bf8", + "mediaType": "application/vnd.oci.image.layer.v1.tar+gzip", + "size": 23727, + "sizeUncompressed": 69822 + }, + { + "digest": "sha256:dec23f069d100ce2979619413133cd208a35069daed2b13e77879ac2404994c5", + "mediaType": "application/vnd.oci.image.layer.v1.tar", + "size": 102485504, + "sizeUncompressed": 102482854 + } + ] + } + ] +} \ No newline at end of file diff --git a/go.mod b/go.mod index 88fcdcc975..2e6d1f38d0 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/docker/go-connections v0.6.0 github.com/getkin/kin-openapi v0.135.0 github.com/google/go-containerregistry v0.21.4 + github.com/hashicorp/go-retryablehttp v0.7.8 github.com/hashicorp/go-version v1.9.0 github.com/logrusorgru/aurora v2.0.3+incompatible github.com/mattn/go-isatty v0.0.21 @@ -84,6 +85,7 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/in-toto/attestation v1.1.2 // indirect github.com/in-toto/in-toto-golang v0.10.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 40e1c7bf50..4017fa5ee8 100644 --- a/go.sum +++ b/go.sum @@ -130,6 +130,12 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 h1:X+2YciYSxvMQK0UZ7sg45ZVabVZBeBuvMkmuI2V3Fak= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7/go.mod h1:lW34nIZuQ8UDPdkon5fmfp2l3+ZkQ2me/+oecHYLOII= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= +github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= github.com/hashicorp/go-version v1.9.0 h1:CeOIz6k+LoN3qX9Z0tyQrPtiB1DFYRPfCIBtaXPSCnA= github.com/hashicorp/go-version v1.9.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/in-toto/attestation v1.1.2 h1:MBFn6lsMq6dptQZJBhalXTcWMb/aJy3V+GX3VYj/V1E= diff --git a/integration-tests/harness/harness.go b/integration-tests/harness/harness.go index 57e0e3d3af..36fbeb381c 100644 --- a/integration-tests/harness/harness.go +++ b/integration-tests/harness/harness.go @@ -4,12 +4,10 @@ package harness import ( "context" cryptorand "crypto/rand" - "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "io" - mathrand "math/rand/v2" "net" "net/http" "os" @@ -252,7 +250,6 @@ func (h *Harness) Commands() map[string]func(ts *testscript.TestScript, neg bool NewCommand("registry-inspect", h.cmdRegistryInspect), NewCommand("registry-seed", h.cmdRegistrySeed), NewCommand("docker-push", h.cmdDockerPush), - NewCommand("mock-weights", h.cmdMockWeights), // Mock upload server commands NewCommand("upload-server-start", h.cmdUploadServerStart), @@ -1063,183 +1060,6 @@ func (h *Harness) cmdDockerPush(ts *testscript.TestScript, neg bool, args []stri ts.Logf("docker-push: pushed %s to %s", localImage, remoteRef) } -// ============================================================================= -// Mock weights command -// ============================================================================= - -// mockWeightsLock mirrors the structure from pkg/model/weights_lock.go -// SYNC: If pkg/model/WeightsLock changes, update this copy. -// We duplicate it here to avoid importing pkg/model which transitively imports pkg/wheels. -type mockWeightsLock struct { - Version string `json:"version"` - Created time.Time `json:"created"` - Files []mockWeightFile `json:"files"` -} - -// mockWeightFile mirrors WeightFile from pkg/model/weights.go -// SYNC: If pkg/model/WeightFile changes, update this copy. -type mockWeightFile struct { - Name string `json:"name"` - Dest string `json:"dest"` - DigestOriginal string `json:"digestOriginal"` - Digest string `json:"digest"` - Size int64 `json:"size"` - SizeUncompressed int64 `json:"sizeUncompressed"` - MediaType string `json:"mediaType"` - ContentType string `json:"contentType,omitempty"` -} - -// cmdMockWeights generates mock weight files and a weights.lock file. -// Usage: mock-weights [--count N] [--min-size S] [--max-size S] -// Defaults: -// - count: 2 -// - min-size: 1kb -// - max-size: 10kb -// -// Creates files in $WORK/weights/ and writes $WORK/weights.lock -func (h *Harness) cmdMockWeights(ts *testscript.TestScript, neg bool, args []string) { - if neg { - ts.Fatalf("mock-weights: does not support negation") - } - - // Parse arguments - count := 2 - minSize := int64(1024) // 1KB - maxSize := int64(10 * 1024) // 10KB - - for i := 0; i < len(args); i++ { - switch args[i] { - case "--count", "-n": - if i+1 < len(args) { - if n, err := strconv.Atoi(args[i+1]); err == nil { - count = n - } - i++ - } - case "--min-size": - if i+1 < len(args) { - if size, err := parseSize(args[i+1]); err == nil { - minSize = size - } - i++ - } - case "--max-size": - if i+1 < len(args) { - if size, err := parseSize(args[i+1]); err == nil { - maxSize = size - } - i++ - } - } - } - - workDir := ts.Getenv("WORK") - weightsDir := filepath.Join(workDir, "weights") - lockPath := filepath.Join(workDir, "weights.lock") - - // Create weights directory - if err := os.MkdirAll(weightsDir, 0o755); err != nil { - ts.Fatalf("mock-weights: failed to create weights dir: %v", err) - } - - var files []mockWeightFile - - for i := 1; i <= count; i++ { - // Random size between min and max - size := minSize - if maxSize > minSize { - size = minSize + mathrand.Int64N(maxSize-minSize+1) //nolint:gosec // test data, not security-sensitive - } - - // Generate identifier (e.g., "weights-001") - weightName := fmt.Sprintf("weights-%03d", i) - filename := weightName + ".bin" - filePath := filepath.Join(weightsDir, filename) - - // Generate random data - data := make([]byte, size) - if _, err := cryptorand.Read(data); err != nil { - ts.Fatalf("mock-weights: failed to generate random data: %v", err) - } - - // Write file - if err := os.WriteFile(filePath, data, 0o644); err != nil { - ts.Fatalf("mock-weights: failed to write %s: %v", filename, err) - } - - // Compute digest (uncompressed, since we're not actually compressing for tests) - hash := sha256.Sum256(data) - digest := "sha256:" + hex.EncodeToString(hash[:]) - - files = append(files, mockWeightFile{ - Name: weightName, - Dest: "/cache/" + filename, - DigestOriginal: digest, - Digest: digest, // Same as original since we're not compressing - Size: size, - SizeUncompressed: size, - // MediaType matches production WeightBuilder output (uncompressed). - MediaType: "application/vnd.cog.weight.layer.v1", - ContentType: "application/octet-stream", - }) - } - - // Create weights.lock - lock := mockWeightsLock{ - Version: "1.0", - Created: time.Now().UTC(), - Files: files, - } - - lockData, err := json.MarshalIndent(lock, "", " ") - if err != nil { - ts.Fatalf("mock-weights: failed to marshal weights.lock: %v", err) - } - - if err := os.WriteFile(lockPath, lockData, 0o644); err != nil { - ts.Fatalf("mock-weights: failed to write weights.lock: %v", err) - } - - ts.Logf("mock-weights: created %d files in %s", count, weightsDir) -} - -// parseSize parses size strings like "1kb", "10KB", "1mb" into bytes. -func parseSize(s string) (int64, error) { - s = strings.TrimSpace(strings.ToLower(s)) - if s == "" { - return 0, fmt.Errorf("empty size string") - } - - var multiplier int64 = 1 - var numStr string - - switch { - case strings.HasSuffix(s, "gb"): - multiplier = 1024 * 1024 * 1024 - numStr = strings.TrimSuffix(s, "gb") - case strings.HasSuffix(s, "mb"): - multiplier = 1024 * 1024 - numStr = strings.TrimSuffix(s, "mb") - case strings.HasSuffix(s, "kb"): - multiplier = 1024 - numStr = strings.TrimSuffix(s, "kb") - case strings.HasSuffix(s, "b"): - numStr = strings.TrimSuffix(s, "b") - default: - numStr = s - } - - num, err := strconv.ParseFloat(strings.TrimSpace(numStr), 64) - if err != nil { - return 0, fmt.Errorf("invalid number: %s", numStr) - } - if num < 0 { - return 0, fmt.Errorf("size cannot be negative") - } - - return int64(num * float64(multiplier)), nil -} - // ============================================================================= // Mock upload server commands // ============================================================================= diff --git a/integration-tests/tests/oci_bundle_build.txtar b/integration-tests/tests/oci_bundle_build.txtar deleted file mode 100644 index 3e831841d6..0000000000 --- a/integration-tests/tests/oci_bundle_build.txtar +++ /dev/null @@ -1,46 +0,0 @@ -# Test building an OCI bundle with declarative weights in cog.yaml. -# Verifies: cog.yaml weights declaration -> cog weights build -> cog build (COG_OCI_INDEX=1) -# The image should build successfully and predictions should work. - -# Create weight files (small, deterministic) -mkdir weights -exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "A" > weights/model-a.bin' -exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "B" > weights/model-b.bin' - -# Step 1: Build weights.lock from cog.yaml declarations -cog weights build -stderr 'Generated weights.lock' -stderr '2 file' -exists weights.lock - -# Step 2: Build with OCI index mode enabled -env COG_OCI_INDEX=1 -cog build -t $TEST_IMAGE -stderr 'Image built as' - -# Verify image was built -exec docker image inspect $TEST_IMAGE -stdout 'run.cog.config' - -# Verify prediction works -cog predict $TEST_IMAGE -i text=hello -stdout 'processed: hello' - --- cog.yaml -- -build: - python_version: "3.12" -predict: "predict.py:Predictor" -weights: - - name: alpha - source: weights/model-a.bin - target: /weights/model-a.bin - - name: beta - source: weights/model-b.bin - target: /weights/model-b.bin - --- predict.py -- -from cog import BasePredictor - -class Predictor(BasePredictor): - def predict(self, text: str) -> str: - return f"processed: {text}" diff --git a/integration-tests/tests/oci_bundle_inspect.txtar b/integration-tests/tests/oci_bundle_inspect.txtar deleted file mode 100644 index 5cc96641d8..0000000000 --- a/integration-tests/tests/oci_bundle_inspect.txtar +++ /dev/null @@ -1,63 +0,0 @@ -# Test cog inspect on a pushed OCI bundle with declarative weights. -# Verifies: push bundle -> cog inspect --remote --json shows correct structure. -# The inspect output should show an OCI index with image + weight manifests. - -[short] skip 'requires local registry' - -# Start test registry -registry-start - -# Create weight files (small, deterministic) -mkdir weights -exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "A" > weights/model-a.bin' -exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "B" > weights/model-b.bin' - -# Build weights.lock -cog weights build -stderr 'Generated weights.lock' -exists weights.lock - -# Build and push with OCI index mode -env COG_OCI_INDEX=1 -cog push $TEST_REGISTRY/test/inspect-model:v1 - -# Inspect the pushed bundle -cog inspect --remote --json $TEST_REGISTRY/test/inspect-model:v1 - -# Verify it's an OCI index -stdout '"type": "index"' - -# Verify image manifest is present -stdout '"type": "image"' - -# Verify weight manifests are present with correct names -stdout '"type": "weights"' -stdout '"name": "alpha"' -stdout '"name": "beta"' - -# Verify weight targets -stdout '"target": "/weights/model-a.bin"' -stdout '"target": "/weights/model-b.bin"' - -# Verify layers are populated -stdout '"layers"' -stdout '"digest": "sha256:' - --- cog.yaml -- -build: - python_version: "3.12" -predict: "predict.py:Predictor" -weights: - - name: alpha - source: weights/model-a.bin - target: /weights/model-a.bin - - name: beta - source: weights/model-b.bin - target: /weights/model-b.bin - --- predict.py -- -from cog import BasePredictor - -class Predictor(BasePredictor): - def predict(self, text: str) -> str: - return f"processed: {text}" diff --git a/integration-tests/tests/oci_bundle_push.txtar b/integration-tests/tests/oci_bundle_push.txtar index cbc35f8816..622d79b76c 100644 --- a/integration-tests/tests/oci_bundle_push.txtar +++ b/integration-tests/tests/oci_bundle_push.txtar @@ -1,49 +1,62 @@ -# Test pushing an OCI bundle with declarative weights via cog push. -# Verifies: cog.yaml weights -> cog weights build -> cog push (BundlePusher path) -# The push should create an OCI index with image + weight manifests. +# Exercise the full `cog push` bundle path with v1 managed weights: +# cog.yaml weights -> implicit cog weights build -> cog push -> OCI index. +# The pushed artifact is an OCI image index carrying the model image +# manifest plus one weight manifest (application/vnd.cog.weight.v1) per +# cog.yaml entry. [short] skip 'requires local registry' -# Start test registry registry-start -# Create weight files (small, deterministic) -mkdir weights -exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "A" > weights/model-a.bin' -exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "B" > weights/model-b.bin' +# Two weight directories. Each holds a single small file; the packer +# emits a bundle layer for all the small files under the default +# threshold so we end up with one layer per weight entry. +mkdir weights-alpha +mkdir weights-beta +exec sh -c 'printf AAAA > weights-alpha/data.bin' +exec sh -c 'printf BBBB > weights-beta/data.bin' -# Step 1: Build weights.lock -cog weights build -stderr 'Generated weights.lock' -exists weights.lock +# image: is required in cog.yaml when weights are declared, and +# cog weights import reads it from there. +exec sh -c 'printf "image: %s/test/bundle-model\n" "$TEST_REGISTRY" >> cog.yaml' -# Step 2: Build and push with OCI index mode -env COG_OCI_INDEX=1 +# Step 1: import generates weights.lock and pushes weight manifests. +# cog push later requires the lockfile to already exist. +cog weights import + +# Step 2: cog push assembles the OCI index (image + weight manifests) +# and uploads it to the registry under the supplied tag. cog push $TEST_REGISTRY/test/bundle-model:v1 -# Verify push succeeded — should mention pushing -stderr -count=1 'Pushing' +# Image push header — weights are pushed during `cog weights import`, +# not during `cog push`. +stderr 'Pushing image ' -# Step 3: Verify the pushed artifact is an OCI index with image + weight manifests +# Resolve the pushed ref and verify the top-level index shape. registry-inspect $TEST_REGISTRY/test/bundle-model:v1 -# Verify it's an OCI index stdout 'application/vnd.oci.image.index.v1\+json' -# Verify weight annotations are present -stdout 'vnd.cog.reference.type.*weights' -stdout 'vnd.cog.weight.name.*alpha' -stdout 'vnd.cog.weight.name.*beta' +# v1 annotations on each weight descriptor (spec §2.5). +stdout 'run.cog.weight.name.*alpha' +stdout 'run.cog.weight.name.*beta' +stdout 'run.cog.weight.set-digest' -- cog.yaml -- +# image: is appended by the test before `cog push` (the value depends +# on $TEST_REGISTRY, which varies per run). cog push still +# pushes to the supplied ref, but image: must be set when weights are +# configured (validation requirement). build: python_version: "3.12" predict: "predict.py:Predictor" weights: - name: alpha - source: weights/model-a.bin - target: /weights/model-a.bin + target: /src/weights-alpha + source: + uri: file://./weights-alpha - name: beta - source: weights/model-b.bin - target: /weights/model-b.bin + target: /src/weights-beta + source: + uri: file://./weights-beta -- predict.py -- from cog import BasePredictor diff --git a/integration-tests/tests/weights_build.txtar b/integration-tests/tests/weights_build.txtar deleted file mode 100644 index b20f646666..0000000000 --- a/integration-tests/tests/weights_build.txtar +++ /dev/null @@ -1,44 +0,0 @@ -# Test that cog weights build generates weights.lock - -# Build should fail without weights section -! cog weights build -stderr 'no weights defined' - -# Add weights section and create weight file -cp cog-with-weights.yaml cog.yaml -mkdir models -exec sh -c 'echo "test model content" > models/model.bin' - -# Build weights.lock -cog weights build -stderr 'Generated weights.lock' -stderr '1 file' - -# Verify weights.lock was created -exists weights.lock - -# Verify weights.lock contains expected content -exec grep -q '"name": "model"' weights.lock -exec grep -q '"dest": "/cache/model.bin"' weights.lock -exec grep -q '"digestOriginal": "sha256:' weights.lock - --- cog.yaml -- -build: - python_version: "3.12" -predict: "predict.py:Predictor" - --- cog-with-weights.yaml -- -build: - python_version: "3.12" -predict: "predict.py:Predictor" -weights: - - name: model - source: models/model.bin - target: /cache/model.bin - --- predict.py -- -from cog import BasePredictor - -class Predictor(BasePredictor): - def predict(self, s: str) -> str: - return "hello " + s diff --git a/integration-tests/tests/weights_filter.txtar b/integration-tests/tests/weights_filter.txtar new file mode 100644 index 0000000000..9a66a806c0 --- /dev/null +++ b/integration-tests/tests/weights_filter.txtar @@ -0,0 +1,69 @@ +# Test that include/exclude patterns in cog.yaml filter the weight set +# during import. A file:// source with mixed file types is created, +# then imported with include/exclude patterns. The resulting lockfile +# and predict output verify that only the matching files were included. + +[short] skip 'requires local registry' + +env COG_CACHE_DIR=$WORK/cache + +registry-start + +# Build a weight directory with multiple file types. +mkdir weights-src +exec sh -c 'echo "safetensor-data" > weights-src/model.safetensors' +exec sh -c 'echo "config-data" > weights-src/config.json' +exec sh -c 'echo "tokenizer-data" > weights-src/tokenizer.json' +exec sh -c 'echo "onnx-data" > weights-src/model.onnx' +exec sh -c 'echo "pytorch-bin-data" > weights-src/pytorch_model.bin' +exec sh -c 'echo "readme-text" > weights-src/README.md' + +# Patch image to use ephemeral test registry. +exec sh -c 'printf "image: %s/test/filter-model\n" "$TEST_REGISTRY" >> cog.yaml' + +# Import with include/exclude filters. +cog weights import +exists weights.lock + +# The lockfile must contain the included files. +exec sh -c 'grep "model.safetensors" weights.lock' +exec sh -c 'grep "config.json" weights.lock' +exec sh -c 'grep "tokenizer.json" weights.lock' + +# The lockfile must NOT contain excluded / non-included files. +! exec sh -c 'grep "model.onnx" weights.lock' +! exec sh -c 'grep "pytorch_model.bin" weights.lock' +! exec sh -c 'grep "README.md" weights.lock' + +# Predict verifies that only filtered files are mounted. +cog predict -i s=check +stdout 'files=config.json,model.safetensors,tokenizer.json' + +-- cog.yaml -- +# image: is appended by the test before `cog weights import` (the +# value depends on $TEST_REGISTRY, which varies per run). +build: + python_version: "3.12" +predict: "predict.py:Predictor" +weights: + - name: filtered + target: /src/weights/filtered + source: + uri: file://./weights-src + include: + - "*.safetensors" + - "*.json" + exclude: + - "README*" + +-- predict.py -- +import os + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + weight_dir = "/src/weights/filtered" + files = sorted(os.listdir(weight_dir)) + return f"files={','.join(files)}" diff --git a/integration-tests/tests/weights_import_predict.txtar b/integration-tests/tests/weights_import_predict.txtar new file mode 100644 index 0000000000..baba300fc0 --- /dev/null +++ b/integration-tests/tests/weights_import_predict.txtar @@ -0,0 +1,62 @@ +# End-to-end test that `cog weights import` warms the local store +# enough for `cog predict` to mount the weights without a separate +# `cog weights pull`. This is the user-facing promise of cog-i12u. +# +# Compare to weights_pull_predict.txtar, which exercises the +# import -> pull -> predict flow (still useful for the case where +# weights.lock is checked in but the local cache is cold). + +[short] skip 'requires local registry' + +env COG_CACHE_DIR=$WORK/cache + +registry-start + +# Build a deterministic weight directory. +mkdir weights-src +exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "X" > weights-src/greeting.txt' + +# Patch cog.yaml to set image: to the ephemeral test registry. +exec sh -c 'printf "image: %s/test/import-predict-model\n" "$TEST_REGISTRY" >> cog.yaml' + +# Step 1: import. After this, the local content store under +# $COG_CACHE_DIR/weights/files/sha256/ must contain greeting.txt's +# bytes — that's the cog-i12u guarantee. +cog weights import +exists weights.lock + +# Step 2: predict, with NO intervening `cog weights pull`. The +# import path warmed the store; predict's hardlink-assemble must +# work straight off the back of import. +cog predict -i s=world +stdout 'hello world \(weight-size=1024\)' + +# Per-invocation mount dir must be cleaned up after predict exits. +exec sh -c 'test ! -d .cog/mounts || test -z "$(ls -A .cog/mounts)"' + +-- cog.yaml -- +# image: is appended by the test before `cog weights import` (the +# value depends on $TEST_REGISTRY, which varies per run). +build: + python_version: "3.12" +predict: "predict.py:Predictor" +weights: + - name: greeting + target: /src/mounted-weights/greeting + source: + uri: file://./weights-src + +-- predict.py -- +import os + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + # Prove the weight was mounted at the target path by reading + # its size. Any failure to mount would raise FileNotFoundError + # here and the prediction would fail. + path = "/src/mounted-weights/greeting/greeting.txt" + size = os.path.getsize(path) + return f"hello {s} (weight-size={size})" diff --git a/integration-tests/tests/weights_pull.txtar b/integration-tests/tests/weights_pull.txtar new file mode 100644 index 0000000000..d8fea01d83 --- /dev/null +++ b/integration-tests/tests/weights_pull.txtar @@ -0,0 +1,100 @@ +# Test the `cog weights pull` flow against a local registry. +# Verifies: import -> pull (cold cache) -> pull (warm cache) -> verbose +# output -> unknown-name error handling. + +[short] skip 'requires local registry' + +# Isolate the weight cache under $WORK so the test doesn't touch +# $XDG_CACHE_HOME/cog and so assertions about cache state are reliable +# across runs. +env COG_CACHE_DIR=$WORK/cache + +registry-start + +# Build deterministic weight directories (file:// sources must be directories). +mkdir weights-alpha +mkdir weights-beta +exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "A" > weights-alpha/model.bin' +exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "B" > weights-beta/model.bin' + +# Step 1: import weights to the local registry. This populates +# weights.lock, pushes the layers, and warms the local cache as a +# side effect (cog-i12u: import == build + push + populate cache). +cog weights import --image $TEST_REGISTRY/test/pull-model +stderr 'Pushing 2 weight' +exists weights.lock + +# Purge the cache so the next step exercises the cold-pull path. +# The realistic scenario for `cog weights pull` is "lockfile checked +# in, but the local cache is empty" — e.g. a fresh clone. +exec sh -c 'rm -rf $WORK/cache/weights/files' +! exists $WORK/cache/weights/files/sha256 + +# Step 2: pull from a cold cache. Expect both weights to be fetched. +cog weights pull --image $TEST_REGISTRY/test/pull-model +stderr 'Pulling alpha' +stderr 'Pulling beta' +stderr 'done' +stderr 'Pulled' + +# Cache now has files under the sha256 prefix layout. +exists $WORK/cache/weights/files/sha256 + +# Step 3: pull again — everything is cached, no registry I/O needed. +cog weights pull --image $TEST_REGISTRY/test/pull-model +stderr 'cached' +stderr 'All 2 weight\(s\) already cached' + +# Step 4: verbose mode surfaces cache dir, manifest ref, and per-file +# lines. +cog weights pull --image $TEST_REGISTRY/test/pull-model --verbose +stderr 'Cache:' +stderr 'Lockfile:' +# Everything is cached so we only see the per-weight "cached" lines +# plus the summary — not layer/file detail. +stderr 'cached \(1/1 files\)' + +# Step 5: delete one blob from the cache and verify pull restores just +# it. We pick any file under the sha256 tree. +exec sh -c 'find $WORK/cache/weights/files/sha256 -type f | head -1 | xargs rm' +cog weights pull --image $TEST_REGISTRY/test/pull-model --verbose +# The weight with the missing file reports 1 missing / 1 total; the +# other reports as cached. Verbose mode also prints per-layer detail +# (short digest + size). +stderr 'missing / 1 total' +stderr ' layer [0-9a-f]+' + +# Step 6: unknown weight name errors with every missing name listed, +# and the cache is unchanged. +! cog weights pull --image $TEST_REGISTRY/test/pull-model alpha nope typo +stderr 'nope' +stderr 'typo' + +# Step 7: named filter pulls only the requested weight. Delete both +# cache trees and re-pull only alpha. +exec sh -c 'rm -rf $WORK/cache/weights/files' +cog weights pull --image $TEST_REGISTRY/test/pull-model alpha +stderr 'Pulling alpha' +! stderr 'Pulling beta' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" +image: test/pull-model +weights: + - name: alpha + target: /src/weights/alpha + source: + uri: file://./weights-alpha + - name: beta + target: /src/weights/beta + source: + uri: file://./weights-beta + +-- predict.py -- +from cog import BasePredictor + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + return "hello " + s diff --git a/integration-tests/tests/weights_pull_predict.txtar b/integration-tests/tests/weights_pull_predict.txtar new file mode 100644 index 0000000000..26502c2ea0 --- /dev/null +++ b/integration-tests/tests/weights_pull_predict.txtar @@ -0,0 +1,71 @@ +# End-to-end test of the managed-weights local-run flow. +# Verifies: cog weights import -> cog weights pull -> cog predict, where +# the predictor container sees the weight files at the configured target +# path. Also verifies mount cleanup after predict exits. + +[short] skip 'requires local registry' + +env COG_CACHE_DIR=$WORK/cache + +registry-start + +# Build a deterministic weight directory. +mkdir weights-src +exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "X" > weights-src/greeting.txt' + +# Patch cog.yaml to set image: to the ephemeral test registry. This +# makes the `cog predict` path (which has no --image flag) resolve +# the repo via src.Config.Image when constructing the WeightManager. +exec sh -c 'printf "image: %s/test/pull-predict-model\n" "$TEST_REGISTRY" >> cog.yaml' + +# Step 1: import to the local registry. This also warms the local +# cache as a side effect of import (cog-i12u guarantee). +cog weights import +exists weights.lock + +# Purge the cache so step 2 exercises pull's cold-fetch path. The +# realistic scenario for `cog weights pull` is "lockfile checked in, +# cache empty" — e.g. a fresh clone. +exec sh -c 'rm -rf $WORK/cache/weights/files' + +# Step 2: pull into the isolated cache. +cog weights pull +stderr 'Pulled' + +# Step 3: predict. The predictor opens the file at the configured +# target path and returns its size, proving the mount is visible +# and readable inside the container. +cog predict -i s=world +stdout 'hello world \(weight-size=1024\)' + +# Step 4: the per-invocation mount dir must be cleaned up after +# predict exits. .cog/mounts/ either doesn't exist (everything +# cleaned) or is empty. +exec sh -c 'test ! -d .cog/mounts || test -z "$(ls -A .cog/mounts)"' + +-- cog.yaml -- +# image: is appended by the test before `cog weights import` (the +# value depends on $TEST_REGISTRY, which varies per run). +build: + python_version: "3.12" +predict: "predict.py:Predictor" +weights: + - name: greeting + target: /src/mounted-weights/greeting + source: + uri: file://./weights-src + +-- predict.py -- +import os + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + # Prove the weight was mounted at the target path by reading + # its size. Any failure to mount would raise FileNotFoundError + # here and the prediction would fail. + path = "/src/mounted-weights/greeting/greeting.txt" + size = os.path.getsize(path) + return f"hello {s} (weight-size={size})" diff --git a/integration-tests/tests/weights_push_inspect.txtar b/integration-tests/tests/weights_push_inspect.txtar deleted file mode 100644 index 9747f3ff21..0000000000 --- a/integration-tests/tests/weights_push_inspect.txtar +++ /dev/null @@ -1,66 +0,0 @@ -# Test weights push and inspect lifecycle against a local registry. -# Verifies: cog weights build -> cog weights push -> cog weights inspect (synced). - -[short] skip 'requires local registry' - -# Start test registry -registry-start - -# Create weight files (small, deterministic) -mkdir weights -exec sh -c 'dd if=/dev/zero bs=512 count=1 2>/dev/null | tr "\0" "A" > weights/model-a.bin' -exec sh -c 'dd if=/dev/zero bs=512 count=1 2>/dev/null | tr "\0" "B" > weights/model-b.bin' - -# Step 1: Build weights.lock -cog weights build -stderr 'Generated weights.lock' -stderr '2 file' -exists weights.lock - -# Verify lock file structure -exec grep -q '"name": "alpha"' weights.lock -exec grep -q '"name": "beta"' weights.lock -exec grep -q '"digest": "sha256:' weights.lock - -# Step 2: Push weights to local registry (repo only, no tag) -cog weights push $TEST_REGISTRY/test/weights-model -stderr 'Pushed 2 weight artifact' -# Push output should show the full ref for each weight -stderr 'weights-alpha-' -stderr 'weights-beta-' - -# Verify tags with :tag are rejected -! cog weights push $TEST_REGISTRY/test/weights-model:v1 -stderr 'includes a tag or digest' - -# Step 3: Inspect — both weights should be synced -cog weights inspect $TEST_REGISTRY/test/weights-model --json -stdout '"status": "synced"' -! stdout '"status": "local-only"' -! stdout '"status": "digest-mismatch"' -# Inspect should show the ref and layers for each weight -stdout '"ref":' -stdout '"layers":' - -# Verify tags with :tag are rejected for inspect too -! cog weights inspect $TEST_REGISTRY/test/weights-model:v1 -stderr 'includes a tag or digest' - --- cog.yaml -- -build: - python_version: "3.12" -predict: "predict.py:Predictor" -weights: - - name: alpha - source: weights/model-a.bin - target: /weights/model-a.bin - - name: beta - source: weights/model-b.bin - target: /weights/model-b.bin - --- predict.py -- -from cog import BasePredictor - -class Predictor(BasePredictor): - def predict(self, s: str) -> str: - return "hello " + s diff --git a/mise.lock b/mise.lock index 15d53cfe37..57ff0ae500 100644 --- a/mise.lock +++ b/mise.lock @@ -39,37 +39,30 @@ backend = "aqua:golangci/golangci-lint" [tools."aqua:golangci/golangci-lint"."platforms.linux-arm64"] checksum = "sha256:6652b42ae02915eb2f9cb2a2e0cac99514c8eded8388d88ae3e06e1a52c00de8" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-arm64.tar.gz" -provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.linux-arm64-musl"] checksum = "sha256:6652b42ae02915eb2f9cb2a2e0cac99514c8eded8388d88ae3e06e1a52c00de8" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-arm64.tar.gz" -provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.linux-x64"] checksum = "sha256:dfa775874cf0561b404a02a8f4481fc69b28091da95aa697259820d429b09c99" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-amd64.tar.gz" -provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.linux-x64-musl"] checksum = "sha256:dfa775874cf0561b404a02a8f4481fc69b28091da95aa697259820d429b09c99" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-amd64.tar.gz" -provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.macos-arm64"] checksum = "sha256:03bfadf67e52b441b7ec21305e501c717df93c959836d66c7f97312654acb297" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-darwin-arm64.tar.gz" -provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.macos-x64"] checksum = "sha256:66fb0da81b8033b477f97eea420d4b46b230ca172b8bb87c6610109f3772b6b6" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-darwin-amd64.tar.gz" -provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.windows-x64"] checksum = "sha256:c60c87695e79db8e320f0e5be885059859de52bb5ee5f11be5577828570bc2a3" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-windows-amd64.zip" -provenance = "github-attestations" [[tools."aqua:gotestyourself/gotestsum"]] version = "1.13.0" @@ -168,16 +161,16 @@ checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f8 url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.linux-arm64-musl"] -checksum = "sha256:a97c8f56d7462908695348dd8c71ea6740c138ce303715793a690503a94fc9a9" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-musl/rustup-init" +checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f841c" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.linux-x64"] checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.linux-x64-musl"] -checksum = "sha256:e6599a1c7be58a2d8eaca66a80e0dc006d87bbcf780a58b7343d6e14c1605cb2" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-musl/rustup-init" +checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.macos-arm64"] checksum = "sha256:20ef5516c31b1ac2290084199ba77dbbcaa1406c45c1d978ca68558ef5964ef5" @@ -200,16 +193,16 @@ checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f8 url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.linux-arm64-musl"] -checksum = "sha256:a97c8f56d7462908695348dd8c71ea6740c138ce303715793a690503a94fc9a9" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-musl/rustup-init" +checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f841c" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.linux-x64"] checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.linux-x64-musl"] -checksum = "sha256:e6599a1c7be58a2d8eaca66a80e0dc006d87bbcf780a58b7343d6e14c1605cb2" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-musl/rustup-init" +checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.macos-arm64"] checksum = "sha256:20ef5516c31b1ac2290084199ba77dbbcaa1406c45c1d978ca68558ef5964ef5" @@ -229,32 +222,25 @@ backend = "aqua:ziglang/zig" [tools."aqua:ziglang/zig"."platforms.linux-arm64"] url = "https://ziglang.org/download/0.15.2/zig-aarch64-linux-0.15.2.tar.xz" -provenance = "minisign" [tools."aqua:ziglang/zig"."platforms.linux-arm64-musl"] url = "https://ziglang.org/download/0.15.2/zig-aarch64-linux-0.15.2.tar.xz" -provenance = "minisign" [tools."aqua:ziglang/zig"."platforms.linux-x64"] url = "https://ziglang.org/download/0.15.2/zig-x86_64-linux-0.15.2.tar.xz" -provenance = "minisign" [tools."aqua:ziglang/zig"."platforms.linux-x64-musl"] url = "https://ziglang.org/download/0.15.2/zig-x86_64-linux-0.15.2.tar.xz" -provenance = "minisign" [tools."aqua:ziglang/zig"."platforms.macos-arm64"] checksum = "blake3:c7d2fb746701fea2c070f66c29a0300ba42a0d5d2c09c493462b8c0f4f0bd604" url = "https://ziglang.org/download/0.15.2/zig-aarch64-macos-0.15.2.tar.xz" -provenance = "minisign" [tools."aqua:ziglang/zig"."platforms.macos-x64"] url = "https://ziglang.org/download/0.15.2/zig-x86_64-macos-0.15.2.tar.xz" -provenance = "minisign" [tools."aqua:ziglang/zig"."platforms.windows-x64"] url = "https://ziglang.org/download/0.15.2/zig-x86_64-windows-0.15.2.zip" -provenance = "minisign" [[tools.cargo-binstall]] version = "1.16.6" @@ -264,18 +250,10 @@ backend = "aqua:cargo-bins/cargo-binstall" checksum = "sha256:b556421835ba67fa98ca1570c85b5511457956b7836ce938b47d3f73899517a3" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-aarch64-unknown-linux-musl.tgz" -[tools.cargo-binstall."platforms.linux-arm64-musl"] -checksum = "sha256:b556421835ba67fa98ca1570c85b5511457956b7836ce938b47d3f73899517a3" -url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-aarch64-unknown-linux-musl.tgz" - [tools.cargo-binstall."platforms.linux-x64"] checksum = "sha256:3225eea8041c30d7462761a481883e3aa8fe31c58def4b6c8dd91b7c80973df0" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-x86_64-unknown-linux-musl.tgz" -[tools.cargo-binstall."platforms.linux-x64-musl"] -checksum = "sha256:3225eea8041c30d7462761a481883e3aa8fe31c58def4b6c8dd91b7c80973df0" -url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-x86_64-unknown-linux-musl.tgz" - [tools.cargo-binstall."platforms.macos-arm64"] checksum = "sha256:30543b378b96fbddabee1edfaccde7914dd2f851f02c560de859f81a21ab665b" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-aarch64-apple-darwin.zip" @@ -288,10 +266,22 @@ url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/ca checksum = "sha256:fca962c3d12ae6192280111074db073c15abad3ba162a1a5a2af0f6f01872114" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-x86_64-pc-windows-msvc.zip" +[[tools."cargo:cargo-deny"]] +version = "0.19.0" +backend = "cargo:cargo-deny" + +[[tools."cargo:cargo-insta"]] +version = "1.46.0" +backend = "cargo:cargo-insta" + [[tools."cargo:cargo-nextest"]] version = "0.9.120" backend = "cargo:cargo-nextest" +[[tools."cargo:cargo-zigbuild"]] +version = "0.20.1" +backend = "cargo:cargo-zigbuild" + [[tools."cargo:maturin"]] version = "1.11.5" backend = "cargo:maturin" @@ -328,6 +318,10 @@ url = "https://dl.google.com/go/go1.25.6.darwin-amd64.tar.gz" checksum = "sha256:19b4733b727ba5c611b5656187f3ac367d278d64c3d4199a845e39c0fdac5335" url = "https://dl.google.com/go/go1.25.6.windows-amd64.zip" +[[tools."go:golang.org/x/tools/cmd/goimports"]] +version = "0.44.0" +backend = "go:golang.org/x/tools/cmd/goimports" + [[tools."npm:markdownlint-cli2"]] version = "0.22.0" backend = "npm:markdownlint-cli2" @@ -344,6 +338,10 @@ backend = "pipx:nox" uvx = "true" uvx_args = "--python-preference=managed -p 3.13" +[[tools.python]] +version = "3.12.12" +backend = "core:python" + [[tools.ruff]] version = "0.14.13" backend = "aqua:astral-sh/ruff" @@ -419,34 +417,51 @@ backend = "aqua:astral-sh/uv" [tools.uv."platforms.linux-arm64"] checksum = "sha256:ba8698c36c00c22efed4bd3506339b03c95604d001f02eaf6fbc814c9224d801" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-aarch64-unknown-linux-musl.tar.gz" -provenance = "github-attestations" [tools.uv."platforms.linux-arm64-musl"] checksum = "sha256:ba8698c36c00c22efed4bd3506339b03c95604d001f02eaf6fbc814c9224d801" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-aarch64-unknown-linux-musl.tar.gz" -provenance = "github-attestations" [tools.uv."platforms.linux-x64"] checksum = "sha256:708b752876aeeb753257e1d55470569789e465684c1d3bc1760db26360b6c28b" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-unknown-linux-musl.tar.gz" -provenance = "github-attestations" [tools.uv."platforms.linux-x64-musl"] checksum = "sha256:708b752876aeeb753257e1d55470569789e465684c1d3bc1760db26360b6c28b" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-unknown-linux-musl.tar.gz" -provenance = "github-attestations" [tools.uv."platforms.macos-arm64"] checksum = "sha256:fcf0a9ea6599c6ae28a4c854ac6da76f2c889354d7c36ce136ef071f7ab9721f" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-aarch64-apple-darwin.tar.gz" -provenance = "github-attestations" [tools.uv."platforms.macos-x64"] checksum = "sha256:171eb8c518313e157c5b4cec7b4f743bc6bab1bd23e09b646679a02d096a047f" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-apple-darwin.tar.gz" -provenance = "github-attestations" [tools.uv."platforms.windows-x64"] checksum = "sha256:eb02fd95d8e0eed462b4a67ecdd320d865b38c560bffcda9a0b87ec944bdf036" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-pc-windows-msvc.zip" -provenance = "github-attestations" + +[[tools.zig]] +version = "0.15.2" +backend = "core:zig" + +[tools.zig."platforms.linux-arm64"] +checksum = "sha256:958ed7d1e00d0ea76590d27666efbf7a932281b3d7ba0c6b01b0ff26498f667f" +url = "https://ziglang.org/download/0.15.2/zig-aarch64-linux-0.15.2.tar.xz" + +[tools.zig."platforms.linux-x64"] +checksum = "sha256:02aa270f183da276e5b5920b1dac44a63f1a49e55050ebde3aecc9eb82f93239" +url = "https://ziglang.org/download/0.15.2/zig-x86_64-linux-0.15.2.tar.xz" + +[tools.zig."platforms.macos-arm64"] +checksum = "sha256:3cc2bab367e185cdfb27501c4b30b1b0653c28d9f73df8dc91488e66ece5fa6b" +url = "https://ziglang.org/download/0.15.2/zig-aarch64-macos-0.15.2.tar.xz" + +[tools.zig."platforms.macos-x64"] +checksum = "sha256:375b6909fc1495d16fc2c7db9538f707456bfc3373b14ee83fdd3e22b3d43f7f" +url = "https://ziglang.org/download/0.15.2/zig-x86_64-macos-0.15.2.tar.xz" + +[tools.zig."platforms.windows-x64"] +checksum = "sha256:3a0ed1e8799a2f8ce2a6e6290a9ff22e6906f8227865911fb7ddedc3cc14cb0c" +url = "https://ziglang.org/download/0.15.2/zig-x86_64-windows-0.15.2.zip" diff --git a/mise.toml b/mise.toml index 8f9b93f305..1f5d90c7e6 100644 --- a/mise.toml +++ b/mise.toml @@ -36,7 +36,7 @@ go = "latest" uv = "0.9.26" "pipx:nox" = { version = "2025.11.12", uvx = true, uvx_args = "--python-preference=managed -p 3.13" } "aqua:ziglang/zig" = "0.15.2" -"rust" = { version = "1.93.0", components = "rustfmt,clippy", targets ="x86_64-unknown-linux-gnu,aarch64-unknown-linux-gnu,aarch64-apple-darwin" } +"rust" = { version = "1.93.0", components = "rustfmt,clippy,rust-analyzer", targets ="x86_64-unknown-linux-gnu,aarch64-unknown-linux-gnu,aarch64-apple-darwin" } cargo-binstall = "1.16.6" # Cargo tools - use aqua backend where available for faster binary downloads # and better security (cosign/SLSA verification). Remaining cargo: tools use binstall. @@ -53,6 +53,7 @@ ruff = "0.14.13" ty = "0.0.10" "npm:prettier" = "3.6.2" "npm:markdownlint-cli2" = "0.22.0" +"go:golang.org/x/tools/cmd/goimports" = "latest" [env] _.path = "./bin" diff --git a/pkg/cli/build.go b/pkg/cli/build.go index 8e52b5924a..6c0dcd58b6 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -13,6 +13,7 @@ import ( "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/registry" "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/weights" ) var ( @@ -87,6 +88,10 @@ func buildCommand(cmd *cobra.Command, args []string) error { return err } + if err := weights.CheckDrift(src.ProjectDir, src.Config.Weights); err != nil { + return err + } + imageName := src.Config.Image if buildTag != "" { imageName = buildTag @@ -219,7 +224,6 @@ func buildOptionsFromFlags(cmd *cobra.Command, imageName string, annotations map Strip: buildStrip, Precompile: buildPrecompile, Annotations: annotations, - OCIIndex: model.OCIIndexEnabled(), SkipSchemaValidation: buildSkipSchemaValidation, } } diff --git a/pkg/cli/inspect.go b/pkg/cli/inspect.go deleted file mode 100644 index d6377a323b..0000000000 --- a/pkg/cli/inspect.go +++ /dev/null @@ -1,404 +0,0 @@ -package cli - -import ( - "context" - "encoding/json" - "fmt" - "os" - "strings" - - "github.com/spf13/cobra" - - "github.com/replicate/cog/pkg/docker" - "github.com/replicate/cog/pkg/model" - "github.com/replicate/cog/pkg/registry" -) - -// InspectOutput is the structured output for cog inspect --json. -type InspectOutput struct { - Reference string `json:"reference"` - Type string `json:"type"` // "image" or "index" - CogVersion string `json:"cogVersion"` - Index *InspectIndex `json:"index,omitempty"` - Image *InspectManifest `json:"image,omitempty"` -} - -// InspectIndex represents an OCI index in inspect output. -type InspectIndex struct { - Reference string `json:"reference"` - Digest string `json:"digest"` - MediaType string `json:"mediaType"` - Manifests []InspectManifest `json:"manifests"` -} - -// InspectManifest represents a manifest entry in inspect output. -type InspectManifest struct { - Type string `json:"type"` // "image" or "weights" - Name string `json:"name,omitempty"` // weight name from AnnotationWeightName - Digest string `json:"digest"` - MediaType string `json:"mediaType"` - Size int64 `json:"size"` - Platform string `json:"platform,omitempty"` // "linux/amd64" - Target string `json:"target,omitempty"` // weight mount path from AnnotationWeightDest - Annotations map[string]string `json:"annotations,omitempty"` - Layers []InspectLayer `json:"layers"` -} - -// InspectLayer represents a layer in inspect output. -type InspectLayer struct { - Digest string `json:"digest"` - Size int64 `json:"size"` - MediaType string `json:"mediaType"` -} - -func newInspectCommand() *cobra.Command { - var ( - localOnly bool - remoteOnly bool - jsonOutput bool - rawOutput bool - ) - - cmd := &cobra.Command{ - Use: "inspect ", - Short: "Inspect a model image or OCI index", - Args: cobra.ExactArgs(1), - Hidden: true, - RunE: func(cmd *cobra.Command, args []string) error { - return inspectCommand(cmd, args, localOnly, remoteOnly, jsonOutput, rawOutput) - }, - } - - cmd.Flags().BoolVar(&localOnly, "local", false, "Only inspect local docker daemon") - cmd.Flags().BoolVar(&remoteOnly, "remote", false, "Only inspect remote registry") - cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output as JSON") - cmd.Flags().BoolVar(&rawOutput, "raw", false, "Output raw JSON fragments (one per line)") - - return cmd -} - -func inspectCommand(cmd *cobra.Command, args []string, localOnly, remoteOnly, jsonOutput, rawOutput bool) error { - ctx := cmd.Context() - - ref, err := model.ParseRef(args[0]) - if err != nil { - return err - } - - dockerClient, err := docker.NewClient(ctx) - if err != nil { - return err - } - - regClient := registry.NewRegistryClient() - resolver := model.NewResolver(dockerClient, regClient) - - // Build resolve options - var opts []model.Option - switch { - case localOnly: - opts = append(opts, model.LocalOnly()) - case remoteOnly: - opts = append(opts, model.RemoteOnly()) - } - - m, err := resolver.Inspect(ctx, ref, opts...) - if err != nil { - return err - } - - // Build output - out, err := buildInspectOutput(ctx, ref.String(), m, regClient) - if err != nil { - return err - } - - switch { - case rawOutput: - return streamRaw(ctx, ref.String(), m, regClient) - case jsonOutput: - enc := json.NewEncoder(os.Stdout) - enc.SetIndent("", " ") - return enc.Encode(out) - default: - printInspectText(out) - return nil - } -} - -func buildInspectOutput(ctx context.Context, reference string, m *model.Model, reg registry.Client) (*InspectOutput, error) { - out := &InspectOutput{ - Reference: reference, - CogVersion: m.CogVersion, - } - - if m.Index != nil { - out.Type = "index" - idx := &InspectIndex{ - Reference: m.Index.Reference, - Digest: m.Index.Digest, - MediaType: m.Index.MediaType, - } - - for _, im := range m.Index.Manifests { - manifest := buildManifestEntry(im) - - // Try to fetch layers from registry - layers, err := fetchLayers(ctx, reference, im.Digest, reg) - if err == nil { - manifest.Layers = layers - } - - idx.Manifests = append(idx.Manifests, manifest) - } - - out.Index = idx - } else { - out.Type = "image" - if m.Image != nil { - manifest := &InspectManifest{ - Type: "image", - Digest: m.Image.Digest, - } - if m.Image.Platform != nil { - parts := []string{m.Image.Platform.OS, m.Image.Platform.Architecture} - if m.Image.Platform.Variant != "" { - parts = append(parts, m.Image.Platform.Variant) - } - manifest.Platform = strings.Join(parts, "/") - } - - // Try to fetch layers - if m.Image.Digest != "" { - layers, err := fetchLayers(ctx, reference, m.Image.Digest, reg) - if err == nil { - manifest.Layers = layers - } - } - - out.Image = manifest - } - } - - return out, nil -} - -func buildManifestEntry(im model.IndexManifest) InspectManifest { - manifest := InspectManifest{ - Digest: im.Digest, - MediaType: im.MediaType, - Size: im.Size, - Annotations: im.Annotations, - } - - switch im.Type { - case model.ManifestTypeWeights: - manifest.Type = "weights" - manifest.Name = im.Annotations[model.AnnotationWeightName] - manifest.Target = im.Annotations[model.AnnotationWeightDest] - default: - manifest.Type = "image" - if im.Platform != nil { - parts := []string{im.Platform.OS, im.Platform.Architecture} - if im.Platform.Variant != "" { - parts = append(parts, im.Platform.Variant) - } - manifest.Platform = strings.Join(parts, "/") - } - } - - return manifest -} - -func fetchLayers(ctx context.Context, reference, digest string, reg registry.Client) ([]InspectLayer, error) { - // Build a digest reference from the repo - ref, err := model.ParseRef(reference) - if err != nil { - return nil, err - } - digestRef := ref.Ref.Context().String() + "@" + digest - - img, err := reg.GetImage(ctx, digestRef, nil) - if err != nil { - return nil, err - } - - manifest, err := img.Manifest() - if err != nil { - return nil, err - } - - var layers []InspectLayer - for _, l := range manifest.Layers { - layers = append(layers, InspectLayer{ - Digest: l.Digest.String(), - Size: l.Size, - MediaType: string(l.MediaType), - }) - } - - return layers, nil -} - -type rawStep struct { - Step string `json:"step"` - Data any `json:"data,omitempty"` - Manifest any `json:"manifest,omitempty"` -} - -func streamRaw(ctx context.Context, reference string, m *model.Model, reg registry.Client) error { - enc := json.NewEncoder(os.Stdout) - - // Step 1: resolve - _ = enc.Encode(rawStep{ - Step: "resolve", - Data: map[string]any{ - "reference": reference, - "cogVersion": m.CogVersion, - "type": func() string { - if m.Index != nil { - return "index" - } - return "image" - }(), - }, - }) - - if m.Index != nil { - // Step 2: index - _ = enc.Encode(rawStep{ - Step: "index", - Data: map[string]any{ - "digest": m.Index.Digest, - "mediaType": m.Index.MediaType, - "count": len(m.Index.Manifests), - }, - }) - - // Step 3: per-child manifests - for _, im := range m.Index.Manifests { - entry := buildManifestEntry(im) - - ref, err := model.ParseRef(reference) - if err == nil { - digestRef := ref.Ref.Context().String() + "@" + im.Digest - img, err := reg.GetImage(ctx, digestRef, nil) - if err == nil { - rawManifest, err := img.RawManifest() - if err == nil { - var parsed any - if jsonErr := json.Unmarshal(rawManifest, &parsed); jsonErr == nil { - _ = enc.Encode(rawStep{ - Step: "manifest", - Data: entry, - Manifest: parsed, - }) - continue - } - } - } - } - - // Fallback: output without raw manifest - _ = enc.Encode(rawStep{ - Step: "manifest", - Data: entry, - }) - } - } - - // Final step: model summary - _ = enc.Encode(rawStep{ - Step: "model", - Data: map[string]any{ - "reference": reference, - "cogVersion": m.CogVersion, - }, - }) - - return nil -} - -func printInspectText(out *InspectOutput) { - fmt.Printf("Model: %s\n", out.Reference) - if out.Type == "index" { - fmt.Println("Type: Model Bundle (OCI Index)") - } else { - fmt.Println("Type: Image") - } - fmt.Printf("Cog: %s\n", out.CogVersion) - fmt.Println() - - if out.Index != nil { - // Build the digest reference: repo@sha256:... - digestRef := out.Index.Digest - if out.Index.Reference != "" && out.Index.Digest != "" { - // Extract repo from the reference (strip tag/digest) - repo := out.Index.Reference - if idx := strings.LastIndex(repo, ":"); idx != -1 { - // Only strip if it looks like a tag (no @) - if !strings.Contains(repo[idx:], "@") { - repo = repo[:idx] - } - } - digestRef = repo + "@" + out.Index.Digest - } - fmt.Printf("Index: %s\n", digestRef) - fmt.Printf(" Tag: %s\n", out.Reference) - fmt.Printf(" Digest: %s\n", out.Index.Digest) - fmt.Printf(" MediaType: %s\n", out.Index.MediaType) - fmt.Printf(" Manifests: %d\n", len(out.Index.Manifests)) - fmt.Println() - - for _, m := range out.Index.Manifests { - printManifestText(m, " ") - fmt.Println() - } - } else if out.Image != nil { - printManifestText(*out.Image, "") - } -} - -func printManifestText(m InspectManifest, indent string) { - if m.Type == "weights" { - name := m.Name - if name == "" { - name = "(unnamed)" - } - fmt.Printf("%s[weights] %s\n", indent, name) - } else { - platform := m.Platform - if platform == "" { - platform = "(unknown)" - } - fmt.Printf("%s[image] %s\n", indent, platform) - } - - fmt.Printf("%s Digest: %s\n", indent, m.Digest) - - // Show manifest size + total layer size if layers are available - if len(m.Layers) > 0 { - var layerTotal int64 - for _, l := range m.Layers { - layerTotal += l.Size - } - fmt.Printf("%s Size: %s (Layers: %s)\n", indent, formatSize(m.Size), formatSize(layerTotal)) - } else { - fmt.Printf("%s Size: %s\n", indent, formatSize(m.Size)) - } - - if m.Target != "" { - fmt.Printf("%s Target: %s\n", indent, m.Target) - } - - if m.MediaType != "" { - fmt.Printf("%s Type: %s\n", indent, m.MediaType) - } - - if len(m.Layers) > 0 { - fmt.Printf("%s Layers: %d\n", indent, len(m.Layers)) - for _, l := range m.Layers { - fmt.Printf("%s %s %s %s\n", indent, l.Digest, formatSize(l.Size), l.MediaType) - } - } -} diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index fd399c9e16..996c6fe76a 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -30,6 +30,7 @@ import ( "github.com/replicate/cog/pkg/util/console" "github.com/replicate/cog/pkg/util/files" "github.com/replicate/cog/pkg/util/mime" + "github.com/replicate/cog/pkg/weights" ) const StdinPath = "-" @@ -189,6 +190,11 @@ func cmdPredict(cmd *cobra.Command, args []string) error { volumes := []command.Volume{} gpus := gpusFlag + // The Manager is built only when we have cog.yaml in scope (the + // build-from-source path). Pre-built images are opaque to Cog and + // may grow their own weight-metadata signal later. + var wm *weights.Manager + resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) if len(args) == 0 { @@ -198,6 +204,10 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return err } + if err := weights.CheckDrift(src.ProjectDir, src.Config.Weights); err != nil { + return err + } + console.Info("Building Docker image from environment in cog.yaml...") console.Info("") m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) @@ -215,6 +225,11 @@ func cmdPredict(cmd *cobra.Command, args []string) error { if gpus == "" && m.HasGPU() { gpus = "all" } + + wm, err = newWeightManager(src, "") + if err != nil { + return err + } } else { // Use existing image imageName = args[0] @@ -248,12 +263,17 @@ func cmdPredict(cmd *cobra.Command, args []string) error { env = append(env, "RUST_LOG="+rustLog) } - predictor, err := predict.NewPredictor(ctx, command.RunOptions{ - GPUs: gpus, - Image: imageName, - Volumes: volumes, - Env: env, - }, false, dockerClient) + predictor, err := predict.NewPredictor(ctx, predict.PredictorOptions{ + RunOptions: command.RunOptions{ + GPUs: gpus, + Image: imageName, + Volumes: volumes, + Env: env, + }, + IsTrain: false, + Docker: dockerClient, + WeightManager: wm, + }) if err != nil { return err } @@ -266,11 +286,16 @@ func cmdPredict(cmd *cobra.Command, args []string) error { console.Info("Missing device driver, re-trying without GPU") _ = predictor.Stop(ctx) - predictor, err = predict.NewPredictor(ctx, command.RunOptions{ - Image: imageName, - Volumes: volumes, - Env: env, - }, false, dockerClient) + predictor, err = predict.NewPredictor(ctx, predict.PredictorOptions{ + RunOptions: command.RunOptions{ + Image: imageName, + Volumes: volumes, + Env: env, + }, + IsTrain: false, + Docker: dockerClient, + WeightManager: wm, + }) if err != nil { return err } diff --git a/pkg/cli/push.go b/pkg/cli/push.go index 2594cc4abf..ccc27b5c1c 100644 --- a/pkg/cli/push.go +++ b/pkg/cli/push.go @@ -13,6 +13,7 @@ import ( "github.com/replicate/cog/pkg/provider/setup" "github.com/replicate/cog/pkg/registry" "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/weights" ) func newPushCommand() *cobra.Command { @@ -65,6 +66,10 @@ func push(cmd *cobra.Command, args []string) error { return err } + if err := weights.CheckDrift(src.ProjectDir, src.Config.Weights); err != nil { + return err + } + imageName := src.Config.Image if len(args) > 0 { imageName = args[0] @@ -108,9 +113,8 @@ func push(cmd *cobra.Command, args []string) error { } // Log weights info - weights := m.WeightArtifacts() - if len(weights) > 0 { - console.Infof("\n%d weight artifact(s)", len(weights)) + if len(m.Weights) > 0 { + console.Infof("\n%d managed weight(s)", len(m.Weights)) } // Push the model (image + optional weights) @@ -125,7 +129,6 @@ func push(cmd *cobra.Command, args []string) error { pushErr := resolver.Push(ctx, m, model.PushOptions{ ImageProgressFn: func(prog model.PushProgress) { - // Phase transitions: use console.Info for pretty CLI formatting if prog.Phase != "" { switch prog.Phase { case model.PushPhaseExporting: @@ -136,14 +139,7 @@ func push(cmd *cobra.Command, args []string) error { return } - // Byte progress: show per-layer progress bars - // Truncate digest for display: "sha256:abc123..." → "abc123..." - displayDigest := prog.LayerDigest - if len(displayDigest) > 7+12 { // "sha256:" + 12 hex chars - displayDigest = displayDigest[7:19] + "..." - } - - pw.Write(displayDigest, "Pushing", prog.Complete, prog.Total) + pw.Write(model.ShortDigest(prog.LayerDigest), "Pushing", prog.Complete, prog.Total) }, OnFallback: func() { // Close progress writer to finalize OCI progress bars before Docker diff --git a/pkg/cli/root.go b/pkg/cli/root.go index c6ccc4087d..ce337c86d8 100644 --- a/pkg/cli/root.go +++ b/pkg/cli/root.go @@ -47,7 +47,6 @@ https://github.com/replicate/cog`, newDebugCommand(), newDoctorCommand(), newInitCommand(), - newInspectCommand(), newLoginCommand(), newPredictCommand(), newPushCommand(), diff --git a/pkg/cli/train.go b/pkg/cli/train.go index d3d6227f2f..7f97ffa3ba 100644 --- a/pkg/cli/train.go +++ b/pkg/cli/train.go @@ -15,6 +15,7 @@ import ( "github.com/replicate/cog/pkg/predict" "github.com/replicate/cog/pkg/registry" "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/weights" ) var ( @@ -66,6 +67,9 @@ func cmdTrain(cmd *cobra.Command, args []string) error { volumes := []command.Volume{} gpus := gpusFlag + // Managed-weight mounts only apply when we have cog.yaml in scope. + var wm *weights.Manager + resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) if len(args) == 0 { @@ -75,6 +79,10 @@ func cmdTrain(cmd *cobra.Command, args []string) error { return err } + if err := weights.CheckDrift(src.ProjectDir, src.Config.Weights); err != nil { + return err + } + console.Info("Building Docker image from environment in cog.yaml...") console.Info("") m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) @@ -92,6 +100,11 @@ func cmdTrain(cmd *cobra.Command, args []string) error { if gpus == "" && m.HasGPU() { gpus = "all" } + + wm, err = newWeightManager(src, "") + if err != nil { + return err + } } else { // Use existing image imageName = args[0] @@ -114,13 +127,18 @@ func cmdTrain(cmd *cobra.Command, args []string) error { console.Info("") console.Info("Starting Docker image and running setup()...") - predictor, err := predict.NewPredictor(ctx, command.RunOptions{ - GPUs: gpus, - Image: imageName, - Volumes: volumes, - Env: trainEnvFlags, - Args: []string{"python", "-m", "cog.server.http", "--x-mode", "train"}, - }, true, dockerClient) + predictor, err := predict.NewPredictor(ctx, predict.PredictorOptions{ + RunOptions: command.RunOptions{ + GPUs: gpus, + Image: imageName, + Volumes: volumes, + Env: trainEnvFlags, + Args: []string{"python", "-m", "cog.server.http", "--x-mode", "train"}, + }, + IsTrain: true, + Docker: dockerClient, + WeightManager: wm, + }) if err != nil { return err } diff --git a/pkg/cli/weights.go b/pkg/cli/weights.go index 0773624faa..db66ca7a96 100644 --- a/pkg/cli/weights.go +++ b/pkg/cli/weights.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "path/filepath" "time" @@ -9,11 +10,13 @@ import ( "github.com/spf13/cobra" "golang.org/x/sync/errgroup" + "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" - "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/registry" "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" ) func newWeightsCommand() *cobra.Command { @@ -24,27 +27,52 @@ func newWeightsCommand() *cobra.Command { Hidden: true, } - cmd.AddCommand(newWeightsBuildCommand()) - cmd.AddCommand(newWeightsInspectCommand()) - cmd.AddCommand(newWeightsPushCommand()) + cmd.AddCommand(newWeightsImportCommand()) + cmd.AddCommand(newWeightsPullCommand()) + cmd.AddCommand(newWeightsStatusCommand()) return cmd } -func newWeightsBuildCommand() *cobra.Command { +func newWeightsImportCommand() *cobra.Command { + var ( + dryRun bool + verbose bool + ) + cmd := &cobra.Command{ - Use: "build", - Short: "Generate weights.lock from weight sources in cog.yaml", - Long: `Reads the weights section from cog.yaml, processes each weight source, -and generates a weights.lock file containing metadata (digests, sizes) for each file.`, - Args: cobra.NoArgs, - RunE: weightsBuildCommand, + Use: "import [name...]", + Short: "Build and push weights to a registry", + Long: `Packages weight sources from cog.yaml into OCI layers, updates weights.lock, +and pushes the layers to a registry. + +Import also warms the local content-addressed weight store as a side +effect, so 'cog predict' can mount the weights immediately without a +separate 'cog weights pull'. Pull is still useful when someone clones +a repo with a checked-in weights.lock but a cold local cache. + +If weight names are provided, only those weights are imported. Otherwise all weights +defined in cog.yaml are imported. + +The registry is determined from the image name, which can be: +- Set in cog.yaml as the 'image' field +- Overridden with the --image flag + +Use --dry-run to preview what would change without importing anything. +Add --verbose to see per-file details including which files pass the filter.`, + Args: cobra.ArbitraryArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return weightsImportCommand(cmd, args, dryRun, verbose) + }, } addConfigFlag(cmd) + cmd.Flags().String("image", "", "Registry repository (overrides cog.yaml image field)") + cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Show what would be imported without making changes") + cmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Show per-file details") return cmd } -func weightsBuildCommand(cmd *cobra.Command, args []string) error { +func weightsImportCommand(cmd *cobra.Command, args []string, dryRun, verbose bool) error { ctx := cmd.Context() src, err := model.NewSource(configFilename) @@ -52,178 +80,242 @@ func weightsBuildCommand(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to read config: %w", err) } - if len(src.Config.Weights) == 0 { - return fmt.Errorf("no weights defined in %s", configFilename) + cfg := src.Config + + imageName, _ := cmd.Flags().GetString("image") + if imageName == "" { + imageName = cfg.Image + } + if imageName == "" { + return fmt.Errorf("To import weights, you must either set the 'image' option in cog.yaml or pass --image. For example, 'cog weights import --image registry.example.com/your-username/model-name'") } - // Extract weight specs from the source - var weightSpecs []*model.WeightSpec - for _, spec := range src.ArtifactSpecs() { - if ws, ok := spec.(*model.WeightSpec); ok { - weightSpecs = append(weightSpecs, ws) - } + repo, err := parseRepoOnly(imageName) + if err != nil { + return err } - console.Infof("Processing %d weight source(s)...", len(weightSpecs)) + weightSpecs, err := collectWeightSpecs(src, args) + if err != nil { + return err + } - lockPath := filepath.Join(src.ProjectDir, model.WeightsLockFilename) - builder := model.NewWeightBuilder(src, global.Version, lockPath) + // Always plan first to show the user what would happen. + lockPath := filepath.Join(src.ProjectDir, lockfile.WeightsLockFilename) + builder := model.NewWeightBuilder(src, nil, lockPath) - // Build each weight artifact (hashes file, updates lockfile) - var totalSize int64 - for _, ws := range weightSpecs { - artifact, buildErr := builder.Build(ctx, ws) - if buildErr != nil { - return fmt.Errorf("failed to build weight %q: %w", ws.Name(), buildErr) - } + plans, err := planWeightImports(ctx, builder, weightSpecs) + if err != nil { + return err + } - wa, ok := artifact.(*model.WeightArtifact) - if !ok { - return fmt.Errorf("unexpected artifact type %T for weight %q", artifact, ws.Name()) - } - size := wa.Descriptor().Size - totalSize += size - console.Infof(" %s -> %s (%s)", wa.Name(), wa.Target, formatSize(size)) + printImportPlan(plans, verbose) + + if dryRun { + return nil } - console.Infof("\nGenerated %s with %d file(s) (%s total)", - model.WeightsLockFilename, len(weightSpecs), formatSize(totalSize)) + // Proceed with the real import. We create a new builder with a real + // store but reuse each plan's resolvedInventory (which captures the + // Source and filtered file list, independent of the builder). + fileStore, err := store.OpenDefault() + if err != nil { + return fmt.Errorf("open weights store: %w", err) + } - return nil -} + builder = model.NewWeightBuilder(src, fileStore, lockPath) -func formatSize(bytes int64) string { - const ( - kb = 1024 - mb = kb * 1024 - gb = mb * 1024 - ) + console.Infof("Building %d weight(s)...", len(weightSpecs)) - switch { - case bytes >= gb: - return fmt.Sprintf("%.1fGB", float64(bytes)/float64(gb)) - case bytes >= mb: - return fmt.Sprintf("%.1fMB", float64(bytes)/float64(mb)) - case bytes >= kb: - return fmt.Sprintf("%.1fKB", float64(bytes)/float64(kb)) - default: - return fmt.Sprintf("%dB", bytes) + artifacts, err := buildWeightArtifactsFromPlans(ctx, builder, weightSpecs, plans) + if err != nil { + return err } -} -func newWeightsPushCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "push [IMAGE]", - Short: "Push weights to a registry", - Long: `Reads weights.lock and pushes weight files as an OCI artifact to a registry. + // Prune lockfile entries for weights no longer declared in cog.yaml. + // This always uses the full config (not the filtered import set) + // so orphans are corrected regardless of which weights were imported. + if err := lockfile.PruneLockfile(lockPath, config.WeightNames(cfg.Weights)); err != nil { + return fmt.Errorf("prune lockfile: %w", err) + } -The registry is determined from the image name, which can be: -- Specified as an argument: cog weights push registry.example.com/user/model -- Set in cog.yaml as the 'image' field`, - Args: cobra.MaximumNArgs(1), - RunE: weightsPushCommand, + for _, wa := range artifacts { + console.Infof(" %s -> %s (%d layer(s), %s)", + wa.Name(), wa.Entry.Target, len(wa.Layers), formatSize(wa.TotalSize())) } - addConfigFlag(cmd) - return cmd + console.Infof("\nPushing %d weight(s) to %s...", len(artifacts), repo) + + return pushWeightArtifacts(ctx, repo, artifacts, "Imported") } -func weightsPushCommand(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() +// planWeightImports runs PlanImport for each spec without side effects. +func planWeightImports(ctx context.Context, builder *model.WeightBuilder, specs []*model.WeightSpec) ([]*model.WeightImportPlan, error) { + plans := make([]*model.WeightImportPlan, 0, len(specs)) + for _, ws := range specs { + plan, err := builder.PlanImport(ctx, ws) + if err != nil { + return nil, fmt.Errorf("plan weight %q: %w", ws.Name(), err) + } + plans = append(plans, plan) + } + return plans, nil +} - src, err := model.NewSource(configFilename) - if err != nil { - return fmt.Errorf("failed to read config: %w", err) +// buildWeightArtifactsFromPlans builds each weight spec, reusing the +// pre-computed inventories from planning to avoid re-walking sources. +func buildWeightArtifactsFromPlans(ctx context.Context, builder *model.WeightBuilder, specs []*model.WeightSpec, plans []*model.WeightImportPlan) ([]*model.WeightArtifact, error) { + artifacts := make([]*model.WeightArtifact, 0, len(specs)) + for i, ws := range specs { + artifact, err := builder.BuildFromPlan(ctx, ws, plans[i]) + if err != nil { + return nil, fmt.Errorf("failed to build weight %q: %w", ws.Name(), err) + } + wa, ok := artifact.(*model.WeightArtifact) + if !ok { + return nil, fmt.Errorf("unexpected artifact type %T for weight %q", artifact, ws.Name()) + } + artifacts = append(artifacts, wa) } + return artifacts, nil +} - cfg := src.Config +// printImportPlan prints a human-readable summary of what would happen. +func printImportPlan(plans []*model.WeightImportPlan, verbose bool) { + for _, p := range plans { + statusIcon := planStatusIcon(p.Status) + console.Infof("%s %s %s → %s", statusIcon, p.Spec.Name(), p.Spec.URI, p.Spec.Target) + console.Infof(" status: %s", p.Status) + + if len(p.Changes) > 0 { + for _, c := range p.Changes { + console.Infof(" changed: %s", c) + } + } + + filtered := p.FilteredFiles() + console.Infof(" files: %d size: %s", len(filtered), formatSize(p.TotalSize())) + + if verbose { + excluded := p.ExcludedFiles() + if len(excluded) > 0 { + console.Infof(" excluded (%d files):", len(excluded)) + for _, f := range excluded { + console.Infof(" - %s (%s)", f.Path, formatSize(f.Size)) + } + } + if len(filtered) > 0 { + console.Infof(" included (%d files):", len(filtered)) + for _, f := range filtered { + console.Infof(" + %s (%s)", f.Path, formatSize(f.Size)) + } + } + } - // Determine image name - imageName := cfg.Image - if len(args) > 0 { - imageName = args[0] + console.Infof("") // blank line between weights } - if imageName == "" { - return fmt.Errorf("To push weights, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog weights push registry.example.com/your-username/model-name'") +} + +func planStatusIcon(status model.WeightImportPlanStatus) string { + switch status { + case model.PlanStatusNew: + return "+" + case model.PlanStatusUnchanged: + return "=" + case model.PlanStatusConfigChanged, model.PlanStatusUpstreamChanged: + return "~" + default: + return "?" + } +} + +// collectWeightSpecs extracts WeightSpecs from the source, optionally +// filtered to only the names listed in filterNames. An error is returned +// if no weights match or if a requested name doesn't exist. +func collectWeightSpecs(src *model.Source, filterNames []string) ([]*model.WeightSpec, error) { + if len(src.Config.Weights) == 0 { + return nil, fmt.Errorf("no weights defined in %s", configFilename) } - // Parse as repository only — reject tags/digests since weight tags are auto-generated. - parsedRepo, err := name.NewRepository(imageName, name.Insecure) + artifactSpecs, err := src.ArtifactSpecs() if err != nil { - // NewRepository fails for inputs with :tag or @digest — check if it's a valid ref - if ref, refErr := name.ParseReference(imageName, name.Insecure); refErr == nil { - return fmt.Errorf("image reference %q includes a tag or digest — provide only the repository (e.g., %q)", imageName, ref.Context().Name()) + return nil, err + } + var allSpecs []*model.WeightSpec + for _, spec := range artifactSpecs { + if ws, ok := spec.(*model.WeightSpec); ok { + allSpecs = append(allSpecs, ws) } - return fmt.Errorf("invalid repository %q: %w", imageName, err) } - repo := parsedRepo.Name() - if len(cfg.Weights) == 0 { - return fmt.Errorf("no weights defined in %s", configFilename) + if len(filterNames) == 0 { + return allSpecs, nil } - // Build weight artifacts (reads lockfile as cache, hashes files) - lockPath := filepath.Join(src.ProjectDir, model.WeightsLockFilename) - builder := model.NewWeightBuilder(src, global.Version, lockPath) + specMap := make(map[string]*model.WeightSpec, len(allSpecs)) + for _, ws := range allSpecs { + specMap[ws.Name()] = ws + } - var artifacts []*model.WeightArtifact - for _, spec := range src.ArtifactSpecs() { - ws, ok := spec.(*model.WeightSpec) - if !ok { + seen := make(map[string]bool, len(filterNames)) + filtered := make([]*model.WeightSpec, 0, len(filterNames)) + for _, n := range filterNames { + if seen[n] { continue } - artifact, buildErr := builder.Build(ctx, ws) - if buildErr != nil { - return fmt.Errorf("failed to build weight %q: %w", ws.Name(), buildErr) - } - wa, ok := artifact.(*model.WeightArtifact) + seen[n] = true + + ws, ok := specMap[n] if !ok { - return fmt.Errorf("unexpected artifact type %T for weight %q", artifact, ws.Name()) + return nil, fmt.Errorf("weight %q not found in %s", n, configFilename) } - artifacts = append(artifacts, wa) + filtered = append(filtered, ws) } + return filtered, nil +} - if len(artifacts) == 0 { - return fmt.Errorf("no weight artifacts to push") +// parseRepoOnly parses an image string as a bare repository, rejecting +// tags and digests (weight tags are auto-generated). +func parseRepoOnly(imageName string) (string, error) { + parsedRepo, err := name.NewRepository(imageName, name.Insecure) + if err != nil { + if ref, refErr := name.ParseReference(imageName, name.Insecure); refErr == nil { + return "", fmt.Errorf("image reference %q includes a tag or digest — provide only the repository (e.g., %q)", imageName, ref.Context().Name()) + } + return "", fmt.Errorf("invalid repository %q: %w", imageName, err) } + return parsedRepo.Name(), nil +} - console.Infof("Pushing %d weight file(s) to %s...", len(artifacts), repo) - +// pushWeightArtifacts pushes weight artifacts to the registry with +// concurrent layer uploads and progress display. The verb parameter +// controls the summary message (e.g. "Imported" vs "Pushed"). +func pushWeightArtifacts(ctx context.Context, repo string, artifacts []*model.WeightArtifact, verb string) error { regClient := registry.NewRegistryClient() pusher := model.NewWeightPusher(regClient) - // Set up progress display using Docker's jsonmessage rendering. pw := docker.NewProgressWriter() defer pw.Close() - // Push each weight artifact concurrently using errgroup for - // bounded concurrency and first-error cancellation. - type pushResult struct { - ref string - size int64 - } - - ordered := make([]pushResult, len(artifacts)) + refs := make([]string, len(artifacts)) g, ctx := errgroup.WithContext(ctx) g.SetLimit(model.GetPushConcurrency()) for i, wa := range artifacts { artName := wa.Name() - artSize := wa.Descriptor().Size g.Go(func() error { result, pushErr := pusher.Push(ctx, repo, wa, model.WeightPushOptions{ - ProgressFn: func(prog model.PushProgress) { - pw.Write(artName, "Pushing", prog.Complete, prog.Total) + ProgressFn: func(prog model.WeightLayerProgress) { + row := model.ShortDigest(prog.LayerDigest) + pw.Write(artName+"/"+row, "Pushing", prog.Complete, prog.Total) }, RetryFn: func(event model.WeightRetryEvent) bool { status := fmt.Sprintf("Retrying (%d/%d) in %s", event.Attempt, event.MaxAttempts, event.NextRetryIn.Round(time.Second)) pw.WriteStatus(event.Name, status) - // In non-TTY mode, also log the error detail since the - // progress writer output won't be visible. if !console.IsTerminal() { console.Warnf(" %s: retrying (%d/%d) in %s: %v", event.Name, event.Attempt, event.MaxAttempts, @@ -239,28 +331,42 @@ func weightsPushCommand(cmd *cobra.Command, args []string) error { } pw.WriteStatus(artName, "Pushed") - ordered[i] = pushResult{ref: result.Ref, size: artSize} + refs[i] = result.Ref return nil }) } if err := g.Wait(); err != nil { - pw.Close() return err } - // Close progress display - pw.Close() - - // Print final summary var totalSize int64 for i, wa := range artifacts { - console.Infof(" %s: %s", wa.Name(), ordered[i].ref) - totalSize += ordered[i].size + console.Infof(" %s: %s", wa.Name(), refs[i]) + totalSize += wa.TotalSize() } - console.Infof("\nPushed %d weight artifact(s) to %s", len(artifacts), repo) + console.Infof("\n%s %d weight artifact(s) to %s", verb, len(artifacts), repo) console.Infof("Total: %s", formatSize(totalSize)) return nil } + +func formatSize(bytes int64) string { + const ( + kb = 1024 + mb = kb * 1024 + gb = mb * 1024 + ) + + switch { + case bytes >= gb: + return fmt.Sprintf("%.1fGB", float64(bytes)/float64(gb)) + case bytes >= mb: + return fmt.Sprintf("%.1fMB", float64(bytes)/float64(mb)) + case bytes >= kb: + return fmt.Sprintf("%.1fKB", float64(bytes)/float64(kb)) + default: + return fmt.Sprintf("%dB", bytes) + } +} diff --git a/pkg/cli/weights_inspect.go b/pkg/cli/weights_inspect.go deleted file mode 100644 index 760cd1c41b..0000000000 --- a/pkg/cli/weights_inspect.go +++ /dev/null @@ -1,306 +0,0 @@ -package cli - -import ( - "context" - "encoding/json" - "fmt" - "os" - "path/filepath" - - "github.com/google/go-containerregistry/pkg/name" - "github.com/spf13/cobra" - - "github.com/replicate/cog/pkg/model" - "github.com/replicate/cog/pkg/registry" -) - -// localWeight tracks the local state of a weight from cog.yaml + weights.lock. -type localWeight struct { - target string - source string - lockFile *model.WeightFile -} - -// WeightsInspectOutput is the structured output for cog weights inspect --json. -type WeightsInspectOutput struct { - Reference string `json:"reference"` - Weights []WeightInspectEntry `json:"weights"` -} - -// WeightInspectEntry represents one weight's comparison between local and remote state. -type WeightInspectEntry struct { - Name string `json:"name"` - Status string `json:"status"` // synced, local-only, remote-only, digest-mismatch, missing-lockfile - Local *WeightLocalState `json:"local,omitempty"` - Remote *WeightRemoteState `json:"remote,omitempty"` -} - -// WeightLocalState represents the local state of a weight from cog.yaml + weights.lock. -type WeightLocalState struct { - Digest string `json:"digest"` - Size int64 `json:"size"` - Target string `json:"target"` - FileExists bool `json:"fileExists"` -} - -// WeightRemoteLayer represents a single layer in a remote weight manifest. -type WeightRemoteLayer struct { - Digest string `json:"digest"` - Size int64 `json:"size"` - MediaType string `json:"mediaType"` -} - -// WeightRemoteState represents the remote state of a weight from the registry. -type WeightRemoteState struct { - Ref string `json:"ref"` - Tag string `json:"tag"` - Digest string `json:"digest"` - Size int64 `json:"size"` - MediaType string `json:"mediaType"` - Layers []WeightRemoteLayer `json:"layers,omitempty"` - MatchedByContent bool `json:"matchedByContent,omitempty"` -} - -func newWeightsInspectCommand() *cobra.Command { - var jsonOutput bool - - cmd := &cobra.Command{ - Use: "inspect ", - Short: "Compare local weights against remote registry state", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - return weightsInspectCommand(cmd, args, jsonOutput) - }, - } - - cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output as JSON") - addConfigFlag(cmd) - - return cmd -} - -func weightsInspectCommand(cmd *cobra.Command, args []string, jsonOutput bool) error { - ctx := cmd.Context() - - // 1. Load local state - src, err := model.NewSource(configFilename) - if err != nil { - return fmt.Errorf("failed to read config: %w", err) - } - - lockPath := filepath.Join(src.ProjectDir, model.WeightsLockFilename) - lock, lockErr := model.LoadWeightsLock(lockPath) - // lockErr is OK — lockfile may not exist yet - - // Build local weight map: name -> (lockfile entry, source file path) - localWeights := make(map[string]*localWeight) - for _, w := range src.Config.Weights { - lw := &localWeight{ - target: w.Target, - source: w.Source, - } - localWeights[w.Name] = lw - } - - // Fill in lockfile data - if lockErr == nil && lock != nil { - for i := range lock.Files { - f := &lock.Files[i] - if lw, ok := localWeights[f.Name]; ok { - lw.lockFile = f - } - } - } - - // 2. Resolve remote state — accept repo only (tags are auto-generated for weights). - parsedRepo, err := name.NewRepository(args[0], name.Insecure) - if err != nil { - if ref, refErr := name.ParseReference(args[0], name.Insecure); refErr == nil { - return fmt.Errorf("image reference %q includes a tag or digest — provide only the repository (e.g., %q)", args[0], ref.Context().Name()) - } - return fmt.Errorf("invalid repository %q: %w", args[0], err) - } - repo := parsedRepo.Name() - - regClient := registry.NewRegistryClient() - remoteWeights := resolveWeightsByTag(ctx, repo, localWeights, regClient) - - // 3. Build comparison - out := &WeightsInspectOutput{ - Reference: repo, - } - - // Track which remote weights we've matched - matchedRemote := make(map[string]bool) - - // Process local weights - for _, w := range src.Config.Weights { - entry := WeightInspectEntry{Name: w.Name} - lw := localWeights[w.Name] - - if lw.lockFile == nil { - // No lockfile entry — needs `cog weights build` - entry.Status = "missing-lockfile" - entry.Local = &WeightLocalState{ - Target: lw.target, - FileExists: fileExists(filepath.Join(src.ProjectDir, lw.source)), - } - } else { - // Check if source file exists on disk - exists := fileExists(filepath.Join(src.ProjectDir, lw.source)) - entry.Local = &WeightLocalState{ - Digest: lw.lockFile.Digest, - Size: lw.lockFile.Size, - Target: lw.lockFile.Dest, - FileExists: exists, - } - - if remote, ok := remoteWeights[w.Name]; ok { - matchedRemote[w.Name] = true - entry.Remote = remote - - if remote.MatchedByContent || lw.lockFile.Digest == remote.Digest { - entry.Status = "synced" - } else { - entry.Status = "digest-mismatch" - } - } else { - entry.Status = "local-only" - } - } - - out.Weights = append(out.Weights, entry) - } - - // Add remote-only weights - for name, remote := range remoteWeights { - if matchedRemote[name] { - continue - } - out.Weights = append(out.Weights, WeightInspectEntry{ - Name: name, - Status: "remote-only", - Remote: remote, - }) - } - - // 4. Output - if jsonOutput { - enc := json.NewEncoder(os.Stdout) - enc.SetIndent("", " ") - return enc.Encode(out) - } - - printWeightsInspectText(out) - return nil -} - -// resolveWeightsByTag checks for each local weight's tag in the registry. -// This is the fallback path when no OCI index exists (e.g., after `cog weights push` -// but before `cog push`). -// -// It looks up the combined tag :weights-- which encodes both -// the weight name and its content digest. A match means the exact content is synced. -func resolveWeightsByTag(ctx context.Context, repo string, localWeights map[string]*localWeight, reg registry.Client) map[string]*WeightRemoteState { - result := make(map[string]*WeightRemoteState) - for weightName, lw := range localWeights { - if lw.lockFile == nil { - continue - } - - tag := model.WeightTag(weightName, lw.lockFile.Digest) - tagRef := repo + ":" + tag - - // Use GetImage to fetch the full manifest (not just HEAD) so we can read layer sizes. - img, err := reg.GetImage(ctx, tagRef, nil) - if err != nil { - continue - } - - manifest, err := img.Manifest() - if err != nil { - continue - } - - digest, err := img.Digest() - if err != nil { - continue - } - - rawManifest, err := img.RawManifest() - if err != nil { - continue - } - - state := &WeightRemoteState{ - Ref: tagRef, - Tag: tag, - Digest: digest.String(), - Size: int64(len(rawManifest)), - MediaType: string(manifest.MediaType), - MatchedByContent: true, - } - - for _, layer := range manifest.Layers { - state.Layers = append(state.Layers, WeightRemoteLayer{ - Digest: layer.Digest.String(), - Size: layer.Size, - MediaType: string(layer.MediaType), - }) - } - - result[weightName] = state - } - if len(result) == 0 { - return nil - } - return result -} - -func printWeightsInspectText(out *WeightsInspectOutput) { - fmt.Printf("Weights for: %s\n\n", out.Reference) - - for _, w := range out.Weights { - if w.Remote != nil && w.Remote.Tag != "" { - fmt.Printf(" %s :%s\n", w.Name, w.Remote.Tag) - } else { - fmt.Printf(" %s\n", w.Name) - } - fmt.Printf(" Status: %s", w.Status) - - switch w.Status { - case "local-only": - fmt.Print(" (not pushed)") - case "remote-only": - fmt.Print(" (not in cog.yaml)") - case "missing-lockfile": - fmt.Print(" (run cog weights build)") - } - fmt.Println() - - if w.Local != nil { - if w.Local.Digest != "" { - fmt.Printf(" Local: %s (%s) -> %s\n", w.Local.Digest, formatSize(w.Local.Size), w.Local.Target) - } else { - fmt.Printf(" Local: (no lockfile entry) -> %s\n", w.Local.Target) - } - } else { - fmt.Println(" Local: -") - } - - if w.Remote != nil { - for _, layer := range w.Remote.Layers { - fmt.Printf(" Layer: %s (%s)\n", layer.Digest, formatSize(layer.Size)) - } - } else { - fmt.Println(" Remote: -") - } - - fmt.Println() - } -} - -func fileExists(path string) bool { - _, err := os.Stat(path) - return err == nil -} diff --git a/pkg/cli/weights_manager.go b/pkg/cli/weights_manager.go new file mode 100644 index 0000000000..87e887b08f --- /dev/null +++ b/pkg/cli/weights_manager.go @@ -0,0 +1,28 @@ +package cli + +import ( + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/weights" +) + +// newWeightManager resolves the repo (from imageOverride or +// src.Config.Image) and delegates construction to weights.NewFromSource. +// Repo resolution is CLI-input parsing, which is why it lives here and +// not in pkg/weights. +func newWeightManager(src *model.Source, imageOverride string) (*weights.Manager, error) { + repo := "" + if len(src.Config.Weights) > 0 { + imageName := imageOverride + if imageName == "" { + imageName = src.Config.Image + } + if imageName != "" { + parsed, err := parseRepoOnly(imageName) + if err != nil { + return nil, err + } + repo = parsed + } + } + return weights.NewFromSource(src, repo) +} diff --git a/pkg/cli/weights_pull.go b/pkg/cli/weights_pull.go new file mode 100644 index 0000000000..81b7eca64d --- /dev/null +++ b/pkg/cli/weights_pull.go @@ -0,0 +1,156 @@ +package cli + +import ( + "fmt" + "path/filepath" + + "github.com/spf13/cobra" + + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/paths" + "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/weights" + "github.com/replicate/cog/pkg/weights/lockfile" +) + +func newWeightsPullCommand() *cobra.Command { + var ( + verbose bool + imageOverride string + ) + + cmd := &cobra.Command{ + Use: "pull [NAME...]", + Short: "Populate the local weight cache from the registry", + Long: `Downloads weight files from the registry into the local content-addressed +cache so 'cog predict' and 'cog run' can mount them at runtime. + +You don't need to run 'cog weights pull' after 'cog weights import' — +import already warms the local cache. Pull is for the case where +weights.lock is checked into git and you have a cold local cache (e.g. +fresh clone, new machine). + +If weight names are provided, only those weights are pulled. Otherwise all +weights defined in cog.yaml are pulled. + +Files already present in the local cache are skipped — re-running pull is +cheap. The local cache defaults to $HOME/.cache/cog/weights; set +COG_CACHE_DIR (or XDG_CACHE_HOME) to move it elsewhere — useful if your +home directory is on a different filesystem than your project. + +Use --verbose to show per-layer and per-file progress.`, + Args: cobra.ArbitraryArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return weightsPullCommand(cmd, args, verbose, imageOverride) + }, + } + + addConfigFlag(cmd) + cmd.Flags().StringVar(&imageOverride, "image", "", "Registry repository (overrides cog.yaml image field)") + cmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Show per-layer and per-file progress") + return cmd +} + +func weightsPullCommand(cmd *cobra.Command, args []string, verbose bool, imageOverride string) error { + ctx := cmd.Context() + + src, err := model.NewSource(configFilename) + if err != nil { + return fmt.Errorf("failed to read config: %w", err) + } + + if len(src.Config.Weights) == 0 { + return fmt.Errorf("no weights defined in %s", configFilename) + } + + mgr, err := newWeightManager(src, imageOverride) + if err != nil { + return err + } + + if verbose { + // WeightsStoreDir cannot fail here because newWeightManager + // already resolved it; ignore the error to keep the log block + // simple. + storeDir, _ := paths.WeightsStoreDir() //nolint:errcheck // see comment above + lockPath := filepath.Join(src.ProjectDir, lockfile.WeightsLockFilename) + console.Infof("Cache: %s", storeDir) + console.Infof("Lockfile: %s", lockPath) + console.Info("") + } + + results, err := mgr.Pull(ctx, args, pullEventPrinter(verbose)) + printPullSummary(results, verbose) + return err +} + +// pullEventPrinter returns a PullEvent handler that writes progress to +// the console. Verbose mode adds per-layer / per-file detail. +func pullEventPrinter(verbose bool) func(weights.PullEvent) { + return func(e weights.PullEvent) { + switch e.Kind { + case weights.PullEventWeightStart: + if e.MissingFiles == 0 { + console.Infof("Pulling %s... cached (%d/%d files)", e.Weight, e.TotalFiles, e.TotalFiles) + return + } + if verbose { + console.Infof("Pulling %s -> %s", e.Weight, e.Target) + console.Infof(" manifest: %s", e.ManifestRef) + console.Infof(" files: %d missing / %d total", e.MissingFiles, e.TotalFiles) + } else { + console.Infof("Pulling %s... (%d file(s))", e.Weight, e.MissingFiles) + } + case weights.PullEventLayerStart: + if !verbose { + return + } + size := "unknown size" + if e.LayerSize > 0 { + size = formatSize(e.LayerSize) + } + console.Infof(" layer %s (%s)", model.ShortDigest(e.LayerDigest), size) + case weights.PullEventFileStored: + if !verbose { + return + } + console.Infof(" %s (%s) %s", e.FilePath, formatSize(e.FileSize), model.ShortDigest(e.FileDigest)) + case weights.PullEventLayerDone: + // Layer boundary is implicit from the per-file lines. + case weights.PullEventWeightDone: + if e.FullyCached { + return + } + console.Infof("Pulling %s... done (%s, %d file(s), %d layer(s))", + e.Weight, formatSize(e.BytesFetched), e.FilesFetched, e.LayersFetched) + } + } +} + +func printPullSummary(results []weights.PullResult, verbose bool) { + if len(results) == 0 { + return + } + var totalBytes int64 + var totalFiles, totalLayers, cachedWeights int + for _, r := range results { + if r.FullyCached { + cachedWeights++ + continue + } + totalBytes += r.BytesFetched + totalFiles += r.FilesFetched + totalLayers += r.LayersFetched + } + if verbose { + console.Info("") + } + if totalFiles == 0 { + console.Infof("All %d weight(s) already cached.", len(results)) + return + } + console.Infof( + "Pulled %s across %d file(s) / %d layer(s) for %d weight(s); %d already cached.", + formatSize(totalBytes), totalFiles, totalLayers, len(results)-cachedWeights, cachedWeights, + ) +} diff --git a/pkg/cli/weights_status.go b/pkg/cli/weights_status.go new file mode 100644 index 0000000000..8e5eb185d6 --- /dev/null +++ b/pkg/cli/weights_status.go @@ -0,0 +1,235 @@ +package cli + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "text/tabwriter" + + "github.com/spf13/cobra" + + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/weights/lockfile" +) + +// The root command has SilenceErrors: true, so Cobra exits 1 without +// printing the error message. +var errWeightsNotReady = errors.New("not all weights are ready") + +// WeightsStatusOutput is the top-level structured output for cog weights status --json. +type WeightsStatusOutput struct { + Weights []WeightStatusEntry `json:"weights"` +} + +// WeightStatusEntry is one weight's JSON representation. +type WeightStatusEntry struct { + Name string `json:"name"` + Target string `json:"target"` + Status string `json:"status"` + Size int64 `json:"size,omitempty"` + SizeCompressed int64 `json:"sizeCompressed,omitempty"` + LayerCount int `json:"layerCount,omitempty"` + FileCount int `json:"fileCount,omitempty"` + Digest string `json:"digest,omitempty"` + Source *WeightStatusSource `json:"source,omitempty"` + Layers []LayerStatusEntry `json:"layers,omitempty"` +} + +// WeightStatusSource records the source metadata from the lockfile entry. +type WeightStatusSource struct { + URI string `json:"uri,omitempty"` + Fingerprint string `json:"fingerprint,omitempty"` +} + +// LayerStatusEntry is one layer's status in the output. +type LayerStatusEntry struct { + Digest string `json:"digest"` + Size int64 `json:"size"` + Status string `json:"status"` +} + +func newWeightsStatusCommand() *cobra.Command { + var ( + jsonOutput bool + verbose bool + ) + + cmd := &cobra.Command{ + Use: "status", + Short: "Show the status of configured weights", + Long: `Shows each declared weight's state across config (cog.yaml), lockfile +(weights.lock), and registry. + +Status values: + ready - config + lockfile match, all layers in registry + incomplete - config + lockfile match, some layers missing from registry + stale - lockfile exists but config has drifted + pending - declared in config, not yet built + orphaned - in lockfile but removed from config + +Every non-ready status is resolved by running 'cog weights import'. + +Use --verbose to show per-layer status for each weight. + +Exit code is 0 when all weights are ready, 1 otherwise.`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return weightsStatusCommand(cmd, jsonOutput, verbose) + }, + } + + cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output as JSON") + cmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Show per-layer status") + addConfigFlag(cmd) + + return cmd +} + +func weightsStatusCommand(cmd *cobra.Command, jsonOutput, verbose bool) error { + ctx := cmd.Context() + + src, err := model.NewSource(configFilename) + if err != nil { + return fmt.Errorf("failed to read config: %w", err) + } + + // Load lockfile — missing is fine (weights may not be built yet), but + // a present-but-corrupt file gets a warning so it doesn't fail silently. + lockPath := filepath.Join(src.ProjectDir, lockfile.WeightsLockFilename) + lock, lockErr := lockfile.LoadWeightsLock(lockPath) + if lockErr != nil && !errors.Is(lockErr, os.ErrNotExist) { + console.Warnf("Failed to load %s: %s", lockfile.WeightsLockFilename, lockErr) + } + + // Resolve registry repo — required for status checks. + if src.Config.Image == "" { + return fmt.Errorf("no 'image' configured in %s — cannot check registry state", configFilename) + } + repo, err := parseRepoOnly(src.Config.Image) + if err != nil { + return fmt.Errorf("invalid image %q: %w", src.Config.Image, err) + } + + reg := registry.NewRegistryClient() + + ws, err := model.ComputeWeightsStatus(ctx, src.Config, lock, repo, reg) + if err != nil { + return fmt.Errorf("computing weight status: %w", err) + } + + // Format output. + entries := statusResultsToEntries(ws.Results()) + out := &WeightsStatusOutput{Weights: entries} + + if jsonOutput { + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + return enc.Encode(out) + } + + printWeightsStatusText(out, verbose) + + if !ws.AllReady() { + return errWeightsNotReady + } + + return nil +} + +// statusResultsToEntries converts model results to CLI output entries. +func statusResultsToEntries(results []model.WeightStatusResult) []WeightStatusEntry { + entries := make([]WeightStatusEntry, len(results)) + for i, r := range results { + entries[i] = WeightStatusEntry{ + Name: r.Name, + Target: r.Target, + Status: string(r.Status), + } + if r.LockEntry != nil { + le := r.LockEntry + entries[i].Size = le.Size + entries[i].SizeCompressed = le.SizeCompressed + entries[i].LayerCount = len(le.Layers) + entries[i].FileCount = len(le.Files) + entries[i].Digest = le.Digest + entries[i].Source = lockSourceToStatus(le.Source) + } + for _, l := range r.Layers { + entries[i].Layers = append(entries[i].Layers, LayerStatusEntry{ + Digest: l.Digest, + Size: l.Size, + Status: string(l.Status), + }) + } + } + return entries +} + +func lockSourceToStatus(s lockfile.WeightLockSource) *WeightStatusSource { + fp := string(s.Fingerprint) + if s.URI == "" && fp == "" { + return nil + } + return &WeightStatusSource{ + URI: s.URI, + Fingerprint: fp, + } +} + +func printWeightsStatusText(out *WeightsStatusOutput, verbose bool) { + if len(out.Weights) == 0 { + fmt.Println("No weights configured.") + return + } + + tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + _, _ = fmt.Fprintln(tw, "NAME\tTARGET\tSTATUS\tSIZE\tLAYERS\tDIGEST") + + for _, e := range out.Weights { + size := "-" + if e.Size > 0 { + size = formatSize(e.Size) + } + + layers := "-" + if e.LayerCount > 0 { + layers = fmt.Sprintf("%d", e.LayerCount) + } + + digest := "-" + if e.Digest != "" { + digest = formatDigestShort(e.Digest) + } + + _, _ = fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n", + e.Name, e.Target, e.Status, size, layers, digest) + + if verbose && len(e.Layers) > 0 { + for i, l := range e.Layers { + prefix := "├─" + if i == len(e.Layers)-1 { + prefix = "└─" + } + _, _ = fmt.Fprintf(tw, " %s\t\t%s\t%s\t\t%s\n", + prefix, l.Status, formatSize(l.Size), formatDigestShort(l.Digest)) + } + } + } + + _ = tw.Flush() +} + +// formatDigestShort returns a human-friendly short digest like "sha256:a1b2c3d4e5f6". +func formatDigestShort(digest string) string { + short := model.ShortDigest(digest) + if short == "" { + return digest + } + algo, _, _ := strings.Cut(digest, ":") + return algo + ":" + short +} diff --git a/pkg/cli/weights_status_test.go b/pkg/cli/weights_status_test.go new file mode 100644 index 0000000000..7b57a8c3ea --- /dev/null +++ b/pkg/cli/weights_status_test.go @@ -0,0 +1,156 @@ +package cli + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/weights/lockfile" +) + +func TestStatusResultsToEntries(t *testing.T) { + results := []model.WeightStatusResult{ + { + Name: "base", + Target: "/weights/base", + Status: model.WeightStatusReady, + LockEntry: &lockfile.WeightLockEntry{ + Size: 4096, + SizeCompressed: 2048, + Layers: []lockfile.WeightLockLayer{{Digest: "sha256:l1"}, {Digest: "sha256:l2"}}, + Files: []lockfile.WeightLockFile{{Path: "a.bin"}, {Path: "b.bin"}, {Path: "c.bin"}}, + Digest: "sha256:manifest123", + Source: lockfile.WeightLockSource{URI: "file://./weights"}, + }, + Layers: []model.LayerStatusResult{ + {Digest: "sha256:l1", Size: 2048, Status: model.LayerStatusReady}, + {Digest: "sha256:l2", Size: 2048, Status: model.LayerStatusReady}, + }, + }, + { + Name: "pending", + Target: "/weights/new", + Status: model.WeightStatusPending, + }, + } + + entries := statusResultsToEntries(results) + + require.Len(t, entries, 2) + + assert.Equal(t, "base", entries[0].Name) + assert.EqualValues(t, model.WeightStatusReady, entries[0].Status) + assert.Equal(t, int64(4096), entries[0].Size) + assert.Equal(t, int64(2048), entries[0].SizeCompressed) + assert.Equal(t, 2, entries[0].LayerCount) + assert.Equal(t, 3, entries[0].FileCount) + assert.Equal(t, "sha256:manifest123", entries[0].Digest) + require.NotNil(t, entries[0].Source) + assert.Equal(t, "file://./weights", entries[0].Source.URI) + require.Len(t, entries[0].Layers, 2) + assert.EqualValues(t, model.LayerStatusReady, entries[0].Layers[0].Status) + + assert.Equal(t, "pending", entries[1].Name) + assert.EqualValues(t, model.WeightStatusPending, entries[1].Status) + assert.Equal(t, int64(0), entries[1].Size) + assert.Nil(t, entries[1].Source) + assert.Empty(t, entries[1].Layers) +} + +func TestWeightsStatusJSONOutput(t *testing.T) { + out := &WeightsStatusOutput{ + Weights: []WeightStatusEntry{ + { + Name: "base", + Target: "/weights/base", + Status: "ready", + Size: 4096, + SizeCompressed: 2048, + LayerCount: 2, + FileCount: 3, + Digest: "sha256:abc123", + Source: &WeightStatusSource{URI: "file://./weights", Fingerprint: "sha256:def456"}, + Layers: []LayerStatusEntry{ + {Digest: "sha256:l1", Size: 2048, Status: "ready"}, + {Digest: "sha256:l2", Size: 2048, Status: "ready"}, + }, + }, + { + Name: "pending-weight", + Target: "/weights/new", + Status: "pending", + }, + }, + } + + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetIndent("", " ") + require.NoError(t, enc.Encode(out)) + + var decoded WeightsStatusOutput + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + + require.Len(t, decoded.Weights, 2) + + ready := decoded.Weights[0] + assert.Equal(t, "base", ready.Name) + assert.EqualValues(t, model.WeightStatusReady, ready.Status) + assert.Len(t, ready.Layers, 2) + + pending := decoded.Weights[1] + assert.Equal(t, "pending-weight", pending.Name) + assert.EqualValues(t, model.WeightStatusPending, pending.Status) + assert.Empty(t, pending.Layers) +} + +func TestFormatDigestShort(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"sha256:a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6", "sha256:a1b2c3d4e5f6"}, + {"sha256:short", "sha256:short"}, + {"noprefix", "noprefix"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + assert.Equal(t, tt.want, formatDigestShort(tt.input)) + }) + } +} + +func TestPrintWeightsStatusText(t *testing.T) { + // Smoke test — just make sure it doesn't panic on various inputs. + printWeightsStatusText(&WeightsStatusOutput{}, false) + printWeightsStatusText(&WeightsStatusOutput{ + Weights: []WeightStatusEntry{ + {Name: "a", Target: "/a", Status: "ready", Size: 1073741824, LayerCount: 3, Digest: "sha256:abcdef123456"}, + {Name: "b", Target: "/b", Status: "pending"}, + {Name: "c", Target: "/c", Status: "orphaned", Size: 512, LayerCount: 1, Digest: "sha256:orphan999999"}, + }, + }, false) +} + +func TestPrintWeightsStatusText_Verbose(t *testing.T) { + // Smoke test for verbose output with layer tree. + printWeightsStatusText(&WeightsStatusOutput{ + Weights: []WeightStatusEntry{ + { + Name: "base", Target: "/w", Status: "incomplete", + Size: 5000000000, LayerCount: 3, Digest: "sha256:abc123", + Layers: []LayerStatusEntry{ + {Digest: "sha256:l1", Size: 2000000000, Status: "ready"}, + {Digest: "sha256:l2", Size: 2000000000, Status: "missing"}, + {Digest: "sha256:l3", Size: 1000000000, Status: "ready"}, + }, + }, + }, + }, true) +} diff --git a/pkg/config/config.go b/pkg/config/config.go index ba1fe8b681..440ded5fdd 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -66,11 +66,35 @@ type Concurrency struct { Max int `json:"max,omitempty" yaml:"max"` } -// WeightSource defines a weight file or directory to include in the model. +// WeightSourceConfig describes where to import weights from. +// This is the "source" sub-object inside a weights entry. +type WeightSourceConfig struct { + URI string `json:"uri,omitempty" yaml:"uri,omitempty"` + Include []string `json:"include,omitempty" yaml:"include,omitempty"` + Exclude []string `json:"exclude,omitempty" yaml:"exclude,omitempty"` +} + +// WeightSource defines a weight directory to include in the model. type WeightSource struct { - Name string `json:"name,omitempty" yaml:"name,omitempty"` - Source string `json:"source" yaml:"source"` - Target string `json:"target,omitempty" yaml:"target,omitempty"` + Name string `json:"name" yaml:"name"` + Target string `json:"target" yaml:"target"` + Source *WeightSourceConfig `json:"source,omitempty" yaml:"source,omitempty"` +} + +func (w *WeightSource) SourceURI() string { + if w.Source == nil { + return "" + } + return w.Source.URI +} + +// WeightNames returns the names of the given weight sources. +func WeightNames(ws []WeightSource) []string { + names := make([]string, len(ws)) + for i, w := range ws { + names[i] = w.Name + } + return names } type Config struct { diff --git a/pkg/config/config_file.go b/pkg/config/config_file.go index a3ee069680..4f78c7c227 100644 --- a/pkg/config/config_file.go +++ b/pkg/config/config_file.go @@ -50,11 +50,13 @@ type mountFile struct { Target string `json:"target,omitempty" yaml:"target,omitempty"` } -// weightFile represents a weight source configuration. +// weightFile represents a weight entry in cog.yaml. +// Uses WeightSourceConfig directly since it has no pointer fields that +// would need "not set" vs "zero value" distinction. type weightFile struct { - Name string `json:"name,omitempty" yaml:"name,omitempty"` - Source string `json:"source" yaml:"source"` - Target string `json:"target,omitempty" yaml:"target,omitempty"` + Name string `json:"name" yaml:"name"` + Target string `json:"target" yaml:"target"` + Source *WeightSourceConfig `json:"source,omitempty" yaml:"source,omitempty"` } // concurrencyFile represents concurrency configuration. diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index ba99b89d0a..b00e23fbb3 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -676,68 +676,86 @@ func TestAbsolutePathInPythonRequirements(t *testing.T) { require.True(t, ok) } -func TestWeightsWithNameYAML(t *testing.T) { +func TestWeightsWithSourceYAML(t *testing.T) { yamlString := `build: python_version: "3.12" +image: "registry.example.com/acme/my-model" predict: "predict.py:Predictor" weights: - name: model-v1 - source: file://./weights/model-v1.zip target: "/weights/model-v1" + source: + uri: "hf://acme/model-v1" + exclude: ["*.onnx"] - name: model-v2 - source: file://./weights/model-v2.zip target: "/weights/model-v2" + source: + uri: "./local-weights/" ` config, err := FromYAML([]byte(yamlString)) require.NoError(t, err) require.Len(t, config.Weights, 2) + require.Equal(t, "registry.example.com/acme/my-model", config.Image) require.Equal(t, "model-v1", config.Weights[0].Name) - require.Equal(t, "file://./weights/model-v1.zip", config.Weights[0].Source) require.Equal(t, "/weights/model-v1", config.Weights[0].Target) + require.NotNil(t, config.Weights[0].Source) + require.Equal(t, "hf://acme/model-v1", config.Weights[0].Source.URI) + require.Equal(t, "hf://acme/model-v1", config.Weights[0].SourceURI()) + require.Equal(t, []string{"*.onnx"}, config.Weights[0].Source.Exclude) + require.Empty(t, config.Weights[0].Source.Include) require.Equal(t, "model-v2", config.Weights[1].Name) - require.Equal(t, "file://./weights/model-v2.zip", config.Weights[1].Source) require.Equal(t, "/weights/model-v2", config.Weights[1].Target) + require.NotNil(t, config.Weights[1].Source) + require.Equal(t, "./local-weights/", config.Weights[1].Source.URI) } -func TestWeightsWithoutNameYAML(t *testing.T) { +func TestWeightsWithoutSourceYAML(t *testing.T) { yamlString := `build: python_version: "3.12" +image: "registry.example.com/acme/my-model" predict: "predict.py:Predictor" weights: - - source: file://./weights/model.zip - target: "/weights/model" + - name: base-model + target: "/src/weights" ` config, err := FromYAML([]byte(yamlString)) require.NoError(t, err) require.Len(t, config.Weights, 1) - require.Equal(t, "", config.Weights[0].Name) - require.Equal(t, "file://./weights/model.zip", config.Weights[0].Source) - require.Equal(t, "/weights/model", config.Weights[0].Target) + require.Equal(t, "base-model", config.Weights[0].Name) + require.Equal(t, "/src/weights", config.Weights[0].Target) + require.Nil(t, config.Weights[0].Source) + require.Equal(t, "", config.Weights[0].SourceURI()) } -func TestWeightsWithNameJSON(t *testing.T) { +func TestWeightsWithSourceJSON(t *testing.T) { jsonString := `{ "build": { "python_version": "3.12" }, + "image": "registry.example.com/acme/my-model", "predict": "predict.py:Predictor", "weights": [ { "name": "model-v1", - "source": "file://./weights/model-v1.zip", - "target": "/weights/model-v1" + "target": "/weights/model-v1", + "source": { + "uri": "hf://acme/model-v1", + "exclude": ["*.onnx"] + } }, { "name": "model-v2", - "source": "file://./weights/model-v2.zip", - "target": "/weights/model-v2" + "target": "/weights/model-v2", + "source": { + "uri": "./local-weights/" + } } ] }` @@ -746,14 +764,18 @@ func TestWeightsWithNameJSON(t *testing.T) { err := json.Unmarshal([]byte(jsonString), &config) require.NoError(t, err) require.Len(t, config.Weights, 2) + require.Equal(t, "registry.example.com/acme/my-model", config.Image) require.Equal(t, "model-v1", config.Weights[0].Name) - require.Equal(t, "file://./weights/model-v1.zip", config.Weights[0].Source) require.Equal(t, "/weights/model-v1", config.Weights[0].Target) + require.NotNil(t, config.Weights[0].Source) + require.Equal(t, "hf://acme/model-v1", config.Weights[0].Source.URI) + require.Equal(t, []string{"*.onnx"}, config.Weights[0].Source.Exclude) require.Equal(t, "model-v2", config.Weights[1].Name) - require.Equal(t, "file://./weights/model-v2.zip", config.Weights[1].Source) require.Equal(t, "/weights/model-v2", config.Weights[1].Target) + require.NotNil(t, config.Weights[1].Source) + require.Equal(t, "./local-weights/", config.Weights[1].Source.URI) } func TestSDKVersionConfig(t *testing.T) { diff --git a/pkg/config/data/config_schema_v1.0.json b/pkg/config/data/config_schema_v1.0.json index 91ca934607..645fe511d0 100644 --- a/pkg/config/data/config_schema_v1.0.json +++ b/pkg/config/data/config_schema_v1.0.json @@ -211,23 +211,46 @@ "array", "null" ], - "description": "A list of weight files or directories to include in the model.", + "description": "A list of weight directories to include in the model.", "items": { "type": "object", - "required": ["source"], + "required": ["name", "target"], "additionalProperties": false, "properties": { "name": { "type": "string", - "description": "A unique identifier for this weight entry." - }, - "source": { - "type": "string", - "description": "Path to a weight file or directory (relative to cog.yaml)." + "description": "A unique identifier for this weight entry. Maps to /weights/ in the registry.", + "pattern": "^[a-z0-9]+([._-][a-z0-9]+)*$" }, "target": { "type": "string", - "description": "Target path in the container (must be under /cache/). Defaults to /cache/." + "description": "Absolute path where the weight directory is mounted in the container." + }, + "source": { + "type": "object", + "description": "Import configuration for this weight entry. Documents provenance and re-import settings.", + "required": ["uri"], + "additionalProperties": false, + "properties": { + "uri": { + "type": "string", + "description": "Source location. Supported schemes: hf://, s3://, http://, https://, oci://, or a local filesystem path." + }, + "include": { + "type": "array", + "description": "Glob patterns for files to include (allowlist mode).", + "items": { + "type": "string" + } + }, + "exclude": { + "type": "array", + "description": "Glob patterns for files to skip during import.", + "items": { + "type": "string" + } + } + } } } } diff --git a/pkg/config/validate.go b/pkg/config/validate.go index e5ab37ebcc..69c931cba2 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -8,6 +8,7 @@ import ( "io/fs" "os" "path/filepath" + "regexp" "slices" "strconv" "strings" @@ -72,6 +73,7 @@ func ValidateConfigFile(cfg *configFile, opts ...ValidateOption) *ValidationResu validateBuild(cfg, options, result) validateEnvironment(cfg, result) validateConcurrency(cfg, result) + validateWeights(cfg, result) // Check deprecated fields checkDeprecatedFields(cfg, result) @@ -455,6 +457,146 @@ func validateConcurrency(cfg *configFile, result *ValidationResult) { } } +// weightNameRegex matches OCI-safe path components: lowercase alphanumeric, +// separated by hyphens, dots, or underscores. Weight names become registry +// path components (/weights/), so they must follow OCI rules. +var weightNameRegex = regexp.MustCompile(`^[a-z0-9]+(?:[._-][a-z0-9]+)*$`) + +// validateWeights validates the weights stanza. +func validateWeights(cfg *configFile, result *ValidationResult) { + if len(cfg.Weights) == 0 { + return + } + + // Weights require an image field. + if cfg.Image == nil || *cfg.Image == "" { + result.AddError(&ValidationError{ + Field: "image", + Message: "image is required when weights are configured", + }) + } + + seenNames := make(map[string]bool) + seenTargets := make(map[string]bool) + var cleanedTargets []string + + for i, w := range cfg.Weights { + idx := fmt.Sprintf("weights[%d]", i) + + // Name is required, must be OCI-safe, and must be unique. + switch { + case w.Name == "": + result.AddError(&ValidationError{ + Field: idx + ".name", + Message: "name is required", + }) + case !weightNameRegex.MatchString(w.Name): + result.AddError(&ValidationError{ + Field: idx + ".name", + Value: w.Name, + Message: "must contain only lowercase alphanumeric characters, hyphens, dots, or underscores (e.g. \"my-model-weights\")", + }) + case seenNames[w.Name]: + result.AddError(&ValidationError{ + Field: idx + ".name", + Value: w.Name, + Message: "duplicate weight name", + }) + default: + seenNames[w.Name] = true + } + + // Validate include/exclude patterns if source is present. + if w.Source != nil { + validateWeightPatterns(idx+".source.include", w.Source.Include, result) + validateWeightPatterns(idx+".source.exclude", w.Source.Exclude, result) + } + + // Target is required, must be absolute, and must be unique. + if w.Target == "" { + result.AddError(&ValidationError{ + Field: idx + ".target", + Message: "target is required", + }) + } else { + if !filepath.IsAbs(w.Target) { + result.AddError(&ValidationError{ + Field: idx + ".target", + Value: w.Target, + Message: "target must be an absolute path", + }) + } + + cleaned := filepath.Clean(w.Target) + if seenTargets[cleaned] { + result.AddError(&ValidationError{ + Field: idx + ".target", + Value: w.Target, + Message: "duplicate weight target", + }) + } else { + seenTargets[cleaned] = true + + // Check disjoint subtrees: no target may be an ancestor + // or descendant of another target. + for _, prev := range cleanedTargets { + if isSubpath(cleaned, prev) || isSubpath(prev, cleaned) { + result.AddError(&ValidationError{ + Field: idx + ".target", + Value: w.Target, + Message: fmt.Sprintf("target overlaps with %q; weight targets must be disjoint", prev), + }) + break + } + } + cleanedTargets = append(cleanedTargets, cleaned) + } + } + } +} + +// validateWeightPatterns validates a list of include or exclude glob patterns. +// It rejects empty-string patterns (including whitespace-only), !-prefixed +// patterns (gitignore negation), and patterns containing backslashes. +// Patterns are checked after trimming whitespace, but the input slice is +// not mutated — the caller must normalize patterns separately. +func validateWeightPatterns(field string, patterns []string, result *ValidationResult) { + for i, raw := range patterns { + p := strings.TrimSpace(raw) + + if p == "" { + result.AddError(&ValidationError{ + Field: fmt.Sprintf("%s[%d]", field, i), + Message: "pattern must not be empty", + }) + continue + } + if strings.HasPrefix(p, "!") { + result.AddError(&ValidationError{ + Field: fmt.Sprintf("%s[%d]", field, i), + Value: p, + Message: "negation patterns (starting with '!') are not supported", + }) + } + if strings.Contains(p, `\`) { + result.AddError(&ValidationError{ + Field: fmt.Sprintf("%s[%d]", field, i), + Value: p, + Message: "patterns must use forward slashes, not backslashes", + }) + } + } +} + +// isSubpath reports whether child is a strict subdirectory of parent. +// Both paths must be cleaned absolute paths. +func isSubpath(child, parent string) bool { + if child == parent { + return false + } + return strings.HasPrefix(child, parent+"/") +} + // checkDeprecatedFields checks for deprecated fields and adds warnings. func checkDeprecatedFields(cfg *configFile, result *ValidationResult) { if cfg.Build == nil { diff --git a/pkg/config/validate_test.go b/pkg/config/validate_test.go index 5f88a4fdd5..8029f60723 100644 --- a/pkg/config/validate_test.go +++ b/pkg/config/validate_test.go @@ -176,5 +176,239 @@ func TestValidateConfigFileNilBuildSkipsPythonVersionCheck(t *testing.T) { require.False(t, result.HasErrors(), "expected no errors for nil build, got: %v", result.Errors) } +func TestValidateWeights(t *testing.T) { + image := ptr("registry.example.com/acme/my-model") + + tests := []struct { + name string + image *string + weights []weightFile + wantErr string // empty means expect no error + }{ + { + name: "valid with two weights", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights"}, + {Name: "lora", Target: "/src/lora"}, + }, + }, + { + name: "valid with source", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights", Source: &WeightSourceConfig{URI: "hf://acme/model", Exclude: []string{"*.onnx"}}}, + }, + }, + { + name: "valid without source", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights"}, + }, + }, + { + name: "weights without image", + weights: []weightFile{ + {Name: "base", Target: "/src/weights"}, + }, + wantErr: "image is required when weights are configured", + }, + { + name: "missing name", + image: image, + weights: []weightFile{ + {Name: "", Target: "/src/weights"}, + }, + wantErr: "name is required", + }, + { + name: "uppercase name", + image: image, + weights: []weightFile{ + {Name: "MyModel", Target: "/src/weights"}, + }, + wantErr: "must contain only lowercase", + }, + { + name: "name with spaces", + image: image, + weights: []weightFile{ + {Name: "my model", Target: "/src/weights"}, + }, + wantErr: "must contain only lowercase", + }, + { + name: "name starting with hyphen", + image: image, + weights: []weightFile{ + {Name: "-base", Target: "/src/weights"}, + }, + wantErr: "must contain only lowercase", + }, + { + name: "valid name with separators", + image: image, + weights: []weightFile{ + {Name: "z-image.turbo_v1", Target: "/src/weights"}, + }, + }, + { + name: "duplicate name", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights"}, + {Name: "base", Target: "/src/other"}, + }, + wantErr: "duplicate weight name", + }, + { + name: "missing target", + image: image, + weights: []weightFile{ + {Name: "base", Target: ""}, + }, + wantErr: "target is required", + }, + { + name: "relative target", + image: image, + weights: []weightFile{ + {Name: "base", Target: "src/weights"}, + }, + wantErr: "target must be an absolute path", + }, + { + name: "duplicate target", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights"}, + {Name: "lora", Target: "/src/weights"}, + }, + wantErr: "duplicate weight target", + }, + { + name: "overlapping targets parent then child", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights"}, + {Name: "lora", Target: "/src/weights/lora"}, + }, + wantErr: "target overlaps with", + }, + { + name: "overlapping targets child then parent", + image: image, + weights: []weightFile{ + {Name: "lora", Target: "/src/weights/lora"}, + {Name: "base", Target: "/src/weights"}, + }, + wantErr: "target overlaps with", + }, + { + name: "disjoint targets no false positive", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights"}, + {Name: "lora", Target: "/src/weights2"}, + }, + }, + { + name: "valid include and exclude patterns", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights", Source: &WeightSourceConfig{ + URI: "hf://acme/model", + Include: []string{"*.safetensors", "*.json"}, + Exclude: []string{"*.onnx", "*.bin"}, + }}, + }, + }, + { + name: "empty string in include pattern", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights", Source: &WeightSourceConfig{ + URI: "hf://acme/model", + Include: []string{"*.safetensors", ""}, + }}, + }, + wantErr: "pattern must not be empty", + }, + { + name: "empty string in exclude pattern", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights", Source: &WeightSourceConfig{ + URI: "hf://acme/model", + Exclude: []string{""}, + }}, + }, + wantErr: "pattern must not be empty", + }, + { + name: "negation pattern in include", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights", Source: &WeightSourceConfig{ + URI: "hf://acme/model", + Include: []string{"!*.bin"}, + }}, + }, + wantErr: "negation patterns", + }, + { + name: "negation pattern in exclude", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights", Source: &WeightSourceConfig{ + URI: "hf://acme/model", + Exclude: []string{"!*.safetensors"}, + }}, + }, + wantErr: "negation patterns", + }, + { + name: "whitespace-only pattern rejected after trim", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights", Source: &WeightSourceConfig{ + URI: "hf://acme/model", + Include: []string{" "}, + }}, + }, + wantErr: "pattern must not be empty", + }, + { + name: "backslash in pattern rejected", + image: image, + weights: []weightFile{ + {Name: "base", Target: "/src/weights", Source: &WeightSourceConfig{ + URI: "hf://acme/model", + Exclude: []string{`onnx\*.bin`}, + }}, + }, + wantErr: "must use forward slashes", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &configFile{ + Build: &buildFile{PythonVersion: ptr("3.12")}, + Image: tt.image, + Weights: tt.weights, + } + result := ValidateConfigFile(cfg) + if tt.wantErr == "" { + require.False(t, result.HasErrors(), "expected no errors, got: %v", result.Errors) + } else { + require.True(t, result.HasErrors()) + require.Contains(t, result.Err().Error(), tt.wantErr) + } + }) + } +} + // ptr returns a pointer to the given value. func ptr[T any](v T) *T { return &v } diff --git a/pkg/docker/command/command.go b/pkg/docker/command/command.go index 1d3539a042..dc7c140192 100644 --- a/pkg/docker/command/command.go +++ b/pkg/docker/command/command.go @@ -74,4 +74,5 @@ type Port struct { type Volume struct { Source string Destination string + ReadOnly bool } diff --git a/pkg/docker/docker.go b/pkg/docker/docker.go index cfc9a578b6..e4d3c88008 100644 --- a/pkg/docker/docker.go +++ b/pkg/docker/docker.go @@ -448,7 +448,11 @@ func (c *apiClient) containerRun(ctx context.Context, options command.RunOptions if len(options.Volumes) > 0 { hostCfg.Binds = make([]string, len(options.Volumes)) for i, volume := range options.Volumes { - hostCfg.Binds[i] = fmt.Sprintf("%s:%s", volume.Source, volume.Destination) + bind := fmt.Sprintf("%s:%s", volume.Source, volume.Destination) + if volume.ReadOnly { + bind += ":ro" + } + hostCfg.Binds[i] = bind } } diff --git a/pkg/image/build.go b/pkg/image/build.go index 056a37e3f8..c16b132043 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -30,6 +30,7 @@ import ( "github.com/replicate/cog/pkg/schema/python" "github.com/replicate/cog/pkg/util/console" cogversion "github.com/replicate/cog/pkg/util/version" + "github.com/replicate/cog/pkg/weights/lockfile" "github.com/replicate/cog/pkg/weightslegacy" "github.com/replicate/cog/pkg/wheels" ) @@ -37,6 +38,7 @@ import ( const dockerignoreBackupPath = ".dockerignore.cog.bak" const weightsManifestPath = ".cog/cache/weights_manifest.json" const bundledSchemaFile = ".cog/openapi_schema.json" +const bundledWeightsFile = ".cog/weights.json" var errGit = errors.New("git error") @@ -65,8 +67,9 @@ func Build( annotations map[string]string, dockerCommand command.Command, client registry.Client) (string, error) { - // remove bundled schema files that may be left from previous builds + // remove bundled files that may be left from previous builds _ = os.Remove(bundledSchemaFile) + _ = os.Remove(bundledWeightsFile) if err := checkCompatibleDockerIgnore(dir); err != nil { return "", err @@ -132,6 +135,16 @@ func Build( } } + // --- Runtime weights manifest (/.cog/weights.json) --- + // When managed weights are configured and a lockfile exists, project the + // lockfile to the minimal runtime manifest (spec §3.3) and write it into + // the build context so it ends up at /.cog/weights.json in the image. + if len(cfg.Weights) > 0 { + if err := writeRuntimeWeightsManifest(dir); err != nil { + return "", err + } + } + // --- Docker build --- var cogBaseImageName string @@ -294,18 +307,14 @@ func Build( // so we don't need metadata labels, pip freeze, or git info. // We still need the schema bundled, so do a minimal second build to add it. if skipLabels { - if len(schemaJSON) > 0 { - // Use trailing "/" on the destination so Docker creates the .cog/ - // directory even in ExcludeSource images where COPY . /src was - // skipped and .cog/ does not yet exist. - schemaDockerfile := fmt.Sprintf("FROM %s\nCOPY %s .cog/\n", tmpImageId, bundledSchemaFile) + if files := collectBundleFiles(schemaJSON); len(files) > 0 { buildOpts := command.ImageBuildOptions{ - DockerfileContents: schemaDockerfile, + DockerfileContents: bundleDockerfile(tmpImageId, files), ImageName: tmpImageId, ProgressOutput: progressOutput, } if _, err := dockerCommand.ImageBuild(ctx, buildOpts); err != nil { - return "", fmt.Errorf("Failed to bundle schema into image: %w", err) + return "", fmt.Errorf("Failed to bundle .cog files into image: %w", err) } } return tmpImageId, nil @@ -389,12 +398,7 @@ func Build( maps.Copy(labels, annotations) // The final image ID comes from the label-adding step. - // When schema validation is skipped (cog exec), there is no schema file to bundle. - schemaFileToBundle := bundledSchemaFile - if skipSchemaValidation { - schemaFileToBundle = "" - } - imageID, err := BuildAddLabelsAndSchemaToImage(ctx, dockerCommand, tmpImageId, imageName, labels, schemaFileToBundle, progressOutput) + imageID, err := BuildAddLabelsAndSchemaToImage(ctx, dockerCommand, tmpImageId, imageName, labels, collectBundleFiles(schemaJSON), progressOutput) if err != nil { return "", fmt.Errorf("Failed to add labels to image: %w", err) } @@ -409,21 +413,15 @@ func Build( return imageID, nil } -// BuildAddLabelsAndSchemaToImage builds a cog model with labels and schema. -// Returns the image ID (sha256:...) of the final image. +// BuildAddLabelsAndSchemaToImage builds a cog model with labels and bundled +// .cog/ files. Returns the image ID (sha256:...) of the final image. // -// The new image is based on the provided image with the labels and schema file appended to it. +// The new image is based on the provided image with the labels and any +// bundled files (schema, weights manifest, etc.) appended to it. // tmpName is the source image to build from, image is the final image name/tag. -func BuildAddLabelsAndSchemaToImage(ctx context.Context, dockerClient command.Command, tmpName, image string, labels map[string]string, bundledSchemaFile string, progressOutput string) (string, error) { - var dockerfile string - if bundledSchemaFile != "" { - dockerfile = fmt.Sprintf("FROM %s\nCOPY %s .cog\n", tmpName, bundledSchemaFile) - } else { - dockerfile = fmt.Sprintf("FROM %s\n", tmpName) - } - +func BuildAddLabelsAndSchemaToImage(ctx context.Context, dockerClient command.Command, tmpName, image string, labels map[string]string, bundleFiles []string, progressOutput string) (string, error) { buildOpts := command.ImageBuildOptions{ - DockerfileContents: dockerfile, + DockerfileContents: bundleDockerfile(tmpName, bundleFiles), ImageName: image, Labels: labels, ProgressOutput: progressOutput, @@ -431,7 +429,7 @@ func BuildAddLabelsAndSchemaToImage(ctx context.Context, dockerClient command.Co imageID, err := dockerClient.ImageBuild(ctx, buildOpts) if err != nil { - return "", fmt.Errorf("Failed to add labels and schema to image: %w", err) + return "", fmt.Errorf("Failed to add labels to image: %w", err) } return imageID, nil } @@ -540,6 +538,56 @@ func writeAndValidateSchema(schemaJSON []byte) error { return nil } +// writeRuntimeWeightsManifest projects the lockfile to /.cog/weights.json (spec §3.3). +func writeRuntimeWeightsManifest(dir string) error { + lockPath := filepath.Join(dir, lockfile.WeightsLockFilename) + lock, err := lockfile.LoadWeightsLock(lockPath) + if err != nil { + return fmt.Errorf("managed weights configured but no lockfile found: %w\nRun 'cog weights import' first", err) + } + + manifest := lock.RuntimeManifest() + data, err := json.MarshalIndent(manifest, "", " ") + if err != nil { + return fmt.Errorf("failed to serialize runtime weights manifest: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(bundledWeightsFile), 0o755); err != nil { + return fmt.Errorf("failed to create directory for %s: %w", bundledWeightsFile, err) + } + if err := os.WriteFile(bundledWeightsFile, data, 0o644); err != nil { //nolint:gosec // bundled into image, not a secret + return fmt.Errorf("failed to write runtime weights manifest %s: %w", bundledWeightsFile, err) + } + console.Debugf("Wrote runtime weights manifest to %s (%d weights)", bundledWeightsFile, len(manifest.Weights)) + return nil +} + +// collectBundleFiles returns the list of .cog/ files that should be COPYed +// into the final image layer. It checks schemaJSON (non-nil = schema was +// generated) and probes the filesystem for the weights manifest. +func collectBundleFiles(schemaJSON []byte) []string { + var files []string + if len(schemaJSON) > 0 { + files = append(files, bundledSchemaFile) + } + if _, err := os.Stat(bundledWeightsFile); err == nil { + files = append(files, bundledWeightsFile) + } + return files +} + +// bundleDockerfile returns a Dockerfile that COPYs .cog/ files into the +// image. Trailing "/" on the destination ensures Docker creates the .cog/ +// directory even when COPY . /src was skipped (ExcludeSource images). +func bundleDockerfile(baseImage string, files []string) string { + var b strings.Builder + fmt.Fprintf(&b, "FROM %s\n", baseImage) + for _, f := range files { + fmt.Fprintf(&b, "COPY %s .cog/\n", f) + } + return b.String() +} + func isGitWorkTree(ctx context.Context, dir string) bool { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() diff --git a/pkg/image/build_test.go b/pkg/image/build_test.go index f60787df45..e518965834 100644 --- a/pkg/image/build_test.go +++ b/pkg/image/build_test.go @@ -2,15 +2,18 @@ package image import ( "context" + "encoding/json" "os" "os/exec" "path/filepath" "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/weights/lockfile" ) var hasGit = (func() bool { @@ -258,3 +261,98 @@ func TestUseStaticSchemaGen(t *testing.T) { }) } } + +func TestWriteRuntimeWeightsManifest(t *testing.T) { + dir := t.TempDir() + + lock := &lockfile.WeightsLock{ + Version: lockfile.Version, + Weights: []lockfile.WeightLockEntry{ + { + Name: "model-a", + Target: "/src/weights/a", + SetDigest: "sha256:aaa111", + Digest: "sha256:manifest-a", + }, + { + Name: "model-b", + Target: "/src/weights/b", + SetDigest: "sha256:bbb222", + Digest: "sha256:manifest-b", + }, + }, + } + require.NoError(t, lock.Save(filepath.Join(dir, lockfile.WeightsLockFilename))) + + // writeRuntimeWeightsManifest writes to the CWD-relative bundledWeightsFile. + t.Chdir(t.TempDir()) + + require.NoError(t, writeRuntimeWeightsManifest(dir)) + + data, err := os.ReadFile(bundledWeightsFile) + require.NoError(t, err) + + var manifest lockfile.RuntimeWeightsManifest + require.NoError(t, json.Unmarshal(data, &manifest)) + require.Len(t, manifest.Weights, 2) + + assert.Equal(t, "model-a", manifest.Weights[0].Name) + assert.Equal(t, "/src/weights/a", manifest.Weights[0].Target) + assert.Equal(t, "sha256:aaa111", manifest.Weights[0].SetDigest) + + assert.Equal(t, "model-b", manifest.Weights[1].Name) + assert.Equal(t, "/src/weights/b", manifest.Weights[1].Target) + assert.Equal(t, "sha256:bbb222", manifest.Weights[1].SetDigest) + + // Verify the JSON contains only the spec §3.3 fields (no lockfile extras). + var raw map[string]json.RawMessage + require.NoError(t, json.Unmarshal(data, &raw)) + + var entries []map[string]json.RawMessage + require.NoError(t, json.Unmarshal(raw["weights"], &entries)) + for i, entry := range entries { + keys := make([]string, 0, len(entry)) + for k := range entry { + keys = append(keys, k) + } + assert.ElementsMatch(t, []string{"name", "target", "setDigest"}, keys, + "entry %d must have exactly the spec §3.3 fields", i) + } +} + +func TestWriteRuntimeWeightsManifest_MissingLockfile(t *testing.T) { + err := writeRuntimeWeightsManifest(t.TempDir()) + require.Error(t, err) + assert.Contains(t, err.Error(), "managed weights configured but no lockfile found") +} + +func TestCollectBundleFiles_SchemaOnly(t *testing.T) { + t.Chdir(t.TempDir()) + + files := collectBundleFiles([]byte(`{"openapi":"3.0.0"}`)) + assert.Equal(t, []string{bundledSchemaFile}, files) +} + +func TestCollectBundleFiles_Nothing(t *testing.T) { + t.Chdir(t.TempDir()) + + files := collectBundleFiles(nil) + assert.Empty(t, files) +} + +func TestCollectBundleFiles_WithWeightsFile(t *testing.T) { + t.Chdir(t.TempDir()) + + require.NoError(t, os.MkdirAll(".cog", 0o755)) + require.NoError(t, os.WriteFile(bundledWeightsFile, []byte(`{"weights":[]}`), 0o644)) + + files := collectBundleFiles([]byte(`{"openapi":"3.0.0"}`)) + assert.Equal(t, []string{bundledSchemaFile, bundledWeightsFile}, files) +} + +func TestBundleDockerfile(t *testing.T) { + df := bundleDockerfile("myimage:latest", []string{bundledSchemaFile, bundledWeightsFile}) + assert.Contains(t, df, "FROM myimage:latest") + assert.Contains(t, df, "COPY .cog/openapi_schema.json .cog/") + assert.Contains(t, df, "COPY .cog/weights.json .cog/") +} diff --git a/pkg/model/artifact_image.go b/pkg/model/artifact_image.go index 9894d40149..38cdc25600 100644 --- a/pkg/model/artifact_image.go +++ b/pkg/model/artifact_image.go @@ -26,6 +26,11 @@ type Platform struct { Variant string } +// PlatformUnknown is the OCI placeholder ("unknown") used for non-platform +// artifacts such as weight manifests, distinguishing them from the model +// image manifest in the OCI index. +const PlatformUnknown = "unknown" + // Label keys for Cog-specific metadata stored in image labels. var ( LabelConfig = global.LabelNamespace + "config" diff --git a/pkg/model/artifact_weight.go b/pkg/model/artifact_weight.go index 382f53a5c7..ff71a4ee45 100644 --- a/pkg/model/artifact_weight.go +++ b/pkg/model/artifact_weight.go @@ -1,99 +1,170 @@ package model import ( - "time" + "context" + "fmt" + "slices" + "strings" v1 "github.com/google/go-containerregistry/pkg/v1" -) -// Media types for weight artifacts (OCI 1.1 conventions). -const ( - // MediaTypeWeightArtifact is the artifactType for weight manifests. - MediaTypeWeightArtifact = "application/vnd.cog.weight.v1" - // MediaTypeWeightConfig is the media type for weight config blobs. - MediaTypeWeightConfig = "application/vnd.cog.weight.config.v1+json" - // MediaTypeWeightLayer is the media type for uncompressed weight layers. - MediaTypeWeightLayer = "application/vnd.cog.weight.layer.v1" - // MediaTypeWeightLayerGzip is the media type for gzip-compressed weight layers. - MediaTypeWeightLayerGzip = "application/vnd.cog.weight.layer.v1+gzip" - // MediaTypeWeightLayerZstd is the media type for zstd-compressed weight layers (future). - MediaTypeWeightLayerZstd = "application/vnd.cog.weight.layer.v1+zstd" + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/model/weightsource" + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" ) -// Annotation keys for weight file layers in OCI manifests. -const ( - AnnotationWeightName = "vnd.cog.weight.name" - AnnotationWeightDest = "vnd.cog.weight.dest" - AnnotationWeightDigestOriginal = "vnd.cog.weight.digest.original" - AnnotationWeightSizeUncompressed = "vnd.cog.weight.size.uncompressed" -) +// MediaTypeWeightArtifact is the artifactType on a weight manifest. Layers +// use standard OCI layer media types; this constant lives on the manifest +// itself so clients can distinguish weight manifests from image manifests +// without parsing annotations. +const MediaTypeWeightArtifact = "application/vnd.cog.weight.v1" -// WeightSpec declares a weight artifact to be built. -// It implements ArtifactSpec. +// WeightSpec is the normalized, user-declared description of a weight: +// target mount path, source URI, and include/exclude filters. Construct +// via WeightSpecFromConfig or WeightSpecFromLock; compare with Equal. +// +// Include and Exclude are sorted at construction time. They describe a +// set of glob patterns applied by the packer, so order is not part of +// the user's intent — reordering patterns in cog.yaml must not trigger +// a rebuild. type WeightSpec struct { - name string - // Source is the local file path to the weight file. - Source string - // Target is the container mount path for this weight. - Target string + name string + Target string // container mount path + URI string // normalized source URI (file://./weights, hf://org/repo) + Include []string // sorted glob patterns + Exclude []string // sorted glob patterns +} + +// WeightSpecFromConfig builds a WeightSpec from a cog.yaml weight entry, +// normalizing the URI and cloning+sorting Include/Exclude. Returns an +// error if the URI is empty or uses an unknown scheme. +func WeightSpecFromConfig(w config.WeightSource) (*WeightSpec, error) { + uri, err := weightsource.NormalizeURI(w.SourceURI()) + if err != nil { + return nil, fmt.Errorf("weight %q: %w", w.Name, err) + } + var include, exclude []string + if w.Source != nil { + include = sortedClone(w.Source.Include) + exclude = sortedClone(w.Source.Exclude) + } + return &WeightSpec{ + name: w.Name, + Target: w.Target, + URI: uri, + Include: include, + Exclude: exclude, + }, nil } -// NewWeightSpec creates a WeightSpec with the given name, source path, and target mount path. -func NewWeightSpec(name, source, target string) *WeightSpec { +// WeightSpecFromLock extracts the user-intent fields (target, URI, +// include/exclude) from a lockfile entry. Fields are copied as stored: +// no re-normalization. A lockfile whose on-disk form differs from what +// we would write today — whether in URI form, include/exclude order, or +// anything else — must report as drift so the next build rewrites it. +func WeightSpecFromLock(e lockfile.WeightLockEntry) *WeightSpec { return &WeightSpec{ - name: name, - Source: source, - Target: target, + name: e.Name, + Target: e.Target, + URI: e.Source.URI, + Include: slices.Clone(e.Source.Include), + Exclude: slices.Clone(e.Source.Exclude), } } -// Type returns ArtifactTypeWeight. -func (s *WeightSpec) Type() ArtifactType { return ArtifactTypeWeight } +// sortedClone returns a sorted copy of s with whitespace-trimmed elements, +// or nil if s is nil. Trimming normalizes patterns that may have stray +// whitespace from YAML parsing; sorting removes order-dependence so +// reordering patterns in cog.yaml does not trigger a rebuild. +func sortedClone(s []string) []string { + if s == nil { + return nil + } + out := make([]string, len(s)) + for i, v := range s { + out[i] = strings.TrimSpace(v) + } + slices.Sort(out) + return out +} -// Name returns the spec's logical name. -func (s *WeightSpec) Name() string { return s.name } +// Equal reports whether two specs describe the same user intent. +// Name is excluded: callers only compare specs for the same weight name. +func (s *WeightSpec) Equal(other *WeightSpec) bool { + return s.Target == other.Target && + s.URI == other.URI && + slices.Equal(s.Include, other.Include) && + slices.Equal(s.Exclude, other.Exclude) +} -// WeightArtifact is a built weight artifact ready to push as an OCI artifact. +func (s *WeightSpec) Type() ArtifactType { return ArtifactTypeWeight } +func (s *WeightSpec) Name() string { return s.name } + +// WeightArtifact is a built weight artifact ready to push as an OCI manifest. // It implements Artifact. +// +// The lockfile entry (Entry) is the single source of truth for all +// metadata. Each layer carries its layerPlan; layer bytes are +// reproduced on demand by streaming from store at push time. type WeightArtifact struct { - name string descriptor v1.Descriptor - // FilePath is the local file path to the weight data (for pushing layers). - FilePath string - // Target is the container mount path for this weight. - Target string - // Config is the weight metadata for the config blob. - Config WeightConfig + // Entry is the lockfile entry describing this weight's metadata. + // Must not be mutated after construction. + Entry lockfile.WeightLockEntry + + // Layers are the packed layer descriptors. The pusher reads bytes + // for each layer by replaying its layerPlan against store; their + // metadata (digest, size, mediaType) matches Entry.Layers. + Layers []packedLayer + + // store is the content-addressed store from which layer bytes are + // re-streamed during push. Required for any path that reads + // layer bytes; tests that only inspect Entry/Layers metadata may + // leave it nil. + store store.Store } -// NewWeightArtifact creates a WeightArtifact from a build result. -func NewWeightArtifact(name string, desc v1.Descriptor, filePath, target string, cfg WeightConfig) *WeightArtifact { +// buildWeightArtifact builds a WeightArtifact from a lockfile entry, +// packed layer descriptors, and the store from which the layers can +// be re-streamed during push. It assembles the manifest *for digest +// computation only* (so entry.Digest can be backfilled), then +// discards it: the manifest carries fileLayers wired to a particular +// context, so reusing it across Push calls would defeat +// cancellation. Push rebuilds the manifest with the push context. +func buildWeightArtifact(entry *lockfile.WeightLockEntry, layers []packedLayer, st store.Store) (*WeightArtifact, error) { + img, err := buildWeightManifestV1(context.Background(), *entry, layers, st) + if err != nil { + return nil, fmt.Errorf("build weight manifest: %w", err) + } + desc, err := descriptorFromImage(img) + if err != nil { + return nil, fmt.Errorf("compute manifest descriptor: %w", err) + } + entry.Digest = desc.Digest.String() return &WeightArtifact{ - name: name, descriptor: desc, - FilePath: filePath, - Target: target, - Config: cfg, - } + Entry: *entry, + Layers: layers, + store: st, + }, nil } -// Type returns ArtifactTypeWeight. -func (a *WeightArtifact) Type() ArtifactType { return ArtifactTypeWeight } - -// Name returns the artifact's logical name. -func (a *WeightArtifact) Name() string { return a.name } +// newWeightArtifact creates a WeightArtifact with a pre-built manifest. +// Prefer buildWeightArtifact for production use; this is for tests that +// need a minimal artifact without a real manifest. +func newWeightArtifact(entry lockfile.WeightLockEntry, desc v1.Descriptor, layers []packedLayer) *WeightArtifact { + return &WeightArtifact{ + descriptor: desc, + Entry: entry, + Layers: layers, + } +} -// Descriptor returns the OCI descriptor for this weight artifact. +func (a *WeightArtifact) Type() ArtifactType { return ArtifactTypeWeight } +func (a *WeightArtifact) Name() string { return a.Entry.Name } func (a *WeightArtifact) Descriptor() v1.Descriptor { return a.descriptor } -// WeightConfig contains metadata about a weight artifact. -// This is serialized as the config blob in the OCI manifest. -// The schema is versioned via SchemaVersion to allow evolution. -type WeightConfig struct { - SchemaVersion string `json:"schemaVersion"` - CogVersion string `json:"cogVersion"` - Name string `json:"name"` - Target string `json:"target"` - Created time.Time `json:"created"` // RFC 3339 format when serialized to JSON -} +// TotalSize returns the sum of all layer blob sizes (bytes over the wire). +func (a *WeightArtifact) TotalSize() int64 { return a.Entry.SizeCompressed } diff --git a/pkg/model/artifact_weight_test.go b/pkg/model/artifact_weight_test.go index 30348a8430..bbca7cf480 100644 --- a/pkg/model/artifact_weight_test.go +++ b/pkg/model/artifact_weight_test.go @@ -1,16 +1,24 @@ package model import ( - "encoding/json" "testing" "time" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/model/weightsource" + "github.com/replicate/cog/pkg/weights/lockfile" ) -func TestWeightSpec_ImplementsArtifactSpec(t *testing.T) { - spec := NewWeightSpec("my-model-weights", "/data/weights.bin", "/weights/model.bin") +func TestWeightSpecFromConfig_ImplementsArtifactSpec(t *testing.T) { + spec, err := WeightSpecFromConfig(config.WeightSource{ + Name: "my-model-weights", + Target: "/src/weights", + Source: &config.WeightSourceConfig{URI: "/data/weights"}, + }) + require.NoError(t, err) var _ ArtifactSpec = spec // compile-time interface check @@ -18,11 +26,166 @@ func TestWeightSpec_ImplementsArtifactSpec(t *testing.T) { require.Equal(t, "my-model-weights", spec.Name()) } -func TestWeightSpec_Fields(t *testing.T) { - spec := NewWeightSpec("llama-7b", "/data/llama-7b.safetensors", "/weights/llama-7b.safetensors") +func TestWeightSpecFromConfig_NormalizesURI(t *testing.T) { + spec, err := WeightSpecFromConfig(config.WeightSource{ + Name: "llama-7b", + Target: "/src/weights/llama-7b", + Source: &config.WeightSourceConfig{URI: "weights/llama-7b"}, + }) + require.NoError(t, err) + + require.Equal(t, "file://./weights/llama-7b", spec.URI) + require.Equal(t, "/src/weights/llama-7b", spec.Target) + require.Empty(t, spec.Include) + require.Empty(t, spec.Exclude) +} + +func TestWeightSpecFromConfig_CopiesIncludeExclude(t *testing.T) { + src := &config.WeightSourceConfig{ + URI: "weights/mw", + Include: []string{"*.safetensors"}, + Exclude: []string{"*.onnx"}, + } + spec, err := WeightSpecFromConfig(config.WeightSource{ + Name: "mw", + Target: "/src/weights/mw", + Source: src, + }) + require.NoError(t, err) + + require.Equal(t, []string{"*.safetensors"}, spec.Include) + require.Equal(t, []string{"*.onnx"}, spec.Exclude) + + // Mutating the config after construction must not affect the spec. + src.Include[0] = "changed" + require.Equal(t, []string{"*.safetensors"}, spec.Include) +} + +func TestWeightSpecFromConfig_EmptyURIError(t *testing.T) { + _, err := WeightSpecFromConfig(config.WeightSource{Name: "w", Target: "/w"}) + require.Error(t, err) +} + +func TestWeightSpecFromConfig_InvalidSchemeError(t *testing.T) { + _, err := WeightSpecFromConfig(config.WeightSource{ + Name: "w", Target: "/w", + Source: &config.WeightSourceConfig{URI: "bogus://nope"}, + }) + require.Error(t, err) +} + +func TestWeightSpecFromLock_ExtractsIntent(t *testing.T) { + entry := lockfile.WeightLockEntry{ + Name: "w", + Target: "/src/w", + Source: lockfile.WeightLockSource{ + URI: "file://./w", + Fingerprint: weightsource.Fingerprint("sha256:abc"), + Include: []string{"*.safetensors"}, + Exclude: []string{"*.onnx"}, + ImportedAt: time.Now(), + }, + Digest: "sha256:manifest", + } + + spec := WeightSpecFromLock(entry) + + require.Equal(t, "w", spec.Name()) + require.Equal(t, "/src/w", spec.Target) + require.Equal(t, "file://./w", spec.URI) + require.Equal(t, []string{"*.safetensors"}, spec.Include) + require.Equal(t, []string{"*.onnx"}, spec.Exclude) +} + +func TestWeightSpec_Equal(t *testing.T) { + base := func() *WeightSpec { + s, err := WeightSpecFromConfig(config.WeightSource{ + Name: "w", + Target: "/src/w", + Source: &config.WeightSourceConfig{ + URI: "weights", + Include: []string{"*.safetensors"}, + Exclude: []string{"*.onnx"}, + }, + }) + require.NoError(t, err) + return s + } + + require.True(t, base().Equal(base())) + + cases := []struct { + name string + mutate func(*WeightSpec) + }{ + {"target differs", func(s *WeightSpec) { s.Target = "/src/other" }}, + {"URI differs", func(s *WeightSpec) { s.URI = "file://./other" }}, + {"include differs", func(s *WeightSpec) { s.Include = []string{"*.bin"} }}, + {"exclude differs", func(s *WeightSpec) { s.Exclude = []string{"*.md"} }}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + a := base() + b := base() + tc.mutate(b) + require.False(t, a.Equal(b), "specs should differ: %s", tc.name) + }) + } +} + +func TestWeightSpec_EqualIgnoresIncludeExcludeOrder(t *testing.T) { + // Include/Exclude are sets of glob patterns; reordering them in + // cog.yaml must not count as drift. + a, err := WeightSpecFromConfig(config.WeightSource{ + Name: "w", Target: "/w", + Source: &config.WeightSourceConfig{ + URI: "weights", + Include: []string{"*.safetensors", "*.json"}, + Exclude: []string{"*.onnx", "*.md"}, + }, + }) + require.NoError(t, err) + b, err := WeightSpecFromConfig(config.WeightSource{ + Name: "w", Target: "/w", + Source: &config.WeightSourceConfig{ + URI: "weights", + Include: []string{"*.json", "*.safetensors"}, + Exclude: []string{"*.md", "*.onnx"}, + }, + }) + require.NoError(t, err) + require.True(t, a.Equal(b)) +} - require.Equal(t, "/data/llama-7b.safetensors", spec.Source) - require.Equal(t, "/weights/llama-7b.safetensors", spec.Target) +func TestWeightSpec_EqualURINormalization(t *testing.T) { + a, err := WeightSpecFromConfig(config.WeightSource{ + Name: "w", + Target: "/w", + Source: &config.WeightSourceConfig{URI: "weights"}, + }) + require.NoError(t, err) + b, err := WeightSpecFromConfig(config.WeightSource{ + Name: "w", + Target: "/w", + Source: &config.WeightSourceConfig{URI: "file://./weights"}, + }) + require.NoError(t, err) + require.True(t, a.Equal(b)) +} + +func TestWeightSpec_EqualIgnoresName(t *testing.T) { + a, err := WeightSpecFromConfig(config.WeightSource{ + Name: "a", Target: "/w", + Source: &config.WeightSourceConfig{URI: "weights"}, + }) + require.NoError(t, err) + b, err := WeightSpecFromConfig(config.WeightSource{ + Name: "b", Target: "/w", + Source: &config.WeightSourceConfig{URI: "weights"}, + }) + require.NoError(t, err) + require.True(t, a.Equal(b)) } func TestWeightArtifact_ImplementsArtifact(t *testing.T) { @@ -30,14 +193,15 @@ func TestWeightArtifact_ImplementsArtifact(t *testing.T) { Digest: v1.Hash{Algorithm: "sha256", Hex: "def456"}, Size: 4096, } - cfg := WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "my-weights", - Target: "/weights/model.bin", - Created: time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC), + layers := []packedLayer{ + { + Digest: v1.Hash{Algorithm: "sha256", Hex: "aaa"}, + Size: 15000, + MediaType: mediaTypeOCILayerTarGzip, + }, } - artifact := NewWeightArtifact("my-weights", desc, "/data/weights.bin", "/weights/model.bin", cfg) + entry := lockfile.WeightLockEntry{Name: "my-weights", Target: "/src/weights"} + artifact := newWeightArtifact(entry, desc, layers) var _ Artifact = artifact // compile-time interface check @@ -51,56 +215,25 @@ func TestWeightArtifact_Fields(t *testing.T) { Digest: v1.Hash{Algorithm: "sha256", Hex: "def456"}, Size: 4096, } - cfg := WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "my-weights", - Target: "/weights/model.bin", - Created: time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC), + layers := []packedLayer{ + { + Digest: v1.Hash{Algorithm: "sha256", Hex: "bbb"}, + Size: 2048, + MediaType: mediaTypeOCILayerTar, + }, } - artifact := NewWeightArtifact("my-weights", desc, "/data/weights.bin", "/weights/model.bin", cfg) + entry := lockfile.WeightLockEntry{Name: "my-weights", Target: "/src/weights"} + artifact := newWeightArtifact(entry, desc, layers) - require.Equal(t, "/data/weights.bin", artifact.FilePath) - require.Equal(t, "/weights/model.bin", artifact.Target) - require.Equal(t, cfg, artifact.Config) -} - -func TestWeightConfig_JSONRoundTrip(t *testing.T) { - original := WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "llama-7b", - Target: "/weights/llama-7b", - Created: time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC), - } - - data, err := json.Marshal(original) - require.NoError(t, err) - - // Verify JSON structure - var raw map[string]any - err = json.Unmarshal(data, &raw) - require.NoError(t, err) - require.Equal(t, "1.0", raw["schemaVersion"]) - require.Equal(t, "0.15.0", raw["cogVersion"]) - require.Equal(t, "llama-7b", raw["name"]) - require.Equal(t, "/weights/llama-7b", raw["target"]) - - // Round-trip - var decoded WeightConfig - err = json.Unmarshal(data, &decoded) - require.NoError(t, err) - require.Equal(t, original.SchemaVersion, decoded.SchemaVersion) - require.Equal(t, original.CogVersion, decoded.CogVersion) - require.Equal(t, original.Name, decoded.Name) - require.Equal(t, original.Target, decoded.Target) - require.True(t, original.Created.Equal(decoded.Created)) + require.Equal(t, "/src/weights", artifact.Entry.Target) + require.Equal(t, layers, artifact.Layers) + require.Empty(t, artifact.Entry.SetDigest) } func TestWeightMediaTypeConstants(t *testing.T) { - // Verify media type constants have expected values + // The artifactType on the manifest is the only v1 media type with a + // Cog-specific name; layers use standard OCI types. require.Equal(t, "application/vnd.cog.weight.v1", MediaTypeWeightArtifact) - require.Equal(t, "application/vnd.cog.weight.config.v1+json", MediaTypeWeightConfig) - require.Equal(t, "application/vnd.cog.weight.layer.v1", MediaTypeWeightLayer) - require.Equal(t, "application/vnd.cog.weight.layer.v1+gzip", MediaTypeWeightLayerGzip) + require.Equal(t, "application/vnd.oci.image.layer.v1.tar", mediaTypeOCILayerTar) + require.Equal(t, "application/vnd.oci.image.layer.v1.tar+gzip", mediaTypeOCILayerTarGzip) } diff --git a/pkg/model/envelope.go b/pkg/model/envelope.go new file mode 100644 index 0000000000..8887c937f0 --- /dev/null +++ b/pkg/model/envelope.go @@ -0,0 +1,94 @@ +package model + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "slices" +) + +// envelopeRevision counts byte-level packer changes the envelope +// struct can't capture on its own. +// +// SYNC: bump when the bytes a packer writes for a given (inventory, +// parameters) tuple change in a way the struct fields below don't +// reflect. Examples: +// +// - tar header format flips (FormatPAX → FormatGNU) or changes to +// deterministic header fields (mode, uid/gid, mtime, typeflag). +// - file ordering inside a layer (today: small-file bundles sort +// by path; large-file layers contain a single file). +// - directory-entry insertion rules (today: every parent of every +// packed file gets a deterministic dir header before the files). +// - compressor framing changes (different gzip impl, header tweaks). +// - default flips on existing parameters where the field stays the +// same shape but the meaning of the value changes. +// +// Pure parameter changes — defaultBundleFileMax, incompressibleExts, +// gzip levels — are captured automatically by envelopeFromOptions and +// do NOT require a revision bump. Bias toward bumping when in doubt: +// missed bumps mean silent lockfile drift; unnecessary bumps mean +// one round of lockfile churn on next `cog weights import`. +const envelopeRevision = 1 + +// envelope captures every input that determines the packer's byte +// output for a given inventory: thresholds, gzip levels, +// incompressible extensions, layer media types, and envelopeRevision. +// Equal envelopes ⇒ byte-identical layer blobs for the same inventory. +// +// JSON tags are on-disk identifiers; renaming a field requires +// updating the snapshot digests in envelope_test.go. +type envelope struct { + BundleFileMax int64 `json:"bundleFileMax"` + BundleSizeMax int64 `json:"bundleSizeMax"` + GzipLevelBundle int `json:"gzipLevelBundle"` + GzipLevelLarge int `json:"gzipLevelLarge"` + IncompressibleExts []string `json:"incompressibleExts"` // sorted ascending + MediaTypeCompressed string `json:"mediaTypeCompressed"` + MediaTypeRaw string `json:"mediaTypeRaw"` + Revision int `json:"revision"` +} + +// envelopeFromOptions builds the envelope describing the current +// packer behavior under opts. Every field reads the live value at +// the call site (defaults via packOptions methods, package-level +// constants for gzip levels and media types, the live +// incompressibleExts map) so a parameter change propagates to the +// digest without anyone having to remember to update the envelope. +// +// TODO: gzip levels and incompressibleExts live on package-level +// state in packer.go rather than on packOptions. Behavior is correct +// but the separation is muddled — revisit when packer grows +// configurable gzip. +func envelopeFromOptions(opts packOptions) envelope { + exts := make([]string, 0, len(incompressibleExts)) + for ext := range incompressibleExts { + exts = append(exts, ext) + } + slices.Sort(exts) + + return envelope{ + BundleFileMax: opts.bundleFileMax(), + BundleSizeMax: opts.bundleSizeMax(), + GzipLevelBundle: gzipLevelBundle, + GzipLevelLarge: gzipLevelLarge, + IncompressibleExts: exts, + MediaTypeCompressed: mediaTypeOCILayerTarGzip, + MediaTypeRaw: mediaTypeOCILayerTar, + Revision: envelopeRevision, + } +} + +// computeEnvelopeFormat returns the canonical sha256 digest of env +// (with "sha256:" prefix). Determinism rests on encoding/json's +// stable struct-field ordering and on IncompressibleExts being +// sorted at construction time. +func computeEnvelopeFormat(env envelope) (string, error) { + data, err := json.Marshal(env) + if err != nil { + return "", fmt.Errorf("marshal envelope: %w", err) + } + sum := sha256.Sum256(data) + return "sha256:" + hex.EncodeToString(sum[:]), nil +} diff --git a/pkg/model/envelope_test.go b/pkg/model/envelope_test.go new file mode 100644 index 0000000000..fcc14992db --- /dev/null +++ b/pkg/model/envelope_test.go @@ -0,0 +1,136 @@ +package model + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// envelopeFormatChangeMessage explains why a snapshot mismatch +// matters and how to resolve it. The envelope digest is stamped into +// weights.lock; a mismatch on next `cog weights import` forces a +// recompute pass and rewrites the lockfile. Touching this snapshot is +// meaningful by design. +// +// When this assertion fails: +// +// 1. Intentional packer parameter change (defaultBundleFileMax, +// incompressibleExts, gzip level)? Update the snapshot. Expect +// lockfile churn on the next import — that's the whole point. +// +// 2. Changed the *bytes* the packer writes (tar header format, file +// ordering, directory-entry emission, compressor framing) without +// a parameter change? Bump envelopeRevision (see the SYNC block +// on that constant) and update the snapshot. +// +// 3. Reverting an intentional change? Revert the snapshot too. +// +// 4. Don't know why this is failing? Probably touched envelope or +// envelopeFromOptions accidentally — field shape, JSON tags, +// field order. Fix that, don't update the snapshot. +const envelopeFormatChangeMessage = "envelope digest changed: see envelopeFormatChangeMessage in envelope_test.go for context and resolution steps" + +// defaultEnvelopeFormatDigest is the snapshot digest of the +// zero-value packOptions envelope under envelopeRevision = 1. +// +// Update this only after reading envelopeFormatChangeMessage above. +const defaultEnvelopeFormatDigest = "sha256:ce2d53f8dd962ace393450e0abadbe227304897be87753a503b61f9c8525726e" + +func TestEnvelopeFormat_DefaultIsStable(t *testing.T) { + got, err := computeEnvelopeFormat(envelopeFromOptions(packOptions{})) + require.NoError(t, err) + assert.Equal(t, defaultEnvelopeFormatDigest, got, envelopeFormatChangeMessage) +} + +func TestEnvelopeFormat_NonDefaults(t *testing.T) { + // Snapshot table for non-default envelopes. Each row freezes the + // digest under one explicit packOptions tweak. If a row breaks, + // see envelopeFormatChangeMessage for resolution. + cases := []struct { + name string + opts packOptions + digest string + }{ + { + name: "custom bundle file max", + opts: packOptions{BundleFileMax: 32 * 1024 * 1024}, + digest: "sha256:42f28ed027f791b53cbd282663de73971c32d3cad9cbb64de6504cadf42b248f", + }, + { + name: "custom bundle size max", + opts: packOptions{BundleSizeMax: 128 * 1024 * 1024}, + digest: "sha256:3e009edfdfc3c4371ed17572723c4f19d7646a1282646fe008fcb95c955c1547", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := computeEnvelopeFormat(envelopeFromOptions(tc.opts)) + require.NoError(t, err) + assert.Equal(t, tc.digest, got, envelopeFormatChangeMessage) + }) + } +} + +func TestEnvelopeFormat_Deterministic(t *testing.T) { + // Two computations of the same envelope must produce the same + // digest. Catches accidental nondeterminism (map iteration order + // leaking into the JSON, time stamps, etc.). + env := envelopeFromOptions(packOptions{}) + a, err := computeEnvelopeFormat(env) + require.NoError(t, err) + b, err := computeEnvelopeFormat(env) + require.NoError(t, err) + assert.Equal(t, a, b, "computeEnvelopeFormat must be deterministic") +} + +func TestEnvelopeFormat_RevisionBumpChangesDigest(t *testing.T) { + // Sanity check that bumping Revision changes the digest. Without + // this, the SYNC block on envelopeRevision would be a comment + // pointing at machinery that doesn't actually work. + env := envelopeFromOptions(packOptions{}) + bumped := env + bumped.Revision = env.Revision + 1 + + a, err := computeEnvelopeFormat(env) + require.NoError(t, err) + b, err := computeEnvelopeFormat(bumped) + require.NoError(t, err) + assert.NotEqual(t, a, b, + "bumping envelopeRevision must produce a different digest") +} + +func TestEnvelopeFormat_FieldsCaptured(t *testing.T) { + // Each field in envelope must contribute to the digest — otherwise + // adding a field to track a new packer input is a silent no-op. + // Mutate one field at a time and assert the digest changes. + base := envelopeFromOptions(packOptions{}) + baseDigest, err := computeEnvelopeFormat(base) + require.NoError(t, err) + + mutations := []struct { + name string + fn func(*envelope) + }{ + {"BundleFileMax", func(e *envelope) { e.BundleFileMax++ }}, + {"BundleSizeMax", func(e *envelope) { e.BundleSizeMax++ }}, + {"GzipLevelBundle", func(e *envelope) { e.GzipLevelBundle++ }}, + {"GzipLevelLarge", func(e *envelope) { e.GzipLevelLarge++ }}, + {"IncompressibleExts", func(e *envelope) { e.IncompressibleExts = append(e.IncompressibleExts, ".new") }}, + {"MediaTypeCompressed", func(e *envelope) { e.MediaTypeCompressed = "x/different" }}, + {"MediaTypeRaw", func(e *envelope) { e.MediaTypeRaw = "x/different" }}, + {"Revision", func(e *envelope) { e.Revision++ }}, + } + for _, m := range mutations { + t.Run(m.name, func(t *testing.T) { + mutated := base + // Clone slice so the mutation doesn't bleed into base. + mutated.IncompressibleExts = append([]string(nil), base.IncompressibleExts...) + m.fn(&mutated) + got, err := computeEnvelopeFormat(mutated) + require.NoError(t, err) + assert.NotEqual(t, baseDigest, got, + "mutating %s must change the digest; if not, the field isn't actually part of the envelope", m.name) + }) + } +} diff --git a/pkg/model/format.go b/pkg/model/format.go deleted file mode 100644 index aa491e3972..0000000000 --- a/pkg/model/format.go +++ /dev/null @@ -1,11 +0,0 @@ -package model - -import "os" - -// TODO(md): OCIIndexEnabled is a temporary gate for the OCI Image Index push path. -// When COG_OCI_INDEX=1, builds produce weight artifacts and pushes create an OCI -// Image Index instead of a single image manifest. Remove this gate (and always use -// the index path) once we've validated index compatibility with all registries. -func OCIIndexEnabled() bool { - return os.Getenv("COG_OCI_INDEX") == "1" -} diff --git a/pkg/model/format_test.go b/pkg/model/format_test.go deleted file mode 100644 index a85b557973..0000000000 --- a/pkg/model/format_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package model - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestOCIIndexEnabled_Default(t *testing.T) { - t.Setenv("COG_OCI_INDEX", "") - require.False(t, OCIIndexEnabled()) -} - -func TestOCIIndexEnabled_Enabled(t *testing.T) { - t.Setenv("COG_OCI_INDEX", "1") - require.True(t, OCIIndexEnabled()) -} - -func TestOCIIndexEnabled_OtherValue(t *testing.T) { - t.Setenv("COG_OCI_INDEX", "0") - require.False(t, OCIIndexEnabled()) -} diff --git a/pkg/model/hash.go b/pkg/model/hash.go deleted file mode 100644 index b8e8d7422e..0000000000 --- a/pkg/model/hash.go +++ /dev/null @@ -1,26 +0,0 @@ -package model - -import ( - "crypto/sha256" - "encoding/hex" - "io" - "os" -) - -// hashFile computes SHA256 digest and size of a file by streaming. -func hashFile(path string) (digest string, size int64, err error) { - f, err := os.Open(path) - if err != nil { - return "", 0, err - } - defer f.Close() - - h := sha256.New() - size, err = io.Copy(h, f) - if err != nil { - return "", 0, err - } - - digest = "sha256:" + hex.EncodeToString(h.Sum(nil)) - return digest, size, nil -} diff --git a/pkg/model/image_builder_test.go b/pkg/model/image_builder_test.go index 2f6f19f027..b96a80bd32 100644 --- a/pkg/model/image_builder_test.go +++ b/pkg/model/image_builder_test.go @@ -76,8 +76,13 @@ func TestImageBuilder_ErrorWrongSpecType(t *testing.T) { ib := NewImageBuilder(&mockFactory{}, &mockDocker{}, src, BuildOptions{}) // Pass a WeightSpec instead of ImageSpec - weightSpec := NewWeightSpec("model", "model.bin", "/weights/model.bin") - _, err := ib.Build(context.Background(), weightSpec) + weightSpec, err := WeightSpecFromConfig(config.WeightSource{ + Name: "model", + Target: "/weights/model.bin", + Source: &config.WeightSourceConfig{URI: "model.bin"}, + }) + require.NoError(t, err) + _, err = ib.Build(context.Background(), weightSpec) require.Error(t, err) require.Contains(t, err.Error(), "expected *ImageSpec") } diff --git a/pkg/model/image_pusher_test.go b/pkg/model/image_pusher_test.go index c6f91321f1..0f93dffa9c 100644 --- a/pkg/model/image_pusher_test.go +++ b/pkg/model/image_pusher_test.go @@ -75,6 +75,9 @@ func (m *ociMockClient) GetDescriptor(context.Context, string) (v1.Descriptor, e return v1.Descriptor{}, nil } func (m *ociMockClient) PushIndex(context.Context, string, v1.ImageIndex) error { return nil } +func (m *ociMockClient) BlobExists(context.Context, string, string) (bool, error) { + return false, nil +} // testArtifact creates an *ImageArtifact for testing with the given reference string. func testArtifact(ref string) *ImageArtifact { diff --git a/pkg/model/index.go b/pkg/model/index.go deleted file mode 100644 index f5e88c5d4e..0000000000 --- a/pkg/model/index.go +++ /dev/null @@ -1,58 +0,0 @@ -// pkg/model/index.go -package model - -// Index represents an OCI Image Index containing multiple manifests. -type Index struct { - // Digest is the index digest (sha256:...). - Digest string - // Reference is the full image reference. - Reference string - // MediaType is typically application/vnd.oci.image.index.v1+json. - MediaType string - // Manifests are the child manifests in this index. - Manifests []IndexManifest -} - -// IndexManifest represents a single manifest within an index. -type IndexManifest struct { - // Digest is the manifest digest. - Digest string - // MediaType is the manifest media type. - MediaType string - // Size is the manifest size in bytes. - Size int64 - // Platform is the target platform (nil for artifacts). - Platform *Platform - // Annotations are OCI annotations on this manifest. - Annotations map[string]string - // Type is derived from platform/annotations (image or weights). - Type ManifestType -} - -// ManifestType identifies the type of manifest in an index. -type ManifestType string - -const ( - // ManifestTypeImage is a runnable container image. - ManifestTypeImage ManifestType = "image" - // ManifestTypeWeights is a weights artifact. - ManifestTypeWeights ManifestType = "weights" -) - -// Annotation keys for weights manifests. -const ( - AnnotationReferenceType = "vnd.cog.reference.type" - AnnotationReferenceDigest = "vnd.cog.reference.digest" -) - -// Annotation values. -const ( - // AnnotationValueWeights is the value for AnnotationReferenceType indicating a weights manifest. - AnnotationValueWeights = "weights" -) - -// Platform values for non-platform-specific artifacts. -const ( - // PlatformUnknown is used for artifacts that are not platform-specific (e.g., weights). - PlatformUnknown = "unknown" -) diff --git a/pkg/model/index_factory.go b/pkg/model/index_factory.go index edfa3f8b8e..1800d7550a 100644 --- a/pkg/model/index_factory.go +++ b/pkg/model/index_factory.go @@ -2,6 +2,7 @@ package model import ( "fmt" + "strconv" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/empty" @@ -16,12 +17,12 @@ type IndexBuilder struct { weightDescriptors []weightDescEntry } -// weightDescEntry pairs a weight descriptor with the image digest it references. +// weightDescEntry pairs a weight descriptor with its index-level metadata. type weightDescEntry struct { - descriptor v1.Descriptor - imageDigest string - name string - target string + descriptor v1.Descriptor + name string + setDigest string + uncompressedSize int64 } // NewIndexBuilder creates a new index builder. @@ -36,14 +37,14 @@ func (b *IndexBuilder) SetImageDescriptor(desc v1.Descriptor, platform *v1.Platf } // AddWeightDescriptor adds a weight manifest descriptor. -// imageDigest is the digest of the model image, used in the reference annotation. -// name and target are optional weight metadata for index annotations. -func (b *IndexBuilder) AddWeightDescriptor(desc v1.Descriptor, imageDigest, name, target string) { +// Metadata (name, setDigest, uncompressedSize) is hoisted into index-level +// annotations so the index is inspectable without fetching child manifests (§2.6). +func (b *IndexBuilder) AddWeightDescriptor(desc v1.Descriptor, name, setDigest string, uncompressedSize int64) { b.weightDescriptors = append(b.weightDescriptors, weightDescEntry{ - descriptor: desc, - imageDigest: imageDigest, - name: name, - target: target, + descriptor: desc, + name: name, + setDigest: setDigest, + uncompressedSize: uncompressedSize, }) } @@ -68,27 +69,31 @@ func (b *IndexBuilder) BuildFromDescriptors() (v1.ImageIndex, error) { }, }) - // Add weight manifest(s) + // Add weight manifest(s). Per spec §2.6, weight descriptors in the + // index carry name and set-digest so the index is inspectable without + // fetching child manifests. for _, entry := range b.weightDescriptors { - annotations := map[string]string{ - AnnotationReferenceType: AnnotationValueWeights, - } - if entry.imageDigest != "" { - annotations[AnnotationReferenceDigest] = entry.imageDigest - } + annotations := make(map[string]string, 3) if entry.name != "" { - annotations[AnnotationWeightName] = entry.name + annotations[AnnotationV1WeightName] = entry.name + } + if entry.setDigest != "" { + annotations[AnnotationV1WeightSetDigest] = entry.setDigest } - if entry.target != "" { - annotations[AnnotationWeightDest] = entry.target + if entry.uncompressedSize > 0 { + annotations[AnnotationV1WeightSizeUncomp] = strconv.FormatInt(entry.uncompressedSize, 10) } + weightDesc := entry.descriptor + weightDesc.ArtifactType = MediaTypeWeightArtifact + idx = mutate.AppendManifests(idx, mutate.IndexAddendum{ - Add: &descriptorAppendable{desc: entry.descriptor}, + Add: &descriptorAppendable{desc: weightDesc}, Descriptor: v1.Descriptor{ - MediaType: entry.descriptor.MediaType, - Size: entry.descriptor.Size, - Digest: entry.descriptor.Digest, + MediaType: entry.descriptor.MediaType, + Size: entry.descriptor.Size, + Digest: entry.descriptor.Digest, + ArtifactType: MediaTypeWeightArtifact, Platform: &v1.Platform{ OS: PlatformUnknown, Architecture: PlatformUnknown, @@ -118,3 +123,7 @@ func (d *descriptorAppendable) Digest() (v1.Hash, error) { func (d *descriptorAppendable) Size() (int64, error) { return d.desc.Size, nil } + +func (d *descriptorAppendable) ArtifactType() (string, error) { + return d.desc.ArtifactType, nil +} diff --git a/pkg/model/index_factory_test.go b/pkg/model/index_factory_test.go index 7fff629f49..a833127a1b 100644 --- a/pkg/model/index_factory_test.go +++ b/pkg/model/index_factory_test.go @@ -29,7 +29,7 @@ func TestIndexBuilder_BuildFromDescriptors(t *testing.T) { builder := NewIndexBuilder() builder.SetImageDescriptor(imgDesc, &v1.Platform{OS: "linux", Architecture: "amd64"}) - builder.AddWeightDescriptor(weightDesc, imgDesc.Digest.String(), "model-v1", "/cache/model.safetensors") + builder.AddWeightDescriptor(weightDesc, "model-v1", "sha256:abcdef1234567890", 1073741824) idx, err := builder.BuildFromDescriptors() require.NoError(t, err) @@ -44,15 +44,16 @@ func TestIndexBuilder_BuildFromDescriptors(t *testing.T) { require.Equal(t, "linux", idxManifest.Manifests[0].Platform.OS) require.Equal(t, "amd64", idxManifest.Manifests[0].Platform.Architecture) - // Second entry: weight artifact with unknown platform and annotations - require.Equal(t, weightDesc.Digest, idxManifest.Manifests[1].Digest) - require.Equal(t, weightDesc.Size, idxManifest.Manifests[1].Size) - require.Equal(t, PlatformUnknown, idxManifest.Manifests[1].Platform.OS) - require.Equal(t, PlatformUnknown, idxManifest.Manifests[1].Platform.Architecture) - require.Equal(t, AnnotationValueWeights, idxManifest.Manifests[1].Annotations[AnnotationReferenceType]) - require.Equal(t, imgDesc.Digest.String(), idxManifest.Manifests[1].Annotations[AnnotationReferenceDigest]) - require.Equal(t, "model-v1", idxManifest.Manifests[1].Annotations[AnnotationWeightName]) - require.Equal(t, "/cache/model.safetensors", idxManifest.Manifests[1].Annotations[AnnotationWeightDest]) + // Second entry: weight artifact with unknown platform, artifactType, and annotations + wm := idxManifest.Manifests[1] + require.Equal(t, weightDesc.Digest, wm.Digest) + require.Equal(t, weightDesc.Size, wm.Size) + require.Equal(t, MediaTypeWeightArtifact, wm.ArtifactType) + require.Equal(t, PlatformUnknown, wm.Platform.OS) + require.Equal(t, PlatformUnknown, wm.Platform.Architecture) + require.Equal(t, "model-v1", wm.Annotations[AnnotationV1WeightName]) + require.Equal(t, "sha256:abcdef1234567890", wm.Annotations[AnnotationV1WeightSetDigest]) + require.Equal(t, "1073741824", wm.Annotations[AnnotationV1WeightSizeUncomp]) }) t.Run("builds index with multiple weight descriptors", func(t *testing.T) { @@ -74,8 +75,8 @@ func TestIndexBuilder_BuildFromDescriptors(t *testing.T) { builder := NewIndexBuilder() builder.SetImageDescriptor(imgDesc, &v1.Platform{OS: "linux", Architecture: "amd64"}) - builder.AddWeightDescriptor(weight1, imgDesc.Digest.String(), "weight-1", "/weights/w1.bin") - builder.AddWeightDescriptor(weight2, imgDesc.Digest.String(), "weight-2", "/weights/w2.bin") + builder.AddWeightDescriptor(weight1, "weight-1", "sha256:set1", 500) + builder.AddWeightDescriptor(weight2, "weight-2", "sha256:set2", 600) idx, err := builder.BuildFromDescriptors() require.NoError(t, err) diff --git a/pkg/model/index_test.go b/pkg/model/index_test.go deleted file mode 100644 index 913fa567fa..0000000000 --- a/pkg/model/index_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// pkg/model/index_test.go -package model - -import ( - "testing" - - v1 "github.com/google/go-containerregistry/pkg/v1" - "github.com/stretchr/testify/require" -) - -func TestModel_IsBundle(t *testing.T) { - t.Run("returns false with no artifacts", func(t *testing.T) { - m := &Model{} - require.False(t, m.IsBundle()) - }) - - t.Run("returns false with only image artifact", func(t *testing.T) { - m := &Model{ - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - }, - } - require.False(t, m.IsBundle()) - }) - - t.Run("returns true with weight artifacts", func(t *testing.T) { - m := &Model{ - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - NewWeightArtifact("w1", v1.Descriptor{}, "/tmp/w1", "/weights/w1", WeightConfig{}), - }, - } - require.True(t, m.IsBundle()) - }) -} - -func TestManifestType(t *testing.T) { - require.Equal(t, ManifestType("image"), ManifestTypeImage) - require.Equal(t, ManifestType("weights"), ManifestTypeWeights) -} diff --git a/pkg/model/model.go b/pkg/model/model.go index 2aeb01a0c4..3d07d832bb 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -2,10 +2,13 @@ package model import ( "encoding/json" + "fmt" + "path/filepath" "github.com/getkin/kin-openapi/openapi3" "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/weights/lockfile" ) // Model represents a Cog model extracted from an image. @@ -15,17 +18,59 @@ type Model struct { Schema *openapi3.T // OpenAPI schema CogVersion string // Version of cog used to build - // Index is the OCI Image Index (populated when inspecting a pushed model). - Index *Index - - // TODO(md): OCIIndex is a temporary gate. When true, Push() creates an OCI - // Image Index with weight artifacts. When false, Push() does a plain docker push. - // Remove this field once index pushes are validated with all registries. - OCIIndex bool - // Artifacts is the collection of all artifacts produced by building this model. - // Populated by Resolver.Build(). Contains ImageArtifact and WeightArtifact instances. + // Populated by Resolver.Build(). Contains ImageArtifact instances only. Artifacts []Artifact + + // Weights are the model's managed weights, loaded from the lockfile + // during Build. Each Weight carries all lockfile metadata (name, + // target, digest, set digest, sizes). The push path uses these to + // HEAD-check weight manifests in the registry; it never streams + // layer bytes. + Weights []Weight +} + +// Weight is the model's representation of a managed weight, projected +// from a lockfile entry. Fields mirror lockfile.WeightLockEntry but +// this type belongs to the model domain and carries only what the +// build and push paths need. +type Weight struct { + Name string + Target string + Digest string // OCI manifest digest + SetDigest string // content-addressable file set identity (spec §2.4) + Size int64 + // SizeCompressed is the total compressed (over-the-wire) size. + SizeCompressed int64 +} + +// WeightFromLockEntry creates a Weight from a lockfile entry. +func WeightFromLockEntry(e lockfile.WeightLockEntry) Weight { + return Weight{ + Name: e.Name, + Target: e.Target, + Digest: e.Digest, + SetDigest: e.SetDigest, + Size: e.Size, + SizeCompressed: e.SizeCompressed, + } +} + +// WeightsFromLockfile loads a lockfile from projectDir and returns +// the corresponding Weight slice. Returns an error if the lockfile +// is missing or corrupt. +func WeightsFromLockfile(projectDir string) ([]Weight, error) { + lock, err := lockfile.LoadWeightsLock( + filepath.Join(projectDir, lockfile.WeightsLockFilename), + ) + if err != nil { + return nil, fmt.Errorf("load weights.lock: %w", err) + } + weights := make([]Weight, len(lock.Weights)) + for i, e := range lock.Weights { + weights[i] = WeightFromLockEntry(e) + } + return weights, nil } // HasGPU returns true if the model requires GPU. @@ -49,9 +94,9 @@ func (m *Model) ImageRef() string { return m.Image.Reference } -// IsBundle returns true if this model has weight artifacts. +// IsBundle returns true if this model has managed weights. func (m *Model) IsBundle() bool { - return len(m.WeightArtifacts()) > 0 + return len(m.Weights) > 0 } // GetImageArtifact returns the first ImageArtifact from the artifacts collection, diff --git a/pkg/model/model_test.go b/pkg/model/model_test.go index 559140161a..894832aa66 100644 --- a/pkg/model/model_test.go +++ b/pkg/model/model_test.go @@ -2,13 +2,13 @@ package model import ( "testing" - "time" "github.com/getkin/kin-openapi/openapi3" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/weights/lockfile" ) func TestModel_HasGPU(t *testing.T) { @@ -125,10 +125,10 @@ func TestModel_GetImageArtifact(t *testing.T) { v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "abc123"}, Size: 1024}, "r8.im/user/model@sha256:abc123", ) - weightArtifact := NewWeightArtifact("weights", + weightArtifact := newWeightArtifact( + lockfile.WeightLockEntry{Name: "weights", Target: "/src/weights"}, v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "def456"}, Size: 4096}, - "/data/weights.bin", "/weights/model.bin", - WeightConfig{SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "weights", Target: "/weights/model.bin", Created: time.Now()}, + []packedLayer{{Digest: v1.Hash{Algorithm: "sha256", Hex: "aaa"}, Size: 100, MediaType: mediaTypeOCILayerTarGzip}}, ) tests := []struct { @@ -177,15 +177,15 @@ func TestModel_WeightArtifacts(t *testing.T) { v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "abc123"}, Size: 1024}, "r8.im/user/model@sha256:abc123", ) - w1 := NewWeightArtifact("llama", + w1 := newWeightArtifact( + lockfile.WeightLockEntry{Name: "llama", Target: "/src/weights/llama"}, v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "w1"}, Size: 4096}, - "/data/llama.bin", "/weights/llama.bin", - WeightConfig{SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "llama", Target: "/weights/llama.bin", Created: time.Now()}, + []packedLayer{{Digest: v1.Hash{Algorithm: "sha256", Hex: "aaa"}, Size: 100, MediaType: mediaTypeOCILayerTar}}, ) - w2 := NewWeightArtifact("embeddings", + w2 := newWeightArtifact( + lockfile.WeightLockEntry{Name: "embeddings", Target: "/src/weights/embed"}, v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "w2"}, Size: 2048}, - "/data/embed.bin", "/weights/embed.bin", - WeightConfig{SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "embeddings", Target: "/weights/embed.bin", Created: time.Now()}, + []packedLayer{{Digest: v1.Hash{Algorithm: "sha256", Hex: "bbb"}, Size: 100, MediaType: mediaTypeOCILayerTar}}, ) tests := []struct { @@ -212,10 +212,10 @@ func TestModel_ArtifactsByType(t *testing.T) { v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "abc123"}, Size: 1024}, "r8.im/user/model@sha256:abc123", ) - w1 := NewWeightArtifact("llama", + w1 := newWeightArtifact( + lockfile.WeightLockEntry{Name: "llama", Target: "/src/weights/llama"}, v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "w1"}, Size: 4096}, - "/data/llama.bin", "/weights/llama.bin", - WeightConfig{SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "llama", Target: "/weights/llama.bin", Created: time.Now()}, + []packedLayer{{Digest: v1.Hash{Algorithm: "sha256", Hex: "aaa"}, Size: 100, MediaType: mediaTypeOCILayerTar}}, ) m := &Model{Artifacts: []Artifact{imgArtifact, w1}} @@ -228,3 +228,31 @@ func TestModel_ArtifactsByType(t *testing.T) { require.Len(t, weights, 1) require.Equal(t, "llama", weights[0].Name()) } + +func TestModel_IsBundle(t *testing.T) { + t.Run("returns false with no artifacts", func(t *testing.T) { + m := &Model{} + require.False(t, m.IsBundle()) + }) + + t.Run("returns false with only image artifact", func(t *testing.T) { + m := &Model{ + Artifacts: []Artifact{ + &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, + }, + } + require.False(t, m.IsBundle()) + }) + + t.Run("returns true with weights", func(t *testing.T) { + m := &Model{ + Artifacts: []Artifact{ + &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, + }, + Weights: []Weight{ + {Name: "w1", Target: "/src/weights/w1", SetDigest: "sha256:abc123"}, + }, + } + require.True(t, m.IsBundle()) + }) +} diff --git a/pkg/model/options.go b/pkg/model/options.go index 64f19fd09c..855ab99582 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -41,15 +41,6 @@ type BuildOptions struct { // DockerfileFile is a custom Dockerfile path. DockerfileFile string - // WeightsLockPath is the path to weights.lock file. - // Default: weights.lock in project directory. - WeightsLockPath string - - // TODO(md): OCIIndex is a temporary gate. When true, builds produce weight - // artifacts and pushes create an OCI Image Index. Set via COG_OCI_INDEX=1. - // Remove this field once index pushes are validated with all registries. - OCIIndex bool - // ExcludeSource skips the COPY . /src step in the generated Dockerfile. // Used by `cog serve` to produce an image identical to `cog build` minus // the source copy — the source directory is volume-mounted at runtime. diff --git a/pkg/model/options_test.go b/pkg/model/options_test.go index 4c3fa3ab43..de622a402c 100644 --- a/pkg/model/options_test.go +++ b/pkg/model/options_test.go @@ -105,7 +105,6 @@ func TestBuildOptions_AllFieldsPreserved(t *testing.T) { Annotations: map[string]string{"key": "value"}, SchemaFile: "/path/to/schema.json", DockerfileFile: "/path/to/Dockerfile", - WeightsLockPath: "/path/to/weights.lock", } result := opts.WithDefaults(src) @@ -123,12 +122,4 @@ func TestBuildOptions_AllFieldsPreserved(t *testing.T) { require.Equal(t, map[string]string{"key": "value"}, result.Annotations) require.Equal(t, "/path/to/schema.json", result.SchemaFile) require.Equal(t, "/path/to/Dockerfile", result.DockerfileFile) - require.Equal(t, "/path/to/weights.lock", result.WeightsLockPath) -} - -func TestBuildOptions_WeightsLockPath(t *testing.T) { - opts := BuildOptions{ - WeightsLockPath: "/custom/weights.lock", - } - require.Equal(t, "/custom/weights.lock", opts.WeightsLockPath) } diff --git a/pkg/model/packer.go b/pkg/model/packer.go new file mode 100644 index 0000000000..8bcfd77425 --- /dev/null +++ b/pkg/model/packer.go @@ -0,0 +1,555 @@ +package model + +import ( + "archive/tar" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "sort" + "strings" + "time" + + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/types" + "golang.org/x/sync/errgroup" + + "github.com/replicate/cog/pkg/model/weightsource" + "github.com/replicate/cog/pkg/weights/store" +) + +// Default packing thresholds per spec §1.2. +const ( + defaultBundleFileMax = 64 * 1024 * 1024 // 64 MB + defaultBundleSizeMax = 256 * 1024 * 1024 // 256 MB +) + +// gzip compression levels. Bundles use BestCompression (small +// text-heavy files reward extra CPU); large compressible files use +// DefaultCompression (marginal savings rarely justify the cost). +// Named constants so envelope.go references them by symbol and +// changes reach the envelope digest automatically. +const ( + gzipLevelBundle = gzip.BestCompression + gzipLevelLarge = gzip.DefaultCompression +) + +// OCI layer media types per spec §2.1. +const ( + mediaTypeOCILayerTar = "application/vnd.oci.image.layer.v1.tar" + mediaTypeOCILayerTarGzip = "application/vnd.oci.image.layer.v1.tar+gzip" +) + +// incompressibleExts lists extensions for dense binary formats that don't +// benefit from gzip compression. Per spec §1.2. +var incompressibleExts = map[string]bool{ + ".safetensors": true, + ".bin": true, + ".gguf": true, + ".onnx": true, + ".parquet": true, + ".pt": true, + ".pth": true, +} + +// packOptions configures the packing algorithm. +type packOptions struct { + // BundleFileMax is the size threshold separating small from large files. + // Files below this are bundled; files at or above get their own layer. + // Defaults to defaultBundleFileMax (64 MB). + BundleFileMax int64 + + // BundleSizeMax is the maximum cumulative size of a single bundle tar. + // Defaults to defaultBundleSizeMax (256 MB). + BundleSizeMax int64 +} + +func (o packOptions) bundleFileMax() int64 { + if o.BundleFileMax > 0 { + return o.BundleFileMax + } + return defaultBundleFileMax +} + +func (o packOptions) bundleSizeMax() int64 { + if o.BundleSizeMax > 0 { + return o.BundleSizeMax + } + return defaultBundleSizeMax +} + +// isGzip reports whether a layer media type is a gzip-compressed tar. +func isGzip(mt types.MediaType) bool { + return mt == mediaTypeOCILayerTarGzip +} + +// packedLayer describes one packed tar layer: enough metadata to write +// an OCI manifest descriptor, plus the layerPlan needed to re-stream +// the layer bytes on demand from a content-addressed store. +// +// There is no on-disk tar file. Layer bytes are produced by streaming +// from the store every time something asks for them — once during +// digest computation, again during push. Determinism rests on the +// store contents being immutable and the tar/compressor framing being +// byte-stable for a given (plan, store) pair. +type packedLayer struct { + // Plan is the layerPlan that produced this layer's bytes. Held so + // fileLayer.Compressed() can reconstruct the bytes by re-running + // the same packing pipeline against the store. + Plan layerPlan + // Digest is the SHA256 digest of the tar bytes (the OCI blob digest). + Digest v1.Hash + // Size is the size of the tar bytes in bytes (post-compression for + // gzip layers). + Size int64 + // UncompressedSize is the total uncompressed size of the files in + // this layer. + UncompressedSize int64 + // MediaType is the OCI media type for this layer. + MediaType types.MediaType +} + +// packResult is the output of packer.execute: layer descriptors and +// per-file content digests. +type packResult struct { + // Layers are the packed layer descriptors. + Layers []packedLayer + // Files are per-file content digests, sorted by path. + Files []packedFile +} + +// packedFile records a file's path, size, content digest, and which layer it +// landed in. Used to build the config blob (§2.3) and set digest (§2.4). +type packedFile struct { + // Path is the file path relative to the weight target directory. + Path string + // Size is the uncompressed file size in bytes. + Size int64 + // Digest is the SHA-256 content digest of the file (hex-encoded with + // "sha256:" prefix). + Digest string + // LayerDigest is the digest of the layer containing this file + // (populated after packing). + LayerDigest string +} + +// plan describes the target layer layout for an inventory. It is a pure +// function of the inventory plus packing thresholds, so layer-assignment +// logic can be inspected and cache-probed without writing tar bytes. +type plan struct { + // Layers is the ordered set of layers to build. Order is + // deterministic: bundles first (sorted small files), then large + // files in inventory order. + Layers []layerPlan +} + +// layerPlan describes a single planned layer. +type layerPlan struct { + // Files are the inventory entries packed into this layer, in the + // order they will appear in the tar stream. Small-file bundles + // sort by Path; large-file layers contain a single entry. + Files []weightsource.InventoryFile + + // MediaType is the OCI media type for the produced blob. + MediaType types.MediaType +} + +// packer plans and executes the build of tar layers from a weight +// source inventory. +// +// Planning (planLayers) is pure and inspectable. Execution streams +// file bytes from a content-addressed store through the tar+gzip +// pipeline to compute layer digests; no on-disk scratch is written. +// The same plan can later be replayed against the same store to +// reproduce the layer bytes for push (see fileLayer). +type packer struct { + opts packOptions +} + +// newPacker constructs a packer. A nil opts yields spec-default +// thresholds. +func newPacker(opts *packOptions) *packer { + var o packOptions + if opts != nil { + o = *opts + } + return &packer{opts: o} +} + +// planLayers computes the target layer layout for inv. It performs no I/O +// and does not read source bytes. The returned plan is deterministic +// for a given (inv, opts) pair. +// +// An empty inventory yields an empty plan. execute rejects empty +// plans; planLayers itself does not, so callers can reason about the empty +// case without invoking execute. +func (p *packer) planLayers(inv weightsource.Inventory) plan { + if len(inv.Files) == 0 { + return plan{} + } + + threshold := p.opts.bundleFileMax() + bundleMax := p.opts.bundleSizeMax() + + var smallFiles, largeFiles []weightsource.InventoryFile + for _, f := range inv.Files { + if f.Size < threshold { + smallFiles = append(smallFiles, f) + } else { + largeFiles = append(largeFiles, f) + } + } + + // Stable-sort small files by path for deterministic bundling. + sort.SliceStable(smallFiles, func(i, j int) bool { + return smallFiles[i].Path < smallFiles[j].Path + }) + + var layers []layerPlan + + // Bundle small files, flushing whenever adding the next would + // exceed bundleMax. A lone small file larger than bundleMax still + // gets its own bundle (guarded by currentSize > 0). + var current []weightsource.InventoryFile + var currentSize int64 + flush := func() { + if len(current) == 0 { + return + } + layers = append(layers, layerPlan{ + Files: current, + MediaType: mediaTypeOCILayerTarGzip, + }) + current = nil + currentSize = 0 + } + for _, f := range smallFiles { + if currentSize > 0 && currentSize+f.Size > bundleMax { + flush() + } + current = append(current, f) + currentSize += f.Size + } + flush() + + // Large files: one layer each, compressed unless the extension + // marks the content as incompressible. + for _, f := range largeFiles { + mt := types.MediaType(mediaTypeOCILayerTarGzip) + if incompressibleExts[strings.ToLower(filepath.Ext(f.Path))] { + mt = mediaTypeOCILayerTar + } + layers = append(layers, layerPlan{ + Files: []weightsource.InventoryFile{f}, + MediaType: mt, + }) + } + + return plan{Layers: layers} +} + +// computeLayerDigests builds each planned layer in memory by streaming +// file bytes from store through the tar+gzip pipeline into a sha256 +// hasher and a byte counter. No bytes are written to disk. +// +// The store MUST already contain every file referenced by the plan; +// callers ingressFromInventory before calling this. +// +// Layers are processed concurrently (bounded by GOMAXPROCS) since each +// layer reads independent files from the store and writes to io.Discard. +// +// On success returns one packedLayer per layerPlan, with Digest, Size, +// UncompressedSize, MediaType, and the originating Plan filled in. +// The Plan field lets callers later reconstruct the layer bytes for +// push without re-walking the inventory or recomputing digests. +func (p *packer) computeLayerDigests(ctx context.Context, st store.Store, pl plan) ([]packedLayer, error) { + if len(pl.Layers) == 0 { + return nil, fmt.Errorf("no layers in plan") + } + + results := make([]packedLayer, len(pl.Layers)) + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(runtime.GOMAXPROCS(0)) + + for i, lp := range pl.Layers { + g.Go(func() error { + lr, err := p.streamLayer(ctx, st, lp, io.Discard) + if err != nil { + return err + } + results[i] = lr + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, err + } + return results, nil +} + +// streamLayer writes the tar bytes for one layer through the +// tar+(gzip?)+sha256+counter pipeline into sink. Used by both digest +// computation (sink = io.Discard) and push (sink = registry uploader, +// via fileLayer.Compressed()). +// +// The returned packedLayer carries the layer's Plan so callers that +// only need to compute the digest can later replay the same stream +// for push without holding tar bytes in memory. +func (p *packer) streamLayer(ctx context.Context, st store.Store, lp layerPlan, sink io.Writer) (packedLayer, error) { + gzipped := isGzip(lp.MediaType) + + // Writer sandwich: tar → (compressor?) → counter → (sink + hasher). + // counter reports on-wire (compressed) bytes; hasher feeds the + // OCI blob digest. + hasher := sha256.New() + counter := &countingWriter{w: io.MultiWriter(sink, hasher)} + + var gzw *gzip.Writer + var tarSink io.Writer = counter + if gzipped { + level := gzipLevelLarge + if len(lp.Files) > 1 { + level = gzipLevelBundle + } + var err error + gzw, err = gzip.NewWriterLevel(counter, level) + if err != nil { + return packedLayer{}, fmt.Errorf("create gzip writer: %w", err) + } + tarSink = gzw + } + + tw := tar.NewWriter(tarSink) + if err := writeLayer(ctx, st, tw, lp.Files); err != nil { + return packedLayer{}, err + } + + if err := tw.Close(); err != nil { + return packedLayer{}, fmt.Errorf("close tar writer: %w", err) + } + if gzw != nil { + if err := gzw.Close(); err != nil { + return packedLayer{}, fmt.Errorf("close gzip writer: %w", err) + } + } + + var uncompressed int64 + for _, f := range lp.Files { + uncompressed += f.Size + } + + return packedLayer{ + Plan: lp, + Digest: v1.Hash{Algorithm: "sha256", Hex: hex.EncodeToString(hasher.Sum(nil))}, + Size: counter.n, + UncompressedSize: uncompressed, + MediaType: lp.MediaType, + }, nil +} + +// packedFilesFromPlan returns the per-file index for a fully-built set +// of layers. Each file in each layerPlan gets a packedFile pointing at +// the layer's computed digest. Output is sorted by path. +// +// Not a method on packer: this is bookkeeping over the plan + computed +// layer digests, not part of the planning or streaming logic. +func packedFilesFromPlan(layers []packedLayer) []packedFile { + var out []packedFile + for _, lr := range layers { + layerDigest := lr.Digest.String() + for _, f := range lr.Plan.Files { + out = append(out, packedFile{ + Path: f.Path, + Size: f.Size, + Digest: f.Digest, + LayerDigest: layerDigest, + }) + } + } + sort.Slice(out, func(i, j int) bool { return out[i].Path < out[j].Path }) + return out +} + +// ingressFromInventory streams each file in inv from src into st, +// hash-verifying as bytes flow through. Files already present in the +// store are skipped — store.PutFile is idempotent and drains the +// reader to io.Discard for already-stored digests, but we don't even +// open the source for those. Open() on remote sources is expensive +// (HTTP round trip); the cheap Exists() probe avoids it. +// +// Hash mismatches surface here, loudly, instead of silently producing +// a tar whose member digest disagrees with the inventory. +func ingressFromInventory(ctx context.Context, src weightsource.Source, st store.Store, inv weightsource.Inventory) error { + for _, f := range inv.Files { + if err := ctx.Err(); err != nil { + return err + } + ok, err := st.Exists(ctx, f.Digest) + if err != nil { + return fmt.Errorf("check store for %s: %w", f.Path, err) + } + if ok { + continue + } + if err := ingressOne(ctx, src, st, f); err != nil { + return fmt.Errorf("ingress %s: %w", f.Path, err) + } + } + return nil +} + +func ingressOne(ctx context.Context, src weightsource.Source, st store.Store, f weightsource.InventoryFile) error { + rc, err := src.Open(ctx, f.Path) + if err != nil { + return fmt.Errorf("open source: %w", err) + } + defer rc.Close() //nolint:errcheck // best-effort close on read path + return st.PutFile(ctx, f.Digest, f.Size, rc) +} + +// writeLayer writes the in-tar layout for a layer: deterministic +// directory entries for every parent directory referenced by any +// file, followed by the files themselves in supplied order. File +// bytes come from the content-addressed store, keyed by digest. +func writeLayer(ctx context.Context, st store.Store, tw *tar.Writer, files []weightsource.InventoryFile) error { + for _, dir := range collectDirs(files) { + if err := ctx.Err(); err != nil { + return err + } + if err := tw.WriteHeader(deterministicDirHeader(dir)); err != nil { + return fmt.Errorf("write dir header %s: %w", dir, err) + } + } + + for _, f := range files { + if err := ctx.Err(); err != nil { + return err + } + if err := writeFileToTar(ctx, st, tw, f); err != nil { + return fmt.Errorf("write file %s: %w", f.Path, err) + } + } + return nil +} + +// unixEpoch is the Unix epoch time, used for deterministic tar headers. +var unixEpoch = time.Unix(0, 0) + +// deterministicDirHeader returns a tar header for a directory with deterministic properties. +func deterministicDirHeader(name string) *tar.Header { + return &tar.Header{ + Typeflag: tar.TypeDir, + Name: name + "/", + Mode: 0o755, + ModTime: unixEpoch, + AccessTime: unixEpoch, + ChangeTime: unixEpoch, + Format: tar.FormatPAX, + } +} + +// writeFileToTar writes a single file entry to a tar writer with +// deterministic headers (spec §1.3: PAX format, zero timestamps, +// uid/gid 0, 0644 perms). File bytes come from the store, opened by +// digest. +func writeFileToTar(ctx context.Context, st store.Store, tw *tar.Writer, f weightsource.InventoryFile) error { + hdr := &tar.Header{ + Typeflag: tar.TypeReg, + Name: f.Path, + Size: f.Size, + Mode: 0o644, + ModTime: unixEpoch, + AccessTime: unixEpoch, + ChangeTime: unixEpoch, + Format: tar.FormatPAX, + // UID/GID: 0/0 — Go zero values. + } + + if err := tw.WriteHeader(hdr); err != nil { + return fmt.Errorf("write header: %w", err) + } + + rc, err := openFromStore(ctx, st, f.Digest) + if err != nil { + return fmt.Errorf("open from store: %w", err) + } + defer rc.Close() //nolint:errcheck // best-effort close on read path + + if _, err := io.Copy(tw, rc); err != nil { + return fmt.Errorf("copy file data: %w", err) + } + + return nil +} + +// openFromStore returns a ReadCloser for the file at digest, resolved +// through the store's path. Wrapping store.Path + os.Open here keeps +// the tar-write loop short and gives the store a single read-time +// hook in case backends grow more elaborate (network-attached +// containerd content store, for example). +func openFromStore(ctx context.Context, st store.Store, digest string) (io.ReadCloser, error) { + path, err := st.Path(ctx, digest) + if err != nil { + return nil, fmt.Errorf("resolve store path for %s: %w", digest, err) + } + // gosec G304: path comes from store.Path, which validates the + // digest and composes the path inside the store root. + f, err := os.Open(path) //nolint:gosec // see comment above + if err != nil { + return nil, fmt.Errorf("open store file %s: %w", path, err) + } + return f, nil +} + +// collectDirs returns the sorted, deduplicated set of directory paths +// needed for the given files. Each intermediate directory is included. +func collectDirs(files []weightsource.InventoryFile) []string { + seen := make(map[string]bool) + var dirs []string + + for _, f := range files { + for _, d := range collectDirsForPath(f.Path) { + if !seen[d] { + seen[d] = true + dirs = append(dirs, d) + } + } + } + + sort.Strings(dirs) + return dirs +} + +// collectDirsForPath returns all parent directory components for a relative path. +// For "a/b/c.txt" it returns ["a", "a/b"]. +func collectDirsForPath(relPath string) []string { + dir := filepath.ToSlash(filepath.Dir(relPath)) + if dir == "." || dir == "" { + return nil + } + + var dirs []string + parts := strings.Split(dir, "/") + for i := range parts { + dirs = append(dirs, strings.Join(parts[:i+1], "/")) + } + return dirs +} + +// countingWriter wraps a writer and counts bytes written. +type countingWriter struct { + w io.Writer + n int64 +} + +func (cw *countingWriter) Write(p []byte) (int, error) { + n, err := cw.w.Write(p) + cw.n += int64(n) + return n, err +} diff --git a/pkg/model/packer_plan_test.go b/pkg/model/packer_plan_test.go new file mode 100644 index 0000000000..376fe71790 --- /dev/null +++ b/pkg/model/packer_plan_test.go @@ -0,0 +1,183 @@ +package model + +import ( + "testing" + + "github.com/google/go-containerregistry/pkg/v1/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/model/weightsource" +) + +// planLayers tests operate purely on Inventory — no disk, no source. +// They exercise layer-assignment logic (classification, bundling, +// media-type selection) that the execute tests only observe indirectly +// through tar output. + +func invFile(path string, size int64) weightsource.InventoryFile { + return weightsource.InventoryFile{Path: path, Size: size, Digest: "sha256:deadbeef"} +} + +func TestPacker_Plan_EmptyInventory(t *testing.T) { + plan := newPacker(nil).planLayers(weightsource.Inventory{}) + assert.Empty(t, plan.Layers, "empty inventory should plan zero layers") +} + +func TestPacker_Plan_SingleSmallFile(t *testing.T) { + plan := newPacker(nil).planLayers(weightsource.Inventory{ + Files: []weightsource.InventoryFile{invFile("config.json", 100)}, + }) + require.Len(t, plan.Layers, 1) + assert.Equal(t, types.MediaType(mediaTypeOCILayerTarGzip), plan.Layers[0].MediaType, + "small-file bundle should be gzipped") + require.Len(t, plan.Layers[0].Files, 1) + assert.Equal(t, "config.json", plan.Layers[0].Files[0].Path) +} + +func TestPacker_Plan_SingleLargeFileIncompressible(t *testing.T) { + plan := newPacker(nil).planLayers(weightsource.Inventory{ + Files: []weightsource.InventoryFile{invFile("model.safetensors", 100*1024*1024)}, + }) + require.Len(t, plan.Layers, 1) + assert.Equal(t, types.MediaType(mediaTypeOCILayerTar), plan.Layers[0].MediaType, + ".safetensors should be uncompressed") +} + +func TestPacker_Plan_SingleLargeFileCompressible(t *testing.T) { + plan := newPacker(nil).planLayers(weightsource.Inventory{ + Files: []weightsource.InventoryFile{invFile("model.dat", 100*1024*1024)}, + }) + require.Len(t, plan.Layers, 1) + assert.Equal(t, types.MediaType(mediaTypeOCILayerTarGzip), plan.Layers[0].MediaType, + ".dat is not in the incompressible set") +} + +func TestPacker_Plan_MixedFilesOrdering(t *testing.T) { + // Small files arrive in unsorted order; the planner must sort + // them within the bundle for deterministic output. Large files + // follow in the order they appear in the inventory. + plan := newPacker(nil).planLayers(weightsource.Inventory{ + Files: []weightsource.InventoryFile{ + invFile("z-small.json", 100), + invFile("large-01.safetensors", 100*1024*1024), + invFile("a-small.json", 100), + invFile("large-02.safetensors", 100*1024*1024), + }, + }) + require.Len(t, plan.Layers, 3, "1 bundle + 2 large files") + + // Bundle layer first, small files sorted by path. + require.Len(t, plan.Layers[0].Files, 2) + assert.Equal(t, "a-small.json", plan.Layers[0].Files[0].Path) + assert.Equal(t, "z-small.json", plan.Layers[0].Files[1].Path) + + // Large files in inventory order. + assert.Equal(t, "large-01.safetensors", plan.Layers[1].Files[0].Path) + assert.Equal(t, "large-02.safetensors", plan.Layers[2].Files[0].Path) +} + +func TestPacker_Plan_BundleSplitsOnSizeMax(t *testing.T) { + // Everything is small (under BundleFileMax=1024), but bundle + // size is capped at 20 bytes. Expect a+b in one bundle, c in + // another. + plan := newPacker(&packOptions{ + BundleFileMax: 1024, + BundleSizeMax: 20, + }).planLayers(weightsource.Inventory{ + Files: []weightsource.InventoryFile{ + invFile("a.txt", 10), + invFile("b.txt", 10), + invFile("c.txt", 10), + }, + }) + require.Len(t, plan.Layers, 2) + require.Len(t, plan.Layers[0].Files, 2) + assert.Equal(t, "a.txt", plan.Layers[0].Files[0].Path) + assert.Equal(t, "b.txt", plan.Layers[0].Files[1].Path) + require.Len(t, plan.Layers[1].Files, 1) + assert.Equal(t, "c.txt", plan.Layers[1].Files[0].Path) +} + +func TestPacker_Plan_FileAtExactThreshold(t *testing.T) { + // A file equal to BundleFileMax is "large" (strict less-than). + plan := newPacker(nil).planLayers(weightsource.Inventory{ + Files: []weightsource.InventoryFile{invFile("model.bin", defaultBundleFileMax)}, + }) + require.Len(t, plan.Layers, 1) + require.Len(t, plan.Layers[0].Files, 1) + assert.Equal(t, types.MediaType(mediaTypeOCILayerTar), plan.Layers[0].MediaType, + "at-threshold large file should not be bundled") +} + +func TestPacker_Plan_FileJustBelowThreshold(t *testing.T) { + plan := newPacker(nil).planLayers(weightsource.Inventory{ + Files: []weightsource.InventoryFile{invFile("model.bin", defaultBundleFileMax-1)}, + }) + require.Len(t, plan.Layers, 1) + assert.Equal(t, types.MediaType(mediaTypeOCILayerTarGzip), plan.Layers[0].MediaType, + "below-threshold file should land in a bundle") +} + +func TestPacker_Plan_IncompressibleExtensions(t *testing.T) { + tests := []struct { + ext string + mediaType types.MediaType + }{ + {".safetensors", mediaTypeOCILayerTar}, + {".bin", mediaTypeOCILayerTar}, + {".gguf", mediaTypeOCILayerTar}, + {".onnx", mediaTypeOCILayerTar}, + {".parquet", mediaTypeOCILayerTar}, + {".pt", mediaTypeOCILayerTar}, + {".pth", mediaTypeOCILayerTar}, + {".dat", mediaTypeOCILayerTarGzip}, + {".json", mediaTypeOCILayerTarGzip}, + {".pickle", mediaTypeOCILayerTarGzip}, + } + for _, tt := range tests { + t.Run(tt.ext, func(t *testing.T) { + plan := newPacker(nil).planLayers(weightsource.Inventory{ + Files: []weightsource.InventoryFile{invFile("model"+tt.ext, 100*1024*1024)}, + }) + require.Len(t, plan.Layers, 1) + assert.Equal(t, tt.mediaType, plan.Layers[0].MediaType) + }) + } +} + +func TestPacker_Plan_SingleLargeFileExceedingBundleSizeMax(t *testing.T) { + // A small-classified file that still exceeds bundleMax gets its + // own bundle (the flush-before-add guard skips when currentSize + // is 0). This is unusual in practice but the documented behavior. + plan := newPacker(&packOptions{ + BundleFileMax: 1024, + BundleSizeMax: 20, + }).planLayers(weightsource.Inventory{ + Files: []weightsource.InventoryFile{invFile("big-small.txt", 100)}, + }) + require.Len(t, plan.Layers, 1) + assert.Equal(t, types.MediaType(mediaTypeOCILayerTarGzip), plan.Layers[0].MediaType) + require.Len(t, plan.Layers[0].Files, 1) + assert.Equal(t, "big-small.txt", plan.Layers[0].Files[0].Path) +} + +func TestPacker_Plan_Deterministic(t *testing.T) { + // Same inventory → identical Plan, run twice. + inv := weightsource.Inventory{ + Files: []weightsource.InventoryFile{ + invFile("b/c.json", 100), + invFile("a.json", 200), + invFile("model.safetensors", 100*1024*1024), + }, + } + p := newPacker(nil) + plan1 := p.planLayers(inv) + plan2 := p.planLayers(inv) + assert.Equal(t, plan1, plan2) +} + +func TestPacker_ComputeLayerDigests_RejectsEmptyPlan(t *testing.T) { + _, err := newPacker(nil).computeLayerDigests(t.Context(), nil, plan{}) + assert.ErrorContains(t, err, "no layers in plan") +} diff --git a/pkg/model/packer_test.go b/pkg/model/packer_test.go new file mode 100644 index 0000000000..8d42d08a1e --- /dev/null +++ b/pkg/model/packer_test.go @@ -0,0 +1,655 @@ +package model + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" + "testing" + "time" + + "github.com/google/go-containerregistry/pkg/v1/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/model/weightsource" + "github.com/replicate/cog/pkg/weights/store" +) + +// packTestDir is a convenience test helper that wires a local +// directory through the new Source/Inventory + ingress + +// computeLayerDigests pipeline. It hides the boilerplate so test +// bodies can focus on packer behavior. +// +// Returns the packResult plus the store the layers reference, so +// callers can stream layer bytes back out via readLayerEntries +// (mirrors the path push uses). +func packTestDir(t *testing.T, dir string, opts *packOptions) (*packResult, *store.FileStore, error) { + t.Helper() + return packTestDirCtx(t, t.Context(), dir, opts) +} + +// packTestDirCtx is the ctx-accepting variant of packTestDir for tests +// that need a context independent of the test lifetime (typically for +// cancellation tests). +func packTestDirCtx(t *testing.T, ctx context.Context, dir string, opts *packOptions) (*packResult, *store.FileStore, error) { + t.Helper() + st, err := store.NewFileStore(t.TempDir()) + if err != nil { + return nil, nil, err + } + src, err := weightsource.NewFileSource("file://"+dir, "") + if err != nil { + return nil, nil, err + } + inv, err := src.Inventory(ctx) + if err != nil { + return nil, nil, err + } + if err := ingressFromInventory(ctx, src, st, inv); err != nil { + return nil, nil, err + } + pkr := newPacker(opts) + pl := pkr.planLayers(inv) + if len(pl.Layers) == 0 { + return nil, nil, fmt.Errorf("no files in inventory") + } + layers, err := pkr.computeLayerDigests(ctx, st, pl) + if err != nil { + return nil, nil, err + } + return &packResult{ + Layers: layers, + Files: packedFilesFromPlan(layers), + }, st, nil +} + +// createTestFile creates a file at the given path (relative to dir) with the given size. +func createTestFile(t *testing.T, dir, relPath string, size int64) { + t.Helper() + absPath := filepath.Join(dir, relPath) + require.NoError(t, os.MkdirAll(filepath.Dir(absPath), 0o755)) + f, err := os.Create(absPath) + require.NoError(t, err) + defer f.Close() + if size > 0 { + require.NoError(t, f.Truncate(size)) + } +} + +// filesInLayer returns the relative paths packed into the given layer, +// derived from a packResult. The packer no longer tags layers with +// content-type annotations — the file→layer mapping lives on the +// packedFile slice instead. +func filesInLayer(pr *packResult, layerDigest string) []string { + var out []string + for _, f := range pr.Files { + if f.LayerDigest == layerDigest { + out = append(out, f.Path) + } + } + sort.Strings(out) + return out +} + +// isBundleLayer reports whether a layer carries more than one file — +// i.e. it is a "bundle" rather than a single-file layer. This replaces +// the old run.cog.weight.content annotation check. +func isBundleLayer(pr *packResult, layerDigest string) bool { + return len(filesInLayer(pr, layerDigest)) > 1 +} + +func TestPack_EmptyDirectory(t *testing.T) { + dir := t.TempDir() + _, _, err := packTestDir(t, dir, nil) + assert.ErrorContains(t, err, "no files in inventory") +} + +func TestPack_SingleSmallFile(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "config.json", 100) + + results, st, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 1) + + r := results.Layers[0] + assert.Equal(t, types.MediaType(mediaTypeOCILayerTarGzip), r.MediaType) + assert.True(t, r.Size > 0) + assert.Equal(t, int64(100), r.UncompressedSize) + assert.NotEmpty(t, r.Digest.Hex) + assert.Equal(t, "sha256", r.Digest.Algorithm) + + // Single small file is a single-entry bundle layer. + assert.Equal(t, []string{"config.json"}, filesInLayer(results, r.Digest.String())) + + // Verify tar contents. + entries := readLayerEntries(t, r, st) + require.Len(t, entries, 1) + assert.Equal(t, "config.json", entries[0]) +} + +func TestPack_SingleLargeFile_Incompressible(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "model.safetensors", 100*1024*1024) // 100 MB + + results, _, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 1) + + r := results.Layers[0] + assert.Equal(t, types.MediaType(mediaTypeOCILayerTar), r.MediaType) + assert.Equal(t, []string{"model.safetensors"}, filesInLayer(results, r.Digest.String())) + assert.Equal(t, int64(100*1024*1024), r.UncompressedSize) +} + +func TestPack_SingleLargeFile_Compressible(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "model.dat", 100*1024*1024) // 100 MB, not in skip set + + results, _, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 1) + + r := results.Layers[0] + assert.Equal(t, types.MediaType(mediaTypeOCILayerTarGzip), r.MediaType) + assert.Equal(t, []string{"model.dat"}, filesInLayer(results, r.Digest.String())) +} + +func TestPack_MixedFiles(t *testing.T) { + dir := t.TempDir() + + // Small files (< 64 MB default threshold). + createTestFile(t, dir, "config.json", 500) + createTestFile(t, dir, "tokenizer.json", 1000) + createTestFile(t, dir, "special_tokens_map.json", 200) + + // Large files. + createTestFile(t, dir, "model-00001.safetensors", 100*1024*1024) + createTestFile(t, dir, "model-00002.safetensors", 100*1024*1024) + + results, st, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 3) // 1 bundle + 2 large files + + // First result should be the bundle (small files come first in output). + bundle := results.Layers[0] + assert.Equal(t, types.MediaType(mediaTypeOCILayerTarGzip), bundle.MediaType) + assert.True(t, isBundleLayer(results, bundle.Digest.String()), "first layer should hold the bundled small files") + + bundleEntries := readLayerEntries(t, bundle, st) + // Files should be sorted by path. + assert.Equal(t, []string{"config.json", "special_tokens_map.json", "tokenizer.json"}, bundleEntries) + + // Large files should be uncompressed tars (safetensors is incompressible) + // and carry exactly one file each. + for _, r := range results.Layers[1:] { + assert.Equal(t, types.MediaType(mediaTypeOCILayerTar), r.MediaType) + assert.Len(t, filesInLayer(results, r.Digest.String()), 1, "single-file layer should contain exactly one file") + } +} + +func TestPack_NestedDirectories(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "text_encoder/config.json", 100) + createTestFile(t, dir, "text_encoder/tokenizer.json", 200) + createTestFile(t, dir, "vae/config.json", 150) + + results, st, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 1) // All small, one bundle. + + entries := readLayerEntries(t, results.Layers[0], st) + // Directories come first (sorted), then files (sorted). + expected := []string{ + "text_encoder/", + "vae/", + "text_encoder/config.json", + "text_encoder/tokenizer.json", + "vae/config.json", + } + assert.Equal(t, expected, entries) +} + +func TestPack_LargeFileInSubdir(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "text_encoder/model-00001.safetensors", 100*1024*1024) + + results, st, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 1) + + r := results.Layers[0] + assert.Equal(t, []string{"text_encoder/model-00001.safetensors"}, filesInLayer(results, r.Digest.String())) + + entries := readLayerEntries(t, r, st) + expected := []string{ + "text_encoder/", + "text_encoder/model-00001.safetensors", + } + assert.Equal(t, expected, entries) +} + +func TestPack_BundleSizeMaxSplits(t *testing.T) { + dir := t.TempDir() + + // Create 3 files of 10 bytes each. Set bundle max to 20 so it splits. + createTestFile(t, dir, "a.txt", 10) + createTestFile(t, dir, "b.txt", 10) + createTestFile(t, dir, "c.txt", 10) + + opts := &packOptions{ + BundleFileMax: 1024, // Everything is "small". + BundleSizeMax: 20, // Forces split: a+b in one bundle, c in another. + } + + results, st, err := packTestDir(t, dir, opts) + require.NoError(t, err) + require.Len(t, results.Layers, 2) + + // Both should be gzipped bundles. + for _, r := range results.Layers { + assert.Equal(t, types.MediaType(mediaTypeOCILayerTarGzip), r.MediaType) + } + + // First bundle should have a.txt and b.txt. + entries1 := readLayerEntries(t, results.Layers[0], st) + assert.Equal(t, []string{"a.txt", "b.txt"}, entries1) + + // Second bundle should have c.txt. + entries2 := readLayerEntries(t, results.Layers[1], st) + assert.Equal(t, []string{"c.txt"}, entries2) +} + +func TestPack_CustomThresholds(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "small.txt", 50) + createTestFile(t, dir, "large.bin", 200) + + opts := &packOptions{ + BundleFileMax: 100, // 50 is small, 200 is large + } + + results, _, err := packTestDir(t, dir, opts) + require.NoError(t, err) + require.Len(t, results.Layers, 2) + + // Bundle for small file: single-entry bundle. + assert.Equal(t, []string{"small.txt"}, filesInLayer(results, results.Layers[0].Digest.String())) + assert.Equal(t, types.MediaType(mediaTypeOCILayerTarGzip), results.Layers[0].MediaType) + + // Individual layer for large file (.bin is in incompressible set, so uncompressed). + assert.Equal(t, []string{"large.bin"}, filesInLayer(results, results.Layers[1].Digest.String())) + assert.Equal(t, types.MediaType(mediaTypeOCILayerTar), results.Layers[1].MediaType) +} + +func TestPack_SkipsDotCogDirectory(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "config.json", 100) + createTestFile(t, dir, ".cog/manifest.json", 50) + createTestFile(t, dir, ".cog/ready", 0) + + results, st, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 1) + + entries := readLayerEntries(t, results.Layers[0], st) + assert.Equal(t, []string{"config.json"}, entries) +} + +func TestPack_DeterministicTarProperties(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "data.txt", 100) + + results, st, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 1) + + // Stream the layer back out via fileLayer (the same path push + // uses) and inspect tar headers. + l := newFileLayer(t.Context(), results.Layers[0], st) + rc, err := l.Compressed() + require.NoError(t, err) + defer rc.Close() //nolint:errcheck + + gr, err := gzip.NewReader(rc) + require.NoError(t, err) + defer gr.Close() //nolint:errcheck + + epoch := time.Unix(0, 0) + + tr := tar.NewReader(gr) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + require.NoError(t, err) + + assert.Equal(t, epoch, hdr.ModTime, "mtime should be Unix epoch") + assert.Equal(t, 0, hdr.Uid, "uid should be 0") + assert.Equal(t, 0, hdr.Gid, "gid should be 0") + + switch hdr.Typeflag { + case tar.TypeReg: + assert.Equal(t, int64(0o644), hdr.Mode, "file mode should be 0644") + case tar.TypeDir: + assert.Equal(t, int64(0o755), hdr.Mode, "dir mode should be 0755") + } + } +} + +func TestPack_DigestDeterminism(t *testing.T) { + // Pack the same directory twice and verify digests match. + dir := t.TempDir() + createTestFile(t, dir, "a.txt", 100) + createTestFile(t, dir, "b.txt", 200) + + results1, _, err := packTestDir(t, dir, nil) + require.NoError(t, err) + + results2, _, err := packTestDir(t, dir, nil) + require.NoError(t, err) + + require.Len(t, results1.Layers, len(results2.Layers)) + for i := range results1.Layers { + assert.Equal(t, results1.Layers[i].Digest, results2.Layers[i].Digest, + "digest mismatch for result %d", i) + assert.Equal(t, results1.Layers[i].Size, results2.Layers[i].Size, + "size mismatch for result %d", i) + } +} + +func TestPack_ContextCancellation(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "file.txt", 100) + + // Independent cancellable context: we need to cancel before the call. + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + _, _, err := packTestDirCtx(t, ctx, dir, nil) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestPack_IncompressibleExtensions(t *testing.T) { + tests := []struct { + ext string + mediaType types.MediaType + }{ + {".safetensors", mediaTypeOCILayerTar}, + {".bin", mediaTypeOCILayerTar}, + {".gguf", mediaTypeOCILayerTar}, + {".onnx", mediaTypeOCILayerTar}, + {".parquet", mediaTypeOCILayerTar}, + {".pt", mediaTypeOCILayerTar}, + {".pth", mediaTypeOCILayerTar}, + {".dat", mediaTypeOCILayerTarGzip}, // compressible + {".json", mediaTypeOCILayerTarGzip}, // compressible + {".pickle", mediaTypeOCILayerTarGzip}, // compressible + } + + for _, tt := range tests { + t.Run(tt.ext, func(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "model"+tt.ext, 100*1024*1024) + + results, _, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 1) + assert.Equal(t, tt.mediaType, results.Layers[0].MediaType) + }) + } +} + +func TestPack_FileAtExactThreshold(t *testing.T) { + dir := t.TempDir() + // File exactly at the threshold should be "large" (>= bundle_file_max) + // and land in its own uncompressed-tar layer (.bin is incompressible). + createTestFile(t, dir, "model.bin", defaultBundleFileMax) + + results, _, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 1) + assert.Equal(t, types.MediaType(mediaTypeOCILayerTar), results.Layers[0].MediaType, + "at-threshold large file should be a single-file uncompressed tar layer") + assert.Equal(t, []string{"model.bin"}, filesInLayer(results, results.Layers[0].Digest.String())) +} + +func TestPack_FileJustBelowThreshold(t *testing.T) { + dir := t.TempDir() + // File just below the threshold should be bundled (tar+gzip). + createTestFile(t, dir, "model.bin", defaultBundleFileMax-1) + + results, _, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 1) + assert.Equal(t, types.MediaType(mediaTypeOCILayerTarGzip), results.Layers[0].MediaType, + "below-threshold file should land in a bundle (tar+gzip)") + assert.Equal(t, []string{"model.bin"}, filesInLayer(results, results.Layers[0].Digest.String())) +} + +func TestPack_LayerBytesAreReproducible(t *testing.T) { + // After cog-i12u there are no tar files on disk — layer bytes + // are streamed on demand. Verify that streaming the same layer + // twice produces byte-identical output (deterministic from the + // (plan, store) pair). + dir := t.TempDir() + createTestFile(t, dir, "a.txt", 100) + createTestFile(t, dir, "big.safetensors", 100*1024*1024) + + results, st, err := packTestDir(t, dir, nil) + require.NoError(t, err) + + for _, r := range results.Layers { + first := readLayerTar(t, r, st) + second := readLayerTar(t, r, st) + assert.Equal(t, first, second, + "streaming layer %s twice must yield identical bytes", r.Digest) + + // And the byte length matches what the packer recorded. + assert.Equal(t, r.Size, int64(len(first)), + "streamed layer %s size must match recorded Size", r.Digest) + } +} + +// readLayerTar streams a layer's full byte stream and returns it. +func readLayerTar(t *testing.T, lr packedLayer, st store.Store) []byte { + t.Helper() + l := newFileLayer(t.Context(), lr, st) + rc, err := l.Compressed() + require.NoError(t, err) + defer rc.Close() //nolint:errcheck + data, err := io.ReadAll(rc) + require.NoError(t, err) + return data +} + +func TestCollectDirsForPath(t *testing.T) { + tests := []struct { + path string + expected []string + }{ + {"file.txt", nil}, + {"a/file.txt", []string{"a"}}, + {"a/b/file.txt", []string{"a", "a/b"}}, + {"a/b/c/file.txt", []string{"a", "a/b", "a/b/c"}}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + got := collectDirsForPath(tt.path) + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestCollectDirs(t *testing.T) { + files := []weightsource.InventoryFile{ + {Path: "b/c/file.txt"}, + {Path: "a/file.txt"}, + {Path: "b/file.txt"}, + {Path: "root.txt"}, + } + got := collectDirs(files) + expected := []string{"a", "b", "b/c"} + assert.Equal(t, expected, got, "dirs should be sorted and deduplicated") +} + +// readLayerEntries streams a packedLayer's tar bytes back out via +// fileLayer (the same code path push uses) and returns the entry +// names in emission order. Handles both compressed and uncompressed +// tars based on lr.MediaType. +// +// This replaces the old readTarGzEntries/readTarEntries that took a +// path: layers no longer have on-disk paths post-cog-i12u. +func readLayerEntries(t *testing.T, lr packedLayer, st store.Store) []string { + t.Helper() + l := newFileLayer(t.Context(), lr, st) + rc, err := l.Compressed() + require.NoError(t, err) + defer rc.Close() //nolint:errcheck // best-effort + data, err := io.ReadAll(rc) + require.NoError(t, err) + + var r io.Reader = bytes.NewReader(data) + if lr.MediaType == mediaTypeOCILayerTarGzip { + gr, err := gzip.NewReader(r) + require.NoError(t, err) + defer gr.Close() //nolint:errcheck // best-effort + r = gr + } + return readTarNames(t, tar.NewReader(r)) +} + +// readTarNames reads all entry names from a tar reader. +func readTarNames(t *testing.T, tr *tar.Reader) []string { + t.Helper() + var names []string + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + require.NoError(t, err) + names = append(names, hdr.Name) + } + return names +} + +// Verify that sorting files produces stable, deterministic ordering. +func TestSmallFileSortingStability(t *testing.T) { + files := []weightsource.InventoryFile{ + {Path: "z.txt", Size: 10}, + {Path: "a.txt", Size: 10}, + {Path: "m/b.txt", Size: 10}, + {Path: "m/a.txt", Size: 10}, + } + + sort.SliceStable(files, func(i, j int) bool { + return files[i].Path < files[j].Path + }) + + expected := []string{"a.txt", "m/a.txt", "m/b.txt", "z.txt"} + var got []string + for _, f := range files { + got = append(got, f.Path) + } + assert.Equal(t, expected, got) +} + +func TestPack_DeepNestedDirsInLargeFile(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "a/b/c/model.safetensors", 100*1024*1024) + + results, st, err := packTestDir(t, dir, nil) + require.NoError(t, err) + require.Len(t, results.Layers, 1) + + entries := readLayerEntries(t, results.Layers[0], st) + expected := []string{ + "a/", + "a/b/", + "a/b/c/", + "a/b/c/model.safetensors", + } + assert.Equal(t, expected, entries) +} + +func TestPack_WorkedExample(t *testing.T) { + // Simulate the z-image-turbo layout from spec §4. + dir := t.TempDir() + + // Small files (configs, tokenizers) — all < 64 MB. + smallFiles := []string{ + "config.json", + "model_index.json", + "tokenizer/tokenizer_config.json", + "tokenizer/special_tokens_map.json", + "tokenizer/vocab.json", + "tokenizer/merges.txt", + } + for _, f := range smallFiles { + createTestFile(t, dir, f, 1024) // 1 KB each + } + + // Large files (safetensors) — each > 64 MB. + largeFiles := []string{ + "text_encoder/model-00001-of-00003.safetensors", + "text_encoder/model-00002-of-00003.safetensors", + "text_encoder/model-00003-of-00003.safetensors", + "vae/diffusion_pytorch_model.safetensors", + "transformer/diffusion_pytorch_model-00001-of-00003.safetensors", + "transformer/diffusion_pytorch_model-00002-of-00003.safetensors", + "transformer/diffusion_pytorch_model-00003-of-00003.safetensors", + } + for _, f := range largeFiles { + createTestFile(t, dir, f, 100*1024*1024) + } + + results, st, err := packTestDir(t, dir, nil) + require.NoError(t, err) + + // 1 bundle for small files + 7 individual layers for large files = 8 total. + require.Len(t, results.Layers, 8) + + // First result is the bundle (all small files landed in one layer). + bundle := results.Layers[0] + assert.Equal(t, types.MediaType(mediaTypeOCILayerTarGzip), bundle.MediaType) + assert.True(t, isBundleLayer(results, bundle.Digest.String()), "first layer should be a bundle") + + // Remaining 7 are individual files, each a standalone uncompressed + // .safetensors layer. + for i := 1; i <= 7; i++ { + r := results.Layers[i] + assert.Equal(t, types.MediaType(mediaTypeOCILayerTar), r.MediaType) + paths := filesInLayer(results, r.Digest.String()) + require.Len(t, paths, 1, "layer %d should carry exactly one file", i) + assert.True(t, strings.HasSuffix(paths[0], ".safetensors"), + "layer %d file %q should be a .safetensors", i, paths[0]) + } + + // Verify no path appears in more than one layer (order-independence). + allPaths := make(map[string]int) + for i, r := range results.Layers { + // readLayerEntries handles both compressed and + // uncompressed media types. + entries := readLayerEntries(t, r, st) + for _, e := range entries { + if strings.HasSuffix(e, "/") { + continue // Skip directory entries for this check. + } + if prev, ok := allPaths[e]; ok { + t.Errorf("path %q appears in both layer %d and %d", e, prev, i) + } + allPaths[e] = i + } + } +} diff --git a/pkg/model/pusher.go b/pkg/model/pusher.go index 56bcb97c02..73f4e8fa5c 100644 --- a/pkg/model/pusher.go +++ b/pkg/model/pusher.go @@ -1,4 +1,3 @@ -// pkg/model/pusher.go package model import ( @@ -15,16 +14,6 @@ import ( // PushOptions configures push behavior. type PushOptions struct { - // ProjectDir is the base directory for resolving weight file paths. - // - // Deprecated: Artifacts carry their own file paths. - ProjectDir string - - // FilePaths maps weight name identifiers to their file paths. - // - // Deprecated: Use Model.Artifacts instead — WeightArtifact carries FilePath. - FilePaths map[string]string - // Platform specifies the target platform for bundle indexes. // Default: linux/amd64 Platform *Platform @@ -40,46 +29,35 @@ type PushOptions struct { OnFallback func() } -// ============================================================================= -// BundlePusher - pushes OCI Index with image + weights -// ============================================================================= - -// BundlePusher pushes bundles (OCI Index with image + weight artifacts). -// It orchestrates ImagePusher and WeightPusher, then assembles the OCI index -// from the pushed manifest descriptors. +// BundlePusher pushes an OCI Image Index containing a model image + its +// weight manifests. It pushes the image, HEAD-checks each weight +// manifest (which was pushed by `cog weights import`), then assembles +// the index from the descriptors. type BundlePusher struct { - imagePusher *ImagePusher - weightPusher *WeightPusher - registry registry.Client + imagePusher *ImagePusher + registry registry.Client } -// NewBundlePusher creates a new BundlePusher from docker and registry clients. -// Both sub-pushers (image and weight) are created internally to keep -// construction unified — callers don't need to know about ImagePusher or -// WeightPusher directly. +// NewBundlePusher creates a BundlePusher. func NewBundlePusher(docker command.Command, reg registry.Client) *BundlePusher { return &BundlePusher{ - imagePusher: newImagePusher(docker, reg), - weightPusher: NewWeightPusher(reg), - registry: reg, + imagePusher: newImagePusher(docker, reg), + registry: reg, } } -// Push pushes the model as an OCI Index with weight artifacts. -// It reads Model.Artifacts to find the image and weight artifacts to push. +// Push pushes the model as an OCI Index. The image is pushed via +// ImagePusher. Weight manifests are verified via HEAD (they were +// pushed by `cog weights import`); if any are missing, the push +// fails with a message to re-run import. func (p *BundlePusher) Push(ctx context.Context, m *Model, opts PushOptions) error { - // Extract artifacts from model imgArtifact := m.GetImageArtifact() if imgArtifact == nil { return fmt.Errorf("no image artifact in model") } - weightArtifacts := m.WeightArtifacts() - - // Derive repo from image reference (strip tag/digest for weight pushes) repo := repoFromReference(imgArtifact.Reference) - // 1. Push image via OCI chunked push (falls back to Docker push on error) var imagePushOpts []ImagePushOption if opts.ImageProgressFn != nil { imagePushOpts = append(imagePushOpts, WithProgressFn(opts.ImageProgressFn)) @@ -91,22 +69,28 @@ func (p *BundlePusher) Push(ctx context.Context, m *Model, opts PushOptions) err return fmt.Errorf("push image %q: %w", imgArtifact.Reference, err) } - // 2. Get image manifest descriptor (lightweight HEAD request) - imgDesc, err := p.registry.GetDescriptor(ctx, imgArtifact.Reference) - if err != nil { - return fmt.Errorf("get image descriptor: %w", err) - } + // HEAD the image and verify weight manifests concurrently. + var imgDesc v1.Descriptor + var weightDescs []v1.Descriptor - // 3. Push weight artifacts concurrently (if any) - var weightResults []*WeightPushResult - if len(weightArtifacts) > 0 { - weightResults, err = p.pushWeights(ctx, repo, weightArtifacts) - if err != nil { - return err + g, gctx := errgroup.WithContext(ctx) + g.Go(func() error { + var descErr error + imgDesc, descErr = p.registry.GetDescriptor(gctx, imgArtifact.Reference) + if descErr != nil { + return fmt.Errorf("get image descriptor: %w", descErr) } + return nil + }) + g.Go(func() error { + var verifyErr error + weightDescs, verifyErr = p.verifyWeights(gctx, repo, m.Weights) + return verifyErr + }) + if err := g.Wait(); err != nil { + return err } - // 4. Build OCI index from pushed descriptors platform := opts.Platform if platform == nil { platform = &Platform{OS: "linux", Architecture: "amd64"} @@ -118,9 +102,9 @@ func (p *BundlePusher) Push(ctx context.Context, m *Model, opts PushOptions) err Architecture: platform.Architecture, Variant: platform.Variant, }) - for i, wr := range weightResults { - builder.AddWeightDescriptor(wr.Descriptor, imgDesc.Digest.String(), - weightArtifacts[i].Name(), weightArtifacts[i].Target) + for i, desc := range weightDescs { + w := m.Weights[i] + builder.AddWeightDescriptor(desc, w.Name, w.SetDigest, w.Size) } idx, err := builder.BuildFromDescriptors() @@ -128,7 +112,7 @@ func (p *BundlePusher) Push(ctx context.Context, m *Model, opts PushOptions) err return fmt.Errorf("build OCI index: %w", err) } - // 5. Push OCI index (overwrites the tag with the index) + // Overwrites the tag with the index. if err := p.registry.PushIndex(ctx, imgArtifact.Reference, idx); err != nil { return fmt.Errorf("push OCI index: %w", err) } @@ -136,22 +120,35 @@ func (p *BundlePusher) Push(ctx context.Context, m *Model, opts PushOptions) err return nil } -// pushWeights pushes all weight artifacts concurrently (bounded by GetPushConcurrency) -// and returns their results in the same order as the input slice. -// If any weight push fails, remaining pushes are canceled and the first error is returned. -func (p *BundlePusher) pushWeights(ctx context.Context, repo string, weights []*WeightArtifact) ([]*WeightPushResult, error) { - ordered := make([]*WeightPushResult, len(weights)) +// verifyWeights HEAD-checks each weight manifest in the registry, +// returning descriptors in input order. Returns an error if any +// manifest is not found (the user needs to run `cog weights import`). +func (p *BundlePusher) verifyWeights( + ctx context.Context, + repo string, + weights []Weight, +) ([]v1.Descriptor, error) { + if len(weights) == 0 { + return nil, nil + } + + descs := make([]v1.Descriptor, len(weights)) g, ctx := errgroup.WithContext(ctx) g.SetLimit(GetPushConcurrency()) - for i, wa := range weights { + for i, w := range weights { g.Go(func() error { - result, err := p.weightPusher.Push(ctx, repo, wa) + tag := WeightTag(w.Name, w.SetDigest) + ref := repo + ":" + tag + desc, err := p.registry.GetDescriptor(ctx, ref) if err != nil { - return fmt.Errorf("push weight %q: %w", wa.Name(), err) + return fmt.Errorf( + "weight %q not found in registry (%s); run 'cog weights import' to push weights first: %w", + w.Name, ref, err, + ) } - ordered[i] = result + descs[i] = desc return nil }) } @@ -160,13 +157,11 @@ func (p *BundlePusher) pushWeights(ctx context.Context, repo string, weights []* return nil, err } - return ordered, nil + return descs, nil } -// repoFromReference extracts the repository (without tag or digest) from an image reference. -// "r8.im/user/model:latest" -> "r8.im/user/model" -// "r8.im/user/model@sha256:abc" -> "r8.im/user/model" -// "localhost:5000/model:latest" -> "localhost:5000/model" +// repoFromReference extracts the repository (without tag or digest) from an +// image reference. "r8.im/user/model:latest" -> "r8.im/user/model". func repoFromReference(ref string) string { parsed, err := name.ParseReference(ref, name.Insecure) if err != nil { diff --git a/pkg/model/pusher_test.go b/pkg/model/pusher_test.go index 25a0225950..f61dec49f7 100644 --- a/pkg/model/pusher_test.go +++ b/pkg/model/pusher_test.go @@ -3,18 +3,28 @@ package model import ( "context" "errors" - "os" - "path/filepath" "sync" "sync/atomic" "testing" - "time" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/types" "github.com/stretchr/testify/require" ) +const testImageRef = "r8.im/user/model:latest" + +// testBundleModel builds a Model with a fixed image ref and optional weights. +func testBundleModel(weights ...Weight) *Model { + return &Model{ + Image: &ImageArtifact{Reference: testImageRef}, + Artifacts: []Artifact{ + &ImageArtifact{name: "model", Reference: testImageRef}, + }, + Weights: weights, + } +} + // ============================================================================= // BundlePusher tests // ============================================================================= @@ -62,23 +72,19 @@ func TestBundlePusher_Push(t *testing.T) { } pusher := NewBundlePusher(docker, reg) - m := &Model{ - Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - // no weight artifacts — image-only model - }, - } - err := pusher.Push(context.Background(), m, PushOptions{}) + err := pusher.Push(context.Background(), testBundleModel(), PushOptions{}) require.NoError(t, err) }) t.Run("full push flow succeeds with single weight", func(t *testing.T) { - // Create temp weight file - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.safetensors") - require.NoError(t, os.WriteFile(weightPath, []byte("fake weight data"), 0o644)) + w := Weight{ + Name: "model-v1", + Target: "/src/weights/model-v1", + Digest: "sha256:weightdigest123", + SetDigest: "sha256:setdigestabc", + Size: 4096, + } // Track call sequence (mutex-protected for goroutine safety) var mu sync.Mutex @@ -99,18 +105,24 @@ func TestBundlePusher_Push(t *testing.T) { imgDesc := v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 1234, - Digest: v1.Hash{Algorithm: "sha256", Hex: "imgdigest"}, + Digest: v1.Hash{Algorithm: "sha256", Hex: "imgdigestabc1234567"}, + } + + weightDesc := v1.Descriptor{ + MediaType: types.OCIManifestSchema1, + Size: 500, + Digest: v1.Hash{Algorithm: "sha256", Hex: "weightdigest123"}, } reg := &mockRegistry{ getDescriptorFunc: func(ctx context.Context, ref string) (v1.Descriptor, error) { track("registry:getDescriptor:" + ref) + expectedWeightTag := WeightTag(w.Name, w.SetDigest) + if ref == "r8.im/user/model:"+expectedWeightTag { + return weightDesc, nil + } return imgDesc, nil }, - pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { - track("registry:pushImage:" + ref) - return nil - }, pushIndexFunc: func(ctx context.Context, ref string, idx v1.ImageIndex) error { track("registry:pushIndex:" + ref) @@ -126,53 +138,33 @@ func TestBundlePusher_Push(t *testing.T) { // Second manifest: weight with annotations require.Equal(t, PlatformUnknown, idxManifest.Manifests[1].Platform.OS) - require.Equal(t, AnnotationValueWeights, idxManifest.Manifests[1].Annotations[AnnotationReferenceType]) - require.Equal(t, imgDesc.Digest.String(), idxManifest.Manifests[1].Annotations[AnnotationReferenceDigest]) + require.NotEmpty(t, idxManifest.Manifests[1].Annotations[AnnotationV1WeightName]) + require.NotEmpty(t, idxManifest.Manifests[1].Annotations[AnnotationV1WeightSetDigest]) return nil }, } pusher := NewBundlePusher(docker, reg) - m := &Model{ - Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - NewWeightArtifact("model-v1", v1.Descriptor{ - Digest: v1.Hash{Algorithm: "sha256", Hex: "aabbccddee112233445566778899aabb"}, - }, weightPath, "/weights/model.safetensors", WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "model-v1", - Target: "/weights/model.safetensors", - Created: time.Now().UTC(), - }), - }, - } - err := pusher.Push(context.Background(), m, PushOptions{ + err := pusher.Push(context.Background(), testBundleModel(w), PushOptions{ Platform: &Platform{OS: "linux", Architecture: "amd64"}, }) require.NoError(t, err) - // Verify the call sequence: - // 1. Push image via docker - // 2. Get image descriptor from registry (lightweight HEAD) - // 3. Push weight via registry (single combined tag) - // 4. Push OCI index to registry + // Verify call sequence. Image push is first, index push is last. + // The two HEAD checks (image + weight) run concurrently so their + // relative order is nondeterministic. require.Len(t, callOrder, 4) require.Equal(t, "docker:push:r8.im/user/model:latest", callOrder[0]) - require.Equal(t, "registry:getDescriptor:r8.im/user/model:latest", callOrder[1]) - require.Equal(t, "registry:pushImage:r8.im/user/model:weights-model-v1-aabbccddee11", callOrder[2]) + expectedTag := WeightTag(w.Name, w.SetDigest) + require.Contains(t, callOrder, "registry:getDescriptor:r8.im/user/model:latest") + require.Contains(t, callOrder, "registry:getDescriptor:r8.im/user/model:"+expectedTag) require.Equal(t, "registry:pushIndex:r8.im/user/model:latest", callOrder[3]) }) t.Run("uses default platform when not specified", func(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) - docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } @@ -185,7 +177,6 @@ func TestBundlePusher_Push(t *testing.T) { Digest: v1.Hash{Algorithm: "sha256", Hex: "abc"}, }, nil }, - pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return nil }, pushIndexFunc: func(ctx context.Context, ref string, idx v1.ImageIndex) error { idxManifest, _ := idx.IndexManifest() // Default platform should be linux/amd64 @@ -196,26 +187,14 @@ func TestBundlePusher_Push(t *testing.T) { } pusher := NewBundlePusher(docker, reg) - m := &Model{ - Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - NewWeightArtifact("w1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", - Target: "/weights/model.bin", Created: time.Now().UTC(), - }), - }, - } - err := pusher.Push(context.Background(), m, PushOptions{}) + err := pusher.Push(context.Background(), testBundleModel( + Weight{Name: "w1", Target: "/src/weights/w1", SetDigest: "sha256:abc"}, + ), PushOptions{}) require.NoError(t, err) }) t.Run("returns error when image push fails", func(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) - docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return errors.New("unauthorized: authentication required") @@ -224,18 +203,9 @@ func TestBundlePusher_Push(t *testing.T) { reg := &mockRegistry{} pusher := NewBundlePusher(docker, reg) - m := &Model{ - Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - NewWeightArtifact("w1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", - Target: "/weights/model.bin", Created: time.Now().UTC(), - }), - }, - } + w1 := Weight{Name: "w1", Target: "/src/weights/w1", SetDigest: "sha256:abc"} - err := pusher.Push(context.Background(), m, PushOptions{}) + err := pusher.Push(context.Background(), testBundleModel(w1), PushOptions{}) require.Error(t, err) require.Contains(t, err.Error(), "push image") @@ -243,10 +213,6 @@ func TestBundlePusher_Push(t *testing.T) { }) t.Run("returns error when get descriptor fails", func(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) - docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } @@ -257,68 +223,47 @@ func TestBundlePusher_Push(t *testing.T) { } pusher := NewBundlePusher(docker, reg) - m := &Model{ - Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - NewWeightArtifact("w1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", - Target: "/weights/model.bin", Created: time.Now().UTC(), - }), - }, - } + w1 := Weight{Name: "w1", Target: "/src/weights/w1", SetDigest: "sha256:abc"} - err := pusher.Push(context.Background(), m, PushOptions{}) + err := pusher.Push(context.Background(), testBundleModel(w1), PushOptions{}) require.Error(t, err) - require.Contains(t, err.Error(), "get image descriptor") + require.Contains(t, err.Error(), "manifest not found") }) - t.Run("returns error when weight push fails", func(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) + t.Run("returns error when weight manifest not in registry", func(t *testing.T) { + imgDesc := v1.Descriptor{ + MediaType: types.OCIManifestSchema1, + Size: 100, + Digest: v1.Hash{Algorithm: "sha256", Hex: "abc"}, + } docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } reg := &mockRegistry{ getDescriptorFunc: func(ctx context.Context, ref string) (v1.Descriptor, error) { - return v1.Descriptor{ - MediaType: types.OCIManifestSchema1, - Size: 100, - Digest: v1.Hash{Algorithm: "sha256", Hex: "abc"}, - }, nil - }, - pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { - return errors.New("weight push failed: quota exceeded") + // Image HEAD succeeds, weight HEAD fails + if ref == "r8.im/user/model:latest" { + return imgDesc, nil + } + return v1.Descriptor{}, errors.New("manifest unknown") }, } pusher := NewBundlePusher(docker, reg) - m := &Model{ - Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - NewWeightArtifact("w1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", - Target: "/weights/model.bin", Created: time.Now().UTC(), - }), - }, - } + w1 := Weight{Name: "w1", Target: "/src/weights/w1", SetDigest: "sha256:abc"} - err := pusher.Push(context.Background(), m, PushOptions{}) + err := pusher.Push(context.Background(), testBundleModel(w1), PushOptions{}) require.Error(t, err) - require.Contains(t, err.Error(), "push weight") require.Contains(t, err.Error(), "w1") + require.Contains(t, err.Error(), "not found in registry") + require.Contains(t, err.Error(), "cog weights import") + require.Contains(t, err.Error(), "manifest unknown") }) t.Run("returns error when index push fails", func(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) - docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } @@ -330,55 +275,39 @@ func TestBundlePusher_Push(t *testing.T) { Digest: v1.Hash{Algorithm: "sha256", Hex: "abc"}, }, nil }, - pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return nil }, pushIndexFunc: func(ctx context.Context, ref string, idx v1.ImageIndex) error { return errors.New("index push failed") }, } pusher := NewBundlePusher(docker, reg) - m := &Model{ - Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - NewWeightArtifact("w1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", - Target: "/weights/model.bin", Created: time.Now().UTC(), - }), - }, - } + w1 := Weight{Name: "w1", Target: "/src/weights/w1", SetDigest: "sha256:abc"} - err := pusher.Push(context.Background(), m, PushOptions{}) + err := pusher.Push(context.Background(), testBundleModel(w1), PushOptions{}) require.Error(t, err) require.Contains(t, err.Error(), "push OCI index") }) - t.Run("pushes multiple weights concurrently", func(t *testing.T) { - dir := t.TempDir() - weight1Path := filepath.Join(dir, "model1.bin") - weight2Path := filepath.Join(dir, "model2.bin") - require.NoError(t, os.WriteFile(weight1Path, []byte("weight 1 data"), 0o644)) - require.NoError(t, os.WriteFile(weight2Path, []byte("weight 2 data"), 0o644)) - + t.Run("verifies multiple weights concurrently", func(t *testing.T) { docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } // Use atomic counter — safe for concurrent access from goroutines - var pushedWeightCount atomic.Int32 + var headCheckCount atomic.Int32 reg := &mockRegistry{ getDescriptorFunc: func(ctx context.Context, ref string) (v1.Descriptor, error) { + // Count HEAD checks for weight tags (not the image tag) + if ref != "r8.im/user/model:latest" { + headCheckCount.Add(1) + } return v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 100, Digest: v1.Hash{Algorithm: "sha256", Hex: "abc"}, }, nil }, - pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { - pushedWeightCount.Add(1) - return nil - }, pushIndexFunc: func(ctx context.Context, ref string, idx v1.ImageIndex) error { idxManifest, _ := idx.IndexManifest() require.Len(t, idxManifest.Manifests, 3) // 1 image + 2 weights @@ -387,30 +316,16 @@ func TestBundlePusher_Push(t *testing.T) { } pusher := NewBundlePusher(docker, reg) - m := &Model{ - Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - NewWeightArtifact("w1", v1.Descriptor{ - Digest: v1.Hash{Algorithm: "sha256", Hex: "aaaa111122223333444455556666777788889999aaaabbbbccccddddeeee0000"}, - }, weight1Path, "/weights/model1.bin", WeightConfig{ - SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", - Target: "/weights/model1.bin", Created: time.Now().UTC(), - }), - NewWeightArtifact("w2", v1.Descriptor{ - Digest: v1.Hash{Algorithm: "sha256", Hex: "bbbb111122223333444455556666777788889999aaaabbbbccccddddeeee0000"}, - }, weight2Path, "/weights/model2.bin", WeightConfig{ - SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w2", - Target: "/weights/model2.bin", Created: time.Now().UTC(), - }), - }, - } - err := pusher.Push(context.Background(), m, PushOptions{}) + err := pusher.Push(context.Background(), testBundleModel( + Weight{Name: "w1", Target: "/src/weights/w1", SetDigest: "sha256:set1"}, + Weight{Name: "w2", Target: "/src/weights/w2", SetDigest: "sha256:set2"}, + ), PushOptions{}) require.NoError(t, err) - require.Equal(t, int32(2), pushedWeightCount.Load()) // both weights pushed (1 tag each) + require.Equal(t, int32(2), headCheckCount.Load()) // both weights HEAD-checked }) + } // ============================================================================= @@ -429,19 +344,12 @@ func TestResolver_Push(t *testing.T) { reg := &mockRegistry{} resolver := NewResolver(docker, reg) - m := &Model{ - Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - }, - } - - err := resolver.Push(context.Background(), m, PushOptions{}) + err := resolver.Push(context.Background(), testBundleModel(), PushOptions{}) require.NoError(t, err) require.True(t, dockerPushed, "standalone should use docker push") }) - t.Run("OCIIndex false uses docker push", func(t *testing.T) { + t.Run("no weights uses docker push", func(t *testing.T) { var dockerPushed bool docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { @@ -452,20 +360,12 @@ func TestResolver_Push(t *testing.T) { reg := &mockRegistry{} resolver := NewResolver(docker, reg) - m := &Model{ - // OCIIndex not set (false by default) - Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - }, - } - - err := resolver.Push(context.Background(), m, PushOptions{}) + err := resolver.Push(context.Background(), testBundleModel(), PushOptions{}) require.NoError(t, err) - require.True(t, dockerPushed, "default format should use docker push") + require.True(t, dockerPushed, "model without weights should use docker push") }) - t.Run("OCIIndex true produces an OCI index", func(t *testing.T) { + t.Run("bundle with weights produces an OCI index", func(t *testing.T) { var indexPushed bool docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, @@ -485,17 +385,11 @@ func TestResolver_Push(t *testing.T) { } resolver := NewResolver(docker, reg) - m := &Model{ - OCIIndex: true, - Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, - Artifacts: []Artifact{ - &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, - }, - } - - err := resolver.Push(context.Background(), m, PushOptions{}) + err := resolver.Push(context.Background(), testBundleModel( + Weight{Name: "w1", Target: "/src/weights/w1", SetDigest: "sha256:abc"}, + ), PushOptions{}) require.NoError(t, err) - require.True(t, indexPushed, "OCIIndex=true should push an OCI index") + require.True(t, indexPushed, "bundle with weights should push an OCI index") }) t.Run("standalone returns error when image nil", func(t *testing.T) { @@ -513,15 +407,15 @@ func TestResolver_Push(t *testing.T) { require.Contains(t, err.Error(), "no image artifact") }) - t.Run("OCIIndex true returns error when no image artifact", func(t *testing.T) { + t.Run("bundle returns error when no image artifact", func(t *testing.T) { docker := &mockDocker{} reg := &mockRegistry{} resolver := NewResolver(docker, reg) m := &Model{ - OCIIndex: true, - Image: nil, - Artifacts: []Artifact{}, + Weights: []Weight{ + {Name: "w1", Target: "/src/weights/w1", SetDigest: "sha256:abc"}, + }, } err := resolver.Push(context.Background(), m, PushOptions{}) @@ -535,13 +429,6 @@ func TestResolver_Push(t *testing.T) { // ============================================================================= func TestPushOptions(t *testing.T) { - t.Run("ProjectDir field", func(t *testing.T) { - opts := PushOptions{ - ProjectDir: "/path/to/project", - } - require.Equal(t, "/path/to/project", opts.ProjectDir) - }) - t.Run("Platform field", func(t *testing.T) { opts := PushOptions{ Platform: &Platform{OS: "linux", Architecture: "arm64"}, diff --git a/pkg/model/resolver.go b/pkg/model/resolver.go index e2d0fa941d..12f4103742 100644 --- a/pkg/model/resolver.go +++ b/pkg/model/resolver.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "path/filepath" "strings" "github.com/docker/docker/api/types/image" @@ -222,7 +221,6 @@ func (r *Resolver) Build(ctx context.Context, src *Source, opts BuildOptions) (* } opts = opts.WithDefaults(src) - // Build image artifact via ImageBuilder ib := NewImageBuilder(r.factory, r.docker, src, opts) imageSpec := NewImageSpec("model", opts.ImageName) imgResult, err := ib.Build(ctx, imageSpec) @@ -239,26 +237,14 @@ func (r *Resolver) Build(ctx context.Context, src *Source, opts BuildOptions) (* return nil, err } - m.OCIIndex = opts.OCIIndex m.Artifacts = []Artifact{ia} - // Build weight artifacts if OCI index mode is enabled - lockPath := opts.WeightsLockPath - if lockPath == "" { - lockPath = filepath.Join(src.ProjectDir, WeightsLockFilename) - } - - if opts.OCIIndex && len(src.Config.Weights) > 0 { - wb := NewWeightBuilder(src, m.CogVersion, lockPath) - for _, ws := range src.Config.Weights { - spec := NewWeightSpec(ws.Name, ws.Source, ws.Target) - artifact, buildErr := wb.Build(ctx, spec) - if buildErr != nil { - return nil, fmt.Errorf("build weight %q: %w", ws.Name, buildErr) - } - m.Artifacts = append(m.Artifacts, artifact) + if len(src.Config.Weights) > 0 { + weights, weightErr := WeightsFromLockfile(src.ProjectDir) + if weightErr != nil { + return nil, weightErr } - + m.Weights = weights } return m, nil @@ -270,7 +256,7 @@ func (r *Resolver) Build(ctx context.Context, src *Source, opts BuildOptions) (* // monolithic push and supports layers of any size through chunked uploads. // Falls back to legacy Docker push if OCI push is not available. func (r *Resolver) Push(ctx context.Context, m *Model, opts PushOptions) error { - if m.OCIIndex { + if m.IsBundle() { pusher := NewBundlePusher(r.docker, r.registry) return pusher.Push(ctx, m, opts) } @@ -397,37 +383,6 @@ func (r *Resolver) modelFromIndex(ref *ParsedRef, manifest *registry.ManifestRes return nil, fmt.Errorf("image %s: %w", ref.Original, err) } - m.Index = &Index{ - Digest: manifest.Digest, // Content-addressable digest from registry - Reference: ref.String(), - MediaType: manifest.MediaType, - Manifests: make([]IndexManifest, len(manifest.Manifests)), - } - - // Populate index manifests - for i, pm := range manifest.Manifests { - im := IndexManifest{ - Digest: pm.Digest, - MediaType: pm.MediaType, - Size: pm.Size, - Annotations: pm.Annotations, - } - if pm.OS != "" { - im.Platform = &Platform{ - OS: pm.OS, - Architecture: pm.Architecture, - Variant: pm.Variant, - } - } - // Determine manifest type - if pm.OS == PlatformUnknown && pm.Annotations != nil && pm.Annotations[AnnotationReferenceType] == AnnotationValueWeights { - im.Type = ManifestTypeWeights - } else { - im.Type = ManifestTypeImage - } - m.Index.Manifests[i] = im - } - return m, nil } @@ -436,18 +391,6 @@ func isOCIIndex(mr *registry.ManifestResult) bool { return mr.IsIndex() } -// findWeightsManifest finds the weights manifest in an index. -// Returns nil if no weights manifest is found. -func findWeightsManifest(manifests []registry.PlatformManifest) *registry.PlatformManifest { - for i := range manifests { - m := &manifests[i] - if m.Annotations != nil && m.Annotations[AnnotationReferenceType] == AnnotationValueWeights { - return m - } - } - return nil -} - // findImageManifest finds the model image manifest in an index. // If platform is specified, matches on OS/Architecture. // Skips artifacts (platform: unknown/unknown). diff --git a/pkg/model/resolver_test.go b/pkg/model/resolver_test.go index 17e7b73282..0692ecadd9 100644 --- a/pkg/model/resolver_test.go +++ b/pkg/model/resolver_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "io" - "os" "path/filepath" "testing" @@ -18,6 +17,7 @@ import ( "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/weights/lockfile" ) // mockDocker implements command.Command for testing. @@ -149,6 +149,10 @@ func (m *mockRegistry) WriteLayer(ctx context.Context, opts registry.WriteLayerO return nil } +func (m *mockRegistry) BlobExists(_ context.Context, _ string, _ string) (bool, error) { + return false, nil +} + // mockFactory implements Factory for testing. type mockFactory struct { name string @@ -1016,7 +1020,7 @@ func TestResolver_Build_NoWeightsManifestWithoutWeights(t *testing.T) { require.NoError(t, err) require.False(t, m.IsBundle()) - require.Empty(t, m.WeightArtifacts()) + require.Empty(t, m.Weights) } func TestResolver_Build_PopulatesArtifacts(t *testing.T) { @@ -1072,7 +1076,7 @@ func TestResolver_Build_PopulatesArtifacts(t *testing.T) { require.Equal(t, imageDigest, imgArtifact.Descriptor().Digest.String()) } -func TestResolver_Build_PopulatesWeightArtifacts(t *testing.T) { +func TestResolver_Build_PopulatesWeights(t *testing.T) { imageDigest := "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" docker := &mockDocker{ @@ -1082,7 +1086,6 @@ func TestResolver_Build_PopulatesWeightArtifacts(t *testing.T) { Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ - LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.15.0", }, @@ -1103,16 +1106,32 @@ func TestResolver_Build_PopulatesWeightArtifacts(t *testing.T) { } resolver := NewResolver(docker, &mockRegistry{}).WithFactory(factory) - // Create a temp directory with a real weight file dir := t.TempDir() - weightContent := []byte("test weight for resolver build") - require.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), weightContent, 0o644)) + + // Pre-create a lockfile (as if `cog weights import` had run). + lock := &lockfile.WeightsLock{ + Version: lockfile.Version, + Weights: []lockfile.WeightLockEntry{ + { + Name: "my-model", + Target: "/srv/weights/model", + Digest: "sha256:deadbeef", + SetDigest: "sha256:setdigest123", + Size: 4096, + SizeCompressed: 2048, + Source: lockfile.WeightLockSource{ + URI: "file://./weights/model", + }, + }, + }, + } + require.NoError(t, lock.Save(filepath.Join(dir, lockfile.WeightsLockFilename))) src := &Source{ Config: &config.Config{ Build: &config.Build{}, Weights: []config.WeightSource{ - {Name: "my-model", Source: "model.safetensors", Target: "/srv/weights/model.safetensors"}, + {Name: "my-model", Target: "/srv/weights/model", Source: &config.WeightSourceConfig{URI: "weights/model"}}, }, }, ProjectDir: dir, @@ -1120,37 +1139,26 @@ func TestResolver_Build_PopulatesWeightArtifacts(t *testing.T) { m, err := resolver.Build(context.Background(), src, BuildOptions{ ImageName: "test-image:latest", - OCIIndex: true, }) require.NoError(t, err) - require.NotNil(t, m.Artifacts) - // Should have 2 artifacts: 1 image + 1 weight - require.Len(t, m.Artifacts, 2, "should have image + weight artifacts") + // Artifacts should contain only the image (weights are not artifacts). + require.Len(t, m.Artifacts, 1, "should have image artifact only") + require.NotNil(t, m.GetImageArtifact()) - // Verify image artifact - imgArtifact := m.GetImageArtifact() - require.NotNil(t, imgArtifact) - require.Equal(t, "model", imgArtifact.Name()) + // Weights loaded from lockfile. + require.Len(t, m.Weights, 1) + require.Equal(t, "my-model", m.Weights[0].Name) + require.Equal(t, "/srv/weights/model", m.Weights[0].Target) + require.Equal(t, "sha256:deadbeef", m.Weights[0].Digest) + require.Equal(t, "sha256:setdigest123", m.Weights[0].SetDigest) + require.Equal(t, int64(4096), m.Weights[0].Size) - // Verify weight artifact - weightArtifacts := m.WeightArtifacts() - require.Len(t, weightArtifacts, 1) - wa := weightArtifacts[0] - require.Equal(t, "my-model", wa.Name()) - require.Equal(t, ArtifactTypeWeight, wa.Type()) - require.Equal(t, "/srv/weights/model.safetensors", wa.Target) - require.Equal(t, filepath.Join(dir, "model.safetensors"), wa.FilePath) - - // Weight config should be populated - require.Equal(t, "1.0", wa.Config.SchemaVersion) - require.Equal(t, "my-model", wa.Config.Name) - require.Equal(t, "/srv/weights/model.safetensors", wa.Config.Target) - require.False(t, wa.Config.Created.IsZero()) + require.True(t, m.IsBundle()) } -func TestResolver_Build_WithWeightsLoadsManifest(t *testing.T) { +func TestResolver_Build_WithWeightsIsBundle(t *testing.T) { imageDigest := "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" docker := &mockDocker{ @@ -1160,7 +1168,6 @@ func TestResolver_Build_WithWeightsLoadsManifest(t *testing.T) { Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ - LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.15.0", }, @@ -1182,13 +1189,30 @@ func TestResolver_Build_WithWeightsLoadsManifest(t *testing.T) { resolver := NewResolver(docker, &mockRegistry{}).WithFactory(factory) dir := t.TempDir() - require.NoError(t, os.WriteFile(filepath.Join(dir, "model.bin"), []byte("test weights"), 0o644)) + + // Pre-create lockfile. + lock := &lockfile.WeightsLock{ + Version: lockfile.Version, + Weights: []lockfile.WeightLockEntry{ + { + Name: "my-model", + Target: "/src/weights", + Digest: "sha256:abc", + SetDigest: "sha256:set123", + Size: 1024, + Source: lockfile.WeightLockSource{ + URI: "file://./weights", + }, + }, + }, + } + require.NoError(t, lock.Save(filepath.Join(dir, lockfile.WeightsLockFilename))) src := &Source{ Config: &config.Config{ Build: &config.Build{}, Weights: []config.WeightSource{ - {Name: "my-model", Source: "model.bin", Target: "/weights/model.bin"}, + {Name: "my-model", Target: "/src/weights", Source: &config.WeightSourceConfig{URI: "weights"}}, }, }, ProjectDir: dir, @@ -1196,50 +1220,19 @@ func TestResolver_Build_WithWeightsLoadsManifest(t *testing.T) { m, err := resolver.Build(context.Background(), src, BuildOptions{ ImageName: "test-image:latest", - OCIIndex: true, }) require.NoError(t, err) require.True(t, m.IsBundle()) - require.True(t, m.OCIIndex) - // Should have 2 artifacts: image + weight - require.Len(t, m.Artifacts, 2) + // 1 image artifact, 1 weight in Weights + require.Len(t, m.Artifacts, 1) require.NotNil(t, m.GetImageArtifact()) - require.Len(t, m.WeightArtifacts(), 1) - - // Weight artifacts should be populated - require.Len(t, m.WeightArtifacts(), 1) + require.Len(t, m.Weights, 1) + require.Equal(t, "my-model", m.Weights[0].Name) } func TestIndexDetectionHelpers(t *testing.T) { - t.Run("findWeightsManifest", func(t *testing.T) { - manifests := []registry.PlatformManifest{ - {Digest: "sha256:image123", OS: "linux", Architecture: "amd64"}, - { - Digest: "sha256:weights456", - OS: PlatformUnknown, - Architecture: PlatformUnknown, - Annotations: map[string]string{ - AnnotationReferenceType: AnnotationValueWeights, - }, - }, - } - - wm := findWeightsManifest(manifests) - require.NotNil(t, wm) - require.Equal(t, "sha256:weights456", wm.Digest) - }) - - t.Run("findWeightsManifest not found", func(t *testing.T) { - manifests := []registry.PlatformManifest{ - {Digest: "sha256:image123", OS: "linux", Architecture: "amd64"}, - } - - wm := findWeightsManifest(manifests) - require.Nil(t, wm) - }) - t.Run("findImageManifest", func(t *testing.T) { manifests := []registry.PlatformManifest{ {Digest: "sha256:image123", OS: "linux", Architecture: "amd64"}, diff --git a/pkg/model/source.go b/pkg/model/source.go index 979748a357..f435c9aad9 100644 --- a/pkg/model/source.go +++ b/pkg/model/source.go @@ -71,23 +71,24 @@ func NewSourceFromConfig(cfg *config.Config, projectDir string) *Source { } } -// ArtifactSpecs returns the artifact declarations derived from this source. -// Always produces at least one ImageSpec. Produces a WeightSpec for each -// weight declared in the config. Returns nil if Config is nil. -func (s *Source) ArtifactSpecs() []ArtifactSpec { +// ArtifactSpecs returns the artifact declarations derived from this +// source: one ImageSpec plus one WeightSpec per configured weight. +// Returns an error if any weight's source URI is malformed. Returns +// (nil, nil) if Config is nil. +func (s *Source) ArtifactSpecs() ([]ArtifactSpec, error) { if s.Config == nil { - return nil + return nil, nil } - var specs []ArtifactSpec + specs := []ArtifactSpec{NewImageSpec("model", s.Config.Image)} - // Always have an image artifact - specs = append(specs, NewImageSpec("model", s.Config.Image)) - - // Add weight specs from config for _, w := range s.Config.Weights { - specs = append(specs, NewWeightSpec(w.Name, w.Source, w.Target)) + ws, err := WeightSpecFromConfig(w) + if err != nil { + return nil, err + } + specs = append(specs, ws) } - return specs + return specs, nil } diff --git a/pkg/model/source_test.go b/pkg/model/source_test.go index 8bdac05375..f0a06f1def 100644 --- a/pkg/model/source_test.go +++ b/pkg/model/source_test.go @@ -42,7 +42,8 @@ func TestSource_ArtifactSpecs_NoWeights(t *testing.T) { } src := NewSourceFromConfig(cfg, "/path/to/project") - specs := src.ArtifactSpecs() + specs, err := src.ArtifactSpecs() + require.NoError(t, err) require.Len(t, specs, 1) @@ -59,13 +60,14 @@ func TestSource_ArtifactSpecs_WithWeights(t *testing.T) { Image: "r8.im/user/model", Build: &config.Build{PythonVersion: "3.11"}, Weights: []config.WeightSource{ - {Name: "llama-7b", Source: "/data/llama-7b.safetensors", Target: "/weights/llama-7b.safetensors"}, - {Name: "embeddings", Source: "/data/embeddings.bin", Target: "/weights/embeddings.bin"}, + {Name: "llama-7b", Target: "/weights/llama-7b", Source: &config.WeightSourceConfig{URI: "/data/llama-7b"}}, + {Name: "embeddings", Target: "/weights/embeddings", Source: &config.WeightSourceConfig{URI: "/data/embeddings"}}, }, } src := NewSourceFromConfig(cfg, "/path/to/project") - specs := src.ArtifactSpecs() + specs, err := src.ArtifactSpecs() + require.NoError(t, err) require.Len(t, specs, 3) // 1 image + 2 weights @@ -79,14 +81,14 @@ func TestSource_ArtifactSpecs_WithWeights(t *testing.T) { require.True(t, ok, "second spec should be *WeightSpec") require.Equal(t, ArtifactTypeWeight, w1.Type()) require.Equal(t, "llama-7b", w1.Name()) - require.Equal(t, "/data/llama-7b.safetensors", w1.Source) - require.Equal(t, "/weights/llama-7b.safetensors", w1.Target) + require.Equal(t, "file:///data/llama-7b", w1.URI) + require.Equal(t, "/weights/llama-7b", w1.Target) w2, ok := specs[2].(*WeightSpec) require.True(t, ok, "third spec should be *WeightSpec") require.Equal(t, "embeddings", w2.Name()) - require.Equal(t, "/data/embeddings.bin", w2.Source) - require.Equal(t, "/weights/embeddings.bin", w2.Target) + require.Equal(t, "file:///data/embeddings", w2.URI) + require.Equal(t, "/weights/embeddings", w2.Target) } func TestSource_ArtifactSpecs_EmptyImageName(t *testing.T) { @@ -95,7 +97,8 @@ func TestSource_ArtifactSpecs_EmptyImageName(t *testing.T) { } src := NewSourceFromConfig(cfg, "/path/to/project") - specs := src.ArtifactSpecs() + specs, err := src.ArtifactSpecs() + require.NoError(t, err) require.Len(t, specs, 1) imgSpec, ok := specs[0].(*ImageSpec) @@ -106,7 +109,22 @@ func TestSource_ArtifactSpecs_EmptyImageName(t *testing.T) { func TestSource_ArtifactSpecs_NilConfig(t *testing.T) { src := NewSourceFromConfig(nil, "/path/to/project") - specs := src.ArtifactSpecs() + specs, err := src.ArtifactSpecs() + require.NoError(t, err) require.Nil(t, specs) } + +func TestSource_ArtifactSpecs_MalformedWeightURI(t *testing.T) { + cfg := &config.Config{ + Image: "r8.im/user/model", + Build: &config.Build{PythonVersion: "3.11"}, + Weights: []config.WeightSource{ + {Name: "bad", Target: "/w", Source: &config.WeightSourceConfig{URI: "bogus://nope"}}, + }, + } + src := NewSourceFromConfig(cfg, "/path/to/project") + + _, err := src.ArtifactSpecs() + require.Error(t, err) +} diff --git a/pkg/model/weight_builder.go b/pkg/model/weight_builder.go index 8bb26cc806..78bcc726af 100644 --- a/pkg/model/weight_builder.go +++ b/pkg/model/weight_builder.go @@ -5,161 +5,365 @@ import ( "errors" "fmt" "os" - "path/filepath" + "slices" + "strings" "time" v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/types" + + "github.com/replicate/cog/pkg/model/weightsource" + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" ) -// WeightBuilder builds WeightArtifact from WeightSpec. -// It hashes the source file, creates a WeightConfig, and manages a lockfile as build cache. +// resolvedInventory holds the results of walking a weight source and +// applying include/exclude filters. Both PlanImport and Build share +// this step; caching the result avoids re-walking the source when the +// CLI plans first and then builds. +type resolvedInventory struct { + source weightsource.Source + full weightsource.Inventory // before filtering (carries fingerprint) + filtered weightsource.Inventory // after include/exclude +} + +// resolveInventory walks the source, applies filters, and returns the +// resolved inventory. This is the shared preamble for PlanImport and Build. +func (b *WeightBuilder) resolveInventory(ctx context.Context, ws *WeightSpec) (*resolvedInventory, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + src, err := weightsource.For(ws.URI, b.projectDir()) + if err != nil { + return nil, err + } + + full, err := src.Inventory(ctx) + if err != nil { + return nil, fmt.Errorf("inventory weight %q: %w", ws.Name(), err) + } + + filtered, err := weightsource.FilterInventory(full, ws.Include, ws.Exclude) + if err != nil { + return nil, fmt.Errorf("filter weight %q: %w", ws.Name(), err) + } + + return &resolvedInventory{source: src, full: full, filtered: filtered}, nil +} + +// WeightBuilder is the weight factory: given a WeightSpec (source URI + +// target), it ingresses the source files into the local +// content-addressed store, plans tar layers, derives layer digests +// (cache-fast or recompute), assembles the v1 OCI manifest, and +// returns a WeightArtifact carrying the layer descriptors and +// manifest digest. +// +// `cog weights import` ≡ `cog weights import + cog weights pull` — +// the build path leaves the local store warm so subsequent `cog +// predict` invocations can hardlink-assemble without a separate pull. +// +// The builder is offline: it never talks to a registry. The manifest +// digest it writes into the artifact descriptor is a sha256 of the +// serialized manifest bytes. type WeightBuilder struct { - source *Source - cogVersion string - lockPath string + source *Source + store store.Store + lockPath string } // NewWeightBuilder creates a WeightBuilder. -// lockPath is where the weights.lock file is read/written as a build cache. -func NewWeightBuilder(source *Source, cogVersion, lockPath string) *WeightBuilder { - return &WeightBuilder{ - source: source, - cogVersion: cogVersion, - lockPath: lockPath, - } +// +// st is the local content-addressed weight store. lockPath is where +// weights.lock is read/written. +func NewWeightBuilder(source *Source, st store.Store, lockPath string) *WeightBuilder { + return &WeightBuilder{source: source, store: st, lockPath: lockPath} } -// Build builds a WeightArtifact from a WeightSpec. -// It resolves the source file, computes its SHA256 digest, and creates the artifact -// with a versioned WeightConfig. +// Build runs the full import pipeline for one weight: +// +// 1. Inventory the source. +// 2. Ingress every file into the local store (skipping already-present +// digests). Hash mismatches surface here. +// 3. Decide whether to trust the lockfile's recorded layer digests +// (fast path) or recompute them by streaming from the store +// (recompute path). +// 4. Assemble the OCI manifest. +// 5. Stamp the current envelope format into the lockfile and rewrite +// it iff anything actually changed. +// +// The push path is independent of Build; the caller is responsible +// for handing the returned artifact to the pusher (which checks +// per-layer with BlobExists before uploading). func (b *WeightBuilder) Build(ctx context.Context, spec ArtifactSpec) (Artifact, error) { + return b.buildWithResolved(ctx, spec, nil) +} + +// BuildFromPlan is like Build but reuses a pre-computed inventory from +// PlanImport, avoiding a second walk/hash of the source. +func (b *WeightBuilder) BuildFromPlan(ctx context.Context, spec ArtifactSpec, plan *WeightImportPlan) (Artifact, error) { + return b.buildWithResolved(ctx, spec, plan.Resolved) +} + +func (b *WeightBuilder) buildWithResolved(ctx context.Context, spec ArtifactSpec, cached *resolvedInventory) (Artifact, error) { ws, ok := spec.(*WeightSpec) if !ok { return nil, fmt.Errorf("weight builder: expected *WeightSpec, got %T", spec) } + if b.store == nil { + return nil, fmt.Errorf("weight builder: store is required") + } - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: + if err := ctx.Err(); err != nil { + return nil, err } - // Resolve file path - absPath := filepath.Join(b.source.ProjectDir, ws.Source) + // Step 1: resolve inventory (reuse cached if available). + resolved := cached + if resolved == nil { + var err error + resolved, err = b.resolveInventory(ctx, ws) + if err != nil { + return nil, err + } + } + inv := resolved.filtered + + // Step 2: ingress the filtered files into the local store. + if err := ingressFromInventory(ctx, resolved.source, b.store, inv); err != nil { + return nil, fmt.Errorf("populate store for weight %q: %w", ws.Name(), err) + } + + // Step 3: decide fast-path vs recompute. + lock, err := loadLockfileOrEmpty(b.lockPath) + if err != nil { + return nil, err + } - // Stat the file to check existence and size - fi, err := os.Stat(absPath) + currentEnvelope, err := computeEnvelopeFormat(envelopeFromOptions(packOptions{})) if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil, fmt.Errorf("weight source not found: %s", ws.Source) + return nil, fmt.Errorf("compute envelope format: %w", err) + } + + existing := lock.FindWeight(ws.Name()) + pkr := newPacker(nil) + plan := pkr.planLayers(inv) + if len(plan.Layers) == 0 { + return nil, fmt.Errorf("weight %q: inventory is empty", ws.Name()) + } + + var layers []packedLayer + if canFastPath(lock, currentEnvelope, existing, ws, inv) { + layers, err = layersFromLockfile(existing, plan) + if err != nil { + // Lockfile and freshly-planned layers disagree on + // shape. Treat as a cache miss: recompute. This can + // happen if the user edited weights.lock by hand. + layers = nil } - return nil, fmt.Errorf("stat weight file %s: %w", ws.Source, err) - } - - // Check lockfile cache: if we have a matching entry (name + size), skip hashing. - // NOTE: This cache only checks name + file size. Same-size modifications (rare for - // weight files) won't be detected. Delete the lockfile to force re-hashing. - // TODO: Consider adding mtime to the cache key for stronger invalidation. - var digestStr string - var size int64 - if cached := b.findCachedEntry(ws.Name(), fi.Size()); cached != nil { - digestStr = cached.Digest - size = cached.Size - } else { - // Cache miss: hash the file - digestStr, size, err = hashFile(absPath) + } + if layers == nil { + layers, err = pkr.computeLayerDigests(ctx, b.store, plan) if err != nil { - return nil, fmt.Errorf("hash weight file %s: %w", ws.Source, err) + return nil, fmt.Errorf("compute layer digests for weight %q: %w", ws.Name(), err) } } - // Parse as v1.Hash for the descriptor - digest, err := v1.NewHash(digestStr) + entry := newWeightLockEntry( + ws.Name(), ws.Target, + lockfile.WeightLockSource{ + URI: ws.URI, + Fingerprint: inv.Fingerprint, + Include: ws.Include, + Exclude: ws.Exclude, + ImportedAt: time.Now().UTC(), + }, + packedFilesFromPlan(layers), + layers, + ) + + // buildWeightArtifact populates entry.Digest (the manifest + // digest), which EntriesEqual needs to compare meaningfully. + artifact, err := buildWeightArtifact(&entry, layers, b.store) if err != nil { - return nil, fmt.Errorf("parse digest: %w", err) + return nil, fmt.Errorf("weight %q: %w", ws.Name(), err) } - // Build the WeightConfig - cfg := WeightConfig{ - SchemaVersion: "1.0", - CogVersion: b.cogVersion, - Name: ws.Name(), - Target: ws.Target, - Created: time.Now().UTC(), + // Preserve the original ImportedAt on a content-equal rewrite so + // a format-bump-only rewrite doesn't churn the timestamp. + // EntriesEqual ignores ImportedAt by design, so this comparison + // answers "would the lockfile diff be only the timestamp?" + entryEqualsExisting := lockfile.EntriesEqual(existing, &entry) + if entryEqualsExisting { + entry.Source.ImportedAt = existing.Source.ImportedAt } - // Build the descriptor - desc := v1.Descriptor{ - Digest: digest, - Size: size, - MediaType: MediaTypeWeightLayer, + // Step 5: stamp envelope + rewrite iff anything changed. + formatChanged := lock.EnvelopeFormat != currentEnvelope + lock.EnvelopeFormat = currentEnvelope + if formatChanged || !entryEqualsExisting { + lock.Upsert(entry) + if err := lock.Save(b.lockPath); err != nil { + return nil, fmt.Errorf("update lockfile: %w", err) + } } - // Update lockfile - if err := b.updateLockfile(ws, digestStr, size); err != nil { - return nil, fmt.Errorf("update lockfile: %w", err) - } + return artifact, nil +} - return NewWeightArtifact(ws.Name(), desc, absPath, ws.Target, cfg), nil +// projectDir returns the builder's project directory, or "" if the +// builder was constructed without a Source. +func (b *WeightBuilder) projectDir() string { + if b.source == nil { + return "" + } + return b.source.ProjectDir } -// findCachedEntry checks the lockfile for an entry matching name and fileSize. -// Returns the cached WeightFile if found and size matches, nil otherwise. -func (b *WeightBuilder) findCachedEntry(name string, fileSize int64) *WeightFile { - if _, err := os.Stat(b.lockPath); err != nil { - return nil +// canFastPath reports whether the recorded lockfile entry can be +// trusted as-is for this build, allowing us to skip the +// digest-recomputation pass. +// +// Every input that determines layer bytes must agree: +// +// - The recorded EnvelopeFormat matches the current packer config. +// A miss here means cog itself produces different bytes for the +// same inventory than the version that wrote the lockfile did. +// - An entry with this name exists. +// - The user-intent fields (target, URI, include/exclude) match. +// - The source's fingerprint matches what the lockfile recorded. +// For file:// the fingerprint is the dirhash of the file set, +// so matching fingerprint ⇒ matching files. For hf:// it's the +// commit SHA, so matching fingerprint ⇒ same canonical files. +// +// Anything missing pushes us onto the recompute path. Recompute is +// cheap because the store is already warm (local I/O + sha256 + +// gzip) — the cost is local CPU, not network or source-side I/O. +func canFastPath( + lock *lockfile.WeightsLock, + currentEnvelope string, + existing *lockfile.WeightLockEntry, + ws *WeightSpec, + inv weightsource.Inventory, +) bool { + if lock.EnvelopeFormat != currentEnvelope { + return false } - lock, err := LoadWeightsLock(b.lockPath) - if err != nil { - return nil + if existing == nil { + return false } - for i, f := range lock.Files { - if f.Name == name && f.Size == fileSize { - return &lock.Files[i] - } + if existing.Target != ws.Target || + existing.Source.URI != ws.URI || + !slices.Equal(existing.Source.Include, ws.Include) || + !slices.Equal(existing.Source.Exclude, ws.Exclude) { + return false + } + if existing.Source.Fingerprint != inv.Fingerprint { + return false } - return nil + return true } -// updateLockfile loads the existing lockfile (if any), adds or updates -// the entry for the given weight, and saves it back. -func (b *WeightBuilder) updateLockfile(ws *WeightSpec, digest string, size int64) error { - // Load existing lockfile, or start fresh. - // LoadWeightsLock wraps the underlying error, so we check the raw file first. - lock := &WeightsLock{ - Version: "1.0", - Created: time.Now().UTC(), - } - if _, err := os.Stat(b.lockPath); err == nil { - existing, loadErr := LoadWeightsLock(b.lockPath) - if loadErr != nil { - return fmt.Errorf("load existing lockfile: %w", loadErr) +// layersFromLockfile reconstructs []packedLayer from a lockfile entry, +// pairing each recorded layer with the corresponding layerPlan from +// the freshly-planned layout. The plan is what reproduces layer bytes +// during push, so the planning result must be available even on the +// fast path. +// +// Returns an error if the lockfile and plan disagree on how many +// layers there are or which files they contain — a strong signal +// that the lockfile is out of sync, which makes the fast path +// unsafe. +func layersFromLockfile(entry *lockfile.WeightLockEntry, pl plan) ([]packedLayer, error) { + if len(entry.Layers) != len(pl.Layers) { + return nil, fmt.Errorf("layer count mismatch: lockfile has %d, plan has %d", len(entry.Layers), len(pl.Layers)) + } + + // Index plan layers by their content signature so we can match + // them to lockfile layers regardless of ordering. Signature is + // the sorted file digests within the layer; that's what + // determines the tar bytes. + planByKey := make(map[string]layerPlan, len(pl.Layers)) + for _, lp := range pl.Layers { + planByKey[layerKey(lp)] = lp + } + + out := make([]packedLayer, 0, len(entry.Layers)) + for _, lk := range entry.Layers { + // Find the plan layer whose files match this locked layer. + // We compare by file digests (the file→layer mapping in + // entry.Files tells us which files belong to this layer). + want := lockedLayerKey(entry, lk.Digest) + lp, ok := planByKey[want] + if !ok { + return nil, fmt.Errorf("locked layer %s has no matching plan layer", lk.Digest) } - lock = existing - } - - entry := WeightFile{ - Name: ws.Name(), - Dest: ws.Target, - Digest: digest, - DigestOriginal: digest, - Size: size, - SizeUncompressed: size, - MediaType: MediaTypeWeightLayer, - } - - // Update existing entry or append - updated := false - for i, f := range lock.Files { - if f.Name == ws.Name() { - lock.Files[i] = entry - updated = true - break + hash, err := v1.NewHash(lk.Digest) + if err != nil { + return nil, fmt.Errorf("parse locked layer digest %s: %w", lk.Digest, err) } + out = append(out, packedLayer{ + Plan: lp, + Digest: hash, + Size: lk.Size, + UncompressedSize: lk.SizeUncompressed, + MediaType: types.MediaType(lk.MediaType), + }) } - if !updated { - lock.Files = append(lock.Files, entry) + return out, nil +} + +// loadLockfileOrEmpty loads the lockfile at path. A missing file is not +// an error — it yields a fresh empty lockfile. +func loadLockfileOrEmpty(path string) (*lockfile.WeightsLock, error) { + lock, err := lockfile.LoadWeightsLock(path) + if err == nil { + return lock, nil } + if errors.Is(err, os.ErrNotExist) { + return &lockfile.WeightsLock{Version: lockfile.Version}, nil + } + return nil, err +} - return lock.Save(b.lockPath) +// layerKey returns a content signature for a layerPlan: the joined +// file digests in tar-emission order. Two planLayers with identical +// keys produce identical tar bytes and therefore identical layer +// digests (modulo envelope-format concerns, which the caller handles +// separately). +func layerKey(lp layerPlan) string { + digests := make([]string, len(lp.Files)) + for i, f := range lp.Files { + digests[i] = f.Digest + } + return strings.Join(digests, "\n") +} + +// lockedLayerKey returns the layerKey for a recorded layer in entry, +// reconstructed by collecting the file digests of every entry.Files +// member that points at the given layerDigest. +// +// Result is sorted by inventory path so it matches a planLayer whose +// Files are in the packer's emission order (small-file bundles are +// path-sorted; single-file layers carry one file). For multi-file +// bundles, both this function and planLayers sort by path; for +// single-file layers, both have one entry. Either way the keys match +// when the underlying file set matches. +func lockedLayerKey(entry *lockfile.WeightLockEntry, layerDigest string) string { + type fd struct { + path string + digest string + } + var fs []fd + for _, f := range entry.Files { + if f.Layer == layerDigest { + fs = append(fs, fd{path: f.Path, digest: f.Digest}) + } + } + slices.SortFunc(fs, func(a, b fd) int { return strings.Compare(a.path, b.path) }) + digests := make([]string, len(fs)) + for i, f := range fs { + digests[i] = f.digest + } + return strings.Join(digests, "\n") } diff --git a/pkg/model/weight_builder_test.go b/pkg/model/weight_builder_test.go index 17c188d2a5..3d066ae5a7 100644 --- a/pkg/model/weight_builder_test.go +++ b/pkg/model/weight_builder_test.go @@ -2,286 +2,526 @@ package model import ( "context" - "crypto/sha256" - "encoding/hex" "os" "path/filepath" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" ) -func TestWeightBuilder_HappyPath(t *testing.T) { - // Setup: real temp file as a weight source - tmpDir := t.TempDir() - weightContent := []byte("test weight data for builder") - weightFile := filepath.Join(tmpDir, "model.safetensors") - err := os.WriteFile(weightFile, weightContent, 0o644) - require.NoError(t, err) +// makeWeightDir writes files into / and returns both +// absolute and relative paths. The contents are small enough to land in a +// single bundle layer under the default pack thresholds. +func makeWeightDir(t *testing.T, projectDir, relDir string, files map[string][]byte) { + t.Helper() + abs := filepath.Join(projectDir, relDir) + require.NoError(t, os.MkdirAll(abs, 0o755)) + for name, data := range files { + full := filepath.Join(abs, name) + require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) + require.NoError(t, os.WriteFile(full, data, 0o644)) + } +} - // Compute expected digest - hash := sha256.Sum256(weightContent) - expectedDigest := "sha256:" + hex.EncodeToString(hash[:]) +// newTestBuilder constructs a WeightBuilder rooted in projectDir with a +// fresh FileStore in t.TempDir() — the canonical fixture for builder +// tests. Returns the builder and the store so tests that need to +// inspect or pre-populate the store can reach it. +func newTestBuilder(t *testing.T, projectDir string, weights []config.WeightSource) (*WeightBuilder, *store.FileStore) { + t.Helper() + src := NewSourceFromConfig(&config.Config{Weights: weights}, projectDir) + st, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + lockPath := filepath.Join(projectDir, "weights.lock") + return NewWeightBuilder(src, st, lockPath), st +} - // Create source with config that has one weight - src := NewSourceFromConfig(&config.Config{ - Weights: []config.WeightSource{ - {Name: "my-model", Source: "model.safetensors", Target: "/srv/weights/model.safetensors"}, - }, - }, tmpDir) +func testWeightSpec(t *testing.T, name, uri, target string) *WeightSpec { + t.Helper() + spec, err := WeightSpecFromConfig(config.WeightSource{ + Name: name, + Target: target, + Source: &config.WeightSourceConfig{URI: uri}, + }) + require.NoError(t, err) + return spec +} - // Create a WeightBuilder - lockPath := filepath.Join(tmpDir, "weights.lock") - wb := NewWeightBuilder(src, "0.15.0", lockPath) +func TestWeightBuilder_HappyPath(t *testing.T) { + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "weights/my-model", map[string][]byte{ + "config.json": []byte(`{"hidden_size": 768}`), + "tokenizer.json": []byte(`{"vocab_size": 50257}`), + }) - // Build from the weight spec - spec := NewWeightSpec("my-model", "model.safetensors", "/srv/weights/model.safetensors") + wb, _ := newTestBuilder(t, projectDir, []config.WeightSource{ + {Name: "my-model", Target: "/src/weights/my-model", Source: &config.WeightSourceConfig{URI: "weights/my-model"}}, + }) + spec := testWeightSpec(t, "my-model", "weights/my-model", "/src/weights/my-model") artifact, err := wb.Build(context.Background(), spec) require.NoError(t, err) - require.NotNil(t, artifact) - // Type assertion: should be a *WeightArtifact wa, ok := artifact.(*WeightArtifact) require.True(t, ok, "expected *WeightArtifact, got %T", artifact) - // Check artifact interface require.Equal(t, ArtifactTypeWeight, wa.Type()) require.Equal(t, "my-model", wa.Name()) + require.Equal(t, "/src/weights/my-model", wa.Entry.Target) + require.NotEmpty(t, wa.Entry.SetDigest, "builder should compute SetDigest") + require.NotEmpty(t, wa.Entry.Files, "builder should populate Files") + + // At least one layer (the bundled small files). + require.NotEmpty(t, wa.Layers, "expected at least one layer") + for _, l := range wa.Layers { + require.NotEmpty(t, l.Digest.Hex) + require.Greater(t, l.Size, int64(0)) + require.NotEmpty(t, l.Plan.Files, + "layer should retain its plan for streaming on push") + } - // Check descriptor + // Manifest descriptor should be populated without needing a registry. desc := wa.Descriptor() - require.Equal(t, expectedDigest, desc.Digest.String()) - require.Equal(t, int64(len(weightContent)), desc.Size) - - // Check weight-specific fields - require.Equal(t, weightFile, wa.FilePath) - require.Equal(t, "/srv/weights/model.safetensors", wa.Target) - - // Check WeightConfig - require.Equal(t, "1.0", wa.Config.SchemaVersion) - require.Equal(t, "0.15.0", wa.Config.CogVersion) - require.Equal(t, "my-model", wa.Config.Name) - require.Equal(t, "/srv/weights/model.safetensors", wa.Config.Target) - require.False(t, wa.Config.Created.IsZero(), "Created should be set") + require.NotEmpty(t, desc.Digest.Hex) + require.Greater(t, desc.Size, int64(0)) } -func TestWeightBuilder_WritesLockfile(t *testing.T) { - // After Build(), a weights.lock should be written/updated at lockPath. - tmpDir := t.TempDir() - weightContent := []byte("lockfile test weight") - err := os.WriteFile(filepath.Join(tmpDir, "model.bin"), weightContent, 0o644) +func TestWeightBuilder_PopulatesStore(t *testing.T) { + // Core promise of cog-i12u: after Build, every file from the + // inventory exists in the local content store. cog predict can + // then hardlink-assemble without a separate `cog weights pull`. + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "w", map[string][]byte{ + "a.json": []byte(`{"x":1}`), + "b.json": []byte(`{"y":2}`), + }) + + wb, st := newTestBuilder(t, projectDir, []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: "w"}}, + }) + + spec := testWeightSpec(t, "w", "w", "/src/w") + art, err := wb.Build(context.Background(), spec) require.NoError(t, err) + wa := art.(*WeightArtifact) - hash := sha256.Sum256(weightContent) - expectedDigest := "sha256:" + hex.EncodeToString(hash[:]) + for _, f := range wa.Entry.Files { + ok, err := st.Exists(context.Background(), f.Digest) + require.NoError(t, err) + assert.True(t, ok, "file %s (%s) should be in the store after Build", f.Path, f.Digest) + } +} +func TestWeightBuilder_FastPath_PopulatesEmptyStore(t *testing.T) { + // Scenario: the lockfile is present (e.g. checked into git) but + // the local store is cold (e.g. fresh clone, or a brand-new + // machine). Build must still ingress every file into the store + // — so `cog predict` works without a separate `cog weights + // pull` even on the fast path. + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "w", map[string][]byte{ + "a.json": []byte(`{"x":1}`), + "b.json": []byte(`{"y":2}`), + }) + + // First, do a normal build to write the lockfile. + wb1, _ := newTestBuilder(t, projectDir, []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: "w"}}, + }) + spec := testWeightSpec(t, "w", "w", "/src/w") + _, err := wb1.Build(context.Background(), spec) + require.NoError(t, err) + + // Now: same project, same lockfile on disk, but a brand-new + // (empty) store. This is the "fresh clone" scenario. src := NewSourceFromConfig(&config.Config{ Weights: []config.WeightSource{ - {Name: "my-model", Source: "model.bin", Target: "/weights/model.bin"}, + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: "w"}}, }, - }, tmpDir) + }, projectDir) + freshStore, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + wb2 := NewWeightBuilder(src, freshStore, filepath.Join(projectDir, "weights.lock")) + art, err := wb2.Build(context.Background(), spec) + require.NoError(t, err) + wa := art.(*WeightArtifact) + + // Every file in the lockfile must now be in the cold store too. + for _, f := range wa.Entry.Files { + ok, err := freshStore.Exists(context.Background(), f.Digest) + require.NoError(t, err) + assert.True(t, ok, + "fast-path build with cold store must populate file %s (%s)", + f.Path, f.Digest) + } +} - lockPath := filepath.Join(tmpDir, "weights.lock") - wb := NewWeightBuilder(src, "0.15.0", lockPath) +func TestWeightBuilder_StampsEnvelopeFormat(t *testing.T) { + // Every successful Build must stamp the current envelope format + // into the lockfile. This is the field that lets future imports + // detect cog-version drift in packer behavior. + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "w", map[string][]byte{"a.json": []byte(`{"x":1}`)}) - spec := NewWeightSpec("my-model", "model.bin", "/weights/model.bin") - _, err = wb.Build(context.Background(), spec) + wb, _ := newTestBuilder(t, projectDir, []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: "w"}}, + }) + + spec := testWeightSpec(t, "w", "w", "/src/w") + _, err := wb.Build(context.Background(), spec) require.NoError(t, err) - // Lockfile should exist - _, err = os.Stat(lockPath) - require.NoError(t, err, "lockfile should be created") + lock, err := lockfile.LoadWeightsLock(filepath.Join(projectDir, "weights.lock")) + require.NoError(t, err) - // Load and verify lockfile contents - lock, err := LoadWeightsLock(lockPath) + want, err := computeEnvelopeFormat(envelopeFromOptions(packOptions{})) require.NoError(t, err) - require.Equal(t, "1.0", lock.Version) - require.Len(t, lock.Files, 1) - - wf := lock.Files[0] - require.Equal(t, "my-model", wf.Name) - require.Equal(t, "/weights/model.bin", wf.Dest) - require.Equal(t, expectedDigest, wf.Digest) - require.Equal(t, int64(len(weightContent)), wf.Size) + assert.Equal(t, want, lock.EnvelopeFormat, + "lockfile must stamp the current envelope format") } -func TestWeightBuilder_UpdatesExistingLockfile(t *testing.T) { - // If a lockfile already exists with entries, Build() should add/update the entry - // for the built weight without losing other entries. - tmpDir := t.TempDir() - - // Create two weight files - content1 := []byte("weight one data") - content2 := []byte("weight two data") - err := os.WriteFile(filepath.Join(tmpDir, "w1.bin"), content1, 0o644) +func TestWeightBuilder_EnvelopeFormatMismatch_TriggersRecompute(t *testing.T) { + // If the lockfile's recorded EnvelopeFormat doesn't match the + // current envelope (e.g. after a cog upgrade with a packer + // behavior change), Build must recompute layer digests rather + // than trust the lockfile's recorded values. Simulate the drift + // by writing a stale envelopeFormat into the lockfile and + // confirm Build rewrites it to the current value. + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "w", map[string][]byte{"a.json": []byte(`{"x":1}`)}) + + wb, _ := newTestBuilder(t, projectDir, []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: "w"}}, + }) + spec := testWeightSpec(t, "w", "w", "/src/w") + _, err := wb.Build(context.Background(), spec) require.NoError(t, err) - err = os.WriteFile(filepath.Join(tmpDir, "w2.bin"), content2, 0o644) + + lockPath := filepath.Join(projectDir, "weights.lock") + + // Corrupt the recorded EnvelopeFormat on disk. + lock, err := lockfile.LoadWeightsLock(lockPath) require.NoError(t, err) + lock.EnvelopeFormat = "sha256:0000000000000000000000000000000000000000000000000000000000000000" + require.NoError(t, lock.Save(lockPath)) - src := NewSourceFromConfig(&config.Config{ - Weights: []config.WeightSource{ - {Name: "weight-1", Source: "w1.bin", Target: "/weights/w1.bin"}, - {Name: "weight-2", Source: "w2.bin", Target: "/weights/w2.bin"}, - }, - }, tmpDir) + // Rebuild — recompute path should fire and stamp the correct + // envelope. + _, err = wb.Build(context.Background(), spec) + require.NoError(t, err) - lockPath := filepath.Join(tmpDir, "weights.lock") - wb := NewWeightBuilder(src, "0.15.0", lockPath) + lock, err = lockfile.LoadWeightsLock(lockPath) + require.NoError(t, err) + want, err := computeEnvelopeFormat(envelopeFromOptions(packOptions{})) + require.NoError(t, err) + assert.Equal(t, want, lock.EnvelopeFormat, + "recompute path must stamp the current envelope format") +} - // Build first weight - spec1 := NewWeightSpec("weight-1", "w1.bin", "/weights/w1.bin") - _, err = wb.Build(context.Background(), spec1) +func TestWeightBuilder_FastPath_NoOpRebuild(t *testing.T) { + // Build the same source twice. Second build's source fingerprint + // matches the lockfile's recorded fingerprint, so canFastPath + // returns true and Build trusts the recorded layer digests + // without recomputing. The lockfile's mtime stays put (no write + // since EntriesEqual returns true), and the manifest digest is + // identical to the first build's. + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "w", map[string][]byte{"a.json": []byte(`{"x":1}`)}) + + wb, _ := newTestBuilder(t, projectDir, []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: "w"}}, + }) + spec := testWeightSpec(t, "w", "w", "/src/w") + first, err := wb.Build(context.Background(), spec) require.NoError(t, err) + fa := first.(*WeightArtifact) - // Build second weight - spec2 := NewWeightSpec("weight-2", "w2.bin", "/weights/w2.bin") - _, err = wb.Build(context.Background(), spec2) + lockPath := filepath.Join(projectDir, "weights.lock") + infoBefore, err := os.Stat(lockPath) require.NoError(t, err) - // Lockfile should contain both entries - lock, err := LoadWeightsLock(lockPath) + second, err := wb.Build(context.Background(), spec) require.NoError(t, err) - require.Len(t, lock.Files, 2) + sa := second.(*WeightArtifact) - names := map[string]bool{} - for _, f := range lock.Files { - names[f.Name] = true - } - require.True(t, names["weight-1"]) - require.True(t, names["weight-2"]) + assert.Equal(t, fa.Descriptor().Digest, sa.Descriptor().Digest, + "unchanged source must produce identical manifest digest") + + infoAfter, err := os.Stat(lockPath) + require.NoError(t, err) + assert.Equal(t, infoBefore.ModTime(), infoAfter.ModTime(), + "unchanged-source rebuild must not rewrite weights.lock") } -func TestWeightBuilder_CacheHit(t *testing.T) { - // When a lockfile entry exists with matching name and size, - // the builder should use the cached digest without re-hashing. - tmpDir := t.TempDir() - weightContent := []byte("cached weight data") - err := os.WriteFile(filepath.Join(tmpDir, "model.bin"), weightContent, 0o644) +func TestWeightBuilder_WritesLockfile(t *testing.T) { + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "weights/mw", map[string][]byte{ + "config.json": []byte(`{"x": 1}`), + "tokenizer.json": []byte(`{"y": 2}`), + }) + + wb, _ := newTestBuilder(t, projectDir, []config.WeightSource{ + {Name: "mw", Target: "/src/weights/mw", Source: &config.WeightSourceConfig{URI: "weights/mw"}}, + }) + + spec := testWeightSpec(t, "mw", "weights/mw", "/src/weights/mw") + artifact, err := wb.Build(context.Background(), spec) require.NoError(t, err) - hash := sha256.Sum256(weightContent) - expectedDigest := "sha256:" + hex.EncodeToString(hash[:]) + wa := artifact.(*WeightArtifact) - src := NewSourceFromConfig(&config.Config{ - Weights: []config.WeightSource{ - {Name: "my-model", Source: "model.bin", Target: "/weights/model.bin"}, - }, - }, tmpDir) + lockPath := filepath.Join(projectDir, "weights.lock") + lock, err := lockfile.LoadWeightsLock(lockPath) + require.NoError(t, err) + require.Equal(t, lockfile.Version, lock.Version) + require.Len(t, lock.Weights, 1) + + entry := lock.Weights[0] + require.Equal(t, "mw", entry.Name) + require.Equal(t, "/src/weights/mw", entry.Target) + require.Equal(t, wa.Descriptor().Digest.String(), entry.Digest) + require.Equal(t, wa.Entry.SetDigest, entry.SetDigest) + require.Len(t, entry.Layers, len(wa.Layers)) + + // Source block is populated with the normalized URI, a sha256 + // fingerprint, and empty include/exclude patterns. + require.Equal(t, "file://./weights/mw", entry.Source.URI) + require.Equal(t, "sha256", entry.Source.Fingerprint.Scheme()) + require.Equal(t, wa.Entry.SetDigest, entry.Source.Fingerprint.String(), + "file:// fingerprint is the set digest") + require.NotNil(t, entry.Source.Include) + require.NotNil(t, entry.Source.Exclude) + require.Empty(t, entry.Source.Include) + require.Empty(t, entry.Source.Exclude) + require.False(t, entry.Source.ImportedAt.IsZero()) + + // File index is populated and sorted by path. + require.Len(t, entry.Files, 2) + require.Equal(t, "config.json", entry.Files[0].Path) + require.Equal(t, "tokenizer.json", entry.Files[1].Path) + for _, f := range entry.Files { + require.NotEmpty(t, f.Digest) + require.NotEmpty(t, f.Layer) + require.Greater(t, f.Size, int64(0)) + } + + // Layer descriptors sorted by digest, carry compressed + uncompressed sizes. + for i := 1; i < len(entry.Layers); i++ { + require.Less(t, entry.Layers[i-1].Digest, entry.Layers[i].Digest, + "layers must be sorted by digest") + } + for _, l := range entry.Layers { + require.NotEmpty(t, l.Digest) + require.NotEmpty(t, l.MediaType) + require.Greater(t, l.Size, int64(0)) + require.Greater(t, l.SizeUncompressed, int64(0)) + } + + // Totals match sums. + var wantSize, wantCompressed int64 + for _, l := range entry.Layers { + wantSize += l.SizeUncompressed + wantCompressed += l.Size + } + require.Equal(t, wantSize, entry.Size) + require.Equal(t, wantCompressed, entry.SizeCompressed) +} + +func TestWeightBuilder_UpdatesExistingLockfile(t *testing.T) { + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "w1", map[string][]byte{"a.json": []byte(`{"w":1}`)}) + makeWeightDir(t, projectDir, "w2", map[string][]byte{"b.json": []byte(`{"w":2}`)}) - lockPath := filepath.Join(tmpDir, "weights.lock") - wb := NewWeightBuilder(src, "0.15.0", lockPath) + wb, _ := newTestBuilder(t, projectDir, []config.WeightSource{ + {Name: "w1", Target: "/src/w1", Source: &config.WeightSourceConfig{URI: "w1"}}, + {Name: "w2", Target: "/src/w2", Source: &config.WeightSourceConfig{URI: "w2"}}, + }) - // First build — populates lockfile - spec := NewWeightSpec("my-model", "model.bin", "/weights/model.bin") - artifact1, err := wb.Build(context.Background(), spec) + _, err := wb.Build(context.Background(), testWeightSpec(t, "w1", "w1", "/src/w1")) + require.NoError(t, err) + _, err = wb.Build(context.Background(), testWeightSpec(t, "w2", "w2", "/src/w2")) require.NoError(t, err) - // Second build — should hit cache - artifact2, err := wb.Build(context.Background(), spec) + lock, err := lockfile.LoadWeightsLock(filepath.Join(projectDir, "weights.lock")) require.NoError(t, err) + require.Len(t, lock.Weights, 2) + + names := map[string]bool{} + for _, w := range lock.Weights { + names[w.Name] = true + } + require.True(t, names["w1"]) + require.True(t, names["w2"]) +} + +func TestWeightBuilder_FastPath_UpdatesConfigFields(t *testing.T) { + // Config-driven fields (target, source URI) can change in + // cog.yaml without the source content changing. The fast path + // must stamp the current values into the lockfile so weights + // status doesn't report the weight as stale. + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "w", map[string][]byte{"a.json": []byte(`{"x":1}`)}) + + oldTarget := "/src/w" + newTarget := "/src/w-moved" - // Both builds should produce the same digest - wa1 := artifact1.(*WeightArtifact) - wa2 := artifact2.(*WeightArtifact) - require.Equal(t, expectedDigest, wa1.Descriptor().Digest.String()) - require.Equal(t, expectedDigest, wa2.Descriptor().Digest.String()) + wb, _ := newTestBuilder(t, projectDir, []config.WeightSource{ + {Name: "w", Target: oldTarget, Source: &config.WeightSourceConfig{URI: "w"}}, + }) - // Lockfile should still have exactly one entry (not duplicated) - lock, err := LoadWeightsLock(lockPath) + // First build writes the lockfile with the old target. + spec := testWeightSpec(t, "w", "w", oldTarget) + first, err := wb.Build(context.Background(), spec) require.NoError(t, err) - require.Len(t, lock.Files, 1) - require.Equal(t, "my-model", lock.Files[0].Name) -} + fa := first.(*WeightArtifact) -func TestWeightBuilder_CacheMiss_SizeChanged(t *testing.T) { - // When the file size changes, the builder should re-hash. - tmpDir := t.TempDir() - content1 := []byte("original content") - err := os.WriteFile(filepath.Join(tmpDir, "model.bin"), content1, 0o644) + lockPath := filepath.Join(projectDir, "weights.lock") + lock, err := lockfile.LoadWeightsLock(lockPath) + require.NoError(t, err) + require.Equal(t, oldTarget, lock.Weights[0].Target) + require.Equal(t, "file://./w", lock.Weights[0].Source.URI) + + // Second build: same name, same source dir, different target. + // Layers should be reused (fast path) but the target must be + // stamped into the lockfile. + spec2 := testWeightSpec(t, "w", "./w", newTarget) + second, err := wb.Build(context.Background(), spec2) require.NoError(t, err) + sa := second.(*WeightArtifact) - src := NewSourceFromConfig(&config.Config{ - Weights: []config.WeightSource{ - {Name: "my-model", Source: "model.bin", Target: "/weights/model.bin"}, - }, - }, tmpDir) + // Layers reused via fast path. + require.Equal(t, fa.Layers[0].Digest, sa.Layers[0].Digest, + "fast path should reuse the same layers") - lockPath := filepath.Join(tmpDir, "weights.lock") - wb := NewWeightBuilder(src, "0.15.0", lockPath) + lock2, err := lockfile.LoadWeightsLock(lockPath) + require.NoError(t, err) + require.Len(t, lock2.Weights, 1) + require.Equal(t, newTarget, lock2.Weights[0].Target, + "fast-path rebuild must update the target in the lockfile") - spec := NewWeightSpec("my-model", "model.bin", "/weights/model.bin") + require.Equal(t, "file://./w", lock2.Weights[0].Source.URI, + "normalized source URI must be preserved") + require.Equal(t, newTarget, sa.Entry.Target) +} - // First build - _, err = wb.Build(context.Background(), spec) - require.NoError(t, err) +func TestWeightBuilder_CacheMiss_ContentsChanged(t *testing.T) { + projectDir := t.TempDir() + weightDir := "w" + makeWeightDir(t, projectDir, weightDir, map[string][]byte{"a.json": []byte(`{"x":1}`)}) - // Change the file (different size) - content2 := []byte("modified content with different size!!") - err = os.WriteFile(filepath.Join(tmpDir, "model.bin"), content2, 0o644) + wb, _ := newTestBuilder(t, projectDir, []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: weightDir}}, + }) + + spec := testWeightSpec(t, "w", weightDir, "/src/w") + first, err := wb.Build(context.Background(), spec) require.NoError(t, err) + fa := first.(*WeightArtifact) + + // Change the file content (different bytes => different digest). + // canFastPath detects this through Source.Fingerprint mismatch + // (fingerprint is the dirhash of the file set for file://) and + // falls back to recompute. + require.NoError(t, os.WriteFile( + filepath.Join(projectDir, weightDir, "a.json"), + []byte(`{"x":2,"y":3}`), 0o644)) - // Second build — should detect size change and re-hash - artifact2, err := wb.Build(context.Background(), spec) + second, err := wb.Build(context.Background(), spec) require.NoError(t, err) + sa := second.(*WeightArtifact) - wa2 := artifact2.(*WeightArtifact) - hash2 := sha256.Sum256(content2) - expectedDigest2 := "sha256:" + hex.EncodeToString(hash2[:]) - require.Equal(t, expectedDigest2, wa2.Descriptor().Digest.String()) - require.Equal(t, int64(len(content2)), wa2.Descriptor().Size) + require.NotEqual(t, fa.Descriptor().Digest, sa.Descriptor().Digest, + "changed content should yield a different manifest digest") } func TestWeightBuilder_ErrorWrongSpecType(t *testing.T) { - tmpDir := t.TempDir() - src := NewSourceFromConfig(&config.Config{}, tmpDir) - lockPath := filepath.Join(tmpDir, "weights.lock") - wb := NewWeightBuilder(src, "0.15.0", lockPath) + projectDir := t.TempDir() + wb, _ := newTestBuilder(t, projectDir, nil) - // Pass an ImageSpec instead of WeightSpec imageSpec := NewImageSpec("model", "test-image") _, err := wb.Build(context.Background(), imageSpec) require.Error(t, err) require.Contains(t, err.Error(), "expected *WeightSpec") } -func TestWeightBuilder_ErrorFileNotFound(t *testing.T) { - tmpDir := t.TempDir() - src := NewSourceFromConfig(&config.Config{}, tmpDir) - lockPath := filepath.Join(tmpDir, "weights.lock") - wb := NewWeightBuilder(src, "0.15.0", lockPath) +func TestWeightBuilder_ErrorSourceNotFound(t *testing.T) { + projectDir := t.TempDir() + wb, _ := newTestBuilder(t, projectDir, nil) - spec := NewWeightSpec("missing", "nonexistent.bin", "/weights/nonexistent.bin") + spec := testWeightSpec(t, "missing", "nonexistent-dir", "/src/missing") _, err := wb.Build(context.Background(), spec) require.Error(t, err) require.Contains(t, err.Error(), "weight source not found") } +func TestWeightBuilder_ErrorSourceIsFile(t *testing.T) { + projectDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(projectDir, "oops.bin"), []byte("data"), 0o644)) + + wb, _ := newTestBuilder(t, projectDir, nil) + + spec := testWeightSpec(t, "oops", "oops.bin", "/src/oops") + _, err := wb.Build(context.Background(), spec) + require.Error(t, err) + require.Contains(t, err.Error(), "is not a directory") +} + func TestWeightBuilder_ErrorContextCancelled(t *testing.T) { - tmpDir := t.TempDir() - err := os.WriteFile(filepath.Join(tmpDir, "model.bin"), []byte("data"), 0o644) - require.NoError(t, err) + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "w", map[string][]byte{"a.json": []byte(`{"x":1}`)}) - src := NewSourceFromConfig(&config.Config{}, tmpDir) - lockPath := filepath.Join(tmpDir, "weights.lock") - wb := NewWeightBuilder(src, "0.15.0", lockPath) + wb, _ := newTestBuilder(t, projectDir, nil) ctx, cancel := context.WithCancel(context.Background()) - cancel() // Cancel immediately + cancel() - spec := NewWeightSpec("model", "model.bin", "/weights/model.bin") - _, err = wb.Build(ctx, spec) + spec := testWeightSpec(t, "w", "w", "/src/w") + _, err := wb.Build(ctx, spec) require.Error(t, err) require.ErrorIs(t, err, context.Canceled) } func TestWeightBuilder_ImplementsBuilderInterface(t *testing.T) { - tmpDir := t.TempDir() - src := NewSourceFromConfig(&config.Config{}, tmpDir) - lockPath := filepath.Join(tmpDir, "weights.lock") + projectDir := t.TempDir() + wb, _ := newTestBuilder(t, projectDir, nil) + var _ Builder = wb +} - // Compile-time check - var _ Builder = NewWeightBuilder(src, "0.1.0", lockPath) +func TestWeightBuilder_NormalizesSourceURI(t *testing.T) { + // Different bare-path spellings of the same directory should + // produce the same normalized URI in the lockfile. + tests := []struct { + name string + rawURI string + wantURI string + }{ + {"bare relative", "weights/mw", "file://./weights/mw"}, + {"dot prefix", "./weights/mw", "file://./weights/mw"}, + {"file scheme", "file://./weights/mw", "file://./weights/mw"}, + {"redundant slashes", "weights//mw", "file://./weights/mw"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "weights/mw", map[string][]byte{"c.json": []byte(`{}`)}) + + wb, _ := newTestBuilder(t, projectDir, []config.WeightSource{ + {Name: "mw", Target: "/src/weights/mw", Source: &config.WeightSourceConfig{URI: tc.rawURI}}, + }) + spec := testWeightSpec(t, "mw", tc.rawURI, "/src/weights/mw") + _, err := wb.Build(context.Background(), spec) + require.NoError(t, err) + + lock, err := lockfile.LoadWeightsLock(filepath.Join(projectDir, "weights.lock")) + require.NoError(t, err) + require.Len(t, lock.Weights, 1) + require.Equal(t, tc.wantURI, lock.Weights[0].Source.URI) + }) + } } diff --git a/pkg/model/weight_import_plan.go b/pkg/model/weight_import_plan.go new file mode 100644 index 0000000000..ab25cde2e5 --- /dev/null +++ b/pkg/model/weight_import_plan.go @@ -0,0 +1,138 @@ +package model + +import ( + "context" + "fmt" + "slices" + + "github.com/replicate/cog/pkg/model/weightsource" +) + +// WeightImportPlanStatus describes what would happen to a weight on import. +type WeightImportPlanStatus string + +const ( + PlanStatusNew WeightImportPlanStatus = "new" + PlanStatusUnchanged WeightImportPlanStatus = "unchanged" + PlanStatusConfigChanged WeightImportPlanStatus = "config-changed" + PlanStatusUpstreamChanged WeightImportPlanStatus = "upstream-changed" +) + +// WeightImportPlan is the result of planning one weight's import without +// executing it. It contains everything needed to show the user what +// would happen and to pass pre-computed inventory into Build. +type WeightImportPlan struct { + Spec *WeightSpec + + Status WeightImportPlanStatus + Changes []string // human-readable list of what changed + + // Resolved is the pre-computed inventory from planning. Build can + // reuse this to avoid re-walking the source. + Resolved *resolvedInventory + + // UnfilteredFiles is populated when include/exclude patterns are + // active, so the caller can show what was excluded. + UnfilteredFiles []weightsource.InventoryFile +} + +// FilteredFiles returns the filtered inventory files. +func (p *WeightImportPlan) FilteredFiles() []weightsource.InventoryFile { + return p.Resolved.filtered.Files +} + +// TotalSize returns the sum of filtered file sizes. +func (p *WeightImportPlan) TotalSize() int64 { + var total int64 + for _, f := range p.Resolved.filtered.Files { + total += f.Size + } + return total +} + +// ExcludedFiles returns files that were in the unfiltered inventory but +// not in the filtered set. +func (p *WeightImportPlan) ExcludedFiles() []weightsource.InventoryFile { + if len(p.UnfilteredFiles) == 0 { + return nil + } + included := make(map[string]bool, len(p.Resolved.filtered.Files)) + for _, f := range p.Resolved.filtered.Files { + included[f.Path] = true + } + var excluded []weightsource.InventoryFile + for _, f := range p.UnfilteredFiles { + if !included[f.Path] { + excluded = append(excluded, f) + } + } + return excluded +} + +// PlanImport runs the inventory + filter steps for one weight without +// ingressing, packing, or pushing. It compares the result against the +// existing lockfile to determine what would change on a real import. +func (b *WeightBuilder) PlanImport(ctx context.Context, ws *WeightSpec) (*WeightImportPlan, error) { + resolved, err := b.resolveInventory(ctx, ws) + if err != nil { + return nil, err + } + + plan := &WeightImportPlan{ + Spec: ws, + Resolved: resolved, + } + + // Keep the unfiltered set if patterns are active. + if len(ws.Include) > 0 || len(ws.Exclude) > 0 { + plan.UnfilteredFiles = resolved.full.Files + } + + // Compare against lockfile. + lock, err := loadLockfileOrEmpty(b.lockPath) + if err != nil { + return nil, err + } + + existing := lock.FindWeight(ws.Name()) + if existing == nil { + plan.Status = PlanStatusNew + return plan, nil + } + + lockSpec := WeightSpecFromLock(*existing) + if !ws.Equal(lockSpec) { + plan.Status = PlanStatusConfigChanged + plan.Changes = describeSpecDrift(ws, lockSpec) + return plan, nil + } + + if existing.Source.Fingerprint != resolved.full.Fingerprint { + plan.Status = PlanStatusUpstreamChanged + plan.Changes = []string{fmt.Sprintf("fingerprint: %s → %s", + existing.Source.Fingerprint, resolved.full.Fingerprint)} + return plan, nil + } + + plan.Status = PlanStatusUnchanged + return plan, nil +} + +// describeSpecDrift returns human-readable descriptions of what differs +// between the config spec and the lockfile spec. +func describeSpecDrift(config, lock *WeightSpec) []string { + var changes []string + if config.URI != lock.URI { + changes = append(changes, fmt.Sprintf("uri: %q → %q", lock.URI, config.URI)) + } + if config.Target != lock.Target { + changes = append(changes, fmt.Sprintf("target: %q → %q", lock.Target, config.Target)) + } + if !slices.Equal(config.Include, lock.Include) { + changes = append(changes, fmt.Sprintf("include: %v → %v", lock.Include, config.Include)) + } + if !slices.Equal(config.Exclude, lock.Exclude) { + changes = append(changes, fmt.Sprintf("exclude: %v → %v", lock.Exclude, config.Exclude)) + } + return changes +} diff --git a/pkg/model/weight_import_plan_test.go b/pkg/model/weight_import_plan_test.go new file mode 100644 index 0000000000..b44306c428 --- /dev/null +++ b/pkg/model/weight_import_plan_test.go @@ -0,0 +1,281 @@ +package model + +import ( + "context" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/weights/lockfile" +) + +func TestPlanImport_NewWeight(t *testing.T) { + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "weights", map[string][]byte{ + "model.safetensors": []byte("weights-data"), + "config.json": []byte(`{"hidden_size": 768}`), + }) + + src := NewSourceFromConfig(&config.Config{ + Weights: []config.WeightSource{ + {Name: "my-model", Target: "/src/weights", Source: &config.WeightSourceConfig{URI: "weights"}}, + }, + }, projectDir) + + lockPath := filepath.Join(projectDir, "weights.lock") + builder := NewWeightBuilder(src, nil, lockPath) + + spec := testWeightSpec(t, "my-model", "weights", "/src/weights") + plan, err := builder.PlanImport(context.Background(), spec) + require.NoError(t, err) + + assert.Equal(t, "my-model", plan.Spec.Name()) + assert.Equal(t, PlanStatusNew, plan.Status) + assert.Len(t, plan.FilteredFiles(), 2) + assert.Greater(t, plan.TotalSize(), int64(0)) + assert.Empty(t, plan.Changes) +} + +func TestPlanImport_Unchanged(t *testing.T) { + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "weights", map[string][]byte{ + "config.json": []byte(`{"hidden_size": 768}`), + }) + + weights := []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: "weights"}}, + } + src := NewSourceFromConfig(&config.Config{Weights: weights}, projectDir) + + wb, _ := newTestBuilder(t, projectDir, weights) + spec := testWeightSpec(t, "w", "weights", "/src/w") + _, err := wb.Build(context.Background(), spec) + require.NoError(t, err) + + lockPath := filepath.Join(projectDir, "weights.lock") + planner := NewWeightBuilder(src, nil, lockPath) + plan, err := planner.PlanImport(context.Background(), spec) + require.NoError(t, err) + + assert.Equal(t, PlanStatusUnchanged, plan.Status) + assert.Empty(t, plan.Changes) +} + +func TestPlanImport_ConfigChanged(t *testing.T) { + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "weights", map[string][]byte{ + "model.safetensors": []byte("data"), + "config.json": []byte("{}"), + "model.onnx": []byte("onnx-data"), + }) + + weights := []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: "weights"}}, + } + + wb, _ := newTestBuilder(t, projectDir, weights) + spec := testWeightSpec(t, "w", "weights", "/src/w") + _, err := wb.Build(context.Background(), spec) + require.NoError(t, err) + + specWithExclude, err := WeightSpecFromConfig(config.WeightSource{ + Name: "w", + Target: "/src/w", + Source: &config.WeightSourceConfig{URI: "weights", Exclude: []string{"*.onnx"}}, + }) + require.NoError(t, err) + + lockPath := filepath.Join(projectDir, "weights.lock") + src := NewSourceFromConfig(&config.Config{Weights: weights}, projectDir) + planner := NewWeightBuilder(src, nil, lockPath) + plan, err := planner.PlanImport(context.Background(), specWithExclude) + require.NoError(t, err) + + assert.Equal(t, PlanStatusConfigChanged, plan.Status) + require.NotEmpty(t, plan.Changes) + assert.Contains(t, plan.Changes[0], "exclude") + + filtered := plan.FilteredFiles() + assert.Len(t, filtered, 2) + for _, f := range filtered { + assert.NotEqual(t, "model.onnx", f.Path) + } +} + +func TestPlanImport_WithFilter_ShowsExcluded(t *testing.T) { + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "weights", map[string][]byte{ + "model.safetensors": []byte("data"), + "config.json": []byte("{}"), + "model.onnx": []byte("onnx-data"), + }) + + spec, err := WeightSpecFromConfig(config.WeightSource{ + Name: "w", + Target: "/src/w", + Source: &config.WeightSourceConfig{URI: "weights", Exclude: []string{"*.onnx"}}, + }) + require.NoError(t, err) + + src := NewSourceFromConfig(&config.Config{ + Weights: []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: "weights", Exclude: []string{"*.onnx"}}}, + }, + }, projectDir) + + lockPath := filepath.Join(projectDir, "weights.lock") + planner := NewWeightBuilder(src, nil, lockPath) + plan, err := planner.PlanImport(context.Background(), spec) + require.NoError(t, err) + + require.NotEmpty(t, plan.UnfilteredFiles) + assert.Len(t, plan.UnfilteredFiles, 3) + assert.Len(t, plan.FilteredFiles(), 2) + + excluded := plan.ExcludedFiles() + require.Len(t, excluded, 1) + assert.Equal(t, "model.onnx", excluded[0].Path) +} + +func TestPlanImport_UpstreamChanged(t *testing.T) { + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "weights", map[string][]byte{ + "config.json": []byte("v1"), + }) + + weights := []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: "weights"}}, + } + + wb, _ := newTestBuilder(t, projectDir, weights) + spec := testWeightSpec(t, "w", "weights", "/src/w") + _, err := wb.Build(context.Background(), spec) + require.NoError(t, err) + + makeWeightDir(t, projectDir, "weights", map[string][]byte{ + "config.json": []byte("v2-different-content"), + }) + + lockPath := filepath.Join(projectDir, "weights.lock") + src := NewSourceFromConfig(&config.Config{Weights: weights}, projectDir) + planner := NewWeightBuilder(src, nil, lockPath) + plan, err := planner.PlanImport(context.Background(), spec) + require.NoError(t, err) + + assert.Equal(t, PlanStatusUpstreamChanged, plan.Status) + require.Len(t, plan.Changes, 1) + assert.Contains(t, plan.Changes[0], "fingerprint") +} + +func TestDescribeSpecDrift(t *testing.T) { + tests := []struct { + name string + config *WeightSpec + lock *WeightSpec + wantLen int + wantSub string + }{ + { + name: "URI changed", + config: &WeightSpec{URI: "hf://org/new", Target: "/src/w"}, + lock: &WeightSpec{URI: "hf://org/old", Target: "/src/w"}, + wantLen: 1, + wantSub: "uri", + }, + { + name: "target changed", + config: &WeightSpec{URI: "hf://org/m", Target: "/src/new"}, + lock: &WeightSpec{URI: "hf://org/m", Target: "/src/old"}, + wantLen: 1, + wantSub: "target", + }, + { + name: "include changed", + config: &WeightSpec{URI: "hf://org/m", Target: "/src/w", Include: []string{"*.json"}}, + lock: &WeightSpec{URI: "hf://org/m", Target: "/src/w"}, + wantLen: 1, + wantSub: "include", + }, + { + name: "multiple changes", + config: &WeightSpec{URI: "new-uri", Target: "/new", Exclude: []string{"*.bin"}}, + lock: &WeightSpec{URI: "old-uri", Target: "/old"}, + wantLen: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + changes := describeSpecDrift(tt.config, tt.lock) + assert.Len(t, changes, tt.wantLen) + if tt.wantSub != "" { + assert.Contains(t, changes[0], tt.wantSub) + } + }) + } +} + +func TestBuildFromPlan_MatchesBuild(t *testing.T) { + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "weights", map[string][]byte{ + "model.safetensors": []byte("weights-data"), + "config.json": []byte(`{"hidden_size": 768}`), + "model.onnx": []byte("onnx-data"), + }) + + weights := []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{ + URI: "weights", Exclude: []string{"*.onnx"}, + }}, + } + + spec, err := WeightSpecFromConfig(weights[0]) + require.NoError(t, err) + + // Plan with nil-store builder. + src := NewSourceFromConfig(&config.Config{Weights: weights}, projectDir) + lockPath := filepath.Join(projectDir, "weights.lock") + planner := NewWeightBuilder(src, nil, lockPath) + plan, err := planner.PlanImport(context.Background(), spec) + require.NoError(t, err) + assert.Equal(t, PlanStatusNew, plan.Status) + + // BuildFromPlan with a real-store builder. + builder, _ := newTestBuilder(t, projectDir, weights) + artifact, err := builder.BuildFromPlan(context.Background(), spec, plan) + require.NoError(t, err) + + wa, ok := artifact.(*WeightArtifact) + require.True(t, ok) + assert.Equal(t, "w", wa.Name()) + assert.Equal(t, "/src/w", wa.Entry.Target) + + // Verify the onnx file was excluded. + for _, f := range wa.Entry.Files { + assert.NotEqual(t, "model.onnx", f.Path, "excluded file should not appear") + } + assert.Len(t, wa.Entry.Files, 2) +} + +func TestPlanImport_NoLockfile(t *testing.T) { + projectDir := t.TempDir() + makeWeightDir(t, projectDir, "weights", map[string][]byte{ + "data.bin": []byte("some data"), + }) + + src := NewSourceFromConfig(&config.Config{ + Weights: []config.WeightSource{ + {Name: "w", Target: "/src/w", Source: &config.WeightSourceConfig{URI: "weights"}}, + }, + }, projectDir) + + lockPath := filepath.Join(projectDir, lockfile.WeightsLockFilename) + planner := NewWeightBuilder(src, nil, lockPath) + spec := testWeightSpec(t, "w", "weights", "/src/w") + plan, err := planner.PlanImport(context.Background(), spec) + require.NoError(t, err) + assert.Equal(t, PlanStatusNew, plan.Status) +} diff --git a/pkg/model/weight_manifest_v1.go b/pkg/model/weight_manifest_v1.go new file mode 100644 index 0000000000..1dc4faffcb --- /dev/null +++ b/pkg/model/weight_manifest_v1.go @@ -0,0 +1,373 @@ +package model + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "maps" + "slices" + "strconv" + "strings" + "sync" + + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/empty" + "github.com/google/go-containerregistry/pkg/v1/mutate" + "github.com/google/go-containerregistry/pkg/v1/types" + + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" +) + +// Manifest-level annotation keys per spec §2.5 (v1 "run.cog.*" namespace). +const ( + AnnotationV1WeightName = "run.cog.weight.name" + AnnotationV1WeightTarget = "run.cog.weight.target" + AnnotationV1WeightSetDigest = "run.cog.weight.set-digest" + + // AnnotationV1WeightSizeUncomp carries an uncompressed byte count. + // It appears in two places: + // - On each layer descriptor inside the weight manifest (§2.5): + // the uncompressed size of that single layer's contents. Set by + // buildWeightManifestV1 from packedLayer.UncompressedSize. + // - On the weight descriptor inside the outer OCI index (§2.6): + // the sum across all layers — i.e. the total uncompressed size + // of the weight. Set by IndexBuilder.AddWeightDescriptor from + // the lockfile entry's Size field. + AnnotationV1WeightSizeUncomp = "run.cog.weight.size.uncompressed" +) + +// MediaTypeWeightConfig is the config blob media type per spec §2.1. +const MediaTypeWeightConfig = "application/vnd.cog.weight.config.v1+json" + +// buildWeightManifestV1 assembles a v1.Image representing a v1 weight +// manifest from a lockfile entry and the corresponding packed layer +// descriptors. The entry provides all metadata (name, target, +// setDigest, file index for the config blob); the packed layers +// carry their layer plans, which the wrapped fileLayer replays +// against st to produce the on-wire tar bytes during push. +// +// ctx scopes any byte-streaming the manifest's layers do later +// (Compressed/Uncompressed). Pass the push context here so a +// canceled push tears down its layer-streaming goroutines +// promptly. nil is fine for callers that only need manifest digest +// + descriptor without ever reading layer bytes. +// +// st may be nil for callers that only need the manifest digest (no +// blob upload). Push paths must supply a real store. +// +// Layers are canonicalized: the manifest emits them in digest-sorted order, +// regardless of input order. This makes the manifest digest a pure function +// of the layer *set* plus metadata, so cold-pack and warm-cache paths +// (which can produce layers in different orders) produce identical +// manifests. The lockfile is also digest-sorted when serialized, so the +// two canonical forms agree. +// +// The returned image has: +// - artifactType: application/vnd.cog.weight.v1 (injected via RawManifest override) +// - config: real config blob (application/vnd.cog.weight.config.v1+json, §2.3) +// - layers: one descriptor per packedLayer, in digest-sorted order, +// preserving mediaType, digest, size +// - annotations: manifest-level weight annotations per spec §2.5 +func buildWeightManifestV1(ctx context.Context, entry lockfile.WeightLockEntry, layers []packedLayer, st store.Store) (v1.Image, error) { + if entry.Name == "" { + return nil, fmt.Errorf("weight name is required") + } + if entry.Target == "" { + return nil, fmt.Errorf("weight target is required") + } + if entry.SetDigest == "" { + return nil, fmt.Errorf("weight set digest is required") + } + if len(entry.Files) == 0 { + return nil, fmt.Errorf("weight files are required for config blob") + } + if len(layers) == 0 { + return nil, fmt.Errorf("at least one layer is required") + } + + // Build config blob from the entry's file index. The lockfile is + // already a superset of the config blob shape (§2.3). + configBlob, err := buildWeightConfigBlob(entry.Name, entry.Target, entry.SetDigest, entry.Files) + if err != nil { + return nil, fmt.Errorf("build config blob: %w", err) + } + + // Copy and sort by digest so callers' slices aren't reordered as a + // side effect and the manifest layer order is a pure function of + // input content. + sorted := slices.Clone(layers) + slices.SortFunc(sorted, func(a, b packedLayer) int { + return strings.Compare(a.Digest.String(), b.Digest.String()) + }) + + // Build mutate.Addendum entries. Each addendum wraps our file-backed + // layer, supplies the media type (used by mutate.Append to build the + // manifest's layer descriptors), and carries the single layer-level + // annotation required by spec §2.5. + // + // Per spec §2.5 the layer descriptor carries one annotation — + // run.cog.weight.size.uncompressed — so consumers can make per-layer + // scheduling/disk decisions without fetching the config blob. All + // other file-level metadata (paths, per-file sizes, layer mappings) + // lives in the config blob. + adds := make([]mutate.Addendum, 0, len(sorted)) + for i, lr := range sorted { + if lr.Digest.Algorithm == "" || lr.Digest.Hex == "" { + return nil, fmt.Errorf("layer %d: missing digest", i) + } + if lr.Size <= 0 { + return nil, fmt.Errorf("layer %d (%s): invalid size %d", i, lr.Digest, lr.Size) + } + if lr.UncompressedSize <= 0 { + return nil, fmt.Errorf("layer %d (%s): invalid uncompressed size %d", i, lr.Digest, lr.UncompressedSize) + } + if lr.MediaType == "" { + return nil, fmt.Errorf("layer %d (%s): missing media type", i, lr.Digest) + } + + adds = append(adds, mutate.Addendum{ + Layer: newFileLayer(ctx, lr, st), + MediaType: lr.MediaType, + Annotations: map[string]string{ + AnnotationV1WeightSizeUncomp: strconv.FormatInt(lr.UncompressedSize, 10), + }, + }) + } + + // Base on empty.Image, switched to OCI manifest media type. + img := mutate.MediaType(empty.Image, types.OCIManifestSchema1) + img, err = mutate.Append(img, adds...) + if err != nil { + return nil, fmt.Errorf("append weight layers: %w", err) + } + + // Compute config blob digest + size. + cfgSum := sha256.Sum256(configBlob) + cfgDigest := v1.Hash{ + Algorithm: "sha256", + Hex: hex.EncodeToString(cfgSum[:]), + } + + annotations := map[string]string{ + AnnotationV1WeightName: entry.Name, + AnnotationV1WeightTarget: entry.Target, + AnnotationV1WeightSetDigest: entry.SetDigest, + } + + // Wrap to inject artifactType, override config to the real config + // blob descriptor, and attach manifest-level annotations. + return &weightManifestV1Image{ + Image: img, + annotations: annotations, + configBlob: configBlob, + configDesc: v1.Descriptor{ + MediaType: types.MediaType(MediaTypeWeightConfig), + Size: int64(len(configBlob)), + Digest: cfgDigest, + }, + }, nil +} + +// weightOCIManifest extends v1.Manifest with artifactType for OCI 1.1 support. +// v1.Manifest in go-containerregistry does not include artifactType at the +// manifest level (only on descriptors), so we serialize it ourselves. +type weightOCIManifest struct { + SchemaVersion int64 `json:"schemaVersion"` + MediaType types.MediaType `json:"mediaType,omitempty"` + ArtifactType string `json:"artifactType,omitempty"` + Config v1.Descriptor `json:"config"` + Layers []v1.Descriptor `json:"layers"` + Annotations map[string]string `json:"annotations,omitempty"` +} + +// weightManifestV1Image wraps a v1.Image to produce a v1 weight manifest with: +// - artifactType set to application/vnd.cog.weight.v1 +// - config pointing to the real config blob (§2.3) +// - manifest-level annotations per spec §2.5 +// +// go-containerregistry's v1.Manifest struct has no ArtifactType field at the +// top level (it lives only on Descriptor). This is a deliberate upstream design +// choice rather than a version lag — upstream main (as of 2026-04) still omits +// it. So we intercept RawManifest() and marshal our own struct that includes +// artifactType. The result is cached via sync.Once so Digest() and +// RawManifest() observe identical bytes, which the registry requires for the +// manifest PUT to succeed. +type weightManifestV1Image struct { + v1.Image + annotations map[string]string + configBlob []byte + configDesc v1.Descriptor + + rawOnce sync.Once + rawManifest []byte + rawManifestErr error +} + +// RawConfigFile returns the weight config blob (§2.3). +func (w *weightManifestV1Image) RawConfigFile() ([]byte, error) { + return w.configBlob, nil +} + +// ArtifactType implements the withArtifactType interface used by partial.Descriptor. +func (w *weightManifestV1Image) ArtifactType() (string, error) { + return MediaTypeWeightArtifact, nil +} + +// Manifest returns the modified manifest with the real config descriptor and +// the v1 weight annotations merged in. +func (w *weightManifestV1Image) Manifest() (*v1.Manifest, error) { + m, err := w.Image.Manifest() + if err != nil { + return nil, err + } + mCopy := m.DeepCopy() + + // Override the config descriptor to point to the real config blob (§2.3). + mCopy.Config = w.configDesc + + // Merge in manifest-level annotations. Our annotations win over any upstream. + if len(w.annotations) > 0 { + if mCopy.Annotations == nil { + mCopy.Annotations = make(map[string]string, len(w.annotations)) + } + maps.Copy(mCopy.Annotations, w.annotations) + } + + return mCopy, nil +} + +// Digest returns the digest of the raw manifest bytes. It must match the +// serialized output of RawManifest so the registry accepts the push. +func (w *weightManifestV1Image) Digest() (v1.Hash, error) { + raw, err := w.RawManifest() + if err != nil { + return v1.Hash{}, err + } + sum := sha256.Sum256(raw) + return v1.Hash{ + Algorithm: "sha256", + Hex: hex.EncodeToString(sum[:]), + }, nil +} + +// RawManifest serializes the weight manifest, including the artifactType +// field that v1.Manifest does not carry. The result is cached so Digest() +// and RawManifest() always see identical bytes. +func (w *weightManifestV1Image) RawManifest() ([]byte, error) { + w.rawOnce.Do(func() { + m, err := w.Manifest() + if err != nil { + w.rawManifestErr = err + return + } + + ociManifest := weightOCIManifest{ + SchemaVersion: m.SchemaVersion, + MediaType: m.MediaType, + ArtifactType: MediaTypeWeightArtifact, + Config: m.Config, + Layers: m.Layers, + Annotations: m.Annotations, + } + w.rawManifest, w.rawManifestErr = json.Marshal(ociManifest) + }) + return w.rawManifest, w.rawManifestErr +} + +// fileLayer is a v1.Layer that streams its on-wire bytes by re-running +// the packer pipeline against a content-addressed store. There is no +// on-disk tar; the layer reproduces its bytes on demand. +// +// The bytes are deterministic for a given (layerPlan, store) pair, so +// repeated Compressed() calls — once during digest verification, again +// during upload, again during retry — observe identical content. +// +// Unlike tarball.LayerFromFile, fileLayer does not re-compress its +// input: the byte stream IS the on-wire blob. This matches OCI +// "artifact" layers where the blob is whatever the registry stores, +// regardless of the MIME type. +// +// The digest and size are supplied by the packer, not recomputed. +// Recomputing would require streaming the whole tar a third time per +// layer. +type fileLayer struct { + plan layerPlan + digest v1.Hash + size int64 + mediaType types.MediaType + store store.Store + // ctx scopes the streaming goroutine in Compressed(); see + // newFileLayer for why it lives on the struct. + ctx context.Context +} + +// newFileLayer constructs a fileLayer that streams its bytes from st +// when Compressed/Uncompressed is called. +// +// ctx scopes the streaming goroutine. The v1.Layer interface +// doesn't accept a context (Compressed and Uncompressed are +// zero-arg), so the canonical workaround — also used by +// go-containerregistry's own remote.Layer — is to stash one on the +// struct. When the caller cancels (e.g. user interrupt mid-push), +// streamLayer observes ctx.Err at its next loop boundary and tears +// down the pipe instead of grinding through more bytes for nobody. +func newFileLayer(ctx context.Context, lr packedLayer, st store.Store) *fileLayer { + return &fileLayer{ + plan: lr.Plan, + digest: lr.Digest, + size: lr.Size, + mediaType: lr.MediaType, + store: st, + ctx: ctx, + } +} + +// Digest returns the blob digest. +func (l *fileLayer) Digest() (v1.Hash, error) { return l.digest, nil } + +// DiffID returns the diff ID for the layer. +// +// For weight artifacts the diff ID is not meaningful (there is no RootFS +// overlay), but partial.Descriptor and mutate.Append both need a non-error +// value. We return the blob digest, matching the pattern used by static.NewLayer. +func (l *fileLayer) DiffID() (v1.Hash, error) { return l.digest, nil } + +// Compressed returns the on-wire layer bytes by re-streaming the +// packer pipeline against the store. The store must contain every +// file in l.plan; otherwise the read fails partway through with an +// fs.ErrNotExist-wrapped error. +func (l *fileLayer) Compressed() (io.ReadCloser, error) { + if l.store == nil { + return nil, fmt.Errorf("fileLayer: store is nil; cannot stream layer %s", l.digest) + } + ctx := l.ctx + if ctx == nil { + ctx = context.Background() + } + pr, pw := io.Pipe() + go func() { + // Errors propagate through pw.CloseWithError so the consumer's + // next Read returns them. The consumer also controls lifetime + // from its end: closing the reader makes pw.Write return + // io.ErrClosedPipe, which streamLayer surfaces as an error. + _, err := newPacker(nil).streamLayer(ctx, l.store, l.plan, pw) + _ = pw.CloseWithError(err) //nolint:errcheck // returned err is the only one possible + }() + return pr, nil +} + +// Uncompressed returns the on-wire layer bytes (same as Compressed for +// weight layers — see fileLayer doc). +func (l *fileLayer) Uncompressed() (io.ReadCloser, error) { + return l.Compressed() +} + +// Size returns the size of the layer blob in bytes. +func (l *fileLayer) Size() (int64, error) { return l.size, nil } + +// MediaType returns the layer's OCI media type. +func (l *fileLayer) MediaType() (types.MediaType, error) { return l.mediaType, nil } diff --git a/pkg/model/weight_manifest_v1_test.go b/pkg/model/weight_manifest_v1_test.go new file mode 100644 index 0000000000..e593925d53 --- /dev/null +++ b/pkg/model/weight_manifest_v1_test.go @@ -0,0 +1,482 @@ +package model + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "io" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + "testing" + + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/model/weightsource" + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" +) + +// ============================================================================= +// Helpers +// ============================================================================= + +// packDir runs the full ingress + plan + computeLayerDigests +// pipeline against sourceDir and returns the layers plus the store +// they reference. Tests that need to read layer bytes pass the store +// to readLayerEntries / fileLayer. +func packDir(t *testing.T, sourceDir string, opts *packOptions) ([]packedLayer, store.Store) { + t.Helper() + st, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + src, err := weightsource.NewFileSource("file://"+sourceDir, "") + require.NoError(t, err) + inv, err := src.Inventory(t.Context()) + require.NoError(t, err) + require.NoError(t, ingressFromInventory(t.Context(), src, st, inv)) + pkr := newPacker(opts) + pl := pkr.planLayers(inv) + require.NotEmpty(t, pl.Layers) + layers, err := pkr.computeLayerDigests(t.Context(), st, pl) + require.NoError(t, err) + return layers, st +} + +// writeSrcFile writes size bytes at relPath under dir. +func writeSrcFile(t *testing.T, dir, relPath string, size int64) { + t.Helper() + abs := filepath.Join(dir, relPath) + require.NoError(t, os.MkdirAll(filepath.Dir(abs), 0o755)) + f, err := os.Create(abs) //nolint:gosec // test file + require.NoError(t, err) + defer f.Close() //nolint:errcheck + if size > 0 { + require.NoError(t, f.Truncate(size)) + } +} + +// defaultEntry returns a minimal valid lockfile.WeightLockEntry for manifest tests. +func defaultEntry() lockfile.WeightLockEntry { + return lockfile.WeightLockEntry{ + Name: "z-image-turbo", + Target: "/src/weights", + SetDigest: "sha256:0000000000000000000000000000000000000000000000000000000000000000", + Files: []lockfile.WeightLockFile{ + {Path: "config.json", Size: 128, Digest: "sha256:aaa", Layer: "sha256:layer1"}, + }, + } +} + +// singleSmallFileLayers produces a valid single-layer result set for +// tests that only care about manifest shape, not layer contents. +// Returns the layers and the store the layers reference. +func singleSmallFileLayers(t *testing.T) ([]packedLayer, store.Store) { + t.Helper() + dir := t.TempDir() + writeSrcFile(t, dir, "config.json", 128) + return packDir(t, dir, nil) +} + +// ============================================================================= +// Entry validation via buildWeightManifestV1 +// ============================================================================= + +func TestBuildWeightManifestV1_RejectsInvalidEntry(t *testing.T) { + validSetDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000000" + validFiles := []lockfile.WeightLockFile{{Path: "f.bin", Size: 1, Digest: "sha256:aaa", Layer: "sha256:l1"}} + layers, st := singleSmallFileLayers(t) + + tests := []struct { + name string + entry lockfile.WeightLockEntry + wantErr string + }{ + {"missing name", lockfile.WeightLockEntry{Target: "/x", SetDigest: validSetDigest, Files: validFiles}, "weight name is required"}, + {"missing target", lockfile.WeightLockEntry{Name: "n", SetDigest: validSetDigest, Files: validFiles}, "weight target is required"}, + {"missing set digest", lockfile.WeightLockEntry{Name: "n", Target: "/x", Files: validFiles}, "weight set digest is required"}, + {"missing files", lockfile.WeightLockEntry{Name: "n", Target: "/x", SetDigest: validSetDigest}, "weight files are required"}, + {"valid", lockfile.WeightLockEntry{Name: "n", Target: "/x", SetDigest: validSetDigest, Files: validFiles}, ""}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := buildWeightManifestV1(t.Context(), tc.entry, layers, st) + if tc.wantErr == "" { + require.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + } + }) + } +} + +func TestBuildWeightManifestV1_ManifestAnnotations(t *testing.T) { + layers, st := singleSmallFileLayers(t) + entry := defaultEntry() + img, err := buildWeightManifestV1(t.Context(), entry, layers, st) + require.NoError(t, err) + + m, err := img.Manifest() + require.NoError(t, err) + + assert.Equal(t, "z-image-turbo", m.Annotations[AnnotationV1WeightName]) + assert.Equal(t, "/src/weights", m.Annotations[AnnotationV1WeightTarget]) + assert.Equal(t, "sha256:0000000000000000000000000000000000000000000000000000000000000000", m.Annotations[AnnotationV1WeightSetDigest]) +} + +// ============================================================================= +// buildWeightManifestV1 — validation +// ============================================================================= + +func TestBuildWeightManifestV1_RejectsMissingName(t *testing.T) { + layers, st := singleSmallFileLayers(t) + + _, err := buildWeightManifestV1(t.Context(), lockfile.WeightLockEntry{ + Target: "/x", + SetDigest: "sha256:0000000000000000000000000000000000000000000000000000000000000000", + Files: []lockfile.WeightLockFile{{Path: "f", Size: 1, Digest: "sha256:a", Layer: "sha256:l"}}, + }, layers, st) + require.Error(t, err) + assert.Contains(t, err.Error(), "name") +} + +func TestBuildWeightManifestV1_RejectsEmptyLayers(t *testing.T) { + _, err := buildWeightManifestV1(t.Context(), defaultEntry(), nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "at least one layer") +} + +func TestBuildWeightManifestV1_RejectsInvalidLayer(t *testing.T) { + base, st := singleSmallFileLayers(t) + + cases := []struct { + name string + mutate func(lr *packedLayer) + wantErr string + }{ + {"missing digest", func(lr *packedLayer) { lr.Digest = v1.Hash{} }, "missing digest"}, + {"zero size", func(lr *packedLayer) { lr.Size = 0 }, "invalid size"}, + {"missing media type", func(lr *packedLayer) { lr.MediaType = "" }, "missing media type"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + lr := base[0] + tc.mutate(&lr) + _, err := buildWeightManifestV1(t.Context(), defaultEntry(), []packedLayer{lr}, st) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + }) + } +} + +// ============================================================================= +// buildWeightManifestV1 — manifest structure +// ============================================================================= + +func TestBuildWeightManifestV1_ManifestShape(t *testing.T) { + layers, st := singleSmallFileLayers(t) + entry := defaultEntry() + img, err := buildWeightManifestV1(t.Context(), entry, layers, st) + require.NoError(t, err) + + // Manifest schema and media type. + m, err := img.Manifest() + require.NoError(t, err) + assert.EqualValues(t, 2, m.SchemaVersion) + assert.Equal(t, types.OCIManifestSchema1, m.MediaType) + + // Config is the weight config descriptor. + assert.Equal(t, types.MediaType(MediaTypeWeightConfig), m.Config.MediaType) + assert.Equal(t, "sha256", m.Config.Digest.Algorithm) + assert.Greater(t, m.Config.Size, int64(0)) + + // Config blob is valid JSON containing the expected fields. + cfgBytes, err := img.RawConfigFile() + require.NoError(t, err) + + // Verify config digest matches the config blob bytes. + cfgSum := sha256.Sum256(cfgBytes) + assert.Equal(t, hex.EncodeToString(cfgSum[:]), m.Config.Digest.Hex) + assert.Equal(t, int64(len(cfgBytes)), m.Config.Size) + + // Layers preserve media type, size, and digest from the packer, and + // carry the uncompressed size annotation per spec §2.5. + require.Len(t, m.Layers, len(layers)) + for i, layer := range m.Layers { + assert.Equal(t, layers[i].MediaType, layer.MediaType) + assert.Equal(t, layers[i].Size, layer.Size) + assert.Equal(t, layers[i].Digest, layer.Digest) + assert.Equal(t, + strconv.FormatInt(layers[i].UncompressedSize, 10), + layer.Annotations[AnnotationV1WeightSizeUncomp], + "layer %d should carry uncompressed size annotation", i) + } + + // Manifest annotations carry the v1 spec keys. + assert.Equal(t, "z-image-turbo", m.Annotations[AnnotationV1WeightName]) + assert.Equal(t, "/src/weights", m.Annotations[AnnotationV1WeightTarget]) + assert.Equal(t, "sha256:0000000000000000000000000000000000000000000000000000000000000000", m.Annotations[AnnotationV1WeightSetDigest]) +} + +func TestBuildWeightManifestV1_RawManifestContainsArtifactType(t *testing.T) { + layers, st := singleSmallFileLayers(t) + entry := defaultEntry() + img, err := buildWeightManifestV1(t.Context(), entry, layers, st) + require.NoError(t, err) + + raw, err := img.RawManifest() + require.NoError(t, err) + + var parsed map[string]any + require.NoError(t, json.Unmarshal(raw, &parsed)) + + assert.Equal(t, MediaTypeWeightArtifact, parsed["artifactType"]) + assert.Equal(t, "application/vnd.oci.image.manifest.v1+json", parsed["mediaType"]) + assert.EqualValues(t, 2, parsed["schemaVersion"]) + + cfg, ok := parsed["config"].(map[string]any) + require.True(t, ok) + assert.Equal(t, MediaTypeWeightConfig, cfg["mediaType"]) + + // Verify config digest matches what buildWeightManifestV1 produced. + cfgBytes, err := img.RawConfigFile() + require.NoError(t, err) + cfgSum := sha256.Sum256(cfgBytes) + assert.Equal(t, "sha256:"+hex.EncodeToString(cfgSum[:]), cfg["digest"]) + assert.EqualValues(t, len(cfgBytes), cfg["size"]) + + rawLayers, ok := parsed["layers"].([]any) + require.True(t, ok) + require.Len(t, rawLayers, len(layers)) +} + +func TestBuildWeightManifestV1_DigestMatchesRawManifest(t *testing.T) { + layers, st := singleSmallFileLayers(t) + img, err := buildWeightManifestV1(t.Context(), defaultEntry(), layers, st) + require.NoError(t, err) + + raw, err := img.RawManifest() + require.NoError(t, err) + + sum := sha256.Sum256(raw) + wantHex := hex.EncodeToString(sum[:]) + + got, err := img.Digest() + require.NoError(t, err) + assert.Equal(t, wantHex, got.Hex) + assert.Equal(t, "sha256", got.Algorithm) +} + +func TestBuildWeightManifestV1_LayersCanonicallySortedByDigest(t *testing.T) { + // The manifest emits layers in digest-sorted order regardless of input + // order, so that different paths producing the same layer set produce + // identical manifests. Mix a small bundle and two "large" files with + // different media types to make sure the sort is by digest, not by + // media type or size — and feed the builder the reverse of the + // already-sorted order so a no-op pass-through can't spuriously pass. + // + // BundleFileMax is set tiny so ~1 KB files qualify as "large" and get + // their own layer; avoids writing hundreds of MB per test. + dir := t.TempDir() + writeSrcFile(t, dir, "config.json", 64) + writeSrcFile(t, dir, "tokenizer.json", 64) + writeSrcFile(t, dir, "model.safetensors", 1024) // incompressible .tar + writeSrcFile(t, dir, "aux.dat", 1024) // compressible .tar.gz + + layers, st := packDir(t, dir, &packOptions{BundleFileMax: 512, BundleSizeMax: 1024}) + require.GreaterOrEqual(t, len(layers), 3, "expected bundle + 2 large layers") + + // Pre-sort then reverse so the input is guaranteed to be in the + // opposite of the expected output order. If the builder forgets to + // sort, the assertion below will fail. + input := slices.Clone(layers) + slices.SortFunc(input, func(a, b packedLayer) int { + return strings.Compare(a.Digest.String(), b.Digest.String()) + }) + slices.Reverse(input) + + img, err := buildWeightManifestV1(t.Context(), defaultEntry(), input, st) + require.NoError(t, err) + + m, err := img.Manifest() + require.NoError(t, err) + require.Len(t, m.Layers, len(layers)) + + // Layers are digest-sorted; assert strict ascending order on the + // serialized digest string. + for i := 1; i < len(m.Layers); i++ { + assert.Less(t, m.Layers[i-1].Digest.String(), m.Layers[i].Digest.String(), + "layer %d digest should sort before layer %d (manifest must be digest-sorted)", i-1, i) + } + + // At least one .tar and one .tar+gzip layer should be present — a + // sanity check that the mixed media types didn't collapse. + var sawTar, sawGzip bool + for _, layer := range m.Layers { + switch layer.MediaType { + case types.MediaType(mediaTypeOCILayerTar): + sawTar = true + case types.MediaType(mediaTypeOCILayerTarGzip): + sawGzip = true + } + } + assert.True(t, sawTar, "expected at least one .tar layer") + assert.True(t, sawGzip, "expected at least one .tar+gzip layer") +} + +func TestBuildWeightManifestV1_InputOrderDoesNotAffectDigest(t *testing.T) { + // Manifest digest must be a pure function of the layer set plus + // metadata — permuting the input slice must not change the digest. + dir := t.TempDir() + writeSrcFile(t, dir, "config.json", 64) + writeSrcFile(t, dir, "model.safetensors", 1024) + writeSrcFile(t, dir, "aux.dat", 1024) + + layers, st := packDir(t, dir, &packOptions{BundleFileMax: 512, BundleSizeMax: 1024}) + require.GreaterOrEqual(t, len(layers), 3, "expected bundle + 2 large layers for a meaningful permutation test") + + imgOriginal, err := buildWeightManifestV1(t.Context(), defaultEntry(), layers, st) + require.NoError(t, err) + originalDigest, err := imgOriginal.Digest() + require.NoError(t, err) + + // Reverse order. + reversed := make([]packedLayer, len(layers)) + for i, l := range layers { + reversed[len(layers)-1-i] = l + } + imgReversed, err := buildWeightManifestV1(t.Context(), defaultEntry(), reversed, st) + require.NoError(t, err) + reversedDigest, err := imgReversed.Digest() + require.NoError(t, err) + + assert.Equal(t, originalDigest, reversedDigest, "manifest digest must be order-invariant") + + // Swap two adjacent layers. + swapped := slices.Clone(layers) + swapped[0], swapped[1] = swapped[1], swapped[0] + imgSwapped, err := buildWeightManifestV1(t.Context(), defaultEntry(), swapped, st) + require.NoError(t, err) + swappedDigest, err := imgSwapped.Digest() + require.NoError(t, err) + assert.Equal(t, originalDigest, swappedDigest, "manifest digest must be invariant under adjacent swap") +} + +func TestBuildWeightManifestV1_DoesNotMutateInputSlice(t *testing.T) { + // Callers keep the packer's or lockfile's layer order; the manifest + // builder copies before sorting so that side effect is invisible. + dir := t.TempDir() + writeSrcFile(t, dir, "config.json", 64) + writeSrcFile(t, dir, "model.safetensors", 1024) + writeSrcFile(t, dir, "aux.dat", 1024) + + layers, st := packDir(t, dir, &packOptions{BundleFileMax: 512, BundleSizeMax: 1024}) + require.GreaterOrEqual(t, len(layers), 2, "need at least two layers to detect mutation") + + before := slices.Clone(layers) + _, err := buildWeightManifestV1(t.Context(), defaultEntry(), layers, st) + require.NoError(t, err) + + assert.Equal(t, before, layers, "buildWeightManifestV1 must not reorder the caller's slice") +} + +func TestBuildWeightManifestV1_LayerDescriptorUncompressedSizeAnnotation(t *testing.T) { + // Spec §2.5: each layer descriptor carries + // run.cog.weight.size.uncompressed as the only layer-level + // annotation. All other file-level metadata lives in the config + // blob. + layers, st := singleSmallFileLayers(t) + img, err := buildWeightManifestV1(t.Context(), defaultEntry(), layers, st) + require.NoError(t, err) + + m, err := img.Manifest() + require.NoError(t, err) + require.Len(t, m.Layers, len(layers)) + for i, l := range m.Layers { + require.Len(t, l.Annotations, 1, + "layer %d should carry exactly one annotation (uncompressed size)", i) + assert.Equal(t, + strconv.FormatInt(layers[i].UncompressedSize, 10), + l.Annotations[AnnotationV1WeightSizeUncomp], + "layer %d uncompressed size annotation", i) + } +} + +// ============================================================================= +// fileLayer — interface contract +// ============================================================================= + +func TestFileLayer_ReturnsLayerBytes(t *testing.T) { + // Pack a directory and verify fileLayer's Compressed() and + // Uncompressed() both return the same byte stream (the on-wire + // tar/tar+gzip blob), and that the bytes hash to the layer + // digest the packer recorded. + dir := t.TempDir() + writeSrcFile(t, dir, "data.txt", 100) + layers, st := packDir(t, dir, nil) + require.Len(t, layers, 1) + lr := layers[0] + + l := newFileLayer(t.Context(), lr, st) + + d, err := l.Digest() + require.NoError(t, err) + assert.Equal(t, lr.Digest, d) + + diffID, err := l.DiffID() + require.NoError(t, err) + assert.Equal(t, d, diffID) + + sz, err := l.Size() + require.NoError(t, err) + assert.Equal(t, lr.Size, sz) + + mt, err := l.MediaType() + require.NoError(t, err) + assert.Equal(t, lr.MediaType, mt) + + // Compressed and Uncompressed return identical streams + // (artifact layers never re-encode). The byte stream digests + // to the layer's recorded digest. + for _, name := range []string{"Compressed", "Uncompressed"} { + t.Run(name, func(t *testing.T) { + var rc io.ReadCloser + var err error + if name == "Compressed" { + rc, err = l.Compressed() + } else { + rc, err = l.Uncompressed() + } + require.NoError(t, err) + defer rc.Close() //nolint:errcheck + got, err := io.ReadAll(rc) + require.NoError(t, err) + assert.Equal(t, lr.Size, int64(len(got)), + "streamed bytes should match recorded layer size") + gotSum := sha256.Sum256(got) + assert.Equal(t, lr.Digest.Hex, hex.EncodeToString(gotSum[:]), + "streamed bytes should hash to the recorded layer digest") + }) + } +} + +func TestFileLayer_NilStoreFailsClosed(t *testing.T) { + // Constructing a fileLayer with a nil store must fail loudly on + // Compressed() rather than panicking. Tests that pass nil + // (because they only assert on metadata) must not accidentally + // reach byte-streaming paths. + lr := packedLayer{ + Digest: v1.Hash{Algorithm: "sha256", Hex: "deadbeef"}, + Size: 1, + MediaType: mediaTypeOCILayerTar, + } + l := newFileLayer(t.Context(), lr, nil) + _, err := l.Compressed() + require.Error(t, err) + assert.Contains(t, err.Error(), "store is nil") +} diff --git a/pkg/model/weight_pipeline_e2e_test.go b/pkg/model/weight_pipeline_e2e_test.go new file mode 100644 index 0000000000..aa90356754 --- /dev/null +++ b/pkg/model/weight_pipeline_e2e_test.go @@ -0,0 +1,285 @@ +package model + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "io" + "os" + "path/filepath" + "strconv" + "testing" + + "github.com/google/go-containerregistry/pkg/name" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/model/weightsource" + "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/registry_testhelpers" + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" +) + +// TestWeightPipeline_EndToEnd exercises Pack → WeightPusher against a +// real test registry, then pulls each layer back and asserts the +// extracted contents match the source directory byte-for-byte. +// +// This covers the critical property "does the v1 artifact extract to +// the correct shape on disk" that would otherwise require a human +// running crane + tar locally. +// +// The source dir is sized so all three packer branches fire under a +// small bundle threshold (1 KiB): +// - a bundle layer (tar+gzip) for the small config/tokenizer files +// - an uncompressed tar layer for a single .safetensors (incompressible) +// - a gzipped tar layer for a single .nemo (compressible) +func TestWeightPipeline_EndToEnd(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + ctx := t.Context() + reg := registry_testhelpers.StartTestRegistry(t) + regHost := reg.RegistryHost() + + // Deterministic content so assertions compare the pushed bytes + // against what's on disk without depending on randomness. The + // safetensors and .nemo files exceed the 1 KiB BundleFileMax we + // pass to Pack, forcing them into single-file layers. + sources := map[string][]byte{ + "config.json": []byte(`{"hidden_size": 768}`), + "tokenizer.json": []byte(`{"vocab_size": 50257}`), + "generation_config.json": []byte(`{"max_length": 128}`), + "plots/asr.png": []byte("PNG\x89fake-png-bytes-for-deterministic-hash"), + "model.safetensors": bytes.Repeat([]byte{'S'}, 4096), + "model.nemo": bytes.Repeat([]byte{'N'}, 4096), + } + + sourceDir := t.TempDir() + writeSourceTree(t, sourceDir, sources) + + // Pack directly with a small bundle threshold so we don't need to + // write 64+ MiB of fixture content to cross the default cutoff. + st, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + src, err := weightsource.NewFileSource("file://"+sourceDir, "") + require.NoError(t, err, "source") + inv, err := src.Inventory(ctx) + require.NoError(t, err, "inventory") + require.NoError(t, ingressFromInventory(ctx, src, st, inv)) + + pkr := newPacker(&packOptions{BundleFileMax: 1024}) + pl := pkr.planLayers(inv) + layers, err := pkr.computeLayerDigests(ctx, st, pl) + require.NoError(t, err, "computeLayerDigests") + require.Len(t, layers, 3, "want 1 bundle + 2 single-file layers") + files := packedFilesFromPlan(layers) + + // Build a lock entry and artifact (manifest + descriptor + digest backfill). + entry := newWeightLockEntry("my-model", "/src/weights", lockfile.WeightLockSource{}, files, layers) + artifact, err := buildWeightArtifact(&entry, layers, st) + require.NoError(t, err) + setDigest := entry.SetDigest + + repo := regHost + "/test/my-model" + pusher := NewWeightPusher(registry.NewRegistryClient()) + result, err := pusher.Push(ctx, repo, artifact) + require.NoError(t, err, "push weights") + require.NotEmpty(t, result.Ref, "push result missing ref") + + // Pull the manifest back and assert spec-compliant shape. + manifestRef, err := name.ParseReference(result.Ref, name.Insecure) + require.NoError(t, err) + pulled, err := remote.Image(manifestRef) + require.NoError(t, err) + + mf, err := pulled.Manifest() + require.NoError(t, err) + + // go-containerregistry's v1.Manifest omits artifactType, so parse + // the raw manifest bytes to verify it. + rawManifest, err := pulled.RawManifest() + require.NoError(t, err) + var rawMf struct { + ArtifactType string `json:"artifactType"` + } + require.NoError(t, json.Unmarshal(rawManifest, &rawMf)) + assert.Equal(t, "application/vnd.cog.weight.v1", rawMf.ArtifactType) + + assert.Equal(t, "my-model", mf.Annotations[AnnotationV1WeightName]) + assert.Equal(t, "/src/weights", mf.Annotations[AnnotationV1WeightTarget]) + assert.Equal(t, setDigest, mf.Annotations[AnnotationV1WeightSetDigest]) + + require.Len(t, mf.Layers, 3) + + // Per spec §2.5, each layer descriptor carries exactly the + // uncompressed-size annotation and nothing else. Partition layers by + // media type + extracted contents: the single uncompressed .tar + // layer is model.safetensors; the single .tar+gzip large-file layer + // is model.nemo; the remaining .tar+gzip layer is the bundle + // containing the JSON/PNG files. + var bundleCount int + var safetensorsLayer, nemoLayer *v1.Descriptor + for i := range mf.Layers { + d := &mf.Layers[i] + require.Len(t, d.Annotations, 1, + "layer %d should carry exactly one annotation (uncompressed size) per spec §2.5", i) + uncompStr := d.Annotations[AnnotationV1WeightSizeUncomp] + require.NotEmpty(t, uncompStr, "layer %d uncompressed size annotation missing", i) + uncomp, err := strconv.ParseInt(uncompStr, 10, 64) + require.NoError(t, err, "layer %d uncompressed size annotation not an integer", i) + assert.Positive(t, uncomp, "layer %d uncompressed size should be positive", i) + + paths := listFilesInPushedLayer(t, repo+"@"+d.Digest.String(), string(d.MediaType)) + switch { + case len(paths) > 1: + bundleCount++ + case len(paths) == 1 && paths[0] == "model.safetensors": + safetensorsLayer = d + case len(paths) == 1 && paths[0] == "model.nemo": + nemoLayer = d + default: + t.Fatalf("layer %s has unexpected file set %v", d.Digest, paths) + } + } + assert.Equal(t, 1, bundleCount, "expected exactly one bundle layer") + require.NotNil(t, safetensorsLayer) + require.NotNil(t, nemoLayer) + + // .safetensors stays uncompressed per spec §1.2; .nemo gets gzipped. + assert.Equal(t, mediaTypeOCILayerTar, string(safetensorsLayer.MediaType), + "model.safetensors should be uncompressed tar") + assert.Equal(t, mediaTypeOCILayerTarGzip, string(nemoLayer.MediaType), + "model.nemo should be gzipped") + + // Pull each layer, extract it, and assert the extracted tree + // matches the source byte-for-byte. + extractDir := t.TempDir() + for _, l := range mf.Layers { + blobRef := repo + "@" + l.Digest.String() + extractLayerToDir(t, blobRef, string(l.MediaType), extractDir) + } + + for relPath, want := range sources { + gotPath := filepath.Join(extractDir, relPath) + got, err := os.ReadFile(gotPath) //nolint:gosec // G304: relPath is a test constant + require.NoError(t, err, "extracted file %q missing under %s", relPath, extractDir) + assert.Equal(t, sha256Hex(want), sha256Hex(got), + "content mismatch for extracted file %q", relPath) + } +} + +// writeSourceTree materializes a map of relative-path → content into +// a directory, creating parent directories on demand. +func writeSourceTree(t *testing.T, dir string, files map[string][]byte) { + t.Helper() + for relPath, data := range files { + full := filepath.Join(dir, relPath) + require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) + require.NoError(t, os.WriteFile(full, data, 0o644)) + } +} + +// extractLayerToDir pulls a layer blob and extracts regular files into +// destDir, preserving relative paths. +func extractLayerToDir(t *testing.T, blobRef, mediaType, destDir string) { + t.Helper() + + rc := openLayerStream(t, blobRef, mediaType) + defer rc.Close() //nolint:errcheck + + tr := tar.NewReader(rc) + for { + hdr, err := tr.Next() + if err == io.EOF { + return + } + require.NoError(t, err, "read tar header") + + target := filepath.Join(destDir, hdr.Name) //nolint:gosec // G305: test-only, author-controlled tar input + switch hdr.Typeflag { + case tar.TypeDir: + require.NoError(t, os.MkdirAll(target, 0o755)) + case tar.TypeReg: + require.NoError(t, os.MkdirAll(filepath.Dir(target), 0o755)) + f, err := os.Create(target) //nolint:gosec // G304: test-only + require.NoError(t, err) + _, err = io.Copy(f, tr) //nolint:gosec // G110: test-only with small bounded inputs + require.NoError(t, err) + require.NoError(t, f.Close()) + } + } +} + +// listFilesInPushedLayer pulls a layer blob and returns the paths of +// the regular files it contains. Used to classify layers by content +// since layer descriptors carry only an uncompressed-size annotation +// (spec §2.5), not file membership. +func listFilesInPushedLayer(t *testing.T, blobRef, mediaType string) []string { + t.Helper() + + rc := openLayerStream(t, blobRef, mediaType) + defer rc.Close() //nolint:errcheck + + var paths []string + tr := tar.NewReader(rc) + for { + hdr, err := tr.Next() + if err == io.EOF { + return paths + } + require.NoError(t, err) + if hdr.Typeflag == tar.TypeReg { + paths = append(paths, hdr.Name) + } + } +} + +// openLayerStream pulls a layer blob by digest and returns a reader +// for its uncompressed tar bytes. +func openLayerStream(t *testing.T, blobRef, mediaType string) io.ReadCloser { + t.Helper() + + ref, err := name.ParseReference(blobRef, name.Insecure) + require.NoError(t, err) + + digest, ok := ref.(name.Digest) + require.True(t, ok, "expected digest reference, got %T", ref) + + layer, err := remote.Layer(digest) + require.NoError(t, err) + + raw, err := layer.Compressed() + require.NoError(t, err) + + if mediaType == mediaTypeOCILayerTarGzip { + gr, err := gzip.NewReader(raw) + require.NoError(t, err) + return &gzipReadCloser{Reader: gr, underlying: raw} + } + return raw +} + +// gzipReadCloser composes a *gzip.Reader with the underlying HTTP body +// so Close closes both. +type gzipReadCloser struct { + *gzip.Reader + underlying io.Closer +} + +func (g *gzipReadCloser) Close() error { + _ = g.Reader.Close() + return g.underlying.Close() +} + +// sha256Hex hashes b as a hex string, convenient for assertion output. +func sha256Hex(b []byte) string { + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]) +} diff --git a/pkg/model/weight_pusher.go b/pkg/model/weight_pusher.go index 0ef74de7d5..63d4a3c9b7 100644 --- a/pkg/model/weight_pusher.go +++ b/pkg/model/weight_pusher.go @@ -2,72 +2,84 @@ package model import ( "context" - "crypto/sha256" - "encoding/hex" - "encoding/json" "fmt" - "os" "strings" - "sync" "time" v1 "github.com/google/go-containerregistry/pkg/v1" - "github.com/google/go-containerregistry/pkg/v1/empty" - "github.com/google/go-containerregistry/pkg/v1/mutate" - "github.com/google/go-containerregistry/pkg/v1/tarball" - "github.com/google/go-containerregistry/pkg/v1/types" + "golang.org/x/sync/errgroup" "github.com/replicate/cog/pkg/registry" ) -// WeightPushOptions configures optional behavior for WeightPusher.Push. +// WeightPusher pushes a WeightArtifact as a v1 multi-layer OCI artifact. +// Each tar layer is uploaded via registry.WriteLayer (which supports +// multipart uploads, progress, and retry), followed by the manifest via +// registry.PushImage. Layers upload concurrently, bounded by +// GetPushConcurrency. +type WeightPusher struct { + registry registry.Client +} + +// NewWeightPusher creates a new WeightPusher. +func NewWeightPusher(reg registry.Client) *WeightPusher { + return &WeightPusher{registry: reg} +} + +// WeightPushOptions configures a weight push. type WeightPushOptions struct { - // ProgressFn is an optional callback for reporting upload progress. - ProgressFn func(PushProgress) - // RetryFn is an optional callback for reporting retry attempts. + // Concurrency is the maximum number of layers to upload in parallel. + // If <= 0, GetPushConcurrency() is used. + Concurrency int + // Tag overrides the manifest tag. Defaults to + // WeightTag(artifact.Name, tagSeed) where tagSeed is the set digest. + Tag string + // ProgressFn is an optional callback for per-layer upload progress. + ProgressFn func(WeightLayerProgress) + // RetryFn is an optional retry callback, invoked per-layer. // Return false to abort the retry. RetryFn func(WeightRetryEvent) bool } -// WeightRetryEvent reports a retry attempt for a weight file upload. +// WeightLayerProgress reports per-layer progress for a weight push. When +// dispatched by BundlePusher, WeightName identifies which artifact the +// layer belongs to; the per-weight WeightPusher.Push path leaves it empty. +type WeightLayerProgress struct { + WeightName string + LayerDigest string + Complete int64 + Total int64 +} + +// WeightRetryEvent reports a retry attempt for a weight layer upload. type WeightRetryEvent struct { - // Name identifies which file is being retried. - Name string - // Attempt is the current retry attempt number (1-indexed). - Attempt int - // MaxAttempts is the maximum number of retry attempts. + // Name identifies which layer is being retried. It combines the weight + // name and layer digest, e.g. "z-image-turbo layer sha256:abc…". + Name string + Attempt int MaxAttempts int - // Err is the error that caused the retry. - Err error - // NextRetryIn is the duration until the next retry attempt. + Err error NextRetryIn time.Duration } -// WeightPushResult contains the result of pushing a single weight artifact. +// WeightPushResult describes a successful weight push. type WeightPushResult struct { - // Ref is the full image reference for the pushed weight manifest (e.g., "registry/repo:weights-name-abc123"). + // Ref is the full image reference the manifest was pushed to + // (e.g. "registry/repo:weights-name-abc123"). Ref string - // Descriptor is the OCI descriptor for the pushed weight manifest. + // Descriptor is the OCI descriptor for the pushed manifest. Descriptor v1.Descriptor } -// WeightPusher pushes a WeightArtifact as a proper OCI artifact manifest -// with config blob and tarball layers. The layer blob is pushed via -// registry.WriteLayer (which supports multipart uploads, progress, and retry), -// followed by the manifest via PushImage. -type WeightPusher struct { - registry registry.Client -} - -// NewWeightPusher creates a new WeightPusher. -func NewWeightPusher(reg registry.Client) *WeightPusher { - return &WeightPusher{registry: reg} -} - -// Push pushes a WeightArtifact to the registry as an OCI artifact manifest. -// The layer blob is pushed first via WriteLayer (multipart uploads, progress, retry), -// then the manifest is pushed via PushImage. -// Returns the descriptor of the pushed manifest. +// Push pushes a WeightArtifact to the registry as a v1 OCI weight manifest. +// Layers upload concurrently; the manifest goes up last. +// +// The artifact owns the content-addressed store from which layer +// bytes are streamed; Push reads through that store on each upload +// (and on retry) without keeping any tar bytes in memory or on +// disk. On layer-upload failure the manifest is not attempted, but +// any already-uploaded layers remain in the registry +// (garbage-collectable). func (p *WeightPusher) Push(ctx context.Context, repo string, artifact *WeightArtifact, opts ...WeightPushOptions) (*WeightPushResult, error) { if artifact == nil { return nil, fmt.Errorf("artifact is nil") @@ -75,52 +87,103 @@ func (p *WeightPusher) Push(ctx context.Context, repo string, artifact *WeightAr if repo == "" { return nil, fmt.Errorf("repo is required") } + if len(artifact.Layers) == 0 { + return nil, fmt.Errorf("weight %q has no layers", artifact.Name()) + } - // Merge options (use first if provided) var opt WeightPushOptions if len(opts) > 0 { opt = opts[0] } - // Verify the weight file exists - if _, err := os.Stat(artifact.FilePath); err != nil { - return nil, fmt.Errorf("weight file %q: %w", artifact.FilePath, err) + // Build the manifest fresh inside Push so the embedded fileLayers + // inherit ctx, not the Background context from artifact build + // time. go-containerregistry's remote.Write may call + // layer.Compressed() during PushImage if the registry returns a + // HEAD miss for a layer we just uploaded (GC race, replication + // lag); without this the streaming goroutine would be deaf to + // user cancellation. + img, err := buildWeightManifestV1(ctx, artifact.Entry, artifact.Layers, artifact.store) + if err != nil { + return nil, fmt.Errorf("build weight manifest: %w", err) + } + + if err := p.pushLayersConcurrently(ctx, repo, artifact, opt); err != nil { + return nil, fmt.Errorf("push weight layers: %w", err) } - // Build the OCI artifact image (config blob + tarball layer) - img, err := buildWeightImage(artifact) - if err != nil { - return nil, fmt.Errorf("build weight image: %w", err) + tag := opt.Tag + if tag == "" { + tag = WeightTag(artifact.Name(), artifact.Entry.SetDigest) + } + ref := repo + ":" + tag + if err := p.registry.PushImage(ctx, ref, img); err != nil { + return nil, fmt.Errorf("push weight manifest (%s): %w", tag, err) } - // Extract the layer to push via WriteLayer (gets multipart + progress + retry) - layers, err := img.Layers() + desc, err := descriptorFromImage(img) if err != nil { - return nil, fmt.Errorf("get image layers: %w", err) + return nil, fmt.Errorf("compute manifest descriptor: %w", err) } - if len(layers) != 1 { - return nil, fmt.Errorf("expected 1 layer, got %d", len(layers)) + return &WeightPushResult{Ref: ref, Descriptor: desc}, nil +} + +// pushLayersConcurrently pushes all layers using bounded concurrency, +// returning the first error (if any). +func (p *WeightPusher) pushLayersConcurrently( + ctx context.Context, + repo string, + artifact *WeightArtifact, + opt WeightPushOptions, +) error { + concurrency := opt.Concurrency + if concurrency <= 0 { + concurrency = GetPushConcurrency() + } + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(concurrency) + + for _, lr := range artifact.Layers { + g.Go(func() error { + return p.pushSingleLayer(ctx, repo, artifact, lr, opt) + }) } - layer := layers[0] - // Build progress callback + return g.Wait() +} + +// pushSingleLayer pushes a single tar layer via registry.WriteLayer, wiring +// up progress and retry callbacks if configured. +func (p *WeightPusher) pushSingleLayer( + ctx context.Context, + repo string, + artifact *WeightArtifact, + lr packedLayer, + opt WeightPushOptions, +) error { + layer := newFileLayer(ctx, lr, artifact.store) + weightName := artifact.Name() + digestStr := lr.Digest.String() + var onProgress func(v1.Update) if opt.ProgressFn != nil { onProgress = func(update v1.Update) { - opt.ProgressFn(PushProgress{ - Complete: update.Complete, - Total: update.Total, + opt.ProgressFn(WeightLayerProgress{ + WeightName: weightName, + LayerDigest: digestStr, + Complete: update.Complete, + Total: update.Total, }) } } - // Build retry configuration if callback is provided var retryConfig *registry.RetryConfig if opt.RetryFn != nil { retryConfig = ®istry.RetryConfig{ OnRetry: func(event registry.RetryEvent) bool { return opt.RetryFn(WeightRetryEvent{ - Name: artifact.Name(), + Name: fmt.Sprintf("%s layer %s", weightName, digestStr), Attempt: event.Attempt, MaxAttempts: event.MaxAttempts, Err: event.Err, @@ -130,74 +193,15 @@ func (p *WeightPusher) Push(ctx context.Context, repo string, artifact *WeightAr } } - // 1. Push layer blob via WriteLayer (multipart uploads, progress, retry) - writeErr := writeLayerWithProgress(ctx, p.registry, registry.WriteLayerOptions{ + err := writeLayerWithProgress(ctx, p.registry, registry.WriteLayerOptions{ Repo: repo, Layer: layer, Retry: retryConfig, }, onProgress) - - if writeErr != nil { - return nil, fmt.Errorf("push weight layer: %w", writeErr) - } - - // 2. Push manifest via PushImage with a single tag combining name and digest. - // The layer blob is already in the registry, so PushImage will skip re-uploading it. - // Tag format: :weights--<12chars> (e.g., :weights-model-v1-383d1f4afa43) - // - // We use the artifact's descriptor digest (original file hash from the lock file), - // NOT the tarball layer digest. This ensures that `weights inspect` can look up the tag - // using the same digest stored in weights.lock, independent of the transport format. - tag := WeightTag(artifact.Name(), artifact.Descriptor().Digest.String()) - ref := repo + ":" + tag - if err := p.registry.PushImage(ctx, ref, img); err != nil { - return nil, fmt.Errorf("push weight manifest (%s): %w", tag, err) - } - - // Build result descriptor from the pushed image - desc, err := descriptorFromImage(img) if err != nil { - return nil, fmt.Errorf("compute manifest descriptor: %w", err) + return fmt.Errorf("push layer %s: %w", digestStr, err) } - - return &WeightPushResult{Ref: ref, Descriptor: desc}, nil -} - -// buildWeightImage creates an OCI artifact image with a config blob (WeightConfig JSON) -// and a tarball layer for the weight file. -func buildWeightImage(artifact *WeightArtifact) (v1.Image, error) { - // 1. Create the base image with OCI manifest media type - img := mutate.MediaType(empty.Image, types.OCIManifestSchema1) - - // 2. Create tarball layer from the weight file. - // WithCompressedCaching memoizes the compressed output so that Digest() and - // Compressed() see identical bytes. Without this, gzip non-determinism between - // separate passes causes DIGEST_INVALID errors on large uploads. - layer, err := tarball.LayerFromFile(artifact.FilePath, - tarball.WithMediaType(types.MediaType(MediaTypeWeightLayer)), - tarball.WithCompressedCaching, - ) - if err != nil { - return nil, fmt.Errorf("create tarball layer: %w", err) - } - - // 3. Append the layer - img, err = mutate.AppendLayers(img, layer) - if err != nil { - return nil, fmt.Errorf("append weight layer: %w", err) - } - - // 4. Serialize the WeightConfig as the config blob - configJSON, err := json.Marshal(artifact.Config) - if err != nil { - return nil, fmt.Errorf("marshal weight config: %w", err) - } - - // 5. Wrap to set custom config blob, config media type, and artifactType - return &weightManifestImage{ - Image: img, - configBlob: configJSON, - }, nil + return nil } // descriptorFromImage computes the v1.Descriptor for a built image manifest. @@ -212,130 +216,41 @@ func descriptorFromImage(img v1.Image) (v1.Descriptor, error) { return v1.Descriptor{}, fmt.Errorf("get raw manifest: %w", err) } + mediaType, err := img.MediaType() + if err != nil { + return v1.Descriptor{}, fmt.Errorf("get media type: %w", err) + } + return v1.Descriptor{ - MediaType: types.OCIManifestSchema1, + MediaType: mediaType, Size: int64(len(rawManifest)), Digest: digest, }, nil } -// weightOCIManifest extends v1.Manifest with artifactType for OCI 1.1 support. -// go-containerregistry v0.20.5's v1.Manifest struct does not include artifactType, -// so we serialize it ourselves. -type weightOCIManifest struct { - SchemaVersion int64 `json:"schemaVersion"` - MediaType types.MediaType `json:"mediaType,omitempty"` - Config v1.Descriptor `json:"config"` - Layers []v1.Descriptor `json:"layers"` - Annotations map[string]string `json:"annotations,omitempty"` - ArtifactType string `json:"artifactType,omitempty"` -} - -// weightManifestImage wraps a v1.Image to set a custom config blob with -// the correct media type and artifactType. This produces a proper OCI 1.1 -// artifact manifest for weight data. -// -// The raw manifest is cached on first computation to ensure deterministic -// digests across multiple calls (e.g., during remote.Write which calls -// both RawManifest and Digest). -type weightManifestImage struct { - v1.Image - configBlob []byte - rawManifest []byte - rawManifestErr error - rawOnce sync.Once -} - -// RawConfigFile returns the WeightConfig JSON as the config blob. -func (w *weightManifestImage) RawConfigFile() ([]byte, error) { - return w.configBlob, nil -} - -// Digest computes the digest from the cached raw manifest. -func (w *weightManifestImage) Digest() (v1.Hash, error) { - raw, err := w.RawManifest() - if err != nil { - return v1.Hash{}, err - } - h := sha256.Sum256(raw) - return v1.Hash{ - Algorithm: "sha256", - Hex: hex.EncodeToString(h[:]), - }, nil -} - -// ArtifactType implements the withArtifactType interface used by partial.Descriptor. -func (w *weightManifestImage) ArtifactType() (string, error) { - return MediaTypeWeightArtifact, nil -} - -// Manifest returns the modified manifest with custom config descriptor. -func (w *weightManifestImage) Manifest() (*v1.Manifest, error) { - m, err := w.Image.Manifest() - if err != nil { - return nil, err - } - // Make a copy to avoid mutating the original - mCopy := m.DeepCopy() +const weightTagPrefix = "weights-" - // Set config to point to our custom config blob - configDigest := sha256.Sum256(w.configBlob) - mCopy.Config = v1.Descriptor{ - MediaType: types.MediaType(MediaTypeWeightConfig), - Size: int64(len(w.configBlob)), - Digest: v1.Hash{ - Algorithm: "sha256", - Hex: hex.EncodeToString(configDigest[:]), - }, +// WeightTag returns the tag for a weight manifest combining name and the +// short prefix of a digest. digest is "sha256:…"; the 12 hex chars after +// the algorithm prefix are used. Falls back to "weights-" if digest +// is empty or missing the algorithm prefix. +func WeightTag(name, digest string) string { + short := ShortDigest(digest) + if short == "" { + return weightTagPrefix + name } - - return mCopy, nil -} - -// RawManifest serializes our modified manifest with artifactType field. -// The result is cached to ensure deterministic digests across multiple calls. -func (w *weightManifestImage) RawManifest() ([]byte, error) { - w.rawOnce.Do(func() { - m, err := w.Manifest() - if err != nil { - w.rawManifestErr = err - return - } - - // Build the OCI manifest with artifactType (not in v1.Manifest struct) - ociManifest := weightOCIManifest{ - SchemaVersion: m.SchemaVersion, - MediaType: m.MediaType, - Config: m.Config, - Layers: m.Layers, - Annotations: m.Annotations, - ArtifactType: MediaTypeWeightArtifact, - } - - w.rawManifest, w.rawManifestErr = json.Marshal(ociManifest) - }) - - return w.rawManifest, w.rawManifestErr + return weightTagPrefix + name + "-" + short } -// ============================================================================= -// Weight tag helpers -// ============================================================================= - -const weightTagPrefix = "weights-" - -// WeightTag returns the tag for a weight manifest combining name and digest. -// The digest should be in "sha256:abc123..." format. -// Returns e.g., "weights-model-v1-abc123def456" (12-char hex suffix). -// Falls back to "weights-" if digest is empty or invalid. -func WeightTag(name, digest string) string { +// ShortDigest returns the 12-hex-char prefix of a "sha256:…" digest, or the +// empty string if the input is empty or has no algorithm prefix. +func ShortDigest(digest string) string { _, hex, ok := strings.Cut(digest, ":") if !ok || hex == "" { - return weightTagPrefix + name + return "" } - short := hex - if len(short) > 12 { - short = short[:12] + if len(hex) > 12 { + return hex[:12] } - return weightTagPrefix + name + "-" + short + return hex } diff --git a/pkg/model/weight_pusher_test.go b/pkg/model/weight_pusher_test.go index 162bcb9ed6..fde36d6c01 100644 --- a/pkg/model/weight_pusher_test.go +++ b/pkg/model/weight_pusher_test.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "sync" + "sync/atomic" "testing" "time" @@ -14,9 +15,57 @@ import ( "github.com/google/go-containerregistry/pkg/v1/types" "github.com/stretchr/testify/require" + "github.com/replicate/cog/pkg/model/weightsource" "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" ) +// packTestLayers packs a directory containing a single file into tar +// layers and returns the layer results. Used as a fixture builder so tests +// don't each reimplement packing. +func packTestLayers(t *testing.T, filename string, content []byte) (sourceDir string, layers []packedLayer, st store.Store) { + t.Helper() + sourceDir = t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(sourceDir, filename), content, 0o644)) + + st, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + + src, err := weightsource.NewFileSource("file://"+sourceDir, "") + require.NoError(t, err) + inv, err := src.Inventory(t.Context()) + require.NoError(t, err) + require.NoError(t, ingressFromInventory(t.Context(), src, st, inv)) + + pkr := newPacker(nil) + pl := pkr.planLayers(inv) + require.NotEmpty(t, pl.Layers) + layers, err = pkr.computeLayerDigests(t.Context(), st, pl) + require.NoError(t, err) + require.NotEmpty(t, layers) + return sourceDir, layers, st +} + +// newTestWeightArtifact builds a WeightArtifact with packed layers and a +// fresh manifest descriptor, suitable for push tests. +func newTestWeightArtifact(t *testing.T, name, target string) *WeightArtifact { + t.Helper() + _, layers, st := packTestLayers(t, "config.json", []byte(`{"hidden_size": 768}`)) + + // Build a lock entry from the pack result. + files := []packedFile{{ + Path: "config.json", + Size: int64(len(`{"hidden_size": 768}`)), + Digest: layers[0].Digest.String(), + LayerDigest: layers[0].Digest.String(), + }} + entry := newWeightLockEntry(name, target, lockfile.WeightLockSource{}, files, layers) + artifact, err := buildWeightArtifact(&entry, layers, st) + require.NoError(t, err) + return artifact +} + func TestWeightPusher_Push_ReturnsErrorForNilArtifact(t *testing.T) { reg := &mockRegistry{} pusher := NewWeightPusher(reg) @@ -27,51 +76,40 @@ func TestWeightPusher_Push_ReturnsErrorForNilArtifact(t *testing.T) { require.Contains(t, err.Error(), "artifact is nil") } -func TestWeightPusher_Push_ReturnsErrorForMissingFile(t *testing.T) { +func TestWeightPusher_Push_ReturnsErrorForEmptyRepo(t *testing.T) { + artifact := newTestWeightArtifact(t, "model-v1", "/src/weights") + reg := &mockRegistry{} pusher := NewWeightPusher(reg) - artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, "/nonexistent/path/weights.bin", "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "model-v1", - Target: "/weights/model.bin", - Created: time.Now().UTC(), - }) + _, err := pusher.Push(context.Background(), "", artifact) + require.Error(t, err) + require.Contains(t, err.Error(), "repo is required") +} - _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) +func TestWeightPusher_Push_ReturnsErrorForEmptyLayers(t *testing.T) { + // Empty layer set must be caught before we try to build a manifest. + artifact := newWeightArtifact( + lockfile.WeightLockEntry{Name: "model-v1", Target: "/src/weights"}, + v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "abc"}}, + nil) + reg := &mockRegistry{} + pusher := NewWeightPusher(reg) + + _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) require.Error(t, err) - require.Contains(t, err.Error(), "weight file") + require.Contains(t, err.Error(), "has no layers") } -func TestWeightPusher_Push_PushesCorrectOCIArtifact(t *testing.T) { - // Create a temp weight file - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.safetensors") - weightContent := []byte("fake weight data for testing tarball layer creation") - require.NoError(t, os.WriteFile(weightPath, weightContent, 0o644)) - - created := time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC) - cfg := WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "model-v1", - Target: "/weights/model.safetensors", - Created: created, - } - - desc := v1.Descriptor{ - Digest: v1.Hash{Algorithm: "sha256", Hex: "aabbccddee112233445566778899aabb00112233445566778899aabbccddeeff"}, - } - artifact := NewWeightArtifact("model-v1", desc, weightPath, "/weights/model.safetensors", cfg) +func TestWeightPusher_Push_PushesExpectedManifest(t *testing.T) { + artifact := newTestWeightArtifact(t, "model-v1", "/src/weights") - // Capture what gets pushed - var pushedRefs []string + var pushedRef string var pushedImg v1.Image reg := &mockRegistry{ pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { - pushedRefs = append(pushedRefs, ref) + pushedRef = ref pushedImg = img return nil }, @@ -79,224 +117,170 @@ func TestWeightPusher_Push_PushesCorrectOCIArtifact(t *testing.T) { pusher := NewWeightPusher(reg) result, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) - require.NoError(t, err) require.NotNil(t, result) - // Verify the image was pushed with a single combined tag - require.Len(t, pushedRefs, 1) - require.Equal(t, "r8.im/user/model:weights-model-v1-aabbccddee11", pushedRefs[0]) - require.NotNil(t, pushedImg) + // Tag derives from the set digest (12-char prefix after "sha256:"). + require.Contains(t, pushedRef, "weights-model-v1-") + require.Equal(t, pushedRef, result.Ref) - // Verify manifest structure + // Manifest shape matches spec §2.2: OCI manifest, config blob, layers + // with standard OCI media types, artifactType on the raw manifest. manifest, err := pushedImg.Manifest() require.NoError(t, err) require.Equal(t, types.OCIManifestSchema1, manifest.MediaType) - - // Verify config blob has correct media type require.Equal(t, types.MediaType(MediaTypeWeightConfig), manifest.Config.MediaType) + require.NotEmpty(t, manifest.Config.Digest.Hex) + require.NotEmpty(t, manifest.Layers) + for _, layer := range manifest.Layers { + require.Contains(t, []types.MediaType{ + types.MediaType(mediaTypeOCILayerTar), + types.MediaType(mediaTypeOCILayerTarGzip), + }, layer.MediaType) + } - // Verify config blob content is correct WeightConfig JSON - configBlob, err := pushedImg.RawConfigFile() + // Raw manifest carries artifactType; check it. + rawManifest, err := pushedImg.RawManifest() require.NoError(t, err) - var parsedConfig WeightConfig - require.NoError(t, json.Unmarshal(configBlob, &parsedConfig)) - require.Equal(t, "1.0", parsedConfig.SchemaVersion) - require.Equal(t, "0.15.0", parsedConfig.CogVersion) - require.Equal(t, "model-v1", parsedConfig.Name) - require.Equal(t, "/weights/model.safetensors", parsedConfig.Target) - require.Equal(t, created, parsedConfig.Created) - - // Verify there's exactly one layer (single file = single layer) - require.Len(t, manifest.Layers, 1) - - // Verify layer media type - require.Equal(t, types.MediaType(MediaTypeWeightLayer), manifest.Layers[0].MediaType) - - // Verify layer size matches the tarball wrapping of the weight file - // (tarball will be larger than raw content due to tar headers) - require.Greater(t, manifest.Layers[0].Size, int64(0)) - - // Verify the result contains a valid descriptor - require.NotEmpty(t, result.Descriptor.Digest.String()) + var raw map[string]any + require.NoError(t, json.Unmarshal(rawManifest, &raw)) + require.Equal(t, MediaTypeWeightArtifact, raw["artifactType"]) + + // Manifest-level annotations per spec §2.5. + require.Equal(t, "model-v1", manifest.Annotations[AnnotationV1WeightName]) + require.Equal(t, "/src/weights", manifest.Annotations[AnnotationV1WeightTarget]) + require.Equal(t, artifact.Entry.SetDigest, manifest.Annotations[AnnotationV1WeightSetDigest]) + + // Result descriptor is populated. + require.NotEmpty(t, result.Descriptor.Digest.Hex) require.Greater(t, result.Descriptor.Size, int64(0)) } -func TestWeightPusher_Push_PropagatesPushError(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) - - artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "model-v1", - Target: "/weights/model.bin", - Created: time.Now().UTC(), - }) +func TestWeightPusher_Push_TagDerivesFromSetDigest(t *testing.T) { + // The tag derives from the set digest so content-identical builds land + // at the same tag. + artifact := newTestWeightArtifact(t, "model-v1", "/src/weights") + var pushedRef string reg := &mockRegistry{ pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { - return fmt.Errorf("unauthorized: authentication required") + pushedRef = ref + return nil }, } pusher := NewWeightPusher(reg) _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) + require.NoError(t, err) - require.Error(t, err) - require.Contains(t, err.Error(), "push weight manifest") - require.Contains(t, err.Error(), "unauthorized") + require.Contains(t, pushedRef, "weights-model-v1-") + require.Contains(t, pushedRef, ShortDigest(artifact.Entry.SetDigest)) } -func TestWeightPusher_Push_RawManifestContainsArtifactType(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test weight data"), 0o644)) - - artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "model-v1", - Target: "/weights/model.bin", - Created: time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC), - }) +func TestWeightPusher_Push_CustomTagOverride(t *testing.T) { + artifact := newTestWeightArtifact(t, "model-v1", "/src/weights") - var pushedImg v1.Image + var pushedRef string reg := &mockRegistry{ pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { - pushedImg = img + pushedRef = ref return nil }, } pusher := NewWeightPusher(reg) - _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) - require.NoError(t, err) - - // Parse raw manifest JSON to verify artifactType field - rawManifest, err := pushedImg.RawManifest() + _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact, + WeightPushOptions{Tag: "latest"}) require.NoError(t, err) + require.Equal(t, "r8.im/user/model:latest", pushedRef) +} - var manifestJSON map[string]any - require.NoError(t, json.Unmarshal(rawManifest, &manifestJSON)) - - // artifactType must be present at the manifest level (OCI 1.1) - require.Equal(t, MediaTypeWeightArtifact, manifestJSON["artifactType"]) +func TestWeightPusher_Push_PropagatesPushError(t *testing.T) { + artifact := newTestWeightArtifact(t, "model-v1", "/src/weights") - // config.mediaType must be the weight config type - configMap, ok := manifestJSON["config"].(map[string]any) - require.True(t, ok, "config should be an object") - require.Equal(t, MediaTypeWeightConfig, configMap["mediaType"]) + reg := &mockRegistry{ + pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { + return fmt.Errorf("unauthorized: authentication required") + }, + } - // layers should have exactly one entry with the weight layer media type - layers, ok := manifestJSON["layers"].([]any) - require.True(t, ok, "layers should be an array") - require.Len(t, layers, 1) + pusher := NewWeightPusher(reg) + _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) - layerMap, ok := layers[0].(map[string]any) - require.True(t, ok, "layer should be an object") - require.Equal(t, MediaTypeWeightLayer, layerMap["mediaType"]) + require.Error(t, err) + require.Contains(t, err.Error(), "push weight manifest") + require.Contains(t, err.Error(), "unauthorized") } -func TestWeightPusher_Push_ReturnsErrorForEmptyRepo(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) - - artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "model-v1", - Target: "/weights/model.bin", - Created: time.Now().UTC(), - }) +func TestWeightPusher_Push_PropagatesLayerError(t *testing.T) { + artifact := newTestWeightArtifact(t, "model-v1", "/src/weights") - reg := &mockRegistry{} - pusher := NewWeightPusher(reg) + reg := &mockRegistry{ + writeLayerFunc: func(ctx context.Context, opts registry.WriteLayerOptions) error { + return fmt.Errorf("upload failed: 503 Service Unavailable") + }, + } - _, err := pusher.Push(context.Background(), "", artifact) + pusher := NewWeightPusher(reg) + _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) require.Error(t, err) - require.Contains(t, err.Error(), "repo is required") + require.Contains(t, err.Error(), "push weight layers") + require.Contains(t, err.Error(), "503 Service Unavailable") } -func TestWeightPusher_Push_ReportsProgressViaWriteLayer(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test weight data for progress tracking"), 0o644)) - - artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "model-v1", - Target: "/weights/model.bin", - Created: time.Now().UTC(), - }) +func TestWeightPusher_Push_ReportsProgressPerLayer(t *testing.T) { + artifact := newTestWeightArtifact(t, "model-v1", "/src/weights") - // Track progress updates received via callback var ( - mu sync.Mutex - progress []PushProgress + mu sync.Mutex + events []WeightLayerProgress ) - // Mock WriteLayer to simulate progress updates (caller owns closing the channel) reg := &mockRegistry{ writeLayerFunc: func(ctx context.Context, opts registry.WriteLayerOptions) error { - // Simulate progress updates like the real registry client if opts.ProgressCh != nil { opts.ProgressCh <- v1.Update{Complete: 500, Total: 1000} opts.ProgressCh <- v1.Update{Complete: 1000, Total: 1000} } return nil }, - pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { - return nil - }, + pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return nil }, } pusher := NewWeightPusher(reg) - result, err := pusher.Push(context.Background(), "r8.im/user/model", artifact, WeightPushOptions{ - ProgressFn: func(p PushProgress) { - mu.Lock() - defer mu.Unlock() - progress = append(progress, p) - }, - }) - + _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact, + WeightPushOptions{ + ProgressFn: func(p WeightLayerProgress) { + mu.Lock() + defer mu.Unlock() + events = append(events, p) + }, + }) require.NoError(t, err) - require.NotNil(t, result) - // Verify we received progress updates mu.Lock() defer mu.Unlock() - require.GreaterOrEqual(t, len(progress), 2, "should receive at least 2 progress updates") - - // Verify progress updates contain expected values - require.Equal(t, int64(500), progress[0].Complete) - require.Equal(t, int64(1000), progress[0].Total) - require.Equal(t, int64(1000), progress[1].Complete) - require.Equal(t, int64(1000), progress[1].Total) + require.NotEmpty(t, events) + // Every event should carry a layer digest that matches one of the + // artifact's layers. + digestsSeen := map[string]bool{} + for _, e := range events { + digestsSeen[e.LayerDigest] = true + } + for _, l := range artifact.Layers { + require.True(t, digestsSeen[l.Digest.String()], + "expected progress for layer %s", l.Digest) + } } func TestWeightPusher_Push_ForwardsRetryCallback(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test weight data"), 0o644)) - - artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "model-v1", - Target: "/weights/model.bin", - Created: time.Now().UTC(), - }) + artifact := newTestWeightArtifact(t, "model-v1", "/src/weights") - // Mock WriteLayer to capture the retry config and invoke it var retryEvents []WeightRetryEvent + var mu sync.Mutex reg := &mockRegistry{ writeLayerFunc: func(ctx context.Context, opts registry.WriteLayerOptions) error { - // Simulate the registry invoking the retry callback if opts.Retry != nil && opts.Retry.OnRetry != nil { opts.Retry.OnRetry(registry.RetryEvent{ Attempt: 1, @@ -307,80 +291,111 @@ func TestWeightPusher_Push_ForwardsRetryCallback(t *testing.T) { } return nil }, - pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { - return nil - }, + pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return nil }, } pusher := NewWeightPusher(reg) - _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact, WeightPushOptions{ - RetryFn: func(event WeightRetryEvent) bool { - retryEvents = append(retryEvents, event) - return true - }, - }) - + _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact, + WeightPushOptions{ + RetryFn: func(event WeightRetryEvent) bool { + mu.Lock() + defer mu.Unlock() + retryEvents = append(retryEvents, event) + return true + }, + }) require.NoError(t, err) - require.Len(t, retryEvents, 1) - require.Equal(t, "model-v1", retryEvents[0].Name) - require.Equal(t, 1, retryEvents[0].Attempt) - require.Equal(t, 3, retryEvents[0].MaxAttempts) - require.Contains(t, retryEvents[0].Err.Error(), "connection reset") - require.Equal(t, 2*time.Second, retryEvents[0].NextRetryIn) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, retryEvents) + + ev := retryEvents[0] + require.Contains(t, ev.Name, "model-v1") + require.Contains(t, ev.Name, "layer sha256:") + require.Equal(t, 1, ev.Attempt) + require.Equal(t, 3, ev.MaxAttempts) + require.Contains(t, ev.Err.Error(), "connection reset") + require.Equal(t, 2*time.Second, ev.NextRetryIn) } -func TestWeightPusher_Push_WriteLayerErrorReported(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) - - artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "model-v1", - Target: "/weights/model.bin", - Created: time.Now().UTC(), - }) +func TestWeightPusher_Push_PropagatesContextCancellation(t *testing.T) { + artifact := newTestWeightArtifact(t, "model-v1", "/src/weights") + + ctx, cancel := context.WithCancel(context.Background()) + cancel() reg := &mockRegistry{ writeLayerFunc: func(ctx context.Context, opts registry.WriteLayerOptions) error { - return fmt.Errorf("upload failed: 503 Service Unavailable") + return ctx.Err() + }, + pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { + return ctx.Err() }, } pusher := NewWeightPusher(reg) - _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) + _, err := pusher.Push(ctx, "r8.im/user/model", artifact) require.Error(t, err) - require.Contains(t, err.Error(), "push weight layer") - require.Contains(t, err.Error(), "503 Service Unavailable") + require.ErrorIs(t, err, context.Canceled) } -func TestWeightPusher_Push_PropagatesContextCancellation(t *testing.T) { - dir := t.TempDir() - weightPath := filepath.Join(dir, "model.bin") - require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) - - artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ - SchemaVersion: "1.0", - CogVersion: "0.15.0", - Name: "model-v1", - Target: "/weights/model.bin", - Created: time.Now().UTC(), +func TestWeightPusher_Push_HonoursConcurrencyLimit(t *testing.T) { + // Pack a source with enough large files that we end up with multiple + // layers. Since test data is small, we rely on tuning bundle_file_max + // so every file lands in its own layer. + sourceDir := t.TempDir() + const n = 4 + for i := range n { + require.NoError(t, os.WriteFile( + filepath.Join(sourceDir, fmt.Sprintf("w-%d.safetensors", i)), + fmt.Appendf(nil, "payload %d", i), + 0o644, + )) + } + + st, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + src, err := weightsource.NewFileSource("file://"+sourceDir, "") + require.NoError(t, err) + inv, err := src.Inventory(t.Context()) + require.NoError(t, err) + require.NoError(t, ingressFromInventory(t.Context(), src, st, inv)) + pkr := newPacker(&packOptions{ + BundleFileMax: 1, // every file becomes its own layer }) + pl := pkr.planLayers(inv) + layers, err := pkr.computeLayerDigests(t.Context(), st, pl) + require.NoError(t, err) + files := packedFilesFromPlan(layers) + require.GreaterOrEqual(t, len(layers), n, "expected a layer per file") - ctx, cancel := context.WithCancel(context.Background()) - cancel() // Cancel immediately + entry := newWeightLockEntry("model", "/src/weights", lockfile.WeightLockSource{}, files, layers) + artifact, err := buildWeightArtifact(&entry, layers, st) + require.NoError(t, err) + var inFlight, maxInFlight atomic.Int32 reg := &mockRegistry{ - pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { - return ctx.Err() + writeLayerFunc: func(ctx context.Context, opts registry.WriteLayerOptions) error { + cur := inFlight.Add(1) + for { + old := maxInFlight.Load() + if cur <= old || maxInFlight.CompareAndSwap(old, cur) { + break + } + } + time.Sleep(10 * time.Millisecond) + inFlight.Add(-1) + return nil }, + pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return nil }, } pusher := NewWeightPusher(reg) - _, err := pusher.Push(ctx, "r8.im/user/model", artifact) - - require.Error(t, err) - require.Contains(t, err.Error(), "context canceled") + _, err = pusher.Push(context.Background(), "r8.im/user/model", artifact, + WeightPushOptions{Concurrency: 2}) + require.NoError(t, err) + require.LessOrEqual(t, int(maxInFlight.Load()), 2, + "concurrency limit not honored") } diff --git a/pkg/model/weights.go b/pkg/model/weights.go index e2e1283883..b8585e3f0a 100644 --- a/pkg/model/weights.go +++ b/pkg/model/weights.go @@ -1,23 +1,113 @@ package model -// WeightFile represents a single weight file entry in a weights lockfile or manifest. -// The Name field is an identifier/handle (like a Docker volume name), not a filename. -type WeightFile struct { - // Name is the identifier/handle for this weight (e.g., "personaplex-7b-v1", "model-v42.5"). - // This is a logical name that maps to deployment blob metadata, not a file path. - Name string `json:"name"` - // Dest is the mount path in the container (e.g., /cache/model.safetensors). - Dest string `json:"dest"` - // DigestOriginal is the SHA256 of the uncompressed file (canonical ID). - DigestOriginal string `json:"digestOriginal"` - // Digest is the SHA256 of the compressed blob (OCI layer ID). +import ( + "encoding/json" + "fmt" + "slices" + "strings" + + "github.com/replicate/cog/pkg/weights/lockfile" +) + +// newWeightLockEntry assembles a lockfile.WeightLockEntry from a source +// description, the packed file index, and the set of packed layers +// produced by pack. +// +// The set digest (spec §2.4) is computed from the canonical file index. +// The manifest digest is left empty — it is filled in by the caller after +// buildWeightManifestV1 assembles the manifest from this entry. +func newWeightLockEntry( + name, target string, + source lockfile.WeightLockSource, + files []packedFile, + layers []packedLayer, +) lockfile.WeightLockEntry { + lockFiles := make([]lockfile.WeightLockFile, len(files)) + for i, f := range files { + lockFiles[i] = lockfile.WeightLockFile{ + Path: f.Path, + Size: f.Size, + Digest: f.Digest, + Layer: f.LayerDigest, + } + } + + lockLayers := make([]lockfile.WeightLockLayer, len(layers)) + var totalSize, totalCompressed int64 + for i, l := range layers { + lockLayers[i] = lockfile.WeightLockLayer{ + Digest: l.Digest.String(), + MediaType: string(l.MediaType), + Size: l.Size, + SizeUncompressed: l.UncompressedSize, + } + totalSize += l.UncompressedSize + totalCompressed += l.Size + } + + entry := lockfile.WeightLockEntry{ + Name: name, + Target: target, + Source: source, + Size: totalSize, + SizeCompressed: totalCompressed, + Files: lockFiles, + Layers: lockLayers, + } + entry.SetDigest = entry.ComputeSetDigest() + return entry +} + +// WeightConfigBlob is the JSON structure for the config blob (§2.3). +type WeightConfigBlob struct { + Name string `json:"name"` + Target string `json:"target"` + SetDigest string `json:"setDigest"` + Files []WeightConfigFile `json:"files"` +} + +// WeightConfigFile is one entry in the config blob's files array. +type WeightConfigFile struct { + Path string `json:"path"` + Layer string `json:"layer"` + Size int64 `json:"size"` Digest string `json:"digest"` - // Size is the compressed size in bytes. - Size int64 `json:"size"` - // SizeUncompressed is the original size in bytes. - SizeUncompressed int64 `json:"sizeUncompressed"` - // MediaType is the OCI layer media type (e.g., application/vnd.cog.weight.layer.v1+gzip). - MediaType string `json:"mediaType"` - // ContentType is the file's MIME type (e.g., application/octet-stream). - ContentType string `json:"contentType,omitempty"` +} + +// buildWeightConfigBlob builds the serialized config blob JSON (§2.3). +// The setDigest and file index come from the lockfile entry — the +// lockfile is the single source of truth for these values. +func buildWeightConfigBlob(name, target, setDigest string, files []lockfile.WeightLockFile) ([]byte, error) { + if len(files) == 0 { + return nil, fmt.Errorf("no files for config blob") + } + + // Sort by path for deterministic output (§2.3: "files array MUST be + // sorted by path lexicographically"). Clone to avoid mutating the + // caller's slice. + sorted := slices.Clone(files) + slices.SortFunc(sorted, func(a, b lockfile.WeightLockFile) int { + return strings.Compare(a.Path, b.Path) + }) + + cfg := WeightConfigBlob{ + Name: name, + Target: target, + SetDigest: setDigest, + Files: make([]WeightConfigFile, len(sorted)), + } + for i, f := range sorted { + cfg.Files[i] = WeightConfigFile{ + Path: f.Path, + Layer: f.Layer, + Size: f.Size, + Digest: f.Digest, + } + } + + configJSON, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal config blob: %w", err) + } + return configJSON, nil } diff --git a/pkg/model/weights_lock.go b/pkg/model/weights_lock.go deleted file mode 100644 index 57385ca74f..0000000000 --- a/pkg/model/weights_lock.go +++ /dev/null @@ -1,53 +0,0 @@ -// pkg/model/weights_lock.go -package model - -import ( - "encoding/json" - "fmt" - "os" - "time" -) - -// WeightsLockFilename is the default filename for the weights lock file. -const WeightsLockFilename = "weights.lock" - -// WeightsLock represents a weights.lock file that pins weight file metadata. -// This is a placeholder format that will be replaced by the declarative weights implementation. -type WeightsLock struct { - // Version is the lockfile format version. - Version string `json:"version"` - // Created is when the lockfile was generated. - Created time.Time `json:"created"` - // Files are the weight file entries. - Files []WeightFile `json:"files"` -} - -// ParseWeightsLock parses a weights.lock JSON document. -func ParseWeightsLock(data []byte) (*WeightsLock, error) { - var lock WeightsLock - if err := json.Unmarshal(data, &lock); err != nil { - return nil, fmt.Errorf("parse weights.lock: %w", err) - } - return &lock, nil -} - -// LoadWeightsLock loads a weights.lock file from disk. -func LoadWeightsLock(path string) (*WeightsLock, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("read weights.lock: %w", err) - } - return ParseWeightsLock(data) -} - -// Save writes the weights.lock to disk. -func (wl *WeightsLock) Save(path string) error { - data, err := json.MarshalIndent(wl, "", " ") - if err != nil { - return fmt.Errorf("marshal weights.lock: %w", err) - } - if err := os.WriteFile(path, data, 0o644); err != nil { - return fmt.Errorf("write weights.lock: %w", err) - } - return nil -} diff --git a/pkg/model/weights_lock_test.go b/pkg/model/weights_lock_test.go deleted file mode 100644 index 79a856d0e4..0000000000 --- a/pkg/model/weights_lock_test.go +++ /dev/null @@ -1,73 +0,0 @@ -// pkg/model/weights_lock_test.go -package model - -import ( - "os" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestWeightsLock(t *testing.T) { - t.Run("parse valid lockfile", func(t *testing.T) { - json := `{ - "version": "1", - "created": "2026-01-30T12:00:00Z", - "files": [ - { - "name": "model.safetensors", - "dest": "/cache/model.safetensors", - "digestOriginal": "sha256:abc123", - "digest": "sha256:def456", - "size": 1000, - "sizeUncompressed": 2000, - "mediaType": "application/vnd.cog.weights.layer.v1+gzip" - } - ] - }` - - lock, err := ParseWeightsLock([]byte(json)) - require.NoError(t, err) - require.Equal(t, "1", lock.Version) - require.Len(t, lock.Files, 1) - require.Equal(t, "model.safetensors", lock.Files[0].Name) - require.Equal(t, "/cache/model.safetensors", lock.Files[0].Dest) - require.Equal(t, "sha256:abc123", lock.Files[0].DigestOriginal) - require.Equal(t, "sha256:def456", lock.Files[0].Digest) - require.Equal(t, int64(1000), lock.Files[0].Size) - }) - - t.Run("load from file", func(t *testing.T) { - dir := t.TempDir() - lockPath := filepath.Join(dir, "weights.lock") - content := `{"version": "1", "created": "2026-01-30T12:00:00Z", "files": []}` - require.NoError(t, os.WriteFile(lockPath, []byte(content), 0o644)) - - lock, err := LoadWeightsLock(lockPath) - require.NoError(t, err) - require.Equal(t, "1", lock.Version) - }) - - t.Run("save to file", func(t *testing.T) { - dir := t.TempDir() - lockPath := filepath.Join(dir, "weights.lock") - - lock := &WeightsLock{ - Version: "1", - Created: time.Date(2026, 1, 30, 12, 0, 0, 0, time.UTC), - Files: []WeightFile{ - {Name: "test.bin", Dest: "/cache/test.bin"}, - }, - } - - require.NoError(t, lock.Save(lockPath)) - - loaded, err := LoadWeightsLock(lockPath) - require.NoError(t, err) - require.Equal(t, lock.Version, loaded.Version) - require.Len(t, loaded.Files, 1) - }) - -} diff --git a/pkg/model/weights_status.go b/pkg/model/weights_status.go new file mode 100644 index 0000000000..97ac0cda34 --- /dev/null +++ b/pkg/model/weights_status.go @@ -0,0 +1,250 @@ +package model + +import ( + "context" + + "golang.org/x/sync/errgroup" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/weights/lockfile" +) + +// WeightStatus describes the resolved status of a weight. +type WeightStatus string + +const ( + WeightStatusReady WeightStatus = "ready" + WeightStatusIncomplete WeightStatus = "incomplete" + WeightStatusStale WeightStatus = "stale" + WeightStatusPending WeightStatus = "pending" + WeightStatusOrphaned WeightStatus = "orphaned" +) + +// LayerStatus describes the registry presence of a single layer. +type LayerStatus string + +const ( + LayerStatusReady LayerStatus = "ready" + LayerStatusMissing LayerStatus = "missing" +) + +// WeightStatusResult is one weight's resolved status. The LockEntry pointer +// is nil for pending weights; non-nil for all other statuses. Layers is +// populated only for weights that had registry checks performed. +type WeightStatusResult struct { + Name string + Target string + Status WeightStatus + LockEntry *lockfile.WeightLockEntry + Layers []LayerStatusResult +} + +// LayerStatusResult is one layer's status in the registry. +type LayerStatusResult struct { + Digest string + Size int64 + Status LayerStatus +} + +// WeightsStatus is the computed status of all weights for a model. +// It is the return value of ComputeWeightsStatus and provides methods +// to inspect the results. +type WeightsStatus struct { + results []WeightStatusResult +} + +// ComputeWeightsStatus determines the status of every weight by matching +// config declarations against the lockfile and checking the registry for +// per-layer blob presence. +// +// Registry checks run concurrently, bounded by GetPushConcurrency(). +// Per-weight registry errors are soft: the weight is marked "incomplete" +// and layers are marked "missing". +// Context cancellation is propagated via errgroup and returns an error. +func ComputeWeightsStatus(ctx context.Context, cfg *config.Config, lock *lockfile.WeightsLock, repo string, reg registry.Client) (*WeightsStatus, error) { + lockByName := make(map[string]*lockfile.WeightLockEntry) + if lock != nil { + for i := range lock.Weights { + lockByName[lock.Weights[i].Name] = &lock.Weights[i] + } + } + + configNames := make(map[string]bool, len(cfg.Weights)) + + // First pass: config-declared weights. Determine local status + // (pending, stale, or needs-registry-check). + results := make([]WeightStatusResult, 0, len(cfg.Weights)+len(lockByName)) + var needRegistryCheck []int // indices into results + + for _, w := range cfg.Weights { + configNames[w.Name] = true + le := lockByName[w.Name] + + r := WeightStatusResult{ + Name: w.Name, + Target: w.Target, + LockEntry: le, + } + + switch { + case le == nil: + r.Status = WeightStatusPending + case isStale(w, le): + r.Status = WeightStatusStale + case len(le.Layers) > 0: + // Config matches lockfile, has layers. Registry check needed. + needRegistryCheck = append(needRegistryCheck, len(results)) + default: + // No layers to check (edge case: lockfile entry with no layers). + r.Status = WeightStatusReady + } + + results = append(results, r) + } + + // Orphaned: in lockfile but not in config. + for i := range lockByName { + if configNames[i] { + continue + } + le := lockByName[i] + results = append(results, WeightStatusResult{ + Name: le.Name, + Target: le.Target, + Status: WeightStatusOrphaned, + LockEntry: le, + }) + } + + // Second pass: concurrent per-layer registry checks. + if len(needRegistryCheck) > 0 { + if err := checkRegistryLayers(ctx, results, needRegistryCheck, repo, reg); err != nil { + return nil, err + } + } + + return &WeightsStatus{results: results}, nil +} + +// statusCheckConcurrency is the concurrency limit for registry HEAD +// requests during status checks. These are lightweight operations, +// not bandwidth-saturating uploads. +const statusCheckConcurrency = 8 + +// checkRegistryLayers checks layer blob existence in the registry for each +// weight that needs verification. Each weight's layers are checked +// concurrently. The weight's status is set to "ready" if all layers exist, +// "incomplete" otherwise. +func checkRegistryLayers(ctx context.Context, results []WeightStatusResult, indices []int, repo string, reg registry.Client) error { + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(statusCheckConcurrency) + + for _, idx := range indices { + r := &results[idx] + g.Go(func() error { + if err := ctx.Err(); err != nil { + return err + } + return checkWeightLayers(ctx, r, repo, reg) + }) + } + + return g.Wait() +} + +// checkWeightLayers checks each layer of a single weight against the registry +// and populates the result's Layers and Status fields. +func checkWeightLayers(ctx context.Context, r *WeightStatusResult, repo string, reg registry.Client) error { + le := r.LockEntry + r.Layers = make([]LayerStatusResult, len(le.Layers)) + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(statusCheckConcurrency) + + for i, layer := range le.Layers { + lr := &r.Layers[i] + lr.Digest = layer.Digest + lr.Size = layer.Size + + g.Go(func() error { + if err := ctx.Err(); err != nil { + return err + } + + exists, err := reg.BlobExists(ctx, repo, layer.Digest) + if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + lr.Status = LayerStatusMissing + return nil + } + + if exists { + lr.Status = LayerStatusReady + } else { + lr.Status = LayerStatusMissing + } + return nil + }) + } + + if err := g.Wait(); err != nil { + return err + } + + // Derive weight status from layer results after all goroutines complete. + r.Status = WeightStatusReady + for _, lr := range r.Layers { + if lr.Status != LayerStatusReady { + r.Status = WeightStatusIncomplete + break + } + } + return nil +} + +// isStale reports whether a config declaration has drifted from its +// lockfile entry. An invalid config (URI that fails normalization) is +// treated as stale: the user asked for something we can't represent, so +// the safe answer is "out of sync". +// +// Source.Fingerprint and Source.ImportedAt are lockfile-side metadata, +// not user-declared inputs, and are excluded from the comparison. +func isStale(w config.WeightSource, le *lockfile.WeightLockEntry) bool { + configSpec, err := WeightSpecFromConfig(w) + if err != nil { + return true + } + return !configSpec.Equal(WeightSpecFromLock(*le)) +} + +// Results returns all weight status results in order: config-declared +// weights first (preserving cog.yaml order), then orphaned lockfile +// entries. +func (ws *WeightsStatus) Results() []WeightStatusResult { + return ws.results +} + +// AllReady reports whether every weight is in the "ready" state. +// Returns true for empty weight lists. +func (ws *WeightsStatus) AllReady() bool { + for _, r := range ws.results { + if r.Status != WeightStatusReady { + return false + } + } + return true +} + +// ByStatus returns all results with the given status. +func (ws *WeightsStatus) ByStatus(status WeightStatus) []WeightStatusResult { + var out []WeightStatusResult + for _, r := range ws.results { + if r.Status == status { + out = append(out, r) + } + } + return out +} diff --git a/pkg/model/weights_status_test.go b/pkg/model/weights_status_test.go new file mode 100644 index 0000000000..12d5a524eb --- /dev/null +++ b/pkg/model/weights_status_test.go @@ -0,0 +1,711 @@ +package model + +import ( + "context" + "fmt" + "testing" + + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/model/weightsource" + "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/weights/lockfile" +) + +// --- mock registry --- + +// statusMockRegistry implements registry.Client with controllable blob existence. +type statusMockRegistry struct { + // blobs maps "repo/digest" -> exists + blobs map[string]bool + // blobErr if set, BlobExists returns this error for all calls + blobErr error +} + +func newMockRegistry() *statusMockRegistry { + return &statusMockRegistry{blobs: make(map[string]bool)} +} + +func (m *statusMockRegistry) addBlob(repo, digest string) { + m.blobs[repo+"/"+digest] = true +} + +func (m *statusMockRegistry) BlobExists(_ context.Context, repo string, digest string) (bool, error) { + if m.blobErr != nil { + return false, m.blobErr + } + return m.blobs[repo+"/"+digest], nil +} + +// Unused interface methods — satisfy registry.Client. +func (m *statusMockRegistry) Inspect(_ context.Context, _ string, _ *registry.Platform) (*registry.ManifestResult, error) { + return nil, nil +} +func (m *statusMockRegistry) GetImage(_ context.Context, _ string, _ *registry.Platform) (v1.Image, error) { + return nil, nil +} +func (m *statusMockRegistry) Exists(_ context.Context, _ string) (bool, error) { return false, nil } +func (m *statusMockRegistry) GetDescriptor(_ context.Context, _ string) (v1.Descriptor, error) { + return v1.Descriptor{}, nil +} +func (m *statusMockRegistry) PushImage(_ context.Context, _ string, _ v1.Image) error { return nil } +func (m *statusMockRegistry) PushIndex(_ context.Context, _ string, _ v1.ImageIndex) error { + return nil +} +func (m *statusMockRegistry) WriteLayer(_ context.Context, _ registry.WriteLayerOptions) error { + return nil +} + +// --- helpers --- + +func lockEntry(name, target, uri, digest string, layers ...lockfile.WeightLockLayer) lockfile.WeightLockEntry { + return lockfile.WeightLockEntry{ + Name: name, + Target: target, + Source: lockfile.WeightLockSource{URI: uri, Include: []string{}, Exclude: []string{}}, + Digest: digest, + SetDigest: digest, + Layers: layers, + } +} + +func layer(digest string, size int64) lockfile.WeightLockLayer { + return lockfile.WeightLockLayer{Digest: digest, Size: size} +} + +func computeStatus(t *testing.T, cfg *config.Config, lock *lockfile.WeightsLock, repo string, reg registry.Client) *WeightsStatus { + t.Helper() + ws, err := ComputeWeightsStatus(context.Background(), cfg, lock, repo, reg) + require.NoError(t, err) + return ws +} + +// --- config vs lockfile (no registry needed for pending/stale/orphaned) --- + +func TestWeightStatuses_AllPending(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/weights/base", Source: &config.WeightSourceConfig{URI: "file://./weights"}}, + {Name: "lora", Target: "/weights/lora"}, + }, + } + + ws := computeStatus(t, cfg, nil, "repo", newMockRegistry()) + results := ws.Results() + + require.Len(t, results, 2) + assert.Equal(t, "base", results[0].Name) + assert.Equal(t, "/weights/base", results[0].Target) + assert.Equal(t, WeightStatusPending, results[0].Status) + assert.Nil(t, results[0].LockEntry) + + assert.Equal(t, "lora", results[1].Name) + assert.Equal(t, WeightStatusPending, results[1].Status) +} + +func TestWeightStatuses_BarePathNormalization(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "parakeet", Target: "/src/weights", Source: &config.WeightSourceConfig{URI: "weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + { + Name: "parakeet", + Target: "/src/weights", + Source: lockfile.WeightLockSource{URI: "file://./weights", Include: []string{}, Exclude: []string{}}, + Digest: "sha256:abc", SetDigest: "sha256:abc", + Layers: []lockfile.WeightLockLayer{layer("sha256:l1", 1024)}, + }, + }, + } + + reg := newMockRegistry() + reg.addBlob("repo", "sha256:l1") + + ws := computeStatus(t, cfg, lock, "repo", reg) + require.Len(t, ws.Results(), 1) + assert.Equal(t, WeightStatusReady, ws.Results()[0].Status, "bare path 'weights' should match normalized 'file://./weights'") +} + +func TestWeightStatuses_DotSlashPathNormalization(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "w", Target: "/w", Source: &config.WeightSourceConfig{URI: "./weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("w", "/w", "file://./weights", "sha256:abc", layer("sha256:l1", 100)), + }, + } + + reg := newMockRegistry() + reg.addBlob("repo", "sha256:l1") + + ws := computeStatus(t, cfg, lock, "repo", reg) + assert.Equal(t, WeightStatusReady, ws.Results()[0].Status) +} + +func TestWeightStatuses_StaleTarget(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/weights/v2", Source: &config.WeightSourceConfig{URI: "file://./weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/weights/base", "file://./weights", "sha256:abc", layer("sha256:l1", 100)), + }, + } + + ws := computeStatus(t, cfg, lock, "repo", newMockRegistry()) + assert.Equal(t, WeightStatusStale, ws.Results()[0].Status) +} + +func TestWeightStatuses_StaleSourceURI(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{URI: "file://./new-weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/w", "file://./weights", "sha256:abc", layer("sha256:l1", 100)), + }, + } + + ws := computeStatus(t, cfg, lock, "repo", newMockRegistry()) + assert.Equal(t, WeightStatusStale, ws.Results()[0].Status) +} + +func TestWeightStatuses_StaleIncludePatterns(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{ + URI: "file://./weights", + Include: []string{"*.bin"}, + }}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/w", "file://./weights", "sha256:abc", layer("sha256:l1", 100)), + }, + } + + ws := computeStatus(t, cfg, lock, "repo", newMockRegistry()) + assert.Equal(t, WeightStatusStale, ws.Results()[0].Status) +} + +func TestWeightStatuses_StaleExcludePatterns(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{ + URI: "file://./weights", + Exclude: []string{"*.tmp"}, + }}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/w", "file://./weights", "sha256:abc", layer("sha256:l1", 100)), + }, + } + + ws := computeStatus(t, cfg, lock, "repo", newMockRegistry()) + assert.Equal(t, WeightStatusStale, ws.Results()[0].Status) +} + +func TestWeightStatuses_NotStaleWithMatchingPatterns(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{ + URI: "file://./weights", + Include: []string{"*.bin", "*.safetensors"}, + Exclude: []string{"*.tmp"}, + }}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + { + Name: "base", + Target: "/w", + Source: lockfile.WeightLockSource{ + URI: "file://./weights", + Include: []string{"*.bin", "*.safetensors"}, + Exclude: []string{"*.tmp"}, + }, + Digest: "sha256:abc", SetDigest: "sha256:abc", + Layers: []lockfile.WeightLockLayer{layer("sha256:l1", 100)}, + }, + }, + } + + reg := newMockRegistry() + reg.addBlob("repo", "sha256:l1") + + ws := computeStatus(t, cfg, lock, "repo", reg) + assert.Equal(t, WeightStatusReady, ws.Results()[0].Status) +} + +func TestWeightStatuses_CogYAMLReorderingNotStale(t *testing.T) { + // Patterns in cog.yaml can appear in any order; what matters is the + // set. The lockfile always stores them in canonical (sorted) form, + // so a cog.yaml reorder against a canonical lockfile reports ready. + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{ + URI: "file://./weights", + Include: []string{"*.safetensors", "*.bin"}, + Exclude: []string{"*.onnx", "*.tmp"}, + }}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + { + Name: "base", + Target: "/w", + Source: lockfile.WeightLockSource{ + URI: "file://./weights", + Include: []string{"*.bin", "*.safetensors"}, + Exclude: []string{"*.onnx", "*.tmp"}, + }, + Digest: "sha256:abc", SetDigest: "sha256:abc", + Layers: []lockfile.WeightLockLayer{layer("sha256:l1", 100)}, + }, + }, + } + + reg := newMockRegistry() + reg.addBlob("repo", "sha256:l1") + + ws := computeStatus(t, cfg, lock, "repo", reg) + assert.Equal(t, WeightStatusReady, ws.Results()[0].Status) +} + +func TestWeightStatuses_UnsortedLockfileIsStale(t *testing.T) { + // A lockfile whose on-disk form does not match the canonical form + // we would write today must report stale so the next build rewrites + // it. Here the lockfile has unsorted include patterns — a fresh + // build from this config would produce sorted patterns, so the + // lockfile is out of date. + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{ + URI: "file://./weights", + Include: []string{"*.bin", "*.safetensors"}, + }}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + { + Name: "base", + Target: "/w", + Source: lockfile.WeightLockSource{ + URI: "file://./weights", + Include: []string{"*.safetensors", "*.bin"}, + Exclude: []string{}, + }, + Digest: "sha256:abc", SetDigest: "sha256:abc", + Layers: []lockfile.WeightLockLayer{layer("sha256:l1", 100)}, + }, + }, + } + + ws := computeStatus(t, cfg, lock, "repo", newMockRegistry()) + assert.Equal(t, WeightStatusStale, ws.Results()[0].Status) +} + +func TestWeightStatuses_Orphaned(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/weights/base", Source: &config.WeightSourceConfig{URI: "file://./weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/weights/base", "file://./weights", "sha256:abc", layer("sha256:l1", 100)), + { + Name: "old-weight", Target: "/weights/old", + Source: lockfile.WeightLockSource{URI: "file://./old", Include: []string{}, Exclude: []string{}}, + Digest: "sha256:def", Size: 2048, + }, + }, + } + + reg := newMockRegistry() + reg.addBlob("repo", "sha256:l1") + + ws := computeStatus(t, cfg, lock, "repo", reg) + results := ws.Results() + + require.Len(t, results, 2) + assert.Equal(t, WeightStatusReady, results[0].Status) + assert.Equal(t, WeightStatusOrphaned, results[1].Status) + assert.Equal(t, int64(2048), results[1].LockEntry.Size) +} + +func TestWeightStatuses_EmptyConfig(t *testing.T) { + ws := computeStatus(t, &config.Config{}, nil, "repo", newMockRegistry()) + assert.Empty(t, ws.Results()) + assert.True(t, ws.AllReady()) +} + +func TestWeightStatuses_EmptyConfigWithOrphanedLockEntries(t *testing.T) { + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + {Name: "orphan", Target: "/w", Digest: "sha256:abc"}, + }, + } + + ws := computeStatus(t, &config.Config{}, lock, "repo", newMockRegistry()) + require.Len(t, ws.Results(), 1) + assert.Equal(t, WeightStatusOrphaned, ws.Results()[0].Status) + assert.False(t, ws.AllReady()) +} + +func TestWeightStatuses_NilSourceIsStale(t *testing.T) { + // A weight declared without a source URI is malformed and reports stale. + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w"}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + { + Name: "base", Target: "/w", + Source: lockfile.WeightLockSource{URI: "", Include: []string{}, Exclude: []string{}}, + Digest: "sha256:abc", + }, + }, + } + + ws := computeStatus(t, cfg, lock, "repo", newMockRegistry()) + assert.Equal(t, WeightStatusStale, ws.Results()[0].Status) +} + +func TestWeightStatuses_FingerprintNotCompared(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{URI: "file://./weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + { + Name: "base", Target: "/w", + Source: lockfile.WeightLockSource{ + URI: "file://./weights", Include: []string{}, Exclude: []string{}, + Fingerprint: weightsource.Fingerprint("sha256:anything"), + }, + Digest: "sha256:abc", SetDigest: "sha256:abc", + Layers: []lockfile.WeightLockLayer{layer("sha256:l1", 100)}, + }, + }, + } + + reg := newMockRegistry() + reg.addBlob("repo", "sha256:l1") + + ws := computeStatus(t, cfg, lock, "repo", reg) + assert.Equal(t, WeightStatusReady, ws.Results()[0].Status) +} + +func TestWeightStatuses_PreservesConfigOrder(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "charlie", Target: "/c"}, + {Name: "alpha", Target: "/a"}, + {Name: "bravo", Target: "/b"}, + }, + } + + ws := computeStatus(t, cfg, nil, "repo", newMockRegistry()) + results := ws.Results() + + require.Len(t, results, 3) + assert.Equal(t, "charlie", results[0].Name) + assert.Equal(t, "alpha", results[1].Name) + assert.Equal(t, "bravo", results[2].Name) +} + +// --- per-layer registry checks --- + +func TestWeightStatuses_AllLayersPresent(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{URI: "file://./weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/w", "file://./weights", "sha256:abc", + layer("sha256:l1", 1000), + layer("sha256:l2", 2000), + layer("sha256:l3", 3000), + ), + }, + } + + reg := newMockRegistry() + reg.addBlob("repo", "sha256:l1") + reg.addBlob("repo", "sha256:l2") + reg.addBlob("repo", "sha256:l3") + + ws := computeStatus(t, cfg, lock, "repo", reg) + r := ws.Results()[0] + + assert.Equal(t, WeightStatusReady, r.Status) + require.Len(t, r.Layers, 3) + for _, l := range r.Layers { + assert.Equal(t, LayerStatusReady, l.Status) + } +} + +func TestWeightStatuses_SomeLayersMissing(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{URI: "file://./weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/w", "file://./weights", "sha256:abc", + layer("sha256:l1", 1000), + layer("sha256:l2", 2000), + layer("sha256:l3", 3000), + ), + }, + } + + reg := newMockRegistry() + reg.addBlob("repo", "sha256:l1") + // l2 missing + reg.addBlob("repo", "sha256:l3") + + ws := computeStatus(t, cfg, lock, "repo", reg) + r := ws.Results()[0] + + assert.Equal(t, WeightStatusIncomplete, r.Status) + require.Len(t, r.Layers, 3) + assert.Equal(t, LayerStatusReady, r.Layers[0].Status) + assert.Equal(t, LayerStatusMissing, r.Layers[1].Status) + assert.Equal(t, LayerStatusReady, r.Layers[2].Status) +} + +func TestWeightStatuses_AllLayersMissing(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{URI: "file://./weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/w", "file://./weights", "sha256:abc", + layer("sha256:l1", 1000), + layer("sha256:l2", 2000), + ), + }, + } + + ws := computeStatus(t, cfg, lock, "repo", newMockRegistry()) + r := ws.Results()[0] + + assert.Equal(t, WeightStatusIncomplete, r.Status) + for _, l := range r.Layers { + assert.Equal(t, LayerStatusMissing, l.Status) + } +} + +func TestWeightStatuses_LayerSizesPreserved(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{URI: "file://./weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/w", "file://./weights", "sha256:abc", + layer("sha256:l1", 4200000000), + layer("sha256:l2", 800000000), + ), + }, + } + + ws := computeStatus(t, cfg, lock, "repo", newMockRegistry()) + r := ws.Results()[0] + + require.Len(t, r.Layers, 2) + assert.Equal(t, int64(4200000000), r.Layers[0].Size) + assert.Equal(t, int64(800000000), r.Layers[1].Size) +} + +func TestWeightStatuses_RegistryErrorIsIncomplete(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{URI: "file://./weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/w", "file://./weights", "sha256:abc", + layer("sha256:l1", 100), + ), + }, + } + + reg := newMockRegistry() + reg.blobErr = fmt.Errorf("network error") + + ws := computeStatus(t, cfg, lock, "repo", reg) + assert.Equal(t, WeightStatusIncomplete, ws.Results()[0].Status) + assert.Equal(t, LayerStatusMissing, ws.Results()[0].Layers[0].Status) +} + +func TestWeightStatuses_StaleSkipsRegistryCheck(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w/new", Source: &config.WeightSourceConfig{URI: "file://./weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/w/old", "file://./weights", "sha256:abc", + layer("sha256:l1", 100), + ), + }, + } + + // If registry were checked, this would panic since mock has no setup. + ws := computeStatus(t, cfg, lock, "repo", newMockRegistry()) + assert.Equal(t, WeightStatusStale, ws.Results()[0].Status) + assert.Nil(t, ws.Results()[0].Layers) +} + +func TestWeightStatuses_PendingSkipsRegistryCheck(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "new", Target: "/w"}, + }, + } + + ws := computeStatus(t, cfg, nil, "repo", newMockRegistry()) + assert.Equal(t, WeightStatusPending, ws.Results()[0].Status) + assert.Nil(t, ws.Results()[0].Layers) +} + +func TestWeightStatuses_ContextCancellation(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "base", Target: "/w", Source: &config.WeightSourceConfig{URI: "file://./weights"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("base", "/w", "file://./weights", "sha256:abc", + layer("sha256:l1", 100), + ), + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := ComputeWeightsStatus(ctx, cfg, lock, "repo", newMockRegistry()) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestWeightStatuses_MixedWithAllStatuses(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "ready", Target: "/a", Source: &config.WeightSourceConfig{URI: "file://./a"}}, + {Name: "incomplete", Target: "/b", Source: &config.WeightSourceConfig{URI: "file://./b"}}, + {Name: "stale", Target: "/c/new", Source: &config.WeightSourceConfig{URI: "file://./c"}}, + {Name: "pending", Target: "/d"}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("ready", "/a", "file://./a", "sha256:aaa", layer("sha256:la", 100)), + lockEntry("incomplete", "/b", "file://./b", "sha256:bbb", layer("sha256:lb", 100)), + lockEntry("stale", "/c/old", "file://./c", "sha256:ccc", layer("sha256:lc", 100)), + lockEntry("orphan", "/e", "file://./e", "sha256:eee", layer("sha256:le", 100)), + }, + } + + reg := newMockRegistry() + reg.addBlob("repo", "sha256:la") // ready has its layer + // incomplete missing sha256:lb + + ws := computeStatus(t, cfg, lock, "repo", reg) + results := ws.Results() + + require.Len(t, results, 5) + assert.Equal(t, WeightStatusReady, results[0].Status) + assert.Equal(t, WeightStatusIncomplete, results[1].Status) + assert.Equal(t, WeightStatusStale, results[2].Status) + assert.Equal(t, WeightStatusPending, results[3].Status) + assert.Equal(t, WeightStatusOrphaned, results[4].Status) +} + +// --- struct helpers --- + +func TestWeightsStatus_ByStatus(t *testing.T) { + cfg := &config.Config{ + Weights: []config.WeightSource{ + {Name: "a", Target: "/a", Source: &config.WeightSourceConfig{URI: "file://./a"}}, + {Name: "b", Target: "/b"}, + {Name: "c", Target: "/c", Source: &config.WeightSourceConfig{URI: "file://./c"}}, + }, + } + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + lockEntry("a", "/a", "file://./a", "sha256:aaa", layer("sha256:la", 100)), + lockEntry("c", "/c", "file://./c", "sha256:ccc", layer("sha256:lc", 100)), + }, + } + + reg := newMockRegistry() + reg.addBlob("repo", "sha256:la") + reg.addBlob("repo", "sha256:lc") + + ws := computeStatus(t, cfg, lock, "repo", reg) + + ready := ws.ByStatus(WeightStatusReady) + assert.Len(t, ready, 2) + + pending := ws.ByStatus(WeightStatusPending) + assert.Len(t, pending, 1) + assert.Equal(t, "b", pending[0].Name) +} diff --git a/pkg/model/weights_test.go b/pkg/model/weights_test.go index f71f9bfe96..60c6457301 100644 --- a/pkg/model/weights_test.go +++ b/pkg/model/weights_test.go @@ -1,15 +1,238 @@ package model import ( + "encoding/json" + "os" + "path/filepath" "testing" + "time" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/model/weightsource" + "github.com/replicate/cog/pkg/weights/lockfile" ) -func TestWeightFile(t *testing.T) { - t.Run("media type constants", func(t *testing.T) { - require.Equal(t, "application/vnd.cog.weight.layer.v1+gzip", MediaTypeWeightLayerGzip) - require.Equal(t, "application/vnd.cog.weight.v1", MediaTypeWeightArtifact) - require.Equal(t, "application/vnd.cog.weight.layer.v1", MediaTypeWeightLayer) - }) +func TestWeightLockEntry_JSONFieldNames(t *testing.T) { + entry := lockfile.WeightLockEntry{ + Name: "z-image-turbo", + Target: "/src/weights", + Digest: "sha256:abc", + SetDigest: "sha256:def", + Source: lockfile.WeightLockSource{ + URI: "file://./weights", + Fingerprint: weightsource.Fingerprint("sha256:def"), + Include: []string{}, + Exclude: []string{}, + }, + Files: []lockfile.WeightLockFile{ + {Path: "a.json", Size: 100, Digest: "sha256:f01", Layer: "sha256:aaa"}, + }, + Layers: []lockfile.WeightLockLayer{ + {Digest: "sha256:aaa", MediaType: mediaTypeOCILayerTarGzip, Size: 110, SizeUncompressed: 100}, + }, + } + + data, err := json.Marshal(entry) + require.NoError(t, err) + s := string(data) + + // Sanity-check that every documented field name is present. + for _, key := range []string{ + `"name":"z-image-turbo"`, + `"target":"/src/weights"`, + `"digest":"sha256:abc"`, + `"setDigest":"sha256:def"`, + `"source":`, + `"uri":"file://./weights"`, + `"fingerprint":"sha256:def"`, + `"files":`, + `"layers":`, + `"sizeUncompressed":100`, + } { + assert.Contains(t, s, key, "expected field %q in JSON", key) + } +} + +func TestMediaTypeArtifactConstant(t *testing.T) { + require.Equal(t, "application/vnd.cog.weight.v1", MediaTypeWeightArtifact) +} + +func TestMediaTypeWeightConfigConstant(t *testing.T) { + require.Equal(t, "application/vnd.cog.weight.config.v1+json", MediaTypeWeightConfig) +} + +// setDigestOf wraps files in a throwaway entry to compute the set digest. +// Unit coverage for ComputeSetDigest itself lives in pkg/weights/lockfile; +// here we just need a digest value to plug into buildWeightConfigBlob. +func setDigestOf(files []lockfile.WeightLockFile) string { + e := lockfile.WeightLockEntry{Files: files} + return e.ComputeSetDigest() +} + +func TestBuildWeightConfigBlob_Deterministic(t *testing.T) { + files := []lockfile.WeightLockFile{ + {Path: "config.json", Size: 100, Digest: "sha256:aaa", Layer: "sha256:l1"}, + {Path: "model.bin", Size: 9999, Digest: "sha256:bbb", Layer: "sha256:l2"}, + } + sd := setDigestOf(files) + cfg1, err := buildWeightConfigBlob("test-weight", "/src/weights", sd, files) + require.NoError(t, err) + cfg2, err := buildWeightConfigBlob("test-weight", "/src/weights", sd, files) + require.NoError(t, err) + assert.Equal(t, cfg1, cfg2, "config blob must be deterministic") +} + +func TestBuildWeightConfigBlob_Structure(t *testing.T) { + files := []lockfile.WeightLockFile{ + {Path: "config.json", Size: 100, Digest: "sha256:aaa", Layer: "sha256:l1"}, + {Path: "model.bin", Size: 9999, Digest: "sha256:bbb", Layer: "sha256:l2"}, + } + setDigest := setDigestOf(files) + configJSON, err := buildWeightConfigBlob("z-image-turbo", "/src/weights", setDigest, files) + require.NoError(t, err) + + var cfg WeightConfigBlob + require.NoError(t, json.Unmarshal(configJSON, &cfg)) + + assert.Equal(t, "z-image-turbo", cfg.Name) + assert.Equal(t, "/src/weights", cfg.Target) + assert.Equal(t, setDigest, cfg.SetDigest) + require.Len(t, cfg.Files, 2) + + // Files should be sorted by path. + assert.Equal(t, "config.json", cfg.Files[0].Path) + assert.Equal(t, "model.bin", cfg.Files[1].Path) + assert.Equal(t, int64(100), cfg.Files[0].Size) + assert.Equal(t, "sha256:aaa", cfg.Files[0].Digest) + assert.Equal(t, "sha256:l1", cfg.Files[0].Layer) +} + +func TestBuildWeightConfigBlob_RejectsEmptyFiles(t *testing.T) { + _, err := buildWeightConfigBlob("name", "/target", "sha256:000", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no files") +} + +func TestSetDigest_StableAcrossRepacks(t *testing.T) { + // Pack the same directory twice with different thresholds (producing + // different layers) and verify the set digest is identical. + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "b.txt"), []byte("world"), 0o644)) + + pr1, _, err := packTestDir(t, dir, &packOptions{BundleFileMax: 1024, BundleSizeMax: 1024}) + require.NoError(t, err) + + pr2, _, err := packTestDir(t, dir, &packOptions{BundleFileMax: 1, BundleSizeMax: 1}) + require.NoError(t, err) + + entry1 := newWeightLockEntry("w", "/w", lockfile.WeightLockSource{}, pr1.Files, pr1.Layers) + entry2 := newWeightLockEntry("w", "/w", lockfile.WeightLockSource{}, pr2.Files, pr2.Layers) + + assert.Equal(t, entry1.SetDigest, entry2.SetDigest, + "set digest must be stable across different packing strategies") +} + +func TestConfigBlob_DiffersAcrossRepacks(t *testing.T) { + // Different packing parameters → different config blobs (different + // layer digests), but same set digest. + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "b.txt"), []byte("world"), 0o644)) + + pr1, _, err := packTestDir(t, dir, &packOptions{BundleFileMax: 1024, BundleSizeMax: 1024}) + require.NoError(t, err) + + // With BundleFileMax=1, all files are "large" (standalone layers). + pr2, _, err := packTestDir(t, dir, &packOptions{BundleFileMax: 1, BundleSizeMax: 1}) + require.NoError(t, err) + + entry1 := newWeightLockEntry("w", "/w", lockfile.WeightLockSource{}, pr1.Files, pr1.Layers) + entry2 := newWeightLockEntry("w", "/w", lockfile.WeightLockSource{}, pr2.Files, pr2.Layers) + + cfg1, err := buildWeightConfigBlob("w", "/w", entry1.SetDigest, entry1.Files) + require.NoError(t, err) + cfg2, err := buildWeightConfigBlob("w", "/w", entry2.SetDigest, entry2.Files) + require.NoError(t, err) + + // Layer digests differ → config blobs differ. + assert.NotEqual(t, cfg1, cfg2, "config blobs should differ when packing strategy differs") +} + +func TestNewWeightLockEntry_PopulatesFromPackResult(t *testing.T) { + layers := []packedLayer{ + { + Digest: v1.Hash{Algorithm: "sha256", Hex: "aaa"}, + Size: 110, + UncompressedSize: 100, + MediaType: mediaTypeOCILayerTarGzip, + }, + { + Digest: v1.Hash{Algorithm: "sha256", Hex: "bbb"}, + Size: 2000, + UncompressedSize: 2000, + MediaType: mediaTypeOCILayerTar, + }, + } + files := []packedFile{ + {Path: "a.json", Size: 100, Digest: "sha256:f01", LayerDigest: "sha256:aaa"}, + {Path: "b.bin", Size: 2000, Digest: "sha256:f02", LayerDigest: "sha256:bbb"}, + } + src := lockfile.WeightLockSource{ + URI: "file://./weights", + Fingerprint: weightsource.Fingerprint("sha256:setdigest"), + Include: []string{}, + Exclude: []string{}, + ImportedAt: time.Date(2026, 4, 16, 17, 27, 7, 0, time.UTC), + } + + entry := newWeightLockEntry("w", "/src/w", src, files, layers) + + assert.Equal(t, "w", entry.Name) + assert.Equal(t, "/src/w", entry.Target) + assert.Empty(t, entry.Digest, "Digest should be empty (filled by caller after manifest build)") + assert.NotEmpty(t, entry.SetDigest, "SetDigest should be computed internally") + + // Size = sum of uncompressed; SizeCompressed = sum of layer sizes. + assert.Equal(t, int64(100+2000), entry.Size) + assert.Equal(t, int64(110+2000), entry.SizeCompressed) + + require.Len(t, entry.Files, 2) + require.Len(t, entry.Layers, 2) + assert.Equal(t, src, entry.Source) + + // Files sorted by path, layers sorted by digest. + assert.Equal(t, "a.json", entry.Files[0].Path) + assert.Equal(t, "sha256:aaa", entry.Layers[0].Digest) +} + +// TestSetDigest_CrossPath verifies that the packer-based set digest +// (WeightLockEntry.ComputeSetDigest) and the weightsource-based +// fingerprint (computeInventory via FileSource.Inventory) produce the +// same value for the same directory. Both paths feed weightsource.DirHash, +// so this test catches regressions in the Dirhashable adapters on +// WeightLockFile and InventoryFile — e.g. wiring the wrong field into +// DirhashParts. +func TestSetDigest_CrossPath(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "b.txt"), []byte("world"), 0o644)) + + // Path 1: pack, convert to lock entry, compute from the entry. + pr, _, err := packTestDir(t, dir, nil) + require.NoError(t, err) + entry := newWeightLockEntry("w", "/w", lockfile.WeightLockSource{}, pr.Files, pr.Layers) + packerSetDigest := entry.ComputeSetDigest() + + // Path 2: inventory fingerprint from directory walk (the weightsource path). + src, err := weightsource.NewFileSource("file://"+dir, "") + require.NoError(t, err) + inv, err := src.Inventory(t.Context()) + require.NoError(t, err) + + assert.Equal(t, packerSetDigest, inv.Fingerprint.String(), + "packer and weightsource must produce the same set digest for the same directory") } diff --git a/pkg/model/weightsource/dirhash.go b/pkg/model/weightsource/dirhash.go new file mode 100644 index 0000000000..020b88f072 --- /dev/null +++ b/pkg/model/weightsource/dirhash.go @@ -0,0 +1,50 @@ +package weightsource + +import ( + "crypto/sha256" + "encoding/hex" + "sort" + "strings" +) + +// DirhashPart is the atomic input to DirHash: the pair of fields that +// uniquely identify a file's contribution to the dirhash. Path is the +// relative path (forward slashes) and Digest is the file's sha256 content +// digest in "sha256:" form. +type DirhashPart struct { + Path string + Digest string +} + +// Dirhashable is implemented by types that can participate in DirHash. +// Both weightsource.InventoryFile and lockfile.WeightLockFile implement +// it, letting the two call sites share one digest implementation. +type Dirhashable interface { + DirhashParts() DirhashPart +} + +// DirHash computes a content-addressable digest of a file set per spec §2.4: +// +// sha256(join(sort(" "), "\n")) +// +// where each line is the file's sha256 hex digest and relative path joined +// by two spaces (matching sha256sum output). DirHash sorts the lines +// itself, so the caller's input order does not affect the result. +// +// The result is the "sha256:" form. This formula computes the weight +// set digest stored in weights.lock (WeightLockEntry.SetDigest), and is +// also used by file:// sources specifically as their Fingerprint — +// content-addressable stores happen to match their fingerprint to their +// dirhash. Other schemes (hf://, s3://, http://) use scheme-native +// identifiers (commit SHA, ETag, etc.) for their Fingerprint instead. +func DirHash[F Dirhashable](files []F) string { + lines := make([]string, len(files)) + for i, f := range files { + p := f.DirhashParts() + _, hexStr, _ := strings.Cut(p.Digest, ":") + lines[i] = hexStr + " " + p.Path + } + sort.Strings(lines) + sum := sha256.Sum256([]byte(strings.Join(lines, "\n"))) + return "sha256:" + hex.EncodeToString(sum[:]) +} diff --git a/pkg/model/weightsource/dirhash_test.go b/pkg/model/weightsource/dirhash_test.go new file mode 100644 index 0000000000..2aa84cd72e --- /dev/null +++ b/pkg/model/weightsource/dirhash_test.go @@ -0,0 +1,77 @@ +package weightsource + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDirHash_Deterministic(t *testing.T) { + files := []InventoryFile{ + {Path: "a.txt", Digest: "sha256:aaa"}, + {Path: "b.txt", Digest: "sha256:bbb"}, + } + d1 := DirHash(files) + d2 := DirHash(files) + require.Equal(t, d1, d2) + assert.True(t, len(d1) > len("sha256:"), "digest must be non-trivial") +} + +func TestDirHash_InputOrderIndependent(t *testing.T) { + ordered := []InventoryFile{ + {Path: "a.txt", Digest: "sha256:aaa"}, + {Path: "b.txt", Digest: "sha256:bbb"}, + } + reversed := []InventoryFile{ + {Path: "b.txt", Digest: "sha256:bbb"}, + {Path: "a.txt", Digest: "sha256:aaa"}, + } + assert.Equal(t, DirHash(ordered), DirHash(reversed), + "DirHash must sort internally — caller order must not matter") +} + +func TestDirHash_DistinguishesContent(t *testing.T) { + a := []InventoryFile{{Path: "f", Digest: "sha256:aaa"}} + b := []InventoryFile{{Path: "f", Digest: "sha256:bbb"}} + assert.NotEqual(t, DirHash(a), DirHash(b)) +} + +func TestDirHash_DistinguishesPath(t *testing.T) { + a := []InventoryFile{{Path: "foo.txt", Digest: "sha256:aaa"}} + b := []InventoryFile{{Path: "bar.txt", Digest: "sha256:aaa"}} + assert.NotEqual(t, DirHash(a), DirHash(b)) +} + +func TestDirHash_EmptyInput(t *testing.T) { + // Empty input hashes to sha256 of the empty string. Not a failure — + // just documents the behavior. + got := DirHash([]InventoryFile{}) + assert.Equal(t, "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", got) +} + +// fakeFile lets us exercise DirHash with a hand-made Dirhashable type to +// confirm the generic constraint works for types outside this package. +type fakeFile struct { + path string + digest string +} + +func (f fakeFile) DirhashParts() DirhashPart { + return DirhashPart{Path: f.path, Digest: f.digest} +} + +func TestDirHash_ArbitraryType(t *testing.T) { + // Confirm that DirHash is truly generic: any type implementing + // Dirhashable should produce the same digest as an InventoryFile + // carrying the same Path/Digest data. + want := DirHash([]InventoryFile{ + {Path: "a.txt", Digest: "sha256:aaa"}, + {Path: "b.txt", Digest: "sha256:bbb"}, + }) + got := DirHash([]fakeFile{ + {path: "a.txt", digest: "sha256:aaa"}, + {path: "b.txt", digest: "sha256:bbb"}, + }) + assert.Equal(t, want, got) +} diff --git a/pkg/model/weightsource/file.go b/pkg/model/weightsource/file.go new file mode 100644 index 0000000000..0f2500fe7e --- /dev/null +++ b/pkg/model/weightsource/file.go @@ -0,0 +1,149 @@ +package weightsource + +import ( + "context" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" +) + +// FileScheme is the URI scheme for local filesystem sources. +const FileScheme = "file" + +// FileSource is the Source implementation for file:// URIs and bare paths. +// +// URIs take one of these forms: +// +// file:///abs/path — absolute path +// file://./rel/path — canonical relative path (explicit ./) +// /abs/path — bare absolute path (normalized to file://) +// ./rel/path — bare relative path (normalized to file://) +// rel/path — bare relative path, no ./ prefix (normalized) +// +// The lockfile stores only the normalized form (see NormalizeURI); the +// absolute on-disk path is resolved once at construction time so the +// Source methods do not re-resolve on every call. +type FileSource struct { + // dir is the resolved absolute path to the source directory. + dir string + // fsys is an fs.FS rooted at dir. Open uses this instead of raw + // os.Open to guarantee path-escape prevention at the FS boundary. + fsys fs.FS +} + +// NewFileSource constructs a FileSource bound to uri, resolving relative +// URIs against projectDir. It validates that the resolved path exists +// and is a directory. +func NewFileSource(uri, projectDir string) (*FileSource, error) { + path, err := resolvePath(uri, projectDir) + if err != nil { + return nil, err + } + fi, err := os.Stat(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("weight source not found: %s", uri) + } + return nil, fmt.Errorf("stat weight source %s: %w", uri, err) + } + if !fi.IsDir() { + return nil, fmt.Errorf("weight source %s is not a directory (file:// sources must be directories)", uri) + } + return &FileSource{dir: path, fsys: os.DirFS(path)}, nil +} + +// sourceDir returns the resolved absolute path of the source directory. +// Exposed primarily for tests and diagnostics; the import pipeline should +// use Inventory + Open rather than reaching for the on-disk path. +func (s *FileSource) sourceDir() string { return s.dir } + +// Inventory walks the source directory and returns per-file path / size / +// content digest plus the source fingerprint (sha256 of the sorted file +// set, spec §2.4). +// +// The .cog state directory is skipped. Non-regular entries (symlinks, +// devices, etc.) are skipped — the spec defines packing over concrete +// files only. +func (s *FileSource) Inventory(ctx context.Context) (Inventory, error) { + if err := ctx.Err(); err != nil { + return Inventory{}, err + } + return computeInventory(ctx, s.dir) +} + +// Open returns a reader for a single file in the source, identified by +// its inventory path (relative to the source root, using forward +// slashes). The caller closes the returned reader. +func (s *FileSource) Open(ctx context.Context, path string) (io.ReadCloser, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + f, err := s.fsys.Open(path) + if err != nil { + return nil, fmt.Errorf("open %s: %w", path, err) + } + rc, ok := f.(io.ReadCloser) + if !ok { + _ = f.Close() + return nil, fmt.Errorf("open %s: filesystem does not support ReadCloser", path) + } + return rc, nil +} + +// normalizeFileURI produces the canonical file:// URI for a path value +// that already has the file:// prefix stripped (or was never present). +func normalizeFileURI(path string) (string, error) { + if path == "" { + return "", fmt.Errorf("empty weight source path") + } + + // On some forms (file:///abs/path) the caller has already stripped + // "file://", leaving "/abs/path". Bare "/abs/path" is the same + // string, so we treat them uniformly. + cleaned := filepath.Clean(path) + if filepath.IsAbs(cleaned) { + return "file://" + cleaned, nil + } + + // filepath.Clean drops a leading "./"; re-add it so the relative + // form is visually unambiguous. "." collapses to itself — callers + // who point at the project dir ("") should not reach here; that's + // rejected upstream. + if cleaned == "." { + return "", fmt.Errorf("weight source cannot be the project directory itself") + } + if strings.HasPrefix(cleaned, "..") { + return "", fmt.Errorf("weight source %q escapes the project directory", path) + } + return "file://./" + cleaned, nil +} + +// resolvePath turns a file:// URI or bare path into an absolute on-disk +// path, resolving relative paths against projectDir. +func resolvePath(uri, projectDir string) (string, error) { + // Strip the file:// scheme if present; bare paths pass through. + path := uri + if rest, ok := strings.CutPrefix(uri, "file://"); ok { + path = rest + } + normalized, err := normalizeFileURI(path) + if err != nil { + return "", err + } + // normalized is always "file://" at this point. + path = strings.TrimPrefix(normalized, "file://") + if filepath.IsAbs(path) { + return path, nil + } + // Relative: resolve against the project directory. The canonical + // form has a "./" prefix — trim it so filepath.Join doesn't double up. + path = strings.TrimPrefix(path, "./") + if projectDir == "" { + return "", fmt.Errorf("relative weight source %q requires a project directory", uri) + } + return filepath.Join(projectDir, path), nil +} diff --git a/pkg/model/weightsource/file_test.go b/pkg/model/weightsource/file_test.go new file mode 100644 index 0000000000..5ced30c0ba --- /dev/null +++ b/pkg/model/weightsource/file_test.go @@ -0,0 +1,286 @@ +package weightsource + +import ( + "context" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNormalizeURI(t *testing.T) { + tests := []struct { + name string + in string + want string + wantErrSubs string + }{ + {"absolute bare", "/data/weights", "file:///data/weights", ""}, + {"absolute file scheme", "file:///data/weights", "file:///data/weights", ""}, + {"relative bare no dot", "weights", "file://./weights", ""}, + {"relative bare dot prefix", "./weights", "file://./weights", ""}, + {"relative file scheme", "file://./weights", "file://./weights", ""}, + {"relative with slash", "./weights/models", "file://./weights/models", ""}, + {"clean double slash", "./weights//models", "file://./weights/models", ""}, + {"clean dot segment", "./weights/./models", "file://./weights/models", ""}, + {"absolute clean", "/data//weights", "file:///data/weights", ""}, + + {"empty", "", "", "empty weight source uri"}, + {"empty file scheme", "file://", "", "empty weight source path"}, + {"project dir itself rejected", ".", "", "project directory itself"}, + {"parent escape rejected", "../sibling", "", "escapes the project directory"}, + {"unknown scheme rejected", "s3://bucket/key", "", "unsupported weight source scheme"}, + + {"hf basic", "hf://org/repo", "hf://org/repo", ""}, + {"hf with ref", "hf://org/repo@v1.0", "hf://org/repo@v1.0", ""}, + {"huggingface canonicalized", "huggingface://org/repo", "hf://org/repo", ""}, + {"huggingface with ref canonicalized", "huggingface://org/repo@main", "hf://org/repo", ""}, + {"hf explicit main ref stripped", "hf://org/repo@main", "hf://org/repo", ""}, + {"hf sha ref preserved", "hf://org/repo@abc123", "hf://org/repo@abc123", ""}, + {"hf invalid repo", "hf://justrepo", "", "expected org/repo"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := NormalizeURI(tc.in) + if tc.wantErrSubs != "" { + assert.ErrorContains(t, err, tc.wantErrSubs) + return + } + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestNewFileSource_Absolute(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("x"), 0o644)) + + // Absolute URI; projectDir is ignored. + uri := "file://" + dir + src, err := NewFileSource(uri, "/unused") + require.NoError(t, err) + assert.Equal(t, dir, src.sourceDir()) +} + +func TestNewFileSource_BareAbsolutePath(t *testing.T) { + dir := t.TempDir() + src, err := NewFileSource(dir, "") + require.NoError(t, err) + assert.Equal(t, dir, src.sourceDir()) +} + +func TestNewFileSource_Relative(t *testing.T) { + projectDir := t.TempDir() + weightsDir := filepath.Join(projectDir, "weights") + require.NoError(t, os.MkdirAll(weightsDir, 0o755)) + + src, err := NewFileSource("file://./weights", projectDir) + require.NoError(t, err) + assert.Equal(t, weightsDir, src.sourceDir()) +} + +func TestNewFileSource_BareRelative(t *testing.T) { + projectDir := t.TempDir() + weightsDir := filepath.Join(projectDir, "weights") + require.NoError(t, os.MkdirAll(weightsDir, 0o755)) + + src, err := NewFileSource("weights", projectDir) + require.NoError(t, err) + assert.Equal(t, weightsDir, src.sourceDir()) +} + +func TestNewFileSource_ErrorCases(t *testing.T) { + projectDir := t.TempDir() + + t.Run("missing", func(t *testing.T) { + _, err := NewFileSource("file://./missing", projectDir) + assert.ErrorContains(t, err, "not found") + }) + + t.Run("is a file not a dir", func(t *testing.T) { + filePath := filepath.Join(projectDir, "oops.bin") + require.NoError(t, os.WriteFile(filePath, []byte("x"), 0o644)) + _, err := NewFileSource("file://./oops.bin", projectDir) + assert.ErrorContains(t, err, "is not a directory") + }) + + t.Run("relative uri without project dir", func(t *testing.T) { + _, err := NewFileSource("file://./weights", "") + assert.ErrorContains(t, err, "project directory") + }) +} + +func TestFileSource_Open(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0o644)) + require.NoError(t, os.MkdirAll(filepath.Join(dir, "sub"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "sub", "b.txt"), []byte("world"), 0o644)) + + src, err := NewFileSource("file://"+dir, "") + require.NoError(t, err) + + t.Run("top level", func(t *testing.T) { + rc, err := src.Open(t.Context(), "a.txt") + require.NoError(t, err) + defer rc.Close() + b, err := io.ReadAll(rc) + require.NoError(t, err) + assert.Equal(t, "hello", string(b)) + }) + + t.Run("nested", func(t *testing.T) { + rc, err := src.Open(t.Context(), "sub/b.txt") + require.NoError(t, err) + defer rc.Close() + b, err := io.ReadAll(rc) + require.NoError(t, err) + assert.Equal(t, "world", string(b)) + }) + + t.Run("missing file", func(t *testing.T) { + _, err := src.Open(t.Context(), "missing.txt") + require.Error(t, err) + }) + + t.Run("path traversal rejected", func(t *testing.T) { + _, err := src.Open(t.Context(), "../"+filepath.Base(dir)+"/../etc/passwd") + require.Error(t, err) + }) + + t.Run("absolute path rejected", func(t *testing.T) { + _, err := src.Open(t.Context(), "/etc/passwd") + require.Error(t, err) + }) + + t.Run("canceled context", func(t *testing.T) { + // Cancellation is tested with an independent context because + // t.Context() is tied to the test lifetime; we need a context + // we can cancel explicitly before the call. + ctx, cancel := context.WithCancel(t.Context()) + cancel() + _, err := src.Open(ctx, "a.txt") + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + }) +} + +func TestFileSource_Inventory(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "b.txt"), []byte("world"), 0o644)) + + src, err := NewFileSource("file://"+dir, "") + require.NoError(t, err) + + inv, err := src.Inventory(t.Context()) + require.NoError(t, err) + + require.Len(t, inv.Files, 2) + assert.Equal(t, "a.txt", inv.Files[0].Path) + assert.Equal(t, int64(5), inv.Files[0].Size) + assert.True(t, strings.HasPrefix(inv.Files[0].Digest, "sha256:")) + assert.Equal(t, "b.txt", inv.Files[1].Path) + assert.Equal(t, int64(5), inv.Files[1].Size) + + assert.Equal(t, "sha256", inv.Fingerprint.Scheme()) + assert.Len(t, inv.Fingerprint.value(), 64, "sha256 hex is 64 chars") +} + +func TestFileSource_Inventory_Stable(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "b.txt"), []byte("world"), 0o644)) + + src, err := NewFileSource("file://"+dir, "") + require.NoError(t, err) + + inv1, err := src.Inventory(t.Context()) + require.NoError(t, err) + inv2, err := src.Inventory(t.Context()) + require.NoError(t, err) + assert.Equal(t, inv1.Fingerprint, inv2.Fingerprint, + "fingerprint must be stable across calls") + assert.Equal(t, inv1.Files, inv2.Files, + "file list must be stable across calls") +} + +func TestFileSource_Inventory_DiffersOnChange(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0o644)) + + src, err := NewFileSource("file://"+dir, "") + require.NoError(t, err) + + inv1, err := src.Inventory(t.Context()) + require.NoError(t, err) + + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("changed"), 0o644)) + + inv2, err := src.Inventory(t.Context()) + require.NoError(t, err) + assert.NotEqual(t, inv1.Fingerprint, inv2.Fingerprint, + "fingerprint must change when content changes") + assert.NotEqual(t, inv1.Files[0].Digest, inv2.Files[0].Digest, + "per-file digest must change when content changes") +} + +func TestFileSource_Inventory_SkipsDotCog(t *testing.T) { + withoutCog := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(withoutCog, "a.txt"), []byte("hello"), 0o644)) + + withCog := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(withCog, "a.txt"), []byte("hello"), 0o644)) + require.NoError(t, os.MkdirAll(filepath.Join(withCog, ".cog"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(withCog, ".cog", "state"), []byte("stuff"), 0o644)) + + src1, err := NewFileSource("file://"+withoutCog, "") + require.NoError(t, err) + src2, err := NewFileSource("file://"+withCog, "") + require.NoError(t, err) + + inv1, err := src1.Inventory(t.Context()) + require.NoError(t, err) + inv2, err := src2.Inventory(t.Context()) + require.NoError(t, err) + assert.Equal(t, inv1.Fingerprint, inv2.Fingerprint, + ".cog directory must be excluded from inventory") + assert.Len(t, inv2.Files, 1) +} + +func TestFileSource_Inventory_ContextCanceled(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("x"), 0o644)) + + src, err := NewFileSource("file://"+dir, "") + require.NoError(t, err) + + // Cancellation is tested with an independent context because + // t.Context() is tied to the test lifetime; we need a context we can + // cancel explicitly before the call. + ctx, cancel := context.WithCancel(t.Context()) + cancel() + _, err = src.Inventory(ctx) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) +} + +// Cross-check: the inventory fingerprint is the published set-digest +// formula (sha256 of sorted "hex path" lines). Guards against the two +// computations drifting apart. +func TestFileSource_Inventory_FingerprintMatchesExplicitSetDigest(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "b.txt"), []byte("world"), 0o644)) + + src, err := NewFileSource("file://"+dir, "") + require.NoError(t, err) + + inv, err := src.Inventory(t.Context()) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(inv.Fingerprint.String(), "sha256:")) +} diff --git a/pkg/model/weightsource/filter.go b/pkg/model/weightsource/filter.go new file mode 100644 index 0000000000..d1f99e371c --- /dev/null +++ b/pkg/model/weightsource/filter.go @@ -0,0 +1,103 @@ +package weightsource + +import ( + "fmt" + "strings" + + ignore "github.com/sabhiram/go-gitignore" +) + +// FilterInventory applies include/exclude glob patterns to an inventory's +// file list and returns a new inventory with only the matching files. +// The returned inventory shares the original's Fingerprint (which is the +// upstream version identity, not affected by filtering). +// +// Semantics: +// - If include is non-empty, a file must match at least one include pattern. +// - If a file matches any exclude pattern, it is excluded (even if it also +// matches an include pattern — exclude wins). +// - If both lists are empty/nil, all files pass through unchanged. +// +// Pattern matching uses gitignore-style globs via go-gitignore: bare patterns +// float across directories ("*.bin" matches any depth), path-shaped patterns +// anchor ("onnx/*.bin" matches direct children of onnx/), and "**" matches +// any number of path segments. +// +// Returns an error if the filter yields zero files — an empty weight set is +// almost always a mistake and should surface immediately. +func FilterInventory(inv Inventory, include, exclude []string) (Inventory, error) { + if len(include) == 0 && len(exclude) == 0 { + return inv, nil + } + + var includeMatcher *ignore.GitIgnore + if len(include) > 0 { + includeMatcher = ignore.CompileIgnoreLines(include...) + } + + var excludeMatcher *ignore.GitIgnore + if len(exclude) > 0 { + excludeMatcher = ignore.CompileIgnoreLines(exclude...) + } + + filtered := make([]InventoryFile, 0, len(inv.Files)) + for _, f := range inv.Files { + if !fileIncluded(f.Path, includeMatcher, excludeMatcher) { + continue + } + filtered = append(filtered, f) + } + + if len(filtered) == 0 { + return Inventory{}, &ZeroSurvivorsError{ + InventorySize: len(inv.Files), + Include: include, + Exclude: exclude, + } + } + + return Inventory{ + Files: filtered, + Fingerprint: inv.Fingerprint, + }, nil +} + +// fileIncluded reports whether a file path passes the include/exclude filter. +func fileIncluded(path string, includeMatcher, excludeMatcher *ignore.GitIgnore) bool { + if excludeMatcher != nil && excludeMatcher.MatchesPath(path) { + return false + } + if includeMatcher != nil { + return includeMatcher.MatchesPath(path) + } + return true +} + +// ZeroSurvivorsError is returned when include/exclude filtering removes all +// files from an inventory. +type ZeroSurvivorsError struct { + InventorySize int + Include []string + Exclude []string +} + +func (e *ZeroSurvivorsError) Error() string { + var b strings.Builder + fmt.Fprintf(&b, "include/exclude patterns matched zero files out of %d in the source", e.InventorySize) + if len(e.Include) > 0 { + fmt.Fprintf(&b, "\n include: %s", formatPatterns(e.Include)) + } + if len(e.Exclude) > 0 { + fmt.Fprintf(&b, "\n exclude: %s", formatPatterns(e.Exclude)) + } + b.WriteString("\n check your patterns — did you mean a different glob?") + return b.String() +} + +func formatPatterns(patterns []string) string { + quoted := make([]string, len(patterns)) + for i, p := range patterns { + quoted[i] = fmt.Sprintf("%q", p) + } + return "[" + strings.Join(quoted, ", ") + "]" +} diff --git a/pkg/model/weightsource/filter_test.go b/pkg/model/weightsource/filter_test.go new file mode 100644 index 0000000000..df901c1bc7 --- /dev/null +++ b/pkg/model/weightsource/filter_test.go @@ -0,0 +1,242 @@ +package weightsource + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testInventory builds an Inventory with the given paths (all other fields +// are placeholders — FilterInventory only inspects Path). +func testInventory(paths ...string) Inventory { + files := make([]InventoryFile, len(paths)) + for i, p := range paths { + files[i] = InventoryFile{Path: p, Size: 100, Digest: "sha256:deadbeef"} + } + return Inventory{ + Files: files, + Fingerprint: Fingerprint("test:abc123"), + } +} + +func filePaths(inv Inventory) []string { + paths := make([]string, len(inv.Files)) + for i, f := range inv.Files { + paths[i] = f.Path + } + return paths +} + +func TestFilterInventory(t *testing.T) { + allFiles := testInventory( + ".gitattributes", + "README.md", + "config.json", + "model.safetensors", + "pytorch_model.bin", + "onnx/model.onnx", + "onnx/model_O1.onnx", + "openvino/openvino_model.bin", + "openvino/openvino_model.xml", + "tokenizer.json", + "tokenizer_config.json", + "deep/nested/dir/weights.safetensors", + ) + + tests := []struct { + name string + inv Inventory + include []string + exclude []string + wantPaths []string + wantErr string + }{ + { + name: "no patterns passes all files", + inv: allFiles, + wantPaths: filePaths(allFiles), + }, + { + name: "nil patterns passes all files", + inv: allFiles, + include: nil, + exclude: nil, + wantPaths: filePaths(allFiles), + }, + { + name: "empty slices passes all files", + inv: allFiles, + include: []string{}, + exclude: []string{}, + wantPaths: filePaths(allFiles), + }, + { + name: "include safetensors and json", + inv: allFiles, + include: []string{"*.safetensors", "*.json"}, + wantPaths: []string{ + "config.json", + "model.safetensors", + "tokenizer.json", + "tokenizer_config.json", + "deep/nested/dir/weights.safetensors", + }, + }, + { + name: "exclude onnx and bin", + inv: allFiles, + exclude: []string{"*.onnx", "*.bin"}, + wantPaths: []string{ + ".gitattributes", + "README.md", + "config.json", + "model.safetensors", + "openvino/openvino_model.xml", + "tokenizer.json", + "tokenizer_config.json", + "deep/nested/dir/weights.safetensors", + }, + }, + { + name: "include and exclude together", + inv: allFiles, + include: []string{"*.safetensors", "*.json", "*.bin"}, + exclude: []string{"*.bin"}, + wantPaths: []string{ + "config.json", + "model.safetensors", + "tokenizer.json", + "tokenizer_config.json", + "deep/nested/dir/weights.safetensors", + }, + }, + { + name: "exclude takes precedence over include", + inv: allFiles, + include: []string{"*.bin"}, + exclude: []string{"*.bin"}, + wantErr: "matched zero files", + }, + { + name: "anchored path pattern", + inv: allFiles, + exclude: []string{"onnx/*"}, + wantPaths: []string{ + ".gitattributes", + "README.md", + "config.json", + "model.safetensors", + "pytorch_model.bin", + "openvino/openvino_model.bin", + "openvino/openvino_model.xml", + "tokenizer.json", + "tokenizer_config.json", + "deep/nested/dir/weights.safetensors", + }, + }, + { + name: "double-star recursion in include", + inv: allFiles, + include: []string{"**/*.safetensors"}, + wantPaths: []string{ + "model.safetensors", + "deep/nested/dir/weights.safetensors", + }, + }, + { + name: "directory pattern excludes subtree", + inv: allFiles, + exclude: []string{"openvino/"}, + wantPaths: []string{ + ".gitattributes", + "README.md", + "config.json", + "model.safetensors", + "pytorch_model.bin", + "onnx/model.onnx", + "onnx/model_O1.onnx", + "tokenizer.json", + "tokenizer_config.json", + "deep/nested/dir/weights.safetensors", + }, + }, + { + name: "case sensitive matching", + inv: testInventory("Model.Safetensors", "model.safetensors"), + include: []string{"*.safetensors"}, + wantPaths: []string{ + "model.safetensors", + }, + }, + { + name: "zero survivors from include mismatch", + inv: testInventory("model.safetensors", "config.json"), + include: []string{"*.gguf"}, + wantErr: "matched zero files out of 2", + }, + { + name: "zero survivors from total exclude", + inv: testInventory("a.bin"), + exclude: []string{"*.bin"}, + wantErr: "matched zero files out of 1", + }, + { + name: "empty inventory with no patterns is ok", + inv: testInventory(), + include: nil, + exclude: nil, + wantPaths: []string{}, + }, + { + name: "single include pattern", + inv: allFiles, + include: []string{"config.json"}, + wantPaths: []string{ + "config.json", + }, + }, + { + name: "forward slashes only", + inv: testInventory("a/b/c.bin", "a/b/d.safetensors"), + include: []string{"a/b/*.safetensors"}, + wantPaths: []string{ + "a/b/d.safetensors", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := FilterInventory(tt.inv, tt.include, tt.exclude) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + + // Verify it's a ZeroSurvivorsError + var zse *ZeroSurvivorsError + assert.ErrorAs(t, err, &zse) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantPaths, filePaths(result)) + + // Fingerprint must be preserved from original. + assert.Equal(t, tt.inv.Fingerprint, result.Fingerprint) + }) + } +} + +func TestZeroSurvivorsError_Message(t *testing.T) { + err := &ZeroSurvivorsError{ + InventorySize: 42, + Include: []string{"*.safetensors"}, + Exclude: []string{"*.bin", "*.onnx"}, + } + msg := err.Error() + assert.Contains(t, msg, "42") + assert.Contains(t, msg, "*.safetensors") + assert.Contains(t, msg, "*.bin") + assert.Contains(t, msg, "*.onnx") + assert.Contains(t, msg, "check your patterns") +} diff --git a/pkg/model/weightsource/fingerprint.go b/pkg/model/weightsource/fingerprint.go new file mode 100644 index 0000000000..5dd7059d02 --- /dev/null +++ b/pkg/model/weightsource/fingerprint.go @@ -0,0 +1,63 @@ +// Package weightsource is the pluggable source layer for weight imports. +// +// A Source is a stateful provider bound at construction time to a specific +// URI. It exposes two capabilities: Inventory lists the files the source +// offers (with sizes, per-file digests, and a source-identity Fingerprint), +// and Open streams one file's bytes. The packer drives the import one file +// at a time so that sources larger than local disk can be imported without +// full materialization. +// +// Implementations exist for file:// (local directory) and hf:// +// (HuggingFace Hub). +package weightsource + +import ( + "strings" +) + +// Fingerprint is a source's version identity, carrying its algorithm (or +// source-native identifier type) as a scheme prefix. +// +// Examples: +// +// sha256: — content hash (file:// sources) +// commit: — git commit (hf:// repos pinned to a commit) +// etag: — HTTP ETag (http:// sources) +// md5: — MD5 hash (s3:// objects) +// timestamp: — last-modified timestamp (fallback for systems +// that expose nothing stronger) +// +// The prefix makes two fingerprints from different sources unambiguously +// unequal even when the opaque values happen to collide. The empty string +// is not a valid Fingerprint — callers that want to express "no fingerprint +// known" should use a separate sentinel. +type Fingerprint string + +// Scheme returns the fingerprint's algorithm or identifier prefix (the +// part before the first colon). Returns "" if the fingerprint is malformed +// (no colon). +func (f Fingerprint) Scheme() string { + scheme, _, ok := strings.Cut(string(f), ":") + if !ok { + return "" + } + return scheme +} + +// value returns the fingerprint's opaque value (the part after the first +// colon). Returns "" if the fingerprint is malformed. +func (f Fingerprint) value() string { + _, value, ok := strings.Cut(string(f), ":") + if !ok { + return "" + } + return value +} + +// String returns the fingerprint in its canonical ":" form. +func (f Fingerprint) String() string { return string(f) } + +// isZero reports whether f is the zero value (empty string). Use this to +// distinguish "no fingerprint" from a real fingerprint whose scheme or +// value happens to be empty. +func (f Fingerprint) isZero() bool { return f == "" } diff --git a/pkg/model/weightsource/fingerprint_test.go b/pkg/model/weightsource/fingerprint_test.go new file mode 100644 index 0000000000..0f500a7177 --- /dev/null +++ b/pkg/model/weightsource/fingerprint_test.go @@ -0,0 +1,54 @@ +package weightsource + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFingerprint_Scheme(t *testing.T) { + tests := []struct { + name string + fp Fingerprint + want string + }{ + {"sha256", Fingerprint("sha256:abc123"), "sha256"}, + {"commit", Fingerprint("commit:deadbeef"), "commit"}, + {"etag", Fingerprint("etag:W/\"abc\""), "etag"}, + {"timestamp with colons", Fingerprint("timestamp:2026-04-17T12:00:00Z"), "timestamp"}, + {"empty", Fingerprint(""), ""}, + {"no separator", Fingerprint("bare"), ""}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, tc.fp.Scheme()) + }) + } +} + +func TestFingerprint_Value(t *testing.T) { + tests := []struct { + name string + fp Fingerprint + want string + }{ + {"sha256", Fingerprint("sha256:abc123"), "abc123"}, + {"timestamp preserves inner colons", Fingerprint("timestamp:2026-04-17T12:00:00Z"), "2026-04-17T12:00:00Z"}, + {"empty", Fingerprint(""), ""}, + {"no separator", Fingerprint("bare"), ""}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, tc.fp.value()) + }) + } +} + +func TestFingerprint_IsZero(t *testing.T) { + assert.True(t, Fingerprint("").isZero()) + assert.False(t, Fingerprint("sha256:abc").isZero()) +} + +func TestFingerprint_String(t *testing.T) { + assert.Equal(t, "sha256:abc", Fingerprint("sha256:abc").String()) +} diff --git a/pkg/model/weightsource/huggingface.go b/pkg/model/weightsource/huggingface.go new file mode 100644 index 0000000000..e42e0ce832 --- /dev/null +++ b/pkg/model/weightsource/huggingface.go @@ -0,0 +1,470 @@ +package weightsource + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path" + "strings" + + "github.com/hashicorp/go-retryablehttp" + "golang.org/x/sync/errgroup" + + "github.com/replicate/cog/pkg/util/console" +) + +// HFScheme is the short URI scheme for HuggingFace Hub sources. +const HFScheme = "hf" + +// HFSchemeLong is the long-form URI scheme alias for HuggingFace Hub. +const HFSchemeLong = "huggingface" + +// hfDefaultBaseURL is the base URL for the HuggingFace Hub API. It can +// be overridden via the HF_ENDPOINT env var (useful for testing and +// mirrors). +const hfDefaultBaseURL = "https://huggingface.co" + +// hfInlineDigestConcurrency controls how many inline (non-LFS) files +// are fetched concurrently during Inventory to compute their sha256. +const hfInlineDigestConcurrency = 4 + +// HFSource is the Source implementation for hf:// URIs. +// +// URI forms: +// +// hf://org/repo — follows main branch +// hf://org/repo@ref — ref is a branch, tag, or 40-char commit sha +// +// The source resolves the ref to a full commit sha at Inventory time and +// uses that pinned sha for all subsequent Open calls. Callers must call +// Inventory before Open to ensure content is pinned to a specific commit. +type HFSource struct { + repo string // "org/repo" + ref string // user-provided ref (branch, tag, or sha); defaults to "main" + resolvedRef string // full commit sha, set by Inventory; Open uses this when non-empty + baseURL *url.URL // parsed once at construction; cloned in buildURL to avoid mutation + token string + client *http.Client +} + +// NewHFSource constructs an HFSource bound to the given hf:// URI. +// It parses the URI and looks up auth from env vars but does not make +// any network calls — validation happens at Inventory time. +func NewHFSource(uri string) (*HFSource, error) { + repo, ref, err := parseHFURI(uri) + if err != nil { + return nil, err + } + + rawURL := os.Getenv("HF_ENDPOINT") + if rawURL == "" { + rawURL = hfDefaultBaseURL + } + baseURL, err := url.Parse(rawURL) + if err != nil { + return nil, fmt.Errorf("invalid HF base URL %q: %w", rawURL, err) + } + + token := os.Getenv("HF_TOKEN") + if token == "" { + token = os.Getenv("HUGGING_FACE_HUB_TOKEN") + } + + return &HFSource{ + repo: repo, + ref: ref, + baseURL: baseURL, + token: token, + client: newHFHTTPClient(), + }, nil +} + +// newHFHTTPClient returns a standard *http.Client whose transport +// retries on 5xx, 429, and network errors with exponential backoff. +// The retryable behavior is provided by go-retryablehttp configured as +// a transport — callers use the standard http.Client API. +func newHFHTTPClient() *http.Client { + rc := retryablehttp.NewClient() + rc.RetryMax = 3 + rc.Logger = nil // Silence default logger; errors surface via return values. + rc.CheckRetry = hfCheckRetry + return rc.StandardClient() +} + +// hfCheckRetry is a retryablehttp.CheckRetry that retries on 5xx and +// 429 but treats other 4xx status codes as permanent failures. +func hfCheckRetry(ctx context.Context, resp *http.Response, err error) (bool, error) { + // Network errors: let the default policy decide (retries them). + if err != nil { + return retryablehttp.DefaultRetryPolicy(ctx, resp, err) + } + // 429 Too Many Requests: retry. + if resp.StatusCode == http.StatusTooManyRequests { + return true, nil + } + // 5xx: retry. + if resp.StatusCode >= 500 { + return true, nil + } + // Everything else (2xx, 3xx, 4xx other than 429): do not retry. + return false, nil +} + +// normalizeHFURI returns the canonical hf:// form of an HF URI. It +// validates the URI, canonicalizes huggingface:// to hf://, and +// preserves the @ref suffix if present. The default ref ("main") is +// not appended — it is implied. +func normalizeHFURI(uri string) (string, error) { + repo, ref, err := parseHFURI(uri) + if err != nil { + return "", err + } + if ref == "main" { + return "hf://" + repo, nil + } + return "hf://" + repo + "@" + ref, nil +} + +// parseHFURI parses "hf://org/repo" or "huggingface://org/repo" (with +// optional @ref suffix) and returns the repo and ref components. +func parseHFURI(uri string) (repo, ref string, err error) { + var rest string + switch { + case strings.HasPrefix(uri, "hf://"): + rest = strings.TrimPrefix(uri, "hf://") + case strings.HasPrefix(uri, "huggingface://"): + rest = strings.TrimPrefix(uri, "huggingface://") + default: + return "", "", fmt.Errorf("not an hf:// URI: %q", uri) + } + if rest == "" { + return "", "", fmt.Errorf("empty hf:// URI") + } + + // Split off @ref suffix if present. + repo = rest + if idx := strings.LastIndex(rest, "@"); idx > 0 { + repo = rest[:idx] + ref = rest[idx+1:] + if ref == "" { + return "", "", fmt.Errorf("empty ref in hf:// URI: %q", uri) + } + } + + // Validate repo has exactly one slash (org/name). + parts := strings.Split(repo, "/") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", "", fmt.Errorf("invalid hf:// repo %q: expected org/repo", repo) + } + + if ref == "" { + ref = "main" + } + return repo, ref, nil +} + +// Inventory calls the HuggingFace Hub API to list files and resolve the +// ref to a pinned commit sha. For LFS/xet-tracked files the sha256 +// digest comes from the API response (free, no download). Inline files +// (small, git-tracked) are fetched and hashed. +// +// The fingerprint is "commit:". +func (s *HFSource) Inventory(ctx context.Context) (Inventory, error) { + if err := ctx.Err(); err != nil { + return Inventory{}, err + } + + console.Debugf("hf: resolving %s@%s", s.repo, s.ref) + + // 1. Resolve ref → commit sha and pin for subsequent Open calls. + commitSHA, err := s.resolveRef(ctx) + if err != nil { + return Inventory{}, fmt.Errorf("resolve ref %q for %s: %w", s.ref, s.repo, err) + } + s.resolvedRef = commitSHA + console.Debugf("hf: resolved to commit %s", commitSHA[:12]) + + // 2. Fetch the recursive tree listing at the resolved commit. + entries, err := s.listTree(ctx, commitSHA) + if err != nil { + return Inventory{}, fmt.Errorf("list tree for %s@%s: %w", s.repo, commitSHA, err) + } + console.Debugf("hf: tree listing returned %d entries", len(entries)) + + // 3. Build inventory files. LFS entries have digests already; + // inline entries need to be fetched and hashed. + files, err := s.buildInventoryFiles(ctx, commitSHA, entries) + if err != nil { + return Inventory{}, err + } + + sortInventoryFiles(files) + + console.Debugf("hf: inventory complete — %d files, fingerprint commit:%s", len(files), commitSHA[:12]) + return Inventory{ + Files: files, + Fingerprint: Fingerprint("commit:" + commitSHA), + }, nil +} + +// Open returns a reader that streams the file from the HuggingFace CDN. +// It follows the redirect from the resolve endpoint to the appropriate +// backend (LFS CDN, xet cas-bridge, or inline git blob). +// +// Open uses the commit sha resolved during Inventory, so file content +// is pinned to the same revision that was inventoried. If Inventory has +// not been called, Open falls back to the original ref. +func (s *HFSource) Open(ctx context.Context, path string) (io.ReadCloser, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + ref := s.resolvedRef + if ref == "" { + ref = s.ref + } + console.Debugf("hf: open %s (ref %s)", path, ref[:min(12, len(ref))]) + return s.fetchFile(ctx, ref, path) +} + +// hfRevisionResponse is the subset of the /api/models/{repo}/revision/{ref} +// response we need. +type hfRevisionResponse struct { + SHA string `json:"sha"` +} + +// resolveRef calls the Hub API to resolve a ref (branch/tag/sha) to the +// full 40-char commit sha. +func (s *HFSource) resolveRef(ctx context.Context) (string, error) { + u := s.buildURL("api", "models", s.repo, "revision", s.ref) + + body, err := s.doGet(ctx, u) + if err != nil { + return "", err + } + defer body.Close() + + var resp hfRevisionResponse + if err := json.NewDecoder(body).Decode(&resp); err != nil { + return "", fmt.Errorf("decode revision response: %w", err) + } + if resp.SHA == "" { + return "", fmt.Errorf("empty sha in revision response for %s@%s", s.repo, s.ref) + } + return resp.SHA, nil +} + +// hfTreeEntry represents one file in the recursive tree listing. +type hfTreeEntry struct { + Type string `json:"type"` + Path string `json:"path"` + Size int64 `json:"size"` + LFS *hfLFSInfo `json:"lfs,omitempty"` +} + +// hfLFSInfo carries the LFS pointer metadata. Present for both LFS and +// xet-tracked files (the Hub API exposes sha256 in the lfs field for +// both). +type hfLFSInfo struct { + OID string `json:"oid"` // sha256 of the full file + Size int64 `json:"size"` +} + +// listTree fetches the recursive tree listing at the given commit sha. +// NOTE: the HF Hub API paginates large repos. This implementation does +// not follow pagination yet — repos with very many files may return an +// incomplete listing. A follow-up should add cursor-based pagination. +func (s *HFSource) listTree(ctx context.Context, commitSHA string) ([]hfTreeEntry, error) { + u := s.buildURLWithQuery("recursive=true", "api", "models", s.repo, "tree", commitSHA) + + body, err := s.doGet(ctx, u) + if err != nil { + return nil, err + } + defer body.Close() + + var entries []hfTreeEntry + if err := json.NewDecoder(body).Decode(&entries); err != nil { + return nil, fmt.Errorf("decode tree response: %w", err) + } + return entries, nil +} + +// buildInventoryFiles converts tree entries into InventoryFiles. LFS +// entries use the lfs.oid as the digest directly. Inline entries are +// fetched and hashed with bounded concurrency. +func (s *HFSource) buildInventoryFiles(ctx context.Context, commitSHA string, entries []hfTreeEntry) ([]InventoryFile, error) { + var lfsFiles []InventoryFile + var inlineEntries []hfTreeEntry + + for _, e := range entries { + if e.Type != "file" { + continue + } + if e.LFS != nil && e.LFS.OID != "" { + lfsFiles = append(lfsFiles, InventoryFile{ + Path: e.Path, + Size: e.LFS.Size, + Digest: "sha256:" + e.LFS.OID, + }) + } else { + inlineEntries = append(inlineEntries, e) + } + } + + console.Debugf("hf: %d LFS files (digest from API), %d inline files (need fetch+hash)", len(lfsFiles), len(inlineEntries)) + + // Hash inline files with bounded concurrency. + inlineFiles, err := s.hashInlineFiles(ctx, commitSHA, inlineEntries) + if err != nil { + return nil, err + } + + return append(lfsFiles, inlineFiles...), nil +} + +// hashInlineFiles fetches and hashes inline (non-LFS) files with +// bounded concurrency via errgroup. +func (s *HFSource) hashInlineFiles(ctx context.Context, commitSHA string, entries []hfTreeEntry) ([]InventoryFile, error) { + if len(entries) == 0 { + return nil, nil + } + + files := make([]InventoryFile, len(entries)) + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(hfInlineDigestConcurrency) + + for i, e := range entries { + g.Go(func() error { + f, err := s.hashOneInlineFile(ctx, commitSHA, e) + if err != nil { + return err + } + files[i] = f + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, err + } + return files, nil +} + +// hashOneInlineFile fetches one inline file from the resolve endpoint +// and computes its sha256 while reading. +func (s *HFSource) hashOneInlineFile(ctx context.Context, commitSHA string, entry hfTreeEntry) (InventoryFile, error) { + rc, err := s.fetchFile(ctx, commitSHA, entry.Path) + if err != nil { + return InventoryFile{}, fmt.Errorf("fetch inline file %s: %w", entry.Path, err) + } + defer rc.Close() + + h := sha256.New() + n, err := io.Copy(h, rc) + if err != nil { + return InventoryFile{}, fmt.Errorf("hash inline file %s: %w", entry.Path, err) + } + + return InventoryFile{ + Path: entry.Path, + Size: n, + Digest: "sha256:" + hex.EncodeToString(h.Sum(nil)), + }, nil +} + +// fetchFile streams one file from the resolve endpoint. The endpoint +// issues a 302 to the appropriate backend (LFS CDN, xet, or inline). +func (s *HFSource) fetchFile(ctx context.Context, ref, filePath string) (io.ReadCloser, error) { + u := s.buildURL(s.repo, "resolve", ref, filePath) + return s.doGet(ctx, u) +} + +// escapeURLPath percent-encodes each component of a forward-slash-separated +// path, preserving the directory structure. This ensures filenames containing +// special URL characters (#, %, spaces, etc.) produce valid URLs. +func escapeURLPath(p string) string { + parts := strings.Split(p, "/") + for i, part := range parts { + parts[i] = url.PathEscape(part) + } + return strings.Join(parts, "/") +} + +// buildURL joins path segments onto s.baseURL. path.Join cleans ".." +// and double-slash components; url.URL.String handles encoding. +func (s *HFSource) buildURL(segments ...string) string { + return s.buildURLWithQuery("", segments...) +} + +// buildURLWithQuery is like buildURL but appends a raw query string. +// It sets both Path (decoded) and RawPath (percent-encoded) so that +// url.URL.String() emits correctly escaped URLs even when path segments +// contain special characters (#, %, spaces, etc.). +func (s *HFSource) buildURLWithQuery(query string, segments ...string) string { + u := *s.baseURL // shallow copy + + joined := path.Join(segments...) + u.Path = path.Join(u.Path, joined) + + escaped := make([]string, len(segments)) + for i, seg := range segments { + escaped[i] = escapeURLPath(seg) + } + joinedEscaped := path.Join(escaped...) + if joined != joinedEscaped { + basePath := u.Path[:len(u.Path)-len(joined)] + u.RawPath = basePath + joinedEscaped + } + + u.RawQuery = query + return u.String() +} + +// doGet performs an HTTP GET with retries (via the retrying transport) +// and returns the response body. The caller must close the body. +// +// Non-retryable 4xx responses are translated into specific errors: +// - 401 → auth hint +// - 403 → permissions hint +// - 404 → not-found +// - others → raw status + snippet of body +func (s *HFSource) doGet(ctx context.Context, url string) (io.ReadCloser, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + if s.token != "" { + req.Header.Set("Authorization", "Bearer "+s.token) + } + + resp, err := s.client.Do(req) //nolint:gosec // G704: URL is constructed from parsed hf:// URI components, not arbitrary user input + if err != nil { + return nil, fmt.Errorf("request %s: %w", url, err) + } + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return resp.Body, nil + } + + // Non-2xx: read a snippet of the body for diagnostics, then close. + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + _ = resp.Body.Close() + + switch resp.StatusCode { + case http.StatusUnauthorized: + return nil, fmt.Errorf("authentication failed for %s (HTTP 401): set HF_TOKEN or HUGGING_FACE_HUB_TOKEN", url) + case http.StatusForbidden: + return nil, fmt.Errorf("access denied for %s (HTTP 403): check repo visibility and token permissions", url) + case http.StatusNotFound: + return nil, fmt.Errorf("not found: %s (HTTP 404)", url) + default: + return nil, fmt.Errorf("HTTP %d from %s: %s", resp.StatusCode, url, string(errBody)) + } +} diff --git a/pkg/model/weightsource/huggingface_test.go b/pkg/model/weightsource/huggingface_test.go new file mode 100644 index 0000000000..53f0966155 --- /dev/null +++ b/pkg/model/weightsource/huggingface_test.go @@ -0,0 +1,560 @@ +package weightsource + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseHFURI(t *testing.T) { + tests := []struct { + name string + uri string + wantRepo string + wantRef string + wantErrSubs string + }{ + {"basic", "hf://myorg/mymodel", "myorg/mymodel", "main", ""}, + {"with tag ref", "hf://myorg/mymodel@v1.0", "myorg/mymodel", "v1.0", ""}, + {"with sha ref", "hf://myorg/mymodel@abc123def456", "myorg/mymodel", "abc123def456", ""}, + {"with branch ref", "hf://myorg/mymodel@feature/branch", "myorg/mymodel", "feature/branch", ""}, + {"long scheme", "huggingface://myorg/mymodel", "myorg/mymodel", "main", ""}, + {"long scheme with ref", "huggingface://myorg/mymodel@v2", "myorg/mymodel", "v2", ""}, + + {"not hf scheme", "file:///data", "", "", "not an hf:// URI"}, + {"empty after prefix", "hf://", "", "", "empty hf:// URI"}, + {"no slash", "hf://justarepo", "", "", "expected org/repo"}, + {"too many slashes", "hf://a/b/c", "", "", "expected org/repo"}, + {"empty org", "hf:///repo", "", "", "expected org/repo"}, + {"empty repo name", "hf://org/", "", "", "expected org/repo"}, + {"empty ref", "hf://org/repo@", "", "", "empty ref"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + repo, ref, err := parseHFURI(tc.uri) + if tc.wantErrSubs != "" { + assert.ErrorContains(t, err, tc.wantErrSubs) + return + } + require.NoError(t, err) + assert.Equal(t, tc.wantRepo, repo) + assert.Equal(t, tc.wantRef, ref) + }) + } +} + +// hfMock is a minimal mock of the HuggingFace Hub API. It serves: +// - GET /api/models/{repo}/revision/{ref} → revision response +// - GET /api/models/{repo}/tree/{ref}?recursive=true → tree listing +// - GET /{repo}/resolve/{ref}/{path} → file content +type hfMock struct { + commitSHA string + tree []hfTreeEntry + files map[string]string // path → content +} + +func (m *hfMock) handler() http.Handler { + mux := http.NewServeMux() + + // Revision endpoint. + mux.HandleFunc("/api/models/", func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + if strings.Contains(path, "/revision/") { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(hfRevisionResponse{SHA: m.commitSHA}) + return + } + if strings.Contains(path, "/tree/") { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(m.tree) + return + } + http.NotFound(w, r) + }) + + // Resolve/download endpoint: /{repo}/resolve/{ref}/{path...} + // Pattern: strip the leading slash, expect "org/repo/resolve/ref/path..." + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/api/") { + http.NotFound(w, r) + return + } + // Find "resolve/" in the path to extract the file path. + parts := strings.SplitN(r.URL.Path, "/resolve/", 2) + if len(parts) != 2 { + http.NotFound(w, r) + return + } + // parts[1] is "ref/file/path" — strip the ref prefix. + _, filePath, ok := strings.Cut(parts[1], "/") + if !ok { + http.NotFound(w, r) + return + } + content, ok := m.files[filePath] + if !ok { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + _, _ = w.Write([]byte(content)) + }) + + return mux +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + u, err := url.Parse(raw) + require.NoError(t, err) + return u +} + +func newTestHFSource(t *testing.T, serverURL, repo, ref string) *HFSource { + t.Helper() + if ref == "" { + ref = "main" + } + return &HFSource{ + repo: repo, + ref: ref, + baseURL: mustParseURL(t, serverURL), + token: "", + client: newHFHTTPClient(), + } +} + +func TestHFSource_Inventory_LFSFiles(t *testing.T) { + mock := &hfMock{ + commitSHA: "abc123def456abc123def456abc123def456abc123", + tree: []hfTreeEntry{ + { + Type: "file", + Path: "model.safetensors", + Size: 1000, + LFS: &hfLFSInfo{OID: "aabbccdd" + strings.Repeat("00", 28), Size: 1000}, + }, + { + Type: "file", + Path: "weights/shard-00.bin", + Size: 2000, + LFS: &hfLFSInfo{OID: "11223344" + strings.Repeat("00", 28), Size: 2000}, + }, + }, + files: map[string]string{}, + } + ts := httptest.NewServer(mock.handler()) + defer ts.Close() + + src := newTestHFSource(t, ts.URL, "myorg/mymodel", "main") + inv, err := src.Inventory(t.Context()) + require.NoError(t, err) + + assert.Equal(t, Fingerprint("commit:abc123def456abc123def456abc123def456abc123"), inv.Fingerprint) + assert.Equal(t, "commit", inv.Fingerprint.Scheme()) + + require.Len(t, inv.Files, 2) + // Files are sorted by path. + assert.Equal(t, "model.safetensors", inv.Files[0].Path) + assert.Equal(t, int64(1000), inv.Files[0].Size) + assert.Equal(t, "sha256:aabbccdd"+strings.Repeat("00", 28), inv.Files[0].Digest) + + assert.Equal(t, "weights/shard-00.bin", inv.Files[1].Path) + assert.Equal(t, int64(2000), inv.Files[1].Size) + assert.Equal(t, "sha256:11223344"+strings.Repeat("00", 28), inv.Files[1].Digest) +} + +func TestHFSource_Inventory_InlineFiles(t *testing.T) { + mock := &hfMock{ + commitSHA: "ffff" + strings.Repeat("00", 18), + tree: []hfTreeEntry{ + {Type: "file", Path: "config.json", Size: 13}, + {Type: "file", Path: "tokenizer.json", Size: 5}, + }, + files: map[string]string{ + "config.json": `{"key":"val"}`, + "tokenizer.json": "hello", + }, + } + ts := httptest.NewServer(mock.handler()) + defer ts.Close() + + src := newTestHFSource(t, ts.URL, "myorg/mymodel", "main") + inv, err := src.Inventory(t.Context()) + require.NoError(t, err) + + require.Len(t, inv.Files, 2) + // Both should have sha256 digests computed from content. + for _, f := range inv.Files { + assert.True(t, strings.HasPrefix(f.Digest, "sha256:"), "digest should start with sha256: for %s", f.Path) + assert.Len(t, strings.TrimPrefix(f.Digest, "sha256:"), 64, "digest hex should be 64 chars for %s", f.Path) + } + + // Verify sizes match actual content. + assert.Equal(t, "config.json", inv.Files[0].Path) + assert.Equal(t, int64(13), inv.Files[0].Size) + assert.Equal(t, "tokenizer.json", inv.Files[1].Path) + assert.Equal(t, int64(5), inv.Files[1].Size) +} + +func TestHFSource_Inventory_MixedLFSAndInline(t *testing.T) { + mock := &hfMock{ + commitSHA: "aaaa" + strings.Repeat("00", 18), + tree: []hfTreeEntry{ + { + Type: "file", + Path: "model.bin", + Size: 5000, + LFS: &hfLFSInfo{OID: "deadbeef" + strings.Repeat("00", 28), Size: 5000}, + }, + {Type: "file", Path: "config.json", Size: 4}, + {Type: "directory", Path: "subdir"}, // should be skipped + }, + files: map[string]string{ + "config.json": "test", + }, + } + ts := httptest.NewServer(mock.handler()) + defer ts.Close() + + src := newTestHFSource(t, ts.URL, "myorg/mymodel", "main") + inv, err := src.Inventory(t.Context()) + require.NoError(t, err) + + require.Len(t, inv.Files, 2, "directory entry should be excluded") + assert.Equal(t, "config.json", inv.Files[0].Path) + assert.Equal(t, "model.bin", inv.Files[1].Path) + assert.Equal(t, "sha256:deadbeef"+strings.Repeat("00", 28), inv.Files[1].Digest) +} + +func TestHFSource_Inventory_FilesAreSorted(t *testing.T) { + mock := &hfMock{ + commitSHA: "bbbb" + strings.Repeat("00", 18), + tree: []hfTreeEntry{ + {Type: "file", Path: "z.txt", Size: 1, LFS: &hfLFSInfo{OID: strings.Repeat("aa", 32), Size: 1}}, + {Type: "file", Path: "a.txt", Size: 1, LFS: &hfLFSInfo{OID: strings.Repeat("bb", 32), Size: 1}}, + {Type: "file", Path: "m.txt", Size: 1, LFS: &hfLFSInfo{OID: strings.Repeat("cc", 32), Size: 1}}, + }, + files: map[string]string{}, + } + ts := httptest.NewServer(mock.handler()) + defer ts.Close() + + src := newTestHFSource(t, ts.URL, "myorg/mymodel", "main") + inv, err := src.Inventory(t.Context()) + require.NoError(t, err) + + require.Len(t, inv.Files, 3) + assert.Equal(t, "a.txt", inv.Files[0].Path) + assert.Equal(t, "m.txt", inv.Files[1].Path) + assert.Equal(t, "z.txt", inv.Files[2].Path) +} + +func TestHFSource_Inventory_Stable(t *testing.T) { + mock := &hfMock{ + commitSHA: "cccc" + strings.Repeat("00", 18), + tree: []hfTreeEntry{ + {Type: "file", Path: "a.txt", Size: 5}, + }, + files: map[string]string{"a.txt": "hello"}, + } + ts := httptest.NewServer(mock.handler()) + defer ts.Close() + + src := newTestHFSource(t, ts.URL, "myorg/mymodel", "main") + + inv1, err := src.Inventory(t.Context()) + require.NoError(t, err) + inv2, err := src.Inventory(t.Context()) + require.NoError(t, err) + + assert.Equal(t, inv1.Fingerprint, inv2.Fingerprint, "fingerprint must be stable across calls") + assert.Equal(t, inv1.Files, inv2.Files, "file list must be stable across calls") +} + +func TestHFSource_Inventory_ContextCanceled(t *testing.T) { + mock := &hfMock{ + commitSHA: "dddd" + strings.Repeat("00", 18), + tree: []hfTreeEntry{}, + files: map[string]string{}, + } + ts := httptest.NewServer(mock.handler()) + defer ts.Close() + + src := newTestHFSource(t, ts.URL, "myorg/mymodel", "main") + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + _, err := src.Inventory(ctx) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestHFSource_Open(t *testing.T) { + mock := &hfMock{ + commitSHA: "eeee" + strings.Repeat("00", 18), + tree: []hfTreeEntry{}, + files: map[string]string{ + "model.bin": "model-bytes", + "sub/data.bin": "nested-data", + }, + } + ts := httptest.NewServer(mock.handler()) + defer ts.Close() + + src := newTestHFSource(t, ts.URL, "myorg/mymodel", "main") + + t.Run("top level", func(t *testing.T) { + rc, err := src.Open(t.Context(), "model.bin") + require.NoError(t, err) + defer rc.Close() + b, err := io.ReadAll(rc) + require.NoError(t, err) + assert.Equal(t, "model-bytes", string(b)) + }) + + t.Run("nested path", func(t *testing.T) { + rc, err := src.Open(t.Context(), "sub/data.bin") + require.NoError(t, err) + defer rc.Close() + b, err := io.ReadAll(rc) + require.NoError(t, err) + assert.Equal(t, "nested-data", string(b)) + }) + + t.Run("missing file", func(t *testing.T) { + _, err := src.Open(t.Context(), "nope.bin") + assert.ErrorContains(t, err, "404") + }) + + t.Run("canceled context", func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + cancel() + _, err := src.Open(ctx, "model.bin") + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + }) +} + +// TestHFSource_Open_UsesResolvedRef verifies that Open uses the commit +// sha resolved during Inventory, not the original mutable ref. This +// prevents content drift between Inventory and Open. +func TestHFSource_Open_UsesResolvedRef(t *testing.T) { + const resolvedSHA = "aaaa" + "bbbb" + "cccc" + "dddd" + "eeee" + "ffff" + "0000" + "1111" + "2222" + "3333" + var resolveRef string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + // Revision endpoint. + if strings.Contains(path, "/revision/") { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(hfRevisionResponse{SHA: resolvedSHA}) + return + } + // Tree endpoint. + if strings.Contains(path, "/tree/") { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode([]hfTreeEntry{ + {Type: "file", Path: "data.bin", Size: 3, LFS: &hfLFSInfo{OID: strings.Repeat("ab", 32), Size: 3}}, + }) + return + } + // Resolve/download endpoint — capture the ref used. + if parts := strings.SplitN(path, "/resolve/", 2); len(parts) == 2 { + ref, _, _ := strings.Cut(parts[1], "/") + resolveRef = ref + _, _ = w.Write([]byte("abc")) + return + } + http.NotFound(w, r) + })) + defer ts.Close() + + src := newTestHFSource(t, ts.URL, "myorg/mymodel", "main") + + // Before Inventory, Open falls back to the original ref. + rc, err := src.Open(t.Context(), "data.bin") + require.NoError(t, err) + _ = rc.Close() + assert.Equal(t, "main", resolveRef, "before Inventory, Open should use original ref") + + // Run Inventory to resolve the ref. + _, err = src.Inventory(t.Context()) + require.NoError(t, err) + + // After Inventory, Open should use the resolved commit sha. + rc, err = src.Open(t.Context(), "data.bin") + require.NoError(t, err) + _ = rc.Close() + assert.Equal(t, resolvedSHA, resolveRef, "after Inventory, Open should use resolved commit sha") +} + +func TestHFSource_AuthHeader(t *testing.T) { + var gotAuth string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(hfRevisionResponse{SHA: "abcd" + strings.Repeat("00", 18)}) + })) + defer ts.Close() + + src := &HFSource{ + repo: "org/repo", + ref: "main", + baseURL: mustParseURL(t, ts.URL), + token: "hf_test_token_123", + client: newHFHTTPClient(), + } + + _, err := src.resolveRef(t.Context()) + require.NoError(t, err) + assert.Equal(t, "Bearer hf_test_token_123", gotAuth) +} + +func TestHFSource_NoAuthHeader_WhenNoToken(t *testing.T) { + var gotAuth string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(hfRevisionResponse{SHA: "abcd" + strings.Repeat("00", 18)}) + })) + defer ts.Close() + + src := &HFSource{ + repo: "org/repo", + ref: "main", + baseURL: mustParseURL(t, ts.URL), + token: "", + client: newHFHTTPClient(), + } + + _, err := src.resolveRef(t.Context()) + require.NoError(t, err) + assert.Empty(t, gotAuth) +} + +func TestHFSource_HTTPErrors(t *testing.T) { + tests := []struct { + name string + statusCode int + wantSub string + }{ + {"401 auth", http.StatusUnauthorized, "HF_TOKEN"}, + {"403 forbidden", http.StatusForbidden, "permissions"}, + {"404 not found", http.StatusNotFound, "HTTP 404"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tc.statusCode) + })) + defer ts.Close() + + src := newTestHFSource(t, ts.URL, "org/repo", "main") + _, err := src.Inventory(t.Context()) + assert.ErrorContains(t, err, tc.wantSub) + }) + } +} + +func TestHFSource_HTTP500_Retries(t *testing.T) { + var attempts atomic.Int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := attempts.Add(1) + if n <= 1 { + w.WriteHeader(http.StatusInternalServerError) + return + } + // Succeed on second attempt. + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(hfRevisionResponse{SHA: "abcd" + strings.Repeat("00", 18)}) + })) + defer ts.Close() + + src := newTestHFSource(t, ts.URL, "org/repo", "main") + sha, err := src.resolveRef(t.Context()) + require.NoError(t, err) + assert.Equal(t, "abcd"+strings.Repeat("00", 18), sha) + assert.Equal(t, int32(2), attempts.Load(), "should have retried once") +} + +func TestHFSource_Open_EscapesPathComponents(t *testing.T) { + // Verify that file paths with special characters are properly + // URL-escaped when sent to the server. Go's net/http server + // decodes percent-encoding in r.URL.Path, so we capture the raw + // request line from the underlying connection via RequestURI. + var gotRequestURI string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotRequestURI = r.RequestURI + _, _ = w.Write([]byte("data")) + })) + defer ts.Close() + + src := newTestHFSource(t, ts.URL, "org/repo", "main") + + // File path with spaces and special chars. + rc, err := src.Open(t.Context(), "sub dir/model file.bin") + require.NoError(t, err) + _ = rc.Close() + + assert.Contains(t, gotRequestURI, "sub%20dir/model%20file.bin", + "path components should be individually escaped") +} + +func TestHFSource_BuildURL(t *testing.T) { + src := &HFSource{baseURL: mustParseURL(t, "https://huggingface.co")} + tests := []struct { + name string + segments []string + want string + }{ + {"simple", []string{"api", "models", "org/repo", "revision", "main"}, "https://huggingface.co/api/models/org/repo/revision/main"}, + {"cleans dots", []string{"api", "models", "org/repo", "revision", ".."}, "https://huggingface.co/api/models/org/repo"}, + {"cleans double slash", []string{"api", "models", "org/repo", "", "tree", "main"}, "https://huggingface.co/api/models/org/repo/tree/main"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := src.buildURL(tc.segments...) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestHFSource_BuildURLWithQuery(t *testing.T) { + src := &HFSource{baseURL: mustParseURL(t, "https://huggingface.co")} + got := src.buildURLWithQuery("recursive=true", "api", "models", "org/repo", "tree", "abc123") + assert.Equal(t, "https://huggingface.co/api/models/org/repo/tree/abc123?recursive=true", got) +} + +func TestFor_HFSchemes(t *testing.T) { + tests := []struct { + name string + uri string + wantRepo string + wantRef string + }{ + {"hf short", "hf://org/repo", "org/repo", "main"}, + {"hf short with ref", "hf://org/repo@v1.0", "org/repo", "v1.0"}, + {"huggingface long", "huggingface://org/repo", "org/repo", "main"}, + {"huggingface long with ref", "huggingface://org/repo@v2.0", "org/repo", "v2.0"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + src, err := For(tc.uri, "") + require.NoError(t, err) + hf, ok := src.(*HFSource) + require.True(t, ok, "expected *HFSource, got %T", src) + assert.Equal(t, tc.wantRepo, hf.repo) + assert.Equal(t, tc.wantRef, hf.ref) + }) + } +} diff --git a/pkg/model/weightsource/setdigest.go b/pkg/model/weightsource/setdigest.go new file mode 100644 index 0000000000..0640c821b3 --- /dev/null +++ b/pkg/model/weightsource/setdigest.go @@ -0,0 +1,102 @@ +package weightsource + +import ( + "context" + "fmt" + "io/fs" + "path/filepath" + "runtime" + + "golang.org/x/sync/errgroup" + + "github.com/replicate/cog/pkg/util" +) + +// fileEntry holds the metadata collected during the walk phase, before +// hashing. Separating walk from hash lets us parallelize the expensive +// SHA-256 computation. +type fileEntry struct { + absPath string + rel string + size int64 +} + +// computeInventory walks dir and produces an Inventory: per-file +// path/size/digest plus the source fingerprint. For file:// sources the +// fingerprint is the dirhash of the file set (spec §2.4) — the same +// formula used for the weight set digest. +// +// The walk phase collects file metadata (fast, sequential). The hash +// phase computes SHA-256 digests concurrently, bounded by GOMAXPROCS. +// +// The .cog state directory is skipped to match the packer's behavior. +// Symlinks and non-regular files are skipped — same reason. +func computeInventory(ctx context.Context, dir string) (Inventory, error) { + // Phase 1: walk to collect paths and sizes (metadata only, fast). + var entries []fileEntry + err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() && d.Name() == ".cog" { + return filepath.SkipDir + } + if !d.Type().IsRegular() { + return nil + } + if err := ctx.Err(); err != nil { + return err + } + + rel, err := filepath.Rel(dir, path) + if err != nil { + return fmt.Errorf("rel path for %s: %w", path, err) + } + rel = filepath.ToSlash(rel) + + info, err := d.Info() + if err != nil { + return fmt.Errorf("stat %s: %w", rel, err) + } + + entries = append(entries, fileEntry{absPath: path, rel: rel, size: info.Size()}) + return nil + }) + if err != nil { + return Inventory{}, err + } + + // Phase 2: hash files concurrently. + files := make([]InventoryFile, len(entries)) + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(runtime.GOMAXPROCS(0)) + + for i, e := range entries { + g.Go(func() error { + if err := ctx.Err(); err != nil { + return err + } + h, err := util.SHA256HashFile(e.absPath) + if err != nil { + return fmt.Errorf("hash %s: %w", e.rel, err) + } + files[i] = InventoryFile{ + Path: e.rel, + Size: e.size, + Digest: "sha256:" + h, + } + return nil + }) + } + + if err := g.Wait(); err != nil { + return Inventory{}, err + } + + sortInventoryFiles(files) + + return Inventory{ + Files: files, + Fingerprint: Fingerprint(DirHash(files)), + }, nil +} diff --git a/pkg/model/weightsource/source.go b/pkg/model/weightsource/source.go new file mode 100644 index 0000000000..8bb3cb22e3 --- /dev/null +++ b/pkg/model/weightsource/source.go @@ -0,0 +1,144 @@ +package weightsource + +import ( + "context" + "fmt" + "io" + "sort" + "strings" +) + +// Source is the provider for a weight-source scheme, bound at +// construction time to a specific URI. +// +// Implementations translate a scheme-specific URI (file://, hf://, s3://, +// http://, ...) into (a) an inventory of what the source contains, and +// (b) an on-demand byte stream for any one file in that inventory. The +// weights subsystem drives the import pipeline off these two capabilities +// — there is deliberately no "materialize the whole source to disk" step, +// so sources whose contents do not fit on local disk can still flow +// through the packer one file at a time. +// +// A Source instance is bound to one URI for its entire lifetime. Callers +// construct a Source via For(uri, projectDir). Methods are expected to be +// context-cancellable and safe to call concurrently for different paths. +type Source interface { + // Inventory returns the file list and version identity for the + // bound source. For file:// this walks and hashes (unavoidable for + // a local directory). For future remote sources it is expected to + // be cheap — HuggingFace Hub exposes per-file sha256 via its API, + // OCI sources read them from the source manifest's config blob. + Inventory(ctx context.Context) (Inventory, error) + + // Open returns a reader for a single file in the source, identified + // by its inventory path (relative to the source root). Called on + // demand during packing. The caller closes the returned reader. + Open(ctx context.Context, path string) (io.ReadCloser, error) +} + +// Inventory is the result of Source.Inventory: everything needed to plan +// an import without transferring payload bytes. +// +// Fingerprint is the source's version identity for the currently bound +// URI. Files is the list of content-addressed entries that make up the +// source; the packer consumes this list to produce tar layers. +type Inventory struct { + Files []InventoryFile + Fingerprint Fingerprint +} + +// InventoryFile is one entry in an Inventory: a file's relative path, +// size, and content digest. For file:// the digest is computed by +// walking and hashing; for remote sources it is read from a source-side +// index. +type InventoryFile struct { + // Path is the file path relative to the source root, using forward + // slashes regardless of the host OS. + Path string + // Size is the uncompressed file size in bytes. + Size int64 + // Digest is the SHA-256 content digest with the "sha256:" prefix. + Digest string +} + +// DirhashParts implements Dirhashable so InventoryFile slices can be +// passed directly to DirHash. +func (f InventoryFile) DirhashParts() DirhashPart { + return DirhashPart{Path: f.Path, Digest: f.Digest} +} + +// sortInventoryFiles sorts files by path. Every Source implementation +// must return a sorted inventory; this helper enforces the convention. +func sortInventoryFiles(files []InventoryFile) { + sort.Slice(files, func(i, j int) bool { return files[i].Path < files[j].Path }) +} + +// NormalizeURI returns the canonical form of a weight source URI. +// +// Each scheme has its own normalization rules: +// - file:// and bare paths → canonical file:// form (see normalizeFileURI) +// - hf:// and huggingface:// → canonical hf:// form (see normalizeHFURI) +// +// Empty strings and unsupported schemes return an error. +func NormalizeURI(uri string) (string, error) { + if uri == "" { + return "", fmt.Errorf("empty weight source uri") + } + + scheme := schemeOf(uri) + switch scheme { + case "file", "": + // Bare paths and file:// URIs. For bare paths the full URI is + // the path; for file:// we strip the scheme prefix before + // normalizing. + path := uri + if scheme == "file" { + path = strings.TrimPrefix(uri, "file://") + } + return normalizeFileURI(path) + case HFScheme, HFSchemeLong: + return normalizeHFURI(uri) + default: + return "", fmt.Errorf("unsupported weight source scheme %q in URI %q", scheme, uri) + } +} + +// For returns the Source implementation for the given URI's scheme, +// bound to uri and projectDir. +// +// The scheme is the substring before the first "://". Bare paths (no +// scheme) are treated as file:// — this accepts both absolute ("/data") +// and relative ("./weights") forms as a convenience at the interface +// boundary. +// +// Unknown schemes return a clear error listing the currently supported +// schemes. This is the only place where scheme → implementation dispatch +// happens; adding s3:// or http:// is a single case here plus the +// matching Source implementation. +// +// For validates that the source exists and is usable. A file:// URI that +// points at a missing path or at a non-directory returns an error here, +// not at Inventory time. +func For(uri, projectDir string) (Source, error) { + scheme := schemeOf(uri) + switch scheme { + case "file", "": + return NewFileSource(uri, projectDir) + case HFScheme, HFSchemeLong: + return NewHFSource(uri) + default: + return nil, fmt.Errorf("unsupported weight source scheme %q (supported: file, hf, huggingface)", scheme) + } +} + +// schemeOf returns the scheme component of a URI, or "" for bare paths. +// It intentionally does not try to parse with net/url — hf://org/repo, +// s3://bucket/key, etc. are not RFC 3986-conformant URLs and net/url +// behaves inconsistently for them across schemes. +func schemeOf(uri string) string { + scheme, _, ok := strings.Cut(uri, "://") + if !ok { + return "" + } + return scheme +} diff --git a/pkg/model/weightsource/source_test.go b/pkg/model/weightsource/source_test.go new file mode 100644 index 0000000000..a3cc64fb01 --- /dev/null +++ b/pkg/model/weightsource/source_test.go @@ -0,0 +1,76 @@ +package weightsource + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFor(t *testing.T) { + // Prepare a real directory so file:// constructors succeed. + projectDir := t.TempDir() + weightsDir := filepath.Join(projectDir, "weights") + require.NoError(t, os.MkdirAll(weightsDir, 0o755)) + + absDir := t.TempDir() + + tests := []struct { + name string + uri string + projectDir string + wantType string // "file", "hf", or "" for expected error + wantErrSubs string + }{ + {"file scheme", "file://" + absDir, "", "file", ""}, + {"file scheme relative", "file://./weights", projectDir, "file", ""}, + {"bare absolute path", absDir, "", "file", ""}, + {"bare relative path", "./weights", projectDir, "file", ""}, + {"bare no prefix", "weights", projectDir, "file", ""}, + {"hf scheme", "hf://org/repo", "", "hf", ""}, + {"huggingface scheme", "huggingface://org/repo", "", "hf", ""}, + {"s3 scheme rejected", "s3://bucket/key", "", "", "unsupported"}, + {"http scheme rejected", "http://example.com/x", "", "", "unsupported"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + s, err := For(tc.uri, tc.projectDir) + if tc.wantType == "" { + assert.ErrorContains(t, err, tc.wantErrSubs) + return + } + require.NoError(t, err) + switch tc.wantType { + case "file": + _, ok := s.(*FileSource) + assert.True(t, ok, "expected *FileSource, got %T", s) + case "hf": + _, ok := s.(*HFSource) + assert.True(t, ok, "expected *HFSource, got %T", s) + } + }) + } +} + +func TestSchemeOf(t *testing.T) { + tests := []struct { + uri string + want string + }{ + {"file:///abs", "file"}, + {"hf://org/repo", "hf"}, + {"huggingface://org/repo", "huggingface"}, + {"s3://bucket/key", "s3"}, + {"/abs", ""}, + {"./rel", ""}, + {"bare", ""}, + {"", ""}, + } + for _, tc := range tests { + t.Run(tc.uri, func(t *testing.T) { + assert.Equal(t, tc.want, schemeOf(tc.uri)) + }) + } +} diff --git a/pkg/paths/paths.go b/pkg/paths/paths.go new file mode 100644 index 0000000000..c918beb6d9 --- /dev/null +++ b/pkg/paths/paths.go @@ -0,0 +1,56 @@ +// Package paths resolves on-disk locations for Cog's per-user caches. +// +// Callers SHOULD route every cache lookup through this package so a single +// environment variable (COG_CACHE_DIR) can relocate all of Cog's caches in +// one step — useful when the default cache directory lives on a different +// filesystem than the user's project tree and hardlinking would fail with +// EXDEV. +package paths + +import ( + "errors" + "fmt" + "os" + "path/filepath" +) + +// envCacheDir lets users override the default cache root. When set, it +// replaces the default entirely. +const envCacheDir = "COG_CACHE_DIR" + +// envXDGCache is the XDG Base Directory Specification cache-home +// variable. Respected on every platform (not just Linux) so users who +// have set it get what they expect. +const envXDGCache = "XDG_CACHE_HOME" + +// WeightsStoreDir returns the directory that backs the local +// content-addressed weight file store. +// +// Resolution order: +// +// 1. $COG_CACHE_DIR/weights, if COG_CACHE_DIR is set. +// 2. $XDG_CACHE_HOME/cog/weights, if set. +// 3. $HOME/.cache/cog/weights otherwise — on every platform. +// +// Note the deliberate deviation from os.UserCacheDir, which returns +// $HOME/Library/Caches on macOS. Dev tooling conventionally lives under +// ~/.cache or ~/., not ~/Library, so we follow suit. +// +// WeightsStoreDir does not create the directory; callers that need it +// to exist should MkdirAll it themselves (FileStore does). +func WeightsStoreDir() (string, error) { + if dir := os.Getenv(envCacheDir); dir != "" { + return filepath.Join(dir, "weights"), nil + } + if dir := os.Getenv(envXDGCache); dir != "" { + return filepath.Join(dir, "cog", "weights"), nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("resolve user home dir: %w", err) + } + if home == "" { + return "", errors.New("user home dir is empty") + } + return filepath.Join(home, ".cache", "cog", "weights"), nil +} diff --git a/pkg/paths/paths_test.go b/pkg/paths/paths_test.go new file mode 100644 index 0000000000..706ee1a731 --- /dev/null +++ b/pkg/paths/paths_test.go @@ -0,0 +1,43 @@ +package paths + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWeightsStoreDir_COGCacheDir_Wins(t *testing.T) { + t.Setenv("COG_CACHE_DIR", "/tmp/custom-cache") + t.Setenv("XDG_CACHE_HOME", "/tmp/xdg-ignored") + + dir, err := WeightsStoreDir() + require.NoError(t, err) + require.Equal(t, filepath.Join("/tmp/custom-cache", "weights"), dir) +} + +func TestWeightsStoreDir_XDGCacheHome_RespectedOnAllPlatforms(t *testing.T) { + t.Setenv("COG_CACHE_DIR", "") + t.Setenv("XDG_CACHE_HOME", "/tmp/xdg") + + dir, err := WeightsStoreDir() + require.NoError(t, err) + require.Equal(t, filepath.Join("/tmp/xdg", "cog", "weights"), dir) +} + +// TestWeightsStoreDir_Defaults verifies the no-env default is +// $HOME/.cache/cog/weights on every platform — explicitly NOT +// $HOME/Library/Caches on macOS. Dev tools conventionally live under +// ~/.cache or ~/.. +func TestWeightsStoreDir_Defaults(t *testing.T) { + t.Setenv("COG_CACHE_DIR", "") + t.Setenv("XDG_CACHE_HOME", "") + + home, err := os.UserHomeDir() + require.NoError(t, err) + + dir, err := WeightsStoreDir() + require.NoError(t, err) + require.Equal(t, filepath.Join(home, ".cache", "cog", "weights"), dir) +} diff --git a/pkg/predict/predictor.go b/pkg/predict/predictor.go index 2de3046603..82f51b5433 100644 --- a/pkg/predict/predictor.go +++ b/pkg/predict/predictor.go @@ -16,6 +16,7 @@ import ( "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/weights" ) type status string @@ -49,40 +50,87 @@ type ValidationErrorResponse struct { } `json:"detail"` } +// PredictorOptions configures a Predictor. +// +// RunOptions carries everything the user supplied (image, volumes, +// env, GPUs, ports). If WeightManager is non-nil, Predictor.Start +// will call Prepare and merge the resulting read-only mounts into +// RunOptions.Volumes before launching the container; Stop will +// Release them afterwards. A nil WeightManager preserves the +// historical behavior for callers that don't deal with managed +// weights. +type PredictorOptions struct { + RunOptions command.RunOptions + IsTrain bool + Docker command.Command + WeightManager *weights.Manager +} + type Predictor struct { runOptions command.RunOptions isTrain bool dockerClient command.Command + weightManager *weights.Manager + mounts *weights.Mounts // populated by Start when weightManager != nil + // Running state containerID string port int } -func NewPredictor(ctx context.Context, runOptions command.RunOptions, isTrain bool, dockerCommand command.Command) (*Predictor, error) { +// NewPredictor constructs a Predictor. See PredictorOptions for the +// meaning of each field. +func NewPredictor(_ context.Context, opts PredictorOptions) (*Predictor, error) { if global.Debug { - runOptions.Env = append(runOptions.Env, "COG_LOG_LEVEL=debug") + opts.RunOptions.Env = append(opts.RunOptions.Env, "COG_LOG_LEVEL=debug") } else { - runOptions.Env = append(runOptions.Env, "COG_LOG_LEVEL=warning") + opts.RunOptions.Env = append(opts.RunOptions.Env, "COG_LOG_LEVEL=warning") } return &Predictor{ - runOptions: runOptions, - isTrain: isTrain, - dockerClient: dockerCommand, + runOptions: opts.RunOptions, + isTrain: opts.IsTrain, + dockerClient: opts.Docker, + weightManager: opts.WeightManager, }, nil } -func (p *Predictor) Start(ctx context.Context, logsWriter io.Writer, timeout time.Duration) error { - var err error +func (p *Predictor) Start(ctx context.Context, logsWriter io.Writer, timeout time.Duration) (retErr error) { containerPort := 5000 + if p.weightManager != nil { + mounts, err := p.weightManager.Prepare(ctx) + if err != nil { + return fmt.Errorf("prepare weights: %w", err) + } + p.mounts = mounts + // Mount dirs are hardlinks from the store; on any Start + // failure we release them immediately so the caller (whose + // defer Stop is only registered on successful Start) doesn't + // orphan /.cog/mounts/. + defer func() { + if retErr != nil { + _ = p.mounts.Release() + p.mounts = nil + } + }() + for _, spec := range mounts.Specs { + p.runOptions.Volumes = append(p.runOptions.Volumes, command.Volume{ + Source: spec.Source, + Destination: spec.Target, + ReadOnly: true, + }) + } + } + p.runOptions.Ports = append(p.runOptions.Ports, command.Port{HostPort: 0, ContainerPort: containerPort}) - p.containerID, err = docker.RunDaemon(ctx, p.dockerClient, p.runOptions, logsWriter) + containerID, err := docker.RunDaemon(ctx, p.dockerClient, p.runOptions, logsWriter) if err != nil { return fmt.Errorf("Failed to start container: %w", err) } + p.containerID = containerID p.port, err = docker.GetHostPortForContainer(ctx, p.dockerClient, p.containerID, containerPort) if err != nil { @@ -164,7 +212,20 @@ func (p *Predictor) waitForContainerReady(ctx context.Context, timeout time.Dura } func (p *Predictor) Stop(ctx context.Context) error { - return p.dockerClient.ContainerStop(ctx, p.containerID) + stopErr := p.dockerClient.ContainerStop(ctx, p.containerID) + + // Always attempt mount cleanup, even if ContainerStop failed — a + // leftover bind source is worth logging over silently orphaning. + // Mount removal after container stop is safe on Linux: bind mounts + // don't prevent source-side removal. + if p.mounts != nil { + if err := p.mounts.Release(); err != nil { + console.Warnf("Failed to clean up weight mounts: %s", err) + } + p.mounts = nil + } + + return stopErr } func (p *Predictor) Predict(inputs Inputs, context RequestContext) (*Response, error) { diff --git a/pkg/predict/predictor_test.go b/pkg/predict/predictor_test.go new file mode 100644 index 0000000000..de9f5d4081 --- /dev/null +++ b/pkg/predict/predictor_test.go @@ -0,0 +1,91 @@ +package predict + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/docker/dockertest" + "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/registry/registrytest" + "github.com/replicate/cog/pkg/weights" + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" +) + +// TestPredictor_Start_CleansUpMountsOnContainerStartFailure is a +// regression test: if Prepare succeeds but the docker container fails +// to start, the per-invocation mount dir under +// /.cog/mounts/ must be cleaned up. Callers only register +// defer Stop() on successful Start, so Start is responsible for +// cleanup on its own failure paths. +func TestPredictor_Start_CleansUpMountsOnContainerStartFailure(t *testing.T) { + t.Parallel() + + data := []byte("x") + sum := sha256.Sum256(data) + digest := "sha256:" + hex.EncodeToString(sum[:]) + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{{ + Name: "w1", + Target: "/src/weights/w1", + Files: []lockfile.WeightLockFile{{Path: "f", Size: int64(len(data)), Digest: digest, Layer: "sha256:x"}}, + }}, + } + + fs, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + require.NoError(t, fs.PutFile(context.Background(), digest, int64(len(data)), bytes.NewReader(data))) + + projectDir := t.TempDir() + mgr, err := weights.NewManager(weights.ManagerOptions{ + Store: fs, + Registry: stubRegistry{Client: registrytest.NewMockRegistryClient()}, + Repo: "example.com/test", + Lock: lock, + ProjectDir: projectDir, + }) + require.NoError(t, err) + + mockDocker := dockertest.NewMockCommand2(t) + mockDocker.EXPECT(). + ContainerStart(mock.Anything, mock.Anything). + Return("", errors.New("simulated container start failure")) + + p, err := NewPredictor(context.Background(), PredictorOptions{ + RunOptions: command.RunOptions{Image: "fake"}, + Docker: mockDocker, + WeightManager: mgr, + }) + require.NoError(t, err) + + err = p.Start(context.Background(), os.Stderr, time.Second) + require.Error(t, err) + + // Invocation dir must be cleaned up by Start's on-error defer. + mountRoot := filepath.Join(projectDir, ".cog", "mounts") + entries, readErr := os.ReadDir(mountRoot) + if readErr == nil { + require.Empty(t, entries, "failed Start must not leave invocation dirs under %s", mountRoot) + } else { + require.ErrorIs(t, readErr, os.ErrNotExist) + } +} + +// stubRegistry is a minimal registry.Client the predict test can +// construct — it's never actually called because Prepare doesn't need +// the registry. +type stubRegistry struct { + registry.Client +} diff --git a/pkg/registry/client.go b/pkg/registry/client.go index d26d6ea303..9058f1bad8 100644 --- a/pkg/registry/client.go +++ b/pkg/registry/client.go @@ -82,4 +82,8 @@ type Client interface { // This method handles transient failures automatically with exponential backoff. // Use WriteLayerOptions to configure progress reporting and retry callbacks. WriteLayer(ctx context.Context, opts WriteLayerOptions) error + + // BlobExists checks whether a blob with the given digest exists in the + // repository. This is a lightweight HEAD request. + BlobExists(ctx context.Context, repo string, digest string) (bool, error) } diff --git a/pkg/registry/registry_client.go b/pkg/registry/registry_client.go index 8e80f8ddad..4524fad6a4 100644 --- a/pkg/registry/registry_client.go +++ b/pkg/registry/registry_client.go @@ -579,6 +579,34 @@ func (c *RegistryClient) writeLayerMultipart(ctx context.Context, repo name.Repo return nil } +// BlobExists checks whether a blob with the given digest exists in the +// repository. This is a lightweight HEAD request. +func (c *RegistryClient) BlobExists(ctx context.Context, repoStr string, digestStr string) (bool, error) { + repo, err := name.NewRepository(repoStr, name.Insecure) + if err != nil { + return false, fmt.Errorf("parse repository %q: %w", repoStr, err) + } + + digest, err := v1.NewHash(digestStr) + if err != nil { + return false, fmt.Errorf("parse digest %q: %w", digestStr, err) + } + + auth, err := authn.Resolve(ctx, authn.DefaultKeychain, repo) + if err != nil { + return false, fmt.Errorf("resolving auth: %w", err) + } + + scopes := []string{repo.Scope(transport.PullScope)} + tr, err := transport.NewWithContext(ctx, repo.Registry, auth, c.transport, scopes) + if err != nil { + return false, fmt.Errorf("creating transport: %w", err) + } + + client := &http.Client{Transport: tr} + return c.checkBlobExists(ctx, client, repo, digest) +} + // checkBlobExists checks if a blob already exists in the repository. func (c *RegistryClient) checkBlobExists(ctx context.Context, client *http.Client, repo name.Repository, digest v1.Hash) (bool, error) { u := url.URL{ diff --git a/pkg/registry/registrytest/mock_client.go b/pkg/registry/registrytest/mock_client.go index c9c895435d..726c212e39 100644 --- a/pkg/registry/registrytest/mock_client.go +++ b/pkg/registry/registrytest/mock_client.go @@ -52,3 +52,7 @@ func (c *MockRegistryClient) GetDescriptor(ctx context.Context, imageRef string) func (c *MockRegistryClient) WriteLayer(ctx context.Context, opts registry.WriteLayerOptions) error { return nil } + +func (c *MockRegistryClient) BlobExists(ctx context.Context, repo string, digest string) (bool, error) { + return false, nil +} diff --git a/pkg/weights/check_drift.go b/pkg/weights/check_drift.go new file mode 100644 index 0000000000..613fe79756 --- /dev/null +++ b/pkg/weights/check_drift.go @@ -0,0 +1,100 @@ +package weights + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "slices" + "strings" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/model/weightsource" + "github.com/replicate/cog/pkg/weights/lockfile" +) + +// CheckDrift loads the lockfile from projectDir and compares it against +// the config's weight declarations. It returns a user-facing error if +// any drift is detected, telling the user to run "cog weights import". +// +// Returns nil when weights is empty, when config and lockfile agree, +// or when the lockfile is missing and there are no config weights. +func CheckDrift(projectDir string, weights []config.WeightSource) error { + if len(weights) == 0 { + return nil + } + + lockPath := filepath.Join(projectDir, lockfile.WeightsLockFilename) + lock, err := lockfile.LoadWeightsLock(lockPath) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("load %s: %w", lockfile.WeightsLockFilename, err) + } + lock = nil // missing lockfile → all weights are pending + } + + configWeights, err := toConfigWeights(weights) + if err != nil { + return err + } + + results := lockfile.CheckDrift(lock, configWeights) + if len(results) == 0 { + return nil + } + + return formatDriftError(results) +} + +// toConfigWeights converts config weight declarations to +// lockfile.ConfigWeight values, normalizing URIs and sorting +// include/exclude patterns so the comparison matches the lockfile's +// canonical form. +func toConfigWeights(ws []config.WeightSource) ([]lockfile.ConfigWeight, error) { + cws := make([]lockfile.ConfigWeight, 0, len(ws)) + for _, w := range ws { + uri, err := weightsource.NormalizeURI(w.SourceURI()) + if err != nil { + return nil, fmt.Errorf("weight %q: %w", w.Name, err) + } + cw := lockfile.ConfigWeight{ + Name: w.Name, + Target: w.Target, + URI: uri, + } + if w.Source != nil { + cw.Include = sortedCopy(w.Source.Include) + cw.Exclude = sortedCopy(w.Source.Exclude) + } + cws = append(cws, cw) + } + return cws, nil +} + +// sortedCopy returns a sorted copy of s, or nil if s is nil. +func sortedCopy(s []string) []string { + if s == nil { + return nil + } + out := slices.Clone(s) + slices.Sort(out) + return out +} + +// formatDriftError builds a user-facing error from drift results. +func formatDriftError(results []lockfile.DriftResult) error { + var b strings.Builder + b.WriteString("weights.lock is out of sync with cog.yaml:\n") + for _, r := range results { + switch r.Kind { + case lockfile.DriftPending: + fmt.Fprintf(&b, " - %q: not imported yet\n", r.Name) + case lockfile.DriftOrphaned: + fmt.Fprintf(&b, " - %q: removed from cog.yaml but still in lockfile\n", r.Name) + case lockfile.DriftConfigChanged: + fmt.Fprintf(&b, " - %q: config changed (%s)\n", r.Name, r.Details) + } + } + b.WriteString("Run 'cog weights import' to update.") + return errors.New(b.String()) +} diff --git a/pkg/weights/check_drift_test.go b/pkg/weights/check_drift_test.go new file mode 100644 index 0000000000..71039afda8 --- /dev/null +++ b/pkg/weights/check_drift_test.go @@ -0,0 +1,180 @@ +package weights + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/weights/lockfile" +) + +func TestCheckDrift(t *testing.T) { + t.Run("no weights in config: always passes", func(t *testing.T) { + require.NoError(t, CheckDrift(t.TempDir(), nil)) + }) + + t.Run("config and lockfile in sync: passes", func(t *testing.T) { + dir := t.TempDir() + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + { + Name: "my-model", + Target: "/src/weights", + Source: lockfile.WeightLockSource{ + URI: "file://./weights", + Include: []string{}, + Exclude: []string{}, + }, + }, + }, + } + require.NoError(t, lock.Save(filepath.Join(dir, lockfile.WeightsLockFilename))) + + ws := []config.WeightSource{ + { + Name: "my-model", + Target: "/src/weights", + Source: &config.WeightSourceConfig{ + URI: "./weights", + Include: []string{}, + Exclude: []string{}, + }, + }, + } + require.NoError(t, CheckDrift(dir, ws)) + }) + + t.Run("pending weight: errors", func(t *testing.T) { + dir := t.TempDir() + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{}} + require.NoError(t, lock.Save(filepath.Join(dir, lockfile.WeightsLockFilename))) + + ws := []config.WeightSource{ + { + Name: "new-model", + Target: "/src/weights", + Source: &config.WeightSourceConfig{URI: "./weights"}, + }, + } + + err := CheckDrift(dir, ws) + require.Error(t, err) + assert.Contains(t, err.Error(), "weights.lock is out of sync") + assert.Contains(t, err.Error(), "new-model") + assert.Contains(t, err.Error(), "not imported yet") + assert.Contains(t, err.Error(), "cog weights import") + }) + + t.Run("orphaned weight: errors", func(t *testing.T) { + dir := t.TempDir() + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + { + Name: "kept", + Target: "/src/kept", + Source: lockfile.WeightLockSource{ + URI: "file://./kept", + Include: []string{}, + Exclude: []string{}, + }, + }, + { + Name: "removed", + Target: "/src/removed", + Source: lockfile.WeightLockSource{ + URI: "file://./removed", + Include: []string{}, + Exclude: []string{}, + }, + }, + }, + } + require.NoError(t, lock.Save(filepath.Join(dir, lockfile.WeightsLockFilename))) + + ws := []config.WeightSource{ + { + Name: "kept", + Target: "/src/kept", + Source: &config.WeightSourceConfig{URI: "./kept"}, + }, + } + + err := CheckDrift(dir, ws) + require.Error(t, err) + assert.Contains(t, err.Error(), "removed") + assert.Contains(t, err.Error(), "removed from cog.yaml") + }) + + t.Run("config changed: errors", func(t *testing.T) { + dir := t.TempDir() + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + { + Name: "my-model", + Target: "/src/old-path", + Source: lockfile.WeightLockSource{ + URI: "file://./weights", + Include: []string{}, + Exclude: []string{}, + }, + }, + }, + } + require.NoError(t, lock.Save(filepath.Join(dir, lockfile.WeightsLockFilename))) + + ws := []config.WeightSource{ + { + Name: "my-model", + Target: "/src/new-path", + Source: &config.WeightSourceConfig{URI: "./weights"}, + }, + } + + err := CheckDrift(dir, ws) + require.Error(t, err) + assert.Contains(t, err.Error(), "my-model") + assert.Contains(t, err.Error(), "config changed") + }) + + t.Run("missing lockfile with config weights: errors", func(t *testing.T) { + ws := []config.WeightSource{ + { + Name: "my-model", + Target: "/src/weights", + Source: &config.WeightSourceConfig{URI: "./weights"}, + }, + } + + err := CheckDrift(t.TempDir(), ws) + require.Error(t, err) + assert.Contains(t, err.Error(), "not imported yet") + }) + + t.Run("corrupt lockfile: errors", func(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(dir, lockfile.WeightsLockFilename), + []byte("not json"), + 0o644, + )) + + ws := []config.WeightSource{ + { + Name: "my-model", + Target: "/src/weights", + Source: &config.WeightSourceConfig{URI: "./weights"}, + }, + } + + err := CheckDrift(dir, ws) + require.Error(t, err) + assert.Contains(t, err.Error(), "load weights.lock") + }) +} diff --git a/pkg/weights/lockfile/drift.go b/pkg/weights/lockfile/drift.go new file mode 100644 index 0000000000..792add2d59 --- /dev/null +++ b/pkg/weights/lockfile/drift.go @@ -0,0 +1,112 @@ +package lockfile + +import ( + "fmt" + "slices" + "strings" +) + +// DriftKind classifies how a weight declaration has drifted from the +// lockfile. +type DriftKind string + +const ( + // DriftOrphaned means the lockfile has an entry with no matching + // config declaration — the weight was removed from cog.yaml. + DriftOrphaned DriftKind = "orphaned" + // DriftPending means config declares a weight that has no lockfile + // entry — the weight has never been imported. + DriftPending DriftKind = "pending" + // DriftConfigChanged means both config and lockfile have the weight + // but a user-intent field (URI, target, include, exclude) differs. + DriftConfigChanged DriftKind = "config-changed" +) + +// DriftResult describes a single config-vs-lockfile mismatch. +type DriftResult struct { + Name string + Kind DriftKind + Details string // human-readable detail, e.g. "target: /old → /new" +} + +// ConfigWeight is the lockfile package's view of a weight declaration +// from cog.yaml. It carries only the user-intent fields that affect +// whether a lockfile entry is stale. Callers must normalize URI and +// sort Include/Exclude before constructing a ConfigWeight — CheckDrift +// does byte-exact comparison. +type ConfigWeight struct { + Name string + URI string + Target string + Include []string + Exclude []string +} + +// CheckDrift compares config declarations against lockfile entries and +// returns every mismatch. The result is empty when config and lockfile +// agree. A nil lock is treated as an empty lockfile (every config +// weight is "pending"). The function is pure: no I/O, no network. +func CheckDrift(lock *WeightsLock, configWeights []ConfigWeight) []DriftResult { + lockByName := make(map[string]*WeightLockEntry) + if lock != nil { + for i := range lock.Weights { + lockByName[lock.Weights[i].Name] = &lock.Weights[i] + } + } + + configNames := make(map[string]bool, len(configWeights)) + var results []DriftResult + + // Check each config weight against the lockfile. + for _, cw := range configWeights { + configNames[cw.Name] = true + le := lockByName[cw.Name] + + if le == nil { + results = append(results, DriftResult{ + Name: cw.Name, + Kind: DriftPending, + }) + } else if details := configChanged(cw, le); details != "" { + results = append(results, DriftResult{ + Name: cw.Name, + Kind: DriftConfigChanged, + Details: details, + }) + } + } + + // Check for orphaned lockfile entries. + if lock != nil { + for _, le := range lock.Weights { + if !configNames[le.Name] { + results = append(results, DriftResult{ + Name: le.Name, + Kind: DriftOrphaned, + }) + } + } + } + + return results +} + +// configChanged returns a human-readable diff string listing every +// user-intent field that differs between config and lockfile. Returns +// "" when they match. +func configChanged(cw ConfigWeight, le *WeightLockEntry) string { + var diffs []string + if cw.URI != le.Source.URI { + diffs = append(diffs, fmt.Sprintf("uri: %s → %s", le.Source.URI, cw.URI)) + } + if cw.Target != le.Target { + diffs = append(diffs, fmt.Sprintf("target: %s → %s", le.Target, cw.Target)) + } + if !slices.Equal(cw.Include, le.Source.Include) { + diffs = append(diffs, fmt.Sprintf("include: %v → %v", le.Source.Include, cw.Include)) + } + if !slices.Equal(cw.Exclude, le.Source.Exclude) { + diffs = append(diffs, fmt.Sprintf("exclude: %v → %v", le.Source.Exclude, cw.Exclude)) + } + return strings.Join(diffs, "; ") +} diff --git a/pkg/weights/lockfile/drift_test.go b/pkg/weights/lockfile/drift_test.go new file mode 100644 index 0000000000..8e53275ce0 --- /dev/null +++ b/pkg/weights/lockfile/drift_test.go @@ -0,0 +1,202 @@ +package lockfile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCheckDrift(t *testing.T) { + tests := []struct { + name string + lock *WeightsLock + config []ConfigWeight + want []DriftResult + }{ + { + name: "no drift when config matches lockfile", + lock: &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "model-a", Target: "/weights/a", Source: WeightLockSource{URI: "file://./a", Include: []string{}, Exclude: []string{}}}, + {Name: "model-b", Target: "/weights/b", Source: WeightLockSource{URI: "file://./b", Include: []string{"*.bin"}, Exclude: []string{"README*"}}}, + }, + }, + config: []ConfigWeight{ + {Name: "model-a", URI: "file://./a", Target: "/weights/a", Include: []string{}, Exclude: []string{}}, + {Name: "model-b", URI: "file://./b", Target: "/weights/b", Include: []string{"*.bin"}, Exclude: []string{"README*"}}, + }, + want: []DriftResult{}, + }, + { + name: "orphaned: lockfile entry not in config", + lock: &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "kept", Target: "/kept", Source: WeightLockSource{URI: "file://./kept", Include: []string{}, Exclude: []string{}}}, + {Name: "removed", Target: "/removed", Source: WeightLockSource{URI: "file://./removed", Include: []string{}, Exclude: []string{}}}, + }, + }, + config: []ConfigWeight{ + {Name: "kept", URI: "file://./kept", Target: "/kept", Include: []string{}, Exclude: []string{}}, + }, + want: []DriftResult{ + {Name: "removed", Kind: DriftOrphaned}, + }, + }, + { + name: "pending: config weight not in lockfile", + lock: &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "existing", Target: "/existing", Source: WeightLockSource{URI: "file://./existing", Include: []string{}, Exclude: []string{}}}, + }, + }, + config: []ConfigWeight{ + {Name: "existing", URI: "file://./existing", Target: "/existing", Include: []string{}, Exclude: []string{}}, + {Name: "new-weight", URI: "file://./new", Target: "/new"}, + }, + want: []DriftResult{ + {Name: "new-weight", Kind: DriftPending}, + }, + }, + { + name: "config-changed: URI differs", + lock: &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "w", Target: "/w", Source: WeightLockSource{URI: "file://./old", Include: []string{}, Exclude: []string{}}}, + }, + }, + config: []ConfigWeight{ + {Name: "w", URI: "file://./new", Target: "/w", Include: []string{}, Exclude: []string{}}, + }, + want: []DriftResult{ + {Name: "w", Kind: DriftConfigChanged, Details: "uri: file://./old → file://./new"}, + }, + }, + { + name: "config-changed: target differs", + lock: &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "w", Target: "/old-path", Source: WeightLockSource{URI: "file://./w", Include: []string{}, Exclude: []string{}}}, + }, + }, + config: []ConfigWeight{ + {Name: "w", URI: "file://./w", Target: "/new-path", Include: []string{}, Exclude: []string{}}, + }, + want: []DriftResult{ + {Name: "w", Kind: DriftConfigChanged, Details: "target: /old-path → /new-path"}, + }, + }, + { + name: "config-changed: include differs", + lock: &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "w", Target: "/w", Source: WeightLockSource{URI: "file://./w", Include: []string{"*.bin"}, Exclude: []string{}}}, + }, + }, + config: []ConfigWeight{ + {Name: "w", URI: "file://./w", Target: "/w", Include: []string{"*.safetensors"}, Exclude: []string{}}, + }, + want: []DriftResult{ + {Name: "w", Kind: DriftConfigChanged, Details: "include: [*.bin] → [*.safetensors]"}, + }, + }, + { + name: "config-changed: exclude differs", + lock: &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "w", Target: "/w", Source: WeightLockSource{URI: "file://./w", Include: []string{}, Exclude: []string{}}}, + }, + }, + config: []ConfigWeight{ + {Name: "w", URI: "file://./w", Target: "/w", Include: []string{}, Exclude: []string{"README*"}}, + }, + want: []DriftResult{ + {Name: "w", Kind: DriftConfigChanged, Details: "exclude: [] → [README*]"}, + }, + }, + { + name: "config-changed: multiple fields differ", + lock: &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "w", Target: "/old-path", Source: WeightLockSource{URI: "file://./old", Include: []string{}, Exclude: []string{}}}, + }, + }, + config: []ConfigWeight{ + {Name: "w", URI: "file://./new", Target: "/new-path", Include: []string{}, Exclude: []string{}}, + }, + want: []DriftResult{ + {Name: "w", Kind: DriftConfigChanged, Details: "uri: file://./old → file://./new; target: /old-path → /new-path"}, + }, + }, + { + name: "nil lockfile with no config weights", + lock: nil, + config: nil, + want: []DriftResult{}, + }, + { + name: "nil lockfile with config weights: all pending", + lock: nil, + config: []ConfigWeight{ + {Name: "a", URI: "file://./a", Target: "/a"}, + {Name: "b", URI: "file://./b", Target: "/b"}, + }, + want: []DriftResult{ + {Name: "a", Kind: DriftPending}, + {Name: "b", Kind: DriftPending}, + }, + }, + { + name: "empty lockfile with config weights: all pending", + lock: &WeightsLock{Version: Version, Weights: []WeightLockEntry{}}, + config: []ConfigWeight{ + {Name: "a", URI: "file://./a", Target: "/a"}, + }, + want: []DriftResult{ + {Name: "a", Kind: DriftPending}, + }, + }, + { + name: "multiple drift types in one check", + lock: &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "ok", Target: "/ok", Source: WeightLockSource{URI: "file://./ok", Include: []string{}, Exclude: []string{}}}, + {Name: "stale", Target: "/stale", Source: WeightLockSource{URI: "file://./old-uri", Include: []string{}, Exclude: []string{}}}, + {Name: "orphan", Target: "/orphan", Source: WeightLockSource{URI: "file://./orphan", Include: []string{}, Exclude: []string{}}}, + }, + }, + config: []ConfigWeight{ + {Name: "ok", URI: "file://./ok", Target: "/ok", Include: []string{}, Exclude: []string{}}, + {Name: "stale", URI: "file://./new-uri", Target: "/stale", Include: []string{}, Exclude: []string{}}, + {Name: "brand-new", URI: "file://./brand-new", Target: "/brand-new"}, + }, + want: []DriftResult{ + {Name: "stale", Kind: DriftConfigChanged, Details: "uri: file://./old-uri → file://./new-uri"}, + {Name: "brand-new", Kind: DriftPending}, + {Name: "orphan", Kind: DriftOrphaned}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CheckDrift(tt.lock, tt.config) + if got == nil { + got = []DriftResult{} + } + require.Len(t, got, len(tt.want)) + for i, want := range tt.want { + assert.Equal(t, want, got[i], "result[%d]", i) + } + }) + } +} diff --git a/pkg/weights/lockfile/lockfile.go b/pkg/weights/lockfile/lockfile.go new file mode 100644 index 0000000000..f05a2cd714 --- /dev/null +++ b/pkg/weights/lockfile/lockfile.go @@ -0,0 +1,408 @@ +// Package lockfile defines the on-disk weights.lock format and operations +// on it: parsing, loading, canonical serialization, and entry-level +// equality checks. +// +// The lockfile is Cog's source-of-truth for imported weights. It captures +// the source (URI + fingerprint + include/exclude), the resulting content +// (setDigest, files, layers), and the assembled OCI manifest digest. +// Everything downstream — OCI manifests, the runtime /.cog/weights.json, +// registry state validation — is a projection of these fields. +package lockfile + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "slices" + "sort" + "time" + + "github.com/replicate/cog/pkg/model/weightsource" +) + +// WeightsLockFilename is the default filename for the weights lock file. +const WeightsLockFilename = "weights.lock" + +// Version is the current lockfile format version. +// +// It is an integer; monotonic bumps (1 → 2) signal schema changes. +// Pre-release "v1" string versions have no migration path. +const Version = 1 + +// WeightsLock is the parsed representation of a weights.lock file. +// +// The serialized form is stable and deterministic: Weights is kept in +// insertion order (matching cog.yaml), every entry's Files slice is +// sorted by path, and every entry's Layers slice is sorted by digest. +// Regenerating the lockfile from the same source produces byte-identical +// output, which is what makes weights.lock safe to check into git. +type WeightsLock struct { + Version int `json:"version"` + // EnvelopeFormat is the sha256 digest (with "sha256:" prefix) + // identifying the packer configuration that produced — or, on + // the next import, will produce — the recorded layer digests. + // + // Cog stamps the current envelope digest into the lockfile on + // every rewrite. On a subsequent import a mismatch (including a + // missing/empty value, treated as "no match") forces the builder + // to recompute layer digests from the local content store + // instead of trusting the cached entry. See + // pkg/model/envelope.go for what feeds into the digest. + // + // Empty when the lockfile has never been written by a + // version of cog that knows about this field — that empty + // value compares unequal to any current envelope digest, which + // is exactly the "force a recompute" behavior we want. + EnvelopeFormat string `json:"envelopeFormat"` + Weights []WeightLockEntry `json:"weights"` +} + +// WeightLockEntry is one declared weight in the lockfile. +// +// The entry carries everything needed to reproduce the OCI artifacts: +// - identity of the source (Source block) +// - content-addressable identity of the file set (SetDigest) +// - per-file index mapping each file to its layer (Files) +// - intrinsic layer properties for the manifest (Layers) +// - the assembled manifest digest (Digest) +// +// No annotations are stored here; OCI presentation annotations are derived +// at manifest-build time from the typed fields (name, target, setDigest, +// etc.). +type WeightLockEntry struct { + // Name is the weight's logical name (e.g. "z-image-turbo"). + Name string `json:"name"` + // Target is the container mount path for this weight. + Target string `json:"target"` + // Source records where the weight came from and how it was filtered. + Source WeightLockSource `json:"source"` + // Digest is the sha256 digest of the assembled OCI manifest. + Digest string `json:"digest"` + // SetDigest is the weight set digest (spec §2.4): a content-addressable + // identifier for the file set, independent of packing strategy. + SetDigest string `json:"setDigest"` + // Size is the total uncompressed size of all files in bytes (sum of + // layer sizeUncompressed). + Size int64 `json:"size"` + // SizeCompressed is the total compressed layer size in bytes (sum of + // layer size) — the bytes the registry stores. + SizeCompressed int64 `json:"sizeCompressed"` + // Files is the per-file index, sorted by path. Each entry records the + // file's size, content digest, and which layer contains it. + Files []WeightLockFile `json:"files"` + // Layers is the set of packed tar layers, sorted by digest. Layer + // emission order from the packer is not guaranteed stable (future + // concurrency) — sorting produces deterministic output. + Layers []WeightLockLayer `json:"layers"` +} + +// WeightLockSource records provenance for a WeightLockEntry. +// +// An import is a pure function of (source URI, source fingerprint, +// include/exclude). Recording all four inputs plus the import timestamp +// makes the lockfile self-contained: given these fields and the source at +// Fingerprint, you can deterministically reproduce the Files/Layers that +// the entry describes. +type WeightLockSource struct { + // URI is the normalized source URI (e.g. file://./weights, + // hf://org/model, s3://bucket/prefix/). + URI string `json:"uri"` + // Fingerprint is the source's version identity at import time. + // Scheme-prefixed (sha256:, commit:, etag:, …). + Fingerprint weightsource.Fingerprint `json:"fingerprint"` + // Include is the sorted list of glob-style include patterns applied + // to the source. Sorted because order is not semantically meaningful + // (the patterns are a set, not a sequence) and canonicalizing here + // keeps the lockfile stable across reorderings in cog.yaml. Empty + // patterns are serialized as [] so the shape is stable. + Include []string `json:"include"` + // Exclude is the sorted list of exclude patterns, same shape as Include. + Exclude []string `json:"exclude"` + // ImportedAt is the wall-clock time of the import that produced this + // entry. It is informational only — it never participates in + // equality checks (see EntriesEqual). + ImportedAt time.Time `json:"importedAt"` +} + +// WeightLockFile is a single file in a WeightLockEntry's Files index. +// +// This mirrors the config blob entry shape (spec §2.3) so the config blob +// can be projected directly from Files without a second walk of the +// source directory. +type WeightLockFile struct { + // Path is the file path relative to the weight source directory, + // with forward slashes regardless of host OS. + Path string `json:"path"` + // Size is the file's uncompressed size in bytes. + Size int64 `json:"size"` + // Digest is the sha256 content digest of the file (hex-encoded with + // the "sha256:" prefix). + Digest string `json:"digest"` + // Layer is the digest of the layer containing this file. + Layer string `json:"layer"` +} + +// DirhashParts implements weightsource.Dirhashable so WeightLockFile +// slices can be passed directly to weightsource.DirHash. +func (f WeightLockFile) DirhashParts() weightsource.DirhashPart { + return weightsource.DirhashPart{Path: f.Path, Digest: f.Digest} +} + +// WeightLockLayer is an intrinsic description of a single packed tar layer. +// +// Only intrinsic properties live here — digest, mediaType, compressed size +// (Size), uncompressed size (SizeUncompressed). Layer content type +// ("bundle" vs "file") is not stored; it is derivable from Files (one +// file referencing the layer = single-file layer, many = bundle). +// Annotations are an OCI presentation detail and never stored in the +// lockfile. +type WeightLockLayer struct { + // Digest is the sha256 digest of the layer blob. + Digest string `json:"digest"` + // MediaType is the OCI layer media type + // (application/vnd.oci.image.layer.v1.tar or .tar+gzip). + MediaType string `json:"mediaType"` + // Size is the size of the layer blob in bytes (the bytes the + // registry stores, post-compression for gzip layers). + Size int64 `json:"size"` + // SizeUncompressed is the sum of regular-file bytes in the layer, + // matching the definition used for run.cog.weight.size.uncompressed + // on index descriptors. + SizeUncompressed int64 `json:"sizeUncompressed"` +} + +// ParseWeightsLock parses a weights.lock JSON document and rejects +// anything that is not a supported lockfile version. +func ParseWeightsLock(data []byte) (*WeightsLock, error) { + var lock WeightsLock + if err := json.Unmarshal(data, &lock); err != nil { + return nil, fmt.Errorf("parse weights.lock: %w", err) + } + if lock.Version != Version { + return nil, fmt.Errorf("unsupported weights.lock version %d (want %d)", + lock.Version, Version) + } + return &lock, nil +} + +// LoadWeightsLock loads a weights.lock file from disk. +func LoadWeightsLock(path string) (*WeightsLock, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read weights.lock: %w", err) + } + return ParseWeightsLock(data) +} + +// Save writes the weights.lock to disk in canonical JSON form. +// +// Save is deterministic: for any given WeightsLock value, repeated calls +// produce byte-identical output. It sorts each entry's Files by path and +// Layers by digest before serializing, normalizes empty Include/Exclude +// slices to [] (never omitted), and emits standard two-space indent. +func (wl *WeightsLock) Save(path string) error { + data, err := wl.Marshal() + if err != nil { + return err + } + if err := os.WriteFile(path, data, 0o644); err != nil { //nolint:gosec // lockfile is checked into the repo + return fmt.Errorf("write weights.lock: %w", err) + } + return nil +} + +// Marshal serializes the lockfile to canonical JSON bytes. It applies the +// sort + normalization rules described on Save. Marshal mutates the +// receiver's entries in place (sorting their Files and Layers); this is +// safe because the sort order is the canonical order. +func (wl *WeightsLock) Marshal() ([]byte, error) { + if wl.Version == 0 { + wl.Version = Version + } + for i := range wl.Weights { + canonicalize(&wl.Weights[i]) + } + data, err := json.MarshalIndent(wl, "", " ") + if err != nil { + return nil, fmt.Errorf("marshal weights.lock: %w", err) + } + return data, nil +} + +// ComputeSetDigest returns the weight set digest (spec §2.4): the dirhash +// of the entry's file set. ComputeSetDigest canonicalizes the entry +// in place before hashing, so Files order at call time does not affect +// the result. +func (e *WeightLockEntry) ComputeSetDigest() string { + canonicalize(e) + return weightsource.DirHash(e.Files) +} + +// RuntimeWeightsManifest is the in-image /.cog/weights.json file that +// signals managed weights to coglet. It is a minimal projection of the +// lockfile: only the fields coglet needs to know which weights to expect +// and where (spec §3.3). +type RuntimeWeightsManifest struct { + Weights []RuntimeWeightEntry `json:"weights"` +} + +// RuntimeWeightEntry is one weight in the runtime manifest. Three fields +// per entry: name, target, and the content-addressable set digest. +type RuntimeWeightEntry struct { + Name string `json:"name"` + Target string `json:"target"` + SetDigest string `json:"setDigest"` +} + +// RuntimeManifest projects the lockfile into the minimal runtime manifest +// written to /.cog/weights.json (spec §3.3). The result contains only the +// fields coglet needs: name, target, and setDigest per weight. +func (wl *WeightsLock) RuntimeManifest() *RuntimeWeightsManifest { + entries := make([]RuntimeWeightEntry, len(wl.Weights)) + for i, w := range wl.Weights { + entries[i] = RuntimeWeightEntry{ + Name: w.Name, + Target: w.Target, + SetDigest: w.SetDigest, + } + } + return &RuntimeWeightsManifest{Weights: entries} +} + +// canonicalize applies the serialization rules to a single entry: +// Files sorted by path, Layers sorted by digest, nil Include/Exclude +// normalized to [] so the shape is stable. Include/Exclude ordering is +// already canonical — WeightSpec sorts at construction, and all writes +// to WeightLockSource flow through a WeightSpec. +func canonicalize(e *WeightLockEntry) { + sort.Slice(e.Files, func(i, j int) bool { return e.Files[i].Path < e.Files[j].Path }) + sort.Slice(e.Layers, func(i, j int) bool { return e.Layers[i].Digest < e.Layers[j].Digest }) + if e.Source.Include == nil { + e.Source.Include = []string{} + } + if e.Source.Exclude == nil { + e.Source.Exclude = []string{} + } +} + +// FindWeight returns the lockfile entry with the given name, or nil if no +// such entry exists. +func (wl *WeightsLock) FindWeight(name string) *WeightLockEntry { + for i := range wl.Weights { + if wl.Weights[i].Name == name { + return &wl.Weights[i] + } + } + return nil +} + +// Retain removes any entries whose Name is not in keep. The order of +// surviving entries is preserved. Retain is used after a full import +// pass to prune weights that were removed from cog.yaml. +func (wl *WeightsLock) Retain(keep []string) { + set := make(map[string]bool, len(keep)) + for _, n := range keep { + set[n] = true + } + kept := make([]WeightLockEntry, 0, len(keep)) + for _, e := range wl.Weights { + if set[e.Name] { + kept = append(kept, e) + } + } + wl.Weights = kept +} + +// PruneLockfile removes lockfile entries whose names are not in keep. +// It is a no-op when the lockfile does not exist or when nothing would +// change, avoiding unnecessary file rewrites (which churn git diffs). +func PruneLockfile(lockPath string, keep []string) error { + lock, err := LoadWeightsLock(lockPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + + before := len(lock.Weights) + lock.Retain(keep) + if len(lock.Weights) == before { + return nil // nothing pruned + } + + return lock.Save(lockPath) +} + +// Upsert inserts or replaces the entry with the matching Name. It leaves +// all other entries in place and untouched. +func (wl *WeightsLock) Upsert(entry WeightLockEntry) { + for i := range wl.Weights { + if wl.Weights[i].Name == entry.Name { + wl.Weights[i] = entry + return + } + } + wl.Weights = append(wl.Weights, entry) +} + +// EntriesEqual reports whether two entries are identical in both content +// and source. ImportedAt is intentionally excluded — it is a consequence +// of an import being written, not an input to the equality check. +// +// A lockfile entry is safe to leave unchanged only when both a and b are +// non-nil and every field (besides ImportedAt) agrees. +func EntriesEqual(a, b *WeightLockEntry) bool { + return entriesContentEqual(a, b) && entriesSourceEqual(a, b) +} + +// entriesContentEqual reports whether two entries describe identical +// on-registry content: same manifest digest, same set digest, same total +// sizes, same file index, same layer descriptors. +func entriesContentEqual(a, b *WeightLockEntry) bool { + if a == nil || b == nil { + return false + } + if a.Name != b.Name || a.Target != b.Target || + a.Digest != b.Digest || a.SetDigest != b.SetDigest || + a.Size != b.Size || a.SizeCompressed != b.SizeCompressed { + return false + } + if len(a.Files) != len(b.Files) { + return false + } + for i := range a.Files { + if a.Files[i] != b.Files[i] { + return false + } + } + if len(a.Layers) != len(b.Layers) { + return false + } + for i := range a.Layers { + if a.Layers[i] != b.Layers[i] { + return false + } + } + return true +} + +// entriesSourceEqual reports whether two entries have identical source +// metadata: same URI, same fingerprint, same include/exclude patterns. +// ImportedAt is intentionally excluded. +func entriesSourceEqual(a, b *WeightLockEntry) bool { + if a == nil || b == nil { + return false + } + if a.Source.URI != b.Source.URI || a.Source.Fingerprint != b.Source.Fingerprint { + return false + } + if !slices.Equal(a.Source.Include, b.Source.Include) { + return false + } + if !slices.Equal(a.Source.Exclude, b.Source.Exclude) { + return false + } + return true +} diff --git a/pkg/weights/lockfile/lockfile_test.go b/pkg/weights/lockfile/lockfile_test.go new file mode 100644 index 0000000000..0bd8a117e6 --- /dev/null +++ b/pkg/weights/lockfile/lockfile_test.go @@ -0,0 +1,667 @@ +package lockfile + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/google/go-containerregistry/pkg/v1/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/model/weightsource" +) + +// OCI layer media types used in fixtures. +var ( + mediaTypeOCILayerTar = string(types.OCIUncompressedLayer) + mediaTypeOCILayerTarGzip = string(types.OCILayer) +) + +// sampleEntry returns a fully-populated WeightLockEntry for tests. +func sampleEntry() WeightLockEntry { + return WeightLockEntry{ + Name: "z-image-turbo", + Target: "/src/weights", + Source: WeightLockSource{ + URI: "file://./weights", + Fingerprint: weightsource.Fingerprint("sha256:def456"), + Include: []string{}, + Exclude: []string{}, + ImportedAt: time.Date(2026, 4, 16, 17, 27, 7, 0, time.UTC), + }, + Digest: "sha256:abc123", + SetDigest: "sha256:def456", + Size: 1500, + SizeCompressed: 1200, + Files: []WeightLockFile{ + {Path: "a.json", Size: 100, Digest: "sha256:f01", Layer: "sha256:aaa"}, + {Path: "b.bin", Size: 1400, Digest: "sha256:f02", Layer: "sha256:bbb"}, + }, + Layers: []WeightLockLayer{ + {Digest: "sha256:aaa", MediaType: mediaTypeOCILayerTarGzip, Size: 110, SizeUncompressed: 100}, + {Digest: "sha256:bbb", MediaType: mediaTypeOCILayerTar, Size: 1400, SizeUncompressed: 1400}, + }, + } +} + +func TestWeightsLock_ParseValid(t *testing.T) { + data := `{ + "version": 1, + "weights": [ + { + "name": "z-image-turbo", + "target": "/src/weights", + "source": { + "uri": "file://./weights", + "fingerprint": "sha256:def456", + "include": [], + "exclude": [], + "importedAt": "2026-04-16T17:27:07Z" + }, + "digest": "sha256:abc123", + "setDigest": "sha256:def456", + "size": 1500, + "sizeCompressed": 1200, + "files": [ + {"path": "a.json", "size": 100, "digest": "sha256:f01", "layer": "sha256:aaa"} + ], + "layers": [ + {"digest": "sha256:aaa", "mediaType": "application/vnd.oci.image.layer.v1.tar+gzip", "size": 110, "sizeUncompressed": 100} + ] + } + ] + }` + + lock, err := ParseWeightsLock([]byte(data)) + require.NoError(t, err) + assert.Equal(t, Version, lock.Version) + require.Len(t, lock.Weights, 1) + + w := lock.Weights[0] + assert.Equal(t, "z-image-turbo", w.Name) + assert.Equal(t, "/src/weights", w.Target) + assert.Equal(t, "sha256:abc123", w.Digest) + assert.Equal(t, "sha256:def456", w.SetDigest) + assert.Equal(t, int64(1500), w.Size) + assert.Equal(t, int64(1200), w.SizeCompressed) + + assert.Equal(t, "file://./weights", w.Source.URI) + assert.Equal(t, weightsource.Fingerprint("sha256:def456"), w.Source.Fingerprint) + + require.Len(t, w.Files, 1) + assert.Equal(t, "a.json", w.Files[0].Path) + assert.Equal(t, int64(100), w.Files[0].Size) + assert.Equal(t, "sha256:aaa", w.Files[0].Layer) + + require.Len(t, w.Layers, 1) + assert.Equal(t, "sha256:aaa", w.Layers[0].Digest) + assert.Equal(t, mediaTypeOCILayerTarGzip, w.Layers[0].MediaType) + assert.Equal(t, int64(110), w.Layers[0].Size) + assert.Equal(t, int64(100), w.Layers[0].SizeUncompressed) +} + +func TestWeightsLock_RejectsUnknownVersion(t *testing.T) { + // The pre-release lockfile used version "v1" (string). The v1 + // schema uses integer 1; anything else is rejected. + data := `{"version": "v1", "weights": []}` + _, err := ParseWeightsLock([]byte(data)) + require.Error(t, err) + + data = `{"version": 2, "weights": []}` + _, err = ParseWeightsLock([]byte(data)) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported weights.lock version") +} + +func TestWeightsLock_LoadFromFile(t *testing.T) { + dir := t.TempDir() + lockPath := filepath.Join(dir, "weights.lock") + content := `{"version": 1, "weights": []}` + require.NoError(t, os.WriteFile(lockPath, []byte(content), 0o644)) + + lock, err := LoadWeightsLock(lockPath) + require.NoError(t, err) + assert.Equal(t, Version, lock.Version) + assert.Empty(t, lock.Weights) +} + +func TestWeightsLock_Save_SetsMissingVersion(t *testing.T) { + dir := t.TempDir() + lockPath := filepath.Join(dir, "weights.lock") + + lock := &WeightsLock{ + Weights: []WeightLockEntry{sampleEntry()}, + } + require.NoError(t, lock.Save(lockPath)) + assert.Equal(t, Version, lock.Version, "Save fills in the missing version") + + loaded, err := LoadWeightsLock(lockPath) + require.NoError(t, err) + assert.Equal(t, Version, loaded.Version) + require.Len(t, loaded.Weights, 1) + assert.Equal(t, "z-image-turbo", loaded.Weights[0].Name) +} + +func TestWeightsLock_Save_Deterministic(t *testing.T) { + dir := t.TempDir() + path1 := filepath.Join(dir, "a.lock") + path2 := filepath.Join(dir, "b.lock") + + lock1 := &WeightsLock{Version: Version, Weights: []WeightLockEntry{sampleEntry()}} + lock2 := &WeightsLock{Version: Version, Weights: []WeightLockEntry{sampleEntry()}} + + require.NoError(t, lock1.Save(path1)) + require.NoError(t, lock2.Save(path2)) + + d1, err := os.ReadFile(path1) + require.NoError(t, err) + d2, err := os.ReadFile(path2) + require.NoError(t, err) + assert.Equal(t, d1, d2, "saving the same lockfile twice must be byte-identical") +} + +func TestWeightsLock_Marshal_SortsFilesByPath(t *testing.T) { + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + { + Name: "w", + Files: []WeightLockFile{ + {Path: "z.txt", Size: 1, Digest: "sha256:z", Layer: "sha256:a"}, + {Path: "a.txt", Size: 1, Digest: "sha256:a", Layer: "sha256:a"}, + {Path: "m.txt", Size: 1, Digest: "sha256:m", Layer: "sha256:a"}, + }, + Layers: []WeightLockLayer{ + {Digest: "sha256:a", MediaType: mediaTypeOCILayerTar, Size: 1, SizeUncompressed: 1}, + }, + }, + }, + } + _, err := lock.Marshal() + require.NoError(t, err) + + got := []string{lock.Weights[0].Files[0].Path, lock.Weights[0].Files[1].Path, lock.Weights[0].Files[2].Path} + assert.Equal(t, []string{"a.txt", "m.txt", "z.txt"}, got, "Marshal sorts files by path") +} + +func TestWeightsLock_Marshal_SortsLayersByDigest(t *testing.T) { + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + { + Name: "w", + Layers: []WeightLockLayer{ + {Digest: "sha256:zzz", MediaType: mediaTypeOCILayerTar, Size: 1, SizeUncompressed: 1}, + {Digest: "sha256:aaa", MediaType: mediaTypeOCILayerTar, Size: 1, SizeUncompressed: 1}, + {Digest: "sha256:mmm", MediaType: mediaTypeOCILayerTar, Size: 1, SizeUncompressed: 1}, + }, + }, + }, + } + _, err := lock.Marshal() + require.NoError(t, err) + + got := []string{lock.Weights[0].Layers[0].Digest, lock.Weights[0].Layers[1].Digest, lock.Weights[0].Layers[2].Digest} + assert.Equal(t, []string{"sha256:aaa", "sha256:mmm", "sha256:zzz"}, got, "Marshal sorts layers by digest") +} + +func TestWeightsLock_Marshal_NormalizesEmptyPatterns(t *testing.T) { + // Source.Include and Source.Exclude should serialize as [] (never + // omitted) when empty or nil, so the schema shape is stable. + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "w", Source: WeightLockSource{URI: "file://./x"}}, + }, + } + data, err := lock.Marshal() + require.NoError(t, err) + assert.Contains(t, string(data), `"include": []`) + assert.Contains(t, string(data), `"exclude": []`) +} + +func TestWeightsLock_Upsert(t *testing.T) { + t.Run("replaces existing entry", func(t *testing.T) { + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "a", Target: "/a", Digest: "sha256:aaa"}, + {Name: "b", Target: "/b", Digest: "sha256:bbb"}, + }, + } + + lock.Upsert(WeightLockEntry{Name: "a", Target: "/a-new", Digest: "sha256:aaa2"}) + require.Len(t, lock.Weights, 2, "upsert replaces in place, does not append") + + got := lock.FindWeight("a") + require.NotNil(t, got) + assert.Equal(t, "/a-new", got.Target) + assert.Equal(t, "sha256:aaa2", got.Digest) + + b := lock.FindWeight("b") + require.NotNil(t, b) + assert.Equal(t, "sha256:bbb", b.Digest) + }) + + t.Run("appends new entry", func(t *testing.T) { + lock := &WeightsLock{Version: Version} + lock.Upsert(WeightLockEntry{Name: "a", Target: "/a", Digest: "sha256:aaa"}) + lock.Upsert(WeightLockEntry{Name: "b", Target: "/b", Digest: "sha256:bbb"}) + + require.Len(t, lock.Weights, 2) + assert.Equal(t, "a", lock.Weights[0].Name) + assert.Equal(t, "b", lock.Weights[1].Name) + }) +} + +func TestWeightsLock_EnvelopeFormat_RoundTrip(t *testing.T) { + // EnvelopeFormat is a top-level lockfile field, persisted across + // Save/Load so a subsequent import can detect packer-config + // drift across cog versions. + dir := t.TempDir() + path := filepath.Join(dir, "weights.lock") + + want := "sha256:abcd1234" + lock := &WeightsLock{ + Version: Version, + EnvelopeFormat: want, + Weights: []WeightLockEntry{sampleEntry()}, + } + require.NoError(t, lock.Save(path)) + + loaded, err := LoadWeightsLock(path) + require.NoError(t, err) + assert.Equal(t, want, loaded.EnvelopeFormat, + "EnvelopeFormat must survive round-trip through disk") +} + +func TestWeightsLock_EnvelopeFormat_OmittedReadsAsEmpty(t *testing.T) { + // Older lockfiles written before the EnvelopeFormat field + // existed simply lack the JSON key. This must parse and produce + // an empty string — the "force a recompute on next import" + // signal — not error out. + data := `{"version": 1, "weights": []}` + lock, err := ParseWeightsLock([]byte(data)) + require.NoError(t, err) + assert.Empty(t, lock.EnvelopeFormat, + "missing envelopeFormat must parse as empty string") +} + +func TestWeightsLock_RoundTrip(t *testing.T) { + original := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{sampleEntry()}, + } + data, err := original.Marshal() + require.NoError(t, err) + + var decoded WeightsLock + require.NoError(t, json.Unmarshal(data, &decoded)) + assert.Equal(t, original.Version, decoded.Version) + require.Len(t, decoded.Weights, 1) + assert.Equal(t, original.Weights[0].Source.Fingerprint, decoded.Weights[0].Source.Fingerprint) + assert.Equal(t, original.Weights[0].Files, decoded.Weights[0].Files) + assert.Equal(t, original.Weights[0].Layers, decoded.Weights[0].Layers) +} + +func TestEntriesContentEqual(t *testing.T) { + a := sampleEntry() + b := sampleEntry() + assert.True(t, entriesContentEqual(&a, &b), "identical entries are content-equal") + + c := sampleEntry() + c.Digest = "sha256:different" + assert.False(t, entriesContentEqual(&a, &c), "differing manifest digest breaks equality") + + d := sampleEntry() + d.SetDigest = "sha256:different" + assert.False(t, entriesContentEqual(&a, &d), "differing set digest breaks equality") + + e := sampleEntry() + e.Files[0].Digest = "sha256:tampered" + assert.False(t, entriesContentEqual(&a, &e), "differing file digest breaks equality") + + f := sampleEntry() + f.Size = 99999 + assert.False(t, entriesContentEqual(&a, &f), "differing size breaks equality") +} + +func TestEntriesSourceEqual(t *testing.T) { + a := sampleEntry() + b := sampleEntry() + // Different importedAt must still be source-equal. + b.Source.ImportedAt = a.Source.ImportedAt.Add(1 * time.Hour) + assert.True(t, entriesSourceEqual(&a, &b), "importedAt must not affect source equality") + + c := sampleEntry() + c.Source.URI = "file://./different" + assert.False(t, entriesSourceEqual(&a, &c), "differing URI breaks source equality") + + d := sampleEntry() + d.Source.Fingerprint = "sha256:different" + assert.False(t, entriesSourceEqual(&a, &d), "differing fingerprint breaks source equality") + + e := sampleEntry() + e.Source.Include = []string{"*.safetensors"} + assert.False(t, entriesSourceEqual(&a, &e), "differing include patterns break source equality") + + f := sampleEntry() + f.Source.Exclude = []string{"README*"} + assert.False(t, entriesSourceEqual(&a, &f), "differing exclude patterns break source equality") +} + +func TestEntriesEqual_RequiresBothContentAndSource(t *testing.T) { + a := sampleEntry() + b := sampleEntry() + assert.True(t, EntriesEqual(&a, &b)) + + // Same content, different source — not equal. + c := sampleEntry() + c.Source.URI = "file://./other" + assert.False(t, EntriesEqual(&a, &c)) + + // Same source, different content — not equal. + d := sampleEntry() + d.Digest = "sha256:different" + assert.False(t, EntriesEqual(&a, &d)) +} + +// setDigestOf returns the set digest for a file set by wrapping it in a +// throwaway entry. Used by the ComputeSetDigest tests below where the +// caller only cares about the files, not the rest of the entry fields. +func setDigestOf(files []WeightLockFile) string { + e := WeightLockEntry{Files: files} + return e.ComputeSetDigest() +} + +func TestComputeSetDigest_Deterministic(t *testing.T) { + files := []WeightLockFile{ + {Path: "config.json", Size: 100, Digest: "sha256:aaa111", Layer: "sha256:layer1"}, + {Path: "model.safetensors", Size: 9999, Digest: "sha256:bbb222", Layer: "sha256:layer2"}, + } + d1 := setDigestOf(files) + d2 := setDigestOf(files) + require.Equal(t, d1, d2, "same inputs must produce same digest") + assert.Greater(t, len(d1), len("sha256:"), "digest must be non-trivial") +} + +func TestComputeSetDigest_PackingIndependent(t *testing.T) { + // Same files, different layer assignments → same set digest. + files1 := []WeightLockFile{ + {Path: "a.txt", Size: 10, Digest: "sha256:aaa", Layer: "sha256:layer1"}, + {Path: "b.txt", Size: 20, Digest: "sha256:bbb", Layer: "sha256:layer1"}, + } + files2 := []WeightLockFile{ + {Path: "a.txt", Size: 10, Digest: "sha256:aaa", Layer: "sha256:layerX"}, + {Path: "b.txt", Size: 20, Digest: "sha256:bbb", Layer: "sha256:layerY"}, + } + assert.Equal(t, setDigestOf(files1), setDigestOf(files2), + "set digest must be independent of layer assignment") +} + +func TestComputeSetDigest_DiffersForDifferentContent(t *testing.T) { + files1 := []WeightLockFile{ + {Path: "a.txt", Size: 10, Digest: "sha256:aaa"}, + } + files2 := []WeightLockFile{ + {Path: "a.txt", Size: 10, Digest: "sha256:bbb"}, + } + assert.NotEqual(t, setDigestOf(files1), setDigestOf(files2), + "different content must produce different set digest") +} + +func TestComputeSetDigest_FileOrderIndependent(t *testing.T) { + // ComputeSetDigest canonicalizes in place, so the caller's input + // order doesn't affect the result. + ordered := []WeightLockFile{ + {Path: "a.txt", Size: 10, Digest: "sha256:aaa"}, + {Path: "b.txt", Size: 20, Digest: "sha256:bbb"}, + } + reversed := []WeightLockFile{ + {Path: "b.txt", Size: 20, Digest: "sha256:bbb"}, + {Path: "a.txt", Size: 10, Digest: "sha256:aaa"}, + } + assert.Equal(t, setDigestOf(ordered), setDigestOf(reversed), + "set digest must be independent of file input order") +} + +func TestRuntimeManifest_ProjectsSpecFields(t *testing.T) { + // Verify the projection matches spec §3.3: only name, target, setDigest. + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + { + Name: "z-image-turbo", + Target: "/src/weights", + Digest: "sha256:abc123", + SetDigest: "sha256:def456", + Size: 32600000000, + SizeCompressed: 32457803776, + Source: WeightLockSource{ + URI: "file://./weights", + Fingerprint: "sha256:def456", + Include: []string{}, + Exclude: []string{}, + }, + Files: []WeightLockFile{ + {Path: "config.json", Size: 1234, Digest: "sha256:f01", Layer: "sha256:aaa"}, + }, + Layers: []WeightLockLayer{ + {Digest: "sha256:aaa", MediaType: mediaTypeOCILayerTarGzip, Size: 15000000, SizeUncompressed: 18500000}, + }, + }, + }, + } + + rm := lock.RuntimeManifest() + require.Len(t, rm.Weights, 1) + + w := rm.Weights[0] + assert.Equal(t, "z-image-turbo", w.Name) + assert.Equal(t, "/src/weights", w.Target) + assert.Equal(t, "sha256:def456", w.SetDigest) +} + +func TestRuntimeManifest_RoundTrip(t *testing.T) { + // Verify that serializing and deserializing the runtime manifest + // produces the exact spec §3.3 shape. + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + { + Name: "z-image-turbo", + Target: "/src/weights", + SetDigest: "sha256:def456", + }, + }, + } + + rm := lock.RuntimeManifest() + data, err := json.MarshalIndent(rm, "", " ") + require.NoError(t, err) + + var decoded RuntimeWeightsManifest + require.NoError(t, json.Unmarshal(data, &decoded)) + require.Len(t, decoded.Weights, 1) + assert.Equal(t, "z-image-turbo", decoded.Weights[0].Name) + assert.Equal(t, "/src/weights", decoded.Weights[0].Target) + assert.Equal(t, "sha256:def456", decoded.Weights[0].SetDigest) + + // Verify the JSON contains only the expected keys (no extras from lockfile). + var raw map[string]json.RawMessage + require.NoError(t, json.Unmarshal(data, &raw)) + + var entries []map[string]json.RawMessage + require.NoError(t, json.Unmarshal(raw["weights"], &entries)) + require.Len(t, entries, 1) + + keys := make([]string, 0, len(entries[0])) + for k := range entries[0] { + keys = append(keys, k) + } + assert.ElementsMatch(t, []string{"name", "target", "setDigest"}, keys, + "runtime manifest entries must contain exactly the spec §3.3 fields") +} + +func TestRuntimeManifest_MultipleWeights(t *testing.T) { + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "model-a", Target: "/src/weights/a", SetDigest: "sha256:aaa"}, + {Name: "model-b", Target: "/src/weights/b", SetDigest: "sha256:bbb"}, + }, + } + + rm := lock.RuntimeManifest() + require.Len(t, rm.Weights, 2) + assert.Equal(t, "model-a", rm.Weights[0].Name) + assert.Equal(t, "model-b", rm.Weights[1].Name) +} + +func TestRuntimeManifest_Empty(t *testing.T) { + lock := &WeightsLock{Version: Version, Weights: []WeightLockEntry{}} + rm := lock.RuntimeManifest() + assert.Empty(t, rm.Weights) +} + +func TestWeightsLock_Retain(t *testing.T) { + t.Run("removes entries not in keep set", func(t *testing.T) { + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "a", Target: "/a", Digest: "sha256:aaa"}, + {Name: "b", Target: "/b", Digest: "sha256:bbb"}, + {Name: "c", Target: "/c", Digest: "sha256:ccc"}, + }, + } + + lock.Retain([]string{"a", "c"}) + require.Len(t, lock.Weights, 2) + assert.Equal(t, "a", lock.Weights[0].Name) + assert.Equal(t, "c", lock.Weights[1].Name) + }) + + t.Run("preserves order", func(t *testing.T) { + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "z", Target: "/z"}, + {Name: "m", Target: "/m"}, + {Name: "a", Target: "/a"}, + }, + } + + lock.Retain([]string{"a", "z"}) + require.Len(t, lock.Weights, 2) + assert.Equal(t, "z", lock.Weights[0].Name, "original insertion order must be preserved") + assert.Equal(t, "a", lock.Weights[1].Name) + }) + + t.Run("empty keep set removes all", func(t *testing.T) { + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "a", Target: "/a"}, + }, + } + + lock.Retain(nil) + assert.Empty(t, lock.Weights) + }) + + t.Run("noop when all kept", func(t *testing.T) { + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "a", Target: "/a"}, + {Name: "b", Target: "/b"}, + }, + } + + lock.Retain([]string{"a", "b"}) + require.Len(t, lock.Weights, 2) + assert.Equal(t, "a", lock.Weights[0].Name) + assert.Equal(t, "b", lock.Weights[1].Name) + }) + + t.Run("keep set with unknown names is safe", func(t *testing.T) { + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "a", Target: "/a"}, + }, + } + + lock.Retain([]string{"a", "nonexistent"}) + require.Len(t, lock.Weights, 1) + assert.Equal(t, "a", lock.Weights[0].Name) + }) +} + +func TestPruneLockfile_RemovesOrphanedEntries(t *testing.T) { + // Regression test: removing a weight from cog.yaml must remove its + // entry from weights.lock. Before the fix, orphaned entries + // persisted and were projected into /.cog/weights.json, causing + // coglet to expect weights that no longer exist. + dir := t.TempDir() + lockPath := filepath.Join(dir, "weights.lock") + + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "w1", Target: "/src/w1", Digest: "sha256:aaa"}, + {Name: "w2", Target: "/src/w2", Digest: "sha256:bbb"}, + }, + } + require.NoError(t, lock.Save(lockPath)) + + err := PruneLockfile(lockPath, []string{"w1"}) + require.NoError(t, err) + + loaded, err := LoadWeightsLock(lockPath) + require.NoError(t, err) + require.Len(t, loaded.Weights, 1, "orphaned entry w2 should be removed") + assert.Equal(t, "w1", loaded.Weights[0].Name) +} + +func TestPruneLockfile_NoopWhenNothingRemoved(t *testing.T) { + // PruneLockfile should not rewrite the lockfile if there are no + // orphaned entries, keeping the mtime stable for git. + dir := t.TempDir() + lockPath := filepath.Join(dir, "weights.lock") + + lock := &WeightsLock{ + Version: Version, + Weights: []WeightLockEntry{ + {Name: "w1", Target: "/src/w1", Digest: "sha256:aaa"}, + }, + } + require.NoError(t, lock.Save(lockPath)) + + infoBefore, err := os.Stat(lockPath) + require.NoError(t, err) + + err = PruneLockfile(lockPath, []string{"w1"}) + require.NoError(t, err) + + infoAfter, err := os.Stat(lockPath) + require.NoError(t, err) + assert.Equal(t, infoBefore.ModTime(), infoAfter.ModTime(), + "PruneLockfile must not rewrite lockfile when nothing changed") +} + +func TestPruneLockfile_MissingLockfileIsNoop(t *testing.T) { + // If no lockfile exists yet, PruneLockfile is a safe no-op. + dir := t.TempDir() + lockPath := filepath.Join(dir, "weights.lock") + + err := PruneLockfile(lockPath, []string{"w1"}) + require.NoError(t, err, "pruning a missing lockfile should not error") + + // Lockfile should not have been created. + _, err = os.Stat(lockPath) + require.True(t, os.IsNotExist(err), "PruneLockfile must not create a lockfile") +} diff --git a/pkg/weights/manager.go b/pkg/weights/manager.go new file mode 100644 index 0000000000..3a579a8f6b --- /dev/null +++ b/pkg/weights/manager.go @@ -0,0 +1,113 @@ +// Package weights orchestrates managed-weight operations: populating +// the local content-addressed store from a registry (Pull) and +// assembling per-invocation mount dirs (Prepare). +// +// The Manager is the single entry point. CLI commands construct one +// and call it; no CLI surface constructs stores, fetches layers, or +// walks tars directly. +package weights + +import ( + "context" + "errors" + "fmt" + + v1 "github.com/google/go-containerregistry/pkg/v1" + + "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" +) + +// imageFetcher is the subset of registry.Client that Pull uses. The +// full registry.Client satisfies it; narrowing here makes the +// dependency explicit and lets tests mock a single method instead of +// the whole fat interface. +type imageFetcher interface { + GetImage(ctx context.Context, ref string, platform *registry.Platform) (v1.Image, error) +} + +// Manager orchestrates managed-weight operations against a local +// content-addressed store and a remote OCI registry. +type Manager struct { + store store.Store + registry imageFetcher + repo string + lock *lockfile.WeightsLock + projectDir string +} + +// ManagerOptions is the argument struct for NewManager. +// +// Store and Registry are always required. Lock, Repo, and ProjectDir +// are required only if the model has weights — a Manager constructed +// with a nil or empty Lock is a valid no-op, so callers don't need to +// branch on "does cog.yaml declare weights?" before constructing one. +type ManagerOptions struct { + Store store.Store + Registry registry.Client + Repo string + Lock *lockfile.WeightsLock + ProjectDir string +} + +func NewManager(opts ManagerOptions) (*Manager, error) { + if opts.Store == nil { + return nil, errors.New("weights manager: store is required") + } + if opts.Registry == nil { + return nil, errors.New("weights manager: registry is required") + } + // Repo/Lock/ProjectDir are only required when the model actually + // has weights. Validated lazily in Pull and Prepare. + if opts.Lock != nil && len(opts.Lock.Weights) > 0 && opts.Repo == "" { + return nil, errors.New("weights manager: repo is required when lock has weights") + } + return &Manager{ + store: opts.Store, + registry: opts.Registry, + repo: opts.Repo, + lock: opts.Lock, + projectDir: opts.ProjectDir, + }, nil +} + +// ProjectDir returns the project directory configured on the Manager. +// Primarily useful for tests and for CLI code that wants to log where +// mounts will live. +func (m *Manager) ProjectDir() string { return m.projectDir } + +// selectEntries returns the lockfile entries matching names, in name +// order. Empty names means every entry in lockfile order. Unknown +// names are reported in a single error so the user sees all typos in +// one shot. +func (m *Manager) selectEntries(names []string) ([]*lockfile.WeightLockEntry, error) { + if m.lock == nil { + if len(names) > 0 { + return nil, fmt.Errorf("unknown weight(s): %v (model has no weights)", names) + } + return nil, nil + } + if len(names) == 0 { + out := make([]*lockfile.WeightLockEntry, len(m.lock.Weights)) + for i := range m.lock.Weights { + out[i] = &m.lock.Weights[i] + } + return out, nil + } + + out := make([]*lockfile.WeightLockEntry, 0, len(names)) + var missing []string + for _, n := range names { + entry := m.lock.FindWeight(n) + if entry == nil { + missing = append(missing, n) + continue + } + out = append(out, entry) + } + if len(missing) > 0 { + return nil, fmt.Errorf("unknown weight(s): %v", missing) + } + return out, nil +} diff --git a/pkg/weights/mount.go b/pkg/weights/mount.go new file mode 100644 index 0000000000..6bf197cfbc --- /dev/null +++ b/pkg/weights/mount.go @@ -0,0 +1,184 @@ +package weights + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "syscall" + + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/weights/lockfile" +) + +// MountSpec describes one bind mount from host to container. Managed- +// weight mounts are always read-only; callers set ReadOnly on their +// container runtime's volume type unconditionally. +type MountSpec struct { + Source string + Target string +} + +// Mounts is the handle returned from Prepare. It owns a per-invocation +// scratch directory under /.cog/mounts and MUST be +// released — either via Release or by the caller noticing the +// context was canceled. +type Mounts struct { + Specs []MountSpec + + root string +} + +// Release removes the per-invocation mount directory and every +// hardlink beneath it. The store's blobs are untouched. +// Release is idempotent and nil-safe. +func (m *Mounts) Release() error { + if m == nil || m.root == "" { + return nil + } + root := m.root + m.root = "" + if err := os.RemoveAll(root); err != nil { + return fmt.Errorf("remove mount dir %s: %w", root, err) + } + return nil +} + +// Prepare assembles per-invocation mount directories for every weight +// in the lockfile. Each weight gets its own directory populated by +// hardlinking blobs from the local store. +// +// If any file is missing from the store, Prepare returns an error +// directing the user at `cog weights pull`. v1 does NOT auto-pull. +// +// Hardlinks require the store and project dir to share a filesystem. +// On EXDEV, Prepare returns a clear error pointing at COG_CACHE_DIR; +// silent copy/symlink fallbacks would defeat the zero-duplication +// property. +// +// On any failure the partially-assembled invocation dir is removed +// before returning. +func (m *Manager) Prepare(ctx context.Context) (_ *Mounts, retErr error) { + // A Manager configured for a weights-less model is a valid no-op: + // Predictor can always call Prepare without checking whether + // weights exist in cog.yaml. + if m.lock == nil || len(m.lock.Weights) == 0 { + return &Mounts{}, nil + } + + if m.projectDir == "" { + return nil, errors.New("prepare: Manager has no project dir") + } + + invocationID, err := newInvocationID() + if err != nil { + return nil, fmt.Errorf("generate invocation id: %w", err) + } + root := filepath.Join(m.projectDir, global.CogBuildArtifactsFolder, "mounts", invocationID) + if err := os.MkdirAll(root, 0o755); err != nil { + return nil, fmt.Errorf("create mount root %s: %w", root, err) + } + + mounts := &Mounts{root: root} + + defer func() { + if retErr != nil { + _ = mounts.Release() + } + }() + + for i := range m.lock.Weights { + entry := &m.lock.Weights[i] + weightDir, err := safeJoin(root, entry.Name) + if err != nil { + return nil, fmt.Errorf("weight %q: %w", entry.Name, err) + } + if err := m.assembleWeightDir(ctx, entry, weightDir); err != nil { + return nil, err + } + mounts.Specs = append(mounts.Specs, MountSpec{ + Source: weightDir, + Target: entry.Target, + }) + } + + return mounts, nil +} + +// safeJoin joins rel onto base and rejects the result if it escapes +// base. Lockfile entries are normally authored by `cog weights import`, +// but they're checked-in files a malicious or corrupt source could +// poison — filepath.Join cleans `..` components but doesn't prevent +// escape (`filepath.Join("/root", "../../etc")` returns `/etc`). +func safeJoin(base, rel string) (string, error) { + if rel == "" { + return "", errors.New("empty path component") + } + cleanBase := filepath.Clean(base) + joined := filepath.Clean(filepath.Join(cleanBase, rel)) + if !strings.HasPrefix(joined+string(filepath.Separator), cleanBase+string(filepath.Separator)) && joined != cleanBase { + return "", fmt.Errorf("path %q escapes parent directory", rel) + } + return joined, nil +} + +func (m *Manager) assembleWeightDir(ctx context.Context, entry *lockfile.WeightLockEntry, weightDir string) error { + if err := os.MkdirAll(weightDir, 0o755); err != nil { + return fmt.Errorf("create %s: %w", weightDir, err) + } + for _, f := range entry.Files { + if err := ctx.Err(); err != nil { + return err + } + src, err := m.store.Path(ctx, f.Digest) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("weight %q is not fully cached locally (missing %s); run 'cog weights pull' first", entry.Name, f.Path) + } + return fmt.Errorf("locate %s (%s): %w", f.Path, f.Digest, err) + } + dst, err := safeJoin(weightDir, filepath.FromSlash(f.Path)) + if err != nil { + return fmt.Errorf("weight %q file %q: %w", entry.Name, f.Path, err) + } + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return fmt.Errorf("create parent of %s: %w", dst, err) + } + if err := os.Link(src, dst); err != nil { + return wrapLinkError(err, src, dst) + } + } + return nil +} + +// wrapLinkError decorates os.Link errors, explicitly diagnosing +// EXDEV — different filesystems for cache and project — because +// silent fallback (copy or symlink) would defeat the zero-duplication +// property or be unreliable inside bind-mounted containers. +func wrapLinkError(err error, src, dst string) error { + if errors.Is(err, syscall.EXDEV) { + return fmt.Errorf( + "hardlink %s -> %s failed: cache directory and project directory are on different filesystems. "+ + "Set COG_CACHE_DIR to a path on the same filesystem as your project, then re-run 'cog weights pull'. "+ + "underlying error: %w", + src, dst, err, + ) + } + return fmt.Errorf("hardlink %s -> %s: %w", src, dst, err) +} + +// newInvocationID returns 8 hex chars (2^32 distinct). Short enough +// for pleasant paths, wide enough that concurrent Predictors don't +// collide in practice. +func newInvocationID() (string, error) { + var b [4]byte + if _, err := rand.Read(b[:]); err != nil { + return "", err + } + return hex.EncodeToString(b[:]), nil +} diff --git a/pkg/weights/mount_test.go b/pkg/weights/mount_test.go new file mode 100644 index 0000000000..14ca5e3e08 --- /dev/null +++ b/pkg/weights/mount_test.go @@ -0,0 +1,312 @@ +package weights + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "os" + "path/filepath" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" +) + +func sha256Of(data []byte) string { + sum := sha256.Sum256(data) + return "sha256:" + hex.EncodeToString(sum[:]) +} + +// primedManager returns a Manager whose store is pre-populated with +// every digest in bytesByDigest. The store is returned so tests can +// inspect it without reaching into unexported Manager fields. +func primedManager(t *testing.T, lock *lockfile.WeightsLock, bytesByDigest map[string][]byte) (*Manager, *store.FileStore) { + t.Helper() + fs, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + for digest, data := range bytesByDigest { + require.NoError(t, fs.PutFile(context.Background(), digest, int64(len(data)), bytes.NewReader(data))) + } + mgr, err := NewManager(ManagerOptions{ + Store: fs, + Registry: newStubRegistry(), + Repo: "example.com/me/model", + Lock: lock, + ProjectDir: t.TempDir(), + }) + require.NoError(t, err) + return mgr, fs +} + +func buildSimpleLock() (*lockfile.WeightsLock, map[string][]byte) { + fileA := []byte("alpha content") + fileB := []byte("bravo content") + dA := sha256Of(fileA) + dB := sha256Of(fileB) + + entry := lockfile.WeightLockEntry{ + Name: "parakeet", + Target: "/src/weights/parakeet", + Files: []lockfile.WeightLockFile{ + {Path: "config.json", Size: int64(len(fileA)), Digest: dA, Layer: "sha256:deadbeef"}, + {Path: "model/weights.bin", Size: int64(len(fileB)), Digest: dB, Layer: "sha256:deadbeef"}, + }, + } + return &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{entry}}, + map[string][]byte{dA: fileA, dB: fileB} +} + +func TestPrepare_HappyPath(t *testing.T) { + t.Parallel() + ctx := context.Background() + lock, bytesByDigest := buildSimpleLock() + mgr, _ := primedManager(t, lock, bytesByDigest) + + mounts, err := mgr.Prepare(ctx) + require.NoError(t, err) + t.Cleanup(func() { _ = mounts.Release() }) + + require.Len(t, mounts.Specs, 1) + spec := mounts.Specs[0] + assert.Equal(t, "/src/weights/parakeet", spec.Target) + assert.Contains(t, spec.Source, filepath.Join(".cog", "mounts")) + assert.True(t, filepath.IsAbs(spec.Source)) + + for _, f := range lock.Weights[0].Files { + onDisk := filepath.Join(spec.Source, filepath.FromSlash(f.Path)) + got, err := os.ReadFile(onDisk) //nolint:gosec // test-owned path + require.NoError(t, err) + require.Equal(t, bytesByDigest[f.Digest], got) + } +} + +func TestPrepare_HardlinksShareInodes(t *testing.T) { + t.Parallel() + // Avoiding byte duplication is the whole point of Prepare — + // verify the mount and store point at the same inode. + ctx := context.Background() + lock, bytesByDigest := buildSimpleLock() + mgr, fs := primedManager(t, lock, bytesByDigest) + + mounts, err := mgr.Prepare(ctx) + require.NoError(t, err) + t.Cleanup(func() { _ = mounts.Release() }) + + f := lock.Weights[0].Files[0] + storePath, err := fs.Path(ctx, f.Digest) + require.NoError(t, err) + mountPath := filepath.Join(mounts.Specs[0].Source, filepath.FromSlash(f.Path)) + + storeStat, err := os.Stat(storePath) + require.NoError(t, err) + mountStat, err := os.Stat(mountPath) + require.NoError(t, err) + require.True(t, os.SameFile(storeStat, mountStat), + "hard-linked mount file must share inode with store file") +} + +func TestPrepare_MissingFile_ErrorMentionsPull(t *testing.T) { + t.Parallel() + ctx := context.Background() + lock, bytesByDigest := buildSimpleLock() + first := lock.Weights[0].Files[0].Digest + mgr, _ := primedManager(t, lock, map[string][]byte{first: bytesByDigest[first]}) + + _, err := mgr.Prepare(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "cog weights pull") + assert.Contains(t, err.Error(), "parakeet") +} + +func TestPrepare_CleansUpOnFailure(t *testing.T) { + t.Parallel() + // A later weight being absent must not leak partial dirs for + // earlier weights. + ctx := context.Background() + dataA := []byte("a") + dataB := []byte("b") + dA := sha256Of(dataA) + dB := sha256Of(dataB) + + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{ + { + Name: "present", Target: "/w1", + Files: []lockfile.WeightLockFile{{Path: "f", Size: 1, Digest: dA, Layer: "sha256:x"}}, + }, + { + Name: "absent", Target: "/w2", + Files: []lockfile.WeightLockFile{{Path: "f", Size: 1, Digest: dB, Layer: "sha256:y"}}, + }, + }, + } + mgr, _ := primedManager(t, lock, map[string][]byte{dA: dataA}) // dB missing + + _, err := mgr.Prepare(ctx) + require.Error(t, err) + + entries, err := os.ReadDir(filepath.Join(mgr.ProjectDir(), ".cog", "mounts")) + if err == nil { + assert.Empty(t, entries, "failed Prepare must not leave invocation dirs behind") + } else { + require.ErrorIs(t, err, os.ErrNotExist) + } +} + +func TestMounts_Release_Idempotent(t *testing.T) { + t.Parallel() + ctx := context.Background() + lock, bytesByDigest := buildSimpleLock() + mgr, _ := primedManager(t, lock, bytesByDigest) + + mounts, err := mgr.Prepare(ctx) + require.NoError(t, err) + + require.NoError(t, mounts.Release()) + _, statErr := os.Stat(mounts.Specs[0].Source) + require.ErrorIs(t, statErr, os.ErrNotExist) + + require.NoError(t, mounts.Release()) +} + +func TestMounts_Release_NilSafe(t *testing.T) { + t.Parallel() + var m *Mounts + require.NoError(t, m.Release()) +} + +func TestPrepare_RejectsPathTraversalInWeightName(t *testing.T) { + t.Parallel() + // A lockfile entry whose Name tries to escape the mount root must + // be refused — the lockfile is normally authored by import, but + // it's a checked-in file that could come from a hand-edit or an + // untrusted fork. + ctx := context.Background() + data := []byte("x") + d := sha256Of(data) + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{{ + Name: "../escape", + Target: "/w", + Files: []lockfile.WeightLockFile{{Path: "f", Size: 1, Digest: d, Layer: "sha256:x"}}, + }}, + } + mgr, _ := primedManager(t, lock, map[string][]byte{d: data}) + + _, err := mgr.Prepare(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "escape") +} + +func TestPrepare_RejectsPathTraversalInFilePath(t *testing.T) { + t.Parallel() + ctx := context.Background() + data := []byte("x") + d := sha256Of(data) + lock := &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{{ + Name: "m1", + Target: "/w", + Files: []lockfile.WeightLockFile{{ + Path: "../../etc/passwd", Size: 1, Digest: d, Layer: "sha256:x", + }}, + }}, + } + mgr, _ := primedManager(t, lock, map[string][]byte{d: data}) + + _, err := mgr.Prepare(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "escape") +} + +func TestSafeJoin(t *testing.T) { + t.Parallel() + tests := []struct { + name string + base string + rel string + wantErr bool + }{ + {name: "simple", base: "/root", rel: "child", wantErr: false}, + {name: "nested", base: "/root", rel: "a/b/c", wantErr: false}, + {name: "parent escape", base: "/root", rel: "../outside", wantErr: true}, + {name: "double parent escape", base: "/root", rel: "../../etc", wantErr: true}, + // Absolute-looking paths are re-rooted under base by filepath.Join, + // so they stay inside and are allowed. + {name: "absolute path in rel gets re-rooted", base: "/root", rel: "/etc/passwd", wantErr: false}, + {name: "empty rel", base: "/root", rel: "", wantErr: true}, + {name: "dot stays in", base: "/root", rel: "./a", wantErr: false}, + {name: "parent then back in", base: "/root", rel: "a/../b", wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := safeJoin(tt.base, tt.rel) + if tt.wantErr { + require.Error(t, err, "rel %q should be rejected", tt.rel) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestWrapLinkError_EXDEV(t *testing.T) { + t.Parallel() + // Bare EXDEV. + err := wrapLinkError(syscall.EXDEV, "/cache/blob", "/project/mount/file") + require.Error(t, err) + assert.Contains(t, err.Error(), "COG_CACHE_DIR", + "EXDEV error must point users at the COG_CACHE_DIR escape hatch") + assert.Contains(t, err.Error(), "different filesystems") + assert.Contains(t, err.Error(), "/cache/blob") + assert.Contains(t, err.Error(), "/project/mount/file") + assert.ErrorIs(t, err, syscall.EXDEV, "wrap must preserve EXDEV for errors.Is") +} + +func TestWrapLinkError_EXDEV_ThroughLinkError(t *testing.T) { + t.Parallel() + // os.Link returns *os.LinkError wrapping syscall.EXDEV; errors.Is + // unwraps through that chain, so the EXDEV branch must still fire. + linkErr := &os.LinkError{Op: "link", Old: "/cache/blob", New: "/project/mount/file", Err: syscall.EXDEV} + err := wrapLinkError(linkErr, "/cache/blob", "/project/mount/file") + assert.Contains(t, err.Error(), "COG_CACHE_DIR") +} + +func TestWrapLinkError_NonEXDEV(t *testing.T) { + t.Parallel() + // Other errors get a plain wrap — no COG_CACHE_DIR hint, since + // that's EXDEV-specific advice. + err := wrapLinkError(errors.New("disk full"), "/a", "/b") + assert.Contains(t, err.Error(), "hardlink /a -> /b") + assert.NotContains(t, err.Error(), "COG_CACHE_DIR") +} + +func TestPrepare_NoProjectDirWithWeights(t *testing.T) { + t.Parallel() + // Weights-less Manager is a no-op regardless of projectDir; only + // a lock with actual weights triggers the projectDir requirement. + fs, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + mgr, err := NewManager(ManagerOptions{ + Store: fs, + Registry: newStubRegistry(), + Repo: "r", + Lock: &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{{Name: "w", Target: "/t"}}, + }, + }) + require.NoError(t, err) + _, err = mgr.Prepare(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "project dir") +} diff --git a/pkg/weights/pull.go b/pkg/weights/pull.go new file mode 100644 index 0000000000..579f0556b1 --- /dev/null +++ b/pkg/weights/pull.go @@ -0,0 +1,264 @@ +package weights + +import ( + "archive/tar" + "context" + "errors" + "fmt" + "io" + + v1 "github.com/google/go-containerregistry/pkg/v1" + + "github.com/replicate/cog/pkg/weights/lockfile" +) + +// PullResult summarizes what happened for a single weight during Pull. +// Returned in the same order as the input names (or lockfile order +// when names is empty). +type PullResult struct { + Name string + FullyCached bool + FilesFetched int + BytesFetched int64 + LayersFetched int +} + +// PullEvent is emitted during Pull to drive progress output. Delivered +// on the calling goroutine in order; handlers MUST NOT block. +// +// Fields are populated per Kind — see each kind's comment. +type PullEvent struct { + // Kind identifies which fields below are populated. + Kind PullEventKind + // Weight is set on every event. + Weight string + + // WeightStart: summary of what's about to happen for the weight. + // ManifestRef is set only when MissingFiles > 0 (fully-cached + // weights need no registry round trip). + Target string + TotalFiles int + MissingFiles int + ManifestRef string + + // LayerStart / LayerDone / FileStored: layer context. + // LayerSize is 0 when the backing layer does not expose a size + // (in-memory test layers). + LayerDigest string + LayerSize int64 + + // FileStored: per-file detail for a file just written to the store. + FilePath string + FileDigest string + FileSize int64 + + // WeightDone: cumulative totals for the weight. FullyCached is + // true when no registry I/O happened. + BytesFetched int64 + FilesFetched int + LayersFetched int + FullyCached bool +} + +type PullEventKind int + +const ( + // PullEventUnknown is the zero value so that a freshly-constructed + // PullEvent{} is distinguishable from a legitimate event. + PullEventUnknown PullEventKind = iota + PullEventWeightStart + PullEventLayerStart + PullEventFileStored + PullEventLayerDone + PullEventWeightDone +) + +// Pull populates the local store with every file referenced by the +// lockfile for the named weights. Empty names means "all weights". +// +// Behavior: +// - Files already present locally are skipped (no registry I/O). +// - A layer is fetched only if at least one of its files is missing. +// The whole layer must be streamed to reach any one file, so we +// store every expected file the layer contains — PutFile is +// idempotent so pre-cached files drain through without rewrites. +// - Registry is authoritative. v1 does not fall back to the source +// URI; re-run `cog weights import` if the registry is missing a +// layer. +// - Every file path in the tar must be in the lockfile. Unexpected +// paths error out. +// +// onEvent, if non-nil, is called synchronously with each PullEvent. +func (m *Manager) Pull(ctx context.Context, names []string, onEvent func(PullEvent)) ([]PullResult, error) { + entries, err := m.selectEntries(names) + if err != nil { + return nil, err + } + + emit := onEvent + if emit == nil { + emit = func(PullEvent) {} + } + + results := make([]PullResult, 0, len(entries)) + for _, entry := range entries { + if err := ctx.Err(); err != nil { + return results, err + } + r, err := m.pullEntry(ctx, entry, emit) + results = append(results, r) + if err != nil { + return results, fmt.Errorf("pull %s: %w", entry.Name, err) + } + } + return results, nil +} + +func (m *Manager) pullEntry(ctx context.Context, entry *lockfile.WeightLockEntry, emit func(PullEvent)) (PullResult, error) { + result := PullResult{Name: entry.Name} + + missingByLayer := map[string][]lockfile.WeightLockFile{} + var missingCount int + for _, f := range entry.Files { + ok, err := m.store.Exists(ctx, f.Digest) + if err != nil { + return result, fmt.Errorf("check %s: %w", f.Digest, err) + } + if ok { + continue + } + missingByLayer[f.Layer] = append(missingByLayer[f.Layer], f) + missingCount++ + } + + manifestRef := "" + if missingCount > 0 { + manifestRef = m.repo + "@" + entry.Digest + } + emit(PullEvent{ + Kind: PullEventWeightStart, + Weight: entry.Name, + Target: entry.Target, + TotalFiles: len(entry.Files), + MissingFiles: missingCount, + ManifestRef: manifestRef, + }) + + if missingCount == 0 { + result.FullyCached = true + emit(PullEvent{Kind: PullEventWeightDone, Weight: entry.Name, FullyCached: true}) + return result, nil + } + + img, err := m.registry.GetImage(ctx, manifestRef, nil) + if err != nil { + return result, fmt.Errorf("fetch weight manifest %s: %w", manifestRef, err) + } + + fileByPath := make(map[string]lockfile.WeightLockFile, len(entry.Files)) + for _, f := range entry.Files { + fileByPath[f.Path] = f + } + + for layerDigest, needed := range missingByLayer { + if err := ctx.Err(); err != nil { + return result, err + } + if err := m.pullLayer(ctx, entry.Name, img, layerDigest, needed, fileByPath, emit); err != nil { + return result, err + } + result.LayersFetched++ + for _, f := range needed { + result.FilesFetched++ + result.BytesFetched += f.Size + } + } + + emit(PullEvent{ + Kind: PullEventWeightDone, + Weight: entry.Name, + BytesFetched: result.BytesFetched, + FilesFetched: result.FilesFetched, + LayersFetched: result.LayersFetched, + }) + return result, nil +} + +// pullLayer streams a layer's tar blob and stores every regular file +// it contains, verifying the expected files for this layer appeared. +func (m *Manager) pullLayer( + ctx context.Context, + weightName string, + img v1.Image, + layerDigest string, + needed []lockfile.WeightLockFile, + fileByPath map[string]lockfile.WeightLockFile, + emit func(PullEvent), +) error { + hash, err := v1.NewHash(layerDigest) + if err != nil { + return fmt.Errorf("parse layer digest %q: %w", layerDigest, err) + } + layer, err := img.LayerByDigest(hash) + if err != nil { + return fmt.Errorf("find layer %s: %w", layerDigest, err) + } + layerSize, _ := layer.Size() + emit(PullEvent{ + Kind: PullEventLayerStart, + Weight: weightName, + LayerDigest: layerDigest, + LayerSize: layerSize, + }) + + rc, err := layer.Uncompressed() + if err != nil { + return fmt.Errorf("open layer %s: %w", layerDigest, err) + } + defer rc.Close() //nolint:errcheck // best-effort close on read path + + tr := tar.NewReader(rc) + written := map[string]bool{} + for { + if err := ctx.Err(); err != nil { + return err + } + hdr, err := tr.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return fmt.Errorf("read layer %s: %w", layerDigest, err) + } + if hdr.Typeflag != tar.TypeReg { + continue + } + + file, ok := fileByPath[hdr.Name] + if !ok { + return fmt.Errorf("layer %s: unexpected file %q not in lockfile", layerDigest, hdr.Name) + } + + if err := m.store.PutFile(ctx, file.Digest, file.Size, tr); err != nil { + return fmt.Errorf("store %s (%s): %w", file.Path, file.Digest, err) + } + written[file.Path] = true + emit(PullEvent{ + Kind: PullEventFileStored, + Weight: weightName, + LayerDigest: layerDigest, + FilePath: file.Path, + FileDigest: file.Digest, + FileSize: file.Size, + }) + } + + for _, f := range needed { + if !written[f.Path] { + return fmt.Errorf("layer %s: missing expected file %q", layerDigest, f.Path) + } + } + + emit(PullEvent{Kind: PullEventLayerDone, Weight: weightName, LayerDigest: layerDigest}) + return nil +} diff --git a/pkg/weights/pull_test.go b/pkg/weights/pull_test.go new file mode 100644 index 0000000000..ee8ee2ac77 --- /dev/null +++ b/pkg/weights/pull_test.go @@ -0,0 +1,697 @@ +package weights + +import ( + "archive/tar" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "testing" + + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/empty" + "github.com/google/go-containerregistry/pkg/v1/mutate" + "github.com/google/go-containerregistry/pkg/v1/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/registry/registrytest" + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" +) + +// --------------------------------------------------------------------------- +// Test fixtures: in-memory layers + registry stub. +// --------------------------------------------------------------------------- + +// rawTarLayer implements v1.Layer over a fixed byte slice of uncompressed +// tar data. Digest/DiffID are computed from the bytes so LayerByDigest +// lookups resolve correctly. +type rawTarLayer struct { + bytes []byte + hash v1.Hash +} + +func newRawTarLayer(data []byte) *rawTarLayer { + sum := sha256.Sum256(data) + return &rawTarLayer{ + bytes: data, + hash: v1.Hash{Algorithm: "sha256", Hex: hex.EncodeToString(sum[:])}, + } +} + +func (l *rawTarLayer) Digest() (v1.Hash, error) { return l.hash, nil } +func (l *rawTarLayer) DiffID() (v1.Hash, error) { return l.hash, nil } +func (l *rawTarLayer) Size() (int64, error) { return int64(len(l.bytes)), nil } +func (l *rawTarLayer) MediaType() (types.MediaType, error) { + return types.OCILayer, nil // uncompressed tar +} + +func (l *rawTarLayer) Compressed() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(l.bytes)), nil +} + +func (l *rawTarLayer) Uncompressed() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(l.bytes)), nil +} + +// truncatedLayer wraps a rawTarLayer but its Uncompressed reader +// returns the first `cutoff` bytes of the tar and then surfaces a +// read error. Simulates a flaky network / truncated blob mid-stream. +type truncatedLayer struct { + *rawTarLayer + cutoff int +} + +func (l *truncatedLayer) Uncompressed() (io.ReadCloser, error) { + head := l.bytes + if l.cutoff < len(head) { + head = head[:l.cutoff] + } + return io.NopCloser(io.MultiReader( + bytes.NewReader(head), + errReader{err: errors.New("simulated blob truncation")}, + )), nil +} + +// errReader always returns the configured error on Read. +type errReader struct{ err error } + +func (r errReader) Read(_ []byte) (int, error) { return 0, r.err } + +// buildLayer returns (tarBytes, []WeightLockFile describing its content). +// Each file's Layer field is filled with layerDigest once known. +func buildLayer(t *testing.T, files map[string][]byte) ([]byte, []lockfile.WeightLockFile) { + t.Helper() + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + lockFiles := make([]lockfile.WeightLockFile, 0, len(files)) + + // Stable iteration order so digests are deterministic across runs. + paths := make([]string, 0, len(files)) + for p := range files { + paths = append(paths, p) + } + // Simple insertion-order stability is enough for tests. + // Emit directory headers first (mirrors the real packer's behavior + // so we exercise the "skip non-regular entries" branch in pullLayer). + require.NoError(t, tw.WriteHeader(&tar.Header{ + Typeflag: tar.TypeDir, + Name: "./", + })) + for _, p := range paths { + data := files[p] + require.NoError(t, tw.WriteHeader(&tar.Header{ + Typeflag: tar.TypeReg, + Name: p, + Size: int64(len(data)), + })) + _, err := tw.Write(data) + require.NoError(t, err) + + sum := sha256.Sum256(data) + lockFiles = append(lockFiles, lockfile.WeightLockFile{ + Path: p, + Size: int64(len(data)), + Digest: "sha256:" + hex.EncodeToString(sum[:]), + }) + } + require.NoError(t, tw.Close()) + return buf.Bytes(), lockFiles +} + +// buildWeightImage returns a v1.Image + the manifest digest + the final +// lockfile entry for a weight whose layers contain the given per-layer +// file maps. +func buildWeightImage(t *testing.T, name, target string, layerFiles []map[string][]byte) (v1.Image, *lockfile.WeightLockEntry) { + t.Helper() + + img := empty.Image + img = mutate.MediaType(img, types.OCIManifestSchema1) + + var allFiles []lockfile.WeightLockFile + var lockLayers []lockfile.WeightLockLayer + + for _, files := range layerFiles { + tarBytes, fs := buildLayer(t, files) + layer := newRawTarLayer(tarBytes) + digest, err := layer.Digest() + require.NoError(t, err) + size, _ := layer.Size() + + // Attach layer digest to each lock file. + for i := range fs { + fs[i].Layer = digest.String() + } + allFiles = append(allFiles, fs...) + lockLayers = append(lockLayers, lockfile.WeightLockLayer{ + Digest: digest.String(), + MediaType: string(types.OCILayer), + Size: size, + SizeUncompressed: size, + }) + + img, err = mutate.Append(img, mutate.Addendum{Layer: layer}) + require.NoError(t, err) + } + + manifestDigest, err := img.Digest() + require.NoError(t, err) + + return img, &lockfile.WeightLockEntry{ + Name: name, + Target: target, + Digest: manifestDigest.String(), + Files: allFiles, + Layers: lockLayers, + } +} + +// stubRegistry composes MockRegistryClient and overrides GetImage to +// return real in-memory v1.Image values (which the mock returns nil +// for). Tests that want to assert Pull doesn't touch the registry set +// getImageErr. +type stubRegistry struct { + *registrytest.MockRegistryClient + images map[string]v1.Image + getImageErr error +} + +func newStubRegistry() *stubRegistry { + return &stubRegistry{ + MockRegistryClient: registrytest.NewMockRegistryClient(), + images: map[string]v1.Image{}, + } +} + +func (s *stubRegistry) put(ref string, img v1.Image) { s.images[ref] = img } + +func (s *stubRegistry) GetImage(_ context.Context, ref string, _ *registry.Platform) (v1.Image, error) { + if s.getImageErr != nil { + return nil, s.getImageErr + } + img, ok := s.images[ref] + if !ok { + return nil, fmt.Errorf("stub registry: no image at %s", ref) + } + return img, nil +} + +const testRepo = "example.com/me/model" + +func newTestManager(t *testing.T, reg registry.Client, lock *lockfile.WeightsLock) (*Manager, store.Store) { + t.Helper() + fs, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + mgr, err := NewManager(ManagerOptions{ + Store: fs, + Registry: reg, + Repo: testRepo, + Lock: lock, + ProjectDir: t.TempDir(), + }) + require.NoError(t, err) + return mgr, fs +} + +func TestManager_Pull_HappyPath(t *testing.T) { + t.Parallel() + ctx := context.Background() + reg := newStubRegistry() + + img, entry := buildWeightImage(t, "m1", "/src/weights", []map[string][]byte{ + {"a.txt": []byte("alpha bytes"), "b.txt": []byte("bravo bytes")}, + {"c.bin": []byte("charlie bytes")}, + }) + reg.put(testRepo+"@"+entry.Digest, img) + + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{*entry}} + mgr, fs := newTestManager(t, reg, lock) + + results, err := mgr.Pull(ctx, nil, nil) + require.NoError(t, err) + require.Len(t, results, 1) + assert.Equal(t, "m1", results[0].Name) + assert.False(t, results[0].FullyCached) + assert.Equal(t, 3, results[0].FilesFetched) + assert.Equal(t, 2, results[0].LayersFetched) + + // Every file is now in the store. + for _, f := range entry.Files { + ok, err := fs.Exists(ctx, f.Digest) + require.NoError(t, err) + require.True(t, ok, "file %s should be cached", f.Path) + } +} + +func TestManager_Pull_AllCached(t *testing.T) { + t.Parallel() + ctx := context.Background() + reg := newStubRegistry() + + img, entry := buildWeightImage(t, "m1", "/w", []map[string][]byte{ + {"x": []byte("data")}, + }) + reg.put(testRepo+"@"+entry.Digest, img) + + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{*entry}} + mgr, fs := newTestManager(t, reg, lock) + + // Pre-populate the store. + for _, f := range entry.Files { + require.NoError(t, fs.PutFile(ctx, f.Digest, f.Size, bytes.NewReader([]byte("data")))) + } + + // Make the registry explode if touched — a fully-cached pull must + // not call it. + reg.getImageErr = errors.New("registry should not be touched for cached pull") + + results, err := mgr.Pull(ctx, nil, nil) + require.NoError(t, err) + require.Len(t, results, 1) + assert.True(t, results[0].FullyCached) + assert.Equal(t, 0, results[0].FilesFetched) +} + +func TestManager_Pull_Idempotent(t *testing.T) { + t.Parallel() + ctx := context.Background() + reg := newStubRegistry() + + img, entry := buildWeightImage(t, "m1", "/w", []map[string][]byte{ + {"x": []byte("one"), "y": []byte("two")}, + }) + reg.put(testRepo+"@"+entry.Digest, img) + + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{*entry}} + mgr, _ := newTestManager(t, reg, lock) + + // First pull populates. + results, err := mgr.Pull(ctx, nil, nil) + require.NoError(t, err) + assert.False(t, results[0].FullyCached) + + // Second pull is a no-op. + results, err = mgr.Pull(ctx, nil, nil) + require.NoError(t, err) + assert.True(t, results[0].FullyCached) +} + +func TestManager_Pull_DigestMismatchInTar(t *testing.T) { + t.Parallel() + ctx := context.Background() + reg := newStubRegistry() + + img, entry := buildWeightImage(t, "m1", "/w", []map[string][]byte{ + {"a": []byte("legitimate content")}, + }) + // Corrupt the lockfile's expected digest so the bytes from the + // registry fail verification. The manifest digest stays valid + // (that addresses the manifest, not the file content). + entry.Files[0].Digest = "sha256:" + hex.EncodeToString(make([]byte, 32)) + reg.put(testRepo+"@"+entry.Digest, img) + + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{*entry}} + mgr, fs := newTestManager(t, reg, lock) + + _, err := mgr.Pull(ctx, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "digest mismatch") + + // Store is unchanged. + ok, err := fs.Exists(ctx, entry.Files[0].Digest) + require.NoError(t, err) + require.False(t, ok) +} + +func TestManager_Pull_UnexpectedFileInTar(t *testing.T) { + t.Parallel() + ctx := context.Background() + reg := newStubRegistry() + + // Build an image whose layer tar contains a file the lockfile + // doesn't describe. Constructing this by hand: write "a" + "b" + // into the tar, but only list "a" in the lockfile. + tarBytes, lockFiles := buildLayer(t, map[string][]byte{ + "a": []byte("known"), + "b": []byte("secret extra"), + }) + layer := newRawTarLayer(tarBytes) + layerDigest, _ := layer.Digest() + layerSize, _ := layer.Size() + + img := empty.Image + img = mutate.MediaType(img, types.OCIManifestSchema1) + img, err := mutate.Append(img, mutate.Addendum{Layer: layer}) + require.NoError(t, err) + manifestDigest, err := img.Digest() + require.NoError(t, err) + + // Keep only the "a" entry in the lockfile. + var keep lockfile.WeightLockFile + for _, f := range lockFiles { + if f.Path == "a" { + keep = f + keep.Layer = layerDigest.String() + } + } + entry := lockfile.WeightLockEntry{ + Name: "m1", + Target: "/w", + Digest: manifestDigest.String(), + Files: []lockfile.WeightLockFile{keep}, + Layers: []lockfile.WeightLockLayer{{ + Digest: layerDigest.String(), + MediaType: string(types.OCILayer), + Size: layerSize, + SizeUncompressed: layerSize, + }}, + } + + reg.put(testRepo+"@"+entry.Digest, img) + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{entry}} + mgr, _ := newTestManager(t, reg, lock) + + _, err = mgr.Pull(ctx, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "not in lockfile") +} + +func TestManager_Pull_LayerReadError(t *testing.T) { + t.Parallel() + // Simulate a flaky network: the layer blob read errors partway + // through the tar stream. The pull must fail with the underlying + // read error, and the partially-written file must not appear in + // the store (PutFile's temp+rename atomicity). + ctx := context.Background() + reg := newStubRegistry() + + // Build a layer with one small file and one large enough that + // cutting the tar part-way through the file body triggers a mid- + // stream read error. + small := []byte("small") + big := bytes.Repeat([]byte("X"), 8192) + tarBytes, lockFiles := buildLayer(t, map[string][]byte{ + "small": small, + "big": big, + }) + + // Wrap the layer so Uncompressed returns a truncated reader. The + // cutoff sits inside the payload of the second file. + rawLayer := newRawTarLayer(tarBytes) + layer := &truncatedLayer{rawTarLayer: rawLayer, cutoff: 1024} + layerDigest, _ := layer.Digest() + layerSize, _ := layer.Size() + + for i := range lockFiles { + lockFiles[i].Layer = layerDigest.String() + } + + img := empty.Image + img = mutate.MediaType(img, types.OCIManifestSchema1) + img, err := mutate.Append(img, mutate.Addendum{Layer: layer}) + require.NoError(t, err) + manifestDigest, err := img.Digest() + require.NoError(t, err) + + entry := lockfile.WeightLockEntry{ + Name: "m1", + Target: "/w", + Digest: manifestDigest.String(), + Files: lockFiles, + Layers: []lockfile.WeightLockLayer{{ + Digest: layerDigest.String(), + MediaType: string(types.OCILayer), + Size: layerSize, + SizeUncompressed: layerSize, + }}, + } + reg.put(testRepo+"@"+entry.Digest, img) + + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{entry}} + mgr, fs := newTestManager(t, reg, lock) + + _, err = mgr.Pull(ctx, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "truncation") + + // The partially-received file must not be exposed in the store. + // At most the fully-streamed file (whichever the tar emitted + // first) may be present — the torn one never is. + for _, f := range entry.Files { + ok, err := fs.Exists(ctx, f.Digest) + require.NoError(t, err) + if ok { + // If something landed, it must be the small file whose + // payload fit entirely before cutoff. The digest test + // below proves integrity. + require.Equal(t, sha256Of(small), f.Digest, + "only the fully-streamed small file may be present; torn files must not appear") + } + } +} + +func TestManager_Pull_LayerMissingExpectedFile(t *testing.T) { + t.Parallel() + // The lockfile claims layer L contains files A and B. The layer + // tar in the registry only has A. Pull must fail for the weight + // with a "missing expected file" error — guarding against + // registry/lockfile drift. + ctx := context.Background() + reg := newStubRegistry() + + // Construct a layer tar that only contains "a". + tarBytes, lockFiles := buildLayer(t, map[string][]byte{ + "a": []byte("alpha content"), + }) + layer := newRawTarLayer(tarBytes) + layerDigest, _ := layer.Digest() + layerSize, _ := layer.Size() + + img := empty.Image + img = mutate.MediaType(img, types.OCIManifestSchema1) + img, err := mutate.Append(img, mutate.Addendum{Layer: layer}) + require.NoError(t, err) + manifestDigest, err := img.Digest() + require.NoError(t, err) + + // Lockfile claims both "a" and "b" live in this layer — "b" is + // fabricated to trigger the post-walk missing-file check. + aFile := lockFiles[0] + aFile.Layer = layerDigest.String() + fakeB := lockfile.WeightLockFile{ + Path: "b", + Size: 5, + Digest: sha256Of([]byte("bravo")), + Layer: layerDigest.String(), + } + + entry := lockfile.WeightLockEntry{ + Name: "m1", + Target: "/w", + Digest: manifestDigest.String(), + Files: []lockfile.WeightLockFile{aFile, fakeB}, + Layers: []lockfile.WeightLockLayer{{ + Digest: layerDigest.String(), + MediaType: string(types.OCILayer), + Size: layerSize, + SizeUncompressed: layerSize, + }}, + } + reg.put(testRepo+"@"+entry.Digest, img) + + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{entry}} + mgr, _ := newTestManager(t, reg, lock) + + _, err = mgr.Pull(ctx, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing expected file") + assert.Contains(t, err.Error(), "b", "error should name the missing path") +} + +func TestManager_Pull_NameFilter(t *testing.T) { + t.Parallel() + ctx := context.Background() + reg := newStubRegistry() + + img1, e1 := buildWeightImage(t, "keep", "/k", []map[string][]byte{{"x": []byte("1")}}) + img2, e2 := buildWeightImage(t, "skip", "/s", []map[string][]byte{{"y": []byte("2")}}) + reg.put(testRepo+"@"+e1.Digest, img1) + reg.put(testRepo+"@"+e2.Digest, img2) + + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{*e1, *e2}} + mgr, fs := newTestManager(t, reg, lock) + + results, err := mgr.Pull(ctx, []string{"keep"}, nil) + require.NoError(t, err) + require.Len(t, results, 1) + assert.Equal(t, "keep", results[0].Name) + + // keep is cached; skip is not. + ok, err := fs.Exists(ctx, e1.Files[0].Digest) + require.NoError(t, err) + assert.True(t, ok) + + ok, err = fs.Exists(ctx, e2.Files[0].Digest) + require.NoError(t, err) + assert.False(t, ok) +} + +func TestManager_Pull_UnknownName(t *testing.T) { + t.Parallel() + ctx := context.Background() + reg := newStubRegistry() + + _, e1 := buildWeightImage(t, "known", "/k", []map[string][]byte{{"x": []byte("1")}}) + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{*e1}} + mgr, _ := newTestManager(t, reg, lock) + + _, err := mgr.Pull(ctx, []string{"known", "nope"}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "nope") +} + +func TestManager_Pull_EmitsEvents(t *testing.T) { + t.Parallel() + ctx := context.Background() + reg := newStubRegistry() + + img, entry := buildWeightImage(t, "m1", "/src/weights", []map[string][]byte{ + {"a.txt": []byte("alpha"), "b.txt": []byte("bravo")}, + }) + reg.put(testRepo+"@"+entry.Digest, img) + + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{*entry}} + mgr, _ := newTestManager(t, reg, lock) + + var events []PullEvent + _, err := mgr.Pull(ctx, nil, func(e PullEvent) { events = append(events, e) }) + require.NoError(t, err) + + // Expected sequence for a single weight with one layer of two + // files: WeightStart, LayerStart, FileStored x2, LayerDone, + // WeightDone. + kinds := make([]PullEventKind, len(events)) + for i, e := range events { + kinds[i] = e.Kind + } + require.Equal(t, []PullEventKind{ + PullEventWeightStart, + PullEventLayerStart, + PullEventFileStored, + PullEventFileStored, + PullEventLayerDone, + PullEventWeightDone, + }, kinds) + + // WeightStart carries the manifest reference and file counts. + start := events[0] + assert.Equal(t, "m1", start.Weight) + assert.Equal(t, "/src/weights", start.Target) + assert.Equal(t, 2, start.TotalFiles) + assert.Equal(t, 2, start.MissingFiles) + assert.Equal(t, testRepo+"@"+entry.Digest, start.ManifestRef) + + // FileStored events carry path + digest. + for _, e := range events[2:4] { + assert.NotEmpty(t, e.FilePath) + assert.NotEmpty(t, e.FileDigest) + } +} + +func TestManager_Pull_EmitsFullyCachedEvent(t *testing.T) { + t.Parallel() + ctx := context.Background() + reg := newStubRegistry() + + img, entry := buildWeightImage(t, "m1", "/w", []map[string][]byte{ + {"x": []byte("data")}, + }) + reg.put(testRepo+"@"+entry.Digest, img) + + lock := &lockfile.WeightsLock{Version: 1, Weights: []lockfile.WeightLockEntry{*entry}} + mgr, fs := newTestManager(t, reg, lock) + + // Pre-populate. + for _, f := range entry.Files { + require.NoError(t, fs.PutFile(ctx, f.Digest, f.Size, bytes.NewReader([]byte("data")))) + } + + var events []PullEvent + _, err := mgr.Pull(ctx, nil, func(e PullEvent) { events = append(events, e) }) + require.NoError(t, err) + + // Fully-cached weights emit exactly WeightStart + WeightDone. + require.Len(t, events, 2) + assert.Equal(t, PullEventWeightStart, events[0].Kind) + assert.Equal(t, 0, events[0].MissingFiles) + assert.Empty(t, events[0].ManifestRef, "fully-cached weight should not set manifest ref") + assert.Equal(t, PullEventWeightDone, events[1].Kind) + assert.True(t, events[1].FullyCached) +} + +func TestNewManager_RequiresStore(t *testing.T) { + t.Parallel() + _, err := NewManager(ManagerOptions{ + Registry: newStubRegistry(), + Repo: "r", + Lock: &lockfile.WeightsLock{}, + }) + require.Error(t, err) +} + +func TestNewManager_RequiresRegistry(t *testing.T) { + t.Parallel() + fs, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + _, err = NewManager(ManagerOptions{ + Store: fs, + Repo: "r", + Lock: &lockfile.WeightsLock{}, + }) + require.Error(t, err) +} + +func TestNewManager_RequiresRepoWhenLockHasWeights(t *testing.T) { + t.Parallel() + fs, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + // Lock with at least one entry → Repo is now required. + _, err = NewManager(ManagerOptions{ + Store: fs, + Registry: newStubRegistry(), + Lock: &lockfile.WeightsLock{ + Version: 1, + Weights: []lockfile.WeightLockEntry{{Name: "w", Target: "/t"}}, + }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "repo") +} + +func TestNewManager_NilLockIsNoop(t *testing.T) { + t.Parallel() + fs, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + // No Lock, no Repo: weights-less model. Manager still constructs. + mgr, err := NewManager(ManagerOptions{ + Store: fs, + Registry: newStubRegistry(), + }) + require.NoError(t, err) + + // Pull is a no-op. + results, err := mgr.Pull(context.Background(), nil, nil) + require.NoError(t, err) + assert.Empty(t, results) + + // Prepare is a no-op. + mounts, err := mgr.Prepare(context.Background()) + require.NoError(t, err) + assert.Empty(t, mounts.Specs) + require.NoError(t, mounts.Release()) +} diff --git a/pkg/weights/setup.go b/pkg/weights/setup.go new file mode 100644 index 0000000000..98c8774fdb --- /dev/null +++ b/pkg/weights/setup.go @@ -0,0 +1,56 @@ +package weights + +import ( + "errors" + "fmt" + "io/fs" + "path/filepath" + + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/weights/lockfile" + "github.com/replicate/cog/pkg/weights/store" +) + +// NewFromSource constructs a Manager from a model.Source and an +// already-parsed repository string. Callers (typically CLI commands) +// are responsible for parsing their `--image` flag / `cog.yaml image:` +// value into a bare repo before calling. +// +// repo may be empty for models that declare no weights — the returned +// Manager is a valid no-op in that case, so CLI callers can construct +// one unconditionally and let Pull/Prepare decide if there's anything +// to do. +// +// Missing lockfiles error out with an actionable message pointing at +// `cog weights import` when the model actually has weights. +func NewFromSource(src *model.Source, repo string) (*Manager, error) { + fileStore, err := store.OpenDefault() + if err != nil { + return nil, fmt.Errorf("open weights cache: %w", err) + } + + var lock *lockfile.WeightsLock + if len(src.Config.Weights) > 0 { + if repo == "" { + return nil, errors.New("cog.yaml declares weights but no repository was resolved; set 'image:' in cog.yaml or pass --image") + } + lockPath := filepath.Join(src.ProjectDir, lockfile.WeightsLockFilename) + loaded, err := lockfile.LoadWeightsLock(lockPath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil, fmt.Errorf("%s not found (run 'cog weights import' first)", lockfile.WeightsLockFilename) + } + return nil, fmt.Errorf("load %s: %w", lockfile.WeightsLockFilename, err) + } + lock = loaded + } + + return NewManager(ManagerOptions{ + Store: fileStore, + Registry: registry.NewRegistryClient(), + Repo: repo, + Lock: lock, + ProjectDir: src.ProjectDir, + }) +} diff --git a/pkg/weights/store/file.go b/pkg/weights/store/file.go new file mode 100644 index 0000000000..1c18faa921 --- /dev/null +++ b/pkg/weights/store/file.go @@ -0,0 +1,259 @@ +package store + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "io/fs" + "iter" + "os" + "path/filepath" + "strings" + + "github.com/replicate/cog/pkg/paths" +) + +const ( + digestAlgorithm = "sha256" + sha256HexLen = 64 + filesDir = "files" +) + +// FileStore is a Store backed by a directory on the local filesystem. +// Files are stored content-addressed under /files/sha256//. +// +// FileStore is safe for concurrent use by multiple goroutines and +// processes: PutFile writes to a temporary file and atomically renames; +// reads are stateless. +type FileStore struct { + root string +} + +// OpenDefault opens the per-user weights FileStore at the path +// returned by paths.WeightsStoreDir. Use this whenever you'd +// otherwise pair WeightsStoreDir + NewFileStore — same intent, less +// boilerplate, one error message style across call sites. +func OpenDefault() (*FileStore, error) { + dir, err := paths.WeightsStoreDir() + if err != nil { + return nil, fmt.Errorf("resolve weights store dir: %w", err) + } + return NewFileStore(dir) +} + +// NewFileStore returns a FileStore rooted at dir. The root and the +// files/sha256/ subtree are created if they don't exist. +func NewFileStore(dir string) (*FileStore, error) { + if dir == "" { + return nil, errors.New("file store: root directory must not be empty") + } + if err := os.MkdirAll(filepath.Join(dir, filesDir, digestAlgorithm), 0o755); err != nil { + return nil, fmt.Errorf("file store: create root: %w", err) + } + return &FileStore{root: dir}, nil +} + +// Root returns the on-disk root of the store. +func (s *FileStore) Root() string { return s.root } + +// parseDigest splits and validates "sha256:<64-lowercase-hex>". +func parseDigest(digest string) (string, error) { + algo, hexStr, ok := strings.Cut(digest, ":") + if !ok { + return "", fmt.Errorf("invalid digest %q: missing algorithm prefix", digest) + } + if algo != digestAlgorithm { + return "", fmt.Errorf("invalid digest %q: only %s is supported", digest, digestAlgorithm) + } + if len(hexStr) != sha256HexLen { + return "", fmt.Errorf("invalid digest %q: expected %d hex chars, got %d", digest, sha256HexLen, len(hexStr)) + } + // hex.DecodeString tolerates uppercase; we require lowercase so + // paths stay canonical. + if strings.ToLower(hexStr) != hexStr { + return "", fmt.Errorf("invalid digest %q: non-lowercase hex", digest) + } + if _, err := hex.DecodeString(hexStr); err != nil { + return "", fmt.Errorf("invalid digest %q: %w", digest, err) + } + return hexStr, nil +} + +func (s *FileStore) pathFor(hexStr string) string { + return filepath.Join(s.root, filesDir, digestAlgorithm, hexStr[:2], hexStr) +} + +func (s *FileStore) prefixDir(hexStr string) string { + return filepath.Join(s.root, filesDir, digestAlgorithm, hexStr[:2]) +} + +// Exists reports whether a file with the given digest is in the store. +func (s *FileStore) Exists(_ context.Context, digest string) (bool, error) { + hexStr, err := parseDigest(digest) + if err != nil { + return false, err + } + switch _, statErr := os.Stat(s.pathFor(hexStr)); { + case statErr == nil: + return true, nil + case errors.Is(statErr, fs.ErrNotExist): + return false, nil + default: + return false, fmt.Errorf("stat %s: %w", digest, statErr) + } +} + +// PutFile writes r to the store under expectedDigest, verifying the +// computed digest as it streams. +// +// Idempotency: if the digest is already present, r is drained to +// io.Discard and nil is returned. This matters because Pull streams a +// whole layer tar and may encounter files already stored from a +// previous pull — we need those to succeed without desyncing the tar. +func (s *FileStore) PutFile(ctx context.Context, expectedDigest string, expectedSize int64, r io.Reader) error { + hexStr, err := parseDigest(expectedDigest) + if err != nil { + return err + } + + if ok, err := s.Exists(ctx, expectedDigest); err != nil { + return err + } else if ok { + _, _ = io.Copy(io.Discard, r) + return nil + } + + prefix := s.prefixDir(hexStr) + if err := os.MkdirAll(prefix, 0o755); err != nil { + return fmt.Errorf("create prefix dir: %w", err) + } + + tmp, err := os.CreateTemp(prefix, "put-*") + if err != nil { + return fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmp.Name() + defer func() { + if tmpPath != "" { + _ = os.Remove(tmpPath) + } + }() + + hasher := sha256.New() + reader := &ctxReader{ctx: ctx, r: io.TeeReader(r, hasher)} + + n, err := io.Copy(tmp, reader) + if err != nil { + _ = tmp.Close() + return fmt.Errorf("write %s: %w", expectedDigest, err) + } + if err := tmp.Close(); err != nil { + return fmt.Errorf("close temp file: %w", err) + } + + if expectedSize >= 0 && n != expectedSize { + return fmt.Errorf("size mismatch for %s: expected %d bytes, got %d", expectedDigest, expectedSize, n) + } + + gotHex := hex.EncodeToString(hasher.Sum(nil)) + if gotHex != hexStr { + return fmt.Errorf("digest mismatch: expected sha256:%s, got sha256:%s", hexStr, gotHex) + } + + final := s.pathFor(hexStr) + // gosec G304/G703: tmpPath comes from os.CreateTemp inside prefix, + // final is composed from the validated sha256 hex — both paths are + // constrained to the store root by construction. + if err := os.Rename(tmpPath, final); err != nil { //nolint:gosec // see comment above + return fmt.Errorf("rename %s: %w", final, err) + } + tmpPath = "" + return nil +} + +// Path returns the on-disk path for the file at digest, or an error +// wrapping fs.ErrNotExist if the digest is not in the store. +func (s *FileStore) Path(ctx context.Context, digest string) (string, error) { + hexStr, err := parseDigest(digest) + if err != nil { + return "", err + } + ok, err := s.Exists(ctx, digest) + if err != nil { + return "", err + } + if !ok { + return "", fmt.Errorf("path %s: %w", digest, fs.ErrNotExist) + } + return s.pathFor(hexStr), nil +} + +// List walks files/sha256/ and yields one FileInfo per entry. Stray +// files (stale temp files from interrupted writes, anything whose name +// isn't a 64-char hex digest) are skipped. +func (s *FileStore) List(ctx context.Context) iter.Seq2[FileInfo, error] { + return func(yield func(FileInfo, error) bool) { + root := filepath.Join(s.root, filesDir, digestAlgorithm) + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + if errors.Is(err, fs.ErrNotExist) && path == root { + return nil + } + return err + } + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + if d.IsDir() { + return nil + } + name := d.Name() + if len(name) != sha256HexLen { + return nil + } + info, err := d.Info() + if err != nil { + return err + } + if !yield(FileInfo{Digest: digestAlgorithm + ":" + name, Size: info.Size()}, nil) { + return filepath.SkipAll + } + return nil + }) + if err != nil && !errors.Is(err, filepath.SkipAll) { + yield(FileInfo{}, err) + } + } +} + +// Delete removes the file at digest. Missing digests are not an error. +func (s *FileStore) Delete(_ context.Context, digest string) error { + hexStr, err := parseDigest(digest) + if err != nil { + return err + } + switch err := os.Remove(s.pathFor(hexStr)); { + case err == nil, errors.Is(err, fs.ErrNotExist): + return nil + default: + return fmt.Errorf("delete %s: %w", digest, err) + } +} + +// ctxReader makes an io.Reader cancelable at Read-boundary granularity. +type ctxReader struct { + ctx context.Context + r io.Reader +} + +func (c *ctxReader) Read(p []byte) (int, error) { + if err := c.ctx.Err(); err != nil { + return 0, err + } + return c.r.Read(p) +} + +var _ Store = (*FileStore)(nil) diff --git a/pkg/weights/store/file_test.go b/pkg/weights/store/file_test.go new file mode 100644 index 0000000000..517afa9b6e --- /dev/null +++ b/pkg/weights/store/file_test.go @@ -0,0 +1,345 @@ +package store + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func digestOf(data []byte) string { + sum := sha256.Sum256(data) + return "sha256:" + hex.EncodeToString(sum[:]) +} + +func newStore(t *testing.T) *FileStore { + t.Helper() + s, err := NewFileStore(t.TempDir()) + require.NoError(t, err) + return s +} + +func TestNewFileStore_EmptyRootRejected(t *testing.T) { + t.Parallel() + _, err := NewFileStore("") + require.Error(t, err) +} + +func TestFile_RoundTrip(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + data := []byte("hello weights") + d := digestOf(data) + + ok, err := s.Exists(ctx, d) + require.NoError(t, err) + require.False(t, ok) + + require.NoError(t, s.PutFile(ctx, d, int64(len(data)), bytes.NewReader(data))) + + ok, err = s.Exists(ctx, d) + require.NoError(t, err) + require.True(t, ok) + + p, err := s.Path(ctx, d) + require.NoError(t, err) + onDisk, err := os.ReadFile(p) //nolint:gosec // test-owned path + require.NoError(t, err) + require.Equal(t, data, onDisk) + + // Path is under the expected layout: /files/sha256//. + hexStr := strings.TrimPrefix(d, "sha256:") + require.Equal(t, filepath.Join(s.Root(), "files", "sha256", hexStr[:2], hexStr), p) +} + +func TestFile_PutFile_SizeMismatch(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + data := []byte("hello weights") + d := digestOf(data) + + // Claim the file is larger than it really is. + err := s.PutFile(ctx, d, int64(len(data))+100, bytes.NewReader(data)) + require.Error(t, err) + assert.Contains(t, err.Error(), "size mismatch") + + // File must not be stored. + ok, err := s.Exists(ctx, d) + require.NoError(t, err) + require.False(t, ok) +} + +func TestFile_PutFile_DigestMismatch(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + data := []byte("real content") + wrong := digestOf([]byte("something else")) + + err := s.PutFile(ctx, wrong, int64(len(data)), bytes.NewReader(data)) + require.Error(t, err) + require.Contains(t, err.Error(), "digest mismatch") + + ok, err := s.Exists(ctx, wrong) + require.NoError(t, err) + require.False(t, ok) + + // No stray temp files in the prefix dir. + hexStr := strings.TrimPrefix(wrong, "sha256:") + prefix := filepath.Join(s.Root(), "files", "sha256", hexStr[:2]) + if entries, err := os.ReadDir(prefix); err == nil { + for _, e := range entries { + assert.NotContains(t, e.Name(), "put-", "stray temp file left behind: %s", e.Name()) + } + } +} + +func TestFile_PutFile_Idempotent(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + data := []byte("idempotent bytes") + d := digestOf(data) + + require.NoError(t, s.PutFile(ctx, d, int64(len(data)), bytes.NewReader(data))) + + // Second Put must succeed AND drain the reader — Pull relies on + // the tar stream staying in sync. + r := &countingReader{Reader: bytes.NewReader(data)} + require.NoError(t, s.PutFile(ctx, d, int64(len(data)), r)) + require.Equal(t, len(data), r.n) +} + +func TestFile_Path_NotFound(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + _, err := s.Path(ctx, digestOf([]byte("absent"))) + require.Error(t, err) + require.ErrorIs(t, err, fs.ErrNotExist) +} + +func TestFile_Delete_Idempotent(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + data := []byte("to delete") + d := digestOf(data) + + // Delete on empty is fine. + require.NoError(t, s.Delete(ctx, d)) + + require.NoError(t, s.PutFile(ctx, d, int64(len(data)), bytes.NewReader(data))) + require.NoError(t, s.Delete(ctx, d)) + + ok, err := s.Exists(ctx, d) + require.NoError(t, err) + require.False(t, ok) + + require.NoError(t, s.Delete(ctx, d)) +} + +func TestFile_List(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + + want := map[string]int64{} + for _, b := range [][]byte{[]byte("alpha"), []byte("beta"), []byte("gamma")} { + d := digestOf(b) + want[d] = int64(len(b)) + require.NoError(t, s.PutFile(ctx, d, int64(len(b)), bytes.NewReader(b))) + } + + got := map[string]int64{} + for fi, err := range s.List(ctx) { + require.NoError(t, err) + got[fi.Digest] = fi.Size + } + require.Equal(t, want, got) +} + +func TestFile_List_EmptyStore(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + count := 0 + for _, err := range s.List(ctx) { + require.NoError(t, err) + count++ + } + require.Equal(t, 0, count) +} + +func TestFile_List_SkipsStrayTempFiles(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + + data := []byte("real file") + d := digestOf(data) + require.NoError(t, s.PutFile(ctx, d, int64(len(data)), bytes.NewReader(data))) + + // Drop a stray temp-file-like entry alongside the real blob. + hexStr := strings.TrimPrefix(d, "sha256:") + prefix := filepath.Join(s.Root(), "files", "sha256", hexStr[:2]) + require.NoError(t, os.WriteFile(filepath.Join(prefix, "put-stray"), []byte("trash"), 0o644)) //nolint:gosec // test + + count := 0 + for fi, err := range s.List(ctx) { + require.NoError(t, err) + require.Equal(t, d, fi.Digest) + count++ + } + require.Equal(t, 1, count) +} + +func TestFile_ConcurrentPutSameDigest(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + data := []byte("racy bytes") + d := digestOf(data) + + const goroutines = 8 + var wg sync.WaitGroup + errs := make([]error, goroutines) + wg.Add(goroutines) + for i := range goroutines { + go func(i int) { + defer wg.Done() + errs[i] = s.PutFile(ctx, d, int64(len(data)), bytes.NewReader(data)) + }(i) + } + wg.Wait() + for i, err := range errs { + assert.NoError(t, err, "goroutine %d", i) + } + + p, err := s.Path(ctx, d) + require.NoError(t, err) + got, err := os.ReadFile(p) //nolint:gosec // test-owned path + require.NoError(t, err) + require.Equal(t, data, got) +} + +func TestFile_InterruptedWriteLeavesNoFinalFile(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + data := []byte("interrupted") + d := digestOf(data) + + err := s.PutFile(ctx, d, int64(len(data)), &failingReader{after: 4, data: data}) + require.Error(t, err) + + ok, err := s.Exists(ctx, d) + require.NoError(t, err) + require.False(t, ok) +} + +func TestFile_PutFile_ContextCanceled(t *testing.T) { + t.Parallel() + s := newStore(t) + data := []byte("bytes that will never finish writing") + d := digestOf(data) + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel the context via the reader — the reader blocks its + // second Read until the context is observed as canceled, so the + // next ctxReader.Read guarantees ctx.Err() is set. + reader := &gatedReader{data: data, cancel: cancel} + + err := s.PutFile(ctx, d, int64(len(data)), reader) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + + // The store must not expose a partial file. + ok, err := s.Exists(context.Background(), d) + require.NoError(t, err) + require.False(t, ok) +} + +func TestFile_InvalidDigestRejected(t *testing.T) { + t.Parallel() + ctx := context.Background() + s := newStore(t) + + for _, bad := range []string{ + "", + "no-colon", + "md5:" + strings.Repeat("0", 32), + "sha256:", + "sha256:tooShort", + "sha256:" + strings.Repeat("Z", 64), // uppercase + } { + _, err := s.Exists(ctx, bad) + require.Error(t, err, "digest %q must be rejected", bad) + } +} + +// countingReader counts bytes read. +type countingReader struct { + io.Reader + n int +} + +func (c *countingReader) Read(p []byte) (int, error) { + n, err := c.Reader.Read(p) + c.n += n + return n, err +} + +// gatedReader emits one byte per Read and cancels its context after +// the first byte. The next Read is preceded by ctxReader's ctx.Err() +// check, which deterministically observes the canceled context. +type gatedReader struct { + data []byte + off int + cancel context.CancelFunc +} + +func (g *gatedReader) Read(p []byte) (int, error) { + if g.off >= len(g.data) { + return 0, io.EOF + } + p[0] = g.data[g.off] + g.off++ + if g.off == 1 { + g.cancel() + } + return 1, nil +} + +// failingReader returns data[:after] then an error. +type failingReader struct { + after int + off int + data []byte +} + +func (f *failingReader) Read(p []byte) (int, error) { + remaining := f.after - f.off + if remaining <= 0 { + return 0, errors.New("simulated transport failure") + } + if remaining > len(p) { + remaining = len(p) + } + n := copy(p, f.data[f.off:f.off+remaining]) + f.off += n + return n, nil +} diff --git a/pkg/weights/store/store.go b/pkg/weights/store/store.go new file mode 100644 index 0000000000..95a689829e --- /dev/null +++ b/pkg/weights/store/store.go @@ -0,0 +1,56 @@ +// Package store defines a narrow, content-addressed interface for +// storing individual weight files on the local machine. +// +// The store knows only digests. Filenames, layer membership, and +// registry URIs are Manager-level concerns. Keeping the surface small +// is what makes the store swappable — a future containerd-backed store +// can drop in behind the same interface. +// +// Digests are "sha256:"; v1 implementations may reject other +// algorithms. Missing digests surface as errors wrapping fs.ErrNotExist. +package store + +import ( + "context" + "io" + "iter" +) + +// Store is a content-addressed store of individual weight files. +// +// Every method takes a context; implementations SHOULD honor cancellation +// where it makes sense. Missing digests on Path surface as errors +// wrapping fs.ErrNotExist. +type Store interface { + // Exists reports whether a file with the given digest is in the store. + // A nil error with false is the ordinary "not present" result. + Exists(ctx context.Context, digest string) (bool, error) + + // PutFile stores r under expectedDigest, hash-verifying as it streams. + // A digest mismatch leaves the store unchanged. + // + // PutFile is idempotent: if the digest is already present, the + // reader is drained to io.Discard and nil is returned. This lets + // callers loop over tar entries without branching on Exists first. + // + // size is advisory. + PutFile(ctx context.Context, expectedDigest string, size int64, r io.Reader) error + + // Path returns an on-disk path for the file, suitable for hardlinking. + // The file MUST be treated as read-only. Not every backend can + // satisfy this; such backends return an error. + Path(ctx context.Context, digest string) (string, error) + + // List iterates every file in the store. Walk errors surface as a + // final (zero, err) pair before the iterator terminates. + List(ctx context.Context) iter.Seq2[FileInfo, error] + + // Delete removes the file at digest. Missing digests are not an error. + Delete(ctx context.Context, digest string) error +} + +// FileInfo describes one stored file. +type FileInfo struct { + Digest string + Size int64 +} diff --git a/specs/draft-weights.md b/specs/draft-weights.md new file mode 100644 index 0000000000..03e87a1e24 --- /dev/null +++ b/specs/draft-weights.md @@ -0,0 +1,393 @@ +# Managed Weights: OCI Format and Runtime State Specification + +- Version: 1.0-draft +- Status: Draft + +## Overview + +A managed weight is a named set of files (model weights, configs, tokenizers, etc.) stored as an OCI artifact in a container registry. Each weight maps to a target directory in the running container and is delivered independently from the model image. + +This spec covers three things: how weight data is packed into OCI layers, how those layers are described in an OCI manifest, and how weight readiness is communicated between a provider and a consumer. + +The spec is driven by the needs of [Cog](https://github.com/replicate/cog) but is a general-purpose format for storing and delivering AI model weights via OCI registries. + +## 1. Layer Format + +All layers are tar archives. Tar provides file metadata (path, size, permissions) at negligible overhead (512 bytes per entry) and supports streaming extraction without buffering. + +**Layers are immutable.** Once a layer is produced and pushed, its content never changes -- the digest is its identity. A weight set is a fixed collection of immutable layers. If all layers are present, the complete file set is present. This property is fundamental to the caching and delivery model: layers can be cached indefinitely, shared across weights and models, and assembled into weight sets without re-verification. + +### 1.1 Layer independence (order-invariant extraction) + +**Layers MUST be extractable in any order and produce identical results.** Unlike Docker image layers which use overlay semantics (later layers shadow earlier ones), weight layers are independent units. Each layer contains a disjoint set of files. No file path appears in more than one layer within a manifest. + +Specifically: + +- No file path appears in more than one layer (disjoint file sets). +- Layers MUST NOT contain overlay/union filesystem artifacts: no whiteout files (`.wh.*`), no opaque whiteout markers (`.wh..wh..opq`), no delete markers of any kind. +- Extracting all layers to the same target directory in any order MUST produce a byte-identical result. + +This constraint exists because weight layers are large (multi-GB) and must be downloaded and extracted in parallel without coordination or sequencing. Requiring ordered extraction would force either buffering of out-of-order layers or serialized download, both unacceptable at the scale of model weights. + +The packing algorithm (§1.2) enforces this: each source file is assigned to exactly one layer. The manifest records the full set of layers; consumers extract all layers to the same target directory in whatever order they arrive. + +### 1.2 Packing strategy + +Producers assign each source file to exactly one layer, maintaining the disjoint file set invariant (§1.1). Two categories of layers exist: + +- **Bundle layers** contain multiple small files packed into a single tar archive. Files within a bundle MUST be stable-sorted by relative path so that identical source files produce byte-identical tar archives (and therefore identical layer digests) across reimports. +- **Standalone layers** contain a single file as a single-entry tar. + +Whether to bundle or not, whether to compress or not, and all other packing parameters are producer implementation choices. Consumers MUST process each layer according to its media type (§2.1) regardless of the producer's choices. For example, a producer might bundle all files under 64 MB into compressed tar layers (up to 256 MB each) and give every file at or above 64 MB its own uncompressed layer. A different producer could skip bundling entirely and emit one layer per file. + +As a general principle, producers SHOULD compress bundle layers (dominated by compressible text formats like JSON and YAML) and SHOULD NOT compress standalone layers (often high-entropy binary data where compression yields negligible savings at substantial CPU cost). If a producer does compress large standalone layers, it SHOULD first probe the content to verify compression yields meaningful savings -- many weight formats are high-entropy and compress poorly. It SHOULD also use a format that supports parallel decompression (e.g., seekable zstd) to avoid serializing extraction of multi-GB layers. + +### 1.3 Allowed content + +Layers MUST contain only regular files and directories. The following are not permitted: + +- **Symlinks** (symbolic and hard links) -- introduce ambiguity (relative vs absolute targets, dangling references, circular chains) and path traversal risk during extraction. Source directories containing symlinks MUST be resolved to regular files before import. +- **Device nodes, FIFOs, sockets** -- not meaningful for weight data. +- **Whiteout files** (`.wh.*`, `.wh..wh..opq`) -- overlay filesystem artifacts that imply ordered layer semantics, which this format forbids (§1.1). +- **Extended attributes, ACLs, security labels** -- platform-specific metadata that breaks deterministic packing. + +Producers MUST reject (not silently skip) excluded content with a descriptive error. + +Producers MUST also reject source directories containing a `.cog/` (or equivalent state directory, see §3) directory, which is reserved for the runtime state protocol. + +### 1.4 Tar properties (deterministic packing) + +All tar archives MUST be produced with these properties to ensure byte-identical digests across re-imports from the same source: + +- Format: PAX (for paths exceeding 100 characters) +- `mtime`, `atime`, `ctime`: 0 (Unix epoch) +- UID/GID: 0/0 +- Permissions: 0644 (files), 0755 (directories) +- No extended attributes, no system-specific metadata +- Paths relative to the weight's target directory (no leading `/` or `./`) +- Paths are case-sensitive with no case folding. +- Paths MUST be valid UTF-8. +- Path components MUST NOT contain: NUL (`\0`), forward slash (`/` -- used only as the path separator), backslash (`\`), or control characters (bytes `0x01`-`0x1F` and `0x7F`). + +### 1.5 No file splitting + +Each file is packed into exactly one layer, whole. Files are never split across multiple layers. This keeps the format simple -- no reassembly metadata, no ordering dependencies between layers, and each extracted file is immediately usable. + +This works because training frameworks already shard large models into multiple files (e.g., 64x 9.8 GB safetensors for kimi-k2.5). The sharding provides natural parallelism at the layer level. If a use case arises where individual files are too large for practical single-layer transport, file splitting would require a reassembly protocol and is deferred to a future spec version. + +## 2. OCI Manifest + +Each named weight is an OCI manifest with `artifactType` identifying it as a cog weight artifact. + +### 2.1 Media types + +| Media type | Usage | +| --------------------------------------------- | ----------------------------- | +| `application/vnd.cog.weight.v1` | Manifest `artifactType` field | +| `application/vnd.cog.weight.config.v1+json` | Config blob media type | +| `application/vnd.oci.image.layer.v1.tar` | Uncompressed tar layer | +| `application/vnd.oci.image.layer.v1.tar+gzip` | Gzip-compressed tar layer | +| `application/vnd.oci.image.layer.v1.tar+zstd` | Zstd-compressed tar layer | + +The layer media types (`tar`, `tar+gzip`, `tar+zstd`) are standard OCI types defined in the [OCI image spec](https://github.com/opencontainers/image-spec/blob/main/layer.md), reused here for ecosystem compatibility. Consumers MUST accept all three. Producers choose which to use for each layer (§1.2); the media type communicates that choice. Because the manifest uses standard OCI media types throughout, existing tools (crane, skopeo, containerd, `docker pull`) work with weight artifacts without modification. + +The `artifactType` distinguishes weight manifests from runnable image manifests. The config blob carries a file-level index for the weight (§2.3). + +### 2.2 Manifest structure + +```json +{ + "schemaVersion": 2, + "mediaType": "application/vnd.oci.image.manifest.v1+json", + "artifactType": "application/vnd.cog.weight.v1", + "config": { + "mediaType": "application/vnd.cog.weight.config.v1+json", + "digest": "sha256:config123...", + "size": 512 + }, + "layers": [ + { + "mediaType": "application/vnd.oci.image.layer.v1.tar+gzip", + "digest": "sha256:aaa...", + "size": 15000000 + }, + { + "mediaType": "application/vnd.oci.image.layer.v1.tar", + "digest": "sha256:bbb...", + "size": 3957900840 + } + ], + "annotations": { + "run.cog.weight.name": "z-image-turbo", + "run.cog.weight.target": "/src/weights", + "run.cog.weight.set-digest": "sha256:def456..." + } +} +``` + +The manifest contains no timestamps, source URIs, or producer version metadata. This makes the manifest a pure function of the weight content (files), the packing strategy (layers), and the cog.yaml config (name, target). Identical inputs always produce an identical manifest digest, and the registry handles dedup at the storage level (§2.7). + +### 2.3 Config blob (file index) + +The config blob is a JSON document with media type `application/vnd.cog.weight.config.v1+json`. It describes the weight artifact and provides a file-level index: every file, which layer it belongs to, its size, and its content digest. + +```json +{ + "name": "z-image-turbo", + "target": "/src/weights", + "setDigest": "sha256:def456...", + "files": [ + { + "path": "config.json", + "layer": "sha256:aaa...", + "size": 1234, + "digest": "sha256:f01..." + }, + { + "path": "tokenizer.json", + "layer": "sha256:aaa...", + "size": 5678, + "digest": "sha256:f02..." + }, + { + "path": "text_encoder/model-00001-of-00003.safetensors", + "layer": "sha256:bbb...", + "size": 3957900840, + "digest": "sha256:f03..." + } + ] +} +``` + +**Top-level fields:** + +| Field | Type | Description | +| ----------- | ------ | --------------------------------------------------------------------------------------------- | +| `name` | string | Weight name (e.g., `z-image-turbo`). Same as the manifest annotation. | +| `target` | string | Absolute mount path in the container (e.g., `/src/weights`). Same as the manifest annotation. | +| `setDigest` | string | Weight set digest (§2.4). Same as the manifest annotation. | + +**File entry fields:** + +| Field | Type | Description | +| -------- | ------- | ------------------------------------------------------------------------------ | +| `path` | string | File path relative to the weight target directory. Same as the tar entry path. | +| `layer` | string | Digest of the layer containing this file. | +| `size` | integer | File size in bytes (uncompressed). | +| `digest` | string | SHA-256 content digest of the individual file. | + +The `files` array MUST be sorted by `path` lexicographically. This ensures the config blob is deterministic for a given packing: the same source files packed with the same parameters always produce an identical config blob. Note that the config blob may differ across packing changes (different `layer` values), but the weight set digest (§2.4) remains stable because it is computed from file content only. + +The config blob provides a complete file-level index of the weight. Infra uses it to assemble the final weight directory from extracted layers without walking the filesystem -- for each file, it knows exactly which layer to source from. The per-file `digest` additionally enables infra to identify identical files across different layers or weights. The `name`, `target`, and `setDigest` fields duplicate the manifest annotations so the config blob is self-describing -- a consumer with only the config blob has enough context to understand what it is and where it goes. + +### 2.4 Weight set digest + +The **weight set digest** is the content identity of a weight's file set, independent of how those files are packed into layers, and independent of manifest metadata (annotations, timestamps, producer version). It is the canonical content-addressable identifier for a set of weight files: two weight manifests with identical weight set digests produce byte-identical extracted results. + +Producers MUST compute the weight set digest as: + +``` +sha256(join(sort(entries), "\n")) +``` + +Where each entry is ` ` (hex-encoded SHA-256 hash of the file content, two spaces, file path) from the config blob's `files` array, sorted lexicographically by `path`. The result is encoded as a standard OCI digest string (e.g., `sha256:def456...`). + +This entry format matches the output of `sha256sum`, so producers and operators can verify a weight set digest from a shell: + +```bash +sha256sum $(find -type f | sort) | sha256sum +``` + +Because the weight set digest is computed from file content (not layer structure), it is **packing-independent**: changing bundle thresholds, compression settings, or any other packing parameter does not change the weight set digest as long as the source files are identical. Different producer versions producing different layer layouts from the same source files will produce the same weight set digest. + +Producers MUST include this value as the `run.cog.weight.set-digest` manifest annotation. The computation is specified so that any party (producers, infra, operators) can independently verify or recompute it. + +The weight set digest enables several behaviors: + +- **Caching**: Infra uses it as the key for assembled weights. If the assembled result already exists for this digest, skip extraction entirely. +- **Cross-model reuse**: Two models using identical weight files produce identical file digests and therefore identical weight set digests, enabling shared caching even when the models have separate weight repositories and different layer layouts. + +### 2.5 Annotations + +Annotations use the `run.cog.*` namespace (reverse-domain of cog.run). + +**Manifest-level annotations:** + +| Key | Value | Description | +| --------------------------- | ------------- | ---------------------------------------------------------------------- | +| `run.cog.weight.name` | string | Weight name (e.g., `z-image-turbo`). REQUIRED. | +| `run.cog.weight.target` | string | Absolute mount path in the container (e.g., `/src/weights`). REQUIRED. | +| `run.cog.weight.set-digest` | digest string | Weight set digest (§2.4). REQUIRED. | + +All manifest-level annotations are deterministic from the weight content and cog.yaml config. No timestamps, source URIs, or producer metadata are included -- identical inputs always produce an identical manifest digest. + +**Layer descriptor annotations:** + +| Key | Value | Description | +| ---------------------------------- | -------------- | ---------------------------------------------------------------------------------------------------------------- | +| `run.cog.weight.size.uncompressed` | integer string | Uncompressed size of the layer's contents in bytes (sum of regular-file bytes, excluding tar headers). REQUIRED. | + +This is the only annotation layer descriptors carry. All file-level metadata (paths, per-file sizes, layer mappings) lives in the config blob (§2.3); consumers that need it MUST read the config blob. + +The uncompressed size is present at the descriptor level so consumers can make per-layer decisions (disk allocation, parallel extraction progress, partial pulls) without fetching the config blob. For compressed layers (`tar+gzip`) the descriptor's `size` is the compressed byte count; this annotation carries the uncompressed count. For uncompressed layers (`tar`) the two are approximately equal (modulo tar headers). + +### 2.6 OCI index (bundle) + +When a model uses managed weights, the push operation produces an OCI image index containing the model image manifest and all weight manifests: + +```json +{ + "schemaVersion": 2, + "mediaType": "application/vnd.oci.image.index.v1+json", + "manifests": [ + { + "mediaType": "application/vnd.oci.image.manifest.v1+json", + "digest": "sha256:image...", + "size": 1234, + "platform": { "os": "linux", "architecture": "amd64" } + }, + { + "mediaType": "application/vnd.oci.image.manifest.v1+json", + "digest": "sha256:weight...", + "size": 5678, + "artifactType": "application/vnd.cog.weight.v1", + "platform": { "os": "unknown", "architecture": "unknown" }, + "annotations": { + "run.cog.weight.name": "z-image-turbo", + "run.cog.weight.set-digest": "sha256:def456...", + "run.cog.weight.size.uncompressed": "32457803776" + } + } + ] +} +``` + +The model image gets a real platform descriptor. Weight descriptors carry both `artifactType` and `platform`: + +- **`artifactType`**: Set to `application/vnd.cog.weight.v1`. This is the OCI-standard mechanism ([image-spec descriptor](https://github.com/opencontainers/image-spec/blob/main/descriptor.md)) for identifying non-image content in an index. It enables tooling to distinguish weight manifests from runnable images without inspecting annotations. +- **`platform`**: Set to `{"os": "unknown", "architecture": "unknown"}`. Weight data is not platform-specific, but the field is included for compatibility. This follows the precedent set by [Docker BuildKit attestations](https://docs.docker.com/build/metadata/attestations/attestation-storage/), which use the same convention to prevent container runtimes from accidentally pulling non-image entries. Omitting `platform` entirely is spec-valid (the field is OPTIONAL per the OCI image-spec) but risks being filtered out by containerd's platform matcher and other tools that assume its presence. + +**Index descriptor annotations:** + +| Key | Value | Description | +| ---------------------------------- | -------------- | --------------------------------------------------------- | +| `run.cog.weight.name` | string | Weight name. REQUIRED. | +| `run.cog.weight.set-digest` | digest string | Weight set digest (§2.4). REQUIRED. | +| `run.cog.weight.size.uncompressed` | integer string | Total uncompressed size of all layers in bytes. REQUIRED. | + +These annotations exist so the index is scannable without fetching child manifests. `name` and `set-digest` identify what the weight is and enable cache lookups. `size.uncompressed` enables scheduling decisions (e.g., whether a node has enough disk space) without downloading any weight data. + +Note that `run.cog.weight.target` is intentionally omitted from the index descriptor. The target path is operational detail available in the weight manifest annotations and config blob (§2.3); it is not needed for scanning or scheduling at the index level. + +The `size` field on the descriptor itself is the size of the weight manifest JSON (per the OCI spec, descriptor `size` is always the byte count of the referenced blob). It is NOT the size of the weight data. Use `run.cog.weight.size.uncompressed` for the actual weight size. + +The binding between a model image and its weights is structural: both appear as siblings in the same OCI index. No back-reference annotation from weight to model is needed. + +### 2.7 Import behavior + +`cog weights import` hashes source files, packs layers, builds the config blob, and pushes the manifest. The lockfile is updated when the weight content changes; if nothing changed, the lockfile is unchanged. + +Because the manifest contains no volatile metadata (§2.2), identical inputs always produce an identical manifest digest. The registry handles blob and manifest dedup at the storage level -- pushing an already-existing blob or manifest is a no-op. No special client-side dedup logic is required. + +Note that the manifest digest and weight set digest (§2.4) operate at different levels. The manifest digest changes when packing changes (different layers, config, and annotations); the weight set digest does not (same files). Infra uses the weight set digest to reuse assembled weights even when the manifest differs. + +### 2.8 Registry namespace and tagging (Cog convention) + +This section describes how Cog organizes weight artifacts in a registry. The namespace layout and tagging scheme are Cog conventions, not normative requirements of the format. Other producers may organize weight artifacts differently. + +``` + # OCI index (bundle) +/weights/ # Named weight repository +/weights/: # Timestamp tag (e.g., 20260416T172707Z) +/weights/@sha256: # Immutable digest reference +``` + +**Tagging scheme:** Weight imports are tagged with the import timestamp in ISO 8601 compact format (`YYYYMMDDTHHMMSSZ`). This applies uniformly regardless of source type (HuggingFace, filesystem, HTTP, registry). The timestamp answers the question "when was this version imported?" Source-specific identifiers (HF commit SHA, S3 path, etc.) are dev-time concerns tracked in `cog.yaml` and the lockfile, not in the weight artifact. + +Cog does not automatically create `:latest` tags. The lockfile records the manifest digest for reproducibility; timestamp tags exist for human-readable listing via `cog weights list`. + +## 3. Runtime State Protocol + +> **Status: Work in progress.** The design direction is settled (filesystem markers, provider writes, consumer reads). The exact file layout and semantics are being refined as the implementation evolves. The state directory name (`.cog/` vs `.weight/` or similar) is TBD. + +### 3.1 Design + +The **provider** (platform infra in prod, `cog weights pull` + local orchestration in dev) assembles a weight directory and communicates readiness via marker files. The **consumer** (coglet) reads these markers to gate `setup()` -- blocking until all weights are ready without requiring any user code to handle the wait. In the future, per-weight markers could enable an async API where `setup()` begins processing weights (such as loading to the GPU) as they become available while others are still downloading. The consumer never writes state -- it is a pure observer. + +Filesystem markers are used instead of an HTTP API because they decouple provider and consumer in time and failure domains. The provider can write state before the container boots. The consumer can read state without the provider being alive. Either side can crash and restart independently. No orchestration, no lifecycle coupling, no retry logic. Multiple containers can share the same weight directory without complexity scaling proportionally -- they all observe the same markers. And the consumer interface is identical regardless of how the weight directory was assembled: attaching a ready-to-run cached volume, downloading all layers from scratch, fetching a diff of changed layers, or rebuilding from new layers all look the same to coglet. + +### 3.2 State markers + +The provider writes state into a `.cog/` subtree within each weight's target directory: + +``` +/.cog/ready # weight is usable +/.cog/failed # delivery failed (contents: error message) +/.cog/downloading # delivery in progress +``` + +Coglet checks these with a single `stat()` call: + +``` +1. .cog/ready exists → weight usable, proceed +2. .cog/failed exists → read error, surface it, fail +3. .cog/downloading exists → in progress, poll +4. .cog/ missing → provider hasn't started, wait with timeout +``` + +Markers are created atomically (write-to-temp + rename). A `ready` marker MUST NOT appear until all weight data is fully written and flushed to disk. + +If the weight directory is already fully assembled when mounted (e.g., reused from cache), the provider writes `ready` immediately. + +**Correctness is the provider's responsibility.** When `ready` is set, the weight directory MUST contain the exact files matching the configured weight set digest. Serving stale or mismatched weights is a catastrophic infra failure. Consumers MUST NOT verify weight content -- no checksumming, no manifest cross-checking, no redundant validation. The `ready` marker is the contract. + +### 3.3 Model image metadata + +`cog build` writes `/.cog/weights.json` into the model image. This file: + +- Signals to coglet that managed weights are active (presence = managed weights, absence = no managed weights). +- Tells coglet what weights the model expects before calling `setup()`. + +```json +{ + "weights": [ + { + "name": "z-image-turbo", + "target": "/src/weights", + "setDigest": "sha256:def456..." + } + ] +} +``` + +The `setDigest` is the weight set digest (§2.4). Coglet reads this file to know which weights to expect and where, then waits for each weight's state markers (§3.2) to report ready before invoking `setup()`. If the weight directory reports a different set digest than expected, coglet will refuse to start. + +### 3.4 Target directory constraints + +- Each weight's `target` must be unique within a model. +- Weight targets must be disjoint subtrees (no nesting). +- Both rules enforced at config validation time. +- Model code should ignore `.cog/` subdirectories in weight targets. + +## 4. Real Example: z-image-turbo (~32 GB) + +Source: [HuggingFace repo](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) with 19 files (configs, tokenizers, safetensors shards). + +**v0:** 19 weight entries in cog.yaml, 19 separate manifests, 19 blobs in the OCI index. + +**v1:** 1 weight entry, 1 manifest, 8 layers. Using a 64 MB bundle threshold: the 12 small files (configs, JSONs, tokenizer, index files -- all under 64 MB, ~16 MB total) are bundled into a single compressed layer. The 7 large files (all safetensors shards, each above 64 MB) each get their own uncompressed standalone layer: + +| Layer | Contents | Size | Format | +| ----- | --------------------------------------------------------------- | ------- | ------------ | +| 1 | Bundle: 12 small files (configs, JSONs, tokenizer, index files) | ~16 MB | compressed | +| 2 | text_encoder/model-00001-of-00003.safetensors | ~3.9 GB | uncompressed | +| 3 | text_encoder/model-00002-of-00003.safetensors | ~3.9 GB | uncompressed | +| 4 | text_encoder/model-00003-of-00003.safetensors | ~99 MB | uncompressed | +| 5 | vae/diffusion_pytorch_model.safetensors | ~167 MB | uncompressed | +| 6 | transformer/diffusion_pytorch_model-00001-of-00003.safetensors | ~9.9 GB | uncompressed | +| 7 | transformer/diffusion_pytorch_model-00002-of-00003.safetensors | ~9.9 GB | uncompressed | +| 8 | transformer/diffusion_pytorch_model-00003-of-00003.safetensors | ~4.6 GB | uncompressed | + +Layer 1 is `tar+gzip` (small compressible text files). Layers 2-8 are `tar` (large binary safetensors where compression yields negligible savings). Consumers process each layer according to its media type regardless of the producer's threshold or compression choices. + +All 8 layers are independent. An extractor can download and unpack them in any order. Layer 1 writes to paths like `config.json`, `tokenizer.json`. Layers 2-8 each write to a single path like `text_encoder/model-00001-of-00003.safetensors`. No path conflicts.