server.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. package server
  2. import (
  3. "net/http"
  4. "strings"
  5. "time"
  6. "github.com/gin-gonic/gin"
  7. "github.com/golang-jwt/jwt/v5"
  8. "github.com/gorilla/websocket"
  9. "github.com/rotisserie/eris"
  10. "github.com/spf13/viper"
  11. "sikey.com/websocket/repositories"
  12. "sikey.com/websocket/utils/keys"
  13. "sikey.com/websocket/utils/zlog"
  14. )
  15. type Server struct {
  16. Ctx *gin.Context
  17. Repositories *repositories.Repositories
  18. Upgrader websocket.Upgrader
  19. Hub *Hub
  20. ReadWait time.Duration
  21. WriteWait time.Duration
  22. PingWait time.Duration
  23. }
  24. func WebsocketHandler(ctx *gin.Context, srv *Server) {
  25. srv.Ctx = ctx
  26. conn, err := srv.Upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
  27. if err != nil {
  28. ctx.AbortWithError(http.StatusInternalServerError, err)
  29. return
  30. }
  31. // Builder headers
  32. headers := headerBuilder(ctx)
  33. // Validate token
  34. id, ok, err := jwtParse(headers)
  35. if !ok {
  36. ctx.AbortWithError(http.StatusUnauthorized, err)
  37. return
  38. }
  39. // Create client
  40. client := &Client{
  41. ctx: ctx.Copy(),
  42. UserId: id,
  43. hub: srv.Hub,
  44. UnderlyingConn: conn,
  45. Send: make(chan *Message, 256),
  46. writeWait: srv.WriteWait,
  47. readWait: srv.ReadWait,
  48. pingWait: srv.PingWait,
  49. isSimpleMsg: headers[keys.SimpleHeader].(bool),
  50. localization: headers[keys.LocalizationHeader].(string),
  51. repos: srv.Repositories,
  52. }
  53. client.hub.Connect <- client
  54. // client.hub.exchange.OnlineNotify(client.UserId)
  55. zlog.Info("client online: ", client.UserId)
  56. go client.reader()
  57. go client.writer()
  58. }
  59. type Headers = map[string]interface{}
  60. func headerBuilder(ctx *gin.Context) Headers {
  61. headers := make(Headers)
  62. request := ctx.Request
  63. accessToken := request.URL.Query().Get(keys.AccessTokenHeader)
  64. simple := request.URL.Query().Get(keys.SimpleHeader)
  65. localization := request.URL.Query().Get(keys.LocalizationHeader)
  66. headers[keys.AccessTokenHeader] = accessToken
  67. headers[keys.SimpleHeader] = simple == "1"
  68. headers[keys.LocalizationHeader] = localization
  69. return headers
  70. }
  71. func jwtParse(headers Headers) (string, bool, error) {
  72. accessToken := headers[keys.AccessTokenHeader].(string)
  73. if len(accessToken) == 0 {
  74. return "", false, eris.New("token is empty")
  75. }
  76. accessToken = strings.Trim(accessToken, " ")
  77. token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
  78. return []byte(viper.GetString("auth.secret")), nil
  79. })
  80. if err != nil {
  81. zlog.Error(err.Error())
  82. return "", false, eris.Wrap(err, "token parse error")
  83. }
  84. mapClaims := token.Claims.(jwt.MapClaims)
  85. // exp := mapClaims["exp"].(float64)
  86. // if exp != 0 && exp < float64(time.Now().Unix()) {
  87. // return "", false, eris.New("token is expired")
  88. // }
  89. uid := mapClaims["Uid"].(string)
  90. return uid, true, nil
  91. }