diff --git a/internal/cache/cache.go b/internal/cache/cache.go index f76648e..b9126ff 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -20,11 +20,12 @@ func (e *Entry) IsExpired() bool { return time.Since(e.Timestamp) > e.TTL } -// Cache is a thread-safe in-memory cache with TTL. +// Cache is a thread-safe in-memory cache with TTL and per-register storage. type Cache struct { mu sync.RWMutex entries map[string]*Entry defaultTTL time.Duration + keepStale bool // when true, cleanup won't delete expired entries // For request coalescing inflight map[string]*inflightRequest @@ -41,10 +42,12 @@ type inflightRequest struct { } // New creates a new cache with the specified default TTL. -func New(defaultTTL time.Duration) *Cache { +// If keepStale is true, expired entries are retained for stale serving. +func New(defaultTTL time.Duration, keepStale bool) *Cache { c := &Cache{ entries: make(map[string]*Entry), defaultTTL: defaultTTL, + keepStale: keepStale, inflight: make(map[string]*inflightRequest), done: make(chan struct{}), } @@ -60,9 +63,14 @@ func (c *Cache) Close() { close(c.done) } -// Key generates a cache key from request parameters. -func Key(slaveID byte, functionCode byte, address uint16, quantity uint16) string { - return fmt.Sprintf("%d:%d:%d:%d", slaveID, functionCode, address, quantity) +// RegKey generates a cache key for a single register or coil. +func RegKey(slaveID byte, functionCode byte, address uint16) string { + return fmt.Sprintf("%d:%d:%d", slaveID, functionCode, address) +} + +// RangeKey generates a coalescing key for a request range. +func RangeKey(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) string { + return fmt.Sprintf("%d:%d:%d:%d", slaveID, functionCode, startAddr, quantity) } // Get retrieves a value from the cache. @@ -100,11 +108,6 @@ func (c *Cache) GetStale(key string) ([]byte, bool) { // Set stores a value in the cache with the default TTL. func (c *Cache) Set(key string, data []byte) { - c.SetWithTTL(key, data, c.defaultTTL) -} - -// SetWithTTL stores a value in the cache with a specific TTL. -func (c *Cache) SetWithTTL(key string, data []byte, ttl time.Duration) { c.mu.Lock() defer c.mu.Unlock() @@ -115,7 +118,7 @@ func (c *Cache) SetWithTTL(key string, data []byte, ttl time.Duration) { c.entries[key] = &Entry{ Data: dataCopy, Timestamp: time.Now(), - TTL: ttl, + TTL: c.defaultTTL, } } @@ -126,17 +129,88 @@ func (c *Cache) Delete(key string) { delete(c.entries, key) } -// GetOrFetch retrieves a value from the cache or fetches it using the provided function. -// This implements request coalescing - multiple concurrent requests for the same key -// will share a single fetch operation. -// Returns the data, a boolean indicating if it was a cache hit, and any error. -func (c *Cache) GetOrFetch(ctx context.Context, key string, fetch func(context.Context) ([]byte, error)) ([]byte, bool, error) { - // Check cache first - if data, ok := c.Get(key); ok { - return data, true, nil +// GetRange retrieves all values for a contiguous register range. +// Returns the per-register/coil values and true only if ALL are cached and fresh. +func (c *Cache) GetRange(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) ([][]byte, bool) { + if quantity == 0 { + return nil, false + } + + c.mu.RLock() + defer c.mu.RUnlock() + + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + key := RegKey(slaveID, functionCode, startAddr+i) + entry, ok := c.entries[key] + if !ok || entry.IsExpired() { + return nil, false + } + data := make([]byte, len(entry.Data)) + copy(data, entry.Data) + values[i] = data + } + return values, true +} + +// GetRangeStale retrieves all values for a contiguous register range, ignoring TTL. +// Returns the per-register/coil values and true only if ALL are present (even if expired). +func (c *Cache) GetRangeStale(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) ([][]byte, bool) { + if quantity == 0 { + return nil, false + } + + c.mu.RLock() + defer c.mu.RUnlock() + + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + key := RegKey(slaveID, functionCode, startAddr+i) + entry, ok := c.entries[key] + if !ok { + return nil, false + } + data := make([]byte, len(entry.Data)) + copy(data, entry.Data) + values[i] = data + } + return values, true +} + +// SetRange stores individual values for a contiguous register range. +// All entries are stored with the same timestamp for consistency. +func (c *Cache) SetRange(slaveID byte, functionCode byte, startAddr uint16, values [][]byte) { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for i, v := range values { + key := RegKey(slaveID, functionCode, startAddr+uint16(i)) + dataCopy := make([]byte, len(v)) + copy(dataCopy, v) + c.entries[key] = &Entry{ + Data: dataCopy, + Timestamp: now, + TTL: c.defaultTTL, + } + } +} + +// DeleteRange removes all entries for a contiguous register range. +func (c *Cache) DeleteRange(slaveID byte, functionCode byte, startAddr uint16, quantity uint16) { + c.mu.Lock() + defer c.mu.Unlock() + + for i := uint16(0); i < quantity; i++ { + key := RegKey(slaveID, functionCode, startAddr+i) + delete(c.entries, key) } +} - // Check if there's already an in-flight request +// Coalesce ensures only one fetch runs for a given key at a time. +// Other callers with the same key wait for and share the first caller's result. +// This handles request coalescing only — it does not interact with cache storage. +func (c *Cache) Coalesce(ctx context.Context, key string, fetch func(context.Context) ([]byte, error)) ([]byte, error) { c.inflightMu.Lock() if req, ok := c.inflight[key]; ok { c.inflightMu.Unlock() @@ -144,14 +218,14 @@ func (c *Cache) GetOrFetch(ctx context.Context, key string, fetch func(context.C select { case <-req.done: if req.err != nil { - return nil, false, req.err + return nil, req.err } // Return a copy data := make([]byte, len(req.result)) copy(data, req.result) - return data, false, nil + return data, nil case <-ctx.Done(): - return nil, false, ctx.Err() + return nil, ctx.Err() } } @@ -165,22 +239,38 @@ func (c *Cache) GetOrFetch(ctx context.Context, key string, fetch func(context.C // Fetch the data data, err := fetch(ctx) - // Store result + // Store result for waiters req.result = data req.err = err - // Cache successful results - if err == nil { - c.Set(key, data) - } - // Clean up and notify waiters c.inflightMu.Lock() delete(c.inflight, key) c.inflightMu.Unlock() close(req.done) - return data, false, err + if err != nil { + return nil, err + } + + result := make([]byte, len(data)) + copy(result, data) + return result, nil +} + +// cleanupOnce runs a single cleanup pass, removing expired entries. +// Skips deletion when keepStale is true. +func (c *Cache) cleanupOnce() { + if c.keepStale { + return + } + c.mu.Lock() + for key, entry := range c.entries { + if entry.IsExpired() { + delete(c.entries, key) + } + } + c.mu.Unlock() } // cleanup periodically removes expired entries. @@ -193,13 +283,7 @@ func (c *Cache) cleanup() { case <-c.done: return case <-ticker.C: - c.mu.Lock() - for key, entry := range c.entries { - if entry.IsExpired() { - delete(c.entries, key) - } - } - c.mu.Unlock() + c.cleanupOnce() } } } diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 8741070..445c010 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -9,7 +9,7 @@ import ( ) func TestCache_GetSet(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() // Test miss @@ -29,7 +29,7 @@ func TestCache_GetSet(t *testing.T) { } func TestCache_TTL(t *testing.T) { - c := New(50 * time.Millisecond) + c := New(50*time.Millisecond, false) defer c.Close() c.Set("key1", []byte("value1")) @@ -49,7 +49,7 @@ func TestCache_TTL(t *testing.T) { } func TestCache_GetStale(t *testing.T) { - c := New(50 * time.Millisecond) + c := New(50*time.Millisecond, false) defer c.Close() c.Set("key1", []byte("value1")) @@ -68,7 +68,7 @@ func TestCache_GetStale(t *testing.T) { } func TestCache_Delete(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() c.Set("key1", []byte("value1")) @@ -79,16 +79,129 @@ func TestCache_Delete(t *testing.T) { } } -func TestCache_Key(t *testing.T) { - key := Key(1, 0x03, 100, 10) +func TestRegKey(t *testing.T) { + key := RegKey(1, 0x03, 100) + expected := "1:3:100" + if key != expected { + t.Errorf("expected %s, got %s", expected, key) + } +} + +func TestRangeKey(t *testing.T) { + key := RangeKey(1, 0x03, 100, 10) expected := "1:3:100:10" if key != expected { t.Errorf("expected %s, got %s", expected, key) } } -func TestCache_GetOrFetch(t *testing.T) { - c := New(time.Second) +func TestCache_GetRange(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + // Store 3 registers + c.Set(RegKey(1, 0x03, 10), []byte{0x00, 0x01}) + c.Set(RegKey(1, 0x03, 11), []byte{0x00, 0x02}) + c.Set(RegKey(1, 0x03, 12), []byte{0x00, 0x03}) + + // Full range hit + values, ok := c.GetRange(1, 0x03, 10, 3) + if !ok { + t.Error("expected range hit") + } + if len(values) != 3 { + t.Fatalf("expected 3 values, got %d", len(values)) + } + for i, expected := range []byte{0x01, 0x02, 0x03} { + if values[i][1] != expected { + t.Errorf("value[%d]: expected 0x%02X, got 0x%02X", i, expected, values[i][1]) + } + } + + // Partial range miss + _, ok = c.GetRange(1, 0x03, 10, 5) + if ok { + t.Error("expected range miss (registers 13-14 not cached)") + } +} + +func TestCache_SetRange(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + values := [][]byte{{0x00, 0x0A}, {0x00, 0x0B}} + c.SetRange(1, 0x03, 100, values) + + // Each register should be independently accessible + data, ok := c.Get(RegKey(1, 0x03, 100)) + if !ok { + t.Error("expected hit for register 100") + } + if data[1] != 0x0A { + t.Errorf("expected 0x0A, got 0x%02X", data[1]) + } + + data, ok = c.Get(RegKey(1, 0x03, 101)) + if !ok { + t.Error("expected hit for register 101") + } + if data[1] != 0x0B { + t.Errorf("expected 0x0B, got 0x%02X", data[1]) + } +} + +func TestCache_DeleteRange(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + values := [][]byte{{0x00, 0x0A}, {0x00, 0x0B}, {0x00, 0x0C}} + c.SetRange(1, 0x03, 100, values) + + // Delete middle register + c.DeleteRange(1, 0x03, 101, 1) + + // Register 100 still cached + if _, ok := c.Get(RegKey(1, 0x03, 100)); !ok { + t.Error("register 100 should still be cached") + } + // Register 101 deleted + if _, ok := c.Get(RegKey(1, 0x03, 101)); ok { + t.Error("register 101 should be deleted") + } + // Register 102 still cached + if _, ok := c.Get(RegKey(1, 0x03, 102)); !ok { + t.Error("register 102 should still be cached") + } + // Full range now misses + if _, ok := c.GetRange(1, 0x03, 100, 3); ok { + t.Error("expected range miss after deleting register 101") + } +} + +func TestCache_GetRangeStale(t *testing.T) { + c := New(50*time.Millisecond, false) + defer c.Close() + + c.SetRange(1, 0x03, 10, [][]byte{{0x00, 0x01}, {0x00, 0x02}}) + time.Sleep(100 * time.Millisecond) + + // Fresh get should miss + if _, ok := c.GetRange(1, 0x03, 10, 2); ok { + t.Error("expected range miss after TTL") + } + + // Stale get should succeed + values, ok := c.GetRangeStale(1, 0x03, 10, 2) + if !ok { + t.Error("expected stale range hit") + } + if len(values) != 2 { + t.Fatalf("expected 2 stale values, got %d", len(values)) + } +} + +func TestCache_Coalesce(t *testing.T) { + c := New(time.Second, false) defer c.Close() ctx := context.Background() @@ -98,39 +211,20 @@ func TestCache_GetOrFetch(t *testing.T) { return []byte("fetched"), nil } - // First call should fetch (cache miss) - data, hit, err := c.GetOrFetch(ctx, "key1", fetch) + data, err := c.Coalesce(ctx, "key1", fetch) if err != nil { t.Errorf("unexpected error: %v", err) } - if hit { - t.Error("expected cache miss on first call") - } if string(data) != "fetched" { t.Errorf("expected fetched, got %s", string(data)) } if fetchCount != 1 { t.Errorf("expected 1 fetch, got %d", fetchCount) } - - // Second call should hit cache - data, hit, err = c.GetOrFetch(ctx, "key1", fetch) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if !hit { - t.Error("expected cache hit on second call") - } - if string(data) != "fetched" { - t.Errorf("expected fetched, got %s", string(data)) - } - if fetchCount != 1 { - t.Errorf("expected 1 fetch (cache hit), got %d", fetchCount) - } } -func TestCache_RequestCoalescing(t *testing.T) { - c := New(time.Second) +func TestCache_CoalescingConcurrent(t *testing.T) { + c := New(time.Second, false) defer c.Close() ctx := context.Background() @@ -153,7 +247,7 @@ func TestCache_RequestCoalescing(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - results[0], _, errors[0] = c.GetOrFetch(ctx, "key1", fetch) + results[0], errors[0] = c.Coalesce(ctx, "key1", fetch) }() // Wait for fetch to start @@ -165,7 +259,7 @@ func TestCache_RequestCoalescing(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - results[i], _, errors[i] = c.GetOrFetch(ctx, "key1", func(ctx context.Context) ([]byte, error) { + results[i], errors[i] = c.Coalesce(ctx, "key1", func(ctx context.Context) ([]byte, error) { atomic.AddInt32(&fetchCount, 1) return []byte("should not be called"), nil }) @@ -196,7 +290,7 @@ func TestCache_RequestCoalescing(t *testing.T) { } func TestCache_ContextCancellation(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -204,7 +298,7 @@ func TestCache_ContextCancellation(t *testing.T) { // Start a slow fetch go func() { - c.GetOrFetch(ctx, "key1", func(ctx context.Context) ([]byte, error) { + c.Coalesce(ctx, "key1", func(ctx context.Context) ([]byte, error) { close(fetchStarted) time.Sleep(time.Second) return []byte("fetched"), nil @@ -217,7 +311,7 @@ func TestCache_ContextCancellation(t *testing.T) { ctx2, cancel2 := context.WithCancel(context.Background()) cancel2() // Cancel immediately - _, _, err := c.GetOrFetch(ctx2, "key1", func(ctx context.Context) ([]byte, error) { + _, err := c.Coalesce(ctx2, "key1", func(ctx context.Context) ([]byte, error) { return []byte("should not be called"), nil }) @@ -229,7 +323,7 @@ func TestCache_ContextCancellation(t *testing.T) { } func TestCache_DataIsolation(t *testing.T) { - c := New(time.Second) + c := New(time.Second, false) defer c.Close() original := []byte("original") @@ -253,3 +347,54 @@ func TestCache_DataIsolation(t *testing.T) { t.Error("cache data was mutated via returned slice") } } + +func TestCache_RangeDataIsolation(t *testing.T) { + c := New(time.Second, false) + defer c.Close() + + original := [][]byte{{0x00, 0x01}, {0x00, 0x02}} + c.SetRange(1, 0x03, 0, original) + + // Mutate original + original[0][0] = 0xFF + + // Cache should be unaffected + values, ok := c.GetRange(1, 0x03, 0, 2) + if !ok { + t.Error("expected range hit") + } + if values[0][0] != 0x00 { + t.Error("cache data was mutated via original slice") + } +} + +func TestCache_KeepStale(t *testing.T) { + // With keepStale=false, cleanup removes expired entries + c := New(50*time.Millisecond, false) + c.Set("key1", []byte("value1")) + time.Sleep(100 * time.Millisecond) + + c.cleanupOnce() + + if _, ok := c.GetStale("key1"); ok { + t.Error("expected stale data to be gone after cleanup with keepStale=false") + } + c.Close() + + // With keepStale=true, cleanup skips deletion + c2 := New(50*time.Millisecond, true) + c2.Set("key1", []byte("value1")) + time.Sleep(100 * time.Millisecond) + + c2.cleanupOnce() + + // Entry should still be accessible via GetStale after cleanup + data, ok := c2.GetStale("key1") + if !ok { + t.Error("expected stale data to survive cleanup with keepStale=true") + } + if string(data) != "value1" { + t.Errorf("expected value1, got %s", string(data)) + } + c2.Close() +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index ddd28bd..c09fac1 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -28,7 +28,7 @@ func New(cfg *config.Config, logger *slog.Logger) (*Proxy, error) { cfg: cfg, logger: logger, client: modbus.NewClient(cfg.Upstream, cfg.Timeout, cfg.RequestDelay, cfg.ConnectDelay, logger), - cache: cache.New(cfg.CacheTTL), + cache: cache.New(cfg.CacheTTL, cfg.CacheServeStale), } p.server = modbus.NewServer(p, logger) @@ -104,41 +104,49 @@ func (p *Proxy) HandleRequest(ctx context.Context, req *modbus.Request) ([]byte, } func (p *Proxy) handleRead(ctx context.Context, req *modbus.Request) ([]byte, error) { - key := cache.Key(req.SlaveID, req.FunctionCode, req.Address, req.Quantity) - - // Use GetOrFetch for request coalescing - data, cacheHit, err := p.cache.GetOrFetch(ctx, key, func(ctx context.Context) ([]byte, error) { - p.logger.Debug("cache miss", + // Check per-register cache + values, cacheHit := p.cache.GetRange(req.SlaveID, req.FunctionCode, req.Address, req.Quantity) + if cacheHit { + p.logger.Debug("cache hit", "slave_id", req.SlaveID, "func", fmt.Sprintf("0x%02X", req.FunctionCode), "addr", req.Address, "qty", req.Quantity, ) + return assembleResponse(req.FunctionCode, req.Quantity, values), nil + } + + // Cache miss — fetch with coalescing + p.logger.Debug("cache miss", + "slave_id", req.SlaveID, + "func", fmt.Sprintf("0x%02X", req.FunctionCode), + "addr", req.Address, + "qty", req.Quantity, + ) + rangeKey := cache.RangeKey(req.SlaveID, req.FunctionCode, req.Address, req.Quantity) + data, err := p.cache.Coalesce(ctx, rangeKey, func(ctx context.Context) ([]byte, error) { return p.client.Execute(ctx, req) }) if err != nil { // Try serving stale data if configured if p.cfg.CacheServeStale { - if stale, ok := p.cache.GetStale(key); ok { + if staleValues, ok := p.cache.GetRangeStale(req.SlaveID, req.FunctionCode, req.Address, req.Quantity); ok { p.logger.Warn("upstream error, serving stale", "slave_id", req.SlaveID, "error", err, ) - return stale, nil + return assembleResponse(req.FunctionCode, req.Quantity, staleValues), nil } } return nil, err } - if cacheHit { - p.logger.Debug("cache hit", - "slave_id", req.SlaveID, - "func", fmt.Sprintf("0x%02X", req.FunctionCode), - "addr", req.Address, - "qty", req.Quantity, - ) + // Decompose response and store per-register + regValues := decomposeResponse(req.FunctionCode, req.Quantity, data) + if regValues != nil { + p.cache.SetRange(req.SlaveID, req.FunctionCode, req.Address, regValues) } return data, nil @@ -171,7 +179,7 @@ func (p *Proxy) handleWrite(ctx context.Context, req *modbus.Request) ([]byte, e return nil, err } - // Invalidate exact matching cache entries for all read function codes + // Invalidate per-register cache entries for the written range p.invalidateCache(req) return resp, nil @@ -181,7 +189,7 @@ func (p *Proxy) handleWrite(ctx context.Context, req *modbus.Request) ([]byte, e } func (p *Proxy) invalidateCache(req *modbus.Request) { - // Invalidate exact matches for all read function codes that could overlap + // Invalidate per-register entries for all read function codes readFuncs := []byte{ modbus.FuncReadCoils, modbus.FuncReadDiscreteInputs, @@ -190,9 +198,93 @@ func (p *Proxy) invalidateCache(req *modbus.Request) { } for _, fc := range readFuncs { - key := cache.Key(req.SlaveID, fc, req.Address, req.Quantity) - p.cache.Delete(key) + p.cache.DeleteRange(req.SlaveID, fc, req.Address, req.Quantity) + } +} + +// Shared byte slices for coil values — safe to reuse since SetRange copies. +var ( + coilOn = []byte{1} + coilOff = []byte{0} +) + +// decomposeResponse extracts per-register/coil values from a Modbus read response. +// Response format: [funcCode, byteCount, data...] +// For registers (FC 0x03, 0x04): each register is 2 bytes. +// For coils/discrete inputs (FC 0x01, 0x02): each coil is 1 bit, stored as 1 byte (0 or 1). +func decomposeResponse(functionCode byte, quantity uint16, data []byte) [][]byte { + if len(data) < 2 { + return nil } + + payload := data[2:] // Skip funcCode and byteCount + + switch functionCode { + case modbus.FuncReadHoldingRegisters, modbus.FuncReadInputRegisters: + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + offset := i * 2 + if int(offset+2) > len(payload) { + return nil + } + reg := make([]byte, 2) + copy(reg, payload[offset:offset+2]) + values[i] = reg + } + return values + + case modbus.FuncReadCoils, modbus.FuncReadDiscreteInputs: + values := make([][]byte, quantity) + for i := uint16(0); i < quantity; i++ { + byteIdx := i / 8 + bitIdx := i % 8 + if int(byteIdx) >= len(payload) { + return nil + } + if payload[byteIdx]&(1<= 2 { + resp[2+i*2] = v[0] + resp[2+i*2+1] = v[1] + } + } + return resp + + case modbus.FuncReadCoils, modbus.FuncReadDiscreteInputs: + byteCount := (quantity + 7) / 8 + resp := make([]byte, 2+byteCount) + resp[0] = functionCode + resp[1] = byte(byteCount) + for i, v := range values { + if len(v) > 0 && v[0] != 0 { + byteIdx := i / 8 + bitIdx := uint(i % 8) + resp[2+byteIdx] |= 1 << bitIdx + } + } + return resp + } + + return nil } func (p *Proxy) buildFakeWriteResponse(req *modbus.Request) []byte { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 99364cc..f36d764 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -26,7 +26,8 @@ func (m *mockClient) Execute(ctx context.Context, req *modbus.Request) ([]byte, func TestProxy_HandleReadCacheHit(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - c := cache.New(time.Second) + c := cache.New(time.Second, false) + defer c.Close() p := &Proxy{ cfg: &config.Config{ @@ -38,15 +39,17 @@ func TestProxy_HandleReadCacheHit(t *testing.T) { cache: c, } - // Pre-populate cache - key := cache.Key(1, modbus.FuncReadHoldingRegisters, 0, 10) - c.Set(key, []byte{0x03, 0x14, 0x00, 0x01}) // Function code + byte count + data + // Pre-populate cache with per-register values + c.SetRange(1, modbus.FuncReadHoldingRegisters, 0, [][]byte{ + {0x00, 0x01}, + {0x00, 0x02}, + }) req := &modbus.Request{ SlaveID: 1, FunctionCode: modbus.FuncReadHoldingRegisters, Address: 0, - Quantity: 10, + Quantity: 2, } resp, err := p.HandleRequest(context.Background(), req) @@ -54,8 +57,15 @@ func TestProxy_HandleReadCacheHit(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if string(resp) != string([]byte{0x03, 0x14, 0x00, 0x01}) { - t.Errorf("unexpected response: %v", resp) + // Expected assembled response: funcCode + byteCount + reg0 + reg1 + expected := []byte{0x03, 0x04, 0x00, 0x01, 0x00, 0x02} + if len(resp) != len(expected) { + t.Fatalf("expected %d bytes, got %d", len(expected), len(resp)) + } + for i := range expected { + if resp[i] != expected[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, expected[i], resp[i]) + } } } @@ -73,12 +83,15 @@ func TestProxy_HandleWriteReadOnlyMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + c := cache.New(time.Second, false) + defer c.Close() + p := &Proxy{ cfg: &config.Config{ ReadOnly: tt.mode, }, logger: logger, - cache: cache.New(time.Second), + cache: c, } req := &modbus.Request{ @@ -104,13 +117,15 @@ func TestProxy_HandleWriteReadOnlyMode(t *testing.T) { func TestProxy_HandleUnknownFunction(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(time.Second, false) + defer c.Close() p := &Proxy{ cfg: &config.Config{ ReadOnly: config.ReadOnlyOn, }, logger: logger, - cache: cache.New(time.Second), + cache: c, } req := &modbus.Request{ @@ -183,3 +198,214 @@ func TestProxy_BuildFakeWriteResponse(t *testing.T) { }) } } + +func TestDecomposeResponse_Registers(t *testing.T) { + // Response: FC 0x03, byteCount=4, reg0=0x0001, reg1=0x0002 + data := []byte{0x03, 0x04, 0x00, 0x01, 0x00, 0x02} + values := decomposeResponse(modbus.FuncReadHoldingRegisters, 2, data) + + if len(values) != 2 { + t.Fatalf("expected 2 values, got %d", len(values)) + } + if values[0][0] != 0x00 || values[0][1] != 0x01 { + t.Errorf("reg0: expected 0x0001, got 0x%02X%02X", values[0][0], values[0][1]) + } + if values[1][0] != 0x00 || values[1][1] != 0x02 { + t.Errorf("reg1: expected 0x0002, got 0x%02X%02X", values[1][0], values[1][1]) + } +} + +func TestDecomposeResponse_Coils(t *testing.T) { + // Response: FC 0x01, byteCount=2, coils 0-9 + // 0xCD = 1100_1101: coils 0,2,3,6,7 on + // 0x01 = 0000_0001: coil 8 on + data := []byte{0x01, 0x02, 0xCD, 0x01} + values := decomposeResponse(modbus.FuncReadCoils, 10, data) + + if len(values) != 10 { + t.Fatalf("expected 10 values, got %d", len(values)) + } + + expected := []byte{1, 0, 1, 1, 0, 0, 1, 1, 1, 0} + for i, exp := range expected { + if values[i][0] != exp { + t.Errorf("coil %d: expected %d, got %d", i, exp, values[i][0]) + } + } +} + +func TestAssembleResponse_Registers(t *testing.T) { + values := [][]byte{{0x00, 0x01}, {0x00, 0x02}} + resp := assembleResponse(modbus.FuncReadHoldingRegisters, 2, values) + + expected := []byte{0x03, 0x04, 0x00, 0x01, 0x00, 0x02} + if len(resp) != len(expected) { + t.Fatalf("expected %d bytes, got %d", len(expected), len(resp)) + } + for i := range expected { + if resp[i] != expected[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, expected[i], resp[i]) + } + } +} + +func TestAssembleResponse_Coils(t *testing.T) { + // Coils 0,2,3,6,7 on, 8 on — should produce 0xCD 0x01 + values := [][]byte{{1}, {0}, {1}, {1}, {0}, {0}, {1}, {1}, {1}, {0}} + resp := assembleResponse(modbus.FuncReadCoils, 10, values) + + expected := []byte{0x01, 0x02, 0xCD, 0x01} + if len(resp) != len(expected) { + t.Fatalf("expected %d bytes, got %d", len(expected), len(resp)) + } + for i := range expected { + if resp[i] != expected[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, expected[i], resp[i]) + } + } +} + +func TestDecomposeAssemble_Roundtrip(t *testing.T) { + tests := []struct { + name string + funcCode byte + quantity uint16 + data []byte + }{ + { + name: "holding registers", + funcCode: modbus.FuncReadHoldingRegisters, + quantity: 3, + data: []byte{0x03, 0x06, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03}, + }, + { + name: "input registers", + funcCode: modbus.FuncReadInputRegisters, + quantity: 2, + data: []byte{0x04, 0x04, 0xFF, 0xFF, 0x00, 0x00}, + }, + { + name: "coils", + funcCode: modbus.FuncReadCoils, + quantity: 10, + data: []byte{0x01, 0x02, 0xCD, 0x01}, + }, + { + name: "discrete inputs", + funcCode: modbus.FuncReadDiscreteInputs, + quantity: 8, + data: []byte{0x02, 0x01, 0xAC}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + values := decomposeResponse(tt.funcCode, tt.quantity, tt.data) + if values == nil { + t.Fatal("decomposeResponse returned nil") + } + + reassembled := assembleResponse(tt.funcCode, tt.quantity, values) + if len(reassembled) != len(tt.data) { + t.Fatalf("length mismatch: expected %d, got %d", len(tt.data), len(reassembled)) + } + for i := range tt.data { + if reassembled[i] != tt.data[i] { + t.Errorf("byte %d: expected 0x%02X, got 0x%02X", i, tt.data[i], reassembled[i]) + } + } + }) + } +} + +func TestProxy_WriteInvalidatesOverlappingReads(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(time.Second, false) + defer c.Close() + + p := &Proxy{ + cfg: &config.Config{ + ReadOnly: config.ReadOnlyOn, + }, + logger: logger, + cache: c, + } + + // Cache registers 0-9 (simulating a previous read of range 0-9) + regs := make([][]byte, 10) + for i := range regs { + regs[i] = []byte{0x00, byte(i)} + } + c.SetRange(1, modbus.FuncReadHoldingRegisters, 0, regs) + + // Write to register 5 — should invalidate register 5 + p.invalidateCache(&modbus.Request{ + SlaveID: 1, + FunctionCode: modbus.FuncWriteSingleRegister, + Address: 5, + Quantity: 1, + }) + + // Full range 0-9 should now miss (register 5 is gone) + _, ok := c.GetRange(1, modbus.FuncReadHoldingRegisters, 0, 10) + if ok { + t.Error("expected range miss after write invalidation of register 5") + } + + // Registers 0-4 and 6-9 should still be cached individually + for i := uint16(0); i < 10; i++ { + _, ok := c.Get(cache.RegKey(1, modbus.FuncReadHoldingRegisters, i)) + if i == 5 { + if ok { + t.Error("register 5 should be invalidated") + } + } else { + if !ok { + t.Errorf("register %d should still be cached", i) + } + } + } +} + +func TestProxy_WriteInvalidatesMultipleRegisters(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c := cache.New(time.Second, false) + defer c.Close() + + p := &Proxy{ + cfg: &config.Config{ + ReadOnly: config.ReadOnlyOn, + }, + logger: logger, + cache: c, + } + + // Cache registers 0-9 + regs := make([][]byte, 10) + for i := range regs { + regs[i] = []byte{0x00, byte(i)} + } + c.SetRange(1, modbus.FuncReadHoldingRegisters, 0, regs) + + // Write to registers 3-5 (write multiple) + p.invalidateCache(&modbus.Request{ + SlaveID: 1, + FunctionCode: modbus.FuncWriteMultipleRegs, + Address: 3, + Quantity: 3, + }) + + // Registers 3,4,5 should be gone + for i := uint16(3); i <= 5; i++ { + if _, ok := c.Get(cache.RegKey(1, modbus.FuncReadHoldingRegisters, i)); ok { + t.Errorf("register %d should be invalidated", i) + } + } + + // Registers 0,1,2,6,7,8,9 should still be cached + for _, i := range []uint16{0, 1, 2, 6, 7, 8, 9} { + if _, ok := c.Get(cache.RegKey(1, modbus.FuncReadHoldingRegisters, i)); !ok { + t.Errorf("register %d should still be cached", i) + } + } +}