diff --git a/agent/driver/network/firewall_windows.go b/agent/driver/network/firewall_windows.go index f8c9296..6299ee1 100644 --- a/agent/driver/network/firewall_windows.go +++ b/agent/driver/network/firewall_windows.go @@ -3,10 +3,63 @@ package network import ( + "fmt" "github.com/cloud-barista/cm-honeybee/agent/pkg/api/rest/model/onprem/network" "github.com/kumako/go-win64api" + "net" + "os/exec" + "strings" ) +func getPreferredInterface() (*net.IPNet, error) { + cmd := exec.Command("powershell", "-Command", + "(Get-NetIPInterface -AddressFamily IPv4 | Where-Object { $_.ConnectionState -eq 'Connected' } | Sort-Object InterfaceMetric | Select-Object -First 1 InterfaceAlias).InterfaceAlias") + output, err := cmd.Output() + if err != nil { + return nil, err + } + + interfaceName := strings.TrimSpace(string(output)) + if interfaceName == "" { + cmd = exec.Command("powershell", "-Command", + "(Get-NetIPInterface -AddressFamily IPv6 | Where-Object { $_.ConnectionState -eq 'Connected' } | Sort-Object InterfaceMetric | Select-Object -First 1 InterfaceAlias).InterfaceAlias") + output, err = cmd.Output() + if err != nil { + return nil, err + } + interfaceName = strings.TrimSpace(string(output)) + if interfaceName == "" { + return nil, fmt.Errorf("no active interface found") + } + } + + iface, err := net.InterfaceByName(interfaceName) + if err != nil { + return nil, err + } + + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + return ipnet, nil + } + } + + return nil, fmt.Errorf("no IP address found for interface") +} + +func getLocalSubnetCIDR() string { + ipNet, err := getPreferredInterface() + if err != nil { + return "*" + } + return ipNet.String() +} + const protocolUnknown = "unknown" func parseProtocol(fwRuleProtocol int32) string { @@ -28,60 +81,69 @@ func parseProtocol(fwRuleProtocol int32) string { func GetFirewallRules() ([]network.FirewallRule, error) { var fwRules = make([]network.FirewallRule, 0) - rules, err := winapi.FirewallRulesGet() if err != nil { return nil, err } priority := 0 + localSubnetCIDR := getLocalSubnetCIDR() + for _, rule := range rules { - if rule.Enabled { - var fwRule network.FirewallRule - - protocol := parseProtocol(rule.Protocol) - if protocol == protocolUnknown { - continue - } - - // Skip all of any-any allows - if (rule.LocalAddresses == "*" || rule.LocalAddresses == "LocalSubnet") && - (rule.RemoteAddresses == "*" || rule.RemoteAddresses == "LocalSubnet") && - (rule.LocalPorts == "*" || rule.LocalPorts == "LocalSubnet" || rule.LocalPorts == "") && - (rule.RemotePorts == "*" || rule.RemotePorts == "LocalSubnet" || rule.RemotePorts == "") { - continue - } - - fwRule.Protocol = protocol - - if rule.Direction == winapi.NET_FW_RULE_DIR_IN { - fwRule.Direction = "inbound" - fwRule.Src = rule.RemoteAddresses - fwRule.SrcPorts = rule.RemotePorts - fwRule.Dst = rule.LocalAddresses - fwRule.DstPorts = rule.LocalPorts - } else if rule.Direction == winapi.NET_FW_RULE_DIR_OUT { - fwRule.Direction = "outbound" - fwRule.Src = rule.LocalAddresses - fwRule.SrcPorts = rule.LocalPorts - fwRule.Dst = rule.RemoteAddresses - fwRule.DstPorts = rule.RemotePorts - } else { - continue - } - - if rule.Action == winapi.NET_FW_ACTION_ALLOW { - fwRule.Action = "allow" - } else if rule.Action == winapi.NET_FW_ACTION_BLOCK { - fwRule.Action = "deny" - } - - priority++ - fwRule.Priority = uint(priority) - - fwRules = append(fwRules, fwRule) + if !rule.Enabled { + continue } - } + // Skip all of any-any allows + if (rule.LocalAddresses == "*" || rule.LocalAddresses == "LocalSubnet") && + (rule.RemoteAddresses == "*" || rule.RemoteAddresses == "LocalSubnet") && + (rule.LocalPorts == "*" || rule.LocalPorts == "LocalSubnet" || rule.LocalPorts == "") && + (rule.RemotePorts == "*" || rule.RemotePorts == "LocalSubnet" || rule.RemotePorts == "") { + continue + } + + var fwRule network.FirewallRule + protocol := parseProtocol(rule.Protocol) + if protocol == protocolUnknown { + continue + } + + localAddr := rule.LocalAddresses + if localAddr == "LocalSubnet" { + localAddr = localSubnetCIDR + } + + remoteAddr := rule.RemoteAddresses + if remoteAddr == "LocalSubnet" { + remoteAddr = localSubnetCIDR + } + + fwRule.Protocol = protocol + if rule.Direction == winapi.NET_FW_RULE_DIR_IN { + fwRule.Direction = "inbound" + fwRule.Src = remoteAddr + fwRule.SrcPorts = rule.RemotePorts + fwRule.Dst = localAddr + fwRule.DstPorts = rule.LocalPorts + } else if rule.Direction == winapi.NET_FW_RULE_DIR_OUT { + fwRule.Direction = "outbound" + fwRule.Src = localAddr + fwRule.SrcPorts = rule.LocalPorts + fwRule.Dst = remoteAddr + fwRule.DstPorts = rule.RemotePorts + } else { + continue + } + + if rule.Action == winapi.NET_FW_ACTION_ALLOW { + fwRule.Action = "allow" + } else if rule.Action == winapi.NET_FW_ACTION_BLOCK { + fwRule.Action = "deny" + } + + priority++ + fwRule.Priority = uint(priority) + fwRules = append(fwRules, fwRule) + } return fwRules, nil }