server.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package server
  2. import (
  3. "net/http"
  4. "strings"
  5. "time"
  6. "github.com/gin-gonic/gin"
  7. "github.com/golang-jwt/jwt/v5"
  8. "github.com/gorilla/websocket"
  9. "github.com/rotisserie/eris"
  10. "github.com/spf13/viper"
  11. "go.uber.org/zap"
  12. "sikey.com/websocket/config"
  13. "sikey.com/websocket/models"
  14. "sikey.com/websocket/repositories"
  15. "sikey.com/websocket/utils/keys"
  16. )
  17. type Server struct {
  18. ID string
  19. Ctx *gin.Context
  20. Repositories *repositories.Repositories
  21. Upgrader websocket.Upgrader
  22. Hub *Hub
  23. }
  24. func (srv *Server) WebsocketHandler(ctx *gin.Context) {
  25. srv.Ctx = ctx
  26. // Builder headers
  27. headers := headerBuilder(ctx)
  28. // Validate token
  29. id, ok, err := jwtParse(headers)
  30. if !ok {
  31. zap.L().Error("invalid token", zap.Error(err))
  32. ctx.AbortWithError(http.StatusUnauthorized, err)
  33. return
  34. }
  35. zap.L().Info("client online", zap.String("user_id", id))
  36. conn, err := srv.Upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
  37. if err != nil {
  38. zap.L().Error("upgrade error", zap.Error(err))
  39. ctx.AbortWithError(http.StatusInternalServerError, err)
  40. return
  41. }
  42. // Reconnection mechanism
  43. var client *Client
  44. if client = srv.Hub.getClientByUserId(id); client != nil {
  45. _ = client.UnderlyingConn.Close()
  46. client.UnderlyingConn = conn
  47. zap.L().Info("client reconnection", zap.String("user_id", id))
  48. return
  49. }
  50. // Create client
  51. client = &Client{
  52. ctx: ctx.Copy(),
  53. UserId: id,
  54. hub: srv.Hub,
  55. UnderlyingConn: conn,
  56. online: &models.Online{UserId: id, ServerId: srv.ID},
  57. Send: make(chan *Message, config.Websocket.MessageSize),
  58. writeWait: config.Websocket.WriteWait * time.Second,
  59. readWait: config.Websocket.ReadWait * time.Second,
  60. pingWait: config.Websocket.HeartbeatWait * time.Second,
  61. isSimpleMsg: headers[keys.SimpleHeader].(bool),
  62. localization: headers[keys.LocalizationHeader].(string),
  63. repos: srv.Repositories,
  64. }
  65. srv.Hub.Connect <- client
  66. // Online status to redis
  67. online := &models.Online{UserId: client.UserId, ServerId: srv.ID}
  68. if err := srv.Repositories.OnlineRepository.SetOnline(ctx, online); err != nil {
  69. ctx.AbortWithError(http.StatusInternalServerError,
  70. eris.Wrapf(err, "unable to set online status for user: %s", client.UserId))
  71. return
  72. }
  73. go client.reader()
  74. go client.writer()
  75. }
  76. type Headers = map[string]interface{}
  77. func headerBuilder(ctx *gin.Context) Headers {
  78. headers := make(Headers)
  79. request := ctx.Request
  80. accessToken := request.URL.Query().Get(keys.AccessTokenHeader)
  81. simple := request.URL.Query().Get(keys.SimpleHeader)
  82. localization := request.URL.Query().Get(keys.LocalizationHeader)
  83. headers[keys.UserIdHeader] = request.URL.Query().Get(keys.UserIdHeader)
  84. headers[keys.AccessTokenHeader] = accessToken
  85. headers[keys.SimpleHeader] = simple == "1"
  86. headers[keys.LocalizationHeader] = localization
  87. return headers
  88. }
  89. type UserClaims struct {
  90. jwt.RegisteredClaims
  91. UID string `json:"uid"`
  92. }
  93. func jwtParse(headers Headers) (string, bool, error) {
  94. if userId, ok := headers[keys.UserIdHeader]; ok {
  95. if userId != "" {
  96. return userId.(string), true, nil
  97. }
  98. }
  99. accessToken := headers[keys.AccessTokenHeader].(string)
  100. if len(accessToken) == 0 {
  101. return "", false, eris.New("token is empty")
  102. }
  103. accessToken = strings.Trim(accessToken, " ")
  104. token, err := jwt.ParseWithClaims(accessToken, &UserClaims{}, func(t *jwt.Token) (interface{}, error) {
  105. return []byte(viper.GetString("auth.secret")), nil
  106. })
  107. userClaims, ok := token.Claims.(*UserClaims)
  108. if ok {
  109. return userClaims.UID, true, nil
  110. } else {
  111. return "", false, eris.Wrap(err, "token parse error")
  112. }
  113. // mapClaims := token.Claims.(jwt.MapClaims)
  114. // exp := mapClaims["exp"].(float64)
  115. // if exp != 0 && exp < float64(time.Now().Unix()) {
  116. // return "", false, eris.New("token is expired")
  117. // }
  118. }