diff --git a/go.mod b/go.mod index 78879f506..7be3e8d8f 100644 --- a/go.mod +++ b/go.mod @@ -116,6 +116,7 @@ require ( github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/mozillazg/go-httpheader v0.2.1 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect @@ -159,6 +160,7 @@ require ( gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gorm.io/driver/clickhouse v0.7.0 // indirect + gorm.io/driver/sqlite v1.6.0 // indirect ) require ( diff --git a/go.sum b/go.sum index 002bc397f..3becb58fd 100644 --- a/go.sum +++ b/go.sum @@ -280,6 +280,8 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= @@ -544,6 +546,8 @@ gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314= gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= gorm.io/driver/sqlite v1.5.0 h1:zKYbzRCpBrT1bNijRnxLDJWPjVfImGEn0lSnUY5gZ+c= gorm.io/driver/sqlite v1.5.0/go.mod h1:kDMDfntV9u/vuMmz8APHtHF0b4nyBB7sfCieC6G8k8I= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/internal/service/install_task_service.go b/internal/service/install_task_service.go index ac01584d3..6ef6e7a95 100644 --- a/internal/service/install_task_service.go +++ b/internal/service/install_task_service.go @@ -1,44 +1,66 @@ package service import ( + "fmt" + "github.com/langgenius/dify-plugin-daemon/internal/db" "github.com/langgenius/dify-plugin-daemon/internal/types/exception" "github.com/langgenius/dify-plugin-daemon/internal/types/models" "github.com/langgenius/dify-plugin-daemon/pkg/entities" "github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities" + "golang.org/x/sync/singleflight" "gorm.io/gorm" ) +var ( + installationTasksGroup singleflight.Group + installationTaskGroup singleflight.Group +) + func FetchPluginInstallationTasks( tenant_id string, page int, page_size int, ) *entities.Response { - tasks, err := db.GetAll[models.InstallTask]( - db.Equal("tenant_id", tenant_id), - db.OrderBy("created_at", true), - db.Page(page, page_size), - ) + key := fmt.Sprintf("tasks:%s:%d:%d", tenant_id, page, page_size) + v, err, _ := installationTasksGroup.Do(key, func() (interface{}, error) { + tasks, err := db.GetAll[models.InstallTask]( + db.Equal("tenant_id", tenant_id), + db.OrderBy("created_at", true), + db.Page(page, page_size), + ) + if err != nil { + return nil, err + } + return tasks, nil + }) if err != nil { return exception.InternalServerError(err).ToResponse() } - return entities.NewSuccessResponse(tasks) + return entities.NewSuccessResponse(v) } func FetchPluginInstallationTask( tenant_id string, task_id string, ) *entities.Response { - task, err := db.GetOne[models.InstallTask]( - db.Equal("id", task_id), - db.Equal("tenant_id", tenant_id), - ) + key := fmt.Sprintf("task:%s:%s", tenant_id, task_id) + v, err, _ := installationTaskGroup.Do(key, func() (interface{}, error) { + task, err := db.GetOne[models.InstallTask]( + db.Equal("id", task_id), + db.Equal("tenant_id", tenant_id), + ) + if err != nil { + return nil, err + } + return task, nil + }) if err != nil { return exception.InternalServerError(err).ToResponse() } - return entities.NewSuccessResponse(task) + return entities.NewSuccessResponse(v) } func DeletePluginInstallationTask( diff --git a/internal/service/install_task_service_test.go b/internal/service/install_task_service_test.go new file mode 100644 index 000000000..dde963dbb --- /dev/null +++ b/internal/service/install_task_service_test.go @@ -0,0 +1,159 @@ +package service + +import ( + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/langgenius/dify-plugin-daemon/internal/db" + "github.com/langgenius/dify-plugin-daemon/internal/types/models" + "golang.org/x/sync/singleflight" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func setupTestDB(t *testing.T) *gorm.DB { + t.Helper() + testDB, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open test db: %v", err) + } + sqlDB, err := testDB.DB() + if err != nil { + t.Fatalf("failed to get underlying sql.DB: %v", err) + } + sqlDB.SetMaxOpenConns(1) + + if err := testDB.AutoMigrate(&models.InstallTask{}); err != nil { + t.Fatalf("failed to auto migrate: %v", err) + } + db.DifyPluginDB = testDB + t.Cleanup(func() { + sqlDB.Close() + }) + return testDB +} + +func TestFetchPluginInstallationTasks_Singleflight(t *testing.T) { + testDB := setupTestDB(t) + installationTasksGroup = singleflight.Group{} + + var queryCount atomic.Int32 + testDB.Callback().Query().Before("gorm:query").Register("test:count_tasks", func(tx *gorm.DB) { + queryCount.Add(1) + time.Sleep(100 * time.Millisecond) + }) + defer testDB.Callback().Query().Remove("test:count_tasks") + + const concurrency = 50 + var wg sync.WaitGroup + wg.Add(concurrency) + start := make(chan struct{}) + errs := make([]int, concurrency) + + for i := 0; i < concurrency; i++ { + go func(idx int) { + defer wg.Done() + <-start + resp := FetchPluginInstallationTasks("tenant-1", 1, 10) + errs[idx] = resp.Code + }(i) + } + + close(start) + wg.Wait() + + for i, code := range errs { + if code != 0 { + t.Errorf("goroutine %d: expected code 0, got %d", i, code) + } + } + + if count := queryCount.Load(); count != 1 { + t.Errorf("singleflight not working: expected 1 db query for same key, got %d", count) + } +} + +func TestFetchPluginInstallationTask_Singleflight(t *testing.T) { + testDB := setupTestDB(t) + installationTaskGroup = singleflight.Group{} + + // Insert a test record before registering the callback. + task := models.InstallTask{ + TenantID: "tenant-1", + Status: models.InstallTaskStatusPending, + TotalPlugins: 1, + } + if err := testDB.Create(&task).Error; err != nil { + t.Fatalf("failed to create test task: %v", err) + } + + var queryCount atomic.Int32 + testDB.Callback().Query().Before("gorm:query").Register("test:count_task", func(tx *gorm.DB) { + queryCount.Add(1) + time.Sleep(100 * time.Millisecond) + }) + defer testDB.Callback().Query().Remove("test:count_task") + + const concurrency = 50 + var wg sync.WaitGroup + wg.Add(concurrency) + start := make(chan struct{}) + errs := make([]int, concurrency) + + for i := 0; i < concurrency; i++ { + go func(idx int) { + defer wg.Done() + <-start + resp := FetchPluginInstallationTask("tenant-1", task.ID) + errs[idx] = resp.Code + }(i) + } + + close(start) + wg.Wait() + + for i, code := range errs { + if code != 0 { + t.Errorf("goroutine %d: expected code 0, got %d", i, code) + } + } + + if count := queryCount.Load(); count != 1 { + t.Errorf("singleflight not working: expected 1 db query for same key, got %d", count) + } +} + +func TestFetchPluginInstallationTasks_DifferentKeysNotDeduplicated(t *testing.T) { + testDB := setupTestDB(t) + installationTasksGroup = singleflight.Group{} + + var queryCount atomic.Int32 + testDB.Callback().Query().Before("gorm:query").Register("test:count_diff", func(tx *gorm.DB) { + queryCount.Add(1) + time.Sleep(50 * time.Millisecond) + }) + defer testDB.Callback().Query().Remove("test:count_diff") + + const numKeys = 3 + var wg sync.WaitGroup + wg.Add(numKeys) + start := make(chan struct{}) + + for i := 0; i < numKeys; i++ { + go func(idx int) { + defer wg.Done() + <-start + FetchPluginInstallationTasks(fmt.Sprintf("tenant-%d", idx), 1, 10) + }(i) + } + + close(start) + wg.Wait() + + if count := queryCount.Load(); count != int32(numKeys) { + t.Errorf("different keys should not be deduplicated: expected %d db queries, got %d", numKeys, count) + } +}