Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 14 additions & 41 deletions embedmd/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@

package embedmd

import "testing"
import (
"testing"

"github.com/campoy/embedmd/internal/testutil"
)

func TestParseCommand(t *testing.T) {
tc := []struct {
Expand All @@ -24,10 +28,10 @@ func TestParseCommand(t *testing.T) {
}{
{name: "start to end",
in: "(code.go /start/ /end/)",
cmd: command{path: "code.go", lang: "go", start: ptr("/start/"), end: ptr("/end/")}},
cmd: command{path: "code.go", lang: "go", start: testutil.Ptr("/start/"), end: testutil.Ptr("/end/")}},
{name: "only start",
in: "(code.go /start/)",
cmd: command{path: "code.go", lang: "go", start: ptr("/start/")}},
cmd: command{path: "code.go", lang: "go", start: testutil.Ptr("/start/")}},
{name: "empty list",
in: "()",
err: "missing file name"},
Expand All @@ -54,10 +58,10 @@ func TestParseCommand(t *testing.T) {
cmd: command{path: "test.md", lang: "markdown"}},
{name: "multi-line comments",
in: `(doc.go /\/\*/ /\*\//)`,
cmd: command{path: "doc.go", lang: "go", start: ptr(`/\/\*/`), end: ptr(`/\*\//`)}},
cmd: command{path: "doc.go", lang: "go", start: testutil.Ptr(`/\/\*/`), end: testutil.Ptr(`/\*\//`)}},
{name: "using $ as end",
in: "(foo.go /start/ $)",
cmd: command{path: "foo.go", lang: "go", start: ptr("/start/"), end: ptr("$")}},
cmd: command{path: "foo.go", lang: "go", start: testutil.Ptr("/start/"), end: testutil.Ptr("$")}},
{name: "extra arguments",
in: "(foo.go /start/ $ extra)", err: "too many arguments"},
{name: "file name with directories",
Expand All @@ -74,7 +78,7 @@ func TestParseCommand(t *testing.T) {
for _, tt := range tc {
t.Run(tt.name, func(t *testing.T) {
cmd, err := parseCommand(tt.in)
if !eqErr(t, tt.name, err, tt.err) {
if !testutil.EqErr(t, tt.name, err, tt.err) {
return
}

Expand All @@ -85,43 +89,12 @@ func TestParseCommand(t *testing.T) {
if want.lang != got.lang {
t.Errorf("case [%s]: expected language %q; got %q", tt.name, want.lang, got.lang)
}
if !eqPtr(want.start, got.start) {
t.Errorf("case [%s]: expected start %v; got %v", tt.name, str(want.start), str(got.start))
if !testutil.EqPtr(want.start, got.start) {
t.Errorf("case [%s]: expected start %v; got %v", tt.name, testutil.Str(want.start), testutil.Str(got.start))
}
if !eqPtr(want.end, got.end) {
t.Errorf("case [%s]: expected end %v; got %v", tt.name, str(want.end), str(got.end))
if !testutil.EqPtr(want.end, got.end) {
t.Errorf("case [%s]: expected end %v; got %v", tt.name, testutil.Str(want.end), testutil.Str(got.end))
}
})
}
}

func ptr(s string) *string { return &s }

func str(s *string) string {
if s == nil {
return "<nil>"
}
return *s
}

func eqPtr(a, b *string) bool {
if a == nil || b == nil {
return a == b
}
return *a == *b
}

func eqErr(t *testing.T, id string, err error, msg string) bool {
t.Helper()
if err == nil && msg == "" {
return true
}
if err == nil && msg != "" {
t.Errorf("case [%s]: expected error message %q; but got nothing", id, msg)
return false
}
if err != nil && msg != err.Error() {
t.Errorf("case [%s]: expected error message %q; but got %q", id, msg, err)
}
return false
}
13 changes: 6 additions & 7 deletions embedmd/embedmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
//
// The format of an embedmd command is:
//
// [embedmd]:# (pathOrURL language /start regexp/ /end regexp/)
// [embedmd]:# (pathOrURL language /start regexp/ /end regexp/)
//
// The embedded code will be extracted from the file at pathOrURL,
// which can either be a relative path to a file in the local file
Expand All @@ -29,27 +29,26 @@
// Omitting the the second regular expression will embed only the piece of
// text that matches /regexp/:
//
// [embedmd]:# (pathOrURL language /regexp/)
// [embedmd]:# (pathOrURL language /regexp/)
//
// To embed the whole line matching a regular expression you can use:
//
// [embedmd]:# (pathOrURL language /.*regexp.*\n/)
// [embedmd]:# (pathOrURL language /.*regexp.*\n/)
//
// If you want to embed from a point to the end you should use:
//
// [embedmd]:# (pathOrURL language /start regexp/ $)
// [embedmd]:# (pathOrURL language /start regexp/ $)
//
// Finally you can embed a whole file by omitting both regular expressions:
//
// [embedmd]:# (pathOrURL language)
// [embedmd]:# (pathOrURL language)
//
// You can ommit the language in any of the previous commands, and the extension
// of the file will be used for the snippet syntax highlighting. Note that while
// this works Go files, since the file extension .go matches the name of the language
// go, this will fail with other files like .md whose language name is markdown.
//
// [embedmd]:# (file.ext)
//
// [embedmd]:# (file.ext)
package embedmd

import (
Expand Down
36 changes: 19 additions & 17 deletions embedmd/embedmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"path/filepath"
"strings"
"testing"

"github.com/campoy/embedmd/internal/testutil"
)

const content = `
Expand All @@ -43,37 +45,37 @@ func TestExtract(t *testing.T) {
{name: "no limits",
out: string(content)},
{name: "only one line",
start: ptr("/func main.*\n/"), out: "func main() {\n"},
start: testutil.Ptr("/func main.*\n/"), out: "func main() {\n"},
{name: "from package to end",
start: ptr("/package main/"), end: ptr("$"), out: string(content[1:])},
start: testutil.Ptr("/package main/"), end: testutil.Ptr("$"), out: string(content[1:])},
{name: "not matching",
start: ptr("/gopher/"), err: "could not match \"/gopher/\""},
start: testutil.Ptr("/gopher/"), err: "could not match \"/gopher/\""},
{name: "part of a line",
start: ptr("/fmt.P/"), end: ptr("/hello/"), out: "fmt.Println(\"hello"},
start: testutil.Ptr("/fmt.P/"), end: testutil.Ptr("/hello/"), out: "fmt.Println(\"hello"},
{name: "function call",
start: ptr("/fmt\\.[^()]*/"), out: "fmt.Println"},
start: testutil.Ptr("/fmt\\.[^()]*/"), out: "fmt.Println"},
{name: "from fmt to end of line",
start: ptr("/fmt.P.*\n/"), out: "fmt.Println(\"hello, test\")\n"},
start: testutil.Ptr("/fmt.P.*\n/"), out: "fmt.Println(\"hello, test\")\n"},
{name: "from func to end of next line",
start: ptr("/func/"), end: ptr("/Println.*\n/"), out: "func main() {\n fmt.Println(\"hello, test\")\n"},
start: testutil.Ptr("/func/"), end: testutil.Ptr("/Println.*\n/"), out: "func main() {\n fmt.Println(\"hello, test\")\n"},
{name: "from func to }",
start: ptr("/func main/"), end: ptr("/}/"), out: "func main() {\n fmt.Println(\"hello, test\")\n}"},
start: testutil.Ptr("/func main/"), end: testutil.Ptr("/}/"), out: "func main() {\n fmt.Println(\"hello, test\")\n}"},

{name: "bad start regexp",
start: ptr("/(/"), err: "error parsing regexp: missing closing ): `(`"},
start: testutil.Ptr("/(/"), err: "error parsing regexp: missing closing ): `(`"},
{name: "bad regexp",
start: ptr("something"), err: "missing slashes (/) around \"something\""},
start: testutil.Ptr("something"), err: "missing slashes (/) around \"something\""},
{name: "bad end regexp",
start: ptr("/fmt.P/"), end: ptr("/)/"), err: "error parsing regexp: unexpected ): `)`"},
start: testutil.Ptr("/fmt.P/"), end: testutil.Ptr("/)/"), err: "error parsing regexp: unexpected ): `)`"},

{name: "start and end of line ^$",
start: ptr("/^func main/"), end: ptr("/}$/"), out: "func main() {\n fmt.Println(\"hello, test\")\n}"},
start: testutil.Ptr("/^func main/"), end: testutil.Ptr("/}$/"), out: "func main() {\n fmt.Println(\"hello, test\")\n}"},
}

for _, tt := range tc {
t.Run(tt.name, func(t *testing.T) {
b, err := extract([]byte(content), tt.start, tt.end)
if !eqErr(t, tt.name, err, tt.err) {
if !testutil.EqErr(t, tt.name, err, tt.err) {
return
}
if string(b) != tt.out {
Expand Down Expand Up @@ -107,7 +109,7 @@ func TestExtractFromFile(t *testing.T) {
},
{
name: "added line break",
cmd: command{path: "code.go", lang: "go", start: ptr("/fmt\\.Println/")},
cmd: command{path: "code.go", lang: "go", start: testutil.Ptr("/fmt\\.Println/")},
files: map[string][]byte{"code.go": []byte(content)},
out: "```go\nfmt.Println\n```\n",
},
Expand All @@ -118,7 +120,7 @@ func TestExtractFromFile(t *testing.T) {
},
{
name: "unmatched regexp",
cmd: command{path: "code.go", lang: "go", start: ptr("/potato/")},
cmd: command{path: "code.go", lang: "go", start: testutil.Ptr("/potato/")},
files: map[string][]byte{"code.go": []byte(content)},
err: "could not extract content from code.go: could not match \"/potato/\"",
},
Expand All @@ -133,7 +135,7 @@ func TestExtractFromFile(t *testing.T) {

w := new(bytes.Buffer)
err := e.runCommand(w, &tt.cmd)
if !eqErr(t, tt.name, err, tt.err) {
if !testutil.EqErr(t, tt.name, err, tt.err) {
return
}
if w.String() != tt.out {
Expand Down Expand Up @@ -267,7 +269,7 @@ func TestProcess(t *testing.T) {
opts = append(opts, WithBaseDir(tt.dir))
}
err := Process(&out, strings.NewReader(tt.in), opts...)
if !eqErr(t, tt.name, err, tt.err) {
if !testutil.EqErr(t, tt.name, err, tt.err) {
return
}
if tt.out != out.String() {
Expand Down
4 changes: 3 additions & 1 deletion embedmd/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"io"
"strings"
"testing"

"github.com/campoy/embedmd/internal/testutil"
)

func TestParser(t *testing.T) {
Expand Down Expand Up @@ -94,7 +96,7 @@ func TestParser(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
var out bytes.Buffer
err := process(&out, strings.NewReader(tt.in), tt.run)
if !eqErr(t, tt.name, err, tt.err) {
if !testutil.EqErr(t, tt.name, err, tt.err) {
return
}
if got := out.String(); got != tt.out {
Expand Down
40 changes: 40 additions & 0 deletions internal/testutil/testutil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Package testutil provides shared test helpers.
package testutil

import "testing"

// EqErr checks whether err matches the expected message msg.
// Returns true only when both are empty (no error expected, none received).
func EqErr(t *testing.T, id string, err error, msg string) bool {
t.Helper()
if err == nil && msg == "" {
return true
}
if err == nil && msg != "" {
t.Errorf("case [%s]: expected error message %q; but got nothing", id, msg)
return false
}
if err != nil && msg != err.Error() {
t.Errorf("case [%s]: expected error message %q; but got %q", id, msg, err)
}
return false
}

// Ptr returns a pointer to the given string.
func Ptr(s string) *string { return &s }

// Str returns the string value of a *string, or "<nil>" if nil.
func Str(s *string) string {
if s == nil {
return "<nil>"
}
return *s
}

// EqPtr returns whether two *string values are equal.
func EqPtr(a, b *string) bool {
if a == nil || b == nil {
return a == b
}
return *a == *b
}
7 changes: 5 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
//
// embedmd supports two flags:
// -d: will print the difference of the input file with what the output
// would have been if executed.
//
// would have been if executed.
//
// -w: rewrites the given files rather than writing the output to the standard
// output.
//
// output.
//
// For more information on the format of the commands, read the documentation
// of the github.com/campoy/embedmd/embedmd package.
Expand Down
20 changes: 4 additions & 16 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"os"
"strings"
"testing"

"github.com/campoy/embedmd/internal/testutil"
)

func TestEmbedStreams(t *testing.T) {
Expand Down Expand Up @@ -65,7 +67,7 @@ func TestEmbedStreams(t *testing.T) {
buf := &bytes.Buffer{}
stdout = buf
foundDiff, err := embed(nil, tt.w, tt.d)
if !eqErr(t, tt.name, err, tt.err) {
if !testutil.EqErr(t, tt.name, err, tt.err) {
continue
}
if got := buf.String(); tt.out != got {
Expand Down Expand Up @@ -112,7 +114,7 @@ func TestEmbedFiles(t *testing.T) {
}

_, err := embed([]string{"docs.md"}, tt.w, tt.d)
if !eqErr(t, tt.name, err, tt.err) {
if !testutil.EqErr(t, tt.name, err, tt.err) {
continue
}
if got := f.buf.String(); tt.out != got {
Expand All @@ -122,20 +124,6 @@ func TestEmbedFiles(t *testing.T) {
}
}

func eqErr(t *testing.T, id string, err error, msg string) bool {
if err == nil && msg == "" {
return true
}
if err == nil && msg != "" {
t.Errorf("case [%s]: expected error message %q; but got nothing", id, msg)
return false
}
if err != nil && msg != err.Error() {
t.Errorf("case [%s]: expected error message %q; but got %q", id, msg, err)
}
return false
}

type fakeFile struct {
io.ReadCloser
buf bytes.Buffer
Expand Down