Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions internal/core/persistence/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package persistence

import (
"encoding/hex"
"errors"
"fmt"
"path"
"strings"
Expand Down Expand Up @@ -47,54 +48,71 @@ func (c *Persistence) Save(tenantId string, pluginId string, maxSize int64, key
maxSize = c.maxStorageSize
}

if err := c.storage.Save(tenantId, pluginId, key, data); err != nil {
return err
newSize := int64(len(data))
var oldSize int64 = 0
if exist, err := c.storage.Exists(tenantId, pluginId, key); err == nil && exist {
if s, err2 := c.storage.StateSize(tenantId, pluginId, key); err2 == nil {
oldSize = s
}
}
delta := newSize - oldSize

allocatedSize := int64(len(data))

storage, err := db.GetOne[models.TenantStorage](
record, err := db.GetOne[models.TenantStorage](
db.Equal("tenant_id", tenantId),
db.Equal("plugin_id", pluginId),
)
if err != nil {
if allocatedSize > c.maxStorageSize || allocatedSize > maxSize {
return fmt.Errorf("allocated size is greater than max storage size")
if !errors.Is(err, db.ErrDatabaseNotFound) {
return err
}

if err == db.ErrDatabaseNotFound {
storage = models.TenantStorage{
TenantID: tenantId,
PluginID: pluginId,
Size: allocatedSize,
}
if err := db.Create(&storage); err != nil {
return err
}
} else {
return err
if newSize > c.maxStorageSize || newSize > maxSize {
return fmt.Errorf("allocated size is greater than max storage size")
}
} else {
if allocatedSize+storage.Size > maxSize || allocatedSize+storage.Size > c.maxStorageSize {
if delta > 0 && (record.Size+delta > maxSize || record.Size+delta > c.maxStorageSize) {
return fmt.Errorf("allocated size is greater than max storage size")
}
}

if err := c.storage.Save(tenantId, pluginId, key, data); err != nil {
return err
}

err = db.Run(
db.Model(&models.TenantStorage{}),
db.Equal("tenant_id", tenantId),
db.Equal("plugin_id", pluginId),
db.Inc(map[string]int64{"size": allocatedSize}),
)
if err != nil {
if errors.Is(err, db.ErrDatabaseNotFound) {
rec := models.TenantStorage{TenantID: tenantId, PluginID: pluginId, Size: newSize}
if err := db.Create(&rec); err != nil {
return err
}
} else {
if delta != 0 {
if delta > 0 {
if err := db.Run(
db.Model(&models.TenantStorage{}),
db.Equal("tenant_id", tenantId),
db.Equal("plugin_id", pluginId),
db.Inc(map[string]int64{"size": delta}),
); err != nil {
return err
}
} else {
if err := db.Run(
db.Model(&models.TenantStorage{}),
db.Equal("tenant_id", tenantId),
db.Equal("plugin_id", pluginId),
db.Dec(map[string]int64{"size": -delta}),
); err != nil {
return err
}
}
}
}

// delete from cache
if _, err = cache.Del(c.getCacheKey(tenantId, pluginId, key)); err == cache.ErrNotFound {
if _, err = cache.Del(c.getCacheKey(tenantId, pluginId, key)); errors.Is(err, cache.ErrNotFound) {
return nil
}
return err
return nil
}

// TODO: raises specific error to avoid confusion
Expand Down
99 changes: 99 additions & 0 deletions internal/core/persistence/persistence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/langgenius/dify-cloud-kit/oss/factory"
"github.com/langgenius/dify-plugin-daemon/internal/db"
"github.com/langgenius/dify-plugin-daemon/internal/types/app"
"github.com/langgenius/dify-plugin-daemon/internal/types/models"
"github.com/langgenius/dify-plugin-daemon/pkg/utils/cache"
"github.com/langgenius/dify-plugin-daemon/pkg/utils/strings"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -69,6 +70,104 @@ func TestPersistenceStoreAndLoad(t *testing.T) {
assert.Equal(t, string(cacheDataBytes), "data")
}

func TestPersistenceOverwriteAdjustsCounter(t *testing.T) {
// init deps
err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil)
assert.Nil(t, err)
defer cache.Close()

db.Init(&app.Config{
DBType: app.DB_TYPE_POSTGRESQL,
DBUsername: "postgres",
DBPassword: "difyai123456",
DBHost: "localhost",
DBPort: 5432,
DBDatabase: "dify_plugin_daemon",
DBSslMode: "disable",
})
defer db.Close()

oss, err := factory.Load("local", cloudoss.OSSArgs{Local: &cloudoss.Local{Path: "./storage"}})
assert.Nil(t, err)

InitPersistence(oss, &app.Config{PersistenceStoragePath: "./persistence_storage", PersistenceStorageMaxSize: 1024 * 1024})

tenant := "tenant_" + strings.RandomString(6)
plugin := "plugin_" + strings.RandomString(6)
key := "k_" + strings.RandomString(6)

// write 4 bytes
assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("abcd")))
st, err := db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin))
assert.Nil(t, err)
assert.Equal(t, int64(4), st.Size)

// overwrite with 2 bytes -> size should be 2
assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("bb")))
st, err = db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin))
assert.Nil(t, err)
assert.Equal(t, int64(2), st.Size)

// overwrite with 3 bytes -> size should be 3
assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("ccc")))
st, err = db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin))
assert.Nil(t, err)
assert.Equal(t, int64(3), st.Size)

// and data should be latest
data, err := persistence.Load(tenant, plugin, key)
assert.Nil(t, err)
assert.Equal(t, "ccc", string(data))
}

func TestPersistenceOverwriteLimitEnforcedByDelta(t *testing.T) {
// init deps
err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil)
assert.Nil(t, err)
defer cache.Close()

db.Init(&app.Config{
DBType: app.DB_TYPE_POSTGRESQL,
DBUsername: "postgres",
DBPassword: "difyai123456",
DBHost: "localhost",
DBPort: 5432,
DBDatabase: "dify_plugin_daemon",
DBSslMode: "disable",
})
defer db.Close()

oss, err := factory.Load("local", cloudoss.OSSArgs{Local: &cloudoss.Local{Path: "./storage"}})
assert.Nil(t, err)

// set a small global limit 5 bytes
InitPersistence(oss, &app.Config{PersistenceStoragePath: "./persistence_storage", PersistenceStorageMaxSize: 5})

tenant := "tenant_" + strings.RandomString(6)
plugin := "plugin_" + strings.RandomString(6)
key := "k_" + strings.RandomString(6)

// write 4 bytes OK
assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("aaaa")))
st, err := db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin))
assert.Nil(t, err)
assert.Equal(t, int64(4), st.Size)

// overwrite with 6 bytes -> delta = +2, 4+2=6 > 5 -> expect error, no change
if err := persistence.Save(tenant, plugin, -1, key, []byte("abcdef")); err == nil {
t.Fatalf("expected limit error, got nil")
}

st, err = db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin))
assert.Nil(t, err)
assert.Equal(t, int64(4), st.Size)

// stored data should remain old value
data, err := persistence.Load(tenant, plugin, key)
assert.Nil(t, err)
assert.Equal(t, "aaaa", string(data))
}

func TestPersistenceSaveAndLoadWithLongKey(t *testing.T) {
err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil)
assert.Nil(t, err)
Expand Down
Loading