diff --git a/plugins/meta/mptcp/main.go b/plugins/meta/mptcp/main.go new file mode 100644 index 000000000..58eee0161 --- /dev/null +++ b/plugins/meta/mptcp/main.go @@ -0,0 +1,293 @@ +// Copyright 2025 CNI authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This is a "meta-plugin". It reads in its own netconf, it does not create +// any network interface but configures MPTCP endpoints and limits. + +package main + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/containernetworking/cni/pkg/skel" + "github.com/containernetworking/cni/pkg/types" + current "github.com/containernetworking/cni/pkg/types/100" + "github.com/containernetworking/cni/pkg/version" + "github.com/containernetworking/plugins/pkg/netlinksafe" + "github.com/containernetworking/plugins/pkg/ns" + bv "github.com/containernetworking/plugins/pkg/utils/buildversion" +) + +// EndpointConfig specifies the flags for MPTCP endpoints. +type EndpointConfig struct { + Signal bool `json:"signal"` + Subflow bool `json:"subflow"` + Backup bool `json:"backup"` + Fullmesh bool `json:"fullmesh"` +} + +// LimitsConfig specifies the MPTCP path manager limits. +// Nil fields are left unchanged from their current values. +type LimitsConfig struct { + Subflows *uint32 `json:"subflows,omitempty"` + AddAddrAccepted *uint32 `json:"addAddrAccepted,omitempty"` +} + +// MPTCPNetConf represents the MPTCP plugin configuration. +type MPTCPNetConf struct { + types.NetConf + + Endpoints *EndpointConfig `json:"endpoints,omitempty"` + Limits *LimitsConfig `json:"limits,omitempty"` +} + +func main() { + skel.PluginMainFuncs(skel.CNIFuncs{ + Add: cmdAdd, + Check: cmdCheck, + Del: cmdDel, + }, version.VersionsStartingFrom("0.3.1"), bv.BuildString("mptcp")) +} + +func cmdAdd(args *skel.CmdArgs) error { + conf, result, err := parseConf(args.StdinData) + if err != nil { + return err + } + + if conf.PrevResult == nil { + return fmt.Errorf("missing prevResult from earlier plugin") + } + + err = ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error { + familyID, err := getMPTCPFamilyID() + if err != nil { + return err + } + + if conf.Endpoints != nil { + link, err := netlinksafe.LinkByName(args.IfName) + if err != nil { + return fmt.Errorf("failed to look up interface %q: %v", args.IfName, err) + } + ifIndex := link.Attrs().Index + + flags := endpointFlags(conf.Endpoints) + + for _, ipCfg := range result.IPs { + ip := ipCfg.Address.IP + if err := addEndpoint(familyID, ip, flags, ifIndex); err != nil { + // Treat EEXIST as success for idempotency + if !os.IsExist(err) { + return fmt.Errorf("failed to add MPTCP endpoint for %s: %v", ip, err) + } + } + } + } + + if conf.Limits != nil { + curSubflows, curAddAddr, err := getLimits(familyID) + if err != nil { + return fmt.Errorf("failed to get MPTCP limits: %v", err) + } + + newSubflows := curSubflows + if conf.Limits.Subflows != nil { + newSubflows = *conf.Limits.Subflows + } + newAddAddr := curAddAddr + if conf.Limits.AddAddrAccepted != nil { + newAddAddr = *conf.Limits.AddAddrAccepted + } + + if err := setLimits(familyID, newSubflows, newAddAddr); err != nil { + return fmt.Errorf("failed to set MPTCP limits: %v", err) + } + } + + return nil + }) + if err != nil { + return fmt.Errorf("cmdAdd failed: %v", err) + } + + return types.PrintResult(result, conf.CNIVersion) +} + +func cmdDel(args *skel.CmdArgs) error { + conf, result, err := parseConf(args.StdinData) + if err != nil { + return err + } + + if conf.Endpoints == nil { + return nil + } + + err = ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error { + familyID, err := getMPTCPFamilyID() + if err != nil { + // MPTCP not available; nothing to clean up + return nil + } + + endpoints, err := listEndpoints(familyID) + if err != nil { + // Best effort cleanup + return nil + } + + // Build a set of IPs to remove. + // Prefer IPs from prevResult; fall back to matching by interface index. + targetIPs := make(map[string]bool) + if result != nil { + for _, ipCfg := range result.IPs { + targetIPs[ipCfg.Address.IP.String()] = true + } + } + + for _, ep := range endpoints { + if ep.Addr == nil { + continue + } + if len(targetIPs) > 0 { + if !targetIPs[ep.Addr.String()] { + continue + } + } else { + // No prevResult IPs; try matching by interface + link, err := netlinksafe.LinkByName(args.IfName) + if err != nil { + return nil + } + if ep.IfIdx != int32(link.Attrs().Index) { + continue + } + } + // Ignore errors during cleanup + _ = delEndpoint(familyID, ep.ID, ep.Addr) + } + + return nil + }) + if err != nil { + _, ok := err.(ns.NSPathNotExistErr) + if ok { + return nil + } + return err + } + + return nil +} + +func cmdCheck(args *skel.CmdArgs) error { + conf, result, err := parseConf(args.StdinData) + if err != nil { + return err + } + + if conf.PrevResult == nil { + return fmt.Errorf("missing prevResult from earlier plugin") + } + + err = ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error { + familyID, err := getMPTCPFamilyID() + if err != nil { + return err + } + + if conf.Endpoints != nil { + endpoints, err := listEndpoints(familyID) + if err != nil { + return fmt.Errorf("failed to list MPTCP endpoints: %v", err) + } + + expectedFlags := endpointFlags(conf.Endpoints) + + for _, ipCfg := range result.IPs { + ip := ipCfg.Address.IP + found := false + for _, ep := range endpoints { + if ep.Addr != nil && ep.Addr.Equal(ip) { + found = true + if ep.Flags != expectedFlags { + return fmt.Errorf("MPTCP endpoint %s has flags 0x%x, expected 0x%x", ip, ep.Flags, expectedFlags) + } + break + } + } + if !found { + return fmt.Errorf("MPTCP endpoint for %s not found", ip) + } + } + } + + if conf.Limits != nil { + curSubflows, curAddAddr, err := getLimits(familyID) + if err != nil { + return fmt.Errorf("failed to get MPTCP limits: %v", err) + } + + if conf.Limits.Subflows != nil && *conf.Limits.Subflows != curSubflows { + return fmt.Errorf("MPTCP subflows limit is %d, expected %d", curSubflows, *conf.Limits.Subflows) + } + if conf.Limits.AddAddrAccepted != nil && *conf.Limits.AddAddrAccepted != curAddAddr { + return fmt.Errorf("MPTCP add_addr_accepted limit is %d, expected %d", curAddAddr, *conf.Limits.AddAddrAccepted) + } + } + + return nil + }) + if err != nil { + return err + } + + return nil +} + +func parseConf(data []byte) (*MPTCPNetConf, *current.Result, error) { + conf := MPTCPNetConf{} + if err := json.Unmarshal(data, &conf); err != nil { + return nil, nil, fmt.Errorf("failed to load netconf: %v", err) + } + + if conf.Endpoints == nil && conf.Limits == nil { + return nil, nil, fmt.Errorf("at least one of 'endpoints' or 'limits' must be specified") + } + + if conf.Endpoints != nil { + if !conf.Endpoints.Signal && !conf.Endpoints.Subflow && + !conf.Endpoints.Backup && !conf.Endpoints.Fullmesh { + return nil, nil, fmt.Errorf("endpoints configured but no flags (signal, subflow, backup, fullmesh) are set") + } + } + + if conf.RawPrevResult == nil { + return &conf, ¤t.Result{}, nil + } + + if err := version.ParsePrevResult(&conf.NetConf); err != nil { + return nil, nil, fmt.Errorf("could not parse prevResult: %v", err) + } + + result, err := current.NewResultFromResult(conf.PrevResult) + if err != nil { + return nil, nil, fmt.Errorf("could not convert result to current version: %v", err) + } + + return &conf, result, nil +} diff --git a/plugins/meta/mptcp/mptcp.go b/plugins/meta/mptcp/mptcp.go new file mode 100644 index 000000000..3f9bdebdc --- /dev/null +++ b/plugins/meta/mptcp/mptcp.go @@ -0,0 +1,260 @@ +// Copyright 2025 CNI authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux + +package main + +import ( + "fmt" + "net" + + "github.com/vishvananda/netlink" + "github.com/vishvananda/netlink/nl" + "golang.org/x/sys/unix" +) + +const ( + mptcpPMGenlName = "mptcp_pm" + mptcpPMGenlVer = 1 +) + +// MPTCP path manager commands (from linux/mptcp_pm.h). +const ( + mptcpPMCmdAddAddr = 1 + mptcpPMCmdDelAddr = 2 + mptcpPMCmdGetAddr = 3 + mptcpPMCmdSetLimits = 5 + mptcpPMCmdGetLimits = 6 +) + +// MPTCP path manager top-level attributes. +const ( + mptcpPMAttrAddr = 1 + mptcpPMAttrRcvAddAddrs = 2 + mptcpPMAttrSubflows = 3 +) + +// MPTCP path manager address attributes (nested inside mptcpPMAttrAddr). +const ( + mptcpPMAddrAttrFamily = 1 + mptcpPMAddrAttrID = 2 + mptcpPMAddrAttrAddr4 = 3 + mptcpPMAddrAttrAddr6 = 4 + mptcpPMAddrAttrPort = 5 + mptcpPMAddrAttrFlags = 6 + mptcpPMAddrAttrIfIdx = 7 +) + +// MPTCP endpoint flags. +const ( + mptcpPMAddrFlagSignal = 0x01 + mptcpPMAddrFlagSubflow = 0x02 + mptcpPMAddrFlagBackup = 0x04 + mptcpPMAddrFlagFullmesh = 0x08 +) + +// mptcpEndpoint represents a parsed MPTCP endpoint. +type mptcpEndpoint struct { + Family uint16 + ID uint8 + Addr net.IP + Flags uint32 + IfIdx int32 +} + +// getMPTCPFamilyID resolves the mptcp_pm generic netlink family ID. +// Must be called inside the target network namespace. +func getMPTCPFamilyID() (int, error) { + fam, err := netlink.GenlFamilyGet(mptcpPMGenlName) + if err != nil { + return -1, fmt.Errorf("MPTCP path manager not available: %v", err) + } + return int(fam.ID), nil +} + +// addEndpoint adds an MPTCP endpoint for the given IP address. +// The endpoint ID is auto-assigned by the kernel (ID=0). +func addEndpoint(familyID int, ip net.IP, flags uint32, ifIndex int) error { + req := nl.NewNetlinkRequest(familyID, unix.NLM_F_ACK) + + addrAttr := nl.NewRtAttr(unix.NLA_F_NESTED|mptcpPMAttrAddr, nil) + + if ip.To4() != nil { + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrFamily, nl.Uint16Attr(unix.AF_INET))) + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrAddr4, ip.To4())) + } else { + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrFamily, nl.Uint16Attr(unix.AF_INET6))) + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrAddr6, ip.To16())) + } + + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrID, nl.Uint8Attr(0))) + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrFlags, nl.Uint32Attr(flags))) + + if ifIndex > 0 { + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrIfIdx, nl.Uint32Attr(uint32(ifIndex)))) + } + + raw := []byte{mptcpPMCmdAddAddr, mptcpPMGenlVer, 0, 0} + raw = append(raw, addrAttr.Serialize()...) + req.AddRawData(raw) + + _, err := req.Execute(unix.NETLINK_GENERIC, 0) + return err +} + +// delEndpoint deletes an MPTCP endpoint by its ID and address. +func delEndpoint(familyID int, id uint8, ip net.IP) error { + req := nl.NewNetlinkRequest(familyID, unix.NLM_F_ACK) + + addrAttr := nl.NewRtAttr(unix.NLA_F_NESTED|mptcpPMAttrAddr, nil) + + if ip.To4() != nil { + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrFamily, nl.Uint16Attr(unix.AF_INET))) + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrAddr4, ip.To4())) + } else { + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrFamily, nl.Uint16Attr(unix.AF_INET6))) + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrAddr6, ip.To16())) + } + + addrAttr.AddChild(nl.NewRtAttr(mptcpPMAddrAttrID, nl.Uint8Attr(id))) + + raw := []byte{mptcpPMCmdDelAddr, mptcpPMGenlVer, 0, 0} + raw = append(raw, addrAttr.Serialize()...) + req.AddRawData(raw) + + _, err := req.Execute(unix.NETLINK_GENERIC, 0) + return err +} + +// listEndpoints lists all MPTCP endpoints in the current namespace. +func listEndpoints(familyID int) ([]mptcpEndpoint, error) { + req := nl.NewNetlinkRequest(familyID, unix.NLM_F_DUMP) + + raw := []byte{mptcpPMCmdGetAddr, mptcpPMGenlVer, 0, 0} + req.AddRawData(raw) + + msgs, err := req.Execute(unix.NETLINK_GENERIC, 0) + if err != nil { + return nil, err + } + + var endpoints []mptcpEndpoint + for _, msg := range msgs { + ep, err := deserializeEndpoint(msg) + if err != nil { + return nil, err + } + endpoints = append(endpoints, ep) + } + return endpoints, nil +} + +// deserializeEndpoint parses a generic netlink response message into an mptcpEndpoint. +func deserializeEndpoint(msg []byte) (mptcpEndpoint, error) { + ep := mptcpEndpoint{} + + if len(msg) < nl.SizeofGenlmsg { + return ep, fmt.Errorf("message too short: %d bytes", len(msg)) + } + + for attr := range nl.ParseAttributes(msg[nl.SizeofGenlmsg:]) { + if attr.Type&nl.NLA_TYPE_MASK != uint16(mptcpPMAttrAddr) { + continue + } + for nested := range nl.ParseAttributes(attr.Value) { + switch nested.Type & nl.NLA_TYPE_MASK { + case mptcpPMAddrAttrFamily: + ep.Family = nl.NativeEndian().Uint16(nested.Value) + case mptcpPMAddrAttrID: + ep.ID = nested.Value[0] + case mptcpPMAddrAttrAddr4: + ep.Addr = make(net.IP, net.IPv4len) + copy(ep.Addr, nested.Value) + case mptcpPMAddrAttrAddr6: + ep.Addr = make(net.IP, net.IPv6len) + copy(ep.Addr, nested.Value) + case mptcpPMAddrAttrFlags: + ep.Flags = nl.NativeEndian().Uint32(nested.Value) + case mptcpPMAddrAttrIfIdx: + ep.IfIdx = int32(nl.NativeEndian().Uint32(nested.Value)) + } + } + } + return ep, nil +} + +// setLimits configures the MPTCP path manager limits. +func setLimits(familyID int, subflows, addAddrAccepted uint32) error { + req := nl.NewNetlinkRequest(familyID, unix.NLM_F_ACK) + + attrs := []*nl.RtAttr{ + nl.NewRtAttr(mptcpPMAttrRcvAddAddrs, nl.Uint32Attr(addAddrAccepted)), + nl.NewRtAttr(mptcpPMAttrSubflows, nl.Uint32Attr(subflows)), + } + + raw := []byte{mptcpPMCmdSetLimits, mptcpPMGenlVer, 0, 0} + for _, a := range attrs { + raw = append(raw, a.Serialize()...) + } + req.AddRawData(raw) + + _, err := req.Execute(unix.NETLINK_GENERIC, 0) + return err +} + +// getLimits retrieves the current MPTCP path manager limits. +func getLimits(familyID int) (subflows, addAddrAccepted uint32, err error) { + req := nl.NewNetlinkRequest(familyID, 0) + + raw := []byte{mptcpPMCmdGetLimits, mptcpPMGenlVer, 0, 0} + req.AddRawData(raw) + + msgs, err := req.Execute(unix.NETLINK_GENERIC, 0) + if err != nil { + return 0, 0, err + } + + if len(msgs) < 1 { + return 0, 0, fmt.Errorf("empty response from MPTCP_PM_CMD_GET_LIMITS") + } + + for attr := range nl.ParseAttributes(msgs[0][nl.SizeofGenlmsg:]) { + switch attr.Type & nl.NLA_TYPE_MASK { + case uint16(mptcpPMAttrRcvAddAddrs): + addAddrAccepted = nl.NativeEndian().Uint32(attr.Value) + case uint16(mptcpPMAttrSubflows): + subflows = nl.NativeEndian().Uint32(attr.Value) + } + } + return subflows, addAddrAccepted, nil +} + +// endpointFlags converts an EndpointConfig into a bitmask of MPTCP endpoint flags. +func endpointFlags(cfg *EndpointConfig) uint32 { + var flags uint32 + if cfg.Signal { + flags |= mptcpPMAddrFlagSignal + } + if cfg.Subflow { + flags |= mptcpPMAddrFlagSubflow + } + if cfg.Backup { + flags |= mptcpPMAddrFlagBackup + } + if cfg.Fullmesh { + flags |= mptcpPMAddrFlagFullmesh + } + return flags +} diff --git a/plugins/meta/mptcp/mptcp_suite_test.go b/plugins/meta/mptcp/mptcp_suite_test.go new file mode 100644 index 000000000..13e550ecc --- /dev/null +++ b/plugins/meta/mptcp/mptcp_suite_test.go @@ -0,0 +1,27 @@ +// Copyright 2025 CNI authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestMPTCP(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "plugins/meta/mptcp") +} diff --git a/plugins/meta/mptcp/mptcp_test.go b/plugins/meta/mptcp/mptcp_test.go new file mode 100644 index 000000000..687b36691 --- /dev/null +++ b/plugins/meta/mptcp/mptcp_test.go @@ -0,0 +1,752 @@ +// Copyright 2025 CNI authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/vishvananda/netlink" + + "github.com/containernetworking/cni/pkg/skel" + current "github.com/containernetworking/cni/pkg/types/100" + "github.com/containernetworking/plugins/pkg/netlinksafe" + "github.com/containernetworking/plugins/pkg/ns" + "github.com/containernetworking/plugins/pkg/testutils" +) + +func configForEndpoints(name, ifName, ip string, signal, subflow, backup, fullmesh bool) []byte { + return []byte(fmt.Sprintf(`{ + "name": "%s", + "type": "mptcp", + "cniVersion": "1.0.0", + "endpoints": { + "signal": %t, + "subflow": %t, + "backup": %t, + "fullmesh": %t + }, + "prevResult": { + "interfaces": [ + {"name": "%s", "sandbox": "netns"} + ], + "ips": [ + { + "address": "%s", + "interface": 0 + } + ] + } + }`, name, signal, subflow, backup, fullmesh, ifName, ip)) +} + +func configForEndpointsDualStack(name, ifName, ipv4, ipv6 string, signal, subflow bool) []byte { + return []byte(fmt.Sprintf(`{ + "name": "%s", + "type": "mptcp", + "cniVersion": "1.0.0", + "endpoints": { + "signal": %t, + "subflow": %t + }, + "prevResult": { + "interfaces": [ + {"name": "%s", "sandbox": "netns"} + ], + "ips": [ + { + "address": "%s", + "interface": 0 + }, + { + "address": "%s", + "interface": 0 + } + ] + } + }`, name, signal, subflow, ifName, ipv4, ipv6)) +} + +func configForLimits(name, ifName, ip string, subflows, addAddrAccepted uint32) []byte { + return []byte(fmt.Sprintf(`{ + "name": "%s", + "type": "mptcp", + "cniVersion": "1.0.0", + "limits": { + "subflows": %d, + "addAddrAccepted": %d + }, + "prevResult": { + "interfaces": [ + {"name": "%s", "sandbox": "netns"} + ], + "ips": [ + { + "address": "%s", + "interface": 0 + } + ] + } + }`, name, subflows, addAddrAccepted, ifName, ip)) +} + +func configForBoth(name, ifName, ip string, signal, subflow bool, subflows, addAddrAccepted uint32) []byte { + return []byte(fmt.Sprintf(`{ + "name": "%s", + "type": "mptcp", + "cniVersion": "1.0.0", + "endpoints": { + "signal": %t, + "subflow": %t + }, + "limits": { + "subflows": %d, + "addAddrAccepted": %d + }, + "prevResult": { + "interfaces": [ + {"name": "%s", "sandbox": "netns"} + ], + "ips": [ + { + "address": "%s", + "interface": 0 + } + ] + } + }`, name, signal, subflow, subflows, addAddrAccepted, ifName, ip)) +} + +var _ = Describe("mptcp plugin", func() { + var originalNS ns.NetNS + var targetNS ns.NetNS + const IFName = "dummy0" + + BeforeEach(func() { + var err error + originalNS, err = testutils.NewNS() + Expect(err).NotTo(HaveOccurred()) + + targetNS, err = testutils.NewNS() + Expect(err).NotTo(HaveOccurred()) + + // Check if MPTCP path manager is available in the kernel + err = targetNS.Do(func(ns.NetNS) error { + _, err := netlink.GenlFamilyGet("mptcp_pm") + return err + }) + if err != nil { + Skip("MPTCP path manager not available in kernel") + } + + err = targetNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + la := netlink.NewLinkAttrs() + la.Name = IFName + err := netlink.LinkAdd(&netlink.Dummy{LinkAttrs: la}) + Expect(err).NotTo(HaveOccurred()) + + link, err := netlinksafe.LinkByName(IFName) + Expect(err).NotTo(HaveOccurred()) + + addr, err := netlink.ParseAddr("10.0.0.2/24") + Expect(err).NotTo(HaveOccurred()) + err = netlink.AddrAdd(link, addr) + Expect(err).NotTo(HaveOccurred()) + + addr6, err := netlink.ParseAddr("fd00::2/64") + Expect(err).NotTo(HaveOccurred()) + err = netlink.AddrAdd(link, addr6) + Expect(err).NotTo(HaveOccurred()) + + err = netlink.LinkSetUp(link) + Expect(err).NotTo(HaveOccurred()) + + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(originalNS.Close()).To(Succeed()) + Expect(targetNS.Close()).To(Succeed()) + }) + + It("passes prevResult through unchanged", func() { + conf := configForEndpoints("test", IFName, "10.0.0.2/24", true, true, false, false) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + r, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + + result, err := current.GetResult(r) + Expect(err).NotTo(HaveOccurred()) + Expect(result.Interfaces).To(HaveLen(1)) + Expect(result.Interfaces[0].Name).To(Equal(IFName)) + Expect(result.IPs).To(HaveLen(1)) + Expect(result.IPs[0].Address.String()).To(Equal("10.0.0.2/24")) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("adds MPTCP endpoint for IPv4 with signal+subflow flags", func() { + conf := configForEndpoints("test", IFName, "10.0.0.2/24", true, true, false, false) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + _, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + // Verify endpoint was created + err = targetNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + familyID, err := getMPTCPFamilyID() + Expect(err).NotTo(HaveOccurred()) + + endpoints, err := listEndpoints(familyID) + Expect(err).NotTo(HaveOccurred()) + + found := false + for _, ep := range endpoints { + if ep.Addr != nil && ep.Addr.String() == "10.0.0.2" { + found = true + Expect(ep.Flags).To(Equal(uint32(mptcpPMAddrFlagSignal | mptcpPMAddrFlagSubflow))) + break + } + } + Expect(found).To(BeTrue(), "expected MPTCP endpoint for 10.0.0.2") + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("adds MPTCP endpoint for IPv6", func() { + conf := configForEndpoints("test", IFName, "fd00::2/64", true, false, false, false) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + _, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + err = targetNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + familyID, err := getMPTCPFamilyID() + Expect(err).NotTo(HaveOccurred()) + + endpoints, err := listEndpoints(familyID) + Expect(err).NotTo(HaveOccurred()) + + found := false + for _, ep := range endpoints { + if ep.Addr != nil && ep.Addr.String() == "fd00::2" { + found = true + Expect(ep.Flags).To(Equal(uint32(mptcpPMAddrFlagSignal))) + break + } + } + Expect(found).To(BeTrue(), "expected MPTCP endpoint for fd00::2") + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("adds MPTCP endpoints for dual-stack", func() { + conf := configForEndpointsDualStack("test", IFName, "10.0.0.2/24", "fd00::2/64", true, true) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + _, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + err = targetNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + familyID, err := getMPTCPFamilyID() + Expect(err).NotTo(HaveOccurred()) + + endpoints, err := listEndpoints(familyID) + Expect(err).NotTo(HaveOccurred()) + + foundV4 := false + foundV6 := false + for _, ep := range endpoints { + if ep.Addr == nil { + continue + } + if ep.Addr.String() == "10.0.0.2" { + foundV4 = true + } + if ep.Addr.String() == "fd00::2" { + foundV6 = true + } + } + Expect(foundV4).To(BeTrue(), "expected MPTCP endpoint for 10.0.0.2") + Expect(foundV6).To(BeTrue(), "expected MPTCP endpoint for fd00::2") + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("sets MPTCP limits", func() { + conf := configForLimits("test", IFName, "10.0.0.2/24", 4, 4) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + _, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + err = targetNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + familyID, err := getMPTCPFamilyID() + Expect(err).NotTo(HaveOccurred()) + + subflows, addAddrAccepted, err := getLimits(familyID) + Expect(err).NotTo(HaveOccurred()) + Expect(subflows).To(Equal(uint32(4))) + Expect(addAddrAccepted).To(Equal(uint32(4))) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("configures both endpoints and limits", func() { + conf := configForBoth("test", IFName, "10.0.0.2/24", true, true, 8, 8) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + _, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + err = targetNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + familyID, err := getMPTCPFamilyID() + Expect(err).NotTo(HaveOccurred()) + + // Check endpoint + endpoints, err := listEndpoints(familyID) + Expect(err).NotTo(HaveOccurred()) + found := false + for _, ep := range endpoints { + if ep.Addr != nil && ep.Addr.String() == "10.0.0.2" { + found = true + Expect(ep.Flags).To(Equal(uint32(mptcpPMAddrFlagSignal | mptcpPMAddrFlagSubflow))) + break + } + } + Expect(found).To(BeTrue(), "expected MPTCP endpoint for 10.0.0.2") + + // Check limits + subflows, addAddrAccepted, err := getLimits(familyID) + Expect(err).NotTo(HaveOccurred()) + Expect(subflows).To(Equal(uint32(8))) + Expect(addAddrAccepted).To(Equal(uint32(8))) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("cmdDel removes endpoints", func() { + conf := configForEndpoints("test", IFName, "10.0.0.2/24", true, true, false, false) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + // Add first + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + _, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + // Delete + err = originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + err := testutils.CmdDelWithArgs(args, func() error { + return cmdDel(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + // Verify endpoint is gone + err = targetNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + familyID, err := getMPTCPFamilyID() + Expect(err).NotTo(HaveOccurred()) + + endpoints, err := listEndpoints(familyID) + Expect(err).NotTo(HaveOccurred()) + + for _, ep := range endpoints { + if ep.Addr != nil && ep.Addr.String() == "10.0.0.2" { + Fail("endpoint 10.0.0.2 should have been removed") + } + } + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("cmdDel handles namespace already gone", func() { + conf := configForEndpoints("test", IFName, "10.0.0.2/24", true, true, false, false) + + // Use a non-existent namespace path + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: "/var/run/netns/does-not-exist", + IfName: IFName, + StdinData: conf, + } + + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + err := testutils.CmdDelWithArgs(args, func() error { + return cmdDel(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("cmdCheck verifies endpoints exist with correct flags", func() { + conf := configForEndpoints("test", IFName, "10.0.0.2/24", true, true, false, false) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + // Add first + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + _, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + // Check should pass + err = originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + err := testutils.CmdCheckWithArgs(args, func() error { + return cmdCheck(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("cmdCheck fails when endpoint is missing", func() { + conf := configForEndpoints("test", IFName, "10.0.0.2/24", true, true, false, false) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + // Don't add, just check -- should fail + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + err := testutils.CmdCheckWithArgs(args, func() error { + return cmdCheck(args) + }) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not found")) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("cmdCheck verifies limits", func() { + conf := configForLimits("test", IFName, "10.0.0.2/24", 6, 6) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + // Add limits + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + _, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + // Check should pass + err = originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + err := testutils.CmdCheckWithArgs(args, func() error { + return cmdCheck(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + It("cmdAdd is idempotent", func() { + conf := configForEndpoints("test", IFName, "10.0.0.2/24", true, true, false, false) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + // Add once + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + _, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + // Add again -- should not error + err = originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + _, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + DescribeTable("endpoint flags", + func(signal, subflow, backup, fullmesh bool, expectedFlags uint32) { + conf := configForEndpoints("test", IFName, "10.0.0.2/24", signal, subflow, backup, fullmesh) + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IFName, + StdinData: conf, + } + + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + _, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + err = targetNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + familyID, err := getMPTCPFamilyID() + Expect(err).NotTo(HaveOccurred()) + + endpoints, err := listEndpoints(familyID) + Expect(err).NotTo(HaveOccurred()) + + found := false + for _, ep := range endpoints { + if ep.Addr != nil && ep.Addr.String() == "10.0.0.2" { + found = true + Expect(ep.Flags).To(Equal(expectedFlags)) + break + } + } + Expect(found).To(BeTrue()) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }, + Entry("signal only", true, false, false, false, uint32(mptcpPMAddrFlagSignal)), + Entry("subflow only", false, true, false, false, uint32(mptcpPMAddrFlagSubflow)), + Entry("backup only", false, false, true, false, uint32(mptcpPMAddrFlagBackup)), + Entry("fullmesh only", false, false, false, true, uint32(mptcpPMAddrFlagFullmesh)), + Entry("signal+subflow", true, true, false, false, uint32(mptcpPMAddrFlagSignal|mptcpPMAddrFlagSubflow)), + Entry("subflow+backup", false, true, true, false, uint32(mptcpPMAddrFlagSubflow|mptcpPMAddrFlagBackup)), + Entry("signal+fullmesh", true, false, false, true, uint32(mptcpPMAddrFlagSignal|mptcpPMAddrFlagFullmesh)), + ) +}) + +var _ = Describe("config validation", func() { + It("rejects config with no endpoints or limits", func() { + conf := []byte(`{ + "name": "test", + "type": "mptcp", + "cniVersion": "1.0.0", + "prevResult": { + "interfaces": [{"name": "eth0", "sandbox": "netns"}], + "ips": [{"address": "10.0.0.2/24", "interface": 0}] + } + }`) + _, _, err := parseConf(conf) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("at least one")) + }) + + It("rejects endpoints with no flags set", func() { + conf := []byte(`{ + "name": "test", + "type": "mptcp", + "cniVersion": "1.0.0", + "endpoints": {}, + "prevResult": { + "interfaces": [{"name": "eth0", "sandbox": "netns"}], + "ips": [{"address": "10.0.0.2/24", "interface": 0}] + } + }`) + _, _, err := parseConf(conf) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no flags")) + }) + + It("accepts limits-only config", func() { + conf := []byte(`{ + "name": "test", + "type": "mptcp", + "cniVersion": "1.0.0", + "limits": { + "subflows": 4 + }, + "prevResult": { + "interfaces": [{"name": "eth0", "sandbox": "netns"}], + "ips": [{"address": "10.0.0.2/24", "interface": 0}] + } + }`) + parsed, result, err := parseConf(conf) + Expect(err).NotTo(HaveOccurred()) + Expect(parsed.Endpoints).To(BeNil()) + Expect(parsed.Limits).NotTo(BeNil()) + Expect(*parsed.Limits.Subflows).To(Equal(uint32(4))) + Expect(result.IPs).To(HaveLen(1)) + }) + + It("accepts endpoints-only config", func() { + conf := []byte(`{ + "name": "test", + "type": "mptcp", + "cniVersion": "1.0.0", + "endpoints": { + "signal": true + }, + "prevResult": { + "interfaces": [{"name": "eth0", "sandbox": "netns"}], + "ips": [{"address": "10.0.0.2/24", "interface": 0}] + } + }`) + parsed, _, err := parseConf(conf) + Expect(err).NotTo(HaveOccurred()) + Expect(parsed.Endpoints).NotTo(BeNil()) + Expect(parsed.Limits).To(BeNil()) + }) + + It("handles missing prevResult for DEL", func() { + conf := []byte(`{ + "name": "test", + "type": "mptcp", + "cniVersion": "1.0.0", + "endpoints": { + "signal": true + } + }`) + parsed, result, err := parseConf(conf) + Expect(err).NotTo(HaveOccurred()) + Expect(parsed).NotTo(BeNil()) + Expect(result.IPs).To(BeEmpty()) + }) +})