-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.go
More file actions
241 lines (215 loc) · 7.72 KB
/
Copy pathmain.go
File metadata and controls
241 lines (215 loc) · 7.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
package main
import (
"bytes"
"context"
"errors"
"flag"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
"strconv"
"strings"
)
func main() {
var (
backendURLStr = ""
flagnameBackendURLStr = "backend-url"
host = ""
flagnameHost = "host"
port = -1
flagnamePort = "port"
oldHost = ""
flagnameOldHost = "old-host"
certfile = ""
flagnameCertFile = "cert"
privkey = ""
flagnamePrivKey = "cert-key"
noredirect = false
flagnameNoredirect = "no-redirect"
noCORS = false
flagnameNoCORS = "no-cors"
logEachRequest = false
flagnameLogEachRequest = "log-all-requests"
)
flag.StringVar(&backendURLStr, flagnameBackendURLStr, "", "Backend server URL")
flag.StringVar(&host, flagnameHost, "", "Current hostname (the hostname to which requests are to be redirected)")
flag.StringVar(&oldHost, flagnameOldHost, "", "Old hostname")
flag.StringVar(&certfile, flagnameCertFile, "", "SSL certificate")
flag.StringVar(&privkey, flagnamePrivKey, "", "The private key of the SSL certificate")
flag.BoolVar(&noredirect, flagnameNoredirect, false, "Turn off all redirects")
flag.BoolVar(&noCORS, flagnameNoCORS, false, "Turn off CORS headers for old-host origin")
flag.IntVar(&port, flagnamePort, -1, "Port on which to listen to (defaults to 80 or 443 depending on whether or not an SSL certificate is provided)")
flag.BoolVar(&logEachRequest, flagnameLogEachRequest, false, "Log every HTTP request")
flag.Parse()
if backendURLStr == "" || host == "" || oldHost == "" {
fmt.Printf("The flags -%s, -%s, and -%s all have to be set for the program to continue\n", flagnameBackendURLStr, flagnameHost, flagnameOldHost)
os.Exit(1)
}
backendURL, err := url.Parse(backendURLStr)
if err != nil {
fmt.Printf("Failed to parse backend URL: %v", err)
os.Exit(1)
}
if backendURL.Scheme == "" || backendURL.Host == "" {
fmt.Println("Invalid or incomplete backend URL")
os.Exit(1)
}
scheme := "http"
if certfile != "" {
scheme = "https"
}
if port == -1 {
if scheme == "http" {
port = 80
} else {
port = 443
}
}
oldHostOrigin := constructOrigin(scheme, oldHost, port)
proxy := httputil.NewSingleHostReverseProxy(backendURL)
proxy.ModifyResponse = func(r *http.Response) error {
if !noCORS {
if r.Header.Get("Access-Control-Allow-Origin") == "" || r.Request.Header.Get("Origin") == oldHostOrigin {
r.Header.Set("Access-Control-Allow-Origin", oldHostOrigin)
r.Header.Set("Access-Control-Allow-Methods", "GET, POST, HEAD, OPTIONS")
r.Header.Set("Access-Control-Allow-Headers", "X-Service-Worker-Version, X-Csrf-Token, Csrf-Token, X-Service-Worker-Cache")
r.Header.Set("Access-Control-Expose-Headers", "Csrf-Token, X-Service-Worker-Cache")
r.Header.Set("Access-Control-Allow-Credentials", "true")
if r.Request.Method == "OPTIONS" {
// Since the backend may return an HTTP response with a
// status code not in the range of 200-299, thus failing the
// preflight response for the client, here, the response is
// altered to contain an empty body with an HTTP status code
// of 204.
r.Body.Close()
r.Body = io.NopCloser(bytes.NewBuffer(nil)) // empty body
r.StatusCode = http.StatusNoContent
r.Status = http.StatusText(r.StatusCode)
r.ContentLength = 0
}
}
}
return nil
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
log.Printf("Proxy error: %v", err)
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusBadGateway)
w.Write([]byte("502 Bad Gateway\n"))
}
// Sub-domain redirect mappings:
redirectTable := map[string]string{
fmt.Sprintf("docs.%s", oldHost): fmt.Sprintf("https://docs.%s", host),
fmt.Sprintf("blog.%s", oldHost): fmt.Sprintf("https://blog.%s", host),
}
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", host, port),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirected := redirectDomains(redirectTable, w, r); redirected {
return
}
shouldRedirect := r.Host == oldHost && !(noredirect || shouldNotRedirect(r))
if logEachRequest {
var serviceWorkerInfo, redirectInfo, preflightInfo string
if v := r.Header.Get("X-Service-Worker-Version"); v != "" {
serviceWorkerInfo = fmt.Sprintf(" (X-Service-Worker-Version: %s)", v)
}
if r.Method == "OPTIONS" {
preflightInfo = fmt.Sprintf(" (Preflighted; origin=%s oldhostorigin=%s)", r.Header.Get("Origin"), oldHostOrigin)
}
if shouldRedirect {
redirectInfo = " (Redirected)"
}
log.Printf("Proxying request: %s %s%s%s%s", r.Method, r.URL.Path, serviceWorkerInfo, redirectInfo, preflightInfo)
}
if shouldRedirect {
url := *r.URL
url.Host = host
url.Scheme = scheme
http.Redirect(w, r, url.String(), http.StatusMovedPermanently)
return
}
proxy.ServeHTTP(w, r)
}),
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
defer stop()
go func() {
log.Printf("Reverse proxy started on :%s (for hosts %s, %s), forwarding to %s\n", scheme, host, oldHost, backendURLStr)
var err error
if scheme == "https" {
err = server.ListenAndServeTLS(certfile, privkey)
} else {
err = server.ListenAndServe()
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Printf("ListenAndServe (%s) error: %v\n", server.Addr, err)
}
}()
// If the server's listening for HTTPS connections on the default HTTPS
// port, then start another server on the default HTTP port, which will
// route all traffic to the HTTPS server.
if scheme == "https" && port == 443 {
redirectServer := &http.Server{
Addr: ":80",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
url := *r.URL
url.Scheme = "https"
url.Host = r.Host
http.Redirect(w, r, url.String(), http.StatusMovedPermanently)
}),
}
go func() {
log.Println("Starting redirect server (HTTP -> HTTPS) on " + redirectServer.Addr)
if err := redirectServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Printf("ListenAndServe (redirect) error: %v\n", err)
}
}()
}
// Wait until interrupt.
<-ctx.Done()
log.Println("Gracefully shutting down server (press again to immediately exit)...")
ctx, stop = signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
defer stop()
if err := server.Shutdown(ctx); err != nil {
log.Printf("Error gracefully shutting down server: %v\n", err)
} else {
log.Println("Server exited gracefully")
}
}
func shouldNotRedirect(r *http.Request) bool {
return r.URL.Path == "/service-worker.js" ||
r.URL.Path == "/manifest.json" ||
r.Header.Get("X-Service-Worker-Version") != "" ||
r.Header.Get("Service-Worker-Navigation-Preload") != "" ||
strings.HasPrefix(r.URL.Path, "/api/") ||
sessionCookieExists(r, "SID")
}
func constructOrigin(scheme, host string, port int) string {
portStr := ""
if !((scheme == "http" && port == 80) || (scheme == "https" && port == 443)) {
portStr = ":" + strconv.Itoa(port)
}
return fmt.Sprintf("%s://%s%s", scheme, host, portStr)
}
func sessionCookieExists(r *http.Request, cookieName string) bool {
if _, err := r.Cookie(cookieName); err == http.ErrNoCookie {
return false
}
return true
}
// redirectDomains HTTP 301 redirects domains from [key] to [value] of the
// domain mappings described in [table]. The return value indicates whether the
// request was redirected or not.
func redirectDomains(table map[string]string, w http.ResponseWriter, r *http.Request) bool {
newURL, found := table[r.Host]
if found {
http.Redirect(w, r, newURL, http.StatusMovedPermanently)
}
return found
}