-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathrewritebody.go
139 lines (110 loc) · 3.08 KB
/
rewritebody.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
// Package plugin_rewritebody a plugin to rewrite response body.
package plugin_rewritebody
import (
"bufio"
"bytes"
"context"
"fmt"
"log"
"net"
"net/http"
"regexp"
)
// Rewrite holds one rewrite body configuration.
type Rewrite struct {
Regex string `json:"regex,omitempty"`
Replacement string `json:"replacement,omitempty"`
}
// Config holds the plugin configuration.
type Config struct {
LastModified bool `json:"lastModified,omitempty"`
Rewrites []Rewrite `json:"rewrites,omitempty"`
}
// CreateConfig creates and initializes the plugin configuration.
func CreateConfig() *Config {
return &Config{}
}
type rewrite struct {
regex *regexp.Regexp
replacement []byte
}
type rewriteBody struct {
name string
next http.Handler
rewrites []rewrite
lastModified bool
}
// New creates and returns a new rewrite body plugin instance.
func New(_ context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
rewrites := make([]rewrite, len(config.Rewrites))
for i, rewriteConfig := range config.Rewrites {
regex, err := regexp.Compile(rewriteConfig.Regex)
if err != nil {
return nil, fmt.Errorf("error compiling regex %q: %w", rewriteConfig.Regex, err)
}
rewrites[i] = rewrite{
regex: regex,
replacement: []byte(rewriteConfig.Replacement),
}
}
return &rewriteBody{
name: name,
next: next,
rewrites: rewrites,
lastModified: config.LastModified,
}, nil
}
func (r *rewriteBody) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
wrappedWriter := &responseWriter{
lastModified: r.lastModified,
ResponseWriter: rw,
}
r.next.ServeHTTP(wrappedWriter, req)
bodyBytes := wrappedWriter.buffer.Bytes()
contentEncoding := wrappedWriter.Header().Get("Content-Encoding")
if contentEncoding != "" && contentEncoding != "identity" {
if _, err := rw.Write(bodyBytes); err != nil {
log.Printf("unable to write body: %v", err)
}
return
}
for _, rwt := range r.rewrites {
bodyBytes = rwt.regex.ReplaceAll(bodyBytes, rwt.replacement)
}
if _, err := rw.Write(bodyBytes); err != nil {
log.Printf("unable to write rewrited body: %v", err)
}
}
type responseWriter struct {
buffer bytes.Buffer
lastModified bool
wroteHeader bool
http.ResponseWriter
}
func (r *responseWriter) WriteHeader(statusCode int) {
if !r.lastModified {
r.ResponseWriter.Header().Del("Last-Modified")
}
r.wroteHeader = true
// Delegates the Content-Length Header creation to the final body write.
r.ResponseWriter.Header().Del("Content-Length")
r.ResponseWriter.WriteHeader(statusCode)
}
func (r *responseWriter) Write(p []byte) (int, error) {
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
return r.buffer.Write(p)
}
func (r *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := r.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", r.ResponseWriter)
}
return hijacker.Hijack()
}
func (r *responseWriter) Flush() {
if flusher, ok := r.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}