diff --git a/tun_windows.go b/tun_windows.go index 57d52e3..9cb1e96 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -172,15 +172,9 @@ func (t *NativeTun) Start() error { if err != nil { return err } - for _, routeRange := range routeRanges { - if routeRange.Addr().Is4() { - err = luid.AddRoute(routeRange, gateway4, 0) - } else { - err = luid.AddRoute(routeRange, gateway6, 0) - } - if err != nil { - return err - } + err = addRouteList(luid, routeRanges, gateway4, gateway6, 0) + if err != nil { + return err } err = windnsapi.FlushResolverCache() if err != nil { @@ -560,20 +554,45 @@ func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error { if err != nil { return err } - for _, routeRange := range routeRanges { - if routeRange.Addr().Is4() { - err = luid.AddRoute(routeRange, gateway4, 0) + err = addRouteList(luid, routeRanges, gateway4, gateway6, 0) + if err != nil { + return err + } + err = windnsapi.FlushResolverCache() + if err != nil { + return err + } + return nil +} + +func addRouteList(luid winipcfg.LUID, destinations []netip.Prefix, gateway4 netip.Addr, gateway6 netip.Addr, metric uint32) error { + row := winipcfg.MibIPforwardRow2{} + row.Init() + row.InterfaceLUID = luid + row.Metric = metric + nextHop4 := row.NextHop + nextHop6 := row.NextHop + if gateway4.IsValid() { + nextHop4.SetAddr(gateway4) + } + if gateway6.IsValid() { + nextHop6.SetAddr(gateway6) + } + for _, destination := range destinations { + err := row.DestinationPrefix.SetPrefix(destination) + if err != nil { + return err + } + if destination.Addr().Is4() { + row.NextHop = nextHop4 } else { - err = luid.AddRoute(routeRange, gateway6, 0) + row.NextHop = nextHop6 } + err = row.Create() if err != nil { return err } } - err = windnsapi.FlushResolverCache() - if err != nil { - return err - } return nil }