123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 |
- package server
- import (
- "net/http"
- "strings"
- "sync"
- "time"
- "github.com/denisbrodbeck/machineid"
- "github.com/gin-gonic/gin"
- "github.com/golang-jwt/jwt/v5"
- "github.com/google/uuid"
- "github.com/gorilla/websocket"
- "github.com/nicksnyder/go-i18n/v2/i18n"
- "github.com/rotisserie/eris"
- "github.com/spf13/viper"
- "go.uber.org/zap"
- "sikey.com/websocket/config"
- "sikey.com/websocket/models"
- "sikey.com/websocket/pkg/dbx"
- "sikey.com/websocket/pkg/format"
- "sikey.com/websocket/pkg/rdbx"
- "sikey.com/websocket/repositories"
- "sikey.com/websocket/utils/keys"
- )
- type Server struct {
- id string
- nats *Nats
- // ctx *gin.Context
- repos *repositories.Repositories
- clients map[*Client]struct{}
- clientsMutex sync.RWMutex
- Connect chan *Client
- Disconnect chan *Client
- upgrader websocket.Upgrader
- }
- func NewServer() *Server {
- repos := repositories.NewRepositories(dbx.GetConnect(), rdbx.GetConnect())
- srv := &Server{
- id: serverId(),
- nats: NewNats(config.Websocket.NatsUrl, repos),
- upgrader: websocket.Upgrader{
- ReadBufferSize: config.Websocket.ReadBufferSize,
- WriteBufferSize: config.Websocket.WriteBufferSize,
- CheckOrigin: func(r *http.Request) bool {
- return true
- },
- },
- repos: repos,
- clients: make(map[*Client]struct{}),
- clientsMutex: sync.RWMutex{},
- Connect: make(chan *Client),
- Disconnect: make(chan *Client),
- }
- // 监听连接的事件
- go srv.events()
- return srv
- }
- func (srv *Server) events() {
- for {
- select {
- case clt, ok := <-srv.Connect:
- if !ok {
- return
- }
- srv.clientsMutex.Lock()
- srv.clients[clt] = struct{}{}
- srv.clientsMutex.Unlock()
- case clt, ok := <-srv.Disconnect:
- if !ok {
- return
- }
- srv.clientsMutex.Lock()
- if _, ok := srv.clients[clt]; ok {
- delete(srv.clients, clt)
- }
- srv.clientsMutex.Unlock()
- }
- }
- }
- func (srv *Server) WebsocketHandler(ctx *gin.Context, bundle *i18n.Bundle) {
- // Builder headers
- headers := headerBuilder(ctx)
- var err error
- var isWatch bool
- var id string
- id, ok := ctx.GetQuery("uid")
- if !ok {
- // Validate token
- user, ok, err := jwtParse(headers)
- if !ok {
- zap.L().Error("[conn] invalid token", zap.Error(err))
- ctx.AbortWithError(http.StatusUnauthorized, err)
- return
- }
- id = user.UID
- isWatch = user.IsChild
- }
- // 查询用户是否有 firebase messaging token
- firebaseMessageToken, err := srv.repos.FirebaseMessageRepository.GetFirebaseToken(ctx, id)
- if err != nil {
- if !eris.Is(err, models.ErrRecordNotFound) {
- zap.L().Error("[conn] find firebase token", zap.Error(err))
- ctx.AbortWithError(http.StatusUnauthorized, err)
- return
- }
- }
- // zap.L().Info("[conn] client online", zap.String("user_id", id))
- conn, err := srv.upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
- if err != nil {
- zap.L().Error("[conn] upgrade error", zap.Error(err))
- ctx.AbortWithError(http.StatusInternalServerError, err)
- return
- }
- // Create client
- client := &Client{
- ctx: ctx.Copy(),
- UserId: id,
- srv: srv,
- nats: srv.nats,
- UnderlyingConn: conn,
- isWatch: isWatch,
- localizer: i18n.NewLocalizer(bundle, headers[keys.LocalizationHeader].(string)),
- firebaseToken: firebaseMessageToken.Token,
- loginToken: headers[keys.AccessTokenHeader].(string),
- online: &models.Online{UserId: id, ServerId: srv.id},
- Received: make(chan Message, config.Websocket.MessageSize),
- ReplySend: make(chan Message, config.Websocket.MessageSize),
- firstReadWait: config.Websocket.FirstReadWait * time.Second,
- heartbeatWait: config.Websocket.HeartbeatWait * time.Second,
- writeWait: config.Websocket.WriteWait * time.Second,
- repos: srv.repos,
- }
- // 连接监听 nats 的消息, 加入到监听者
- srv.nats.Subscribe <- &subscriber{client: client}
- // 将连接加入到 server 的连接管理里
- srv.Connect <- client
- // Online status to redis
- if err := client.Online(); err != nil {
- ctx.AbortWithError(http.StatusInternalServerError, err)
- return
- }
- // 客户端读/写协程
- go client.reader()
- go client.recv()
- go client.reply()
- // 加载离线消息, 一次行推送给客户端
- go client.loadOfflineMessage()
- zap.L().Info("[conn] 客户端上线", zap.String("user_id", client.UserId), zap.String("token", client.loginToken))
- }
- func (srv *Server) GetClients() []map[string]interface{} {
- srv.clientsMutex.RLock()
- defer srv.clientsMutex.RUnlock()
- result := make([]map[string]interface{}, 0, len(srv.clients))
- for clt := range srv.clients {
- result = append(result, map[string]interface{}{
- "addr": clt.UnderlyingConn.RemoteAddr().String(),
- "user_id": clt.UserId,
- "online": clt.online,
- "is_simple_msg": clt.isSimpleMsg,
- "is_watch": clt.isWatch,
- "localization": clt.localization,
- "firebase_token": clt.firebaseToken,
- "login_token": clt.loginToken,
- "first_read_wait": clt.firstReadWait,
- "write_wait": clt.writeWait,
- "heartbeat_wait": clt.heartbeatWait,
- "last_heartbeat_time": clt.LastHeartbeatTime.Format(format.DateParseAllUnixMilliFormat),
- "last_received_message_time": clt.LastReceivedMessageTime.Format(format.DateParseAllUnixMilliFormat),
- "last_received_notify_time": clt.LastReceivedNotifyTime.Format(format.DateParseAllUnixMilliFormat),
- })
- }
- return result
- }
- func serverId() string {
- var id string
- id, err := machineid.ID()
- if err != nil {
- id = uuid.NewString()
- } else {
- id = strings.ToLower(id)
- }
- return id
- }
- type Headers = map[string]interface{}
- func headerBuilder(ctx *gin.Context) Headers {
- headers := make(Headers)
- request := ctx.Request
- accessToken := request.URL.Query().Get(keys.AccessTokenHeader)
- simple := request.URL.Query().Get(keys.SimpleHeader)
- localization := request.URL.Query().Get(keys.LocalizationHeader)
- if localization == "" {
- localization = "en"
- }
- headers[keys.UserIdHeader] = request.URL.Query().Get(keys.UserIdHeader)
- headers[keys.AccessTokenHeader] = accessToken
- headers[keys.SimpleHeader] = simple == "1"
- headers[keys.LocalizationHeader] = localization
- return headers
- }
- type UserClaims struct {
- jwt.RegisteredClaims
- UID string `json:"uid"`
- IsChild bool `json:"is_child"`
- }
- func jwtParse(headers Headers) (*UserClaims, bool, error) {
- if userId, ok := headers[keys.UserIdHeader]; ok {
- if userId != "" {
- return &UserClaims{UID: userId.(string), IsChild: false}, true, nil
- }
- }
- accessToken := headers[keys.AccessTokenHeader].(string)
- if len(accessToken) == 0 {
- return nil, false, eris.New("token is empty")
- }
- accessToken = strings.Trim(accessToken, " ")
- token, err := jwt.ParseWithClaims(accessToken, &UserClaims{}, func(t *jwt.Token) (interface{}, error) {
- return []byte(viper.GetString("auth.secret")), nil
- })
- userClaims, ok := token.Claims.(*UserClaims)
- if ok {
- return userClaims, true, nil
- } else {
- return nil, false, eris.Wrap(err, "token parse error")
- }
- // mapClaims := token.Claims.(jwt.MapClaims)
- // exp := mapClaims["exp"].(float64)
- // if exp != 0 && exp < float64(time.Now().Unix()) {
- // return "", false, eris.New("token is expired")
- // }
- }
|