server.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package server
  2. import (
  3. "net/http"
  4. "strings"
  5. "time"
  6. "github.com/gin-gonic/gin"
  7. "github.com/golang-jwt/jwt/v5"
  8. "github.com/gorilla/websocket"
  9. "github.com/rotisserie/eris"
  10. "github.com/spf13/viper"
  11. "sikey.com/websocket/config"
  12. "sikey.com/websocket/repositories"
  13. "sikey.com/websocket/utils/keys"
  14. "sikey.com/websocket/utils/zlog"
  15. )
  16. type Server struct {
  17. Ctx *gin.Context
  18. Repositories *repositories.Repositories
  19. Upgrader websocket.Upgrader
  20. Hub *Hub
  21. }
  22. func WebsocketHandler(ctx *gin.Context, srv *Server) {
  23. srv.Ctx = ctx
  24. conn, err := srv.Upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
  25. if err != nil {
  26. ctx.AbortWithError(http.StatusInternalServerError, err)
  27. return
  28. }
  29. // Builder headers
  30. headers := headerBuilder(ctx)
  31. // Validate token
  32. id, ok, err := jwtParse(headers)
  33. if !ok {
  34. ctx.AbortWithError(http.StatusUnauthorized, err)
  35. return
  36. }
  37. // Create client
  38. client := &Client{
  39. ctx: ctx.Copy(),
  40. UserId: id,
  41. hub: srv.Hub,
  42. UnderlyingConn: conn,
  43. Send: make(chan *Message, config.Websocket.MessageSize),
  44. writeWait: config.Websocket.WriteWait * time.Second,
  45. readWait: config.Websocket.ReadWait * time.Second,
  46. pingWait: config.Websocket.HeartbeatWait * time.Second,
  47. isSimpleMsg: headers[keys.SimpleHeader].(bool),
  48. localization: headers[keys.LocalizationHeader].(string),
  49. repos: srv.Repositories,
  50. }
  51. srv.Hub.Connect <- client
  52. zlog.Info("client: ", client.UserId)
  53. go client.reader()
  54. go client.writer()
  55. }
  56. type Headers = map[string]interface{}
  57. func headerBuilder(ctx *gin.Context) Headers {
  58. headers := make(Headers)
  59. request := ctx.Request
  60. accessToken := request.URL.Query().Get(keys.AccessTokenHeader)
  61. simple := request.URL.Query().Get(keys.SimpleHeader)
  62. localization := request.URL.Query().Get(keys.LocalizationHeader)
  63. headers[keys.UserIdHeader] = request.URL.Query().Get(keys.UserIdHeader)
  64. headers[keys.AccessTokenHeader] = accessToken
  65. headers[keys.SimpleHeader] = simple == "1"
  66. headers[keys.LocalizationHeader] = localization
  67. return headers
  68. }
  69. func jwtParse(headers Headers) (string, bool, error) {
  70. if userId, ok := headers[keys.UserIdHeader]; ok {
  71. return userId.(string), true, nil
  72. }
  73. accessToken := headers[keys.AccessTokenHeader].(string)
  74. if len(accessToken) == 0 {
  75. return "", false, eris.New("token is empty")
  76. }
  77. accessToken = strings.Trim(accessToken, " ")
  78. token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
  79. return []byte(viper.GetString("auth.secret")), nil
  80. })
  81. if err != nil {
  82. zlog.Error(err.Error())
  83. return "", false, eris.Wrap(err, "token parse error")
  84. }
  85. mapClaims := token.Claims.(jwt.MapClaims)
  86. // exp := mapClaims["exp"].(float64)
  87. // if exp != 0 && exp < float64(time.Now().Unix()) {
  88. // return "", false, eris.New("token is expired")
  89. // }
  90. uid := mapClaims["Uid"].(string)
  91. return uid, true, nil
  92. }