server.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. package server
  2. import (
  3. "net/http"
  4. "strings"
  5. "sync"
  6. "time"
  7. "github.com/denisbrodbeck/machineid"
  8. "github.com/gin-gonic/gin"
  9. "github.com/golang-jwt/jwt/v5"
  10. "github.com/google/uuid"
  11. "github.com/gorilla/websocket"
  12. "github.com/nicksnyder/go-i18n/v2/i18n"
  13. "github.com/rotisserie/eris"
  14. "github.com/spf13/viper"
  15. "go.uber.org/zap"
  16. "sikey.com/websocket/config"
  17. "sikey.com/websocket/models"
  18. "sikey.com/websocket/pkg/dbx"
  19. "sikey.com/websocket/pkg/format"
  20. "sikey.com/websocket/pkg/rdbx"
  21. "sikey.com/websocket/repositories"
  22. "sikey.com/websocket/utils/keys"
  23. )
  24. type Server struct {
  25. id string
  26. nats *Nats
  27. // ctx *gin.Context
  28. repos *repositories.Repositories
  29. clients map[*Client]struct{}
  30. clientsMutex sync.RWMutex
  31. Connect chan *Client
  32. Disconnect chan *Client
  33. upgrader websocket.Upgrader
  34. }
  35. func NewServer() *Server {
  36. repos := repositories.NewRepositories(dbx.GetConnect(), rdbx.GetConnect())
  37. srv := &Server{
  38. id: serverId(),
  39. nats: NewNats(config.Websocket.NatsUrl, repos),
  40. upgrader: websocket.Upgrader{
  41. ReadBufferSize: config.Websocket.ReadBufferSize,
  42. WriteBufferSize: config.Websocket.WriteBufferSize,
  43. CheckOrigin: func(r *http.Request) bool {
  44. return true
  45. },
  46. },
  47. repos: repos,
  48. clients: make(map[*Client]struct{}),
  49. clientsMutex: sync.RWMutex{},
  50. Connect: make(chan *Client),
  51. Disconnect: make(chan *Client),
  52. }
  53. // 监听连接的事件
  54. go srv.events()
  55. return srv
  56. }
  57. func (srv *Server) events() {
  58. for {
  59. select {
  60. case clt, ok := <-srv.Connect:
  61. if !ok {
  62. return
  63. }
  64. srv.clientsMutex.Lock()
  65. srv.clients[clt] = struct{}{}
  66. srv.clientsMutex.Unlock()
  67. case clt, ok := <-srv.Disconnect:
  68. if !ok {
  69. return
  70. }
  71. srv.clientsMutex.Lock()
  72. if _, ok := srv.clients[clt]; ok {
  73. delete(srv.clients, clt)
  74. }
  75. srv.clientsMutex.Unlock()
  76. }
  77. }
  78. }
  79. func (srv *Server) WebsocketHandler(ctx *gin.Context, bundle *i18n.Bundle) {
  80. // Builder headers
  81. headers := headerBuilder(ctx)
  82. var err error
  83. var isWatch bool
  84. var id string
  85. id, ok := ctx.GetQuery("uid")
  86. if !ok {
  87. // Validate token
  88. user, ok, err := jwtParse(headers)
  89. if !ok {
  90. zap.L().Error("[conn] invalid token", zap.Error(err))
  91. ctx.AbortWithError(http.StatusUnauthorized, err)
  92. return
  93. }
  94. id = user.UID
  95. isWatch = user.IsChild
  96. }
  97. // 查询用户是否有 firebase messaging token
  98. firebaseMessageToken, err := srv.repos.FirebaseMessageRepository.GetFirebaseToken(ctx, id)
  99. if err != nil {
  100. if !eris.Is(err, models.ErrRecordNotFound) {
  101. zap.L().Error("[conn] find firebase token", zap.Error(err))
  102. ctx.AbortWithError(http.StatusUnauthorized, err)
  103. return
  104. }
  105. }
  106. // zap.L().Info("[conn] client online", zap.String("user_id", id))
  107. conn, err := srv.upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
  108. if err != nil {
  109. zap.L().Error("[conn] upgrade error", zap.Error(err))
  110. ctx.AbortWithError(http.StatusInternalServerError, err)
  111. return
  112. }
  113. // Create client
  114. client := &Client{
  115. ctx: ctx.Copy(),
  116. UserId: id,
  117. srv: srv,
  118. nats: srv.nats,
  119. UnderlyingConn: conn,
  120. isWatch: isWatch,
  121. localizer: i18n.NewLocalizer(bundle, headers[keys.LocalizationHeader].(string)),
  122. firebaseToken: firebaseMessageToken.Token,
  123. loginToken: headers[keys.AccessTokenHeader].(string),
  124. online: &models.Online{UserId: id, ServerId: srv.id},
  125. Received: make(chan Message, config.Websocket.MessageSize),
  126. ReplySend: make(chan Message, config.Websocket.MessageSize),
  127. firstReadWait: config.Websocket.FirstReadWait * time.Second,
  128. heartbeatWait: config.Websocket.HeartbeatWait * time.Second,
  129. writeWait: config.Websocket.WriteWait * time.Second,
  130. repos: srv.repos,
  131. }
  132. // 连接监听 nats 的消息, 加入到监听者
  133. srv.nats.Subscribe <- &subscriber{client: client}
  134. // 将连接加入到 server 的连接管理里
  135. srv.Connect <- client
  136. // Online status to redis
  137. if err := client.Online(); err != nil {
  138. ctx.AbortWithError(http.StatusInternalServerError, err)
  139. return
  140. }
  141. // 客户端读/写协程
  142. go client.reader()
  143. go client.recv()
  144. go client.reply()
  145. // 加载离线消息, 一次行推送给客户端
  146. go client.loadOfflineMessage()
  147. zap.L().Info("[conn] 客户端上线", zap.String("user_id", client.UserId), zap.String("token", client.loginToken))
  148. }
  149. func (srv *Server) GetClients() []map[string]interface{} {
  150. srv.clientsMutex.RLock()
  151. defer srv.clientsMutex.RUnlock()
  152. result := make([]map[string]interface{}, 0, len(srv.clients))
  153. for clt := range srv.clients {
  154. result = append(result, map[string]interface{}{
  155. "addr": clt.UnderlyingConn.RemoteAddr().String(),
  156. "user_id": clt.UserId,
  157. "online": clt.online,
  158. "is_simple_msg": clt.isSimpleMsg,
  159. "is_watch": clt.isWatch,
  160. "localization": clt.localization,
  161. "firebase_token": clt.firebaseToken,
  162. "login_token": clt.loginToken,
  163. "first_read_wait": clt.firstReadWait,
  164. "write_wait": clt.writeWait,
  165. "heartbeat_wait": clt.heartbeatWait,
  166. "last_heartbeat_time": clt.LastHeartbeatTime.Format(format.DateParseAllUnixMilliFormat),
  167. "last_received_message_time": clt.LastReceivedMessageTime.Format(format.DateParseAllUnixMilliFormat),
  168. "last_received_notify_time": clt.LastReceivedNotifyTime.Format(format.DateParseAllUnixMilliFormat),
  169. })
  170. }
  171. return result
  172. }
  173. func serverId() string {
  174. var id string
  175. id, err := machineid.ID()
  176. if err != nil {
  177. id = uuid.NewString()
  178. } else {
  179. id = strings.ToLower(id)
  180. }
  181. return id
  182. }
  183. type Headers = map[string]interface{}
  184. func headerBuilder(ctx *gin.Context) Headers {
  185. headers := make(Headers)
  186. request := ctx.Request
  187. accessToken := request.URL.Query().Get(keys.AccessTokenHeader)
  188. simple := request.URL.Query().Get(keys.SimpleHeader)
  189. localization := request.URL.Query().Get(keys.LocalizationHeader)
  190. if localization == "" {
  191. localization = "en"
  192. }
  193. headers[keys.UserIdHeader] = request.URL.Query().Get(keys.UserIdHeader)
  194. headers[keys.AccessTokenHeader] = accessToken
  195. headers[keys.SimpleHeader] = simple == "1"
  196. headers[keys.LocalizationHeader] = localization
  197. return headers
  198. }
  199. type UserClaims struct {
  200. jwt.RegisteredClaims
  201. UID string `json:"uid"`
  202. IsChild bool `json:"is_child"`
  203. }
  204. func jwtParse(headers Headers) (*UserClaims, bool, error) {
  205. if userId, ok := headers[keys.UserIdHeader]; ok {
  206. if userId != "" {
  207. return &UserClaims{UID: userId.(string), IsChild: false}, true, nil
  208. }
  209. }
  210. accessToken := headers[keys.AccessTokenHeader].(string)
  211. if len(accessToken) == 0 {
  212. return nil, false, eris.New("token is empty")
  213. }
  214. accessToken = strings.Trim(accessToken, " ")
  215. token, err := jwt.ParseWithClaims(accessToken, &UserClaims{}, func(t *jwt.Token) (interface{}, error) {
  216. return []byte(viper.GetString("auth.secret")), nil
  217. })
  218. userClaims, ok := token.Claims.(*UserClaims)
  219. if ok {
  220. return userClaims, true, nil
  221. } else {
  222. return nil, false, eris.Wrap(err, "token parse error")
  223. }
  224. // mapClaims := token.Claims.(jwt.MapClaims)
  225. // exp := mapClaims["exp"].(float64)
  226. // if exp != 0 && exp < float64(time.Now().Unix()) {
  227. // return "", false, eris.New("token is expired")
  228. // }
  229. }