server.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. package server
  2. import (
  3. "net/http"
  4. "strings"
  5. "time"
  6. "github.com/gin-gonic/gin"
  7. "github.com/golang-jwt/jwt/v5"
  8. "github.com/gorilla/websocket"
  9. "github.com/rotisserie/eris"
  10. "github.com/spf13/viper"
  11. "sikey.com/websocket/config"
  12. "sikey.com/websocket/models"
  13. "sikey.com/websocket/repositories"
  14. "sikey.com/websocket/utils/keys"
  15. "sikey.com/websocket/utils/zlog"
  16. )
  17. type Server struct {
  18. ID string
  19. Ctx *gin.Context
  20. Repositories *repositories.Repositories
  21. Upgrader websocket.Upgrader
  22. Hub *Hub
  23. }
  24. func (srv *Server) WebsocketHandler(ctx *gin.Context) {
  25. srv.Ctx = ctx
  26. conn, err := srv.Upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
  27. if err != nil {
  28. zlog.Error(err)
  29. ctx.AbortWithError(http.StatusInternalServerError, err)
  30. return
  31. }
  32. // Builder headers
  33. headers := headerBuilder(ctx)
  34. // Validate token
  35. id, ok, err := jwtParse(headers)
  36. if !ok {
  37. zlog.Error(err)
  38. ctx.AbortWithError(http.StatusUnauthorized, err)
  39. return
  40. }
  41. // Create client
  42. client := &Client{
  43. ctx: ctx.Copy(),
  44. UserId: id,
  45. hub: srv.Hub,
  46. UnderlyingConn: conn,
  47. Send: make(chan *Message, config.Websocket.MessageSize),
  48. writeWait: config.Websocket.WriteWait * time.Second,
  49. readWait: config.Websocket.ReadWait * time.Second,
  50. pingWait: config.Websocket.HeartbeatWait * time.Second,
  51. isSimpleMsg: headers[keys.SimpleHeader].(bool),
  52. localization: headers[keys.LocalizationHeader].(string),
  53. repos: srv.Repositories,
  54. }
  55. srv.Hub.Connect <- client
  56. zlog.Debugf("client: %s", client.UserId)
  57. // Online status to redis
  58. online := &models.Online{UserId: client.UserId, ServerId: srv.ID}
  59. if err := srv.Repositories.OnlineRepository.SetOnline(ctx, online); err != nil {
  60. ctx.AbortWithError(http.StatusInternalServerError,
  61. eris.Wrapf(err, "unable to set online status for user: %s", client.UserId))
  62. return
  63. }
  64. go client.reader()
  65. go client.writer()
  66. }
  67. type Headers = map[string]interface{}
  68. func headerBuilder(ctx *gin.Context) Headers {
  69. headers := make(Headers)
  70. request := ctx.Request
  71. accessToken := request.URL.Query().Get(keys.AccessTokenHeader)
  72. simple := request.URL.Query().Get(keys.SimpleHeader)
  73. localization := request.URL.Query().Get(keys.LocalizationHeader)
  74. headers[keys.UserIdHeader] = request.URL.Query().Get(keys.UserIdHeader)
  75. headers[keys.AccessTokenHeader] = accessToken
  76. headers[keys.SimpleHeader] = simple == "1"
  77. headers[keys.LocalizationHeader] = localization
  78. return headers
  79. }
  80. func jwtParse(headers Headers) (string, bool, error) {
  81. if userId, ok := headers[keys.UserIdHeader]; ok {
  82. if userId != "" {
  83. return userId.(string), true, nil
  84. }
  85. }
  86. accessToken := headers[keys.AccessTokenHeader].(string)
  87. if len(accessToken) == 0 {
  88. return "", false, eris.New("token is empty")
  89. }
  90. accessToken = strings.Trim(accessToken, " ")
  91. token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
  92. return []byte(viper.GetString("auth.secret")), nil
  93. })
  94. if err != nil {
  95. return "", false, eris.Wrap(err, "token parse error")
  96. }
  97. mapClaims := token.Claims.(jwt.MapClaims)
  98. // exp := mapClaims["exp"].(float64)
  99. // if exp != 0 && exp < float64(time.Now().Unix()) {
  100. // return "", false, eris.New("token is expired")
  101. // }
  102. uid := mapClaims["Uid"].(string)
  103. return uid, true, nil
  104. }