Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Guidance for coding agents working in this repository.
- Prefer splitting a code line only when it triggers the `lll` linter, do not split a command or arguments list for each element
- Use `netip` types instead of `net` types whenever possible
- Use constants instead of variables whenever possible, especially function-local inline constants.
- Prefer using pure functions over methods when possible. Especially if the method does not need any fields from the receiving struct, it should be a pure function.
- Do not use `time.Sleep`, prefer using a `time.Timer` with a `select` statement also listening on a context cancelation
- `panic`:
- should only be used when a programming error is encountered and you should NOT return errors for programming errors (such as passing nil objects)
Expand Down Expand Up @@ -127,6 +128,7 @@ The Go formatter used is gofumpt.
### Errors

- Always prefer wrapping errors with some context with `fmt.Errorf("doing this: %w", err)`
- Use `errors.New("error message")` when creating a 'bottom' constant string error without additional context, instead of `fmt.Errorf`
- In rare cases, you can just use `return err` notably:
- If the function is called **recursively**, since we don't wrap the wrapping multiple times for each recursion
- If the current function only statement is the call to another function, for example:
Expand Down Expand Up @@ -179,6 +181,8 @@ The Go formatter used is gofumpt.

- Do not use `http.DefaultClient`, use a custom `*http.Client` with a fixed timeout and share with dependency injections.
- Do not check for injected dependencies being `nil`, prefer to just panic on a nil pointer. By default it's fine to panic if a developer injects a dependency `nil`. `nil` does not mean use a default.
- Prefer using a `switch { case ...}` statement over multiple consecutive `if` statements to have shorter code.
- Prefer using `[...]T` instead of `[]T` when the length is fixed and known at compile time, to avoid unnecessary allocations.

## Validation checklist

Expand Down
2 changes: 2 additions & 0 deletions internal/firewall/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type firewallImpl interface { //nolint:interfacebloat
AcceptIpv6MulticastOutput(ctx context.Context, intf string) error
AcceptOutput(ctx context.Context, protocol, intf string,
ip netip.Addr, port uint16, remove bool) error
AcceptOutputFromIPPortToIPPort(ctx context.Context, protocol, intf string,
source, destination netip.AddrPort, remove bool) error
AcceptOutputFromIPToSubnet(ctx context.Context, intf string, assignedIP netip.Addr,
subnet netip.Prefix, remove bool) error
AcceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error
Expand Down
23 changes: 23 additions & 0 deletions internal/firewall/iptables/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,29 @@ func (c *Config) AcceptOutput(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction)
}

func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error {
if source.Addr().BitLen() != destination.Addr().BitLen() {
return fmt.Errorf("source and destination address families do not match")
}
Comment thread
qdm12 marked this conversation as resolved.

interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}

instruction := fmt.Sprintf("%s OUTPUT -s %s --sport %d -d %s %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), source.Addr(), source.Port(), destination.Addr(),
interfaceFlag, protocol, protocol, destination.Port())
Comment thread
qdm12 marked this conversation as resolved.
Outdated
if destination.Addr().Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output from %s to %s: %s", source, destination, needIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}

// AcceptOutputFromIPToSubnet accepts outgoing traffic from sourceIP to destinationSubnet
// on the interface intf. If intf is empty, it is set to "*" which means all interfaces.
// If remove is true, the rule is removed instead of added.
Expand Down
7 changes: 7 additions & 0 deletions internal/firewall/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,10 @@ func (c *Config) AcceptOutput(ctx context.Context, protocol, intf string,
) error {
return c.impl.AcceptOutput(ctx, protocol, intf, ip, port, remove)
}

func (c *Config) AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error {
return c.impl.AcceptOutputFromIPPortToIPPort(ctx, protocol, intf,
source, destination, remove)
}
56 changes: 56 additions & 0 deletions internal/restrictednet/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package restrictednet

import (
"context"
"fmt"
"net/http"

"github.com/qdm12/dns/v2/pkg/provider"
)

