Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"mime/multipart"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"path/filepath"
Expand Down Expand Up @@ -991,7 +992,7 @@ func (c *Context) ClientIP() string {

var (
trusted bool
remoteIP net.IP
remoteIP netip.Addr
)
// If gin is listening a unix socket, always trust it.
localAddr, ok := c.Request.Context().Value(http.LocalAddrContextKey).(net.Addr)
Expand All @@ -1004,8 +1005,9 @@ func (c *Context) ClientIP() string {
// It also checks if the remoteIP is a trusted proxy or not.
// In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks
// defined by Engine.SetTrustedProxies()
remoteIP = net.ParseIP(c.RemoteIP())
if remoteIP == nil {
var err error
remoteIP, err = netip.ParseAddr(c.RemoteIP())
if err != nil {
return ""
}
trusted = c.engine.isTrustedProxy(remoteIP)
Expand Down
5 changes: 3 additions & 2 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"os"
"path/filepath"
Expand Down Expand Up @@ -3128,9 +3129,9 @@ func TestRemoteIPFail(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest(http.MethodPost, "/", nil)
c.Request.RemoteAddr = "[:::]:80"
ip := net.ParseIP(c.RemoteIP())
ip, err := netip.ParseAddr(c.RemoteIP())
trust := c.engine.isTrustedProxy(ip)
assert.Nil(t, ip)
require.Error(t, err)
assert.False(t, trust)
}

Expand Down
61 changes: 20 additions & 41 deletions gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"html/template"
"net"
"net/http"
"net/netip"
"os"
"path"
"strings"
Expand Down Expand Up @@ -36,15 +37,9 @@ var (

var defaultPlatform string

var defaultTrustedCIDRs = []*net.IPNet{
{ // 0.0.0.0/0 (IPv4)
IP: net.IP{0x0, 0x0, 0x0, 0x0},
Mask: net.IPMask{0x0, 0x0, 0x0, 0x0},
},
{ // ::/0 (IPv6)
IP: net.IP{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Mask: net.IPMask{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
},
var defaultTrustedCIDRs = []netip.Prefix{
netip.MustParsePrefix("0.0.0.0/0"), // IPv4
netip.MustParsePrefix("::/0"), // IPv6
}

// HandlerFunc defines the handler used by gin middleware as return value.
Expand Down Expand Up @@ -185,7 +180,7 @@ type Engine struct {
maxParams uint16
maxSections uint16
trustedProxies []string
trustedCIDRs []*net.IPNet
trustedCIDRs []netip.Prefix
}

var _ IRouter = (*Engine)(nil)
Expand Down Expand Up @@ -411,33 +406,31 @@ func iterate(path, method string, routes RoutesInfo, root *node) RoutesInfo {
return routes
}

func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) {
func (engine *Engine) prepareTrustedCIDRs() ([]netip.Prefix, error) {
if engine.trustedProxies == nil {
return nil, nil
}

cidr := make([]*net.IPNet, 0, len(engine.trustedProxies))
cidrs := make([]netip.Prefix, 0, len(engine.trustedProxies))
for _, trustedProxy := range engine.trustedProxies {
if !strings.Contains(trustedProxy, "/") {
ip := parseIP(trustedProxy)
if ip == nil {
return cidr, &net.ParseError{Type: "IP address", Text: trustedProxy}
addr, err := netip.ParseAddr(trustedProxy)
if err != nil {
return cidrs, &net.ParseError{Type: "IP address", Text: trustedProxy}
}

switch len(ip) {
case net.IPv4len:
if addr.Is4() {
trustedProxy += "/32"
case net.IPv6len:
} else {
trustedProxy += "/128"
}
}
_, cidrNet, err := net.ParseCIDR(trustedProxy)
prefix, err := netip.ParsePrefix(trustedProxy)
if err != nil {
return cidr, err
return cidrs, err
}
cidr = append(cidr, cidrNet)
cidrs = append(cidrs, prefix.Masked())
}
return cidr, nil
return cidrs, nil
}

// SetTrustedProxies set a list of network origins (IPv4 addresses,
Expand All @@ -455,7 +448,7 @@ func (engine *Engine) SetTrustedProxies(trustedProxies []string) error {

// isUnsafeTrustedProxies checks if Engine.trustedCIDRs contains all IPs, it's not safe if it has (returns true)
func (engine *Engine) isUnsafeTrustedProxies() bool {
return engine.isTrustedProxy(net.ParseIP("0.0.0.0")) || engine.isTrustedProxy(net.ParseIP("::"))
return engine.isTrustedProxy(netip.MustParseAddr("0.0.0.0")) || engine.isTrustedProxy(netip.MustParseAddr("::"))
}

// parseTrustedProxies parse Engine.trustedProxies to Engine.trustedCIDRs
Expand All @@ -466,7 +459,7 @@ func (engine *Engine) parseTrustedProxies() error {
}

// isTrustedProxy will check whether the IP address is included in the trusted list according to Engine.trustedCIDRs
func (engine *Engine) isTrustedProxy(ip net.IP) bool {
func (engine *Engine) isTrustedProxy(ip netip.Addr) bool {
if engine.trustedCIDRs == nil {
return false
}
Expand All @@ -486,8 +479,8 @@ func (engine *Engine) validateHeader(header string) (clientIP string, valid bool
items := strings.Split(header, ",")
for i := len(items) - 1; i >= 0; i-- {
ipStr := strings.TrimSpace(items[i])
ip := net.ParseIP(ipStr)
if ip == nil {
ip, err := netip.ParseAddr(ipStr)
if err != nil {
break
}

Expand Down Expand Up @@ -520,20 +513,6 @@ func (engine *Engine) updateRouteTrees() {
}
}

// parseIP parse a string representation of an IP and returns a net.IP with the
// minimum byte representation or nil if input is invalid.
func parseIP(ip string) net.IP {
parsedIP := net.ParseIP(ip)

if ipv4 := parsedIP.To4(); ipv4 != nil {
// return ip in a 4-byte representation
return ipv4
}

// return ip in a 16-byte representation or nil
return parsedIP
}

// Run attaches the router to a http.Server and starts listening and serving HTTP requests.
// It is a shortcut for http.ListenAndServe(addr, router)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
Expand Down
25 changes: 9 additions & 16 deletions gin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/netip"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -869,7 +870,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) {

// valid ipv4 cidr
{
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("0.0.0.0/0")}
expectedTrustedCIDRs := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}
err := r.SetTrustedProxies([]string{"0.0.0.0/0"})

require.NoError(t, err)
Expand All @@ -885,7 +886,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) {

// valid ipv4 address
{
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("192.168.1.33/32")}
expectedTrustedCIDRs := []netip.Prefix{netip.MustParsePrefix("192.168.1.33/32")}

err := r.SetTrustedProxies([]string{"192.168.1.33"})

Expand All @@ -902,7 +903,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) {

// valid ipv6 address
{
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("2002:0000:0000:1234:abcd:ffff:c0a8:0101/128")}
expectedTrustedCIDRs := []netip.Prefix{netip.MustParsePrefix("2002:0000:0000:1234:abcd:ffff:c0a8:0101/128")}
err := r.SetTrustedProxies([]string{"2002:0000:0000:1234:abcd:ffff:c0a8:0101"})

require.NoError(t, err)
Expand All @@ -918,7 +919,7 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) {

// valid ipv6 cidr
{
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("::/0")}
expectedTrustedCIDRs := []netip.Prefix{netip.MustParsePrefix("::/0")}
err := r.SetTrustedProxies([]string{"::/0"})

require.NoError(t, err)
Expand All @@ -934,10 +935,10 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) {

// valid combination
{
expectedTrustedCIDRs := []*net.IPNet{
parseCIDR("::/0"),
parseCIDR("192.168.0.0/16"),
parseCIDR("172.16.0.1/32"),
expectedTrustedCIDRs := []netip.Prefix{
netip.MustParsePrefix("::/0"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("172.16.0.1/32"),
}
err := r.SetTrustedProxies([]string{
"::/0",
Expand Down Expand Up @@ -969,14 +970,6 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) {
}
}

func parseCIDR(cidr string) *net.IPNet {
_, parsedCIDR, err := net.ParseCIDR(cidr)
if err != nil {
fmt.Println(err)
}
return parsedCIDR
}

func assertRoutePresent(t *testing.T, gotRoutes RoutesInfo, wantRoute RouteInfo) {
for _, gotRoute := range gotRoutes {
if gotRoute.Path == wantRoute.Path && gotRoute.Method == wantRoute.Method {
Expand Down
Loading