diff --git a/embedmd/command_test.go b/embedmd/command_test.go index da978cc..630aa76 100644 --- a/embedmd/command_test.go +++ b/embedmd/command_test.go @@ -13,7 +13,11 @@ package embedmd -import "testing" +import ( + "testing" + + "github.com/campoy/embedmd/internal/testutil" +) func TestParseCommand(t *testing.T) { tc := []struct { @@ -24,10 +28,10 @@ func TestParseCommand(t *testing.T) { }{ {name: "start to end", in: "(code.go /start/ /end/)", - cmd: command{path: "code.go", lang: "go", start: ptr("/start/"), end: ptr("/end/")}}, + cmd: command{path: "code.go", lang: "go", start: testutil.Ptr("/start/"), end: testutil.Ptr("/end/")}}, {name: "only start", in: "(code.go /start/)", - cmd: command{path: "code.go", lang: "go", start: ptr("/start/")}}, + cmd: command{path: "code.go", lang: "go", start: testutil.Ptr("/start/")}}, {name: "empty list", in: "()", err: "missing file name"}, @@ -54,10 +58,10 @@ func TestParseCommand(t *testing.T) { cmd: command{path: "test.md", lang: "markdown"}}, {name: "multi-line comments", in: `(doc.go /\/\*/ /\*\//)`, - cmd: command{path: "doc.go", lang: "go", start: ptr(`/\/\*/`), end: ptr(`/\*\//`)}}, + cmd: command{path: "doc.go", lang: "go", start: testutil.Ptr(`/\/\*/`), end: testutil.Ptr(`/\*\//`)}}, {name: "using $ as end", in: "(foo.go /start/ $)", - cmd: command{path: "foo.go", lang: "go", start: ptr("/start/"), end: ptr("$")}}, + cmd: command{path: "foo.go", lang: "go", start: testutil.Ptr("/start/"), end: testutil.Ptr("$")}}, {name: "extra arguments", in: "(foo.go /start/ $ extra)", err: "too many arguments"}, {name: "file name with directories", @@ -74,7 +78,7 @@ func TestParseCommand(t *testing.T) { for _, tt := range tc { t.Run(tt.name, func(t *testing.T) { cmd, err := parseCommand(tt.in) - if !eqErr(t, tt.name, err, tt.err) { + if !testutil.EqErr(t, tt.name, err, tt.err) { return } @@ -85,43 +89,12 @@ func TestParseCommand(t *testing.T) { if want.lang != got.lang { t.Errorf("case [%s]: expected language %q; got %q", tt.name, want.lang, got.lang) } - if !eqPtr(want.start, got.start) { - t.Errorf("case [%s]: expected start %v; got %v", tt.name, str(want.start), str(got.start)) + if !testutil.EqPtr(want.start, got.start) { + t.Errorf("case [%s]: expected start %v; got %v", tt.name, testutil.Str(want.start), testutil.Str(got.start)) } - if !eqPtr(want.end, got.end) { - t.Errorf("case [%s]: expected end %v; got %v", tt.name, str(want.end), str(got.end)) + if !testutil.EqPtr(want.end, got.end) { + t.Errorf("case [%s]: expected end %v; got %v", tt.name, testutil.Str(want.end), testutil.Str(got.end)) } }) } } - -func ptr(s string) *string { return &s } - -func str(s *string) string { - if s == nil { - return "" - } - return *s -} - -func eqPtr(a, b *string) bool { - if a == nil || b == nil { - return a == b - } - return *a == *b -} - -func eqErr(t *testing.T, id string, err error, msg string) bool { - t.Helper() - if err == nil && msg == "" { - return true - } - if err == nil && msg != "" { - t.Errorf("case [%s]: expected error message %q; but got nothing", id, msg) - return false - } - if err != nil && msg != err.Error() { - t.Errorf("case [%s]: expected error message %q; but got %q", id, msg, err) - } - return false -} diff --git a/embedmd/embedmd.go b/embedmd/embedmd.go index c59cc2a..aed9eea 100644 --- a/embedmd/embedmd.go +++ b/embedmd/embedmd.go @@ -16,7 +16,7 @@ // // The format of an embedmd command is: // -// [embedmd]:# (pathOrURL language /start regexp/ /end regexp/) +// [embedmd]:# (pathOrURL language /start regexp/ /end regexp/) // // The embedded code will be extracted from the file at pathOrURL, // which can either be a relative path to a file in the local file @@ -29,27 +29,26 @@ // Omitting the the second regular expression will embed only the piece of // text that matches /regexp/: // -// [embedmd]:# (pathOrURL language /regexp/) +// [embedmd]:# (pathOrURL language /regexp/) // // To embed the whole line matching a regular expression you can use: // -// [embedmd]:# (pathOrURL language /.*regexp.*\n/) +// [embedmd]:# (pathOrURL language /.*regexp.*\n/) // // If you want to embed from a point to the end you should use: // -// [embedmd]:# (pathOrURL language /start regexp/ $) +// [embedmd]:# (pathOrURL language /start regexp/ $) // // Finally you can embed a whole file by omitting both regular expressions: // -// [embedmd]:# (pathOrURL language) +// [embedmd]:# (pathOrURL language) // // You can ommit the language in any of the previous commands, and the extension // of the file will be used for the snippet syntax highlighting. Note that while // this works Go files, since the file extension .go matches the name of the language // go, this will fail with other files like .md whose language name is markdown. // -// [embedmd]:# (file.ext) -// +// [embedmd]:# (file.ext) package embedmd import ( diff --git a/embedmd/embedmd_test.go b/embedmd/embedmd_test.go index 8c113c8..dba20d1 100644 --- a/embedmd/embedmd_test.go +++ b/embedmd/embedmd_test.go @@ -21,6 +21,8 @@ import ( "path/filepath" "strings" "testing" + + "github.com/campoy/embedmd/internal/testutil" ) const content = ` @@ -43,37 +45,37 @@ func TestExtract(t *testing.T) { {name: "no limits", out: string(content)}, {name: "only one line", - start: ptr("/func main.*\n/"), out: "func main() {\n"}, + start: testutil.Ptr("/func main.*\n/"), out: "func main() {\n"}, {name: "from package to end", - start: ptr("/package main/"), end: ptr("$"), out: string(content[1:])}, + start: testutil.Ptr("/package main/"), end: testutil.Ptr("$"), out: string(content[1:])}, {name: "not matching", - start: ptr("/gopher/"), err: "could not match \"/gopher/\""}, + start: testutil.Ptr("/gopher/"), err: "could not match \"/gopher/\""}, {name: "part of a line", - start: ptr("/fmt.P/"), end: ptr("/hello/"), out: "fmt.Println(\"hello"}, + start: testutil.Ptr("/fmt.P/"), end: testutil.Ptr("/hello/"), out: "fmt.Println(\"hello"}, {name: "function call", - start: ptr("/fmt\\.[^()]*/"), out: "fmt.Println"}, + start: testutil.Ptr("/fmt\\.[^()]*/"), out: "fmt.Println"}, {name: "from fmt to end of line", - start: ptr("/fmt.P.*\n/"), out: "fmt.Println(\"hello, test\")\n"}, + start: testutil.Ptr("/fmt.P.*\n/"), out: "fmt.Println(\"hello, test\")\n"}, {name: "from func to end of next line", - start: ptr("/func/"), end: ptr("/Println.*\n/"), out: "func main() {\n fmt.Println(\"hello, test\")\n"}, + start: testutil.Ptr("/func/"), end: testutil.Ptr("/Println.*\n/"), out: "func main() {\n fmt.Println(\"hello, test\")\n"}, {name: "from func to }", - start: ptr("/func main/"), end: ptr("/}/"), out: "func main() {\n fmt.Println(\"hello, test\")\n}"}, + start: testutil.Ptr("/func main/"), end: testutil.Ptr("/}/"), out: "func main() {\n fmt.Println(\"hello, test\")\n}"}, {name: "bad start regexp", - start: ptr("/(/"), err: "error parsing regexp: missing closing ): `(`"}, + start: testutil.Ptr("/(/"), err: "error parsing regexp: missing closing ): `(`"}, {name: "bad regexp", - start: ptr("something"), err: "missing slashes (/) around \"something\""}, + start: testutil.Ptr("something"), err: "missing slashes (/) around \"something\""}, {name: "bad end regexp", - start: ptr("/fmt.P/"), end: ptr("/)/"), err: "error parsing regexp: unexpected ): `)`"}, + start: testutil.Ptr("/fmt.P/"), end: testutil.Ptr("/)/"), err: "error parsing regexp: unexpected ): `)`"}, {name: "start and end of line ^$", - start: ptr("/^func main/"), end: ptr("/}$/"), out: "func main() {\n fmt.Println(\"hello, test\")\n}"}, + start: testutil.Ptr("/^func main/"), end: testutil.Ptr("/}$/"), out: "func main() {\n fmt.Println(\"hello, test\")\n}"}, } for _, tt := range tc { t.Run(tt.name, func(t *testing.T) { b, err := extract([]byte(content), tt.start, tt.end) - if !eqErr(t, tt.name, err, tt.err) { + if !testutil.EqErr(t, tt.name, err, tt.err) { return } if string(b) != tt.out { @@ -107,7 +109,7 @@ func TestExtractFromFile(t *testing.T) { }, { name: "added line break", - cmd: command{path: "code.go", lang: "go", start: ptr("/fmt\\.Println/")}, + cmd: command{path: "code.go", lang: "go", start: testutil.Ptr("/fmt\\.Println/")}, files: map[string][]byte{"code.go": []byte(content)}, out: "```go\nfmt.Println\n```\n", }, @@ -118,7 +120,7 @@ func TestExtractFromFile(t *testing.T) { }, { name: "unmatched regexp", - cmd: command{path: "code.go", lang: "go", start: ptr("/potato/")}, + cmd: command{path: "code.go", lang: "go", start: testutil.Ptr("/potato/")}, files: map[string][]byte{"code.go": []byte(content)}, err: "could not extract content from code.go: could not match \"/potato/\"", }, @@ -133,7 +135,7 @@ func TestExtractFromFile(t *testing.T) { w := new(bytes.Buffer) err := e.runCommand(w, &tt.cmd) - if !eqErr(t, tt.name, err, tt.err) { + if !testutil.EqErr(t, tt.name, err, tt.err) { return } if w.String() != tt.out { @@ -267,7 +269,7 @@ func TestProcess(t *testing.T) { opts = append(opts, WithBaseDir(tt.dir)) } err := Process(&out, strings.NewReader(tt.in), opts...) - if !eqErr(t, tt.name, err, tt.err) { + if !testutil.EqErr(t, tt.name, err, tt.err) { return } if tt.out != out.String() { diff --git a/embedmd/parser_test.go b/embedmd/parser_test.go index 51007d4..a9b26c7 100644 --- a/embedmd/parser_test.go +++ b/embedmd/parser_test.go @@ -19,6 +19,8 @@ import ( "io" "strings" "testing" + + "github.com/campoy/embedmd/internal/testutil" ) func TestParser(t *testing.T) { @@ -94,7 +96,7 @@ func TestParser(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var out bytes.Buffer err := process(&out, strings.NewReader(tt.in), tt.run) - if !eqErr(t, tt.name, err, tt.err) { + if !testutil.EqErr(t, tt.name, err, tt.err) { return } if got := out.String(); got != tt.out { diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..9ac7711 --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,40 @@ +// Package testutil provides shared test helpers. +package testutil + +import "testing" + +// EqErr checks whether err matches the expected message msg. +// Returns true only when both are empty (no error expected, none received). +func EqErr(t *testing.T, id string, err error, msg string) bool { + t.Helper() + if err == nil && msg == "" { + return true + } + if err == nil && msg != "" { + t.Errorf("case [%s]: expected error message %q; but got nothing", id, msg) + return false + } + if err != nil && msg != err.Error() { + t.Errorf("case [%s]: expected error message %q; but got %q", id, msg, err) + } + return false +} + +// Ptr returns a pointer to the given string. +func Ptr(s string) *string { return &s } + +// Str returns the string value of a *string, or "" if nil. +func Str(s *string) string { + if s == nil { + return "" + } + return *s +} + +// EqPtr returns whether two *string values are equal. +func EqPtr(a, b *string) bool { + if a == nil || b == nil { + return a == b + } + return *a == *b +} diff --git a/main.go b/main.go index c25e5cc..bb26330 100644 --- a/main.go +++ b/main.go @@ -24,9 +24,12 @@ // // embedmd supports two flags: // -d: will print the difference of the input file with what the output -// would have been if executed. +// +// would have been if executed. +// // -w: rewrites the given files rather than writing the output to the standard -// output. +// +// output. // // For more information on the format of the commands, read the documentation // of the github.com/campoy/embedmd/embedmd package. diff --git a/main_test.go b/main_test.go index 62d1b73..e073179 100644 --- a/main_test.go +++ b/main_test.go @@ -19,6 +19,8 @@ import ( "os" "strings" "testing" + + "github.com/campoy/embedmd/internal/testutil" ) func TestEmbedStreams(t *testing.T) { @@ -65,7 +67,7 @@ func TestEmbedStreams(t *testing.T) { buf := &bytes.Buffer{} stdout = buf foundDiff, err := embed(nil, tt.w, tt.d) - if !eqErr(t, tt.name, err, tt.err) { + if !testutil.EqErr(t, tt.name, err, tt.err) { continue } if got := buf.String(); tt.out != got { @@ -112,7 +114,7 @@ func TestEmbedFiles(t *testing.T) { } _, err := embed([]string{"docs.md"}, tt.w, tt.d) - if !eqErr(t, tt.name, err, tt.err) { + if !testutil.EqErr(t, tt.name, err, tt.err) { continue } if got := f.buf.String(); tt.out != got { @@ -122,20 +124,6 @@ func TestEmbedFiles(t *testing.T) { } } -func eqErr(t *testing.T, id string, err error, msg string) bool { - if err == nil && msg == "" { - return true - } - if err == nil && msg != "" { - t.Errorf("case [%s]: expected error message %q; but got nothing", id, msg) - return false - } - if err != nil && msg != err.Error() { - t.Errorf("case [%s]: expected error message %q; but got %q", id, msg, err) - } - return false -} - type fakeFile struct { io.ReadCloser buf bytes.Buffer