// Client is a client for making restricted network requests,
// such as opening temporary firewall rules for HTTPS connections.
// It is not meant to be high performance, although it can be used for
// multiple requests and concurrently.
type Client struct {
ipv6Supported bool
firewall Firewall
outboundInterface string
dohServers []provider.DoHServer
}

func New(firewall Firewall, defaultInterface string, ipv6Supported bool,
upstreamResolvers []provider.Provider,
) (*Client, error) {
dohServers := make([]provider.DoHServer, len(upstreamResolvers))
for i, upstreamResolver := range upstreamResolvers {
dohServers[i] = upstreamResolver.DoH
}
Comment thread
qdm12 marked this conversation as resolved.
Outdated

return &Client{
firewall: firewall,
outboundInterface: defaultInterface,
ipv6Supported: ipv6Supported,
dohServers: dohServers,
}, nil
}

func (c *Client) OpenHTTPSByDomain(ctx context.Context, domain string) (
httpClient *http.Client, cleanup func() error, err error,
) {
resolvedIPs, err := c.ResolveName(ctx, domain)
if err != nil {
return nil, nil, fmt.Errorf("resolving name: %w", err)
} else if len(resolvedIPs) == 0 {
return nil, nil, fmt.Errorf("no IP address found for name %q", domain)
}

selectedIP := resolvedIPs[0]

httpClient, cleanup, err = c.OpenHTTPS(domain, selectedIP)
if err != nil {
return nil, nil, fmt.Errorf("opening HTTPS: %w", err)
}

return httpClient, cleanup, nil
}
68 changes: 68 additions & 0 deletions internal/restrictednet/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package restrictednet

import (
"context"
"net/netip"
"testing"

"github.com/golang/mock/gomock"
"github.com/qdm12/dns/v2/pkg/provider"
"github.com/stretchr/testify/require"
)

type listenAddrPortMatcher struct {
expected netip.AddrPort
}

func (m listenAddrPortMatcher) Matches(x any) bool {
ip, ok := x.(netip.AddrPort)
if !ok {
return false
}
if m.expected.IsValid() {
return ip == m.expected
}
return ip.IsValid() && ip.Addr().IsValid() && ip.Port() > 0
}

func (m listenAddrPortMatcher) String() string {
if m.expected.IsValid() {
return "is the same as " + m.expected.String()
}
return "is a valid netip.AddrPort with a valid IP and non-zero port"
}

func Test_Client_OpenHTTPS(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
firewall := NewMockFirewall(ctrl)

destination := netip.MustParseAddrPort("1.2.3.4:443")
backgroundContext := context.Background()
sourceMatcher := listenAddrPortMatcher{}
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
backgroundContext, "tcp", "eth0", sourceMatcher, destination, false,
).DoAndReturn(func(_ context.Context,
_, _ string, source, _ netip.AddrPort, _ bool,
) error {
sourceMatcher.expected = source
return nil
})
firewall.EXPECT().AcceptOutputFromIPPortToIPPort(
backgroundContext, "tcp", "eth0", sourceMatcher, destination, true,
)

const ipv6Supported = false
upstreamResolvers := []provider.Provider{provider.Google()}
client, err := New(firewall, "eth0", ipv6Supported, upstreamResolvers)
require.NoError(t, err)

httpClient, cleanup, err := client.OpenHTTPS("api.example.com", netip.MustParseAddr("1.2.3.4"))
require.NoError(t, err)
require.NotNil(t, httpClient)
require.NotNil(t, cleanup)

err = cleanup()
require.NoError(t, err)
}
115 changes: 115 additions & 0 deletions internal/restrictednet/https.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package restrictednet

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"time"
)

// OpenHTTPS opens temporary restrictive firewall output for one HTTPS destination.
// The returned cleanup function must be called to remove the temporary firewall rule and close connections.
func (c *Client) OpenHTTPS(destinationTLSName string, destinationIP netip.Addr,
) (httpClient *http.Client, cleanup func() error, err error) {
listener, sourceAddrPort, err := bindSourcePort(destinationIP)
if err != nil {
return nil, nil, fmt.Errorf("binding source port: %w", err)
}

const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationIP, httpsPort)

