diff --git a/.github/workflows/sonar.yml b/.github/workflows/sonar.yml index 2f0d35b..7ea2fcc 100644 --- a/.github/workflows/sonar.yml +++ b/.github/workflows/sonar.yml @@ -23,6 +23,8 @@ jobs: steps: - uses: actions/checkout@v6 + with: + fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v6 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dc7883d..e6297f3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -37,6 +37,12 @@ jobs: go-version-file: 'go.mod' id: go + - name: Allow toolchain auto-download + run: echo "GOTOOLCHAIN=auto" >> $GITHUB_ENV + + - name: Check out code into the Go module directory + uses: actions/checkout@v6 + - name: Get dependencies run: | go get -v -t ./... diff --git a/arc.go b/arc.go index 10a20d1..91c5072 100644 --- a/arc.go +++ b/arc.go @@ -54,7 +54,7 @@ func (b arcbuilder) Build() *ArcFSM { type tuple struct { Status int - Type interface{} + Type any } // ArcFSM is a defined Finite-State-Machine that allows specific mutations of @@ -71,7 +71,7 @@ type ArcFSM struct { } // IsValidTransition validates status transition without committing the transaction -func (fsm *ArcFSM) IsValidTransition(from Status, to Status) bool { +func (fsm *ArcFSM) IsValidTransition(from, to Status) bool { s, ok := fsm.updates[from.ShiftStatus()] if !ok { return false diff --git a/gen_1_test.go b/gen_1_test.go index 00d4305..7d235f8 100644 --- a/gen_1_test.go +++ b/gen_1_test.go @@ -21,7 +21,7 @@ func (一 insert) Insert( ) (int64, error) { var ( q strings.Builder - args []interface{} + args []any ) q.WriteString("insert into users set `status`=?, `created_at`=?, `updated_at`=? ") @@ -54,7 +54,7 @@ func (一 update) Update( ) (int64, error) { var ( q strings.Builder - args []interface{} + args []any ) q.WriteString("update users set `status`=?, `updated_at`=? ") @@ -92,7 +92,7 @@ func (一 complete) Update( ) (int64, error) { var ( q strings.Builder - args []interface{} + args []any ) q.WriteString("update users set `status`=?, `updated_at`=? ") diff --git a/gen_2_test.go b/gen_2_test.go index c033f42..86ca0c7 100644 --- a/gen_2_test.go +++ b/gen_2_test.go @@ -21,7 +21,7 @@ func (一 i) Insert( ) (int64, error) { var ( q strings.Builder - args []interface{} + args []any ) q.WriteString("insert into tests set `status`=?, `created_at`=?, `updated_at`=? ") @@ -57,7 +57,7 @@ func (一 u) Update( ) (int64, error) { var ( q strings.Builder - args []interface{} + args []any ) q.WriteString("update tests set `status`=?, `updated_at`=? ") diff --git a/gen_3_test.go b/gen_3_test.go index 2c2644c..9570008 100644 --- a/gen_3_test.go +++ b/gen_3_test.go @@ -20,7 +20,7 @@ func (一 i_t) Insert( ) (int64, error) { var ( q strings.Builder - args []interface{} + args []any ) if 一.CreatedAt.IsZero() { @@ -70,7 +70,7 @@ func (一 u_t) Update( ) (int64, error) { var ( q strings.Builder - args []interface{} + args []any ) if 一.UpdatedAt.IsZero() { diff --git a/gen_4_test.go b/gen_4_test.go index b115892..1032d63 100644 --- a/gen_4_test.go +++ b/gen_4_test.go @@ -21,7 +21,7 @@ func (一 insert2) Insert( ) (int64, error) { var ( q strings.Builder - args []interface{} + args []any ) q.WriteString("insert into users set `status`=?, `created_at`=?, `updated_at`=? ") @@ -57,7 +57,7 @@ func (一 move) Update( ) (int64, error) { var ( q strings.Builder - args []interface{} + args []any ) q.WriteString("update users set `status`=?, `updated_at`=? ") diff --git a/gen_string_test.go b/gen_string_test.go index ccfdc09..488388b 100644 --- a/gen_string_test.go +++ b/gen_string_test.go @@ -21,7 +21,7 @@ func (一 insertStr) Insert( ) (string, error) { var ( q strings.Builder - args []interface{} + args []any ) q.WriteString("insert into usersStr set `id`=?, `status`=?, `created_at`=?, `updated_at`=? ") @@ -49,7 +49,7 @@ func (一 updateStr) Update( ) (string, error) { var ( q strings.Builder - args []interface{} + args []any ) q.WriteString("update usersStr set `status`=?, `updated_at`=? ") @@ -87,7 +87,7 @@ func (一 completeStr) Update( ) (string, error) { var ( q strings.Builder - args []interface{} + args []any ) q.WriteString("update usersStr set `status`=?, `updated_at`=? ") diff --git a/helper_test.go b/helper_test.go index 1f1ee4b..0669bc4 100644 --- a/helper_test.go +++ b/helper_test.go @@ -120,7 +120,7 @@ type Currency struct { Amount int64 } -func (c *Currency) Scan(src interface{}) error { +func (c *Currency) Scan(src any) error { var s sql.NullString if err := s.Scan(src); err != nil { return err diff --git a/shift.go b/shift.go index fe4842a..11cf0b6 100644 --- a/shift.go +++ b/shift.go @@ -123,7 +123,7 @@ type GenFSM[T primary] struct { } // IsValidTransition validates status transition without committing the transaction -func (fsm *GenFSM[T]) IsValidTransition(from Status, to Status) bool { +func (fsm *GenFSM[T]) IsValidTransition(from, to Status) bool { s, ok := fsm.states[from.ShiftStatus()] if !ok { return false @@ -293,11 +293,11 @@ func updateTx[T primary](ctx context.Context, tx *sql.Tx, from Status, to Status type status struct { st Status t reflex.EventType - req interface{} + req any insert bool next map[Status]bool } -func sameType(a interface{}, b interface{}) bool { +func sameType(a, b any) bool { return reflect.TypeOf(a) == reflect.TypeOf(b) } diff --git a/shift_internal_test.go b/shift_internal_test.go index 3882fc6..217c062 100644 --- a/shift_internal_test.go +++ b/shift_internal_test.go @@ -19,8 +19,8 @@ type yy y func Test(t *testing.T) { cases := []struct { name string - a interface{} - b interface{} + a any + b any res bool }{ { diff --git a/shiftgen/shiftgen.go b/shiftgen/shiftgen.go index 7549194..029138e 100644 --- a/shiftgen/shiftgen.go +++ b/shiftgen/shiftgen.go @@ -147,7 +147,7 @@ func parseInserters() ([]string, error) { if *inserter != "" { ii = append(ii, *inserter) } else if strings.TrimSpace(*inserters) != "" { - for _, i := range strings.Split(*inserters, ",") { + for i := range strings.SplitSeq(*inserters, ",") { ii = append(ii, strings.TrimSpace(i)) } } @@ -157,7 +157,7 @@ func parseInserters() ([]string, error) { func parseUpdaters() []string { var uu []string if strings.TrimSpace(*updaters) != "" { - for _, u := range strings.Split(*updaters, ",") { + for u := range strings.SplitSeq(*updaters, ",") { uu = append(uu, strings.TrimSpace(u)) } } @@ -304,7 +304,7 @@ func generateSrc(pkgPath, table string, inserters, updaters []string, statusFiel } func execTpl(out io.Writer, tpl string, data Data) error { - t := template.New("").Funcs(map[string]interface{}{ + t := template.New("").Funcs(map[string]any{ "col": quoteCol, }) diff --git a/shiftgen/shiftgen_test.go b/shiftgen/shiftgen_test.go index 02f7e4d..e018b68 100644 --- a/shiftgen/shiftgen_test.go +++ b/shiftgen/shiftgen_test.go @@ -10,6 +10,109 @@ import ( "github.com/stretchr/testify/require" ) +func TestToSnakeCase(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {name: "lowercase unchanged", input: "hello", want: "hello"}, + {name: "camel to snake", input: "CamelCase", want: "camel_case"}, + {name: "multiple words", input: "MyFieldName", want: "my_field_name"}, + {name: "acronym", input: "IDField", want: "id_field"}, + {name: "already snake", input: "snake_case", want: "snake_case"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := toSnakeCase(tt.input) + if got != tt.want { + t.Errorf("toSnakeCase(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestIDZeroValue(t *testing.T) { + tests := []struct { + name string + idType string + want string + }{ + {name: "int64", idType: "int64", want: "0"}, + {name: "string", idType: "string", want: `""`}, + {name: "unknown type", idType: "uuid", want: ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := Struct{IDType: tt.idType} + got := s.IDZeroValue() + if got != tt.want { + t.Errorf("IDZeroValue() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestParseUpdaters(t *testing.T) { + tests := []struct { + name string + value string + want []string + }{ + {name: "empty", value: "", want: nil}, + {name: "single", value: "UpdateReq", want: []string{"UpdateReq"}}, + {name: "multiple", value: "UpdateReq,CompleteReq", want: []string{"UpdateReq", "CompleteReq"}}, + {name: "with spaces", value: " UpdateReq , CompleteReq ", want: []string{"UpdateReq", "CompleteReq"}}, + {name: "whitespace only", value: " ", want: nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + orig := *updaters + *updaters = tt.value + t.Cleanup(func() { *updaters = orig }) + + got := parseUpdaters() + require.Equal(t, tt.want, got) + }) + } +} + +func TestParseInserters(t *testing.T) { + tests := []struct { + name string + inserterVal string + insertersVal string + want []string + wantErr bool + }{ + {name: "empty", inserterVal: "", insertersVal: "", want: nil}, + {name: "single inserter", inserterVal: "InsertReq", insertersVal: "", want: []string{"InsertReq"}}, + {name: "multiple inserters", inserterVal: "", insertersVal: "InsertA,InsertB", want: []string{"InsertA", "InsertB"}}, + {name: "inserters with spaces", inserterVal: "", insertersVal: " InsertA , InsertB ", want: []string{"InsertA", "InsertB"}}, + {name: "both set returns error", inserterVal: "InsertReq", insertersVal: "InsertA", want: nil, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + origInserter := *inserter + origInserters := *inserters + *inserter = tt.inserterVal + *inserters = tt.insertersVal + t.Cleanup(func() { + *inserter = origInserter + *inserters = origInserters + }) + + got, err := parseInserters() + if tt.wantErr { + require.Error(t, err) + return + } + jtest.RequireNil(t, err) + require.Equal(t, tt.want, got) + }) + } +} + func TestGen(t *testing.T) { cc := []struct { dir string diff --git a/test_shift.go b/test_shift.go index 3113758..f4c91b2 100644 --- a/test_shift.go +++ b/test_shift.go @@ -87,8 +87,8 @@ func randomInsert(req any) (Inserter[int64], error) { } s := reflect.New(reflect.ValueOf(req).Type()).Elem() - for i := 0; i < s.NumField(); i++ { - f := s.Field(i) + for _, f := range s.Fields() { + f := f f.Set(randVal(f.Type())) } return s.Interface().(Inserter[int64]), nil @@ -117,15 +117,15 @@ func buildPaths(states map[int]status, from Status) [][]status { } var ( - intType = reflect.TypeOf((int)(0)) - int64Type = reflect.TypeOf((int64)(0)) - float64Type = reflect.TypeOf((float64)(0)) - timeType = reflect.TypeOf(time.Time{}) - sliceByteType = reflect.TypeOf([]byte(nil)) - boolType = reflect.TypeOf(false) - stringType = reflect.TypeOf("") - nullTimeType = reflect.TypeOf(sql.NullTime{}) - nullStringType = reflect.TypeOf(sql.NullString{}) + intType = reflect.TypeFor[int]() + int64Type = reflect.TypeFor[int64]() + float64Type = reflect.TypeFor[float64]() + timeType = reflect.TypeFor[time.Time]() + sliceByteType = reflect.TypeFor[[]byte]() + boolType = reflect.TypeFor[bool]() + stringType = reflect.TypeFor[string]() + nullTimeType = reflect.TypeFor[sql.NullTime]() + nullStringType = reflect.TypeFor[sql.NullString]() ) func randVal(t reflect.Type) reflect.Value { diff --git a/test_shift_test.go b/test_shift_test.go index dbc5463..1dfd02a 100644 --- a/test_shift_test.go +++ b/test_shift_test.go @@ -90,11 +90,11 @@ func TestTestFSM(t *testing.T) { } func (ii i) GetMetadata(ctx context.Context, tx *sql.Tx, id int64, status shift.Status) ([]byte, error) { - return []byte(fmt.Sprint(id)), nil + return fmt.Append(nil, id), nil } func (uu u) GetMetadata(ctx context.Context, tx *sql.Tx, from shift.Status, to shift.Status) ([]byte, error) { - return []byte(fmt.Sprint(uu.ID)), nil + return fmt.Append(nil, uu.ID), nil } func TestWithMeta(t *testing.T) { @@ -111,8 +111,7 @@ func TestWithMeta(t *testing.T) { err := shift.TestFSM(t, dbc, fsm) require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() sc, err := events.ToStream(dbc)(context.Background(), "") require.NoError(t, err)