server.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. package server
  2. import (
  3. "net/http"
  4. "strings"
  5. "time"
  6. "code.sikey.com.cn/serverbackend/Serverx/zlog"
  7. "github.com/gin-gonic/gin"
  8. "github.com/golang-jwt/jwt/v5"
  9. "github.com/gorilla/websocket"
  10. "github.com/rotisserie/eris"
  11. "github.com/spf13/viper"
  12. "sikey.com/websocket/config"
  13. "sikey.com/websocket/models"
  14. "sikey.com/websocket/repositories"
  15. "sikey.com/websocket/utils/keys"
  16. )
  17. type Server struct {
  18. ID string
  19. Ctx *gin.Context
  20. Repositories *repositories.Repositories
  21. Upgrader websocket.Upgrader
  22. Hub *Hub
  23. }
  24. func (srv *Server) WebsocketHandler(ctx *gin.Context) {
  25. srv.Ctx = ctx
  26. // Builder headers
  27. headers := headerBuilder(ctx)
  28. // Validate token
  29. id, ok, err := jwtParse(headers)
  30. if !ok {
  31. zlog.Error(err)
  32. ctx.AbortWithError(http.StatusUnauthorized, err)
  33. return
  34. }
  35. log := zlog.With("ClientId", id)
  36. log.Info("client online")
  37. conn, err := srv.Upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
  38. if err != nil {
  39. zlog.Error(err)
  40. ctx.AbortWithError(http.StatusInternalServerError, err)
  41. return
  42. }
  43. // Reconnection mechanism
  44. var client *Client
  45. if client = srv.Hub.getClientByUserId(id); client != nil {
  46. _ = client.UnderlyingConn.Close()
  47. client.UnderlyingConn = conn
  48. log.Info("client reconnection")
  49. return
  50. }
  51. // Create client
  52. client = &Client{
  53. ctx: ctx.Copy(),
  54. UserId: id,
  55. hub: srv.Hub,
  56. UnderlyingConn: conn,
  57. online: &models.Online{UserId: id, ServerId: srv.ID},
  58. Send: make(chan *Message, config.Websocket.MessageSize),
  59. writeWait: config.Websocket.WriteWait * time.Second,
  60. readWait: config.Websocket.ReadWait * time.Second,
  61. pingWait: config.Websocket.HeartbeatWait * time.Second,
  62. isSimpleMsg: headers[keys.SimpleHeader].(bool),
  63. localization: headers[keys.LocalizationHeader].(string),
  64. repos: srv.Repositories,
  65. }
  66. srv.Hub.Connect <- client
  67. // Online status to redis
  68. online := &models.Online{UserId: client.UserId, ServerId: srv.ID}
  69. if err := srv.Repositories.OnlineRepository.SetOnline(ctx, online); err != nil {
  70. ctx.AbortWithError(http.StatusInternalServerError,
  71. eris.Wrapf(err, "unable to set online status for user: %s", client.UserId))
  72. return
  73. }
  74. go client.reader()
  75. go client.writer()
  76. }
  77. type Headers = map[string]interface{}
  78. func headerBuilder(ctx *gin.Context) Headers {
  79. headers := make(Headers)
  80. request := ctx.Request
  81. accessToken := request.URL.Query().Get(keys.AccessTokenHeader)
  82. simple := request.URL.Query().Get(keys.SimpleHeader)
  83. localization := request.URL.Query().Get(keys.LocalizationHeader)
  84. headers[keys.UserIdHeader] = request.URL.Query().Get(keys.UserIdHeader)
  85. headers[keys.AccessTokenHeader] = accessToken
  86. headers[keys.SimpleHeader] = simple == "1"
  87. headers[keys.LocalizationHeader] = localization
  88. return headers
  89. }
  90. type UserClaims struct {
  91. jwt.RegisteredClaims
  92. UID string `json:"uid"`
  93. }
  94. func jwtParse(headers Headers) (string, bool, error) {
  95. if userId, ok := headers[keys.UserIdHeader]; ok {
  96. if userId != "" {
  97. return userId.(string), true, nil
  98. }
  99. }
  100. accessToken := headers[keys.AccessTokenHeader].(string)
  101. if len(accessToken) == 0 {
  102. return "", false, eris.New("token is empty")
  103. }
  104. accessToken = strings.Trim(accessToken, " ")
  105. token, err := jwt.ParseWithClaims(accessToken, &UserClaims{}, func(t *jwt.Token) (interface{}, error) {
  106. return []byte(viper.GetString("auth.secret")), nil
  107. })
  108. userClaims, ok := token.Claims.(*UserClaims)
  109. if ok {
  110. return userClaims.UID, true, nil
  111. } else {
  112. zlog.Errorf("invalid token: %v", err)
  113. return "", false, eris.Wrap(err, "token parse error")
  114. }
  115. // mapClaims := token.Claims.(jwt.MapClaims)
  116. // exp := mapClaims["exp"].(float64)
  117. // if exp != 0 && exp < float64(time.Now().Unix()) {
  118. // return "", false, eris.New("token is expired")
  119. // }
  120. }