Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 45 additions & 29 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ var auth struct {

allowedClient []netAddr

authed *TimeoutSet // cache authenticated users based on ip
// cache authenticated users based on ip and port
// add port to identify the users behind one ip address
// this may cause a user auth more than once
authed *TimeoutSet

template *template.Template
}
Expand Down Expand Up @@ -114,6 +117,11 @@ func parseAllowedClient(val string) {
mask = NewNbitIPv4Mask(32)
}
auth.allowedClient[i] = netAddr{ip.Mask(mask), mask}

// TODO: add mask here, add record in usage
if (usageFlag) {
addAllowedClient(ipAndMask[0])
}
}
}

Expand Down Expand Up @@ -182,15 +190,20 @@ func initAuth() {
func Authenticate(conn *clientConn, r *Request) (err error) {
clientIP, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
if auth.authed.has(clientIP) {
debug.Printf("%s has already authed\n", clientIP)
debug.Printf("%s has already authed\n", conn.RemoteAddr().String())
return
}
if authIP(clientIP) { // IP is allowed
return
}
err = authUserPasswd(conn, r)
if err == nil {
err, user := authUserPasswd(conn, r)
if err == nil && user != ""{
auth.authed.add(clientIP)
// update the map of address to userid in usage
if usageFlag {
updateAddrToUser(clientIP, user)
}

}
return
}
Expand Down Expand Up @@ -231,22 +244,22 @@ func calcRequestDigest(kv map[string]string, ha1, method string) string {
return md5sum(strings.Join(arr, ":"))
}

func checkProxyAuthorization(conn *clientConn, r *Request) error {
func checkProxyAuthorization(conn *clientConn, r *Request) (error, string) {
if debug {
debug.Printf("cli(%s) authorization: %s\n", conn.RemoteAddr(), r.ProxyAuthorization)
}

arr := strings.SplitN(r.ProxyAuthorization, " ", 2)
if len(arr) != 2 {
return errors.New("auth: malformed ProxyAuthorization header: " + r.ProxyAuthorization)
return errors.New("auth: malformed ProxyAuthorization header: " + r.ProxyAuthorization), ""
}
authMethod := strings.ToLower(strings.TrimSpace(arr[0]))
if authMethod == "digest" {
return authDigest(conn, r, arr[1])
} else if authMethod == "basic" {
return authBasic(conn, arr[1])
}
return errors.New("auth: method " + arr[0] + " unsupported, must use digest")
return errors.New("auth: method " + arr[0] + " unsupported, must use digest"), ""
}

func authPort(conn *clientConn, user string, au *authUser) error {
Expand All @@ -262,73 +275,76 @@ func authPort(conn *clientConn, user string, au *authUser) error {
return nil
}

func authBasic(conn *clientConn, userPasswd string) error {
func authBasic(conn *clientConn, userPasswd string) (error, string) {
b64, err := base64.StdEncoding.DecodeString(userPasswd)
if err != nil {
return errors.New("auth:" + err.Error())
return errors.New("auth:" + err.Error()), ""
}
arr := strings.Split(string(b64), ":")
if len(arr) != 2 {
return errors.New("auth: malformed basic auth user:passwd")
return errors.New("auth: malformed basic auth user:passwd"), ""
}
user := arr[0]
passwd := arr[1]

au, ok := auth.user[user]
if !ok || au.passwd != passwd {
return errAuthRequired
return errAuthRequired, user
}
if ret := authPort(conn, user, au); ret != nil {
return ret, user
}
return authPort(conn, user, au)
return nil, user
}

func authDigest(conn *clientConn, r *Request, keyVal string) error {
func authDigest(conn *clientConn, r *Request, keyVal string) (error, string) {
authHeader := parseKeyValueList(keyVal)
if len(authHeader) == 0 {
return errors.New("auth: empty authorization list")
return errors.New("auth: empty authorization list"), ""
}
nonceTime, err := strconv.ParseInt(authHeader["nonce"], 16, 64)
if err != nil {
return fmt.Errorf("auth: nonce %v", err)
return fmt.Errorf("auth: nonce %v", err), ""
}
// If nonce time too early, reject. iOS will create a new connection to do
// authentication.
if time.Now().Sub(time.Unix(nonceTime, 0)) > time.Minute {
return errAuthRequired
return errAuthRequired, ""
}

user := authHeader["username"]
au, ok := auth.user[user]
if !ok {
errl.Printf("cli(%s) auth: no such user: %s\n", conn.RemoteAddr(), authHeader["username"])
return errAuthRequired
return errAuthRequired, "user"
}

if err = authPort(conn, user, au); err != nil {
return err
return err, user
}
if authHeader["qop"] != "auth" {
return errors.New("auth: qop wrong: " + authHeader["qop"])
return errors.New("auth: qop wrong: " + authHeader["qop"]), user
}
response, ok := authHeader["response"]
if !ok {
return errors.New("auth: no request-digest response")
return errors.New("auth: no request-digest response"), user
}

au.initHA1(user)
digest := calcRequestDigest(authHeader, au.ha1, r.Method)
if response != digest {
errl.Printf("cli(%s) auth: digest not match, maybe password wrong", conn.RemoteAddr())
return errAuthRequired
return errAuthRequired, user
}
return nil
return nil, user
}

func authUserPasswd(conn *clientConn, r *Request) (err error) {
func authUserPasswd(conn *clientConn, r *Request) (err error, user string) {
if r.ProxyAuthorization != "" {
// client has sent authorization header
err = checkProxyAuthorization(conn, r)
if err == nil {
return
err, user = checkProxyAuthorization(conn, r)
if err == nil && user != ""{
return nil, user
} else if err != errAuthRequired {
sendErrorPage(conn, statusBadReq, "Bad authorization request", err.Error())
return
Expand All @@ -344,13 +360,13 @@ func authUserPasswd(conn *clientConn, r *Request) (err error) {
}
buf := new(bytes.Buffer)
if err := auth.template.Execute(buf, data); err != nil {
return fmt.Errorf("error generating auth response: %v", err)
return fmt.Errorf("error generating auth response: %v", err), ""
}
if bool(debug) && verbose {
debug.Printf("authorization response:\n%s", buf.String())
}
if _, err := conn.Write(buf.Bytes()); err != nil {
return fmt.Errorf("send auth response error: %v", err)
return fmt.Errorf("send auth response error: %v", err), ""
}
return errAuthRequired
return errAuthRequired, ""
}
22 changes: 22 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ type Config struct {

// not config option
saveReqLine bool // for http and cow parent, should save request line from client

// capacity limitation file
UserCapacityFile string
UsageResetDate int

RestartInterval time.Duration
}

var config Config
Expand Down Expand Up @@ -559,6 +565,18 @@ func (p configParser) ParseUserPasswdFile(val string) {
config.UserPasswdFile = val
}

func (p configParser) ParseUserCapacityFile(val string) {
err := isFileExists(val)
if err != nil {
Fatal("userCapacityFile:", err)
}
config.UserCapacityFile = val
}

func (p configParser) ParseUsageResetDate(val string) {
config.UsageResetDate = parseInt(val, "usageResetDate")
}

func (p configParser) ParseAllowedClient(val string) {
config.AllowedClient = val
}
Expand All @@ -567,6 +585,10 @@ func (p configParser) ParseAuthTimeout(val string) {
config.AuthTimeout = parseDuration(val, "authTimeout")
}

func (p configParser) ParseRestartInterval(val string) {
config.RestartInterval = parseDuration(val, "restartInterval")
}

func (p configParser) ParseCore(val string) {
config.Core = parseInt(val, "core")
}
Expand Down
16 changes: 16 additions & 0 deletions doc/sample-config/rc
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,26 @@ listen = http://127.0.0.1:7777
# 注意:如有重复用户,COW 会报错退出
#userPasswdFile = /path/to/file

# 如需开启用户流量监控,必须使用文件来管理用户名密码,
# 同时在另一个文件中设定每个用户的流量,流量为整数,以MB为单位, 每行内容如下
# username:capacity
# 这里的username必须与密码文件中的username对应
# 同时设定该文件的路径
# 以及自动重置用量的日期
# 若不需要重置,请将日期设置为-1
# 如果添加了allowedClient,并且要对其限流,则需要将IP作为username和对应的流量加入到userCapacityFile中
# 暂时只支持单个IP的记录
#userCapacityFile = /path/to/file
#usageResetDate = 12

# 认证失效时间
# 语法:2h3m4s 表示 2 小时 3 分钟 4 秒
#authTimeout = 2h

# 代理自动重启时间
# 语法:2h3m4s 表示 2 小时 3 分钟 4 秒
#restartInterval = 2h

#############################
# 高级选项
#############################
Expand Down
16 changes: 16 additions & 0 deletions doc/sample-config/rc-en
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,26 @@ listen = http://127.0.0.1:7777
# COW will report error and exit if there's duplicated user.
#userPasswdFile = /path/to/file

# To enable data transfer recording, the userPasswdFile must be enabled.
# List all those content in another file like this:
# username:capacity
# The username here should match those in userPasswd file.
# Set the path to the capacity file
# and usage reset date.
# Set the reset date to -1 if it is unnecessary.
# If the allowedClient is enable and their capcity is limited, the IP address(as username) and capacity must been added to userCapacityFile
# The sub-network is not supported yet.
#userCapacityFile = /path/to/file
#usageResetDate = 12

# Time interval to keep authentication information.
# Syntax: 2h3m4s means 2 hours 3 minutes 4 seconds
#authTimeout = 2h

# Time interval to restart proxy.
# Syntax: 2h3m4s means 2 hours 3 minutes 4 seconds
#restartInterval = 2h

#############################
# Advanced options
#############################
Expand Down
20 changes: 20 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ var (
relaunch bool
)

var usageFlag bool

// This code is from goagain
func lookPath() (argv0 string, err error) {
argv0, err = exec.LookPath(os.Args[0])
Expand All @@ -28,7 +30,9 @@ func lookPath() (argv0 string, err error) {
return
}


func main() {
usageFlag = false
quit = make(chan struct{})
// Parse flags after load config to allow override options in config
cmdLineConfig := parseCmdLineConfig()
Expand All @@ -41,12 +45,16 @@ func main() {

initSelfListenAddr()
initLog()

usageFlag = initUsage()

initAuth()
initSiteStat()
initPAC() // initPAC uses siteStat, so must init after site stat

initStat()


initParentPool()

/*
Expand Down Expand Up @@ -77,6 +85,18 @@ func main() {
go proxy.Serve(&wg, quit)
}

//add the usage recorder
if usageFlag {
wg.Add(1)
go startUsageRecorder(&wg, quit)
}

// start restart deamon
if config.RestartInterval != 0 {
wg.Add(1)
pid := os.Getpid()
go restartDeamon(pid, &wg, quit)
}
wg.Wait()

if relaunch {
Expand Down
30 changes: 30 additions & 0 deletions main_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"os"
"os/signal"
"syscall"
"time"
"sync"
)

func sigHandler() {
Expand All @@ -28,3 +30,31 @@ func sigHandler() {
}
*/
}

func restartDeamon(pid int, wg *sync.WaitGroup, quit <-chan struct{}) {
defer func() {
wg.Done()
}()

duration := int(config.RestartInterval.Seconds())
interval := 0
debug.Println("Pid: ", pid, "restart interval: ", duration)
for {
select {
case <- quit:
debug.Println("exit the restart deamon")
return
default:
time.Sleep(time.Second)
interval += 1
if (interval > duration) {
info.Println("Restart proxy now!")
// connPool.CloseAll()
syscall.Kill(pid, syscall.SIGUSR1)
return
}
}
}


}
Loading