Skip to content

Commit a7d4916

Browse files
authored
feat: add wechat provider (#486)
1 parent a3b41a0 commit a7d4916

File tree

4 files changed

+405
-0
lines changed

4 files changed

+405
-0
lines changed

providers/wechat/session.go

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package wechat
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"strings"
7+
"time"
8+
9+
"github.com/markbates/goth"
10+
)
11+
12+
// Session stores data during the auth process with Wechat.
13+
type Session struct {
14+
AuthURL string
15+
AccessToken string
16+
RefreshToken string
17+
ExpiresAt time.Time
18+
Openid string
19+
Unionid string
20+
}
21+
22+
var _ goth.Session = &Session{}
23+
24+
// GetAuthURL will return the URL set by calling the `BeginAuth` function on the Wepay provider.
25+
func (s Session) GetAuthURL() (string, error) {
26+
if s.AuthURL == "" {
27+
return "", errors.New(goth.NoAuthUrlErrorMessage)
28+
}
29+
return s.AuthURL, nil
30+
}
31+
32+
// Authorize the session with Wepay and return the access token to be stored for future use.
33+
func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) {
34+
p := provider.(*Provider)
35+
token, openid, err := p.fetchToken(params.Get("code"))
36+
37+
if err != nil {
38+
return "", err
39+
}
40+
41+
if !token.Valid() {
42+
return "", errors.New("invalid token received from provider")
43+
}
44+
45+
s.AccessToken = token.AccessToken
46+
s.RefreshToken = token.RefreshToken
47+
s.ExpiresAt = token.Expiry
48+
s.Openid = openid
49+
return token.AccessToken, err
50+
}
51+
52+
// Marshal the session into a string
53+
func (s Session) Marshal() string {
54+
b, _ := json.Marshal(s)
55+
return string(b)
56+
}
57+
58+
func (s Session) String() string {
59+
return s.Marshal()
60+
}
61+
62+
// UnmarshalSession wil unmarshal a JSON string into a session.
63+
func (p *Provider) UnmarshalSession(data string) (goth.Session, error) {
64+
s := &Session{}
65+
err := json.NewDecoder(strings.NewReader(data)).Decode(s)
66+
return s, err
67+
}

providers/wechat/session_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package wechat_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/markbates/goth"
7+
"github.com/markbates/goth/providers/wechat"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func Test_Implements_Session(t *testing.T) {
12+
t.Parallel()
13+
a := assert.New(t)
14+
s := &wechat.Session{}
15+
16+
a.Implements((*goth.Session)(nil), s)
17+
}
18+
19+
func Test_GetAuthURL(t *testing.T) {
20+
t.Parallel()
21+
a := assert.New(t)
22+
s := &wechat.Session{}
23+
24+
_, err := s.GetAuthURL()
25+
a.Error(err)
26+
27+
s.AuthURL = "/foo"
28+
29+
url, _ := s.GetAuthURL()
30+
a.Equal(url, "/foo")
31+
}
32+
33+
func Test_ToJSON(t *testing.T) {
34+
t.Parallel()
35+
a := assert.New(t)
36+
s := &wechat.Session{}
37+
38+
data := s.Marshal()
39+
a.Equal(data, `{"AuthURL":"","AccessToken":"","RefreshToken":"","ExpiresAt":"0001-01-01T00:00:00Z"}`)
40+
}
41+
42+
func Test_String(t *testing.T) {
43+
t.Parallel()
44+
a := assert.New(t)
45+
s := &wechat.Session{}
46+
47+
a.Equal(s.String(), s.Marshal())
48+
}

providers/wechat/wechat.go

