diff --git a/dirs.go b/dirs.go index 60ad6d3..a582a85 100644 --- a/dirs.go +++ b/dirs.go @@ -15,6 +15,7 @@ import ( "strings" "sync" + "aslevy.com/go-doc/internal/dlog" "golang.org/x/mod/semver" ) @@ -41,6 +42,7 @@ var dirs Dirs // dirsInit starts the scanning of package directories in GOROOT and GOPATH. Any // extra paths passed to it are included in the channel. func dirsInit(extra ...Dir) { + dlog.Printf("GOROOT: %s", buildCtx.GOROOT) if buildCtx.GOROOT == "" { stdout, err := exec.Command("go", "env", "GOROOT").Output() if err != nil { diff --git a/dirsextra.go b/dirsextra.go index 59048b5..45f743f 100644 --- a/dirsextra.go +++ b/dirsextra.go @@ -1,9 +1,33 @@ package main import ( + "bytes" + "log" + "os/exec" + "aslevy.com/go-doc/internal/godoc" ) +var GOMODCACHE, GOMOD string + +func init() { + stdout, err := exec.Command("go", "env", "GOROOT", "GOMODCACHE", "GOMOD").Output() + if err != nil { + if ee, ok := err.(*exec.ExitError); ok && len(ee.Stderr) > 0 { + log.Fatalf("failed to determine GOROOT: 'go env GOROOT' failed:\n%s", ee.Stderr) + } + log.Fatalf("failed to determine GOROOT: $GOROOT is not set and could not run 'go env GOROOT':\n\t%s", err) + } + + lines := bytes.Split(stdout, []byte("\n")) + if len(lines) < 3 { + panic("failed to parse stdout from `go env GOROOT GOMODCACHE GOMOD`\n" + string(stdout)) + } + buildCtx.GOROOT = string(bytes.TrimSpace(lines[0])) + GOMODCACHE = string(bytes.TrimSpace(lines[1])) + GOMOD = string(bytes.TrimSpace(lines[2])) +} + var xdirs godoc.Dirs = dirs.PackageDirs() func (dirs *Dirs) PackageDirs() *PackageDirs { return (*PackageDirs)(dirs) } diff --git a/doc_test.go b/doc_test.go index 2104112..94d7a43 100644 --- a/doc_test.go +++ b/doc_test.go @@ -15,7 +15,7 @@ import ( "strings" "testing" - "aslevy.com/go-doc/internal/index" + "aslevy.com/go-doc/internal/modpkg" ) func TestMain(m *testing.M) { @@ -31,9 +31,9 @@ func TestMain(m *testing.M) { panic(err) } - os.Setenv("GODOC_FORMAT", "text") // Use the text format. - os.Setenv("GODOC_PAGER", "-") // Disable paging. - os.Setenv(index.SyncEnvVar, index.ModeOff) // Disable index. + os.Setenv("GODOC_FORMAT", "text") // Use the text format. + os.Setenv("GODOC_PAGER", "-") // Disable paging. + os.Setenv(modpkg.SyncEnvVar, modpkg.ModeOff) // Disable modpkg. dirsInit( Dir{importPath: "testdata", dir: testdataDir}, diff --git a/go.mod b/go.mod index 246a5c2..68894a8 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,8 @@ require ( github.com/davecgh/go-spew v1.1.1 github.com/muesli/reflow v0.3.0 github.com/muesli/termenv v0.13.0 + github.com/onsi/ginkgo/v2 v2.12.0 + github.com/onsi/gomega v1.27.10 github.com/schollz/progressbar/v3 v3.13.1 github.com/stretchr/testify v1.8.1 golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c @@ -19,10 +21,14 @@ require ( ) require ( - github.com/aymanbagabas/go-osc52 v1.2.1 // indirect + github.com/aymanbagabas/go-osc52 v1.0.3 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/dlclark/regexp2 v1.7.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-logr/logr v1.2.4 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/css v1.0.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect @@ -41,6 +47,8 @@ require ( golang.org/x/net v0.30.0 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/term v0.25.0 // indirect + golang.org/x/text v0.19.0 // indirect + golang.org/x/tools v0.26.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/gc/v3 v3.0.0-20241004144649-1aea3fae8852 // indirect modernc.org/libc v1.61.0 // indirect diff --git a/go.sum b/go.sum index 600ef2a..9580e70 100644 --- a/go.sum +++ b/go.sum @@ -6,9 +6,8 @@ github.com/alecthomas/assert/v2 v2.2.0 h1:f6L/b7KE2bfA+9O4FL3CM/xJccDEwPVYd5fALB github.com/alecthomas/assert/v2 v2.2.0/go.mod h1:b/+1DI2Q6NckYi+3mXyH3wFb8qG37K/DuK80n7WefXA= github.com/alecthomas/repr v0.1.0 h1:ENn2e1+J3k09gyj2shc0dHr/yjaWSHRlrJ4DPMevDqE= github.com/alecthomas/repr v0.1.0/go.mod h1:2kn6fqh/zIyPLmm3ugklbEi5hg5wS435eygvNfaDQL8= +github.com/aymanbagabas/go-osc52 v1.0.3 h1:DTwqENW7X9arYimJrPeGZcV0ln14sGMt3pHZspWD+Mg= github.com/aymanbagabas/go-osc52 v1.0.3/go.mod h1:zT8H+Rk4VSabYN90pWyugflM3ZhpTZNC7cASDfUCdT4= -github.com/aymanbagabas/go-osc52 v1.2.1 h1:q2sWUyDcozPLcLabEMd+a+7Ea2DitxZVN9hTxab9L4E= -github.com/aymanbagabas/go-osc52 v1.2.1/go.mod h1:zT8H+Rk4VSabYN90pWyugflM3ZhpTZNC7cASDfUCdT4= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -19,6 +18,14 @@ github.com/dlclark/regexp2 v1.7.0 h1:7lJfhqlPssTb1WQx4yvTHN0uElPEv52sbaECrAQxjAo github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -53,6 +60,10 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/onsi/ginkgo/v2 v2.12.0 h1:UIVDowFPwpg6yMUpPjGkYvf06K3RAiJXUhCxEwQVHRI= +github.com/onsi/ginkgo/v2 v2.12.0/go.mod h1:ZNEzXISYlqpb8S36iN71ifqLi3vVD1rVJGvWRCJOUpQ= +github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= +github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= @@ -67,6 +78,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= @@ -96,9 +108,13 @@ golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/completion/match.go b/internal/completion/match.go index 2ea94bd..a6c9174 100644 --- a/internal/completion/match.go +++ b/internal/completion/match.go @@ -14,6 +14,8 @@ const indent = " " // using Zsh parameter expansion and avoids external dependencies like jq for // parsing JSON. type Match struct { + Tag Tag + Pkg string Type string Match string @@ -22,16 +24,14 @@ type Match struct { DisplayIndent bool Describe string - - Tag string } // String returns the string representation of the match which is the following // format. Empty fields are omitted if shown in [brackets]. // -// [:][[.].]:: +// [:][[.].]:: // -// Note: If m.Tag is TagStructFields or TagInterfaceMethods, `.` is also +// Note: If m.Tag is TagStructField or TagInterfaceMethod, `.` is also // prepended to ``. func (m Match) String() string { var match string @@ -50,7 +50,7 @@ func (m Match) String() string { display = indent } switch m.Tag { - case TagStructFields, TagInterfaceMethods: + case TagStructField, TagInterfaceMethod: // Because of the way struct fields and interface methods are // rendered, there is nothing which identifies their associated // type, so we need to add the type prefix so the user can @@ -124,60 +124,63 @@ func WithTag(tag Tag) MatchOption { type Tag = string const ( - // TagPackages contains matches for packages. - TagPackages Tag = "packages" + // TagPackage contains matches for packages. + TagPackage Tag = "package" + + // TagPackageDir contains matches for package directories. + TagPackageDir Tag = "package-dir" - // TagConsts contains the first const in each non-typed const group + // TagConst contains the first const in each non-typed const group // declaration, just as go doc displays consts in the package summary. // // Typed const groups are shown under the types tag with the given // type, just as go doc organizes them. - TagConsts Tag = "consts" + TagConst Tag = "const" - // TagAllConsts contains all consts, including subsequent names in + // TagConstAll contains all consts, including subsequent names in // grouped const declarations, and typed consts. // // Since any const name in a const group will return the same output // from go doc, this tag should only be checked last as a fallback. - TagAllConsts Tag = "all-consts" + TagConstAll Tag = "const-all" - // TagVars contains the first var in each non-typed var group + // TagVar contains the first var in each non-typed var group // declaration, just as go doc displays vars in the package summary. // // Typed var groups are shown under the types tag with the given type, // just as go doc organizes them. - TagVars Tag = "vars" + TagVar Tag = "var" - // TagAllVars contains all vars, including subsequent names in grouped + // TagVarAll contains all vars, including subsequent names in grouped // var declarations, and typed vars. // // Since any var name in a var group will return the same output from // go doc, this tag should only be checked last as a fallback. - TagAllVars Tag = "all-vars" + TagVarAll Tag = "var-all" - // TagFuncs contains all functions in the package, except for factory + // TagFunc contains all functions in the package, except for factory // functions for exported types, which are listed under the types tag // with the type they provide. - TagFuncs Tag = "funcs" + TagFunc Tag = "func" - // TagTypes contains all types with their associated var and const + // TagType contains all types with their associated var and const // declarations and factory functions. - TagTypes Tag = "types" + TagType Tag = "type" - // TagTypeMethods contains all methods in the form ".". - TagTypeMethods Tag = "type-methods" + // TagTypeMethod contains all methods in the form ".". + TagTypeMethod Tag = "type-method" - // TagInterfaceMethods contains all interface methods in the form + // TagInterfaceMethod contains all interface methods in the form // "." - TagInterfaceMethods Tag = "interface-methods" + TagInterfaceMethod Tag = "interface-method" - // TagStructFields contains all struct fields in the form + // TagStructField contains all struct fields in the form // "." - TagStructFields Tag = "struct-fields" + TagStructField Tag = "struct-field" - // TagMethods contains all methods without the preceding "." + // TagMethod contains all methods without the preceding "." // // Usually these should only be shown after no other matches have been // found. - TagMethods Tag = "methods" + TagMethod Tag = "method" ) diff --git a/internal/completion/pkgs.go b/internal/completion/pkgs.go index 8b916b7..b4c02ea 100644 --- a/internal/completion/pkgs.go +++ b/internal/completion/pkgs.go @@ -78,7 +78,7 @@ func (c Completer) completePackageImportPaths(partial string) (matched bool) { match, WithDisplay(dir.ImportPath), WithDescription(desc), - WithTag(TagPackages), + WithTag(TagPackage), )) } return @@ -196,7 +196,7 @@ func (c Completer) completePackageFilePaths(partial string) (matched bool) { pkgDir = "." + sep + pkgDir } - c.suggest(NewMatch(pkgDir, WithDescription(desc), WithTag(TagPackages))) + c.suggest(NewMatch(pkgDir, WithDescription(desc), WithTag(TagPackage))) } return } diff --git a/internal/completion/symbols.go b/internal/completion/symbols.go index e445ccd..51fcd51 100644 --- a/internal/completion/symbols.go +++ b/internal/completion/symbols.go @@ -17,10 +17,10 @@ func (c Completer) completeSymbol(pkg godoc.PackageInfo, partialSymbol string) ( values := make([]*doc.Value, 0, len(pkgDoc.Consts)+len(pkgDoc.Vars)) values = append(values, pkgDoc.Consts...) values = append(values, pkgDoc.Vars...) - tag := TagConsts + tag := TagConst for i, value := range values { if i == len(pkgDoc.Consts) { - tag = TagVars + tag = TagVar } tag := tag var passName bool @@ -56,14 +56,14 @@ func (c Completer) completeSymbol(pkg godoc.PackageInfo, partialSymbol string) ( // their respective type. continue } - matched = c.suggestIfMatchPrefix(pkg, partialSymbol, fnc.Name, fnc.Doc, fnc.Decl, false, WithTag(TagFuncs)) || matched + matched = c.suggestIfMatchPrefix(pkg, partialSymbol, fnc.Name, fnc.Doc, fnc.Decl, false, WithTag(TagFunc)) || matched } // TYPES for _, typ := range pkgDoc.Types { // Suggest the type itself. typSpec := pkg.FindTypeSpec(typ.Decl, typ.Name) - matched = c.suggestIfMatchPrefix(pkg, partialSymbol, typ.Name, typ.Doc, typSpec, false, WithTag(TagTypes)) || matched + matched = c.suggestIfMatchPrefix(pkg, partialSymbol, typ.Name, typ.Doc, typSpec, false, WithTag(TagType)) || matched // Typed consts and vars. values := make([]*doc.Value, 0, len(typ.Consts)+len(typ.Vars)) @@ -71,7 +71,7 @@ func (c Completer) completeSymbol(pkg godoc.PackageInfo, partialSymbol string) ( values = append(values, typ.Vars...) for _, value := range values { for _, name := range value.Names { - matched = c.suggestIfMatchPrefix(pkg, partialSymbol, name, value.Doc, value.Decl, false, WithTag(TagTypes), WithDisplayIndent(true)) || matched + matched = c.suggestIfMatchPrefix(pkg, partialSymbol, name, value.Doc, value.Decl, false, WithTag(TagType), WithDisplayIndent(true)) || matched // Remaining names were already suggested under // all-consts and all-vars above. break @@ -80,7 +80,7 @@ func (c Completer) completeSymbol(pkg godoc.PackageInfo, partialSymbol string) ( // Constructors for _, fnc := range typ.Funcs { - matched = c.suggestIfMatchPrefix(pkg, partialSymbol, fnc.Name, fnc.Doc, fnc.Decl, false, WithTag(TagTypes), WithDisplayIndent(true)) || matched + matched = c.suggestIfMatchPrefix(pkg, partialSymbol, fnc.Name, fnc.Doc, fnc.Decl, false, WithTag(TagType), WithDisplayIndent(true)) || matched } if !c.isExported(typ.Name) { @@ -90,7 +90,7 @@ func (c Completer) completeSymbol(pkg godoc.PackageInfo, partialSymbol string) ( // Methods without the preceding `.` for _, method := range typ.Methods { - matched = c.suggestIfMatchPrefix(pkg, partialSymbol, method.Name, method.Doc, method.Decl, false, WithTag(TagMethods)) || matched + matched = c.suggestIfMatchPrefix(pkg, partialSymbol, method.Name, method.Doc, method.Decl, false, WithTag(TagMethod)) || matched } } @@ -128,7 +128,7 @@ func (c Completer) completeTypeDotMethodOrField(pkg godoc.PackageInfo, docTyp *d // Type Methods (.) for _, method := range docTyp.Methods { - matched = c.suggestIfMatchPrefix(pkg, partial, method.Name, method.Doc, method.Decl, false, withType, WithTag(TagTypeMethods)) || matched + matched = c.suggestIfMatchPrefix(pkg, partial, method.Name, method.Doc, method.Decl, false, withType, WithTag(TagTypeMethod)) || matched } // Interface and struct types require special handling. @@ -141,7 +141,7 @@ func (c Completer) completeTypeDotMethodOrField(pkg godoc.PackageInfo, docTyp *d continue } name := iMethod.Names[0].Name - matched = c.suggestIfMatchPrefix(pkg, partial, name, iMethod.Doc.Text(), iMethod, false, withType, WithTag(TagInterfaceMethods)) || matched + matched = c.suggestIfMatchPrefix(pkg, partial, name, iMethod.Doc.Text(), iMethod, false, withType, WithTag(TagInterfaceMethod)) || matched } // An interface has no fields or other methods so we are done // with this type. @@ -152,7 +152,7 @@ func (c Completer) completeTypeDotMethodOrField(pkg godoc.PackageInfo, docTyp *d for _, field := range typ.Fields.List { docs := field.Doc.Text() for _, name := range field.Names { - matched = c.suggestIfMatchPrefix(pkg, partial, name.Name, docs, field, false, withType, WithTag(TagStructFields)) || matched + matched = c.suggestIfMatchPrefix(pkg, partial, name.Name, docs, field, false, withType, WithTag(TagStructField)) || matched } } } diff --git a/internal/dlog/log.go b/internal/dlog/log.go index eac8d14..f8f3e2c 100644 --- a/internal/dlog/log.go +++ b/internal/dlog/log.go @@ -60,6 +60,7 @@ func EnableFlag() flag.Value { return defaultLogger.EnableFlag() } func Print(v ...any) { defaultLogger.Output(2, fmt.Sprint(v...)) } func Printf(format string, v ...any) { defaultLogger.Output(2, fmt.Sprintf(format, v...)) } func Println(v ...any) { defaultLogger.Output(2, fmt.Sprintln(v...)) } +func Output(calldepth int, s string) { defaultLogger.Output(calldepth+2, s) } func Dump(v ...any) { defaultLogger.Dump(v...) } func Child(prefix string) Logger { return defaultLogger.Child(prefix) } @@ -90,6 +91,8 @@ type Logger interface { Println(...any) SetOutput(io.Writer) + Output(calldepth int, s string) + // Dump prints the spew representation of the arguments. Dump(...any) } @@ -112,6 +115,8 @@ func newLogger(output io.Writer, prefix string, flag int) *logger { } } +func (l *logger) Output(calldepth int, s string) { l.Logger.Output(calldepth+2, s) } + func (l *logger) Child(child string) Logger { prefix := l.Prefix() if child != "" { diff --git a/internal/flags/flags.go b/internal/flags/flags.go index 1643e56..e04ec21 100644 --- a/internal/flags/flags.go +++ b/internal/flags/flags.go @@ -11,7 +11,6 @@ import ( "aslevy.com/go-doc/internal/completion" "aslevy.com/go-doc/internal/dlog" "aslevy.com/go-doc/internal/godoc" - "aslevy.com/go-doc/internal/index" "aslevy.com/go-doc/internal/install" "aslevy.com/go-doc/internal/open" "aslevy.com/go-doc/internal/outfmt" @@ -26,7 +25,6 @@ func addAllFlags(fs *flag.FlagSet) { godoc.AddFlags(fs) pager.AddFlags(fs) open.AddFlags(fs) - index.AddFlags(fs) outfmt.AddFlags(fs) } diff --git a/internal/godoc/dirs.go b/internal/godoc/dirs.go index c482dd9..90a7955 100644 --- a/internal/godoc/dirs.go +++ b/internal/godoc/dirs.go @@ -4,19 +4,49 @@ package godoc -import "errors" - -// A PackageDir describes a directory holding code by specifying -// the expected import path and the file system directory. +import ( + "errors" + "path/filepath" + "strings" +) + +// A PackageDir describes a directory holding code by specifying the expected +// import path and the file system directory. type PackageDir struct { ImportPath string // import path for that dir Dir string // file system directory + Version string // module version (if applicable) +} + +type PackageDirOption func(*PackageDir) + +func WithVersion(version string) PackageDirOption { + return func(pkg *PackageDir) { + pkg.Version = version + } } -func NewPackageDir(importPath, dir string) PackageDir { return PackageDir{importPath, dir} } +func NewPackageDir(importPath, dir string, opts ...PackageDirOption) PackageDir { + pkgDir := PackageDir{ + ImportPath: importPath, + Dir: dir, + } + for _, opt := range opts { + opt(&pkgDir) + } + if pkgDir.Version == "" { + pkgDir.Version = parseVersionFromDir(dir) + } + return pkgDir +} +func parseVersionFromDir(dir string) string { + base := filepath.Base(dir) + _, version, _ := strings.Cut(base, "@") + return version +} -// Dirs exposes the functionality of the cmd/go-doc.Dirs type that is -// needed by the completion package. +// Dirs exposes the functionality of the cmd/go-doc.Dirs type that is needed by +// the completion package. type Dirs interface { // Next returns the next PackageDir in the list of packages. Next() (PackageDir, bool) diff --git a/internal/index/bench-summary.txt b/internal/index/bench-summary.txt deleted file mode 100644 index 1c3492d..0000000 --- a/internal/index/bench-summary.txt +++ /dev/null @@ -1,8 +0,0 @@ -BenchmarkLoadForceSync_stdlib-10 42 28325832 ns/op 3654710 B/op 39460 allocs/op -BenchmarkLoadReSync_stdlib-10 4437 259719 ns/op 21227 B/op 501 allocs/op -BenchmarkLoadSkipSync_stdlib-10 13027 91987 ns/op 10872 B/op 214 allocs/op -BenchmarkLoadSync_InMemory_stdlib-10 26 41634599 ns/op 4259348 B/op 60304 allocs/op -BenchmarkLoadSync_stdlib-10 4345 270973 ns/op 22168 B/op 518 allocs/op -BenchmarkRandomPartialSearchPath-10 1000000 1150 ns/op 80 B/op 6 allocs/op -BenchmarkSearch_exact_stdlib-10 15174 79066 ns/op 2531 B/op 60 allocs/op -BenchmarkSearch_partials_stdlib-10 5442 207300 ns/op 13486 B/op 311 allocs/op diff --git a/internal/index/benchmarks.txt b/internal/index/benchmarks.txt deleted file mode 100644 index 77a7395..0000000 --- a/internal/index/benchmarks.txt +++ /dev/null @@ -1,48 +0,0 @@ -goos: darwin -goarch: arm64 -pkg: aslevy.com/go-doc/internal/index -BenchmarkSearch_partials_stdlib-10 5442 207300 ns/op 13486 B/op 311 allocs/op ---- BENCH: BenchmarkSearch_partials_stdlib-10 - search_test.go:166: num matches: 1 - search_test.go:166: num matches: 2 - search_test.go:166: num matches: 1 -BenchmarkSearch_exact_stdlib-10 15174 79066 ns/op 2531 B/op 60 allocs/op ---- BENCH: BenchmarkSearch_exact_stdlib-10 - search_test.go:166: num matches: 1 - search_test.go:166: num matches: 1 - search_test.go:166: num matches: 1 - search_test.go:166: num matches: 1 -BenchmarkRandomPartialSearchPath-10 1000000 1150 ns/op 80 B/op 6 allocs/op ---- BENCH: BenchmarkRandomPartialSearchPath-10 - search_test.go:224: path: internal/pkgbits - search_test.go:224: path: internal/syscall - search_test.go:224: path: internal/ld - search_test.go:224: path: runtime/cgo -BenchmarkLoadSync_stdlib-10 4345 270973 ns/op 22168 B/op 518 allocs/op ---- BENCH: BenchmarkLoadSync_stdlib-10 - sync_test.go:44: index sync {CreatedAt:0001-01-01 00:00:00 +0000 UTC UpdatedAt:0001-01-01 00:00:00 +0000 UTC BuildRevision: GoVersion:} - sync_test.go:44: index sync {CreatedAt:2023-07-29 01:01:14 +0000 UTC UpdatedAt:2023-07-29 01:01:14 +0000 UTC BuildRevision: GoVersion:go1.20.6} - sync_test.go:44: index sync {CreatedAt:2023-07-29 01:01:15 +0000 UTC UpdatedAt:2023-07-29 01:01:15 +0000 UTC BuildRevision: GoVersion:go1.20.6} - sync_test.go:44: index sync {CreatedAt:2023-07-29 01:01:15 +0000 UTC UpdatedAt:2023-07-29 01:01:16 +0000 UTC BuildRevision: GoVersion:go1.20.6} - sync_test.go:44: index sync {CreatedAt:2023-07-29 01:01:16 +0000 UTC UpdatedAt:2023-07-29 01:01:17 +0000 UTC BuildRevision: GoVersion:go1.20.6} -BenchmarkLoadSync_InMemory_stdlib-10 26 41634599 ns/op 4259348 B/op 60304 allocs/op ---- BENCH: BenchmarkLoadSync_InMemory_stdlib-10 - sync_test.go:63: index sync {CreatedAt:0001-01-01 00:00:00 +0000 UTC UpdatedAt:0001-01-01 00:00:00 +0000 UTC BuildRevision: GoVersion:} - sync_test.go:63: index sync {CreatedAt:0001-01-01 00:00:00 +0000 UTC UpdatedAt:0001-01-01 00:00:00 +0000 UTC BuildRevision: GoVersion:} -BenchmarkLoadReSync_stdlib-10 4437 259719 ns/op 21227 B/op 501 allocs/op ---- BENCH: BenchmarkLoadReSync_stdlib-10 - sync_test.go:90: index sync {CreatedAt:2023-07-29 01:01:18 +0000 UTC UpdatedAt:2023-07-29 01:01:18 +0000 UTC BuildRevision: GoVersion:go1.20.6} - sync_test.go:90: index sync {CreatedAt:2023-07-29 01:01:18 +0000 UTC UpdatedAt:2023-07-29 01:01:18 +0000 UTC BuildRevision: GoVersion:go1.20.6} - sync_test.go:90: index sync {CreatedAt:2023-07-29 01:01:18 +0000 UTC UpdatedAt:2023-07-29 01:01:19 +0000 UTC BuildRevision: GoVersion:go1.20.6} -BenchmarkLoadForceSync_stdlib-10 42 28325832 ns/op 3654710 B/op 39460 allocs/op ---- BENCH: BenchmarkLoadForceSync_stdlib-10 - sync_test.go:117: index sync {CreatedAt:0001-01-01 00:00:00 +0000 UTC UpdatedAt:0001-01-01 00:00:00 +0000 UTC BuildRevision: GoVersion:} - sync_test.go:117: index sync {CreatedAt:0001-01-01 00:00:00 +0000 UTC UpdatedAt:0001-01-01 00:00:00 +0000 UTC BuildRevision: GoVersion:} -BenchmarkLoadSkipSync_stdlib-10 13027 91987 ns/op 10872 B/op 214 allocs/op ---- BENCH: BenchmarkLoadSkipSync_stdlib-10 - sync_test.go:144: index sync {CreatedAt:0001-01-01 00:00:00 +0000 UTC UpdatedAt:0001-01-01 00:00:00 +0000 UTC BuildRevision: GoVersion:} - sync_test.go:144: index sync {CreatedAt:0001-01-01 00:00:00 +0000 UTC UpdatedAt:0001-01-01 00:00:00 +0000 UTC BuildRevision: GoVersion:} - sync_test.go:144: index sync {CreatedAt:0001-01-01 00:00:00 +0000 UTC UpdatedAt:0001-01-01 00:00:00 +0000 UTC BuildRevision: GoVersion:} - sync_test.go:144: index sync {CreatedAt:0001-01-01 00:00:00 +0000 UTC UpdatedAt:0001-01-01 00:00:00 +0000 UTC BuildRevision: GoVersion:} -PASS -ok aslevy.com/go-doc/internal/index 13.769s diff --git a/internal/index/dirs.go b/internal/index/dirs.go deleted file mode 100644 index 938c218..0000000 --- a/internal/index/dirs.go +++ /dev/null @@ -1,87 +0,0 @@ -package index - -import ( - "context" - - "aslevy.com/go-doc/internal/godoc" - "golang.org/x/sync/errgroup" -) - -type Dirs struct { - idx *Index - - searchPath string - searchPartial bool - g *errgroup.Group - cancel context.CancelFunc - - next chan godoc.PackageDir - results []godoc.PackageDir - offset int -} - -var _ godoc.Dirs = (*Dirs)(nil) - -func NewDirs(pkgIdx *Index) godoc.Dirs { - return &Dirs{idx: pkgIdx} -} - -func (d *Dirs) Reset() { d.offset = 0 } -func (d *Dirs) Next() (pkg godoc.PackageDir, ok bool) { - if d.offset < len(d.results) { - pkg := d.results[d.offset] - d.offset++ - return pkg, true - } - - pkg, ok = <-d.next - if ok { - d.results = append(d.results, pkg) - d.offset++ - } - return pkg, ok -} -func (d *Dirs) FilterExact(path string) error { return d.filter(path) } -func (d *Dirs) FilterPartial(path string) error { return d.filter(path, WithMatchPartials()) } -func (d *Dirs) filter(path string, opts ...SearchOption) error { - o := newSearchOptions(opts...) - if d.searchPath == path && d.searchPartial == o.matchPartials { - return nil - } - - d.Reset() - if d.cancel != nil { - d.cancel() - d.g.Wait() - } - d.results = d.results[:0] - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - - rows, err := d.idx.searchRows(ctx, path, opts...) - if err != nil { - return err - } - - d.searchPath = path - d.searchPartial = o.matchPartials - d.next = make(chan godoc.PackageDir) - - d.cancel = cancel - d.g, ctx = errgroup.WithContext(ctx) - d.g.Go(func() error { - defer cancel() - defer close(d.next) - return scanPackageDirs(rows, func(pkg godoc.PackageDir) error { - select { - case <-ctx.Done(): - return ctx.Err() - case d.next <- pkg: - } - return nil - }) - }) - - return nil -} diff --git a/internal/index/dirs_test.go b/internal/index/dirs_test.go deleted file mode 100644 index bf361af..0000000 --- a/internal/index/dirs_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package index - -import ( - "context" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestDirs_partial(t *testing.T) { - require := require.New(t) - ctx := context.Background() - pkgIdx, err := Load(ctx, dbFilePath(t), stdlibCodeRoots(), loadOpts()) - require.NoError(err) - t.Cleanup(func() { require.NoError(pkgIdx.Close()) }) - - randomPartial, err := pkgIdx.randomPartial() - require.NoError(err) - t.Cleanup(func() { require.NoError(randomPartial.Close()) }) - - path, err := randomPartial.randomPartial() - require.NoError(err) - t.Log("filter path: ", path) - - dirs := NewDirs(pkgIdx) - require.NoError(dirs.FilterPartial(path)) - for { - pkg, ok := dirs.Next() - if !ok { - return - } - t.Log("pkg: ", pkg) - } -} - -func TestDirs_exact(t *testing.T) { - require := require.New(t) - ctx := context.Background() - pkgIdx, err := Load(ctx, dbFilePath(t), stdlibCodeRoots(), loadOpts()) - require.NoError(err) - t.Cleanup(func() { require.NoError(pkgIdx.Close()) }) - - randomPartial, err := pkgIdx.randomPartial() - require.NoError(err) - t.Cleanup(func() { require.NoError(randomPartial.Close()) }) - - path, err := randomPartial.randomPartial() - require.NoError(err) - t.Log("filter path: ", path) - - dirs := NewDirs(pkgIdx) - require.NoError(dirs.FilterExact(path)) - for { - pkg, ok := dirs.Next() - if !ok { - return - } - t.Log("pkg: ", pkg) - } -} diff --git a/internal/index/flags.go b/internal/index/flags.go deleted file mode 100644 index 234bc77..0000000 --- a/internal/index/flags.go +++ /dev/null @@ -1,42 +0,0 @@ -package index - -import ( - "flag" - "fmt" - "os" - "time" - - _dlog "aslevy.com/go-doc/internal/dlog" - "aslevy.com/go-doc/internal/flagvar" -) - -const ( - SyncEnvVar = "GODOC_INDEX_MODE" - ResyncEnvVar = "GODOC_INDEX_RESYNC" - DefaultResyncInterval = 20 * time.Minute - NoProgressBar = "GODOC_NO_PROGRESS_BAR" -) - -var ( - dlog = _dlog.Child("index") - Sync = ModeAutoSync - ResyncInterval = DefaultResyncInterval -) - -func AddFlags(fs *flag.FlagSet) { - debugDesc := "enable debug logging for index" - fs.Var(dlog.EnableFlag(), "debug-index", debugDesc) - fs.Var(dlogSearch.EnableFlag(), "debug-index-search", debugDesc+" search") - fs.Var(dlogSync.EnableFlag(), "debug-index-sync", debugDesc+" sync") - - Sync, _ = ParseMode(os.Getenv(SyncEnvVar)) - fs.Var(flagvar.Parse(&Sync, ParseMode), "index-mode", fmt.Sprintf("cached index modes: %s", modes())) - fs.DurationVar(&ResyncInterval, "index-resync", parseResyncInterval(os.Getenv(ResyncEnvVar)), "resync index if older than this duration") -} -func parseResyncInterval(s string) time.Duration { - d, err := time.ParseDuration(s) - if err != nil { - return DefaultResyncInterval - } - return d -} diff --git a/internal/index/index.go b/internal/index/index.go deleted file mode 100644 index dde474c..0000000 --- a/internal/index/index.go +++ /dev/null @@ -1,157 +0,0 @@ -package index - -import ( - "bufio" - "bytes" - "context" - "database/sql" - "fmt" - - "golang.org/x/sync/errgroup" - _ "modernc.org/sqlite" - - "aslevy.com/go-doc/internal/godoc" -) - -type Index struct { - options - - db *sql.DB - tx *sqlTx - - metadata - - cancel context.CancelFunc - g *errgroup.Group -} - -func Load(ctx context.Context, dbPath string, codeRoots []godoc.PackageDir, opts ...Option) (*Index, error) { - o := newOptions(opts...) - if o.mode == ModeOff { - return nil, nil - } - - dlog.Printf("loading %q", dbPath) - dlog.Printf("options: %+v", o) - db, err := sql.Open("sqlite", dbPath) - if err != nil { - return nil, fmt.Errorf("failed to open index database: %w", err) - } - - idx := Index{ - options: o, - db: db, - } - - if err := idx.initDB(ctx); err != nil { - return nil, err - } - - ctx, idx.cancel = context.WithCancel(ctx) - idx.g, ctx = errgroup.WithContext(ctx) - idx.g.Go(func() error { - defer idx.cancel() - return idx.syncCodeRoots(ctx, codeRoots) - }) - - return &idx, nil -} - -func (idx *Index) waitSync() error { return idx.g.Wait() } - -func (idx *Index) Close() error { - idx.cancel() - if err := idx.waitSync(); err != nil { - dlog.Printf("failed to sync: %v", err) - } - return idx.db.Close() -} - -func (idx *Index) initDB(ctx context.Context) error { - // Only proceed if the database matches our application ID. - if err := idx.assertApplicationID(ctx); err != nil { - return err - } - - // Check if the schema is up to date. - userVersion, err := idx.getUserVersion(ctx) - if err != nil { - return err - } - - if userVersion == 0 { // user version not set - return idx.applySchema(ctx) - } - - if userVersion != schemaCRC { - dlog.Printf("user_version (%d) != schema CRC (%d)", userVersion, schemaCRC) - return fmt.Errorf("database does not have the correct schema") - } - - return nil -} - -func (idx *Index) applySchema(ctx context.Context) error { - dlog.Printf("Applying schema...") - schemaVersion, err := idx.getSchemaVersion(ctx) - if err != nil { - return err - } - - if schemaVersion > 0 { - return fmt.Errorf("database schema_version (%d) is not zero", schemaVersion) - } - - if err := idx.enableForeignKeys(ctx); err != nil { - return err - } - - queries := schemaQueries() - for i, stmt := range queries { - _, err := idx.db.ExecContext(ctx, stmt) - if err != nil { - return fmt.Errorf("failed to apply schema query %d: %w", i+1, err) - } - } - - dlog.Printf("schema CRC: %d", schemaCRC) - return idx.setUserVersion(ctx, schemaCRC) -} - -// schemaQueries returns the individual queries in schema.sql. -func schemaQueries() []string { - const numQueries = 8 // number of queries in schema.sql - queries := make([]string, 0, numQueries) - scanner := bufio.NewScanner(bytes.NewReader(_schema)) - scanner.Split(sqlSplit) - for scanner.Scan() { - queries = append(queries, scanner.Text()) - } - if err := scanner.Err(); err != nil { - panic(fmt.Errorf("failed to scan schema.sql: %w", err)) - } - return queries -} -func sqlSplit(data []byte, atEOF bool) (advance int, token []byte, err error) { - defer func() { - // Trim the token of any leading or trailing whitespace. - token = bytes.TrimSpace(token) - if len(token) == 0 { - // Ensure we don't return an empty token. - token = nil - } - }() - - semiColon := bytes.Index(data, []byte(";")) - if semiColon == -1 { - // No semi-colon yet... - if atEOF { - // That's everything... - return len(data), data, nil - } - // Ask for more data so we can find the EOL. - return 0, nil, nil - } - // We found a semi-colon... - return semiColon + 1, data[:semiColon+1], nil -} diff --git a/internal/index/option.go b/internal/index/option.go deleted file mode 100644 index e103d18..0000000 --- a/internal/index/option.go +++ /dev/null @@ -1,77 +0,0 @@ -package index - -import ( - "fmt" - "strings" - "time" -) - -type Mode = string - -const ( - ModeOff Mode = "off" - ModeAutoSync = "auto" - ModeForceSync = "force" - ModeSkipSync = "skip" -) - -func modes() string { - return strings.Join([]Mode{ModeOff, ModeAutoSync, ModeForceSync, ModeSkipSync}, ", ") -} - -func ParseMode(s string) (Mode, error) { - switch s { - case ModeOff, ModeAutoSync, ModeForceSync, ModeSkipSync: - return s, nil - } - return ModeAutoSync, fmt.Errorf("invalid index mode: %q", s) -} - -type Option func(*options) -type options struct { - mode Mode - resyncInterval time.Duration - disableProgressBar bool -} - -func newOptions(opts ...Option) options { - o := defaultOptions() - WithOptions(opts...)(&o) - return o -} -func defaultOptions() options { - return options{ - mode: ModeAutoSync, - resyncInterval: DefaultResyncInterval, - } -} - -func WithOptions(opts ...Option) Option { - return func(o *options) { - for _, opt := range opts { - opt(o) - } - } -} - -func WithAuto() Option { return WithMode(ModeAutoSync) } -func WithOff() Option { return WithMode(ModeOff) } -func WithForceSync() Option { return WithMode(ModeForceSync) } -func WithSkipSync() Option { return WithMode(ModeSkipSync) } -func WithMode(mode Mode) Option { - return func(o *options) { - o.mode = mode - } -} - -func WithResyncInterval(interval time.Duration) Option { - return func(o *options) { - o.resyncInterval = interval - } -} - -func WithNoProgressBar() Option { - return func(o *options) { - o.disableProgressBar = true - } -} diff --git a/internal/index/pragma.go b/internal/index/pragma.go deleted file mode 100644 index 8dcf091..0000000 --- a/internal/index/pragma.go +++ /dev/null @@ -1,76 +0,0 @@ -package index - -import ( - "context" - "fmt" -) - -// sqliteApplicationID is the magic number used to identify sqlite3 databases -// created by this application. -// -// See https://www.sqlite.org/fileformat.html#application_id -const ( - sqliteApplicationID uint32 = 0x0_90_D0C_90 // GO DOC GO - pragmaApplicationID = "application_id" -) - -func (idx *Index) assertApplicationID(ctx context.Context) error { - appID, err := idx.getApplicationID(ctx) - if err != nil { - return err - } - if appID == 0 { // app ID not set - return idx.setApplicationID(ctx) - } - if appID != sqliteApplicationID { - return fmt.Errorf("unrecognized database") - } - return nil -} -func (idx *Index) getApplicationID(ctx context.Context) (appID uint32, _ error) { - return appID, idx.getPragma(ctx, pragmaApplicationID, &appID) -} -func (idx *Index) setApplicationID(ctx context.Context) error { - return idx.setPragma(ctx, pragmaApplicationID, sqliteApplicationID) -} - -const pragmaUserVersion = "user_version" - -func (idx *Index) getUserVersion(ctx context.Context) (userVersion uint32, _ error) { - return userVersion, idx.getPragma(ctx, pragmaUserVersion, &userVersion) -} -func (idx *Index) setUserVersion(ctx context.Context, userVersion uint32) error { - return idx.setPragma(ctx, pragmaUserVersion, userVersion) -} - -func (idx *Index) getSchemaVersion(ctx context.Context) (int, error) { - const pragmaSchemaVersion = "schema_version" - var schemaVersion int - if err := idx.getPragma(ctx, pragmaSchemaVersion, &schemaVersion); err != nil { - return 0, err - } - return schemaVersion, nil -} - -func (idx *Index) enableForeignKeys(ctx context.Context) error { - const pragmaForeignKeys = "foreign_keys" - return idx.setPragma(ctx, pragmaForeignKeys, "on") -} - -func (idx *Index) getPragma(ctx context.Context, key string, val any) error { - query := fmt.Sprintf(`PRAGMA %s;`, key) - err := idx.db.QueryRowContext(ctx, query).Scan(val) - if err != nil { - return fmt.Errorf("failed to read pragma %s: %w", key, err) - } - return nil -} - -func (idx *Index) setPragma(ctx context.Context, key string, val any) error { - query := fmt.Sprintf(`PRAGMA %s=%v;`, key, val) - _, err := idx.db.ExecContext(ctx, query) - if err != nil { - return fmt.Errorf("failed to set pragma %s=%v: %w", key, val, err) - } - return nil -} diff --git a/internal/index/progressbar.go b/internal/index/progressbar.go deleted file mode 100644 index 89eee8f..0000000 --- a/internal/index/progressbar.go +++ /dev/null @@ -1,41 +0,0 @@ -package index - -import ( - "os" - "time" - - "github.com/schollz/progressbar/v3" -) - -type progressBar interface { - ChangeMax(int) - GetMax() int - Add(int) error - Finish() error - Clear() error -} - -type nopProgressBar struct{} - -func (nopProgressBar) ChangeMax(int) {} -func (nopProgressBar) GetMax() int { return 0 } -func (nopProgressBar) Add(int) error { return nil } -func (nopProgressBar) Finish() error { return nil } -func (nopProgressBar) Clear() error { return nil } - -func newProgressBar(o options, total int, description string) progressBar { - if o.disableProgressBar { - return nopProgressBar{} - } - return progressbar.NewOptions(total, - progressbar.OptionSetDescription("package index: "+description), - progressbar.OptionSetWriter(os.Stderr), - progressbar.OptionThrottle(time.Second/3), - progressbar.OptionShowCount(), // show current count e.g. 3/5 - progressbar.OptionClearOnFinish(), // clear bar when done - progressbar.OptionSetPredictTime(false), - progressbar.OptionSetElapsedTime(false), - progressbar.OptionEnableColorCodes(true), - progressbar.OptionUseANSICodes(true), - ) -} diff --git a/internal/index/progressbar_test.go b/internal/index/progressbar_test.go deleted file mode 100644 index 93f51cb..0000000 --- a/internal/index/progressbar_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package index - -import ( - "testing" - "time" -) - -func TestProgressBar(t *testing.T) { - t.Skip() - t.Log("testing progress bar...") - pb := newProgressBar(options{}, 1000, "syncing...") - for i := 0; i < 1000; i++ { - pb.Add(1) - time.Sleep(2 * time.Millisecond) - } - pb.Finish() -} diff --git a/internal/index/schema.go b/internal/index/schema.go deleted file mode 100644 index e604a1a..0000000 --- a/internal/index/schema.go +++ /dev/null @@ -1,287 +0,0 @@ -// This file along with schema.sql define the schema for the database. -// -// For each SQL table there is a corresponding Go type and Index methods for -// selecting, inserting, or updating rows. - -package index - -import ( - "bytes" - "context" - "database/sql" - _ "embed" - "fmt" - "hash/crc32" - "runtime/debug" - "time" - - "aslevy.com/go-doc/internal/godoc" -) - -// _schema is the SQL schema for the index database. -// -//go:embed schema.sql -var _schema []byte - -var schemaCRC = func() uint32 { - crc := crc32.NewIEEE() - crc.Write(_schema) - return crc.Sum32() -}() - -type metadata struct { - CreatedAt time.Time - UpdatedAt time.Time - - BuildRevision string - GoVersion string -} - -func (idx *Index) selectMetadata(ctx context.Context) (metadata, error) { - const query = ` -SELECT createdAt, updatedAt, buildRevision, goVersion FROM metadata WHERE rowid=1; -` - return scanMetadata(idx.db.QueryRowContext(ctx, query)) -} -func scanMetadata(row sqlRow) (metadata, error) { - var meta metadata - return meta, row.Scan( - &meta.CreatedAt, - &meta.UpdatedAt, - &meta.BuildRevision, - &meta.GoVersion, - ) -} - -func (idx *Index) upsertMetadata(ctx context.Context) error { - const query = ` -INSERT INTO metadata(rowid, buildRevision, goVersion) VALUES (1, ?, ?) - ON CONFLICT(rowid) DO - UPDATE SET - updatedAt=CURRENT_TIMESTAMP, - buildRevision=excluded.buildRevision, - goVersion=excluded.goVersion; -` - if _, err := idx.tx.ExecContext(ctx, query, buildRevision, goVersion); err != nil { - return fmt.Errorf("failed to upsert metadata: %w", err) - } - return nil -} - -var buildRevision, goVersion string = func() (string, string) { - var buildRevision string - info, ok := debug.ReadBuildInfo() - if !ok { - panic("debug.ReadBuildInfo() failed") - } - for _, s := range info.Settings { - if s.Key == "vcs.revision" { - buildRevision = s.Value - break - } - } - return buildRevision, info.GoVersion -}() - -type class = int - -const ( - classStdlib class = iota - classLocal - classRequired - classNotRequired -) - -func classString(c class) string { - switch c { - case classStdlib: - return "stdlib" - case classLocal: - return "local" - case classRequired: - return "required" - case classNotRequired: - return "not required" - default: - return "unknown class" - } -} - -type module struct { - ID int64 - ImportPath string - Dir string - Class class - Vendor bool -} - -func (idx *Index) selectModule(ctx context.Context, importPath string) (module, error) { - stmt, err := idx.tx.PrepareContext(ctx, ` -SELECT rowid, importPath, dir, class, vendor FROM module WHERE importPath=?; -`) - if err != nil { - return module{}, err - } - return scanModule(stmt.QueryRowContext(ctx, importPath)) -} -func scanModule(row sqlRow) (module, error) { - var mod module - return mod, row.Scan(&mod.ID, &mod.ImportPath, &mod.Dir, &mod.Class, &mod.Vendor) -} - -type sqlRow interface { - Scan(dest ...any) error -} - -func (idx *Index) insertModule(ctx context.Context, pkgDir godoc.PackageDir, class class, vendor bool) (int64, error) { - stmt, err := idx.tx.PrepareContext(ctx, ` -INSERT INTO module (importPath, dir, class, vendor) VALUES (?, ?, ?, ?); -`) - if err != nil { - return -1, err - } - res, err := stmt.ExecContext(ctx, pkgDir.ImportPath, pkgDir.Dir, int(class), vendor) - if err != nil { - return -1, nil - } - return res.LastInsertId() -} - -func (idx *Index) updateModule(ctx context.Context, modID int64, pkgDir godoc.PackageDir, class class, vendor bool) error { - stmt, err := idx.tx.PrepareContext(ctx, ` -UPDATE module SET (dir, class, vendor) = (?, ?, ?) WHERE rowid=?; -`) - if err != nil { - return err - } - - _, err = stmt.ExecContext(ctx, pkgDir.Dir, int(class), vendor, modID) - return err -} - -func (idx *Index) pruneModules(ctx context.Context, keep []int64) error { - query := fmt.Sprintf(` -DELETE FROM module WHERE rowid NOT IN (%s); -`, placeholders(len(keep))) - _, err := idx.tx.ExecContext(ctx, query, pruneModulesArgs(keep)...) - return err -} -func placeholders(n int) string { - var buf bytes.Buffer - for i := 0; i < n; i++ { - if i > 0 { - buf.WriteByte(',') - } - buf.WriteByte('?') - } - return buf.String() -} -func pruneModulesArgs(keep []int64) []any { - args := make([]any, 0, len(keep)) - for _, id := range keep { - args = append(args, id) - } - return args -} - -type package_ struct { - ID int64 - ModuleID int64 - RelativePath string - NumParts int -} - -func (idx *Index) selectPackageID(ctx context.Context, modID int64, relativePath string) (int64, error) { - stmt, err := idx.tx.PrepareContext(ctx, ` -SELECT rowid FROM package WHERE moduleId=? AND relativePath=?; -`) - if err != nil { - return -1, err - } - var id int64 - return id, stmt.QueryRowContext(ctx, modID, relativePath).Scan(&id) -} - -func (idx *Index) insertPackage(ctx context.Context, modID int64, relativePath string) (int64, error) { - stmt, err := idx.tx.PrepareContext(ctx, ` -INSERT INTO package(moduleId, relativePath) VALUES (?, ?); -`) - if err != nil { - return -1, err - } - res, err := stmt.ExecContext(ctx, modID, relativePath) - if err != nil { - return -1, fmt.Errorf("failed to insert package: %w", err) - } - return res.LastInsertId() -} - -func (idx *Index) prunePackages(ctx context.Context, modID int64, keep []int64) error { - dlog.Printf("pruning unused packages for module %d", modID) - query := fmt.Sprintf(` -DELETE FROM package WHERE moduleId=? AND rowid NOT IN (%s); -`, placeholders(len(keep))) - _, err := idx.tx.ExecContext(ctx, query, prunePackagesArgs(modID, keep)...) - if err != nil { - return fmt.Errorf("failed to prune packages: %w", err) - } - return nil -} -func prunePackagesArgs(modID int64, keep []int64) []any { - args := make([]any, 0, len(keep)+1) - args = append(args, modID) - for _, id := range keep { - args = append(args, id) - } - return args -} - -type partial struct { - ID int64 - PackageID int64 - Parts string - NumParts int -} - -func (idx *Index) insertPartial(ctx context.Context, pkgID int64, parts string) (int64, error) { - stmt, err := idx.tx.PrepareContext(ctx, ` -INSERT INTO partial(packageId, parts) VALUES (?, ?); -`) - if err != nil { - return -1, err - } - - res, err := stmt.ExecContext(ctx, pkgID, parts) - if err != nil { - return -1, fmt.Errorf("failed to insert partial: %w", err) - } - return res.LastInsertId() -} - -type sqlTx struct { - *sql.Tx - stmts map[string]*sql.Stmt -} - -func newSqlTx(tx *sql.Tx) *sqlTx { - return &sqlTx{ - Tx: tx, - stmts: make(map[string]*sql.Stmt), - } -} - -func (tx *sqlTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { - stmt, ok := tx.stmts[query] - if ok { - return stmt, nil - } - stmt, err := tx.Tx.PrepareContext(ctx, query) - if err != nil { - return nil, err - } - tx.stmts[query] = stmt - return stmt, nil -} -func (tx *sqlTx) Prepare(query string) (*sql.Stmt, error) { - return tx.PrepareContext(context.Background(), query) -} diff --git a/internal/index/schema.sql b/internal/index/schema.sql deleted file mode 100644 index 6560e72..0000000 --- a/internal/index/schema.sql +++ /dev/null @@ -1,94 +0,0 @@ -CREATE TABLE metadata ( - rowid INTEGER PRIMARY KEY CHECK (rowid = 1), -- only one row - createdAt DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - updatedAt DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - buildRevision TEXT NOT NULL, - goVersion TEXT NOT NULL -); - -CREATE TABLE module ( - rowid INTEGER PRIMARY KEY, - importPath TEXT UNIQUE NOT NULL, - dir TEXT NOT NULL CHECK (dir != ''), -- dir must not be empty - class INT NOT NULL CHECK (class >= 0 AND class <= 3), -- 0: stdlib, 1: local, 2: required, 3: not required - vendor BOOL DEFAULT false, - numParts INT GENERATED ALWAYS AS - (length(importPath) - length(replace(importPath, '/', '')) + -- number of slashes - iif(length(importPath)>0,1,0)) -- add 1 if path is not empty - STORED -); - -CREATE INDEX module_class ON module(class, importPath); - -CREATE TABLE package ( - rowid INTEGER PRIMARY KEY, - moduleId INT REFERENCES module(rowid) - ON DELETE CASCADE - ON UPDATE CASCADE, - relativePath TEXT NOT NULL, - numParts INT GENERATED ALWAYS AS - (length(relativePath) - length(replace(relativePath, '/', '')) + -- number of slashes - iif(length(relativePath)>0,1,0)) -- add 1 if path is not empty - STORED, - - UNIQUE(moduleId, relativePath) ON CONFLICT IGNORE -); - -CREATE VIEW modulePackage AS - SELECT - package.rowid, - trim(module.importPath || '/' || package.relativePath, '/') as packageImportPath, - rtrim(module.dir || '/' || package.relativePath, '/') as packageDir, - package.moduleId, - module.importPath as moduleImportPath, - relativePath, - class, - vendor, - package.numParts as relativeNumParts, - package.numParts + module.numParts as totalNumParts - FROM package - INNER JOIN module - ON package.moduleId=module.rowid - ORDER BY - class ASC, - moduleImportPath ASC, - relativeNumParts ASC, - relativePath ASC; - -CREATE TABLE partial ( - rowid INTEGER PRIMARY KEY, - packageId INT REFERENCES package(rowid) - ON DELETE CASCADE - ON UPDATE CASCADE, - parts TEXT NOT NULL CHECK (parts != ''), -- parts must not be empty - numParts INT GENERATED ALWAYS AS - (length(parts) - length(replace(parts, '/', '')) + 1) -- number of slashes + 1 - STORED, - - UNIQUE(packageId, parts) ON CONFLICT IGNORE -); - -CREATE INDEX partial_idx_numParts_parts ON partial(numParts, parts COLLATE NOCASE); - -CREATE VIEW partialPackage AS - SELECT - package.rowid, - packageImportPath, - packageDir, - moduleId, - moduleImportPath, - class, - relativePath, - relativeNumParts, - totalNumParts, - parts, - partial.numParts as partialNumParts - FROM partial - INNER JOIN modulePackage AS package - ON partial.packageId=package.rowid - ORDER BY - partialNumParts ASC, - class ASC, - moduleImportPath ASC, - relativeNumParts ASC, - relativePath ASC; diff --git a/internal/index/search.go b/internal/index/search.go deleted file mode 100644 index 536c170..0000000 --- a/internal/index/search.go +++ /dev/null @@ -1,192 +0,0 @@ -package index - -import ( - "context" - "database/sql" - "fmt" - "strings" - - "aslevy.com/go-doc/internal/godoc" -) - -var dlogSearch = dlog.Child("search") - -type SearchOption func(*searchOptions) - -type searchOptions struct { - matchPartials bool -} - -func newSearchOptions(opts ...SearchOption) searchOptions { - var o searchOptions - WithSearchOptions(opts...)(&o) - return o -} -func WithSearchOptions(opts ...SearchOption) SearchOption { - return func(o *searchOptions) { - for _, opt := range opts { - opt(o) - } - } -} -func WithMatchPartials() SearchOption { - return func(o *searchOptions) { - o.matchPartials = true - } -} - -func (idx *Index) Search(ctx context.Context, path string, opts ...SearchOption) ([]godoc.PackageDir, error) { - rows, err := idx.searchRows(ctx, path, opts...) - if err != nil { - return nil, err - } - if rows == nil { - return nil, nil - } - - var pkgs []godoc.PackageDir - return pkgs, scanPackageDirs(rows, func(pkg godoc.PackageDir) error { - pkgs = append(pkgs, pkg) - return nil - }) -} -func scanPackageDirs(rows *sql.Rows, handler func(godoc.PackageDir) error) error { - defer rows.Close() - for rows.Next() { - pkg, err := scanPackageDir(rows) - if err != nil { - return err - } - if err := handler(pkg); err != nil { - return err - } - } - return rows.Err() -} -func scanPackageDir(row sqlRow) (godoc.PackageDir, error) { - var pkg godoc.PackageDir - var min int - return pkg, row.Scan(&pkg.ImportPath, &pkg.Dir, &min) -} - -func (idx *Index) searchRows(ctx context.Context, path string, opts ...SearchOption) (*sql.Rows, error) { - if err := idx.waitSync(); err != nil { - return nil, err - } - - query, params, err := idx.searchQueryParams(ctx, path, opts...) - if err != nil { - return nil, err - } - // dlogSearch.Printf("query: \n%s", query) - // dlogSearch.Printf("params: \n%+v", params) - return idx.db.QueryContext(ctx, query, params...) -} - -func (idx *Index) searchQueryParams(ctx context.Context, path string, opts ...SearchOption) (query string, params []any, _ error) { - where, params, err := idx.searchWhereParams(ctx, path, opts...) - if err != nil { - return "", nil, err - } - - const selectQuery = ` -SELECT - packageImportPath, - packageDir, - min(partialNumParts) -FROM - partialPackage -WHERE %s -GROUP BY packageImportPath -ORDER BY - partialNumParts ASC, - class ASC, - moduleImportPath ASC, - relativeNumParts ASC, - relativePath ASC; -` - return fmt.Sprintf(selectQuery, where), params, err -} -func (idx *Index) searchWhereParams(ctx context.Context, path string, opts ...SearchOption) (where string, params []any, _ error) { - o := newSearchOptions(opts...) - if !o.matchPartials { - return idx.searchWhereParamsExact(path) - } - return idx.searchWhereParamsPartial(ctx, path) -} -func (idx *Index) searchWhereParamsExact(path string) (where string, params []any, _ error) { - if path == "" { - return "FALSE", nil, nil - } - - where = `( - partialNumParts = ? AND - parts = ? -)` - params = []any{ - strings.Count(path, "/") + 1, - path, - } - return -} -func (idx *Index) searchWhereParamsPartial(ctx context.Context, path string) (where string, params []any, _ error) { - if path == "" { - return "TRUE", nil, nil - } - - maxParts, err := idx.maxPartialNumParts(ctx) - if err != nil { - return "", nil, err - } - - const whereQuery = `( - ? AND - partialNumParts = ? AND - parts LIKE ? -)` - numParts, like := searchLike(path) - - var queryBldr strings.Builder - for i := 1; i <= maxParts; i++ { - if queryBldr.Len() > 0 { - queryBldr.WriteString(` OR `) - } - queryBldr.WriteString(whereQuery) - - validWhere := i >= numParts - params = append(params, validWhere, i, like.String()) - - if !validWhere { - continue - } - - like.WriteString("/%") - } - return queryBldr.String(), params, nil -} -func searchLike(path string) (numParts int, like *strings.Builder) { - like = new(strings.Builder) - parts := strings.Split(path, "/") - numParts = len(parts) - like.Grow(len(path) + numParts) - for _, part := range parts { - // like must not start with "%/" to avoid causing sqlite to - // perform a full table scan. So if like is empty and the part - // is empty, move on. - // https://www.sqlite.org/optoverview.html#the_like_optimization - if like.Len() > 0 { - like.WriteByte('/') - } else if len(part) == 0 { - numParts-- - continue - } - like.WriteString(part) - like.WriteByte('%') - } - return numParts, like -} -func (idx *Index) maxPartialNumParts(ctx context.Context) (int, error) { - const query = `SELECT MAX(numParts) FROM partial;` - var max int - return max, idx.db.QueryRowContext(ctx, query).Scan(&max) -} diff --git a/internal/index/search_test.go b/internal/index/search_test.go deleted file mode 100644 index f41b727..0000000 --- a/internal/index/search_test.go +++ /dev/null @@ -1,280 +0,0 @@ -package index - -import ( - "context" - "database/sql" - "flag" - "fmt" - "go/build" - "path/filepath" - "testing" - - "aslevy.com/go-doc/internal/benchmark" - "aslevy.com/go-doc/internal/godoc" - "github.com/stretchr/testify/require" -) - -func init() { AddFlags(flag.CommandLine) } - -type indexTest struct { - name string - mods []godoc.PackageDir - searchTests []searchTest -} -type searchTest struct { - paths []string - partial bool - results []string -} - -func (test indexTest) run(t *testing.T) { - require := require.New(t) - t.Helper() - ctx := context.Background() - pkgs, err := Load(ctx, dbMem, test.mods, loadOpts()) - require.NoError(err) - t.Cleanup(func() { - require.NoError(pkgs.Close()) - }) - for _, searchTest := range test.searchTests { - searchTest.run(t, pkgs) - } -} - -func (test searchTest) run(t *testing.T, pkgs *Index) { - t.Helper() - ctx := context.Background() - for _, path := range test.paths { - name := "exact:" - var opts []SearchOption - if test.partial { - name = "partial:" - opts = append(opts, WithMatchPartials()) - } - t.Run(name+path, func(t *testing.T) { - results, err := pkgs.Search(ctx, path, opts...) - require.NoError(t, err) - require.Equal(t, test.results, importPaths(results)) - }) - } -} - -func importPaths(pkgs []godoc.PackageDir) []string { - paths := make([]string, len(pkgs)) - for i, pkg := range pkgs { - paths[i] = pkg.ImportPath - } - return paths -} - -var GOROOT = build.Default.GOROOT - -func stdlibCodeRoots() []godoc.PackageDir { - return []godoc.PackageDir{ - godoc.NewPackageDir("", filepath.Join(GOROOT, "src")), - godoc.NewPackageDir("cmd", filepath.Join(GOROOT, "src", "cmd")), - } -} - -var indexTests = []indexTest{{ - name: "stdlib", - mods: stdlibCodeRoots(), - searchTests: []searchTest{{ - paths: []string{"json", "jso"}, - partial: true, - results: []string{ - "encoding/json", - "net/rpc/jsonrpc", - }, - }, { - paths: []string{"encoding/json", "encoding/jso", "e/j"}, - partial: true, - results: []string{"encoding/json"}, - }, { - paths: []string{"http"}, - partial: true, - results: []string{"net/http", "net/http/httptest", "net/http/httptrace", "net/http/httputil", "net/http/cgi", "net/http/cookiejar", "net/http/fcgi", "net/http/internal", "net/http/pprof", "net/http/internal/ascii", "net/http/internal/testcert"}, - }, { - paths: []string{"http"}, - partial: false, - results: []string{"net/http"}, - }, { - paths: []string{"ht"}, - partial: true, - results: []string{"html", "net/http", "net/http/httptest", "net/http/httptrace", "net/http/httputil", "html/template", "net/http/cgi", "net/http/cookiejar", "net/http/fcgi", "net/http/internal", "net/http/pprof", "net/http/internal/ascii", "net/http/internal/testcert"}, - }, { - paths: []string{"a"}, - partial: true, - results: []string{"arena", "crypto/aes", "encoding/ascii85", "encoding/asn1", "go/ast", "hash/adler32", "internal/abi", "runtime/asan", "sync/atomic", "crypto/internal/alias", "runtime/internal/atomic", "net/http/internal/ascii", "runtime/race/internal/amd64v1", "runtime/race/internal/amd64v3", "cmd/addr2line", "cmd/api", "cmd/asm", "cmd/internal/archive", "cmd/asm/internal/arch", "cmd/asm/internal/asm", "cmd/compile/internal/abi", "cmd/compile/internal/abt", "cmd/compile/internal/amd64", "cmd/compile/internal/arm", "cmd/compile/internal/arm64", "cmd/go/internal/auth", "cmd/internal/obj/arm", "cmd/internal/obj/arm64", "cmd/link/internal/amd64", "cmd/link/internal/arm", "cmd/link/internal/arm64", "archive/tar", "archive/zip", "cmd/asm/internal/flags", "cmd/asm/internal/lex"}, - }, { - paths: []string{"c/a"}, - partial: true, - results: []string{"crypto/aes", "cmd/addr2line", "cmd/api", "cmd/asm", "cmd/asm/internal/arch", "cmd/asm/internal/asm", "cmd/asm/internal/flags", "cmd/asm/internal/lex"}, - }, { - paths: []string{"as"}, - partial: true, - results: []string{"encoding/ascii85", "encoding/asn1", "go/ast", "runtime/asan", "net/http/internal/ascii", "cmd/asm", "cmd/asm/internal/asm", "cmd/asm/internal/arch", "cmd/asm/internal/flags", "cmd/asm/internal/lex"}, - }}, -}} - -func TestSearch(t *testing.T) { - for _, test := range indexTests { - t.Run(test.name, test.run) - } -} - -func BenchmarkSearch_partials_stdlib(b *testing.B) { - b.Helper() - const partial = true - benchmarkSearch_stdlib(b, partial) -} -func BenchmarkSearch_exact_stdlib(b *testing.B) { - b.Helper() - const partial = false - benchmarkSearch_stdlib(b, partial) -} - -func benchmarkSearch_stdlib(b *testing.B, partial bool) { - require := require.New(b) - - ctx := context.Background() - codeRoots := stdlibCodeRoots() - - var searchOpts []SearchOption - if partial { - searchOpts = append(searchOpts, WithMatchPartials()) - } - var matches []godoc.PackageDir - var err error - - var pkgIdx *Index - var randomPartial *randomPartial - benchmark.Run(b, func() { - pkgIdx, err = Load(ctx, dbFilePath(b), codeRoots, loadOpts()) - require.NoError(err) - b.Cleanup(func() { require.NoError(pkgIdx.Close()) }) - - randomPartial, err = pkgIdx.randomPartial() - require.NoError(err) - b.Cleanup(func() { require.NoError(randomPartial.Close()) }) - }, func() { - path, err := randomPartial.randomPartial() - require.NoError(err) - matches, err = pkgIdx.Search(ctx, path, searchOpts...) - require.NoError(err) - }) - b.Log("num matches: ", len(matches)) -} - -func TestRandomPartialSearchPath(t *testing.T) { - require := require.New(t) - - ctx := context.Background() - codeRoots := stdlibCodeRoots() - - pkgIdx, err := Load(ctx, dbFilePath(t), codeRoots, loadOpts()) - require.NoError(err) - t.Cleanup(func() { require.NoError(pkgIdx.Close()) }) - - randomPartial, err := pkgIdx.randomPartial() - require.NoError(err) - t.Cleanup(func() { require.NoError(randomPartial.Close()) }) - - paths := make(map[string]struct{}) - var duplicates int - const total = 1000 - for i := 0; i < total; i++ { - path, err := randomPartial.randomPartial() - require.NoError(err) - if _, duplicate := paths[path]; duplicate { - // t.Log("duplicate path:", path, i) - duplicates++ - continue - } - // t.Log("unique path:", path, i) - paths[path] = struct{}{} - } - - t.Log("duplicates:", duplicates) - require.Less(duplicates, total/2, "too many duplicates") -} - -func BenchmarkRandomPartialSearchPath(b *testing.B) { - require := require.New(b) - var path string - var pkgIdx *Index - var randomPartial *randomPartial - var err error - benchmark.Run(b, func() { - ctx := context.Background() - codeRoots := stdlibCodeRoots() - opts := WithOptions(WithNoProgressBar()) - - pkgIdx, err = Load(ctx, dbMem, codeRoots, opts) - require.NoError(err) - b.Cleanup(func() { require.NoError(pkgIdx.Close()) }) - - randomPartial, err = pkgIdx.randomPartial() - require.NoError(err) - b.Cleanup(func() { require.NoError(randomPartial.Close()) }) - }, func() { - path, err = randomPartial.randomPartial() - require.NoError(err) - }) - b.Log("path: ", path) -} - -func (pkgIdx *Index) randomPartial() (*randomPartial, error) { - if err := pkgIdx.waitSync(); err != nil { - return nil, err - } - stmt, err := pkgIdx.db.Prepare(` -SELECT parts FROM partial ORDER BY RANDOM(); -`) - if err != nil { - return nil, err - } - rows, err := stmt.Query() - if err != nil { - return nil, err - } - return &randomPartial{ - stmt: stmt, - rows: rows, - }, nil -} - -type randomPartial struct { - stmt *sql.Stmt - rows *sql.Rows -} - -func (r *randomPartial) Close() error { - if r.stmt != nil { - if err := r.stmt.Close(); err != nil { - return err - } - } - return nil -} -func (r *randomPartial) randomPartial() (string, error) { - if !r.rows.Next() { - if err := r.rows.Close(); err != nil { - return "", err - } - rows, err := r.stmt.Query() - if err != nil { - return "", err - } - r.rows = rows - if !r.rows.Next() { - return "", fmt.Errorf("no rows") - } - } - return scanImportPath(r.rows) - -} -func scanImportPath(rows sqlRow) (string, error) { - var path string - return path, rows.Scan(&path) -} diff --git a/internal/index/sync.go b/internal/index/sync.go deleted file mode 100644 index 96c6df0..0000000 --- a/internal/index/sync.go +++ /dev/null @@ -1,276 +0,0 @@ -package index - -import ( - "context" - "database/sql" - "errors" - "fmt" - "log" - "os" - "path" - "path/filepath" - "strings" - "time" - - "aslevy.com/go-doc/internal/godoc" -) - -var dlogSync = dlog.Child("sync") - -func (idx *Index) needsSync(ctx context.Context) (bool, error) { - switch idx.options.mode { - case ModeOff, ModeSkipSync: - return false, nil - case ModeForceSync: - return true, nil - } - var err error - idx.metadata, err = idx.selectMetadata(ctx) - if ignoreErrNoRows(err) != nil { - return false, err - } - - if idx.metadata.BuildRevision != buildRevision || - idx.metadata.GoVersion != goVersion { - return true, nil - } - - dlogSync.Printf("created at: %v", idx.CreatedAt.Local()) - dlogSync.Printf("updated at: %v", idx.UpdatedAt.Local()) - return time.Since(idx.UpdatedAt) > idx.options.resyncInterval, nil -} -func ignoreErrNoRows(err error) error { - if errors.Is(err, sql.ErrNoRows) { - return nil - } - return err -} - -func (idx *Index) syncCodeRoots(ctx context.Context, codeRoots []godoc.PackageDir) (retErr error) { - needsSync, err := idx.needsSync(ctx) - if err != nil { - return err - } - if !needsSync { - return nil - } - - dlogSync.Println("syncing code roots...") - commitIfNilErr, err := idx.beginTx(ctx) - if err != nil { - return err - } - defer commitIfNilErr(&retErr) - - pb := newProgressBar(idx.options, len(codeRoots)+1, "syncing code roots") - defer pb.Finish() - - var keep []int64 - for _, codeRoot := range codeRoots { - modIDs, err := idx.syncCodeRoot(ctx, codeRoot) - if err != nil { - return err - } - keep = append(keep, modIDs...) - pb.Add(1) - } - - const vendor = false - if err := idx.pruneModules(ctx, keep); err != nil { - return err - } - pb.Add(1) - - return idx.upsertMetadata(ctx) -} -func (idx *Index) beginTx(ctx context.Context) (commitIfNilErr func(*error), _ error) { - tx, err := idx.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - idx.tx = newSqlTx(tx) - return func(retErr *error) { - if *retErr != nil { - tx.Rollback() - return - } - *retErr = tx.Commit() - }, nil -} - -func (idx *Index) syncCodeRoot(ctx context.Context, root godoc.PackageDir) (modIDs []int64, _ error) { - class, vendor := parseClassVendor(root) - if vendor { - // ImportPath is empty for vendor directories, so we use the - // Dir instead so as not to conflict with the stdlib, which - // uses the empty import path. - // - // The root vendor module is a place holder, so its import path - // will never be used to form a package path. - root.ImportPath = root.Dir - return idx.syncVendoredModules(ctx, root) - } - return idx.syncModule(ctx, root, class) -} - -func parseClassVendor(root godoc.PackageDir) (class, bool) { - if isVendor(root.Dir) { - return classRequired, true - } - switch root.ImportPath { - case "", "cmd": - return classStdlib, false - } - if _, hasVersion := parseVersion(root.Dir); hasVersion { - return classRequired, false - } - return classLocal, false -} -func parseVersion(dir string) (string, bool) { - _, version, found := strings.Cut(filepath.Base(dir), "@") - return version, found -} -func isVendor(dir string) bool { return filepath.Base(dir) == "vendor" } - -func (idx *Index) syncModule(ctx context.Context, root godoc.PackageDir, class int) (modIDs []int64, _ error) { - const vendor = false - modID, needsSync, err := idx.upsertModule(ctx, root, class, vendor) - if err != nil { - return nil, err - } - modIDs = append(modIDs, modID) - - if !needsSync && idx.options.mode != ModeForceSync { - dlogSync.Printf("code root %q is already synced", root.ImportPath) - return modIDs, nil - } - - return modIDs, idx.syncModulePackages(ctx, modID, root) -} - -func (idx *Index) upsertModule(ctx context.Context, root godoc.PackageDir, class class, vendor bool) (modID int64, needsSync bool, _ error) { - mod, err := idx.selectModule(ctx, root.ImportPath) - if ignoreErrNoRows(err) != nil { - return -1, false, err - } - if mod.Dir == root.Dir { - // The module is already in the database and the directory - // hasn't changed, so we assume we are synced. - return mod.ID, false, nil - } - - if mod.ID < 1 { - mod.ID, err = idx.insertModule(ctx, root, class, vendor) - if err != nil { - return -1, false, err - } - } else { - if err := idx.updateModule(ctx, mod.ID, root, class, vendor); err != nil { - return -1, false, err - } - } - return mod.ID, true, nil -} - -func (idx *Index) syncModulePackages(ctx context.Context, modID int64, root godoc.PackageDir) error { - dlogSync.Printf("syncing module packages for %q in %q", root.ImportPath, root.Dir) - root.Dir = filepath.Clean(root.Dir) // because filepath.Join will do it anyway - - // this is the queue of directories to examine in this pass. - this := []godoc.PackageDir{} - // next is the queue of directories to examine in the next pass. - next := []godoc.PackageDir{root} - - var keep []int64 - for len(next) > 0 { - dlogSync.Printf("descending") - this, next = next, this[0:0] - for _, pkg := range this { - dlogSync.Printf("walking %q", pkg) - fd, err := os.Open(pkg.Dir) - if err != nil { - log.Print(err) - continue - } - - entries, err := fd.Readdir(0) - fd.Close() - if err != nil { - log.Print(err) - continue - } - hasGoFiles := false - for _, entry := range entries { - name := entry.Name() - // For plain files, remember if this directory contains any .go - // source files, but ignore them otherwise. - if !entry.IsDir() { - if !hasGoFiles && strings.HasSuffix(name, ".go") { - hasGoFiles = true - pkgID, err := idx.syncPackage(ctx, modID, root, pkg) - if err != nil { - return err - } - if pkgID > 0 { - keep = append(keep, pkgID) - } - } - continue - } - // Entry is a directory. - - // The go tool ignores directories starting with ., _, or named "testdata". - if name[0] == '.' || name[0] == '_' || name == "testdata" { - continue - } - // Ignore vendor directories and stop at module boundaries. - if name == "vendor" { - continue - } - if fi, err := os.Stat(filepath.Join(pkg.Dir, name, "go.mod")); err == nil && !fi.IsDir() { - continue - } - // Remember this (fully qualified) directory for the next pass. - subPkg := godoc.NewPackageDir( - path.Join(pkg.ImportPath, name), - filepath.Join(pkg.Dir, name), - ) - dlogSync.Printf("queuing %q", subPkg.ImportPath) - next = append(next, subPkg) - } - } - } - - return idx.prunePackages(ctx, modID, keep) -} - -func (idx *Index) syncPackage(ctx context.Context, modID int64, root, pkg godoc.PackageDir) (int64, error) { - dlogSync.Printf("syncing package %q in %q", pkg.ImportPath, pkg.Dir) - relativePath := strings.TrimPrefix(pkg.ImportPath[len(root.ImportPath):], "/") - pkgID, err := idx.selectPackageID(ctx, modID, relativePath) - if ignoreErrNoRows(err) != nil { - return -1, err - } - if pkgID > 0 { - return pkgID, nil - } - - pkgID, err = idx.insertPackage(ctx, modID, relativePath) - if err != nil { - return -1, err - } - - return pkgID, idx.syncPartials(ctx, pkgID, pkg.ImportPath) -} - -func (idx *Index) syncPartials(ctx context.Context, pkgID int64, importPath string) error { - dlogSync.Printf("syncing partials for package %q", importPath) - lastSlash := len(importPath) - for lastSlash > 0 { - lastSlash = strings.LastIndex(importPath[:lastSlash], "/") - if _, err := idx.insertPartial(ctx, pkgID, importPath[lastSlash+1:]); err != nil { - return fmt.Errorf("failed to insert partial: %w", err) - } - } - return nil -} diff --git a/internal/index/sync_test.go b/internal/index/sync_test.go deleted file mode 100644 index 60ccde3..0000000 --- a/internal/index/sync_test.go +++ /dev/null @@ -1,145 +0,0 @@ -package index - -import ( - "context" - "path/filepath" - "testing" - - "aslevy.com/go-doc/internal/benchmark" - "github.com/stretchr/testify/require" -) - -const dbMem = ":memory:" - -func dbFilePath(tb testing.TB) string { - const ( - dbFile = "file:" - dbName = "index.sqlite3" - ) - path := dbFile + filepath.Join(tb.TempDir(), dbName) - // tb.Log("db path: ", path) - return path -} - -func loadOpts() Option { return WithOptions(WithNoProgressBar(), WithResyncInterval(0)) } - -// BenchmarkLoadSync_stdlib benchmarks the time it takes to sync an index of -// the stdlib from scratch and write it to the filesystem. -func BenchmarkLoadSync_stdlib(b *testing.B) { - require := require.New(b) - - ctx := context.Background() - dbPath := dbFilePath(b) - codeRoots := stdlibCodeRoots() - - var idx *Index - var err error - benchmark.Run(b, nil, func() { - idx, err = Load(ctx, dbPath, codeRoots, loadOpts()) - require.NoError(err) - require.NoError(idx.waitSync()) - require.NoError(idx.Close()) - }) - - b.Logf("index sync %+v", idx.metadata) -} - -// BenchmarkLoadSync_InMemory_stdlib is like BenchmarkLoadSync_stdlib, but uses -// an in memory database instead of the filesystem. -func BenchmarkLoadSync_InMemory_stdlib(b *testing.B) { - require := require.New(b) - - ctx := context.Background() - codeRoots := stdlibCodeRoots() - - var idx *Index - var err error - benchmark.Run(b, nil, func() { - idx, err = Load(ctx, dbMem, codeRoots, loadOpts()) - require.NoError(err) - require.NoError(idx.waitSync()) - require.NoError(idx.Close()) - }) - b.Logf("index sync %+v", idx.metadata) -} - -// BenchmarkLoadReSync_stdlib benchmarks the time it takes to re-sync an -// existing index of the stdlib when it has not changed. -func BenchmarkLoadReSync_stdlib(b *testing.B) { - require := require.New(b) - - ctx := context.Background() - dbPath := dbFilePath(b) - codeRoots := stdlibCodeRoots() - opts := loadOpts() - - var idx *Index - var err error - benchmark.Run(b, func() { - // sync initially prior to running the benchmark - idx, err = Load(ctx, dbPath, codeRoots, opts) - require.NoError(err) - require.NoError(idx.waitSync()) - require.NoError(idx.Close()) - }, func() { - idx, err = Load(ctx, dbPath, codeRoots, opts) - require.NoError(err) - require.NoError(idx.waitSync()) - require.NoError(idx.Close()) - }) - b.Logf("index sync %+v", idx.metadata) -} - -// BenchmarkLoadForceSync_stdlib benchmarks the time it takes to re-sync an -// existing index of the stdlib when it has not changed. -func BenchmarkLoadForceSync_stdlib(b *testing.B) { - require := require.New(b) - - ctx := context.Background() - dbPath := dbFilePath(b) - codeRoots := stdlibCodeRoots() - opts := WithOptions(loadOpts(), WithForceSync()) - - var idx *Index - var err error - benchmark.Run(b, func() { - // sync initially prior to running the benchmark - idx, err = Load(ctx, dbPath, codeRoots, opts) - require.NoError(err) - require.NoError(idx.waitSync()) - require.NoError(idx.Close()) - }, func() { - idx, err = Load(ctx, dbPath, codeRoots, opts) - require.NoError(err) - require.NoError(idx.waitSync()) - require.NoError(idx.Close()) - }) - b.Logf("index sync %+v", idx.metadata) -} - -// BenchmarkLoadSkipSync_stdlib benchmarks the time it takes to load an -// existing index of the stdlib without syncing. -func BenchmarkLoadSkipSync_stdlib(b *testing.B) { - require := require.New(b) - - ctx := context.Background() - dbPath := dbFilePath(b) - codeRoots := stdlibCodeRoots() - opts := WithOptions(loadOpts(), WithSkipSync()) - - var idx *Index - var err error - benchmark.Run(b, func() { - // sync initially prior to running the benchmark - idx, err = Load(ctx, dbPath, codeRoots, opts) - require.NoError(err) - require.NoError(idx.waitSync()) - require.NoError(idx.Close()) - }, func() { - idx, err = Load(ctx, dbPath, codeRoots, opts) - require.NoError(err) - require.NoError(idx.waitSync()) - require.NoError(idx.Close()) - }) - b.Logf("index sync %+v", idx.metadata) -} diff --git a/internal/index/testdata/module/go.mod b/internal/index/testdata/module/go.mod deleted file mode 100644 index 6499e32..0000000 --- a/internal/index/testdata/module/go.mod +++ /dev/null @@ -1,31 +0,0 @@ -module example.com/module - -go 1.19 - -require ( - aslevy.com/go-doc v0.0.0-20211002150000-000000000000 - github.com/alecthomas/chroma v0.10.0 - github.com/charmbracelet/glamour v0.6.1-0.20221114002222-bf21e0bca6f3 - github.com/davecgh/go-spew v1.1.1 -) - -require ( - github.com/aymanbagabas/go-osc52 v1.2.1 // indirect - github.com/aymerick/douceur v0.2.0 // indirect - github.com/dlclark/regexp2 v1.7.0 // indirect - github.com/gorilla/css v1.0.0 // indirect - github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-isatty v0.0.17 // indirect - github.com/mattn/go-runewidth v0.0.14 // indirect - github.com/microcosm-cc/bluemonday v1.0.21 // indirect - github.com/muesli/reflow v0.3.0 // indirect - github.com/muesli/termenv v0.13.0 // indirect - github.com/olekukonko/tablewriter v0.0.5 // indirect - github.com/rivo/uniseg v0.4.3 // indirect - github.com/yuin/goldmark v1.5.3 // indirect - github.com/yuin/goldmark-emoji v1.0.1 // indirect - golang.org/x/net v0.4.0 // indirect - golang.org/x/sys v0.4.0 // indirect -) - -replace aslevy.com/go-doc => ../../../../ diff --git a/internal/index/testdata/module/go.sum b/internal/index/testdata/module/go.sum deleted file mode 100644 index d9643b3..0000000 --- a/internal/index/testdata/module/go.sum +++ /dev/null @@ -1,62 +0,0 @@ -github.com/alecthomas/chroma v0.10.0 h1:7XDcGkCQopCNKjZHfYrNLraA+M7e0fMiJ/Mfikbfjek= -github.com/alecthomas/chroma v0.10.0/go.mod h1:jtJATyUxlIORhUOFNA9NZDWGAQ8wpxQQqNSB4rjA/1s= -github.com/aymanbagabas/go-osc52 v1.0.3/go.mod h1:zT8H+Rk4VSabYN90pWyugflM3ZhpTZNC7cASDfUCdT4= -github.com/aymanbagabas/go-osc52 v1.2.1 h1:q2sWUyDcozPLcLabEMd+a+7Ea2DitxZVN9hTxab9L4E= -github.com/aymanbagabas/go-osc52 v1.2.1/go.mod h1:zT8H+Rk4VSabYN90pWyugflM3ZhpTZNC7cASDfUCdT4= -github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= -github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= -github.com/charmbracelet/glamour v0.6.1-0.20221114002222-bf21e0bca6f3 h1:X2wFHf/10YT3wqEGzAqoMAiLUvc/v05S6GwSCUj4AV8= -github.com/charmbracelet/glamour v0.6.1-0.20221114002222-bf21e0bca6f3/go.mod h1:Rp5bKbGkf2NwGd1UzSdkbIL9ff8z0dxSmIBUtC8R0BA= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= -github.com/dlclark/regexp2 v1.7.0 h1:7lJfhqlPssTb1WQx4yvTHN0uElPEv52sbaECrAQxjAo= -github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= -github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= -github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= -github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= -github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= -github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/microcosm-cc/bluemonday v1.0.21 h1:dNH3e4PSyE4vNX+KlRGHT5KrSvjeUkoNPwEORjffHJg= -github.com/microcosm-cc/bluemonday v1.0.21/go.mod h1:ytNkv4RrDrLJ2pqlsSI46O6IVXmZOBBD4SaJyDwwTkM= -github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= -github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= -github.com/muesli/termenv v0.13.0 h1:wK20DRpJdDX8b7Ek2QfhvqhRQFZ237RGRO0RQ/Iqdy0= -github.com/muesli/termenv v0.13.0/go.mod h1:sP1+uffeLaEYpyOTb8pLCUctGcGLnoFjSn4YJK5e2bc= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.3 h1:utMvzDsuh3suAEnhH0RdHmoPbU648o6CvXxTx4SBMOw= -github.com/rivo/uniseg v0.4.3/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.5.3 h1:3HUJmBFbQW9fhQOzMgseU134xfi6hU+mjWywx5Ty+/M= -github.com/yuin/goldmark v1.5.3/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -github.com/yuin/goldmark-emoji v1.0.1 h1:ctuWEyzGBwiucEqxzwe0SOYDXPAucOrE9NQC18Wa1os= -github.com/yuin/goldmark-emoji v1.0.1/go.mod h1:2w1E6FEWLcDQkoTE+7HU6QF1F6SLlNGjRIBbIZQFqkQ= -golang.org/x/net v0.0.0-20221002022538-bcab6841153b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= -golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU= -golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/index/testdata/module/main.go b/internal/index/testdata/module/main.go deleted file mode 100644 index 01ef11a..0000000 --- a/internal/index/testdata/module/main.go +++ /dev/null @@ -1,14 +0,0 @@ -package main - -import ( - "log" - - _ "aslevy.com/go-doc/testdata/codeblocks" - _ "github.com/alecthomas/chroma" - _ "github.com/charmbracelet/glamour" - _ "github.com/davecgh/go-spew/spew" -) - -func main() { - log.Println("Hello, world!") -} diff --git a/internal/index/vendor.go b/internal/index/vendor.go deleted file mode 100644 index 3e68d94..0000000 --- a/internal/index/vendor.go +++ /dev/null @@ -1,72 +0,0 @@ -package index - -import ( - "context" - "log" - "os" - - "aslevy.com/go-doc/internal/godoc" - "aslevy.com/go-doc/internal/vendored" -) - -func (idx *Index) syncVendoredModules(ctx context.Context, vendorRoot godoc.PackageDir) ([]int64, error) { - const vendor = true - modID, needsSync, err := idx.upsertModule(ctx, vendorRoot, classLocal, vendor) - if err != nil { - return nil, err - } - - if !needsSync && idx.vendorUnchanged(vendorRoot) { - return idx.vendoredModuleIDs(ctx) - } - - modIDs := []int64{modID} - if err := vendored.Parse(ctx, vendorRoot.Dir, func(ctx context.Context, mod godoc.PackageDir, pkgs ...godoc.PackageDir) error { - pkgKeep := make([]int64, len(pkgs)) - modID, _, err := idx.upsertModule(ctx, mod, classRequired, vendor) - if err != nil { - return err - } - modIDs = append(modIDs, modID) - for _, pkg := range pkgs { - pkgID, err := idx.syncPackage(ctx, modID, mod, pkg) - if err != nil { - return err - } - pkgKeep = append(pkgKeep, pkgID) - } - return idx.prunePackages(ctx, modID, pkgKeep) - }); err != nil { - return nil, err - } - return modIDs, nil -} - -func (idx *Index) vendorUnchanged(vendor godoc.PackageDir) bool { - info, err := os.Stat(vendor.Dir) - if err != nil { - log.Printf("failed to stat %s: %v", vendor.Dir, err) - return true - } - return idx.UpdatedAt.After(info.ModTime()) -} - -func (idx *Index) vendoredModuleIDs(ctx context.Context) ([]int64, error) { - const query = `SELECT rowid FROM module WHERE vendor = true;` - rows, err := idx.db.QueryContext(ctx, query) - if err != nil { - return nil, err - } - defer rows.Close() - - var modIDs []int64 - for rows.Next() { - var modID int64 - if err := rows.Scan(&modID); err != nil { - return nil, err - } - modIDs = append(modIDs, modID) - } - - return modIDs, rows.Err() -} diff --git a/internal/install/assets/plugin/_golang b/internal/install/assets/plugin/_golang index 9ca0a52..a6c8b6a 100644 --- a/internal/install/assets/plugin/_golang +++ b/internal/install/assets/plugin/_golang @@ -207,7 +207,7 @@ __go_package_symbols() { local -a allSyms local argNum=${#line} - local -a DISABLE_OPTS=( "-debug=false" "-debug-index=false" "-install-completion=false" "-open=false" ) + local -a DISABLE_OPTS=( "-debug=false" "-install-completion=false" "-open=false" ) allSyms=("${(@f)$(go-doc -complete -arg ${argNum} ${DISABLE_OPTS} ${GODOC_OPTS} ${words[2,-1]})}") || return 1 # completions for the third argument are always prefixed with the type from diff --git a/internal/modpkg/db/metadata.go b/internal/modpkg/db/metadata.go new file mode 100644 index 0000000..466c5ce --- /dev/null +++ b/internal/modpkg/db/metadata.go @@ -0,0 +1,153 @@ +package db + +import ( + "context" + _ "embed" + "errors" + "fmt" + "hash/crc32" + "io" + "os" + "path/filepath" + "runtime/debug" + "time" + + "aslevy.com/go-doc/internal/sql" +) + +type Metadata struct { + CreatedAt time.Time + UpdatedAt time.Time + + BuildRevision string + GoVersion string + + MainModule +} + +type MainModule struct { + Dir string + GoModHash int32 + GoSumHash int32 + Vendor bool +} + +func (stored Metadata) NeedsSync() (*Metadata, error) { + current, err := NewMetadata(stored.MainModule.Dir) + if err != nil { + return ¤t, err + } + if stored.MainModule == current.MainModule { + return nil, nil + } + return ¤t, nil +} + +func NewMetadata(mainModuleDir string) (meta Metadata, rerr error) { + meta.MainModule.Dir = mainModuleDir + + var err error + meta.BuildRevision, meta.GoVersion, err = parseBuildInfo() + rerr = errors.Join(rerr, err) + + meta.GoModHash, err = hashGoModFile(mainModuleDir) + rerr = errors.Join(rerr, err) + + meta.GoSumHash, err = hashGoSumFile(mainModuleDir) + rerr = errors.Join(rerr, err) + + meta.Vendor, err = usingVendor(mainModuleDir) + rerr = errors.Join(rerr, err) + + return +} +func parseBuildInfo() (string, string, error) { + info, ok := debug.ReadBuildInfo() + if !ok { + return "", "", fmt.Errorf("debug.ReadBuildInfo() failed") + } + return parseBuildRevision(info), info.GoVersion, nil +} +func parseBuildRevision(info *debug.BuildInfo) string { + for _, s := range info.Settings { + if s.Key == "vcs.revision" { + if s.Value == "" { + return s.Value + } + break + } + } + return "unknown" +} + +func hashGoModFile(mainModDir string) (int32, error) { + return fileCRC32(filepath.Join(mainModDir, "go.mod")) +} +func hashGoSumFile(mainModDir string) (int32, error) { + return fileCRC32(filepath.Join(mainModDir, "go.sum")) +} +func fileCRC32(filePath string) (int32, error) { + f, err := os.Open(filePath) + if err != nil { + return 0, fmt.Errorf("failed to open %q: %w", filePath, err) + } + defer f.Close() + + crc := crc32.NewIEEE() + if _, err := io.Copy(crc, f); err != nil { + return 0, fmt.Errorf("failed to write file %q to CRC32 hash: %w", filePath, err) + } + return int32(crc.Sum32()), nil +} +func usingVendor(mainModDir string) (bool, error) { + vendorPath := filepath.Join(mainModDir, "vendor") + fi, err := os.Stat(vendorPath) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, fmt.Errorf("failed to stat %q: %w", vendorPath, err) + } + if !fi.IsDir() { + return false, fmt.Errorf("vendor path %q is not a directory", vendorPath) + } + return true, nil +} + +func (db *DB) SelectMetadata(ctx context.Context) (Metadata, error) { + return selectMetadata(ctx, db.db) +} + +//go:embed sql/metadata_select.sql +var querySelectMetadata string + +func selectMetadata(ctx context.Context, db sql.Querier) (Metadata, error) { + var meta Metadata + row := db.QueryRowContext(ctx, querySelectMetadata) + return meta, row.Scan( + &meta.CreatedAt, + &meta.UpdatedAt, + &meta.BuildRevision, + &meta.GoVersion, + &meta.GoModHash, + &meta.GoSumHash, + &meta.Vendor, + ) +} + +//go:embed sql/metadata_upsert.sql +var queryUpsertMetadata string + +func (s *Sync) upsertMetadata(ctx context.Context, meta *Metadata) error { + _, err := s.tx.ExecContext(ctx, queryUpsertMetadata, + sql.Named("build_revision", meta.BuildRevision), + sql.Named("go_version", meta.GoVersion), + sql.Named("go_mod_hash", meta.GoModHash), + sql.Named("go_sum_hash", meta.GoSumHash), + sql.Named("vendor", meta.Vendor), + ) + if err != nil { + return fmt.Errorf("failed to upsert metadata: %w", err) + } + return nil +} diff --git a/internal/modpkg/db/module.go b/internal/modpkg/db/module.go new file mode 100644 index 0000000..0f80322 --- /dev/null +++ b/internal/modpkg/db/module.go @@ -0,0 +1,97 @@ +package db + +import ( + "context" + _ "embed" + "errors" + "fmt" + + "aslevy.com/go-doc/internal/godoc" + "aslevy.com/go-doc/internal/sql" +) + +type Module struct { + ID int64 + godoc.PackageDir +} + +//go:embed query/module_upsert.sql +var queryModuleUpsertSql string + +func prepareUpsertModule(ctx context.Context, db sql.Querier) (*sql.Stmt, error) { + return db.PrepareContext(ctx, queryModuleUpsertSql) +} + +func (s *Sync) upsertModule(ctx context.Context, mod *Module) (needSync bool, _ error) { + row := s.stmt.upsertModule.QueryRowContext( + ctx, + sql.Named("import_path", mod.ImportPath), + sql.Named("version", mod.Version), + ) + return needSync, row.Scan( + &needSync, + &mod.ID, + ) +} + +func SelectAllModules(ctx context.Context, db sql.Querier) (_ []Module, rerr error) { + return selectModulesFromWhere(ctx, db, "module", "") +} +func selectModulesFromWhere(ctx context.Context, db sql.Querier, from, where string, args ...any) (_ []Module, rerr error) { + query := buildSelectModulesFromWhereQuery(from, where) + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to select from module: %w", err) + } + defer func() { + if err := rows.Close(); err != nil { + rerr = errors.Join(rerr, fmt.Errorf("failed to close rows: %w", err)) + } + }() + return scanModules(ctx, rows) +} +func buildSelectModulesFromWhereQuery(from, where string) string { + query := ` +SELECT + rowid, + import_path, + dir, + class +FROM + `[1:] // remove leading newline + query += from + if where != "" { + query += ` +WHERE + ` + query += where + } + query += ";" + return query +} +func scanModules(ctx context.Context, rows *sql.Rows) (mods []Module, _ error) { + for rows.Next() { + mod, err := scanModule(rows) + if err != nil { + return nil, err + } + mods = append(mods, mod) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed to load next module: %w", err) + } + return mods, nil +} +func scanModule(row sql.RowScanner) (mod Module, _ error) { + if err := row.Scan(&mod.ID, &mod.ImportPath); err != nil { + return mod, fmt.Errorf("failed to scan module: %w", err) + } + return mod, nil +} + +func (s *Sync) selectModulesToPrune(ctx context.Context) ([]Module, error) { + return selectModulesFromWhere(ctx, s.tx, "module", "keep=FALSE ORDER BY rowid") +} +func (s *Sync) selectModulesThatNeedSync(ctx context.Context) ([]Module, error) { + return selectModulesFromWhere(ctx, s.tx, "module", "sync=TRUE ORDER BY rowid") +} diff --git a/internal/modpkg/db/open.go b/internal/modpkg/db/open.go new file mode 100644 index 0000000..89f192f --- /dev/null +++ b/internal/modpkg/db/open.go @@ -0,0 +1,166 @@ +package db + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + + "aslevy.com/go-doc/internal/sql" +) + +type DB struct { + db *sql.DB + + stored Metadata +} + +func (db *DB) Close() error { + return db.db.Close() +} + +const goDocDBPath = ".go-doc/go-doc.sqlite3" + +func Open(ctx context.Context, mainModDir string) (_ *DB, rerr error) { + dbPath := filepath.Join(mainModDir, goDocDBPath) + if err := ensureDBPathDirExists(dbPath); err != nil { + return nil, err + } + + db, err := open(ctx, dbPath) + if err == nil { + // The database is ready to use. + return db, nil + } + if !errors.Is(err, errSchemaChecksumMismatch) { + // The error is not a schema checksum mismatch, so we can't + // recover. + return nil, err + } + + // TODO: log that we are moving the old database to .old + + // The schema checksum mismatch means that the database schema is + // incompatible with the current version of the code. We need to remove + // the database and re-build it. We'll just rename it to be safe. + if err := os.Rename(dbPath, dbPath+".old"); err != nil { + return nil, fmt.Errorf("failed to remove existing database with incompatible schema: %w", err) + } + + return open(ctx, dbPath) +} + +func open(ctx context.Context, dbPath string) (_ *DB, rerr error) { + sqldb, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", rerr) + } + // Close the database if we fail to initialize for any reason. + defer func() { + if rerr == nil { + // Success, so leave the database open. + return + } + // Failed to initialize for some reason so close the database. + if err := sqldb.Close(); err != nil { + rerr = errors.Join(rerr, fmt.Errorf("failed to close database: %w", err)) + } + }() + + db := DB{ + db: sqldb, + } + + if err := db.initialize(ctx); err != nil { + return nil, fmt.Errorf("failed to initialize database: %w", err) + } + + return &db, nil +} +func ensureDBPathDirExists(dbPath string) error { + dirPath := filepath.Dir(dbPath) + if err := os.Mkdir(dirPath, 0755); err != nil && !os.IsExist(err) { + return fmt.Errorf("failed to create database directory %s: %w", dirPath, err) + } + return nil +} + +func (db *DB) initialize(ctx context.Context) (rerr error) { + ready, err := db.checkSchema(ctx) + if err != nil { + return err + } + + // Always enable foreign keys and recursive triggers. + if err := db.enableForeignKeys(ctx); err != nil { + return err + } + if err := db.enableRecursiveTriggers(ctx); err != nil { + return err + } + + if !ready { + // The WAL journal mode is persistent so we only need to set it + // if the database is not ready. This must occur outside of the + // following transaction. + if err := db.journalModeWAL(ctx); err != nil { + return err + } + } + + tx, err := db.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.RollbackOnError(&rerr) + + if !ready { + if err := applySchema(ctx, tx); err != nil { + return err + } + } + + if ready { + if db.stored, err = selectMetadata(ctx, tx); err != nil && + !errors.Is(err, sql.ErrNoRows) { + return err + } + } + + return tx.Commit() +} + +var errSchemaChecksumMismatch = errors.New("schema checksum mismatch") + +func (db *DB) checkSchema(ctx context.Context) (ready bool, _ error) { + appID, err := getApplicationID(ctx, db.db) + if err != nil { + return false, err + } + + userVersion, err := getUserVersion(ctx, db.db) + if err != nil { + return false, err + } + + schemaVersion, err := getSchemaVersion(ctx, db.db) + if err != nil { + return false, err + } + + if appID+userVersion+schemaVersion == 0 { + // Database is uninitialized. + return false, nil + } + + if appID != sqliteApplicationID { + return false, fmt.Errorf("unrecognized database application ID") + } + + if userVersion != schemaChecksum { + return false, errSchemaChecksumMismatch + } + + return true, nil +} diff --git a/internal/modpkg/db/package.go b/internal/modpkg/db/package.go new file mode 100644 index 0000000..bb4fa25 --- /dev/null +++ b/internal/modpkg/db/package.go @@ -0,0 +1,108 @@ +package db + +import ( + "context" + _ "embed" + "errors" + "fmt" + + "aslevy.com/go-doc/internal/sql" +) + +type Package struct { + ID int64 + ModuleID int64 + RelativePath string + NumParts int +} + +func (s *Sync) prepareStmtUpsertPackage(ctx context.Context) (err error) { + s.stmt.upsertPkg, err = prepareStmtUpsertPackage(ctx, s.tx) + return +} + +func prepareStmtUpsertPackage(ctx context.Context, db sql.Querier) (*sql.Stmt, error) { + stmt, err := db.PrepareContext(ctx, queryUpsertPackage) + if err != nil { + return nil, fmt.Errorf("failed to prepare upsert package statement: %w", err) + } + return stmt, nil +} + +//go:embed sql/package_upsert.sql +var queryUpsertPackage string + +func (s *Sync) upsertPackage(ctx context.Context, pkg *Package) error { + row := s.stmt.upsertPkg.QueryRowContext(ctx, + sql.Named("module_id", pkg.ModuleID), + sql.Named("relative_path", pkg.RelativePath), + ) + if err := row.Err(); err != nil { + return fmt.Errorf("failed to upsert package: %w", err) + } + if err := row.Scan(&pkg.ID); err != nil { + return fmt.Errorf("failed to scan upserted package: %w", err) + } + return nil +} + +func SelectAllPackages(ctx context.Context, db sql.Querier) ([]Package, error) { + return selectPackagesFromWhere(ctx, db, "package", "") +} +func SelectModulePackages(ctx context.Context, db sql.Querier, modId int64) ([]Package, error) { + return selectPackagesFromWhere(ctx, db, "package", "module_id = ? ORDER BY rowid", modId) +} +func selectPackagesFromWhere(ctx context.Context, db sql.Querier, from, where string, args ...interface{}) (_ []Package, rerr error) { + query := ` +SELECT + rowid, + module_id, + relative_path, + num_parts +FROM + ` + query += from + if where != "" { + query += ` +WHERE + ` + query += where + } + query += ";" + + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to select packages: %w", err) + } + defer func() { + if err := rows.Close(); err != nil { + rerr = errors.Join(rerr, fmt.Errorf("failed to close rows: %w", err)) + } + }() + return scanPackages(ctx, rows) +} + +func scanPackages(ctx context.Context, rows *sql.Rows) (pkgs []Package, _ error) { + for rows.Next() { + pkg, err := scanPackage(rows) + if err != nil { + return nil, err + } + pkgs = append(pkgs, pkg) + } + if err := rows.Err(); err != nil { + return pkgs, fmt.Errorf("failed to load next package: %w", err) + } + return pkgs, nil +} + +func scanPackage(row sql.RowScanner) (pkg Package, _ error) { + if err := row.Scan(&pkg.ID, &pkg.ModuleID, &pkg.RelativePath, &pkg.NumParts); err != nil { + return pkg, fmt.Errorf("failed to scan package: %w", err) + } + return pkg, nil +} + +func (s *Sync) selectPackagesPrune(ctx context.Context) ([]Package, error) { + return selectPackagesFromWhere(ctx, s.tx, "package", "keep=FALSE ORDER BY rowid") +} diff --git a/internal/modpkg/db/part.go b/internal/modpkg/db/part.go new file mode 100644 index 0000000..4c919e8 --- /dev/null +++ b/internal/modpkg/db/part.go @@ -0,0 +1,98 @@ +package db + +import ( + "context" + _ "embed" + "errors" + "fmt" + "path" + + "aslevy.com/go-doc/internal/sql" +) + +type Part struct { + ID int64 + Name string + ParentID *int64 + PackageID *int64 + PathDepth int64 +} + +type PartClosure struct { + AncestorID int64 + DescendantID int64 + Depth int64 +} + +//go:embed sql/part_select_by_package_id.sql +var querySelectPackageParts string + +func SelectPackageParts(ctx context.Context, db sql.Querier, packageID int64, parts []Part) ([]Part, error) { + rows, err := db.QueryContext(ctx, querySelectPackageParts, + sql.Named("package_id", packageID), + ) + if err != nil { + return nil, fmt.Errorf("failed to query parts for packageID: %w", err) + } + defer rows.Close() + + for rows.Next() { + var part Part + if err := rows.Scan(&part.ID, &part.Name, &part.ParentID, &part.PackageID); err != nil { + return nil, fmt.Errorf("failed to scan Part: %w", err) + } + parts = append(parts, part) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed to scan all parts: %w", err) + } + return parts, nil +} + +type ModulePackage struct { + PackageID int64 + PackageImportPath string + Dir string +} + +//go:embed sql/package_select_by_parts.sql +var querySelectPackagesByParts string + +func selectPackagesByParts(ctx context.Context, db sql.Querier, parts []string, pkgs []ModulePackage) (_ []ModulePackage, rerr error) { + rows, err := selectPackagesByPartsRows(ctx, db, true, parts) + if err != nil { + return nil, err + } + + defer func() { + if err := rows.Close(); err != nil { + rerr = errors.Join(rerr, fmt.Errorf("failed to close rows: %w", err)) + } + }() + + for rows.Next() { + var pkg ModulePackage + if err := rows.Scan(&pkg.PackageID, &pkg.PackageImportPath, &pkg.Dir); err != nil { + return nil, fmt.Errorf("failed to scan ModulePackage: %w", err) + } + pkgs = append(pkgs, pkg) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed to scan all ModulePackages: %w", err) + } + return pkgs, nil +} + +func (db *DB) SelectPackagesByPartsRows(ctx context.Context, exact bool, parts []string) (*sql.Rows, error) { + return selectPackagesByPartsRows(ctx, db.db, exact, parts) +} +func selectPackagesByPartsRows(ctx context.Context, db sql.Querier, exact bool, parts []string) (*sql.Rows, error) { + rows, err := db.QueryContext(ctx, querySelectPackagesByParts, + sql.Named("search_path", path.Join(parts...)), + sql.Named("exact", exact), + ) + if err != nil { + return nil, fmt.Errorf("failed to query packages by parts: %w", err) + } + return rows, nil +} diff --git a/internal/modpkg/db/pragma.go b/internal/modpkg/db/pragma.go new file mode 100644 index 0000000..3e3ceba --- /dev/null +++ b/internal/modpkg/db/pragma.go @@ -0,0 +1,89 @@ +package db + +import ( + "context" + "fmt" + + "aslevy.com/go-doc/internal/sql" +) + +// sqliteApplicationID is the magic number used to identify sqlite3 databases +// created by this application. +// +// See https://www.sqlite.org/fileformat.html#application_id +const ( + sqliteApplicationID = int32(0x0_90_D0C_90) // GO DOC GO + pragmaApplicationID = "application_id" + pragmaUserVersion = "user_version" + pragmaSchemaVersion = "schema_version" + pragmaForeignKeys = "foreign_keys" + pragmaRecursiveTriggers = "recursive_triggers" + pragmaJournalMode = "journal_mode" +) + +func assertApplicationID(ctx context.Context, db sql.Querier) error { + appID, err := getApplicationID(ctx, db) + if err != nil { + return err + } + if appID == 0 { // app ID not set + return setApplicationID(ctx, db) + } + if appID != sqliteApplicationID { + return fmt.Errorf("unrecognized database application ID") + } + return nil +} +func getApplicationID(ctx context.Context, db sql.Querier) (appID int32, err error) { + err = getPragma(ctx, db, pragmaApplicationID, &appID) + return +} +func setApplicationID(ctx context.Context, db sql.Querier) error { + return setPragma(ctx, db, pragmaApplicationID, sqliteApplicationID) +} + +func getUserVersion(ctx context.Context, db sql.Querier) (userVersion int32, err error) { + err = getPragma(ctx, db, pragmaUserVersion, &userVersion) + return +} +func setUserVersion(ctx context.Context, db sql.Querier, userVersion int32) error { + return setPragma(ctx, db, pragmaUserVersion, userVersion) +} + +func getSchemaVersion(ctx context.Context, db sql.Querier) (schemaVersion int32, err error) { + err = getPragma(ctx, db, pragmaSchemaVersion, &schemaVersion) + return +} + +func (db *DB) enableForeignKeys(ctx context.Context) error { + return setPragma(ctx, db.db, pragmaForeignKeys, true) +} + +func (db *DB) enableRecursiveTriggers(ctx context.Context) error { + return setPragma(ctx, db.db, pragmaRecursiveTriggers, true) +} + +func (db *DB) journalModeWAL(ctx context.Context) error { + return setPragma(ctx, db.db, pragmaJournalMode, "wal") +} + +func getPragma(ctx context.Context, db sql.Querier, key string, val any) error { + query := fmt.Sprintf(`PRAGMA %s;`, key) + row := db.QueryRowContext(ctx, query) + if err := row.Err(); err != nil { + return err + } + if err := row.Scan(val); err != nil { + return fmt.Errorf("failed to scan %s: %w", query, err) + } + return nil +} + +func setPragma(ctx context.Context, db sql.Querier, key string, val any) error { + query := fmt.Sprintf(`PRAGMA %s=%v;`, key, val) + _, err := db.ExecContext(ctx, query) + if err != nil { + return fmt.Errorf("failed to set %s: %w", query, err) + } + return nil +} diff --git a/internal/modpkg/db/query/module_upsert.sql b/internal/modpkg/db/query/module_upsert.sql new file mode 100644 index 0000000..05905c6 --- /dev/null +++ b/internal/modpkg/db/query/module_upsert.sql @@ -0,0 +1,25 @@ +-- Upsert modules into the database, marking them to be kept since still in +-- use, and marking the module's packages for sync if they may have changed. +INSERT INTO module ( + import_path, + version +) +VALUES + $import_path, + $version +ON CONFLICT ( + import_path +) DO +UPDATE SET + sync = ( + excluded.version == "" -- sync if the version is empty + OR + version != excluded.version -- or if the version has changed + ), + -- Keep this module since it's still in use. + keep = TRUE, + version = excluded.version +RETURNING + rowid, + sync +;--- diff --git a/internal/modpkg/db/registerfunction.go b/internal/modpkg/db/registerfunction.go new file mode 100644 index 0000000..009f309 --- /dev/null +++ b/internal/modpkg/db/registerfunction.go @@ -0,0 +1,19 @@ +package db + +import ( + "database/sql/driver" + "strings" + + "modernc.org/sqlite" +) + +func init() { + sqlite.MustRegisterDeterministicScalarFunction("concat_ws", -1, func(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) { + sep := args[0].(string) + elems := make([]string, len(args)-1) + for i, arg := range args[1:] { + elems[i] = arg.(string) + } + return strings.Join(elems, sep), nil + }) +} diff --git a/internal/modpkg/db/schema.go b/internal/modpkg/db/schema.go new file mode 100644 index 0000000..d266297 --- /dev/null +++ b/internal/modpkg/db/schema.go @@ -0,0 +1,173 @@ +package db + +import ( + "bufio" + "bytes" + "context" + _ "embed" + "errors" + "fmt" + "hash/crc32" + + "aslevy.com/go-doc/internal/sql" + _ "modernc.org/sqlite" +) + +//go:embed schema/metadata.sql +var schemaMetadataSql []byte + +//go:embed schema/modpkg.sql +var schemaModPkgSql []byte + +//go:embed schema/importpath.sql +var schemaImportPathSql []byte + +var schemaQueries = mustSplitSqlQueries(schemaMetadataSql, schemaModPkgSql, schemaImportPathSql) + +// schemaChecksum is the CRC32 checksum of schema. +var schemaChecksum int32 = func() int32 { + crc := crc32.NewIEEE() + for _, query := range schemaQueries { + if _, err := crc.Write(minifySql(query)); err != nil { + panic(err) + } + } + return int32(crc.Sum32()) +}() + +// applySchema execs all schemaQueries against the db. +func applySchema(ctx context.Context, db sql.Querier) error { + if err := execQueries(ctx, db, schemaQueries...); err != nil { + return err + } + + if err := setApplicationID(ctx, db); err != nil { + return err + } + + return setUserVersion(ctx, db, schemaChecksum) +} + +func mustSplitSqlQueries(sqlScript ...[]byte) (queries []string) { + queries, err := splitSqlQueries(sqlScript...) + if err != nil { + panic(err) + } + return queries +} + +func splitSqlQueries(sqlScripts ...[]byte) (queries []string, err error) { + for _, sql := range sqlScripts { + qrys, err := splitSql(sql) + if err != nil { + return nil, err + } + queries = append(queries, qrys...) + } + return queries, nil +} + +func minifySql(query string) []byte { + var minified bytes.Buffer + minified.Grow(len(query)) + scanner := bufio.NewScanner(bytes.NewReader([]byte(query))) + scanner.Split(bufio.ScanLines) + for scanner.Scan() { + line := scanner.Bytes() + sqlLine, _, _ := bytes.Cut(line, []byte(commentPrefix)) + sqlLine = bytes.TrimSpace(sqlLine) + if len(sqlLine) == 0 { + continue + } + _, _ = minified.Write(sqlLine) + _, _ = minified.Write([]byte("\n")) + } + if err := scanner.Err(); err != nil { + panic(err) + } + return minified.Bytes() +} + +func execQueries(ctx context.Context, db sql.Querier, queries ...string) error { + for _, query := range queries { + _, err := db.ExecContext(ctx, query) + if err != nil { + return fmt.Errorf("failed to apply query: %w\n%s\n", err, query) + } + } + return nil +} + +func splitSql(sql []byte) (queries []string, _ error) { + scanner := bufio.NewScanner(bytes.NewReader(sql)) + scanner.Split(scanSqlQueries) + for scanner.Scan() { + query := scanner.Text() + if query == "" { + continue + } + queries = append(queries, query) + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("failed to split SQL statements: %w", err) + } + return queries, nil +} + +const ( + commentPrefix = "--" + stmtDelimiter = ";---" +) + +func scanSqlQueries(data []byte, atEOF bool) (advance int, token []byte, rerr error) { + defer func() { + if (rerr != nil && + !errors.Is(rerr, bufio.ErrFinalToken)) || + advance == 0 { + return + } + // Trim the token of any leading or trailing whitespace. + token = bytes.TrimSpace(token) + // Trim leading comment lines. + for { + adv, tkn, err := bufio.ScanLines(token, true) + if err != nil { + rerr = err + return + } + if adv == 0 { + return + } + if len(tkn) > 0 { + tkn = bytes.TrimSpace(tkn) + if isComment := bytes.HasPrefix(tkn, []byte(commentPrefix)); !isComment { + return + } + } + token = token[adv:] + } + }() + + stmtDelim := bytes.Index(data, []byte(stmtDelimiter)) + if stmtDelim == -1 { + // No complete statement yet... + if atEOF { + // That's everything... don't treat this as an error to + // allow for trailing whitespace, comments, or + // statements that don't use the stmtDelimeter. + return len(data), data, bufio.ErrFinalToken + } + // Ask for more data so we can find the EOL. + return 0, nil, nil + } + // We found the stmtDelimiter, now find the next newline. + newline := bytes.Index(data[stmtDelim+len([]byte(stmtDelimiter)):], []byte("\n")) + if newline == -1 { + if atEOF { + return len(data), data, bufio.ErrFinalToken + } + return 0, nil, nil + } + + return stmtDelim + len([]byte(stmtDelimiter)) + newline + 1, data[:stmtDelim+1], nil +} diff --git a/internal/modpkg/db/schema/importpath.sql b/internal/modpkg/db/schema/importpath.sql new file mode 100644 index 0000000..2adc112 --- /dev/null +++ b/internal/modpkg/db/schema/importpath.sql @@ -0,0 +1,391 @@ +-- The schema for the database. +-- +-- Use ;--- to separate statements. This is a simple hack to allow for +-- splitting complete statements. Omitting the ;--- won't result in an error, +-- but that statement will be executed together with all subsequent statements +-- until the next ;--- or the end of the file. + + +-- import_path_segment stores all segments for all package import paths as +-- a trie. Root segments have a NULL parent_id. +CREATE TABLE import_path_segment ( + rowid INTEGER PRIMARY KEY, + parent_id INT NOT NULL + REFERENCES import_path_segment(rowid) + ON DELETE RESTRICT + ON UPDATE CASCADE, + name TEXT NOT NULL CHECK (name != ''), + UNIQUE(parent_id, name), + + path_depth INT NOT NULL CHECK (path_depth > 0), + package_id INT UNIQUE + REFERENCES package(rowid) + ON DELETE SET NULL + ON UPDATE CASCADE +);--- + +CREATE INDEX import_path_segment_name ON import_path_segment(name);--- + +-- import_path_segment_closure stores the relationships between all import path +-- segments. +CREATE TABLE import_path_segment_closure ( + ancestor_id INT NOT NULL + REFERENCES import_path_segment(rowid) + ON DELETE CASCADE + ON UPDATE CASCADE, + descendant_id INT NOT NULL + REFERENCES import_path_segment(rowid) + ON DELETE CASCADE + ON UPDATE CASCADE, + PRIMARY KEY(ancestor_id, descendant_id), + + distance INT NOT NULL CHECK (distance >= 0) +) WITHOUT ROWID;--- + +CREATE INDEX import_path_segment_closure_descendant_id ON import_path_segment_closure(descendant_id, ancestor_id, distance);--- +CREATE INDEX import_path_segment_closure_distance ON import_path_segment_closure(distance);--- + +-- import_path_segment_descendant_stats calculates the number of immediate +-- children, number of total descendants, and the maximum descentant distance +-- for each segment. +CREATE VIEW import_path_segment_descendant_stats ( + segment_id, + num_children, + num_descendants, + max_descendant_distance +) AS +SELECT + closure.ancestor_id AS segment_id, + count(closure.descendant_id) + FILTER ( + WHERE closure.distance = 1 + ) AS num_children, + count(closure.descendant_id) + FILTER ( + WHERE closure.distance > 0 + ) AS num_descendants, + max(closure.distance) AS descendant_max_distance +FROM + import_path_segment_closure AS closure +GROUP BY + closure.ancestor_id +;--- + +-- insert_next_import_path_segment is used by an INSTEAD OF INSERT trigger to +-- recursively split an import path into segments. +CREATE VIEW insert_next_import_path_segment ( + package_id, + total_num_segments, + path_depth, + segment_parent_id, + segment_name, + remaining_path +) AS +VALUES ( + NULL, + NULL, + NULL, + NULL, + NULL, + NULL +);--- + +-- split_import_path_segment_on_insert_package is a trigger that fires whenever +-- a new package is inserted. +-- +-- This initiates a recursive trigger chain that splits the package's path and +-- inserts them into the part table. +CREATE TRIGGER + split_import_path_segment_on_insert_package +AFTER + INSERT ON + package +BEGIN + + INSERT INTO insert_next_import_path_segment ( + package_id, + total_num_segments, + path_depth, + segment_parent_id, + segment_name, + remaining_path + ) + SELECT + package.rowid AS package_id, + package.total_num_segments AS total_num_segments, + 0 AS path_depth, + NULL AS segment_parent_id, + '' AS segment_name, + package.import_path || '/' AS remaining_path + FROM + module_package AS package + WHERE + package.rowid = new.rowid; + +END;--- + +-- recursive_package_part_splitter is a recursive trigger that fires instead of +-- inserting to the insert_next_import_path_segment view. +-- +-- It inserts the current part into the part and part_package tables. +-- +-- It then inserts the next part, if any, into the insert_next_import_path_segment view. +CREATE TRIGGER + recursive_package_part_splitter +INSTEAD OF + INSERT ON + insert_next_import_path_segment + WHEN + new.path_depth <= new.total_num_segments +BEGIN + + -- insert the segment + INSERT INTO import_path_segment ( + parent_id, + name, + path_depth, + package_id + ) + SELECT + new.segment_parent_id AS parent_id, + new.segment_name AS name, + new.path_depth AS path_depth, + -- the package_id is only set on the final segment, otherwise it is NULL + iif( + new.path_depth = new.total_num_segments, + new.package_id, + NULL + ) AS package_id + WHERE + -- the first iteration is invalid, so we skip it + new.path_depth > 0 + ON CONFLICT DO + UPDATE SET + package_id = excluded.package_id + WHERE + new.path_depth = new.total_num_segments; + + INSERT INTO + insert_next_import_path_segment ( + package_id, + total_num_segments, + path_depth, + segment_parent_id, + segment_name, + remaining_path + ) + SELECT + new.package_id AS package_id, + new.total_num_segments AS total_num_segments, + new.path_depth + 1 AS path_depth, + + -- the first iteration will have NULL segment_parent_id + iif( + new.path_depth = 0, + NULL, + + -- if the segment was inserted, changes() will be 1, so we can avoid the + -- subquery and use last_insert_rowid() + iif( + changes() > 0, + last_insert_rowid(), + ( + SELECT + rowid + FROM + import_path_segment_view + WHERE + parent_id IS new.segment_parent_id + AND + name IS new.segment_name + ) + ) + ) AS segment_parent_id, + + substr(new.remaining_path, 1, slash-1) AS segment_name, + substr(new.remaining_path, slash+1) AS remaining_path + FROM ( + -- find the position of the first slash + SELECT + instr(new.remaining_path, '/') AS slash + ) + WHERE + -- when new.path_depth = new.total_num_segments, we have inserted all + -- segments and we are done + new.path_depth < new.total_num_segments; + +END;--- + +-- insert_import_path_segment_closure_on_insert_import_path_segment populates +-- the import_path_segment_closure table for each new segment. +-- +-- Each segment is its own ancestor, and all of its parent's ancestors are also +-- its ancestors, with a distance of 1 more than the distance to the parent. +CREATE TRIGGER + insert_import_path_segment_closure_on_insert_import_path_segment +AFTER + INSERT ON + import_path_segment +BEGIN + + INSERT INTO + import_path_segment_closure ( + ancestor_id, + descendant_id, + distance + ) + -- the new segment is its own ancestor, with a distance of 0 + SELECT + new.rowid AS ancestor_id, + new.rowid AS descendant_id, + 0 AS distance + UNION ALL + -- all of the new segment's parent's ancestors, are also its ancestors but + -- with a distance of 1 more than the distance to the parent + SELECT + closure.ancestor_id AS ancestor_id, + new.rowid AS descendant_id, + closure.distance + 1 AS distance + FROM + import_path_segment_closure AS closure + WHERE + closure.ancestor_id IS new.parent_id; + +END;--- + +-- delete_import_path_segment_with_null_package_id_and_no_children fires +-- whenever an import path segment's package_id is set to NULL, which occurs +-- automatically when a package is deleted. Segment's must either have +-- children, or be the final segment of a package's import path, otherwise they +-- are deleted. This ensures the trie is kept in sync with the package table. +-- +-- Ancestors of a segment with a NULL package_id and no children are cleaned up +-- by the subsequently defined trigger. +CREATE TRIGGER + delete_import_path_segment_with_null_package_id_and_no_children +AFTER + UPDATE OF + package_id + ON + import_path_segment + WHEN + new.package_id IS NULL +BEGIN + + DELETE FROM + import_path_segment AS segment + WHERE + segment.rowid = new.rowid + AND + 0 = ( + count(*) FILTER ( + WHERE segment.rowid IS new.rowid + ) + ); + +END;--- + +-- recursively_prune_leaf_parts_with_null_package_id recursively deletes leaf +-- parts that have a NULL package_id. +CREATE TRIGGER + recursively_delete_import_path_segment_with_null_package_id_and_no_children +AFTER + DELETE ON + import_path_segment +BEGIN + + DELETE FROM + import_path_segment AS segment + WHERE + segment.rowid = old.parent_id + AND + segment.package_id IS NULL + AND + 0 = ( + count(*) FILTER ( + WHERE segment.parent_id IS old.parent_id + ) + ); + +END;--- + +-- set_package_keep_false_for_modules_with_sync_true fires whenever a module is +-- updated such that sync is set to TRUE. It sets package.keep to FALSE for all +-- of the module's packages. As packages are re-synced, package.keep is set +-- back to TRUE so that only packages that are no longer available in the +-- module are left with keep set to FALSE, resulting in them being deleted to +-- finalize a sync. +CREATE TRIGGER + set_package_keep_false_for_modules_with_sync_true +AFTER + UPDATE OF + sync + ON + module + WHEN + new.sync = TRUE +BEGIN + + UPDATE + package + SET + keep = FALSE + WHERE + package.module_id = new.rowid; + +END;--- + +-- after_update_metadata_prune_module_package fires after the metadata has been +-- updated. It performs some sanity checks to ensure that the sync was +-- completed successfully and then deletes all modules and packages that have +-- keep set to FALSE. +CREATE TRIGGER + after_update_metadata_prune_module_package +AFTER + UPDATE ON + metadata +BEGIN + + SELECT + RAISE(ABORT, 'invalid sync: no modules were synced') + WHERE + NOT EXISTS ( + SELECT + 1 + FROM + module + WHERE + keep = TRUE + ); + + SELECT + RAISE(ABORT, 'invalid sync: no packages were synced for one or more modules') + WHERE + EXISTS ( + SELECT + 1 + FROM ( + SELECT + count(*) AS num_pkgs + FROM + package + WHERE + keep = TRUE + GROUP BY + module_id + ) + WHERE + num_pkgs = 0 + ); + + DELETE FROM + module + WHERE + keep = FALSE; + + DELETE FROM + package + WHERE + keep = FALSE; + +END;--- diff --git a/internal/modpkg/db/schema/metadata.sql b/internal/modpkg/db/schema/metadata.sql new file mode 100644 index 0000000..16a8350 --- /dev/null +++ b/internal/modpkg/db/schema/metadata.sql @@ -0,0 +1,47 @@ +-- The schema for the database. +-- +-- Use ;--- to separate statements. This is a simple hack to allow for +-- splitting complete statements. Omitting the ;--- won't result in an error, +-- but that statement will be executed together with all subsequent statements +-- until the next ;--- or the end of the file. + +-- metadata stores information about the database. +-- +-- rowid is the primary key and is always 1 to ensure there is only one row. +-- +-- created_at is the time the database was created. +-- +-- updated_at is the time the database was last updated. +-- +-- build_revision is the git revision of the go-doc build which last updated +-- this database. +-- +-- go_version is the version of Go used to build the go-doc binary. +-- +-- go_root is the path to the Go root directory. +-- +-- go_mod_cache is the path to the Go module cache. +-- +-- main_mod_id is the module_id of the main module. +-- +-- go_mod_hash is the CRC32 hash of the go.mod file. +-- +-- go_sum_hash is the CRC32 hash of the go.sum file. +-- +-- vendor is a boolean that indicates whether the main module is vendored. +CREATE TABLE metadata ( + rowid INTEGER PRIMARY KEY + NOT NULL + CHECK (rowid = 1), + + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + + build_revision TEXT NOT NULL CHECK (build_revision != ''), + go_version TEXT NOT NULL CHECK (go_version != ''), + + main_mod_dir TEXT NOT NULL CHECK (main_mod_dir != ''), + go_mod_hash INT NOT NULL CHECK (go_mod_hash != 0), + go_sum_hash INT NOT NULL CHECK (go_sum_hash != 0), + vendor BOOL NOT NULL DEFAULT FALSE +) WITHOUT ROWID;--- diff --git a/internal/modpkg/db/schema/modpkg.sql b/internal/modpkg/db/schema/modpkg.sql new file mode 100644 index 0000000..2cf478a --- /dev/null +++ b/internal/modpkg/db/schema/modpkg.sql @@ -0,0 +1,145 @@ +-- The schema for the database. +-- +-- Use ;--- to separate statements. This is a simple hack to allow for +-- splitting complete statements. Omitting the ;--- won't result in an error, +-- but that statement will be executed together with all subsequent statements +-- until the next ;--- or the end of the file. + +-- module stores all required modules and the directory they are located in. +-- +-- import_path is the module's import path. +-- +-- version is the module's version, if any. +-- +-- relative_dir is the directory the module is located in, relative to the +-- parent_dir.dir referenced by parent_dir_id. +-- +-- parent_dir_id is the parent_dir's rowid. +-- +-- num_segments is the number of slash separated parts in the module's import path. +-- +-- sync is a boolean that indicates whether the module's packages should be +-- synced. Newly inserted modules have sync set to TRUE. Upserted modules have +-- sync set to true if the module's dir has changed. +-- +-- keep is a boolean that indicates whether the module should be kept. At the +-- beginning of a sync, all module's keep are set to FALSE. Newly inserted +-- modules have keep set to TRUE. Upserted modules have keep set to TRUE. After +-- syncing all modules, any modules that have keep set to FALSE are deleted. +CREATE TABLE module ( + rowid INTEGER PRIMARY KEY, + + import_path TEXT NOT NULL UNIQUE CHECK ( + -- must not have leading or trailing slashes + import_path = trim(import_path, '/') + ), + + version TEXT NOT NULL, + + num_segments INT NOT NULL GENERATED ALWAYS AS ( + iif( + length(import_path) = 0, + 0, + 1 + + length(import_path) + - length(replace(import_path, '/', '')) + ) + ) STORED, + + sync BOOL NOT NULL DEFAULT TRUE, + keep BOOL NOT NULL DEFAULT TRUE +);--- + +-- package stores all packages for all modules. +-- +-- module_id is the module the package belongs to. +-- +-- relative_path is the package's path relative to the module's import_path. +-- This can be empty if the module's import path is an importable package. +-- +-- num_segments is the number of slash separated parts in the package's relative path. +-- +-- keep is a boolean that indicates whether the package should be kept. +-- Whenever a module requires a sync, keep is set to FALSE for all of its +-- packages. Newly inserted packages have keep set to TRUE. When existing +-- packages are upserted, keep is set back to TRUE. After syncing all packages, +-- any that have keep set to FALSE are deleted. +CREATE TABLE package ( + rowid INTEGER PRIMARY KEY, + + module_id INT NOT NULL + REFERENCES module(rowid) + ON DELETE CASCADE + ON UPDATE CASCADE, + + in_mod_path TEXT NOT NULL UNIQUE CHECK ( + -- must not have leading or trailing slashes + in_mod_path = trim(in_mod_path, '/') + ), + UNIQUE(module_id, in_mod_path), + + num_segments INT NOT NULL GENERATED ALWAYS AS ( + iif( + length(in_mod_path) = 0, + 0, + 1 + + length(in_mod_path) + - length(replace(in_mod_path, '/', '')) + ) + ) STORED, + + keep BOOL NOT NULL DEFAULT TRUE +);--- + +-- module_package is a view that joins module and package information. +-- +-- package_id is the package's rowid. +-- +-- package_import_path is the package's import path. +-- +-- package_dir is the directory the package is located in. +-- +-- module_id is the module's rowid. +-- +-- module_import_path is the module's import path. +-- +-- relative_path is the package's path relative to the module's import_path. +-- +-- class is an integer that represents the type of module. +-- +-- relative_num_segments is the number of slash separated parts in the package's relative_path. +-- +-- total_num_segments is the number of slash separated parts in the package_import_path. +CREATE VIEW module_package ( + rowid, -- package.rowid + package_id, -- package.rowid + import_path, + total_num_segments, + module_id, + module_import_path, + module_num_segments +) AS +SELECT + package.rowid AS rowid, + package.rowid AS package_id, + concat_ws( + '/', + module.import_path, + package.in_mod_path + ) AS import_path, + package.num_segments + + module.num_segments AS total_num_segments, + module.rowid AS module_id, + module.import_path AS module_path, + module.num_segments AS module_num_segments +FROM + package, + module +ON + package.module_id = module.rowid +ORDER BY + module.num_segments ASC, + module.import_path ASC, + package.num_segments ASC, + package.in_mod_path ASC +;--- diff --git a/internal/modpkg/db/schema_suite_test.go b/internal/modpkg/db/schema_suite_test.go new file mode 100644 index 0000000..5aff275 --- /dev/null +++ b/internal/modpkg/db/schema_suite_test.go @@ -0,0 +1,458 @@ +//go:build disable +// +build disable + +package db + +import ( + "context" + "path/filepath" + "testing" + + "aslevy.com/go-doc/internal/dlog" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func init() { + dlog.Enable() +} + +func TestSchema(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Schema Suite") +} + +func tempDBPath() string { + GinkgoHelper() + const dbName = "index.sqlite3" + tempDir := GinkgoT().TempDir() + // tempDir = "." + return filepath.Join(tempDir, dbName) +} + +var _ = Describe("Schema", func() { + var ( + dbPath string + db *DB + meta Metadata + allMods []Module + allPkgs []Package + ) + When("OpenDB is called", func() { + BeforeEach(func(ctx context.Context) { + dbPath = tempDBPath() + By("opening db " + dbPath) + var err error + db, err = OpenDB(ctx, dbPath) + Expect(err). + To(Succeed(), "OpenDB") + DeferCleanup(func() { + By("closing the database") + Expect(db.Close()). + To(Succeed(), "sql.DB.Close") + }) + }) + It("initializes the database", func(ctx context.Context) { + var foreignKeys bool + Expect(getPragma(ctx, db.db, pragmaForeignKeys, &foreignKeys)).To(Succeed(), "failed to get foreign_keys pragma") + Expect(foreignKeys).To(BeTrue(), "foreign_keys should be enabled") + + var recursiveTriggers bool + Expect(getPragma(ctx, db.db, pragmaRecursiveTriggers, &recursiveTriggers)).To(Succeed(), "failed to get recursive_triggers pragma") + Expect(recursiveTriggers).To(BeTrue(), "recursive_triggers should be enabled") + + applicationID, err := getApplicationID(ctx, db.db) + Expect(err).To(Succeed(), "failed to get application_id pragma") + Expect(applicationID).To(Equal(sqliteApplicationID), "application_id should be set") + + userVersion, err := getUserVersion(ctx, db.db) + Expect(err).To(Succeed(), "failed to get user_version pragma") + Expect(userVersion).To(Equal(schemaChecksum), "user_version should be set") + }) + + When("first synced", func() { + BeforeEach(func(ctx context.Context) { + sync, err := db.StartSync(ctx) + Expect(err).To(Succeed(), "NewSync") + Expect(sync).ToNot(BeNil(), "NewSync: should not be nil") + + By("syncing modules") + allMods = initModules() + for i := range allMods { + needsSync, err := sync.AddModule(ctx, &allMods[i]) + Expect(err).To(Succeed(), "Sync.AddModule") + Expect(needsSync).To(BeTrue(), "Sync.AddModule: all modules should need sync") + } + + By("syncing packages") + allPkgs = initPackages() + for i := range allPkgs { + Expect(sync.AddPackage(ctx, &allPkgs[i])). + To(Succeed(), "Sync.AddPackage") + } + + By("finishing sync") + meta = initMetadata() + Expect(sync.Finish(ctx, meta, nil)). + To(Succeed(), "Sync.Finish") + }) + It("has all modules and packages", func(ctx context.Context) { + Expect(SelectAllModules(ctx, db.db)). + To(Equal(allMods), "all modules should be synced") + + Expect(SelectAllPackages(ctx, db.db)). + To(Equal(allPkgs), "all packages should be synced") + }) + When("re-synced", func() { + var modPrune, needSync []Module + var pkgPrune []Package + BeforeEach(func(ctx context.Context) { + By("closing the database") + Expect(db.Close()). + To(Succeed(), "sql.DB.Close") + + By("re-opening the database") + var err error + db, err = OpenDB(ctx, dbPath) + Expect(err). + To(Succeed(), "OpenDB") + + allMods = initModules() + allPkgs = initPackages() + meta = initMetadata() + modPrune = modPrune[:0] + needSync = needSync[:0] + pkgPrune = pkgPrune[:0] + }) + JustBeforeEach(func(ctx context.Context) { + var err error + sync, err := db.StartSync(ctx) + Expect(err).To(Succeed(), "failed to sync modules") + Expect(sync).ToNot(BeNil(), "sync should not be nil") + + for i, mod := range allMods { + syncMod, err := sync.AddModule(ctx, &mod) + Expect(err).To(Succeed(), + "failed to add required modules") + Expect(mod).To(Equal(allMods[i]), "module should not be modified") + if syncMod { + needSync = append(needSync, allMods[i]) + } + } + + if len(needSync) > 0 { + modIDs := make(map[int64]struct{}, len(needSync)) + for _, mod := range needSync { + modIDs[mod.ID] = struct{}{} + } + + By("re-syncing packages") + for _, pkg := range allPkgs { + if _, ok := modIDs[pkg.ModuleID]; !ok { + continue + } + Expect(sync.AddPackage(ctx, &pkg)).To(Succeed(), + "failed to sync packages") + } + } + + modPrune, err = sync.selectModulesPrune(ctx) + Expect(err).To(Succeed(), + "failed to select modules to prune") + + pkgPrune, err = sync.selectPackagesPrune(ctx) + Expect(err).To(Succeed(), + "failed to select packages to prune") + + Expect(sync.Finish(ctx, meta, nil)).To(Succeed(), + "sync should finish successfully") + + Expect(SelectAllModules(ctx, db.db)).To(Equal(allMods), + "SelectAllModules should return all modules") + + }) + + When("the modules have not changed", func() { + It("should return no modules", func() { + Expect(needSync).To(BeEmpty(), "SyncModules should return no modules") + Expect(modPrune).To(BeEmpty(), "there should be no modules to prune") + Expect(pkgPrune).To(BeEmpty(), "there should be no packages to prune") + }) + }) + + When("a new module is added", func() { + BeforeEach(func() { + By("adding a new module") + newMod := Module{ + ID: allMods[len(allMods)-1].ID + 1, + ImportPath: "github.com/onsi/gomega", + RelativeDir: "/home/adam/go/pkg/mod/github.com/onsi/gomega@v1.10.3", + } + allMods = append(allMods, newMod) + allPkgs = append(allPkgs, Package{ + ID: allPkgs[len(allPkgs)-1].ID + 1, + RelativePath: "", + ModuleID: newMod.ID, + NumParts: 0, + }, Package{ + ID: allPkgs[len(allPkgs)-1].ID + 2, + RelativePath: "types", + ModuleID: newMod.ID, + NumParts: 1, + }) + }) + + It("should return the new module", func() { + Expect(needSync).To(Equal(allMods[len(allMods)-1:]), "SyncModules should return the new module") + Expect(modPrune).To(BeEmpty(), "there should be no modules to prune") + Expect(pkgPrune).To(BeEmpty(), "there should be no packages to prune") + }) + }) + + When("a module is removed", func() { + var removed []Module + BeforeEach(func() { + By("removing a module") + removed = allMods[len(allMods)-1:] + allMods = allMods[:len(allMods)-1] + }) + + It("should prune the removed module and its packages", func(ctx context.Context) { + Expect(needSync).To(BeEmpty(), "SyncModules should return no modules") + Expect(SelectModulePackages(ctx, db.db, removed[0].ID)). + To(BeEmpty(), "SelectModulePackages should return no packages") + Expect(modPrune).To(Equal(removed), "the removed module should be pruned") + Expect(pkgPrune).To(BeEmpty(), "there should be no packages to prune") + }) + }) + + When("a module is updated", func() { + var updated []Module + var modPkgs []Package + BeforeEach(func(ctx context.Context) { + By("changing the directory of a module") + allMods[0].RelativeDir = "/home/adam/go/pkg/mod/github.com/stretchr/testify@v1.8.2" + updated = allMods[:1] + + var err error + modPkgs, err = SelectModulePackages(ctx, db.db, updated[0].ID) + Expect(err).To(Succeed(), "failed to select module packages") + Expect(modPkgs).ToNot(BeEmpty(), "module packages should not be empty") + }) + + It("should return the updated module", func() { + Expect(needSync).To(Equal(updated), "SyncModules should return the updated module") + Expect(modPrune).To(BeEmpty(), "there should be no modules to prune") + Expect(pkgPrune).To(BeEmpty(), "the module's packages should be potentially pruned") + }) + + When("the module's packages are unchanged", func() { + It("should retain the module's packages", func(ctx context.Context) { + Expect(SelectModulePackages(ctx, db.db, updated[0].ID)). + To(Equal(modPkgs), "synced packages are not correct") + }) + }) + + When("a module's package is removed", func() { + BeforeEach(func() { + By("removing a package") + allPkgs = modPkgs[1:] + }) + It("should prune the removed package", func(ctx context.Context) { + Expect(SelectModulePackages(ctx, db.db, updated[0].ID)). + To(Equal(allPkgs), "synced packages are not correct") + }) + }) + + When("a module's packages are added and removed", func() { + BeforeEach(func() { + modPkgs = append(modPkgs, Package{ + ID: int64(len(allPkgs) + 1), + ModuleID: updated[0].ID, + RelativePath: "added", + NumParts: 1, + }) + allPkgs = modPkgs[1:] + }) + + It("should prune the removed package, add the added package, and retain the pre-existing packages", func(ctx context.Context) { + Expect(SelectModulePackages(ctx, db.db, updated[0].ID)). + To(Equal(allPkgs), "synced packages are not correct") + }) + }) + }) + + When("modules are removed, added, and updated", func() { + var removed Module + BeforeEach(func() { + By("removing, adding, and updating modules") + removed = allMods[0] + allMods = allMods[1:] + allMods[1].RelativeDir = "/home/adam/go/pkg/mod/github.com/muesli/reflow@v0.3.1" + newMod := Module{ + ID: allMods[len(allMods)-1].ID + 1, + ImportPath: "github.com/onsi/gomega", + RelativeDir: "/home/adam/go/pkg/mod/github.com/onsi/gomega@v1.10.3", + } + allMods = append(allMods, newMod) + allPkgs = append(allPkgs, Package{ + ID: allPkgs[len(allPkgs)-1].ID + 1, + RelativePath: "", + ModuleID: newMod.ID, + NumParts: 0, + }, Package{ + ID: allPkgs[len(allPkgs)-1].ID + 2, + RelativePath: "types", + ModuleID: newMod.ID, + NumParts: 1, + }) + }) + It("should return the new and updated modules", func() { + Expect(needSync).To(Equal(allMods[1:])) + }) + It("should remove the removed module's packages", func(ctx context.Context) { + Expect(SelectModulePackages(ctx, db.db, removed.ID)). + To(BeEmpty(), "SelectModulePackages should return no packages") + }) + }) + }) + }) + }) + + // Describe("Metadata", func() { + // var md Metadata + // var originalCreatedAt time.Time + // JustBeforeEach(func(ctx context.Context) { + // By("selecting the metadata") + // // Save last loaded created at time for use in next + // // When block + // originalCreatedAt = md.CreatedAt + // var err error + // md, err = SelectMetadata(ctx, db) + // Expect(err).To(Succeed(), "failed to select metadata") + // }) + + // It("should initialize the metadata", func() { + // Expect(md.CreatedAt).To(BeTemporally("~", time.Now(), time.Second), "CreatedAt should be set to now") + // Expect(md.UpdatedAt).To(Equal(md.CreatedAt), "UpdatedAt should be the same as CreatedAt") + // Expect(md.BuildRevision).ToNot(BeEmpty(), "BuildRevision should be set") + // Expect(md.GoVersion).ToNot(BeEmpty(), "GoVersion should be set") + // }) + + // When("the metadata already exists", func() { + // BeforeEach(func(ctx context.Context) { + // By("updating the metadata") + // time.Sleep(time.Second) + // Expect(UpsertMetadata(ctx, db)).To(Succeed(), "failed to upsert metadata") + // }) + + // It("should update the metadata", func() { + // Expect(originalCreatedAt).ToNot(BeZero(), "originalCreatedAt should be set") + // Expect(md.CreatedAt).To(Equal(originalCreatedAt), "CreatedAt should not have changed") + // Expect(md.UpdatedAt).To(BeTemporally(">", md.CreatedAt), "UpdatedAt should be after CreatedAt") + // Expect(md.BuildRevision).ToNot(BeEmpty(), "BuildRevision should be set") + // Expect(md.GoVersion).ToNot(BeEmpty(), "GoVersion should be set") + // }) + // }) + // }) + +}) + +func initMetadata() Metadata { + return Metadata{ + BuildRevision: "test", + GoVersion: "test", + GoModHash: 1024, + GoSumHash: 1024, + } +} + +func initModules() []Module { + return []Module{{ + ID: 1, + ImportPath: "github.com/stretchr/testify", + RelativeDir: "/home/adam/go/pkg/mod/github.com/stretchr/testify@v1.8.1", + }, { + ID: 2, + ImportPath: "github.com/muesli/reflow", + RelativeDir: "/home/adam/go/pkg/mod/github.com/muesli/reflow@v0.3.0", + }, { + ID: 3, + ImportPath: "github.com/onsi/ginkgo/v2", + RelativeDir: "/home/adam/go/pkg/mod/github.com/onsi/ginkgo/v2@v2.11.0", + }} +} + +func initPackages() []Package { + return []Package{{ + ID: 1, + ModuleID: 1, + RelativePath: "", + NumParts: 0, + }, { + ID: 2, + ModuleID: 1, + RelativePath: "assert", + NumParts: 1, + }, { + ID: 3, + ModuleID: 1, + RelativePath: "require", + NumParts: 1, + }, { + ID: 4, + ModuleID: 2, + RelativePath: "indent", + NumParts: 1, + }, { + ID: 5, + ModuleID: 2, + RelativePath: "wordwrap", + NumParts: 1, + }, { + ID: 6, + ModuleID: 2, + RelativePath: "ansi", + NumParts: 1, + }, { + ID: 7, + ModuleID: 2, + RelativePath: "padding", + NumParts: 1, + }, { + ID: 8, + ModuleID: 3, + RelativePath: "", + NumParts: 0, + }, { + ID: 9, + ModuleID: 3, + RelativePath: "types", + NumParts: 1, + }, { + ID: 10, + ModuleID: 3, + RelativePath: "config", + NumParts: 1, + }, { + ID: 11, + ModuleID: 3, + RelativePath: "integration", + NumParts: 1, + }, { + ID: 12, + ModuleID: 3, + RelativePath: "docs", + NumParts: 1, + }, { + ID: 13, + ModuleID: 3, + RelativePath: "extensions/global", + NumParts: 2, + }, { + ID: 14, + ModuleID: 3, + RelativePath: "extensions/table", + NumParts: 2, + }} +} diff --git a/internal/modpkg/db/schema_test.go b/internal/modpkg/db/schema_test.go new file mode 100644 index 0000000..32449b6 --- /dev/null +++ b/internal/modpkg/db/schema_test.go @@ -0,0 +1,74 @@ +package db + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSchema(t *testing.T) { + require := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + db, err := Open(ctx, tmpDir) + require.NoError(err, "Open") + require.NotNil(db, "Open") + t.Cleanup(func() { require.NoError(db.Close(), "DB.Close") }) + + populateDB(t, ctx, db, + NewModulePackages("github.com/good-coder/foobar", "v1.0.0", + "", "foo", "bar", "baz", "foo/baz", "foo/bir/baz"), + NewModulePackages("github.com/bad-coder/foobar", "v1.0.1", + "", "foo", "bur", "buz", "foo/buz", "foo/bur/buz"), + ) + // initial sync + // add modules and packages + // + // check tables + +} + +type ModulePackages struct { + ImportPath string + Version string + Packages []string +} + +func NewModulePackages(importPath, version string, packages ...string) ModulePackages { + for i, pkg := range packages { + pkg = strings.TrimPrefix(pkg, importPath) + pkg = strings.TrimPrefix(pkg, "/") + packages[i] = pkg + } + return ModulePackages{ + ImportPath: importPath, + Version: version, + Packages: packages, + } +} + +func populateDB(t *testing.T, ctx context.Context, db *DB, modPkgs ...ModulePackages) { + require := require.New(t) + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + + sync, err := db.Sync(ctx) + require.NoError(err, "DB.StartSyncIfNeeded") + require.NotNil(sync, "DB.StartSyncIfNeeded") + + for _, mp := range modPkgs { + mod, err := sync.AddModule(ctx, mp.ImportPath, mp.Version) + require.NoError(err, "Sync.AddModule") + require.NotNil(mod, "Sync.AddModule") + + for _, pkg := range mp.Packages { + require.NoError(sync.AddPackage(ctx, mod, pkg), "Sync.AddPackage") + } + } + + require.NoError(sync.Finish(ctx), "Sync.Finish") +} diff --git a/internal/modpkg/db/sql/metadata_select.sql b/internal/modpkg/db/sql/metadata_select.sql new file mode 100644 index 0000000..90c865c --- /dev/null +++ b/internal/modpkg/db/sql/metadata_select.sql @@ -0,0 +1,15 @@ +SELECT + created_at, + updated_at, + + build_revision, + go_version, + + go_mod_hash, + go_sum_hash, + vendor +FROM + metadata +WHERE + rowid = 1 +LIMIT 1; diff --git a/internal/modpkg/db/sql/metadata_upsert.sql b/internal/modpkg/db/sql/metadata_upsert.sql new file mode 100644 index 0000000..abed080 --- /dev/null +++ b/internal/modpkg/db/sql/metadata_upsert.sql @@ -0,0 +1,39 @@ + +INSERT INTO + metadata( + rowid, + + build_revision, + go_version, + + go_mod_hash, + go_sum_hash, + vendor + ) +VALUES ( + 1, + $build_revision, + $go_version, + + $go_mod_hash, + $go_sum_hash, + $vendor +) +ON CONFLICT(rowid) DO + UPDATE SET + updated_at = CURRENT_TIMESTAMP, + ( + build_revision, + go_version, + + go_mod_hash, + go_sum_hash, + vendor + ) = ( + excluded.build_revision, + excluded.go_version, + + excluded.go_mod_hash, + excluded.go_sum_hash, + excluded.vendor + ); diff --git a/internal/modpkg/db/sql/module_update_parent_dir.sql b/internal/modpkg/db/sql/module_update_parent_dir.sql new file mode 100644 index 0000000..164c80b --- /dev/null +++ b/internal/modpkg/db/sql/module_update_parent_dir.sql @@ -0,0 +1,18 @@ +UPDATE + module +SET + parent_dir_id = ( -- set the parent dir id to + iif( + $vendor, + $parent_dir_id_vendor, -- the vendor parent dir id if switching _to_ vendor mode + $parent_dir_id_gomodcache, -- the mod cache dir id if switching _from_ vendor mode + ) + ) +WHERE + parent_dir_id = ( -- only for modules with the parent dir id set to + iif( + NOT $vendor, + $parent_dir_id_vendor, -- the vendor parent dir if switching _from_ vendor mode + $parent_dir_id_gomodcache, -- the mod cache dir if switching _to_ vendor mode + ) + ); diff --git a/internal/modpkg/db/sql/module_update_set_sync_keep_false.sql b/internal/modpkg/db/sql/module_update_set_sync_keep_false.sql new file mode 100644 index 0000000..9d6bfce --- /dev/null +++ b/internal/modpkg/db/sql/module_update_set_sync_keep_false.sql @@ -0,0 +1,5 @@ +UPDATE + module +SET + sync = FALSE, + keep = FALSE; diff --git a/internal/modpkg/db/sql/module_upsert.sql b/internal/modpkg/db/sql/module_upsert.sql new file mode 100644 index 0000000..05905c6 --- /dev/null +++ b/internal/modpkg/db/sql/module_upsert.sql @@ -0,0 +1,25 @@ +-- Upsert modules into the database, marking them to be kept since still in +-- use, and marking the module's packages for sync if they may have changed. +INSERT INTO module ( + import_path, + version +) +VALUES + $import_path, + $version +ON CONFLICT ( + import_path +) DO +UPDATE SET + sync = ( + excluded.version == "" -- sync if the version is empty + OR + version != excluded.version -- or if the version has changed + ), + -- Keep this module since it's still in use. + keep = TRUE, + version = excluded.version +RETURNING + rowid, + sync +;--- diff --git a/internal/modpkg/db/sql/package_select_by_parts.sql b/internal/modpkg/db/sql/package_select_by_parts.sql new file mode 100644 index 0000000..8415055 --- /dev/null +++ b/internal/modpkg/db/sql/package_select_by_parts.sql @@ -0,0 +1,61 @@ +WITH RECURSIVE + matches ( + matched_path, + remaining_path, + part_id, + part_path_depth + ) +AS + ( + VALUES ( + '', + $search_path || '/', + NULL, + 0, + 0 + ) + UNION + SELECT + concat_ws('/', matched_path, part.name) AS matched_path, + substr(remaining_path, instr(remaining_path, '/')+1) AS remaining_path, + part.rowid AS part_id, + part.path_depth AS part_path_depth + FROM + matches, + part + WHERE + remaining_path != '' -- we still have parts to match + AND + ( + matches.part_id IS NULL -- handles initial case where no parts have been matched + OR + part.parent_id = matches.part_id -- the part is a child of a previously matched part + ) + AND + ( -- the part name matches the next part in the path + ( + $exact IS TRUE + AND + part.name = substr(remaining_path, 1, instr(remaining_path, '/')-1) -- match exactly + ) + OR + name LIKE substr(remaining_path, 1, instr(remaining_path, '/')-1) || '%' -- match prefix + ) + ORDER BY + part_path_depth DESC, + length(matched_path) ASC + ) +SELECT + package_import_path, + dir, + matched_path +FROM + matches, part_package USING (part_id), + package_view USING (package_id) +WHERE + remaining_path = '' +ORDER BY + (total_num_parts - part_path_depth) ASC, + length(matched_path) ASC, + total_num_parts ASC +; diff --git a/internal/modpkg/db/sql/package_select_by_parts_2.sql b/internal/modpkg/db/sql/package_select_by_parts_2.sql new file mode 100644 index 0000000..5735dd3 --- /dev/null +++ b/internal/modpkg/db/sql/package_select_by_parts_2.sql @@ -0,0 +1,270 @@ +WITH RECURSIVE +split_search_path( + start_pos, + end_pos, + prev_part, + prev_path_depth, + remaining +) AS ( + SELECT + 1 AS start_pos, + 0 AS end_pos, + '' AS prev_part, + 0 AS prev_path_depth, + trim($search_path, '/') || '/' AS remaining + UNION ALL + SELECT + end_pos + 1 AS start_pos, + instr( + remaining, + '/' + ) AS end_pos, + + iif(end_pos > 0, + substr(remaining, start_pos, end_pos), + '' + ) AS prev_part, + + iif(end_pos > 0, + prev_path_depth + 1, + 0 + ) AS prev_path_depth, + + iif(end_pos > 0, + substr(remaining, end_pos + 1), + remaining + ) AS remaining + FROM + split_search_path + WHERE + remaining != '' +), + +search AS ( + SELECT + prev_part || iif($exact, '', '%') AS name, + prev_path_depth AS path_depth, + count(*) - prev_path_depth AS min_deepest_descendant_distance + FROM + split_search_path + WHERE + prev_part != '' +), + +match_search ( + first_id, + part_id, + search_path_depth +) AS ( + SELECT + part.rowid AS first_id, + part.rowid AS part_id, + search.path_depth AS search_path_depth + FROM + part_view AS part, + search + ON + part.name LIKE search.name + WHERE + search.path_depth = 1 + AND + part.deepest_descendant_distance >= search.min_deepest_descendant_distance + UNION ALL + SELECT + match.first_id AS first_id, + part.rowid AS part_id, + search.path_depth AS search_path_depth + FROM + part_view AS part, + match_search AS match, + search + ON + part.parent_id = match.part_id + AND + part.name LIKE search.name + AND + part.max_descendant_distance >= search.min_max_descendant_distance + AND + search.path_depth = match.search_path_depth + 1 +), + +match AS ( + SELECT + first_id, + part_id + FROM + match_search + WHERE + search_path_depth = ( + SELECT + count(*) + FROM + search + ) +), + +package_match AS ( + SELECT + match.first_id, + match.part_id + FROM + match, + part + ON + part.rowid = match.part_id + WHERE + part.package_id IS NOT NULL +), + +dir_match ( + first_id, + part_id, + package_id, + num_children +) AS ( + SELECT + match.first_id AS first_id, + match.part_id AS part_id, + part.package_id AS package_id, + part.num_children AS num_children, + 0 AS unambiguous_depth + FROM + match, + part + ON + part.rowid = match.part_id + WHERE + part.num_children > 0 + UNION ALL + SELECT + dir_match.first_id AS first_id, + part.rowid AS part_id, + part.package_id AS package_id, + part.num_children AS num_children, + dir_match.unambiguous_depth + 1 AS unambiguous_depth + FROM + dir_match, + part + ON + part.parent_id = dir_match.part_id + WHERE + dir_match.num_children = 1 + AND ( + dir_match.package_id IS NULL + OR + unambiguous_depth = 0 + ) +), + +follow_unambiguous AS ( + SELECT + match.first_id, + match.part_id + FROM + dir_match, + part + ON + part.rowid = dir_match.part_id + WHERE + part.num_children = 1 + UNION ALL + SELECT + follow_unambiguous.first_id, + part.rowid + FROM + follow_unambiguous, + part + ON + part.parent_id = follow_unambiguous.part_id + WHERE + part.package_id IS NOT NULL + OR + part.num_children = 1 +), + +descend_dir_math AS ( + SELECT + deepest_match.first_id, + deepest_match.part_id + FROM + deepest_match, + part + ON + part.rowid = deepest_match.part_id + WHERE + part.package_id IS NULL + AND + part.num_children = 1 + UNION ALL + SELECT + descend_dir_match.first_id AS first_id, + part.rowid AS part_id + FROM + descend_dir_match, + part + ON + part.parent_id = descend_dir_match.part_id + WHERE + + part.package_id IS NULL +), + + + + +unambiguous_child_match( + first_id, + part_id +) AS ( + SELECT + dir_match.first_id, + dir_match.part_id + FROM + dir_match, + part + ON + part.rowid = dir_match.part_id + WHERE + part.num_children = 1 + UNION ALL + SELECT + unambiguous_child_match.first_id, + part.rowid + FROM + unambiguous_child_match, + part + ON + part.parent_id = unambiguous_child_match.part_id + WHERE + part.package_id IS NOT NULL + OR + part.num_children = 1 + +build_full_path(part_id, full_path) AS ( + -- Build full path from root to deepest matching parts + SELECT dm.part_id, dm.matched_path + FROM deepest_matches dm + WHERE dm.parent_id IS NULL + UNION ALL + SELECT p.rowid, concat_ws('/', bfp.full_path, p.name) + FROM part p + JOIN build_full_path bfp ON p.rowid = bfp.part_id +), +descend_unambiguous(part_id, parent_id, first_matched_parent_id, depth, is_package, num_children, matched_path) AS ( + -- Descend to unambiguous parts + SELECT dm.part_id, dm.parent_id, dm.first_matched_parent_id, dm.depth, dm.is_package, dm.num_children, bfp.full_path + FROM deepest_matches dm + JOIN build_full_path bfp ON dm.part_id = bfp.part_id + UNION ALL + SELECT p.rowid, p.parent_id, du.first_matched_parent_id, du.depth + 1, p.package_id IS NOT NULL, p.num_children, + concat_ws('/', du.matched_path, p.name) + FROM part p + JOIN descend_unambiguous du ON p.parent_id = du.part_id + WHERE du.num_children = 1 +) +SELECT ... +FROM descend_unambiguous +JOIN ... -- Join with package or module tables +WHERE ... -- Final conditions + + diff --git a/internal/modpkg/db/sql/package_upsert.sql b/internal/modpkg/db/sql/package_upsert.sql new file mode 100644 index 0000000..ce66454 --- /dev/null +++ b/internal/modpkg/db/sql/package_upsert.sql @@ -0,0 +1,16 @@ +-- Upsert packages into the database, marking them to be kept since still in +-- use. +INSERT INTO + package ( + module_id, + relative_path + ) +VALUES ( + $module_id, + $relative_path +) +ON CONFLICT(module_id, relative_path) + DO UPDATE SET + keep = TRUE +RETURNING + rowid; diff --git a/internal/modpkg/db/sql/parent_dir_upsert.sql b/internal/modpkg/db/sql/parent_dir_upsert.sql new file mode 100644 index 0000000..7df5297 --- /dev/null +++ b/internal/modpkg/db/sql/parent_dir_upsert.sql @@ -0,0 +1,21 @@ +INSERT INTO + parent_dir ( + rowid, + key, + dir + ) +VALUES ( + $rowid, $key, $dir +) +ON CONFLICT +DO UPDATE SET + ( + rowid, + key, + dir + ) = ( + excluded.rowid, + excluded.key, + excluded.dir + ) +; diff --git a/internal/modpkg/db/sql/part_count_children.sql b/internal/modpkg/db/sql/part_count_children.sql new file mode 100644 index 0000000..38cb481 --- /dev/null +++ b/internal/modpkg/db/sql/part_count_children.sql @@ -0,0 +1,18 @@ + +SELECT + count(*) +FROM + part +WHERE + package_id IS NOT NULL +AND + rowid IN ( + SELECT DISTINCT + descendant_id + FROM + part_closure + WHERE + ancestor_id = $part_id + ) +; + diff --git a/internal/modpkg/db/sql/part_select_by_package_id.sql b/internal/modpkg/db/sql/part_select_by_package_id.sql new file mode 100644 index 0000000..38e6e25 --- /dev/null +++ b/internal/modpkg/db/sql/part_select_by_package_id.sql @@ -0,0 +1,18 @@ +SELECT + rowid, + name, + parent_id, + package_id +FROM + part +WHERE + rowid IN ( + SELECT + part_id + FROM + part_package + WHERE + package_id = $package_id + ) +ORDER BY + path_depth ASC; diff --git a/internal/modpkg/db/sql/schema.sql b/internal/modpkg/db/sql/schema.sql new file mode 100644 index 0000000..aeee5d4 --- /dev/null +++ b/internal/modpkg/db/sql/schema.sql @@ -0,0 +1,630 @@ +-- The schema for the database. +-- +-- Use ;--- to separate statements. This is a simple hack to allow for +-- splitting complete statements. Omitting the ;--- won't result in an error, +-- but that statement will be executed together with all subsequent statements +-- until the next ;--- or the end of the file. + +-- metadata stores information about the database. +-- +-- rowid is the primary key and is always 1 to ensure there is only one row. +-- +-- created_at is the time the database was created. +-- +-- updated_at is the time the database was last updated. +-- +-- build_revision is the git revision of the go-doc build which last updated +-- this database. +-- +-- go_version is the version of Go used to build the go-doc binary. +-- +-- go_root is the path to the Go root directory. +-- +-- go_mod_cache is the path to the Go module cache. +-- +-- main_mod_id is the module_id of the main module. +-- +-- go_mod_hash is the CRC32 hash of the go.mod file. +-- +-- go_sum_hash is the CRC32 hash of the go.sum file. +-- +-- vendor is a boolean that indicates whether the main module is vendored. +CREATE TABLE metadata ( + rowid INTEGER PRIMARY KEY + NOT NULL + CHECK (rowid = 1), + + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + + build_revision TEXT NOT NULL CHECK (build_revision != ''), + go_version TEXT NOT NULL CHECK (go_version != ''), + + go_mod_hash INT NOT NULL CHECK (go_mod_hash != 0), + go_sum_hash INT NOT NULL CHECK (go_sum_hash != 0), + vendor BOOL NOT NULL DEFAULT FALSE +) WITHOUT ROWID;--- + +-- module stores all required modules and the directory they are located in. +-- +-- import_path is the module's import path. +-- +-- version is the module's version, if any. +-- +-- relative_dir is the directory the module is located in, relative to the +-- parent_dir.dir referenced by parent_dir_id. +-- +-- parent_dir_id is the parent_dir's rowid. +-- +-- num_segments is the number of slash separated parts in the module's import path. +-- +-- sync is a boolean that indicates whether the module's packages should be +-- synced. Newly inserted modules have sync set to TRUE. Upserted modules have +-- sync set to true if the module's dir has changed. +-- +-- keep is a boolean that indicates whether the module should be kept. At the +-- beginning of a sync, all module's keep are set to FALSE. Newly inserted +-- modules have keep set to TRUE. Upserted modules have keep set to TRUE. After +-- syncing all modules, any modules that have keep set to FALSE are deleted. +CREATE TABLE module ( + rowid INTEGER PRIMARY KEY, + + import_path TEXT NOT NULL UNIQUE CHECK ( + -- must not have leading or trailing slashes + import_path = trim(import_path, '/') + ), + + version TEXT NOT NULL, + go_sum_hash TEXT NOT NULL, + + num_segments INT NOT NULL GENERATED ALWAYS AS ( + iif( + length(import_path) = 0, + 0, + 1 + + length(import_path) + - length(replace(import_path, '/', '')) + ) + ) STORED, + + sync BOOL NOT NULL DEFAULT TRUE, + keep BOOL NOT NULL DEFAULT TRUE +);--- + +-- package stores all packages for all modules. +-- +-- module_id is the module the package belongs to. +-- +-- relative_path is the package's path relative to the module's import_path. +-- This can be empty if the module's import path is an importable package. +-- +-- num_segments is the number of slash separated parts in the package's relative path. +-- +-- keep is a boolean that indicates whether the package should be kept. +-- Whenever a module requires a sync, keep is set to FALSE for all of its +-- packages. Newly inserted packages have keep set to TRUE. When existing +-- packages are upserted, keep is set back to TRUE. After syncing all packages, +-- any that have keep set to FALSE are deleted. +CREATE TABLE package ( + rowid INTEGER PRIMARY KEY, + + module_id INT NOT NULL + REFERENCES module(rowid) + ON DELETE CASCADE + ON UPDATE CASCADE, + + in_mod_path TEXT NOT NULL UNIQUE CHECK ( + -- must not have leading or trailing slashes + in_mod_path = trim(in_mod_path, '/') + ), + UNIQUE(module_id, in_mod_path), + + num_segments INT NOT NULL GENERATED ALWAYS AS ( + iif( + length(in_mod_path) = 0, + 0, + 1 + + length(in_mod_path) + - length(replace(in_mod_path, '/', '')) + ) + ) STORED, + + keep BOOL NOT NULL DEFAULT TRUE +);--- + +-- package_view is a view that joins module and package information. +-- +-- package_id is the package's rowid. +-- +-- package_import_path is the package's import path. +-- +-- package_dir is the directory the package is located in. +-- +-- module_id is the module's rowid. +-- +-- module_import_path is the module's import path. +-- +-- relative_path is the package's path relative to the module's import_path. +-- +-- class is an integer that represents the type of module. +-- +-- relative_num_segments is the number of slash separated parts in the package's relative_path. +-- +-- total_num_segments is the number of slash separated parts in the package_import_path. +CREATE VIEW package_view ( + rowid, + import_path, + num_segments, + module_id, + module_import_path, + module_num_segments +) AS SELECT + package.rowid AS rowid, + concat_ws( + '/', + module.import_path, + package.in_mod_path + ) AS import_path, + package.num_segments AS num_segments, + module.rowid AS module_id, + module.import_path AS module_path, + module.num_segments AS module_num_segments +FROM + package, + module +ON + package.module_id = module.rowid +ORDER BY + module.num_segments ASC, + module.import_path ASC, + package.num_segments ASC, + package.in_mod_path ASC +;--- + +CREATE TABLE import_path_segment_name ( + rowid INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL CHECK (name != '') +);--- + +CREATE TABLE import_path_segment ( + rowid INTEGER PRIMARY KEY, + parent_id INT NOT NULL + REFERENCES import_path_segment(rowid) + ON DELETE CASCADE + ON UPDATE RESTRICT, + name_id INT NOT NULL + REFERENCES import_path_segment_name(rowid) + ON DELETE RESTRICT + ON UPDATE CASCADE, + UNIQUE(parent_id, name_id), + + path_depth INT NOT NULL CHECK (path_depth > 0), + package_id INT UNIQUE + REFERENCES package(rowid) + ON DELETE SET NULL + ON UPDATE CASCADE +);--- + +CREATE VIEW import_path_segment_view ( + rowid, + parent_id, + name, + path_depth, + package_id, + num_children, + num_descendants, + max_descendant_distance, +) AS +SELECT + segment.rowid AS rowid, + segment.parent_id AS parent_id, + segmant_name.name AS name, + segment.path_depth AS path_depth, + segment.package_id AS package_id, + closure.num_children, + closure.num_descendants, + closure.max_descendant_distance +FROM + import_path_segment AS segment +JOIN + import_path_segment_name AS segment_name +ON + segment.name_id = segment_name.rowid +JOIN + import_path_segment_closure_view AS closure +ON + closure.segment_id = segment.rowid +;--- + +CREATE TABLE import_path_segment_package ( + package_id INT NOT NULL + REFERENCES package(rowid) + ON DELETE CASCADE + ON UPDATE CASCADE, + segment_id INT NOT NULL + REFERENCES import_path_segment(rowid) + ON DELETE CASCADE + ON UPDATE CASCADE, + PRIMARY KEY(package_id, segment_id), + + path_depth INT NOT NULL CHECK (path_depth > 0), + UNIQUE(package_id, path_depth) +) WITHOUT ROWID;--- + +CREATE TABLE import_path_segment_closure ( + ancestor_id INT NOT NULL + REFERENCES import_path_segment(rowid) + ON DELETE CASCADE + ON UPDATE CASCADE, + descendant_id INT NOT NULL + REFERENCES import_path_segment(rowid) + ON DELETE CASCADE + ON UPDATE CASCADE, + PRIMARY KEY(ancestor_id, descendant_id), + distance INT NOT NULL CHECK (distance >= 0) +) WITHOUT ROWID;--- + +CREATE VIEW import_path_segment_closure_view ( + segment_id, + num_children, + num_descendants, + max_descendant_distance +) +AS SELECT + closure.ancestor_id AS segment_id, + count(closure.descendant_id) + FILTER ( + WHERE closure.distance = 1 + ) AS num_children, + count(closure.descendant_id) AS num_descendants, + max(closure.distance) AS descendant_max_distance +FROM + import_path_segment_closure AS closure +GROUP BY + closure.ancestor_id +;--- + + +-- import_path_segment_split_view is a view that exists purely to allow for an INSTEAD OF +-- trigger to be used, which automatically splits a package path into parts. +CREATE VIEW + import_path_segment_split_view ( + package_id, + total_num_segments, + path_depth, + segment_parent_id, + segment_name, + remaining_path + ) +AS +VALUES ( + NULL, + NULL, + NULL, + NULL, + NULL, + NULL +);--- + +-- import_path_segment_splitter_on_insert_package is a trigger that fires whenever +-- a new package is inserted. +-- +-- This initiates a recursive trigger chain that splits the package's path and +-- inserts them into the part table. +CREATE TRIGGER + import_path_segment_splitter_on_insert_package +AFTER + INSERT ON + package +BEGIN + INSERT INTO + import_path_segment_split_view ( + package_id, + total_num_segments, + path_depth, + segment_parent_id, + segment_name, + remaining_path + ) + SELECT + package.rowid AS package_id, + package.num_segments + + module.num_segments AS total_num_segments, + 0 AS path_depth, + NULL AS segment_parent_id, + '' AS segment_name, + package.import_path || '/' AS remaining_path + FROM + package, module + ON + package.module_id = module.rowid + WHERE + package.rowid = new.rowid; +END;--- + +-- recursive_package_part_splitter is a recursive trigger that fires instead of +-- inserting to the import_path_segment_split_view view. +-- +-- It inserts the current part into the part and part_package tables. +-- +-- It then inserts the next part, if any, into the import_path_segment_split_view view. +CREATE TRIGGER + recursive_package_part_splitter +INSTEAD OF + INSERT ON + import_path_segment_split_view + WHEN + new.path_depth <= new.total_num_segments +BEGIN + + -- insert the segment name + INSERT INTO + import_path_segment_name ( + name + ) + SELECT + new.segment_name + WHERE + new.path_depth > 0 + ON CONFLICT DO + NOTHING; + + -- insert the segment + INSERT INTO + import_path_segment ( + parent_id, + name_id, + path_depth, + package_id + ) + SELECT + new.segment_parent_id AS parent_id, + -- if the segment name was inserted, changes() will be 1, so we can use + -- last_insert_rowid() and avoid the subquery + iif( + changes() > 0, + last_insert_rowid(), + ( + SELECT + rowid + FROM + import_path_segment_name + WHERE + name IS new.segment_name + ) + ) AS name_id, + new.path_depth AS path_depth, + + -- the package_id is only set on the final segment, otherwise it is NULL + iif( + new.path_depth = new.total_num_segments, + new.package_id, + NULL + ) AS package_id + WHERE + new.path_depth > 0 + ON CONFLICT DO + UPDATE SET + package_id = excluded.package_id + WHERE + new.path_depth = new.total_num_segments; + + INSERT INTO + import_path_segment_split_view ( + package_id, + total_num_segments, + path_depth, + segment_parent_id, + segment_name, + remaining_path + ) + SELECT + new.package_id AS package_id, + new.total_num_segments AS total_num_segments, + new.path_depth+1 AS path_depth, + + -- the first iteration will have NULL segment_parent_id + iif( + new.path_depth = 0, + NULL, + + -- if the segment was inserted, changes() will be 1, so we can avoid the + -- subquery and use last_insert_rowid() + iif( + changes() > 0, + last_insert_rowid(), + ( + SELECT + rowid + FROM + import_path_segment_view + WHERE + parent_id IS new.segment_parent_id + AND + name IS new.segment_name + ) + ) + ) AS segment_parent_id, + + substr(new.remaining_path, 1, slash-1) AS segment_name, + substr(new.remaining_path, slash+1) AS remaining_path + FROM ( + -- find the position of the first slash + SELECT + instr(new.remaining_path, '/') AS slash + ) + WHERE + -- when new.path_depth = new.total_num_segments, we have inserted all + -- segments and we are done + new.path_depth < new.total_num_segments; + +END;--- + +-- insert_part_path_closure is a trigger that fires whenever a new part is +-- inserted into the part table. It populates the part_path table with the new +-- part and all of its ancestors. +CREATE TRIGGER + build_import_path_segment_closure +AFTER + INSERT ON + import_path_segment +BEGIN + + -- insert all ancestors of the new segment, including the segment itself + INSERT INTO + import_path_segment_closure ( + ancestor_id, + descendant_id, + distance + ) + -- the new segment is its own ancestor, with a distance of 0 + SELECT + new.rowid AS ancestor_id, + new.rowid AS descendant_id, + 0 AS distance + UNION ALL + -- all ancestors of the new segment's parent, are also ancestors of the new + -- segment with a distance of 1 more than the distance to the parent + SELECT + closure.ancestor_id AS ancestor_id, + new.rowid AS descendant_id, + closure.distance + 1 AS distance + FROM + import_path_segment_closure AS closure + WHERE + closure.ancestor_id IS new.parent_id; + +END;--- + +-- prune_leaf_parts_with_null_package_id is a trigger that fires whenever +-- a part is updated such that its package_id is set to NULL. It deletes the +-- part if it has no children. +CREATE TRIGGER + prune_leaf_segments_with_null_package_id +AFTER + UPDATE OF + package_id + ON + import_path_segment + WHEN + new.package_id IS NULL +BEGIN + + DELETE FROM + import_path_segment + WHERE + rowid = new.rowid + AND + NOT EXISTS ( + SELECT + 1 + FROM + import_path_segment AS segment + WHERE + segment.parent_id IS new.rowid + ); + +END;--- + +-- recursively_prune_leaf_parts_with_null_package_id recursively deletes leaf +-- parts that have a NULL package_id. +CREATE TRIGGER + recursively_prune_leaf_parts_with_null_package_id +AFTER + DELETE ON + import_path_segment +BEGIN + + DELETE FROM + import_path_segment + WHERE + package_id IS NULL + AND + NOT EXISTS ( + SELECT + 1 + FROM + part AS p + WHERE + p.parent_id IS part.rowid + ); + +END;--- + +-- set_package_keep_false_for_modules_with_sync_true fires whenever a module is +-- updated such that sync is set to TRUE. It sets package.keep to FALSE for all +-- of the module's packages. As packages are re-synced, package.keep is set +-- back to TRUE so that only packages that are no longer available in the +-- module are left with keep set to FALSE, resulting in them being deleted to +-- finalize a sync. +CREATE TRIGGER + set_package_keep_false_for_modules_with_sync_true +AFTER + UPDATE OF + sync + ON + module + WHEN + new.sync = TRUE +BEGIN + UPDATE + package + SET + keep = FALSE + WHERE + package.module_id = new.rowid; +END;--- + +-- after_update_metadata_prune_module_package fires after the metadata has been +-- updated. It performs some sanity checks to ensure that the sync was +-- completed successfully and then deletes all modules and packages that have +-- keep set to FALSE. +CREATE TRIGGER + after_update_metadata_prune_module_package +AFTER + UPDATE ON + metadata +BEGIN + SELECT + RAISE(ABORT, 'invalid sync: no modules were synced') + WHERE + NOT EXISTS ( + SELECT + 1 + FROM + module + WHERE + keep = TRUE + ); + + SELECT + RAISE(ABORT, 'invalid sync: no packages were synced for one or more modules') + WHERE + EXISTS ( + SELECT + 1 + FROM ( + SELECT + count(*) AS num_pkgs + FROM + package + WHERE + keep = TRUE + GROUP BY + module_id + ) + WHERE + num_pkgs = 0 + ); + + DELETE FROM + module + WHERE + keep = FALSE; + + DELETE FROM + package + WHERE + keep = FALSE; + +END;--- diff --git a/internal/modpkg/db/sync.go b/internal/modpkg/db/sync.go new file mode 100644 index 0000000..b63e010 --- /dev/null +++ b/internal/modpkg/db/sync.go @@ -0,0 +1,121 @@ +package db + +import ( + "context" + _ "embed" + "errors" + "fmt" + + "aslevy.com/go-doc/internal/sql" +) + +type Sync struct { + tx *sql.Tx + db *DB + Current *Metadata + stmt syncStmts +} + +type syncStmts struct { + upsertModule *sql.Stmt + upsertPkg *sql.Stmt +} + +func (db *DB) Sync(ctx context.Context) (_ *Sync, rerr error) { + current, err := db.stored.NeedsSync() + if err != nil { + return nil, nil + } + if current == nil { + return nil, nil + } + + tx, err := db.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("failed to being transaction: %w", err) + } + defer tx.RollbackOnError(&rerr) + + upsertModStmt, err := prepareUpsertModule(ctx, tx) + if err != nil { + return nil, err + } + + if err := setAllModuleSyncKeepFalse(ctx, tx); err != nil { + return nil, err + } + + return &Sync{ + tx: tx, + db: db, + Current: current, + stmt: syncStmts{ + upsertModule: upsertModStmt, + }, + }, nil +} + +//go:embed sql/module_update_set_sync_keep_false.sql +var queryModuleUpdateSetSyncKeepFalse string + +func setAllModuleSyncKeepFalse(ctx context.Context, db sql.Querier) error { + _, err := db.ExecContext(ctx, queryModuleUpdateSetSyncKeepFalse) + return err +} + +func (s *Sync) AddModule(ctx context.Context, importPath, version string) (_ *Module, rerr error) { + defer s.tx.RollbackOnError(&rerr) + + var mod Module + mod.ImportPath = importPath + mod.Version = version + + needSync, err := s.upsertModule(ctx, &mod) + if err != nil { + return nil, err + } + if !needSync { + return nil, nil + } + if s.stmt.upsertPkg == nil { + if err := s.prepareStmtUpsertPackage(ctx); err != nil { + return nil, err + } + } + + return &mod, nil +} + +func (s *Sync) AddPackage(ctx context.Context, mod *Module, pkgImportPath string) (rerr error) { + defer s.tx.RollbackOnError(&rerr) + return s.upsertPackage(ctx, &Package{ + ModuleID: mod.ID, + RelativePath: pkgImportPath[len(mod.ImportPath):], + }) +} + +func (s *Sync) Finish(ctx context.Context) (rerr error) { + defer func() { + if s.stmt.upsertPkg != nil { + if err := s.stmt.upsertPkg.Close(); err != nil { + rerr = errors.Join(rerr, fmt.Errorf("failed to close upsert package statement: %w", err)) + } + } + if err := s.stmt.upsertModule.Close(); err != nil { + rerr = errors.Join(rerr, fmt.Errorf("failed to close upsert module statement: %w", err)) + } + }() + + if err := s.finish(ctx); err != nil { + return fmt.Errorf("failed to finish sync: %w", err) + } + + return nil +} +func (s *Sync) finish(ctx context.Context) (rerr error) { + defer s.tx.RollbackOnError(&rerr) + if err := s.upsertMetadata(ctx, s.Current); err != nil { + return err + } + return s.tx.Commit() +} diff --git a/internal/modpkg/dirs.go b/internal/modpkg/dirs.go new file mode 100644 index 0000000..4321bf3 --- /dev/null +++ b/internal/modpkg/dirs.go @@ -0,0 +1,86 @@ +package modpkg + +import ( + "context" + "fmt" + "strings" + + "aslevy.com/go-doc/internal/dlog" + "aslevy.com/go-doc/internal/godoc" +) + +func (modPkg *ModPkg) FilterExact(importPath string) error { + return modPkg.filter(context.TODO(), importPath, true) +} +func (modPkg *ModPkg) FilterPartial(importPath string) error { + return modPkg.filter(context.TODO(), importPath, false) +} +func (modPkg *ModPkg) filter(ctx context.Context, importPath string, exact bool) error { + if modPkg.search == importPath && modPkg.exact == exact { + return nil + } + + if modPkg.g != nil { + if err := modPkg.g.Wait(); err != nil { + return fmt.Errorf("failed to wait for sync: %w", err) + } + modPkg.g = nil + } else if modPkg.rows != nil { + if err := modPkg.rows.Close(); err != nil { + return fmt.Errorf("failed to close previously open package search query: %w", err) + } + modPkg.offset = 0 + modPkg.results = modPkg.results[:0] + } + + modPkg.exact = exact + modPkg.search = importPath + + parts := strings.Split(importPath, "/") + var err error + modPkg.rows, err = modPkg.db.SelectPackagesByPartsRows(ctx, modPkg.exact, parts) + return err +} + +func (modPkg *ModPkg) Reset() { modPkg.offset = 0 } +func (modPkg *ModPkg) Next() (godoc.PackageDir, bool) { + if modPkg.offset < len(modPkg.results) { + // We have a result in memory. + next := modPkg.offset + modPkg.offset++ + return modPkg.results[next], true + } + // We don't have a result in memory. + + // The database query was already closed so we have no more results. + if modPkg.rows == nil { + return godoc.PackageDir{}, false + } + + // Check if there is a next row. + if !modPkg.rows.Next() { + // See if there is an error or if we're just at the end of the + // results. + if err := modPkg.rows.Err(); err != nil { + dlog.Printf("failed to read next package from db: %v", err) + } + // Close the database query since we're done with it. + if err := modPkg.rows.Close(); err != nil { + dlog.Printf("failed to close rows for package query: %v", err) + } + // Set the database query to nil so we know we're done with it. + modPkg.rows = nil + // Return that there are no more results. + return godoc.PackageDir{}, false + } + + // Read the next row from the database query. + var pkg godoc.PackageDir + if err := modPkg.rows.Scan(&pkg.ImportPath, &pkg.Dir); err != nil { + dlog.Printf("failed to scan next package from db: %v", err) + return godoc.PackageDir{}, false + } + modPkg.results = append(modPkg.results, pkg) + modPkg.offset++ + return pkg, true +} diff --git a/internal/modpkg/modcache.go b/internal/modpkg/modcache.go new file mode 100644 index 0000000..ac603e4 --- /dev/null +++ b/internal/modpkg/modcache.go @@ -0,0 +1,114 @@ +package modpkg + +import ( + "context" + "fmt" + "log" + "os" + "path" + "path/filepath" + "runtime" + "strings" + + "aslevy.com/go-doc/internal/dlog" + "aslevy.com/go-doc/internal/godoc" + "aslevy.com/go-doc/internal/modpkg/db" + "aslevy.com/go-doc/internal/progressbar" + "golang.org/x/sync/errgroup" +) + +func (modPkg *ModPkg) syncFromGoModCache(ctx context.Context, progressBar *progressbar.ProgressBar, sync *db.Sync, coderoots []godoc.PackageDir) (rerr error) { + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(runtime.NumCPU()*2 + 1) + g.Go(func() error { + for _, root := range coderoots { + mod, err := sync.AddModule(ctx, root.ImportPath, root.Version) + if err != nil { + return err + } + if mod == nil { + progressBar.Add(1) + continue + } + + g.Go(func() error { + if err := modPkg.syncModulePackages(ctx, sync, mod); err != nil { + return fmt.Errorf("failed to sync module packages: %w", err) + } + progressBar.Add(1) + return nil + }) + } + return nil + }) + return g.Wait() +} + +var dlogSync = dlog.Child("sync") + +func (modPkg *ModPkg) syncModulePackages(ctx context.Context, sync *db.Sync, mod *db.Module) error { + mod.Dir = filepath.Clean(mod.Dir) // because filepath.Join will do it anyway + dlogSync.Printf("syncing packages for module %q in %q", mod.ImportPath, mod.Dir) + + // this is the queue of directories to examine in this pass. + this := []godoc.PackageDir{} + // next is the queue of directories to examine in the next pass. + next := []godoc.PackageDir{mod.PackageDir} + + for len(next) > 0 && ctx.Err() == nil { + dlogSync.Printf("descending") + this, next = next, this[0:0] + for _, pkg := range this { + dlogSync.Printf("walking %q", pkg) + fd, err := os.Open(pkg.Dir) + if err != nil { + log.Print(err) + continue + } + + entries, err := fd.Readdir(0) + fd.Close() + if err != nil { + log.Print(err) + continue + } + hasGoFiles := false + for _, entry := range entries { + name := entry.Name() + // For plain files, remember if this directory contains any .go + // source files, but ignore them otherwise. + if !entry.IsDir() { + if !hasGoFiles && strings.HasSuffix(name, ".go") { + hasGoFiles = true + if err := sync.AddPackage(ctx, mod, pkg.ImportPath); err != nil { + return err + } + } + continue + } + // Entry is a directory. + + // The go tool ignores directories starting with ., _, or named "testdata". + if name[0] == '.' || name[0] == '_' || name == "testdata" { + continue + } + // Ignore vendor directories and stop at module boundaries. + if name == "vendor" { + continue + } + if fi, err := os.Stat(filepath.Join(pkg.Dir, name, "go.mod")); err == nil && !fi.IsDir() { + continue + } + // Remember this (fully qualified) directory for the next pass. + subPkg := godoc.PackageDir{ + ImportPath: path.Join(pkg.ImportPath, name), + Dir: filepath.Join(pkg.Dir, name), + } + dlogSync.Printf("queuing %q", subPkg.ImportPath) + next = append(next, subPkg) + } + } + } + + return ctx.Err() +} diff --git a/internal/modpkg/modpkg.go b/internal/modpkg/modpkg.go new file mode 100644 index 0000000..5f721ef --- /dev/null +++ b/internal/modpkg/modpkg.go @@ -0,0 +1,82 @@ +package modpkg + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "aslevy.com/go-doc/internal/godoc" + "aslevy.com/go-doc/internal/modpkg/db" + "aslevy.com/go-doc/internal/progressbar" + "golang.org/x/sync/errgroup" +) + +type ModPkg struct { + db *db.DB + g *errgroup.Group + + rows *sql.Rows + results []godoc.PackageDir + offset int + + search string + exact bool +} + +func New(ctx context.Context, mainModDir string, coderoots []godoc.PackageDir) (*ModPkg, error) { + db, err := db.Open(ctx, mainModDir) + if err != nil { + return nil, err + } + + modPkg := ModPkg{db: db} + modPkg.g, ctx = errgroup.WithContext(ctx) + modPkg.g.Go(func() error { return modPkg.sync(ctx, coderoots) }) + + return &modPkg, nil +} + +func (modPkg *ModPkg) sync(ctx context.Context, coderoots []godoc.PackageDir) (rerr error) { + + sync, err := modPkg.db.Sync(ctx) + if err != nil { + return err + } + if sync == nil { + // No need to sync. + return nil + } + progressBar := progressbar.New(len(coderoots)) + defer func() { + if err := sync.Finish(ctx); err != nil { + rerr = errors.Join(rerr, fmt.Errorf("failed to finish sync: %w", err)) + } + if err := progressBar.Close(); err != nil { + rerr = errors.Join(rerr, fmt.Errorf("failed to close progress bar: %w", err)) + } + }() + + if sync.Current.Vendor { + return modPkg.syncFromVendorDir(ctx, sync, coderoots[0]) + } + return modPkg.syncFromGoModCache(ctx, progressBar, sync, coderoots) +} + +func (modPkg *ModPkg) Close() error { + var rerr error + if modPkg.g != nil { + if err := modPkg.g.Wait(); err != nil { + rerr = errors.Join(rerr, fmt.Errorf("failed to wait for sync: %w", err)) + } + } else if modPkg.rows != nil { + if err := modPkg.rows.Close(); err != nil { + rerr = errors.Join(rerr, fmt.Errorf("failed to close previously open package search query: %w", err)) + } + } + if err := modPkg.db.Close(); err != nil { + rerr = errors.Join(rerr, fmt.Errorf("failed to close module/package database: %w", err)) + } + + return rerr +} diff --git a/internal/modpkg/vendor.go b/internal/modpkg/vendor.go new file mode 100644 index 0000000..706807f --- /dev/null +++ b/internal/modpkg/vendor.go @@ -0,0 +1,26 @@ +package modpkg + +import ( + "context" + + "aslevy.com/go-doc/internal/godoc" + "aslevy.com/go-doc/internal/modpkg/db" + "aslevy.com/go-doc/internal/vendored" +) + +func (modPkg *ModPkg) syncFromVendorDir(ctx context.Context, sync *db.Sync, vendor godoc.PackageDir) error { + return vendored.Parse(ctx, vendor.Dir, func(ctx context.Context, modDir godoc.PackageDir) (vendored.PackageHandler, error) { + mod, err := sync.AddModule(ctx, modDir.ImportPath, modDir.Version) + if err != nil { + return nil, err + } + if mod == nil { + return nil, nil + } + + handlePackage := func(ctx context.Context, pkgImportPath string) error { + return sync.AddPackage(ctx, mod, pkgImportPath) + } + return handlePackage, nil + }) +} diff --git a/internal/progressbar/progressbar.go b/internal/progressbar/progressbar.go new file mode 100644 index 0000000..49e6ba0 --- /dev/null +++ b/internal/progressbar/progressbar.go @@ -0,0 +1,24 @@ +package progressbar + +import ( + "os" + "time" + + "github.com/schollz/progressbar/v3" +) + +type ProgressBar = progressbar.ProgressBar + +func New(totalNumMods int) *progressbar.ProgressBar { + return progressbar.NewOptions(totalNumMods, + progressbar.OptionSetDescription("indexing modules..."), + progressbar.OptionSetWriter(os.Stderr), + progressbar.OptionThrottle(time.Second/3), + progressbar.OptionShowCount(), // show current count e.g. 3/5 + progressbar.OptionClearOnFinish(), // clear bar when done + progressbar.OptionSetPredictTime(false), + progressbar.OptionSetElapsedTime(false), + progressbar.OptionEnableColorCodes(true), + progressbar.OptionUseANSICodes(true), + ) +} diff --git a/internal/slices/pop.go b/internal/slices/pop.go new file mode 100644 index 0000000..134ff19 --- /dev/null +++ b/internal/slices/pop.go @@ -0,0 +1,9 @@ +package slices + +func PopFirst[T any](s []T) (T, []T) { + return s[0], s[1:] +} + +func PopLast[T any](s []T) (T, []T) { + return s[len(s)-1], s[:len(s)-1] +} diff --git a/internal/sql/conn.go b/internal/sql/conn.go new file mode 100644 index 0000000..300d744 --- /dev/null +++ b/internal/sql/conn.go @@ -0,0 +1,45 @@ +package sql + +import ( + "context" + "database/sql" +) + +type Conn struct{ *sql.Conn } + +func (db *DB) Conn(ctx context.Context) (*Conn, error) { + conn, err := db.DB.Conn(ctx) + if err != nil { + return nil, err + } + return &Conn{conn}, nil +} + +func (conn *Conn) Begin() (*Tx, error) { return conn.BeginTx(context.Background(), nil) } +func (conn *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { + tx, err := conn.Conn.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &Tx{tx}, nil +} + +func (conn *Conn) ExecContext(ctx context.Context, query string, args ...any) (Result, error) { + return _ExecContext(conn.Conn, ctx, query, args...) +} + +func (conn *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + return _PrepareContext(conn.Conn, ctx, query) +} + +func (conn *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { + return _QueryContext(conn.Conn, ctx, query, args...) +} + +func (conn *Conn) QueryRow(query string, args ...any) *Row { + return conn.QueryRowContext(context.Background(), query, args...) +} +func (conn *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row { + row := conn.Conn.QueryRowContext(ctx, query, args...) + return &Row{row, query, args} +} diff --git a/internal/sql/db.go b/internal/sql/db.go new file mode 100644 index 0000000..12515cc --- /dev/null +++ b/internal/sql/db.go @@ -0,0 +1,54 @@ +package sql + +import ( + "context" + "database/sql" +) + +type DB struct{ *sql.DB } + +func Open(driverName, dataSourceName string) (*DB, error) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + return nil, err + } + return &DB{db}, nil +} + +func (db *DB) Begin() (*Tx, error) { return db.BeginTx(context.Background(), nil) } +func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { + tx, err := db.DB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &Tx{tx}, nil +} + +func (db *DB) Exec(query string, args ...any) (Result, error) { + return db.ExecContext(context.Background(), query, args...) +} +func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) { + return _ExecContext(db.DB, ctx, query, args...) +} + +func (db *DB) Prepare(query string) (*Stmt, error) { + return db.PrepareContext(context.Background(), query) +} +func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + return _PrepareContext(db.DB, ctx, query) +} + +func (db *DB) Query(query string, args ...any) (*Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} +func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { + return _QueryContext(db.DB, ctx, query, args...) +} + +func (db *DB) QueryRow(query string, args ...any) *Row { + return db.QueryRowContext(context.Background(), query, args...) +} +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row { + row := db.DB.QueryRowContext(ctx, query, args...) + return &Row{row, query, args} +} diff --git a/internal/sql/err.go b/internal/sql/err.go new file mode 100644 index 0000000..0db680c --- /dev/null +++ b/internal/sql/err.go @@ -0,0 +1,44 @@ +package sql + +import ( + "errors" + "fmt" + "strings" +) + +func newErrFailedQuery(err error, verb, query string, args ...any) error { + if errors.Is(err, &ErrFailedQuery{}) { + // Don't wrap errors that are already ErrFailedQuery. + return err + } + return ErrFailedQuery{ + Err: err, + Verb: verb, + Query: query, + Args: args, + } +} + +type ErrFailedQuery struct { + Err error + Verb string + Query string + Args []any +} + +func (err ErrFailedQuery) Unwrap() error { return err.Err } +func (err ErrFailedQuery) Error() string { + format := ` +failed to %s: %v + +query: +%s + +args: %v +`[1:] // skip leading newline + return fmt.Sprintf(format, + err.Verb, err.Err, + strings.TrimSpace(err.Query), + err.Args, + ) +} diff --git a/internal/sql/querier.go b/internal/sql/querier.go new file mode 100644 index 0000000..479c004 --- /dev/null +++ b/internal/sql/querier.go @@ -0,0 +1,53 @@ +package sql + +import ( + "context" + "database/sql" +) + +type Querier interface { + Exec(query string, args ...any) (Result, error) + Prepare(query string) (*Stmt, error) + Query(query string, args ...any) (*Rows, error) + QueryRow(query string, args ...any) *Row + + QuerierContext +} + +type QuerierContext interface { + ExecContext(ctx context.Context, query string, args ...any) (Result, error) + PrepareContext(ctx context.Context, query string) (*Stmt, error) + QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *Row +} + +type sqlQuerierContext interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +func _ExecContext(q sqlQuerierContext, ctx context.Context, query string, args ...any) (Result, error) { + res, err := q.ExecContext(ctx, query, args...) + if err != nil { + return nil, newErrFailedQuery(err, "exec", query, args...) + } + return res, nil +} + +func _PrepareContext(q sqlQuerierContext, ctx context.Context, query string) (*Stmt, error) { + stmt, err := q.PrepareContext(ctx, query) + if err != nil { + return nil, newErrFailedQuery(err, "prepare", query) + } + return &Stmt{stmt, query}, nil +} + +func _QueryContext(q sqlQuerierContext, ctx context.Context, query string, args ...any) (*Rows, error) { + rows, err := q.QueryContext(ctx, query, args...) + if err != nil { + return nil, newErrFailedQuery(err, "query", query, args...) + } + return rows, nil +} diff --git a/internal/sql/row.go b/internal/sql/row.go new file mode 100644 index 0000000..6b223bf --- /dev/null +++ b/internal/sql/row.go @@ -0,0 +1,25 @@ +package sql + +import ( + "database/sql" +) + +type Row struct { + *sql.Row + query string + args []any +} + +func (row *Row) Err() error { + if err := row.Row.Err(); err != nil { + return newErrFailedQuery(err, "query", row.query, row.args...) + } + return nil +} + +func (row *Row) Scan(dest ...any) error { + if err := row.Row.Scan(dest...); err != nil { + return newErrFailedQuery(err, "scan", row.query, row.args...) + } + return nil +} diff --git a/internal/sql/scanner.go b/internal/sql/scanner.go new file mode 100644 index 0000000..1c579e2 --- /dev/null +++ b/internal/sql/scanner.go @@ -0,0 +1,11 @@ +package sql + +var ( + _ RowScanner = (*Row)(nil) + _ RowScanner = (*Rows)(nil) +) + +type RowScanner interface { + Scan(dest ...any) error + Err() error +} diff --git a/internal/sql/sql.go b/internal/sql/sql.go new file mode 100644 index 0000000..f679a4f --- /dev/null +++ b/internal/sql/sql.go @@ -0,0 +1,51 @@ +// Package sql is a wrapper around database/sql that provides some convenience +// features for error handling and transactions. Most exported symbols are +// passthrough to the original package except for DB, Conn, Stmt, Tx, and Row. +package sql + +import "database/sql" + +const ( + LevelDefault = sql.LevelDefault + LevelReadUncommitted = sql.LevelReadUncommitted + LevelReadCommitted = sql.LevelReadCommitted + LevelWriteCommitted = sql.LevelWriteCommitted + LevelRepeatableRead = sql.LevelRepeatableRead + LevelSnapshot = sql.LevelSnapshot + LevelSerializable = sql.LevelSerializable + LevelLinearizable = sql.LevelLinearizable +) + +var ( + ErrConnDone = sql.ErrConnDone + ErrNoRows = sql.ErrNoRows + ErrTxDone = sql.ErrTxDone + + Drivers = sql.Drivers + Register = sql.Register + + OpenDB = sql.OpenDB + + Named = sql.Named +) + +type ( + ColumnType = sql.ColumnType + DBStats = sql.DBStats + IsolationLevel = sql.IsolationLevel + NamedArg = sql.NamedArg + NullBool = sql.NullBool + NullByte = sql.NullByte + NullFloat64 = sql.NullFloat64 + NullInt16 = sql.NullInt16 + NullInt32 = sql.NullInt32 + NullInt64 = sql.NullInt64 + NullString = sql.NullString + NullTime = sql.NullTime + Out = sql.Out + RawBytes = sql.RawBytes + Result = sql.Result + Rows = sql.Rows + Scanner = sql.Scanner + TxOptions = sql.TxOptions +) diff --git a/internal/sql/stmt.go b/internal/sql/stmt.go new file mode 100644 index 0000000..6e12d2d --- /dev/null +++ b/internal/sql/stmt.go @@ -0,0 +1,41 @@ +package sql + +import ( + "context" + "database/sql" +) + +type Stmt struct { + *sql.Stmt + query string +} + +func (stmt *Stmt) Exec(args ...any) (Result, error) { + return stmt.ExecContext(context.Background(), args...) +} +func (stmt *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) { + res, err := stmt.Stmt.ExecContext(ctx, args...) + if err != nil { + return nil, newErrFailedQuery(err, "exec", stmt.query, args...) + } + return res, nil +} + +func (stmt *Stmt) Query(args ...any) (*Rows, error) { + return stmt.QueryContext(context.Background(), args...) +} +func (stmt *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) { + rows, err := stmt.Stmt.QueryContext(ctx, args...) + if err != nil { + return nil, newErrFailedQuery(err, "query", stmt.query, args...) + } + return rows, nil +} + +func (stmt *Stmt) QueryRow(args ...any) *Row { + return stmt.QueryRowContext(context.Background(), args...) +} +func (stmt *Stmt) QueryRowContext(ctx context.Context, args ...any) *Row { + row := stmt.Stmt.QueryRowContext(ctx, args...) + return &Row{row, stmt.query, args} +} diff --git a/internal/sql/tx.go b/internal/sql/tx.go new file mode 100644 index 0000000..362ac0c --- /dev/null +++ b/internal/sql/tx.go @@ -0,0 +1,93 @@ +package sql + +import ( + "context" + "database/sql" + "errors" + "fmt" +) + +// Tx is a wrapper around *sql.Tx that provides a RollbackOnError method +// which can be deferred to rollback the transaction depending on the returned +// error. +// +// The Rollback and Commit methods ignore the sql.ErrTxDone error and otherwise +// wrap the returned error if not nil. +type Tx struct{ *sql.Tx } + +func (tx *Tx) Exec(query string, args ...any) (Result, error) { + return tx.ExecContext(context.Background(), query, args...) +} +func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) { + return _ExecContext(tx.Tx, ctx, query, args...) +} + +func (tx *Tx) Prepare(query string) (*Stmt, error) { + return tx.PrepareContext(context.Background(), query) +} +func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + return _PrepareContext(tx.Tx, ctx, query) +} + +func (tx *Tx) Query(query string, args ...any) (*Rows, error) { + return tx.QueryContext(context.Background(), query, args...) +} +func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { + return _QueryContext(tx.Tx, ctx, query, args...) +} + +func (tx *Tx) QueryRow(query string, args ...any) *Row { + return tx.QueryRowContext(context.Background(), query, args...) +} +func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row { + row := tx.Tx.QueryRowContext(ctx, query, args...) + return &Row{row, query, args} +} + +// RollbackOnError calls tx.Rollback() if *rerr is not nil or if recovering +// from a panic. This function should be deferred and passed a pointer to the +// caller's final error. +// +// func doSomething(tx *sqlTx) (rerr error) { +// defer tx.RollbackOnError(&rerr) +// if err := tx.Exec(...); err != nil { +// return err // will rollback +// } +// return nil // will not rollback +// } +// +// The transaction is left open if not rolled back. The user is responsible for +// calling tx.Commit(). +func (tx *Tx) RollbackOnError(rerr *error) { + if p := recover(); p != nil { + *rerr = errors.Join(*rerr, fmt.Errorf("panic: %v", p)) + } + if *rerr == nil { + return + } + if err := tx.Rollback(); err != nil { + *rerr = errors.Join(*rerr, err) + } +} + +// Rollback calls tx.Tx.Rollback() and ignores sql.ErrTxDone errors and +// otherwise wraps the returned error if not nil. +func (tx *Tx) Rollback() error { + if err := tx.Tx.Rollback(); err != nil && !errors.Is(err, ErrTxDone) { + return fmt.Errorf("failed to rollback transaction: %w", err) + } + return nil +} + +// Commit calls tx.Tx.Commit() and wraps the returned error if not nil. +func (tx *Tx) Commit() error { + if err := tx.Tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + return nil +} + +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { return tx.StmtContext(context.Background(), stmt) } +func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { + return &Stmt{tx.Tx.StmtContext(ctx, stmt.Stmt), stmt.query} +} diff --git a/internal/vendored/parse.go b/internal/vendored/parse.go index 61aff4b..d9cbee8 100644 --- a/internal/vendored/parse.go +++ b/internal/vendored/parse.go @@ -10,6 +10,7 @@ import ( "strings" "aslevy.com/go-doc/internal/godoc" + "aslevy.com/go-doc/internal/slices" ) type ModulePackages map[godoc.PackageDir][]godoc.PackageDir @@ -20,13 +21,17 @@ func ParseModulePackages(ctx context.Context, vendorDir string) (ModulePackages, } func (modPkgs ModulePackages) parse(ctx context.Context, vendorDir string) error { - return Handler(func(_ context.Context, mod godoc.PackageDir, pkgs ...godoc.PackageDir) error { - modPkgs[mod] = append(modPkgs[mod], pkgs...) - return nil + return Handler(func(_ context.Context, mod godoc.PackageDir) (PackageHandler, error) { + return func(_ context.Context, pkgImportPath string) error { + modPkgs[mod] = append(modPkgs[mod], godoc.NewPackageDir(pkgImportPath, filepath.Join(vendorDir, pkgImportPath))) + return nil + }, nil }).parse(ctx, vendorDir) } -type Handler func(ctx context.Context, mod godoc.PackageDir, pkgs ...godoc.PackageDir) error +type Handler func(ctx context.Context, mod godoc.PackageDir) (PackageHandler, error) + +type PackageHandler func(ctx context.Context, pkgImportPath string) error func Parse(ctx context.Context, vendorDir string, handle Handler) error { return handle.parse(ctx, vendorDir) @@ -43,56 +48,167 @@ func (handle Handler) parse(ctx context.Context, vendorDir string) error { return handle.parseData(ctx, vendorDir, modTxtFile) } -func (handle Handler) parseData(ctx context.Context, vendorDir string, data io.Reader) error { - var mod godoc.PackageDir - var pkgs []godoc.PackageDir +func (handleMod Handler) parseData(ctx context.Context, vendorDir string, data io.Reader) error { + var handlePkg PackageHandler lines := bufio.NewScanner(data) for lines.Scan() && ctx.Err() == nil { - modImportPath, _, pkgImportPath := parseLine(lines.Text()) - if modImportPath != "" { - if mod.Dir != "" { - if err := handle(ctx, mod, pkgs...); err != nil { - return err - } - } - mod = godoc.NewPackageDir( - modImportPath, - filepath.Join(vendorDir, modImportPath), - ) - pkgs = pkgs[:0] + line := parseLine(lines.Text()) + if line == nil { continue } - if pkgImportPath != "" { - if mod.Dir == "" { - return fmt.Errorf("found package %q before a module", pkgImportPath) + if line.Kind == modulesTxtLineKindModule { + var err error + handlePkg, err = handleMod(ctx, godoc.NewPackageDir( + line.ImportPath, + filepath.Join(vendorDir, line.ImportPath), + godoc.WithVersion(line.FullVersion()), + )) + if err != nil { + return err } - if !strings.HasPrefix(pkgImportPath, mod.ImportPath) { - return fmt.Errorf("package %q is not in module %q", pkgImportPath, mod.ImportPath) - } - pkgs = append(pkgs, godoc.NewPackageDir(pkgImportPath, "")) + continue + } + if handlePkg == nil { + continue + } + if err := handlePkg(ctx, line.ImportPath); err != nil { + return err } } return nil } -func parseLine(line string) (modImportPath, modVersion, pkgImportPath string) { + +type modulesTxtLine struct { + Kind modulesTxtLineKind + + Explicit bool + GoVersion string + + ImportPath string + Version string + ReplaceImportPath string + ReplaceVersion string +} + +func (line modulesTxtLine) FullVersion() string { + if line.Version == "" { + // An empty version signals we should always sync the module. + return "" + } + + if line.ReplaceImportPath == "" { + // No replacement so just the version. + return line.Version + } + + // A replacement without a version is a relative path replacement and + // should generally always be synced. + if line.ReplaceVersion == "" { + return "" + } + + // Construct a version that depends on the module version, its + // replacement path, and its replacement version. If any one of these + // changes we will re-sync the module. + return fmt.Sprintf("%s=>%s@%s", line.Version, line.ReplaceImportPath, line.ReplaceVersion) +} + +type modulesTxtLineKind int + +const ( + modulesTxtLineKindUnknown modulesTxtLineKind = iota + modulesTxtLineKindModule + modulesTxtLineKindPackage +) + +func parseLine(line string) *modulesTxtLine { + var l modulesTxtLine fields := strings.Fields(line) + if len(fields) == 0 { - return + return nil + } + + next, remaining := slices.PopFirst(fields) + + const goVersionPrefix = "##" + if next == goVersionPrefix { + // This is a go version line which we ignore. + return nil } - switch fields[0] { - case "#": - // module - if len(fields) < 3 { - return + + const ( + modulePrefix = "#" + replaceArrow = "=>" + versionPrefix = "v" + ) + if next == modulePrefix { + if len(remaining) < 2 { + // invalid module line + return nil } - modImportPath, modVersion = fields[1], fields[2] - if !strings.HasPrefix(modVersion, "v") { - modVersion = "" + l.Kind = modulesTxtLineKindModule + + l.ImportPath, remaining = slices.PopFirst(remaining) + + next, remaining = slices.PopFirst(remaining) + + if next == replaceArrow { + // Redundant replace line without complete versions + // which we skip. + // # path/to/mod => path/to/replace/mod v1.2.3 + // or + // # path/to/mod => ../path/to/replace/mod + return nil } - case "##": - // ignore - default: - pkgImportPath = fields[0] + + if !strings.HasPrefix(next, versionPrefix) { + // invalid module line + return nil + } + l.Version = next + + // simple module line + // # path/to/mod v1.2.3 + if len(remaining) == 0 { + return &l + } + + // module replace line + // # path/to/mod v1.2.3 => path/to/replace/mod v1.3.2 + // or + // # path/to/mod v1.2.3 => ../path/to/replace/mod + if len(remaining) < 2 { + // invalid module line + return nil + } + + next, remaining = slices.PopFirst(remaining) + if next != replaceArrow { + // invalid module line + return nil + } + + l.ReplaceImportPath, remaining = slices.PopFirst(remaining) + if len(remaining) == 0 { + // this is a relative path replacement, which has no + // version, so we clear the version. for example: + // # path/to/mod v1.2.3 => ../path/to/replace/mod + l.Version = "" + return &l + } + if !strings.HasPrefix(remaining[0], versionPrefix) { + // invalid module line + return nil + } + // this is a versioned replacement, like + // # path/to/mod v1.2.3 => path/to/replace/mod v1.3.2 + l.ReplaceVersion = remaining[0] + return &l } - return + + // this is a package line + l.Kind = modulesTxtLineKindPackage + l.ImportPath = next + return &l } diff --git a/main.go b/main.go index eb5399b..d0f8b24 100644 --- a/main.go +++ b/main.go @@ -44,6 +44,7 @@ package main import ( "bytes" + "context" "errors" "flag" "fmt" @@ -60,7 +61,6 @@ import ( "aslevy.com/go-doc/internal/dlog" "aslevy.com/go-doc/internal/flags" "aslevy.com/go-doc/internal/godoc" - "aslevy.com/go-doc/internal/index" "aslevy.com/go-doc/internal/outfmt" ) @@ -118,11 +118,15 @@ func do(writer io.Writer, flagSet *flag.FlagSet, args []string) (err error) { return err } } - godoc.NoImports = godoc.NoImports || short // don't show imports with -short - if pkgIdx := packageIndex(); pkgIdx != nil { - defer pkgIdx.Close() - xdirs = index.NewDirs(pkgIdx) + modPkg := openModPkg(context.Background()) + if modPkg != nil { + xdirs = modPkg + defer func() { + if err := modPkg.Close(); err != nil { + dlog.Printf("modpkg.Close: %v", err) + } + }() } completer := completion.NewCompleter(writer, xdirs, unexported, matchCase, flagSet.Args()) diff --git a/mainextra.go b/mainextra.go index 99aa049..59b6d51 100644 --- a/mainextra.go +++ b/mainextra.go @@ -1,37 +1,27 @@ package main import ( - "bytes" "context" - "os" - "os/exec" - "path/filepath" - "strings" + "log" "aslevy.com/go-doc/internal/dlog" "aslevy.com/go-doc/internal/godoc" - "aslevy.com/go-doc/internal/index" + "aslevy.com/go-doc/internal/modpkg" ) -func packageIndex() *index.Index { - localModuleRoot := moduleRootDir(goCmd()) - if localModuleRoot == "" { +func openModPkg(ctx context.Context) *modpkg.ModPkg { + if GOMOD == "" { + dlog.Printf("GOMOD is empty, not using modpkg") return nil } - path := indexCachePath(localModuleRoot) - if err := os.Mkdir(filepath.Dir(path), 0755); err != nil && !os.IsExist(err) { - dlog.Printf("failed to create index cache dir: %v", err) - return nil - } - pkgIdx, err := index.Load(context.Background(), path, dirsToIndexModules(codeRoots()...), index.WithMode(index.Sync)) + modPkg, err := modpkg.New(ctx, GOMOD, dirsToIndexModules(codeRoots()...),) if err != nil { - dlog.Printf("index.Load: %v", err) + log.Fatalf("modpkg.New: %v", err) + return nil } - return pkgIdx -} -func indexCachePath(localModuleRoot string) string { - return filepath.Join(localModuleRoot, ".go-doc", "packages.sqlite3") + return modPkg } + func dirsToIndexModules(dirs ...Dir) []godoc.PackageDir { mods := make([]godoc.PackageDir, len(dirs)) for i, dir := range dirs { @@ -39,12 +29,3 @@ func dirsToIndexModules(dirs ...Dir) []godoc.PackageDir { } return mods } -func moduleRootDir(goCmd string) string { - args := []string{"env", "GOMOD"} - stdout, err := exec.Command(goCmd, args...).Output() - if err != nil { - dlog.Printf("failed to run `%s %s`: %v", goCmd, strings.Join(args, " "), err) - return "" - } - return filepath.Dir(string(bytes.TrimSpace(stdout))) -}