diff --git a/main.go b/main.go index bb26330..4218595 100644 --- a/main.go +++ b/main.go @@ -128,15 +128,6 @@ var openFile = func(name string) (file, error) { return os.OpenFile(name, os.O_RDWR, 0666) } -func readFile(path string) ([]byte, error) { - f, err := openFile(path) - if err != nil { - return nil, err - } - defer f.Close() - return io.ReadAll(f) -} - func processFile(path string, rewrite, doDiff bool) (foundDiff bool, err error) { if filepath.Ext(path) != ".md" { return false, fmt.Errorf("not a markdown file") @@ -148,17 +139,19 @@ func processFile(path string, rewrite, doDiff bool) (foundDiff bool, err error) } defer f.Close() + var original bytes.Buffer + var r io.Reader = f + if doDiff { + r = io.TeeReader(f, &original) + } + buf := new(bytes.Buffer) - if err := embedmd.Process(buf, f, embedmd.WithBaseDir(filepath.Dir(path))); err != nil { + if err := embedmd.Process(buf, r, embedmd.WithBaseDir(filepath.Dir(path))); err != nil { return false, err } if doDiff { - f, err := readFile(path) - if err != nil { - return false, fmt.Errorf("could not read %s for diff: %v", path, err) - } - data, err := diff(string(f), buf.String()) + data, err := diff(original.String(), buf.String()) if err != nil || len(data) == 0 { return false, err } diff --git a/main_test.go b/main_test.go index e073179..76c4bb7 100644 --- a/main_test.go +++ b/main_test.go @@ -99,7 +99,7 @@ func TestEmbedFiles(t *testing.T) { {name: "diffing a single file", in: "one\ntwo\nthree", d: true, - out: "@@ -1 +1,4 @@\n+one\n+two\n+three\n \n", + out: "@@ -1,3 +1,4 @@\n one\n two\n three\n+\n", }, }