+237
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
package wechat
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
"net/http"
8+
"net/url"
9+
"time"
10+
11+
"github.com/markbates/goth"
12+
"golang.org/x/oauth2"
13+
)
14+
15+
const (
16+
AuthURL = "https://open.weixin.qq.com/connect/qrconnect"
17+
TokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token"
18+
19+
ScopeSnsapiLogin = "snsapi_login"
20+
21+
ProfileURL = "https://api.weixin.qq.com/sns/userinfo"
22+
)
23+
24+
type Provider struct {
25+
providerName string
26+
config *oauth2.Config
27+
httpClient *http.Client
28+
ClientID string
29+
ClientSecret string
30+
RedirectURL string
31+
Lang WechatLangType
32+
33+
AuthURL string
34+
TokenURL string
35+
ProfileURL string
36+
}
37+
38+
type WechatLangType string
39+
40+
const (
41+
WECHAT_LANG_CN WechatLangType = "cn"
42+
WECHAT_LANG_EN WechatLangType = "en"
43+
)
44+
45+
// New creates a new Wechat provider, and sets up important connection details.
46+
// You should always call `wechat.New` to get a new Provider. Never try to create
47+
// one manually.
48+
func New(clientID, clientSecret, redirectURL string, lang WechatLangType) *Provider {
49+
p := &Provider{
50+
providerName: "wechat",
51+
ClientID: clientID,
52+
ClientSecret: clientSecret,
53+
RedirectURL: redirectURL,
54+
Lang: lang,
55+
AuthURL: AuthURL,
56+
TokenURL: TokenURL,
57+
ProfileURL: ProfileURL,
58+
}
59+
p.config = newConfig(p)
60+
return p
61+
}
62+
63+
// Name is the name used to retrieve this provider later.
64+
func (p *Provider) Name() string {
65+
return p.providerName
66+
}
67+
68+
// SetName is to update the name of the provider (needed in case of multiple providers of 1 type)
69+
func (p *Provider) SetName(name string) {
70+
p.providerName = name
71+
}
72+
73+
func (p *Provider) Client() *http.Client {
74+
return goth.HTTPClientWithFallBack(p.httpClient)
75+
}
76+
77+
// Debug is a no-op for the wechat package.
78+
func (p *Provider) Debug(debug bool) {}
79+
80+
// BeginAuth asks Wechat for an authentication end-point.
81+
func (p *Provider) BeginAuth(state string) (goth.Session, error) {
82+
params := url.Values{}
83+
params.Add("appid", p.ClientID)
84+
params.Add("response_type", "code")
85+
params.Add("state", state)
86+
params.Add("scope", ScopeSnsapiLogin)
87+
params.Add("redirect_uri", p.RedirectURL)
88+
session := &Session{
89+
AuthURL: fmt.Sprintf("%s?%s", p.AuthURL, params.Encode()),
90+
}
91+
return session, nil
92+
}
93+
94+
// FetchUser will go to Wepay and access basic information about the user.
95+
func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
96+
s := session.(*Session)
97+
user := goth.User{
98+
AccessToken: s.AccessToken,
99+
Provider: p.Name(),
100+
RefreshToken: s.RefreshToken,
101+
ExpiresAt: s.ExpiresAt,
102+
}
103+
104+
if user.AccessToken == "" {
105+
// data is not yet retrieved since accessToken is still empty
106+
return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName)
107+
}
108+
109+
params := url.Values{}
110+
params.Add("access_token", s.AccessToken)
111+
params.Add("openid", s.Openid)
112+
params.Add("lang", string(p.Lang))
113+
114+
url := fmt.Sprintf("%s?%s", p.ProfileURL, params.Encode())
115+
116+
req, err := http.NewRequest("GET", url, nil)
117+
if err != nil {
118+
return user, err
119+
}
120+
resp, err := p.Client().Do(req)
121+
if err != nil {
122+
if resp != nil {
123+
resp.Body.Close()
124+
}
125+
return user, err
126+
}
127+
defer resp.Body.Close()
128+
129+
if resp.StatusCode != http.StatusOK {
130+
return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, resp.StatusCode)
131+
}
132+
133+
err = userFromReader(resp.Body, &user)
134+
return user, err
135+
}
136+
137+
func newConfig(provider *Provider) *oauth2.Config {
138+
c := &oauth2.Config{
139+
ClientID: provider.ClientID,
140+
ClientSecret: provider.ClientSecret,
141+
RedirectURL: provider.RedirectURL,
142+
Endpoint: oauth2.Endpoint{
143+
AuthURL: provider.AuthURL,
144+
TokenURL: provider.TokenURL,
145+
},
146+
Scopes: []string{},
147+
}
148+
149+
c.Scopes = append(c.Scopes, ScopeSnsapiLogin)
150+
151+
return c
152+
}
153+
154+
func userFromReader(r io.Reader, user *goth.User) error {
155+
u := struct {
156+
Openid string `json:"openid"`
157+
Nickname string `json:"nickname"`
158+
Sex int `json:"sex"`
159+
Province string `json:"province"`
160+
City string `json:"city"`
161+
Country string `json:"country"`
162+
AvatarURL string `json:"headimgurl"`
163+
Unionid string `json:"unionid"`
164+
Code int `json:"errcode"`
165+
Msg string `json:"errmsg"`
166+
}{}
167+
err := json.NewDecoder(r).Decode(&u)
168+
if err != nil {
169+
return err
170+
}
171+
172+
if len(u.Msg) > 0 {
173+
return fmt.Errorf("CODE: %d, MSG: %s", u.Code, u.Msg)
174+
}
175+
176+
user.Email = fmt.Sprintf("%[email protected]", u.Openid)
177+
user.Name = u.Nickname
178+
user.UserID = u.Openid
179+
user.NickName = u.Nickname
180+
user.Location = u.City
181+
user.AvatarURL = u.AvatarURL
182+
user.RawData = map[string]interface{}{
183+
"Unionid": u.Unionid,
184+
}
185+
return nil
186+
}
187+
188+
// RefreshTokenAvailable refresh token is provided by auth provider or not
189+
func (p *Provider) RefreshTokenAvailable() bool {
190+
return false
191+
}
192+
193+
// RefreshToken get new access token based on the refresh token
194+
func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
195+
196+
return nil, nil
197+
}
198+
199+
func (p *Provider) fetchToken(code string) (*oauth2.Token, string, error) {
200+
201+
params := url.Values{}
202+
params.Add("appid", p.ClientID)
203+
params.Add("secret", p.ClientSecret)
204+
params.Add("grant_type", "authorization_code")
205+
params.Add("code", code)
206+
url := fmt.Sprintf("%s?%s", p.TokenURL, params.Encode())
207+
resp, err := p.Client().Get(url)
208+
209+
if err != nil {
210+
return nil, "", err
211+
}
212+
defer resp.Body.Close()
213+
if resp.StatusCode != http.StatusOK {
214+
return nil, "", fmt.Errorf("wechat /gettoken returns code: %d", resp.StatusCode)
215+
}
216+
217+
obj := struct {
218+
AccessToken string `json:"access_token"`
219+
ExpiresIn time.Duration `json:"expires_in"`
220+
Openid string `json:"openid"`
221+
Code int `json:"errcode"`
222+
Msg string `json:"errmsg"`
223+
}{}
224+
if err = json.NewDecoder(resp.Body).Decode(&obj); err != nil {
225+
return nil, "", err
226+
}
227+
if obj.Code != 0 {
228+
return nil, "", fmt.Errorf("CODE: %d, MSG: %s", obj.Code, obj.Msg)
229+
}
230+
231+
token := &oauth2.Token{
232+
AccessToken: obj.AccessToken,
233+
Expiry: time.Now().Add(obj.ExpiresIn * time.Second),
234+
}
235+
236+
return token, obj.Openid, nil
237+
}

0 commit comments

Comments
 (0)