const remove = false
ctx := context.Background() // it's a quick firewall change, worth not passing a context
err = c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
_ = listener.Close()
return nil, nil, fmt.Errorf("allowing output traffic through firewall: %w", err)
}

httpClient = newHTTPSClient(destinationTLSName, destinationIP, sourceAddrPort)
cleanup = func() error {
var errs []error
httpClient.CloseIdleConnections()
const remove = true
err := c.firewall.AcceptOutputFromIPPortToIPPort(ctx, "tcp", c.outboundInterface,
sourceAddrPort, destinationAddrPort, remove)
if err != nil {
errs = append(errs, fmt.Errorf("removing output traffic rule: %w", err))
}
err = listener.Close()
if err != nil {
errs = append(errs, fmt.Errorf("closing listener: %w", err))
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
Comment thread
qdm12 marked this conversation as resolved.
Comment thread
qdm12 marked this conversation as resolved.
return httpClient, cleanup, nil
}

func newHTTPSClient(destinationTLSName string,
destinationIP netip.Addr, sourceAddress netip.AddrPort,
) *http.Client {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint:forcetypeassert
httpTransport.Proxy = nil
httpTransport.MaxIdleConns = 1
httpTransport.MaxIdleConnsPerHost = 1
httpTransport.IdleConnTimeout = time.Second
httpTransport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: destinationTLSName,
}
httpTransport.DialContext = newBoundDialContext(destinationIP, sourceAddress)

const timeout = 5 * time.Second
return &http.Client{
Timeout: timeout,
Transport: httpTransport,
}
}

func newBoundDialContext(destinationAddress netip.Addr,
sourceAddress netip.AddrPort,
) func(ctx context.Context, network, _ string) (net.Conn, error) {
const httpsPort = 443
destinationAddrPort := netip.AddrPortFrom(destinationAddress, httpsPort).String()
return func(ctx context.Context, network, _ string) (net.Conn, error) {
const timeout = 2 * time.Second
dialer := &net.Dialer{Timeout: timeout}
dialer.LocalAddr = net.TCPAddrFromAddrPort(sourceAddress)
connection, err := dialer.DialContext(ctx, network, destinationAddrPort)
if err != nil {
return nil, fmt.Errorf("%s dialing %s: %w", network, destinationAddrPort, err)
}
return connection, nil
}
}
Comment thread
qdm12 marked this conversation as resolved.

func bindSourcePort(destinationIP netip.Addr) (
listener net.Listener, sourceAddr netip.AddrPort, err error,
) {
var bindAddr netip.Addr
if destinationIP.Is4() {
bindAddr = netip.AddrFrom4([4]byte{})
} else {
bindAddr = netip.AddrFrom16([16]byte{})
}
Comment thread
qdm12 marked this conversation as resolved.
Outdated
Comment on lines +153 to +157

listener, err = net.ListenTCP("tcp", net.TCPAddrFromAddrPort(
netip.AddrPortFrom(bindAddr, 0)))
if err != nil {
return nil, netip.AddrPort{}, fmt.Errorf("binding TCP port: %w", err)
}
Comment thread
qdm12 marked this conversation as resolved.
Outdated

tcpAddr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert
sourceAddr = tcpAddr.AddrPort()

return listener, sourceAddr, nil
}
12 changes: 12 additions & 0 deletions internal/restrictednet/interfaces.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package restrictednet

import (
"context"
"net/netip"
)

type Firewall interface {
AcceptOutputFromIPPortToIPPort(ctx context.Context,
protocol, intf string, source, destination netip.AddrPort, remove bool,
) error
}
3 changes: 3 additions & 0 deletions internal/restrictednet/mocks_generate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package restrictednet

//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall
50 changes: 50 additions & 0 deletions internal/restrictednet/mocks_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading