diff --git a/.gitignore b/.gitignore index 9397e2db2d..ab09cfb609 100644 --- a/.gitignore +++ b/.gitignore @@ -26,8 +26,9 @@ bin/* # Compiled test tools /weights-gen -# Test directory for weights testing +# Test directories for weights testing /test-weights/ +/test-weights-example/ weights.lock # Auto-:d version files from setuptools-scm diff --git a/AGENTS.md b/AGENTS.md index dc62e7b782..8ab31a22d9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -127,6 +127,15 @@ The main commands for working on the CLI are: - `mise run test:go` - Runs all Go unit tests - `go test ./pkg/...` - Runs tests directly with `go test` +### Building and running Go binaries + +**Never run `go build` and leave a binary in the repo.** Stray binaries bloat the repo and get accidentally committed. Follow these rules: + +- **To test execution**, use `go run ./cmd/` — no binary is produced. +- **To verify compilation**, use `go build ./cmd/` (without `-o`) — this still writes a binary to the working directory, so prefer `go vet ./cmd/` for a compile check that produces no artifact. +- **If you must produce a binary** (e.g. for integration tests), write it to a temp directory and clean up: `go build -o "$(mktemp -d)/binary" ./cmd/`. +- **For installable builds**, use `mise run build:cog` or `make install` — these have proper output paths. + ## Working on the Python SDK The Python SDK is developed in the `python/cog/` directory. It uses `uv` for virtual environments and `tox` for testing across multiple Python versions. diff --git a/cmd/cog-kong/build.go b/cmd/cog-kong/build.go new file mode 100644 index 0000000000..e5afb7054c --- /dev/null +++ b/cmd/cog-kong/build.go @@ -0,0 +1,44 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/util/console" +) + +// BuildCmd implements the "cog build" command. +type BuildCmd struct { + BuildFlags `embed:""` + + Tag string `name:"tag" short:"t" help:"A name for the built image in the form 'repository:tag'."` +} + +// Validate is called by Kong after parsing, before Run. It replaces Cobra's PreRunE. +func (cmd *BuildCmd) Validate() error { + return cmd.ValidateMutualExclusivity() +} + +// Run executes the build command. +func (cmd *BuildCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client, src *model.Source) error { + imageName := src.Config.Image + if cmd.Tag != "" { + imageName = cmd.Tag + } + if imageName == "" { + imageName = config.DockerImageName(src.ProjectDir) + } + + resolver := model.NewResolver(dockerClient, regClient) + m, err := resolver.Build(ctx, src, cmd.BuildOptions(imageName, nil)) + if err != nil { + return err + } + + console.Infof("\nImage built as %s", m.ImageRef()) + + return nil +} diff --git a/cmd/cog-kong/cli.go b/cmd/cog-kong/cli.go new file mode 100644 index 0000000000..78790ec139 --- /dev/null +++ b/cmd/cog-kong/cli.go @@ -0,0 +1,38 @@ +package main + +import ( + "context" + + "github.com/alecthomas/kong" + + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/update" + "github.com/replicate/cog/pkg/util/console" +) + +// Globals holds flags available to every command. +// The AfterApply hook replaces Cobra's PersistentPreRun. +type Globals struct { + Debug bool `name:"debug" short:"d" env:"COG_DEBUG" help:"Show debugging output."` + Registry string `name:"registry" default:"${registry_default}" env:"COG_REGISTRY_HOST" hidden:"" help:"Registry host."` + Profile bool `name:"profile" hidden:"" help:"Enable profiling."` + Version kong.VersionFlag `name:"version" short:"v" help:"Show version of Cog."` +} + +// AfterApply runs after flag parsing, before the command's Run. +// This is the Kong equivalent of Cobra's PersistentPreRun. +func (g *Globals) AfterApply(ctx context.Context) error { + if g.Debug { + global.Debug = true + console.SetLevel(console.DebugLevel) + } + if g.Profile { + global.ProfilingEnabled = true + } + global.ReplicateRegistryHost = g.Registry + + if err := update.DisplayAndCheckForRelease(ctx); err != nil { + console.Debugf("%s", err) + } + return nil +} diff --git a/cmd/cog-kong/context.go b/cmd/cog-kong/context.go new file mode 100644 index 0000000000..ed884fe9aa --- /dev/null +++ b/cmd/cog-kong/context.go @@ -0,0 +1,28 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/provider" + "github.com/replicate/cog/pkg/provider/setup" + "github.com/replicate/cog/pkg/registry" +) + +// provideDockerClient creates a Docker client, binding to the command.Command interface. +// Registered as a singleton provider so all commands share one connection. +func provideDockerClient(ctx context.Context) (command.Command, error) { + return docker.NewClient(ctx) +} + +// provideRegistryClient creates a registry client, binding to the registry.Client interface. +func provideRegistryClient() registry.Client { + return registry.NewRegistryClient() +} + +// provideProviderRegistry creates a provider registry with all built-in providers. +// This replaces the setup.Init() global side-effect pattern used in the cobra CLI. +func provideProviderRegistry() *provider.Registry { + return setup.NewRegistry() +} diff --git a/cmd/cog-kong/flags.go b/cmd/cog-kong/flags.go new file mode 100644 index 0000000000..11c67b1aa7 --- /dev/null +++ b/cmd/cog-kong/flags.go @@ -0,0 +1,100 @@ +package main + +import ( + "fmt" + "os" + "strings" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/model" +) + +// ConfigFlag is an embeddable flag group for specifying the cog.yaml path. +// Any command that embeds ConfigFlag (directly or via BuildFlags) automatically +// gets a ProvideModelSource method discovered by Kong's DI system. +type ConfigFlag struct { + File string `name:"file" short:"f" default:"cog.yaml" help:"The name of the config file."` +} + +// ProvideModelSource is discovered by Kong's DI system (Provide* convention). +// It loads the model source from the config file path specified by --file. +func (c *ConfigFlag) ProvideModelSource() (*model.Source, error) { + return model.NewSource(c.File) +} + +// BuildFlags groups all flags shared across commands that build images. +// Embed this in any command struct that calls resolver.Build(). +type BuildFlags struct { + ConfigFlag `embed:""` + + NoCache bool `name:"no-cache" help:"Do not use cache when building the image."` + SeparateWeights bool `name:"separate-weights" help:"Separate model weights from code in image layers."` + Secrets []string `name:"secret" help:"Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file'."` + Progress string `name:"progress" default:"${progress_default}" enum:"auto,plain,tty,quiet" help:"Set type of build progress output: ${enum}."` + UseCudaBaseImage string `name:"use-cuda-base-image" default:"auto" enum:"auto,true,false" help:"Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image)."` + UseCogBaseImage *bool `name:"use-cog-base-image" help:"Use pre-built Cog base image for faster cold boots."` + OpenAPISchema string `name:"openapi-schema" type:"existingfile" help:"Load OpenAPI schema from a file."` + + // Hidden flags + Dockerfile string `name:"dockerfile" hidden:"" type:"existingfile" help:"Path to a Dockerfile. If set, cog will use this Dockerfile instead of generating one from cog.yaml."` + Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."` + Strip bool `name:"strip" hidden:"" help:"Whether to strip shared libraries for faster inference times."` + Precompile bool `name:"precompile" hidden:"" help:"Whether to precompile python files for faster load times."` +} + +// AfterApply syncs parsed flag values to package-level globals that the build +// pipeline reads. This runs after Kong parses flags but before Run(). +func (b *BuildFlags) AfterApply() error { + config.BuildSourceEpochTimestamp = b.Timestamp + return nil +} + +// BuildOptions constructs a model.BuildOptions from the current flag values. +// The imageName and annotations parameters vary by caller (build vs push). +func (b *BuildFlags) BuildOptions(imageName string, annotations map[string]string) model.BuildOptions { + return model.BuildOptions{ + ImageName: imageName, + Secrets: b.Secrets, + NoCache: b.NoCache, + SeparateWeights: b.SeparateWeights, + UseCudaBaseImage: b.UseCudaBaseImage, + ProgressOutput: b.Progress, + SchemaFile: b.OpenAPISchema, + DockerfileFile: b.Dockerfile, + UseCogBaseImage: b.UseCogBaseImage, + Strip: b.Strip, + Precompile: b.Precompile, + Annotations: annotations, + OCIIndex: model.OCIIndexEnabled(), + } +} + +// ValidateMutualExclusivity ensures that at most one of --use-cog-base-image, +// --use-cuda-base-image, and --dockerfile is explicitly set. +func (b *BuildFlags) ValidateMutualExclusivity() error { + var flagsSet []string + if b.UseCogBaseImage != nil { + flagsSet = append(flagsSet, "--use-cog-base-image") + } + if b.UseCudaBaseImage != "auto" { + flagsSet = append(flagsSet, "--use-cuda-base-image") + } + if b.Dockerfile != "" { + flagsSet = append(flagsSet, "--dockerfile") + } + if len(flagsSet) > 1 { + return fmt.Errorf("The flags %s are mutually exclusive: you can only set one of them", strings.Join(flagsSet, " and ")) + } + return nil +} + +// progressDefault returns the default progress output based on environment. +func progressDefault() string { + if v := os.Getenv("BUILDKIT_PROGRESS"); v != "" { + return v + } + if os.Getenv("TERM") == "dumb" { + return "plain" + } + return "auto" +} diff --git a/cmd/cog-kong/main.go b/cmd/cog-kong/main.go new file mode 100644 index 0000000000..332468b4aa --- /dev/null +++ b/cmd/cog-kong/main.go @@ -0,0 +1,108 @@ +package main + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + + "github.com/alecthomas/kong" + + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/util/console" +) + +// Build-time variables. Initialized from global defaults; overridden by -ldflags at build time. +var ( + version = global.Version + commit = global.Commit + buildTime = global.BuildTime +) + +// CLI is the root command struct. Kong parses into this. +type CLI struct { + Globals + + Build BuildCmd `cmd:"" help:"Build an image from cog.yaml."` + Push PushCmd `cmd:"" help:"Build and push model in current directory to a Docker registry."` +} + +func main() { + ctx, cancel := newCancellationContext() + + var cli CLI + + initOpts := []kong.Option{ + // CLI metadata and variable interpolation for struct tags + kong.Name("cog"), + kong.Description("Containers for machine learning."), + kong.Vars{ + "version": fmt.Sprintf("cog version %s (built %s)", version, buildTime), + "commit": commit, + "progress_default": progressDefault(), + "registry_default": global.DefaultReplicateRegistryHost, + }, + kong.UsageOnError(), + + // bindings for lazily injecting dependencies into Run() methods + kong.BindTo(ctx, (*context.Context)(nil)), + kong.BindSingletonProvider(provideDockerClient), + kong.BindToProvider(provideRegistryClient), + kong.BindSingletonProvider(provideProviderRegistry), + } + + parser, err := kong.New(&cli, initOpts...) + if err != nil { + // Fatal error creating the parser — this is a bug, so panic to get a stack trace. + panic(err) + } + + kctx, err := parser.Parse(os.Args[1:]) + + // Unable to parse input to a valid command + if err != nil { + // If the command isn't runnable (i.e. `cog`) just print help and exit 0 (matches Cobra behavior). + var parseErr *kong.ParseError + // Exit code 80 is kong's internal code for "no runnable command selected" (e.g. bare `cog` with no subcommand). + if errors.As(err, &parseErr) && parseErr.ExitCode() == 80 && strings.HasPrefix(parseErr.Error(), "expected") { + _ = parseErr.Context.PrintUsage(false) + return + } + + // otherwise it's a real parse error (e.g. unexpected command or flag), so print the error and exit non-zero. + parser.FatalIfErrorf(err) + } + + err = kctx.Run() + cancel() + // command returned an error. Print and exit non-zero. + if err != nil { + parser.FatalIfErrorf(err) + } +} + +func newCancellationContext() (context.Context, context.CancelFunc) { + // First signal cancels the context, giving commands a chance to clean up. + // Second signal force-exits immediately. + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + + go func() { + // Block until the first signal cancels the context. + <-ctx.Done() + + // Now register for the second signal after the first one has been received. + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGTERM) + + console.Debugf("Shutting down. Signal again to force quit.") + + <-sig + console.Warnf("Forced exit") + os.Exit(1) + }() + + return ctx, cancel +} diff --git a/cmd/cog-kong/push.go b/cmd/cog-kong/push.go new file mode 100644 index 0000000000..9a92411024 --- /dev/null +++ b/cmd/cog-kong/push.go @@ -0,0 +1,86 @@ +package main + +import ( + "context" + "fmt" + + "github.com/replicate/go/uuid" + + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/provider" + "github.com/replicate/cog/pkg/registry" + "github.com/replicate/cog/pkg/util/console" +) + +// PushCmd implements the "cog push" command. +type PushCmd struct { + BuildFlags `embed:""` + + Image string `arg:"" optional:"" help:"Image name to push (e.g. registry.example.com/user/model)."` +} + +// Validate is called by Kong after parsing, before Run. +func (cmd *PushCmd) Validate() error { + return cmd.ValidateMutualExclusivity() +} + +// Run executes the push command: build then push. +func (cmd *PushCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client, providerReg *provider.Registry, src *model.Source) error { + imageName := src.Config.Image + if cmd.Image != "" { + imageName = cmd.Image + } + + if imageName == "" { + return fmt.Errorf("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push registry.example.com/your-username/model-name'") + } + + // Look up the provider for the target registry + p := providerReg.ForImage(imageName) + if p == nil { + return fmt.Errorf("no provider found for image '%s'", imageName) + } + + pushOpts := provider.PushOptions{ + Image: imageName, + Config: src.Config, + ProjectDir: src.ProjectDir, + } + + // Generate a push ID for annotations + buildID, _ := uuid.NewV7() + annotations := map[string]string{} + if buildID.String() != "" { + annotations["run.cog.push_id"] = buildID.String() + } + + // Build the model + resolver := model.NewResolver(dockerClient, regClient) + m, err := resolver.Build(ctx, src, cmd.BuildOptions(imageName, annotations)) + if err != nil { + _ = p.PostPush(ctx, pushOpts, err) + return err + } + + // Log weights info + weights := m.WeightArtifacts() + if len(weights) > 0 { + console.Infof("\n%d weight artifact(s)", len(weights)) + } + + // Push the model + console.Infof("\nPushing image '%s'...", m.ImageRef()) + pushErr := resolver.Push(ctx, m, model.PushOptions{}) + + // PostPush: the provider handles formatting errors and showing success messages + if err := p.PostPush(ctx, pushOpts, pushErr); err != nil { + return err + } + + if pushErr != nil { + return fmt.Errorf("failed to push image: %w", pushErr) + } + + return nil +} diff --git a/go.mod b/go.mod index 6026f8aeb4..3150c289a0 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/replicate/cog go 1.25 require ( + github.com/alecthomas/kong v1.14.0 github.com/anaskhan96/soup v1.2.5 github.com/aws/aws-sdk-go-v2 v1.39.4 github.com/aws/aws-sdk-go-v2/service/s3 v1.88.7 diff --git a/go.sum b/go.sum index d2aa4df8b8..d3fa9ab2bd 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,12 @@ github.com/VividCortex/ewma v1.2.0 h1:f58SaIzcDXrSy3kWaHNvuJgJ3Nmz59Zji6XoJR/q1o github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpHMqeKTCYkitsPqHNxTmd4SNR5r94FGM8= github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo= +github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= +github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/kong v1.14.0 h1:gFgEUZWu2ZmZ+UhyZ1bDhuutbKN1nTtJTwh19Wsn21s= +github.com/alecthomas/kong v1.14.0/go.mod h1:wrlbXem1CWqUV5Vbmss5ISYhsVPkBb1Yo7YKJghju2I= +github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs= +github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/anaskhan96/soup v1.2.5 h1:V/FHiusdTrPrdF4iA1YkVxsOpdNcgvqT1hG+YtcZ5hM= github.com/anaskhan96/soup v1.2.5/go.mod h1:6YnEp9A2yywlYdM4EgDz9NEHclocMepEtku7wg6Cq3s= github.com/anchore/go-struct-converter v0.0.0-20221118182256-c68fdcfa2092 h1:aM1rlcoLz8y5B2r4tTLMiVTrMtpfY0O8EScKJxaSaEc= @@ -152,6 +158,8 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/in-toto/in-toto-golang v0.5.0 h1:hb8bgwr0M2hGdDsLjkJ3ZqJ8JFLL/tgYdAxF/XEFBbY= github.com/in-toto/in-toto-golang v0.5.0/go.mod h1:/Rq0IZHLV7Ku5gielPT4wPHJfH1GdHMCq8+WPxw8/BE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= diff --git a/pkg/provider/setup/setup.go b/pkg/provider/setup/setup.go index 88f96de05f..427afd60ed 100644 --- a/pkg/provider/setup/setup.go +++ b/pkg/provider/setup/setup.go @@ -11,16 +11,25 @@ import ( var once sync.Once -// Init initializes the default provider registry with all built-in providers -// This function is idempotent - it only runs once even if called multiple times -func Init() { - once.Do(func() { - registry := provider.DefaultRegistry() +// registerBuiltinProviders registers all built-in providers on the given registry. +// Providers are registered in priority order: Replicate first (more specific), +// then Generic as a fallback for any OCI registry. +func registerBuiltinProviders(reg *provider.Registry) { + reg.Register(replicate.New()) + reg.Register(generic.New()) +} - // Register Replicate provider first (more specific) - registry.Register(replicate.New()) +// NewRegistry creates a new provider registry with all built-in providers registered. +func NewRegistry() *provider.Registry { + reg := provider.NewRegistry() + registerBuiltinProviders(reg) + return reg +} - // Register Generic provider last (fallback for any OCI registry) - registry.Register(generic.New()) +// Init initializes the default provider registry with all built-in providers. +// This function is idempotent - it only runs once even if called multiple times. +func Init() { + once.Do(func() { + registerBuiltinProviders(provider.DefaultRegistry()) }) }