diff --git a/auth.go b/auth.go index 270e00a8..16bc9c0d 100644 --- a/auth.go +++ b/auth.go @@ -47,7 +47,10 @@ var auth struct { allowedClient []netAddr - authed *TimeoutSet // cache authenticated users based on ip + // cache authenticated users based on ip and port + // add port to identify the users behind one ip address + // this may cause a user auth more than once + authed *TimeoutSet template *template.Template } @@ -114,6 +117,11 @@ func parseAllowedClient(val string) { mask = NewNbitIPv4Mask(32) } auth.allowedClient[i] = netAddr{ip.Mask(mask), mask} + + // TODO: add mask here, add record in usage + if (usageFlag) { + addAllowedClient(ipAndMask[0]) + } } } @@ -182,15 +190,20 @@ func initAuth() { func Authenticate(conn *clientConn, r *Request) (err error) { clientIP, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) if auth.authed.has(clientIP) { - debug.Printf("%s has already authed\n", clientIP) + debug.Printf("%s has already authed\n", conn.RemoteAddr().String()) return } if authIP(clientIP) { // IP is allowed return } - err = authUserPasswd(conn, r) - if err == nil { + err, user := authUserPasswd(conn, r) + if err == nil && user != ""{ auth.authed.add(clientIP) + // update the map of address to userid in usage + if usageFlag { + updateAddrToUser(clientIP, user) + } + } return } @@ -231,14 +244,14 @@ func calcRequestDigest(kv map[string]string, ha1, method string) string { return md5sum(strings.Join(arr, ":")) } -func checkProxyAuthorization(conn *clientConn, r *Request) error { +func checkProxyAuthorization(conn *clientConn, r *Request) (error, string) { if debug { debug.Printf("cli(%s) authorization: %s\n", conn.RemoteAddr(), r.ProxyAuthorization) } arr := strings.SplitN(r.ProxyAuthorization, " ", 2) if len(arr) != 2 { - return errors.New("auth: malformed ProxyAuthorization header: " + r.ProxyAuthorization) + return errors.New("auth: malformed ProxyAuthorization header: " + r.ProxyAuthorization), "" } authMethod := strings.ToLower(strings.TrimSpace(arr[0])) if authMethod == "digest" { @@ -246,7 +259,7 @@ func checkProxyAuthorization(conn *clientConn, r *Request) error { } else if authMethod == "basic" { return authBasic(conn, arr[1]) } - return errors.New("auth: method " + arr[0] + " unsupported, must use digest") + return errors.New("auth: method " + arr[0] + " unsupported, must use digest"), "" } func authPort(conn *clientConn, user string, au *authUser) error { @@ -262,73 +275,76 @@ func authPort(conn *clientConn, user string, au *authUser) error { return nil } -func authBasic(conn *clientConn, userPasswd string) error { +func authBasic(conn *clientConn, userPasswd string) (error, string) { b64, err := base64.StdEncoding.DecodeString(userPasswd) if err != nil { - return errors.New("auth:" + err.Error()) + return errors.New("auth:" + err.Error()), "" } arr := strings.Split(string(b64), ":") if len(arr) != 2 { - return errors.New("auth: malformed basic auth user:passwd") + return errors.New("auth: malformed basic auth user:passwd"), "" } user := arr[0] passwd := arr[1] au, ok := auth.user[user] if !ok || au.passwd != passwd { - return errAuthRequired + return errAuthRequired, user + } + if ret := authPort(conn, user, au); ret != nil { + return ret, user } - return authPort(conn, user, au) + return nil, user } -func authDigest(conn *clientConn, r *Request, keyVal string) error { +func authDigest(conn *clientConn, r *Request, keyVal string) (error, string) { authHeader := parseKeyValueList(keyVal) if len(authHeader) == 0 { - return errors.New("auth: empty authorization list") + return errors.New("auth: empty authorization list"), "" } nonceTime, err := strconv.ParseInt(authHeader["nonce"], 16, 64) if err != nil { - return fmt.Errorf("auth: nonce %v", err) + return fmt.Errorf("auth: nonce %v", err), "" } // If nonce time too early, reject. iOS will create a new connection to do // authentication. if time.Now().Sub(time.Unix(nonceTime, 0)) > time.Minute { - return errAuthRequired + return errAuthRequired, "" } user := authHeader["username"] au, ok := auth.user[user] if !ok { errl.Printf("cli(%s) auth: no such user: %s\n", conn.RemoteAddr(), authHeader["username"]) - return errAuthRequired + return errAuthRequired, "user" } if err = authPort(conn, user, au); err != nil { - return err + return err, user } if authHeader["qop"] != "auth" { - return errors.New("auth: qop wrong: " + authHeader["qop"]) + return errors.New("auth: qop wrong: " + authHeader["qop"]), user } response, ok := authHeader["response"] if !ok { - return errors.New("auth: no request-digest response") + return errors.New("auth: no request-digest response"), user } au.initHA1(user) digest := calcRequestDigest(authHeader, au.ha1, r.Method) if response != digest { errl.Printf("cli(%s) auth: digest not match, maybe password wrong", conn.RemoteAddr()) - return errAuthRequired + return errAuthRequired, user } - return nil + return nil, user } -func authUserPasswd(conn *clientConn, r *Request) (err error) { +func authUserPasswd(conn *clientConn, r *Request) (err error, user string) { if r.ProxyAuthorization != "" { // client has sent authorization header - err = checkProxyAuthorization(conn, r) - if err == nil { - return + err, user = checkProxyAuthorization(conn, r) + if err == nil && user != ""{ + return nil, user } else if err != errAuthRequired { sendErrorPage(conn, statusBadReq, "Bad authorization request", err.Error()) return @@ -344,13 +360,13 @@ func authUserPasswd(conn *clientConn, r *Request) (err error) { } buf := new(bytes.Buffer) if err := auth.template.Execute(buf, data); err != nil { - return fmt.Errorf("error generating auth response: %v", err) + return fmt.Errorf("error generating auth response: %v", err), "" } if bool(debug) && verbose { debug.Printf("authorization response:\n%s", buf.String()) } if _, err := conn.Write(buf.Bytes()); err != nil { - return fmt.Errorf("send auth response error: %v", err) + return fmt.Errorf("send auth response error: %v", err), "" } - return errAuthRequired + return errAuthRequired, "" } diff --git a/config.go b/config.go index c3bf294e..2749f573 100644 --- a/config.go +++ b/config.go @@ -76,6 +76,12 @@ type Config struct { // not config option saveReqLine bool // for http and cow parent, should save request line from client + + // capacity limitation file + UserCapacityFile string + UsageResetDate int + + RestartInterval time.Duration } var config Config @@ -559,6 +565,18 @@ func (p configParser) ParseUserPasswdFile(val string) { config.UserPasswdFile = val } +func (p configParser) ParseUserCapacityFile(val string) { + err := isFileExists(val) + if err != nil { + Fatal("userCapacityFile:", err) + } + config.UserCapacityFile = val +} + +func (p configParser) ParseUsageResetDate(val string) { + config.UsageResetDate = parseInt(val, "usageResetDate") +} + func (p configParser) ParseAllowedClient(val string) { config.AllowedClient = val } @@ -567,6 +585,10 @@ func (p configParser) ParseAuthTimeout(val string) { config.AuthTimeout = parseDuration(val, "authTimeout") } +func (p configParser) ParseRestartInterval(val string) { + config.RestartInterval = parseDuration(val, "restartInterval") +} + func (p configParser) ParseCore(val string) { config.Core = parseInt(val, "core") } diff --git a/doc/sample-config/rc b/doc/sample-config/rc index 362f16fc..b10071ca 100644 --- a/doc/sample-config/rc +++ b/doc/sample-config/rc @@ -115,10 +115,26 @@ listen = http://127.0.0.1:7777 # 注意:如有重复用户,COW 会报错退出 #userPasswdFile = /path/to/file +# 如需开启用户流量监控,必须使用文件来管理用户名密码, +# 同时在另一个文件中设定每个用户的流量,流量为整数,以MB为单位, 每行内容如下 +# username:capacity +# 这里的username必须与密码文件中的username对应 +# 同时设定该文件的路径 +# 以及自动重置用量的日期 +# 若不需要重置,请将日期设置为-1 +# 如果添加了allowedClient,并且要对其限流,则需要将IP作为username和对应的流量加入到userCapacityFile中 +# 暂时只支持单个IP的记录 +#userCapacityFile = /path/to/file +#usageResetDate = 12 + # 认证失效时间 # 语法:2h3m4s 表示 2 小时 3 分钟 4 秒 #authTimeout = 2h +# 代理自动重启时间 +# 语法:2h3m4s 表示 2 小时 3 分钟 4 秒 +#restartInterval = 2h + ############################# # 高级选项 ############################# diff --git a/doc/sample-config/rc-en b/doc/sample-config/rc-en index 206bcf6c..0b3b2a60 100644 --- a/doc/sample-config/rc-en +++ b/doc/sample-config/rc-en @@ -134,10 +134,26 @@ listen = http://127.0.0.1:7777 # COW will report error and exit if there's duplicated user. #userPasswdFile = /path/to/file +# To enable data transfer recording, the userPasswdFile must be enabled. +# List all those content in another file like this: +# username:capacity +# The username here should match those in userPasswd file. +# Set the path to the capacity file +# and usage reset date. +# Set the reset date to -1 if it is unnecessary. +# If the allowedClient is enable and their capcity is limited, the IP address(as username) and capacity must been added to userCapacityFile +# The sub-network is not supported yet. +#userCapacityFile = /path/to/file +#usageResetDate = 12 + # Time interval to keep authentication information. # Syntax: 2h3m4s means 2 hours 3 minutes 4 seconds #authTimeout = 2h +# Time interval to restart proxy. +# Syntax: 2h3m4s means 2 hours 3 minutes 4 seconds +#restartInterval = 2h + ############################# # Advanced options ############################# diff --git a/main.go b/main.go index 422ecbef..da68ba60 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,8 @@ var ( relaunch bool ) +var usageFlag bool + // This code is from goagain func lookPath() (argv0 string, err error) { argv0, err = exec.LookPath(os.Args[0]) @@ -28,7 +30,9 @@ func lookPath() (argv0 string, err error) { return } + func main() { + usageFlag = false quit = make(chan struct{}) // Parse flags after load config to allow override options in config cmdLineConfig := parseCmdLineConfig() @@ -41,12 +45,16 @@ func main() { initSelfListenAddr() initLog() + + usageFlag = initUsage() + initAuth() initSiteStat() initPAC() // initPAC uses siteStat, so must init after site stat initStat() + initParentPool() /* @@ -77,6 +85,18 @@ func main() { go proxy.Serve(&wg, quit) } + //add the usage recorder + if usageFlag { + wg.Add(1) + go startUsageRecorder(&wg, quit) + } + + // start restart deamon + if config.RestartInterval != 0 { + wg.Add(1) + pid := os.Getpid() + go restartDeamon(pid, &wg, quit) + } wg.Wait() if relaunch { diff --git a/main_unix.go b/main_unix.go index bd780bd9..0665a6ea 100644 --- a/main_unix.go +++ b/main_unix.go @@ -6,6 +6,8 @@ import ( "os" "os/signal" "syscall" + "time" + "sync" ) func sigHandler() { @@ -28,3 +30,31 @@ func sigHandler() { } */ } + +func restartDeamon(pid int, wg *sync.WaitGroup, quit <-chan struct{}) { + defer func() { + wg.Done() + }() + + duration := int(config.RestartInterval.Seconds()) + interval := 0 + debug.Println("Pid: ", pid, "restart interval: ", duration) + for { + select { + case <- quit: + debug.Println("exit the restart deamon") + return + default: + time.Sleep(time.Second) + interval += 1 + if (interval > duration) { + info.Println("Restart proxy now!") + // connPool.CloseAll() + syscall.Kill(pid, syscall.SIGUSR1) + return + } + } + } + + +} diff --git a/proxy.go b/proxy.go index 601ac65e..083eeaaf 100644 --- a/proxy.go +++ b/proxy.go @@ -501,6 +501,14 @@ func (c *clientConn) serve() { return } + if usageFlag { + if checkUsage(c.RemoteAddr().String()) != true { + sendErrorPage(c, statusForbidden, "Run out of capacity", + genErrMsg(&r, nil, "Please contact proxy admin.")) + return + } + } + if r.ExpectContinue { sendErrorPage(c, statusExpectFailed, "Expect header not supported", "Please contact COW's developer if you see this.") @@ -985,8 +993,13 @@ var connectBuf = leakybuf.NewLeakyBuf(512, connectBufSize) func copyServer2Client(sv *serverConn, c *clientConn, r *Request) (err error) { buf := connectBuf.Get() + total := 0 defer func() { connectBuf.Put(buf) + // update usage for user + if (usageFlag) { + go accumulateUsage(c.RemoteAddr().String(), total) + } }() /* @@ -997,7 +1010,6 @@ func copyServer2Client(sv *serverConn, c *clientConn, r *Request) (err error) { } */ - total := 0 const directThreshold = 8192 readTimeoutSet := false for { @@ -1285,6 +1297,10 @@ func (sv *serverConn) doRequest(c *clientConn, r *Request, rp *Response) (err er r.state = rsSent if err = c.readResponse(sv, r, rp); err == nil { sv.updateVisit() + // response received successfully + if (usageFlag) { + accumulateUsage(c.RemoteAddr().String(), int(rp.ContLen)) + } } return err } diff --git a/usage.go b/usage.go new file mode 100644 index 00000000..2a9d6005 --- /dev/null +++ b/usage.go @@ -0,0 +1,308 @@ +package main + +import ( + "os" + "strings" + "strconv" + "github.com/cyfdecyf/bufio" + "errors" + "time" + "bytes" + "fmt" + "sync" + "net" +) + +var recordPath string + +var userUsage struct { + usage map[string]int + capacity map[string]int + addrToUser map[string]string + lastSavedts time.Time + + // channel for update + updateMsg chan string + updateSig chan bool +} + +// var tempUsage map[string]int + +func parseCapacity(line string) (user string, capacity int, err error) { + arr := strings.Split(line, ":") + n := len(arr) + if n != 2 { + err = errors.New("User capacity limitation: " + line + + " syntax wrong, should be username:capacity") + return "", 0, err + } + c, err := strconv.Atoi(arr[1]) + if err != nil { + err = errors.New("Record file format error: " + arr[1] + + " syntax wrong, should be int") + return "", 0, err + } + debug.Printf("user: %s, capacity: %d", arr[0], c) + return arr[0], c, nil +} + +func parseUsage(line string) (user string, usage int, err error) { + arr := strings.Split(line, ":") + n := len(arr) + if n != 2 { + err = errors.New("Record file format error: " + line + + " syntax wrong, should be username:usage") + return "", 0, err + } + c, err := strconv.Atoi(arr[1]) + if err != nil { + err = errors.New("Record file format error: " + arr[1] + + " syntax wrong, should be int") + return "", 0, err + } + debug.Printf("user: %s, usage: %d", arr[0], c) + return arr[0], c, nil +} + +func loadCapcity(file string) { + // load capcity first + if file == "" { + return + } + f, err := os.Open(file) + if err != nil { + Fatal("error opening user usage file:", err) + } + + r := bufio.NewReader(f) + s := bufio.NewScanner(r) + for s.Scan() { + line := s.Text() + if line == "" { + continue + } + u, c, err := parseCapacity(s.Text()) + if err != nil { + Fatal(err) + } + if _, ok := userUsage.capacity[u]; ok { + Fatal("duplicate user:", u) + } + userUsage.capacity[u] = c + userUsage.usage[u] = 0 + + } + f.Close() +} + +func loadUsage() { + f, err := os.OpenFile(recordPath, os.O_CREATE, 0600) + if err != nil { + Fatal("error opening/creating user record file:", err) + } + r := bufio.NewReader(f) + s := bufio.NewScanner(r) + for s.Scan() { + ts := s.Text() + if ts == "" { + continue + } + if t, e := time.Parse(time.ANSIC, ts); e == nil { + userUsage.lastSavedts = t + break + } else { + Fatal("incomplete user record, please delete ", recordPath, " and restart: ", e) + return + } + } + for s.Scan() { + line := s.Text() + if line == "" { + continue + } + u, c, err := parseUsage(s.Text()) + if err != nil { + Fatal(err) + } + userUsage.usage[u] = c + + } + f.Close() +} + +func flushLog() { + if time.Now().Day() == config.UsageResetDate && + config.UsageResetDate != -1 && + userUsage.lastSavedts.Day() != config.UsageResetDate { + //it's time to clear the record of last month + for k, _ := range userUsage.usage { + userUsage.usage[k] = 0 + } + + } + bakPath := recordPath + ".bak" + f, err := os.OpenFile(bakPath, os.O_WRONLY | os.O_CREATE, 0600) + if err != nil { + Fatal("error opening/creating user record file:", err) + } + w := bufio.NewWriter(f) + t := time.Now() + w.WriteString(t.Format(time.ANSIC)) + w.WriteString("\n") + w.Flush() + for k, v := range userUsage.usage { + r := fmt.Sprintf("%s:%d\n", k, v) + w.WriteString(r) + } + w.Flush() + f.Close() + + os.Remove(recordPath) + os.Rename(bakPath, recordPath) + userUsage.lastSavedts = t + + +} + +func startUsageRecorder(wg *sync.WaitGroup, quit <-chan struct{}) { + defer func() { + flushLog() + debug.Println("exit the usage recorder") + wg.Done() + }() + var exit bool + go func() { + <-quit + userUsage.updateSig <- true + exit=true + }() + + go updateUsage(userUsage.updateMsg, userUsage.updateSig) + + + debug.Println("start usage recording!") + interval := 0 + for { + time.Sleep(1000 * time.Millisecond) + interval += 1 + if exit { + break + } + if interval > 1800 { + flushLog() + interval = 0 + } + } +} + +func initUsage() bool{ + if config.UserPasswdFile == "" || + config.UserCapacityFile == ""{ + return false + } + + if config.UsageResetDate == 0 || config.UsageResetDate > 30 { + Fatal("wrong UsageResetDate: ", config.UsageResetDate) + } + // get current running path + dir, err := os.Getwd() + if err != nil { + Fatal("error opening current directory:", err) + } + buf := new(bytes.Buffer) + fmt.Fprint(buf, dir, "/_records.log") + recordPath = buf.String() + + userUsage.capacity = make(map[string]int) + userUsage.usage = make(map[string]int) + userUsage.addrToUser = make(map[string]string) + // tempUsage = make(map[string]int) + + userUsage.updateMsg = make(chan string, 5000) + userUsage.updateSig = make(chan bool) + //load capacity at first + loadCapcity(config.UserCapacityFile) + + // load usage + loadUsage() + return true +} + +func checkUsage(addr string) bool { + clientIP, _, _ := net.SplitHostPort(addr) + var user string + var capacity int + var usage int + if val, ok := userUsage.addrToUser[clientIP]; ok { + user = val + } else { + errl.Println("unkonw address: ", addr) + return false + } + if val, ok := userUsage.capacity[user]; ok { + capacity = val + } else { + errl.Println("unkonw user: ", user) + return false + } + // don't have to check here + usage = userUsage.usage[user] + usageInMB := usage / 1024 / 1024 + return (usageInMB < capacity) +} + +func accumulateUsage(addr string, size int) { + msg := addr + "-" + strconv.Itoa(size) + userUsage.updateMsg <- msg + return + // clientIP, _, _ := net.SplitHostPort(addr) + // if _, ok := userUsage.addrToUser[clientIP]; !ok { + // errl.Println("un recorded addr: ", addr) + // return + // } + + // if _, ok := tempUsage[addr]; ok { + // tempUsage[addr] += size + // } else { + // tempUsage[addr] = size + // } +} + +func updateAddrToUser(addr string, user string) { + userUsage.addrToUser[addr] = user + // add record + if _, ok := userUsage.capacity[user]; !ok { + errl.Println("un restricted user: ", user, " check user password file and user capcity file") + } + debug.Println("add addr: ", addr, "to user: ", user) +} + +func addAllowedClient(addr string) { + if _, ok := userUsage.addrToUser[addr]; ok { + debug.Println("duplicated allowed client ip: ", addr) + return + } + + userUsage.addrToUser[addr] = addr +} + +func updateUsage(msgChan chan string, sigChan chan bool) { + for { + select { + case msg := <- msgChan: + arr := strings.Split(msg, "-") + addr := arr[0] + size, _ := strconv.Atoi(arr[1]) + clientIP, _, _ := net.SplitHostPort(addr) + if user, ok := userUsage.addrToUser[clientIP]; ok { + userUsage.usage[user] += size + } else { + errl.Println("un recorded addr: ", addr) + } + + case <- sigChan: + return + } + } + +}