diff --git a/device/device.go b/device/device.go index 5fb0f59f5..0b7289b24 100644 --- a/device/device.go +++ b/device/device.go @@ -92,6 +92,12 @@ type Device struct { ipcMutex sync.RWMutex closed chan struct{} log *Logger + + // batchSizeOverride, when nonzero, replaces the value returned by + // BatchSize and is used to size per-goroutine eager buffer allocations. + // Must be set before calling Up so that the first bind/receive goroutines + // pick it up. + batchSizeOverride atomic.Int32 } // deviceState represents the state of a Device. @@ -301,6 +307,10 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { device.rate.limiter.Init() device.indexTable.Init() + if MaxBatchSizeOverride > 0 { + device.batchSizeOverride.Store(int32(MaxBatchSizeOverride)) + } + device.PopulatePools() // create queues @@ -331,8 +341,12 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { // BatchSize returns the BatchSize for the device as a whole which is the max of // the bind batch size and the tun batch size. The batch size reported by device // is the size used to construct memory pools, and is the allowed batch size for -// the lifetime of the device. +// the lifetime of the device. A nonzero override set via SetMaxBatchSize +// takes precedence and is returned as-is. func (device *Device) BatchSize() int { + if o := device.batchSizeOverride.Load(); o > 0 { + return int(o) + } size := device.net.bind.BatchSize() dSize := device.tun.device.BatchSize() if size < dSize { @@ -341,6 +355,18 @@ func (device *Device) BatchSize() int { return size } +// SetMaxBatchSize overrides the per-batch size used by the receive and TUN +// read goroutines, and therefore the number of message buffers each of them +// holds eagerly for the lifetime of the Device. Zero disables the override. +// Must be called before Up; already-running goroutines keep the batch size +// they started with. +func (device *Device) SetMaxBatchSize(n int) { + if n < 0 { + n = 0 + } + device.batchSizeOverride.Store(int32(n)) +} + func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { device.peers.RLock() defer device.peers.RUnlock() @@ -522,7 +548,7 @@ func (device *Device) BindUpdate() error { device.net.stopping.Add(len(recvFns)) device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake - batchSize := netc.bind.BatchSize() + batchSize := device.BatchSize() for _, fn := range recvFns { go device.RoutineReceiveIncoming(batchSize, fn) } diff --git a/device/pool_size.go b/device/pool_size.go new file mode 100644 index 000000000..3a6bc1680 --- /dev/null +++ b/device/pool_size.go @@ -0,0 +1,41 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package device + +// MaxBatchSizeOverride, when nonzero, replaces the per-Device batch size used +// to size eager buffer allocations in RoutineReceiveIncoming and +// RoutineReadFromTUN. Zero means "do not override" (Devices fall back to the +// larger of bind.BatchSize() and tun.BatchSize()); zero is NOT "unlimited". +// Changes affect Devices created after this assignment; use +// SetMaxBatchSizeOverride to set it. +var MaxBatchSizeOverride uint32 = 0 + +// SetPreallocatedBuffersPerPool sets the cap on the number of buffers held by +// each per-Device pool. Zero disables the cap (upstream default on +// non-mobile platforms). Changes affect Devices created after this call. +// To retune a live Device, use Device.SetPreallocatedBuffersPerPool. +func SetPreallocatedBuffersPerPool(n uint32) { + PreallocatedBuffersPerPool = n +} + +// SetMaxBatchSizeOverride sets the global batch size override applied to +// Devices created after this call. Zero disables the override. Existing +// Devices are unaffected; use Device.SetMaxBatchSize for per-instance. +func SetMaxBatchSizeOverride(n uint32) { + MaxBatchSizeOverride = n +} + +// SetPreallocatedBuffersPerPool updates the cap on this Device's pools in +// place. Takes effect immediately; goroutines blocked in Get are unblocked if +// the cap was raised. Has no effect if the Device was created with +// PreallocatedBuffersPerPool == 0. +func (device *Device) SetPreallocatedBuffersPerPool(n uint32) { + device.pool.messageBuffers.SetMax(n) + device.pool.inboundElements.SetMax(n) + device.pool.outboundElements.SetMax(n) + device.pool.inboundElementsContainer.SetMax(n) + device.pool.outboundElementsContainer.SetMax(n) +} diff --git a/device/pools.go b/device/pools.go index 2c18f4179..179f83dc0 100644 --- a/device/pools.go +++ b/device/pools.go @@ -10,23 +10,24 @@ import ( ) type WaitPool struct { - pool sync.Pool - cond sync.Cond - lock sync.Mutex - count uint32 // Get calls not yet Put back - max uint32 + pool sync.Pool + cond sync.Cond + lock sync.Mutex + count uint32 // Get calls not yet Put back + max uint32 + tracked bool // true if max was non-zero at construction; enables SetMax } func NewWaitPool(max uint32, new func() any) *WaitPool { - p := &WaitPool{pool: sync.Pool{New: new}, max: max} + p := &WaitPool{pool: sync.Pool{New: new}, max: max, tracked: max != 0} p.cond = sync.Cond{L: &p.lock} return p } func (p *WaitPool) Get() any { - if p.max != 0 { + if p.tracked { p.lock.Lock() - for p.count >= p.max { + for p.max != 0 && p.count >= p.max { p.cond.Wait() } p.count++ @@ -37,7 +38,7 @@ func (p *WaitPool) Get() any { func (p *WaitPool) Put(x any) { p.pool.Put(x) - if p.max == 0 { + if !p.tracked { return } p.lock.Lock() @@ -46,6 +47,19 @@ func (p *WaitPool) Put(x any) { p.cond.Signal() } +// SetMax updates the pool cap. Takes effect immediately; waiters are +// broadcast so they re-check against the new value. Has no effect if the +// pool was constructed with max == 0 (unbounded, fast-path Get/Put). +func (p *WaitPool) SetMax(n uint32) { + if !p.tracked { + return + } + p.lock.Lock() + p.max = n + p.cond.Broadcast() + p.lock.Unlock() +} + func (device *Device) PopulatePools() { device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { s := make([]*QueueInboundElement, 0, device.BatchSize()) diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index 236dea16c..97255d0e9 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -10,10 +10,13 @@ import "golang.zx2c4.com/wireguard/conn" /* Reduce memory consumption for Android */ const ( - QueueStagedSize = conn.IdealBatchSize - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram - PreallocatedBuffersPerPool = 4096 + QueueStagedSize = conn.IdealBatchSize + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram ) + +// PreallocatedBuffersPerPool caps the number of buffers held by each per-Device +// pool. Use SetPreallocatedBuffersPerPool to change this before calling NewDevice. +var PreallocatedBuffersPerPool uint32 = 4096 diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index b06118576..26ce1aaf4 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -10,10 +10,14 @@ package device import "golang.zx2c4.com/wireguard/conn" const ( - QueueStagedSize = conn.IdealBatchSize - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram - PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth + QueueStagedSize = conn.IdealBatchSize + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram ) + +// PreallocatedBuffersPerPool caps the number of buffers held by each per-Device +// pool. Zero disables the cap and allows unbounded growth (upstream default). +// Use SetPreallocatedBuffersPerPool to change this before calling NewDevice. +var PreallocatedBuffersPerPool uint32 = 0 diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go index 425b2eab6..3a7f92c43 100644 --- a/device/queueconstants_windows.go +++ b/device/queueconstants_windows.go @@ -6,10 +6,14 @@ package device const ( - QueueStagedSize = 128 - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = 65535 // Match with WINTUN_MAX_IP_PACKET_SIZE macro definition - PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + MaxSegmentSize = 65535 // Match with WINTUN_MAX_IP_PACKET_SIZE macro definition ) + +// PreallocatedBuffersPerPool caps the number of buffers held by each per-Device +// pool. Zero disables the cap and allows unbounded growth (upstream default). +// Use SetPreallocatedBuffersPerPool to change this before calling NewDevice. +var PreallocatedBuffersPerPool uint32 = 0