diff --git a/conn.go b/conn.go index a9fbf2b..25d88e0 100644 --- a/conn.go +++ b/conn.go @@ -37,16 +37,20 @@ type Conn struct { TestDial nltest.Func // for testing only; passed to nltest.Dial NetNS int // fd referencing the network namespace netlink will interact with. - lasting bool // establish a lasting connection to be used across multiple netlink operations. - mu sync.Mutex // protects the following state - messages []netlink.Message - err error - nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. + lasting bool // establish a lasting connection to be used across multiple netlink operations. + mu sync.Mutex // protects the following state + messages []netlink.Message + err error + nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. + sockOptions []SockOption } // ConnOption is an option to change the behavior of the nftables Conn returned by Open. type ConnOption func(*Conn) +// SockOption is an option to change the behavior of the netlink socket used by the nftables Conn. +type SockOption func(*netlink.Conn) error + // New returns a netlink connection for querying and modifying nftables. Some // aspects of the new netlink connection can be configured using the options // WithNetNSFd, WithTestDial, and AsLasting. @@ -101,6 +105,14 @@ func WithTestDial(f nltest.Func) ConnOption { } } +// WithSockOptions sets the specified socket options when creating a new netlink +// connection. +func WithSockOptions(opts ...SockOption) ConnOption { + return func(cc *Conn) { + cc.sockOptions = append(cc.sockOptions, opts...) + } +} + // netlinkCloser is returned by netlinkConn(UnderLock) and must be called after // being done with the returned netlink connection in order to properly close // this connection, if necessary. @@ -284,11 +296,28 @@ func (cc *Conn) FlushRuleset() { } func (cc *Conn) dialNetlink() (*netlink.Conn, error) { + var ( + conn *netlink.Conn + err error + ) + if cc.TestDial != nil { - return nltest.Dial(cc.TestDial), nil + conn = nltest.Dial(cc.TestDial) + } else { + conn, err = netlink.Dial(unix.NETLINK_NETFILTER, &netlink.Config{NetNS: cc.NetNS}) + } + + if err != nil { + return nil, err + } + + for _, opt := range cc.sockOptions { + if err := opt(conn); err != nil { + return nil, err + } } - return netlink.Dial(unix.NETLINK_NETFILTER, &netlink.Config{NetNS: cc.NetNS}) + return conn, nil } func (cc *Conn) setErr(err error) {