diff --git a/socket_linux.go b/socket_linux.go index ebda532a..fbd7d438 100644 --- a/socket_linux.go +++ b/socket_linux.go @@ -338,37 +338,92 @@ func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { return pkgHandle.SocketDiagTCPInfo(family) } -// SocketDiagTCP requests INET_DIAG_INFO for TCP protocol for specified family type and return related socket. +// SocketDiagTCPOptions represents the configuration options for socket diagnostics. +type SocketDiagTCPOptions struct { + // Families is a bitset representing the internet address families of interest to be used in the query. The bits + // are expected to be set according to the unix.AF_* constants. If Families is zero, both AF_INET and AF_INET6 + // address families will be retrieved by default. + Families uint32 + + // States is a bitset that specifies the TCP state of the target sockets to be retrieved. The bits are + // expected to be set according to TCP_* state constants. If States is zero, all states will be retrieved. + States uint32 +} + +// SocketDiagTCPWithOptions requests INET_DIAG_INFO for TCP protocol for specified options and returns related sockets. +// Currently only AF_INET and AF_INET6 address families are supported. // // If the returned error is [ErrDumpInterrupted], results may be inconsistent // or incomplete. -func (h *Handle) SocketDiagTCP(family uint8) ([]*Socket, error) { - // Construct the request - req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) - req.AddData(&socketRequest{ - Family: family, - Protocol: unix.IPPROTO_TCP, - Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)), - States: uint32(0xfff), // all states - }) +func (h *Handle) SocketDiagTCPWithOptions(opts SocketDiagTCPOptions) ([]*Socket, error) { + if opts.Families == 0 { + opts.Families = (1 << unix.AF_INET) | (1 << unix.AF_INET6) // all IPv4 and IPv6 sockets + } + if opts.States == 0 { + opts.States = uint32(0xfff) // all states + } + + // check if any unsupported families are set + supportedFamilies := uint32((1 << unix.AF_INET) | (1 << unix.AF_INET6)) + if opts.Families&^supportedFamilies != 0 { + return nil, fmt.Errorf("unsupported address families specified: only AF_INET and AF_INET6 are supported") + } - // Do the query and parse the result var result []*Socket - executeErr := req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool { - sockInfo := &Socket{} - if err := sockInfo.deserialize(msg); err != nil { - return false + var err error + for _, family := range []uint8{unix.AF_INET, unix.AF_INET6} { + if opts.Families&(1<