server.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package server
  2. import (
  3. "net/http"
  4. "strings"
  5. "time"
  6. "github.com/gin-gonic/gin"
  7. "github.com/golang-jwt/jwt/v5"
  8. "github.com/gorilla/websocket"
  9. "github.com/rotisserie/eris"
  10. "github.com/spf13/viper"
  11. "sikey.com/websocket/repositories"
  12. "sikey.com/websocket/utils/keys"
  13. "sikey.com/websocket/utils/zlog"
  14. )
  15. type Server struct {
  16. Ctx *gin.Context
  17. Repositories *repositories.Repositories
  18. Upgrader websocket.Upgrader
  19. Hub *Hub
  20. ReadWait time.Duration
  21. WriteWait time.Duration
  22. PingWait time.Duration
  23. }
  24. func WebsocketHandler(ctx *gin.Context, srv *Server) {
  25. srv.Ctx = ctx
  26. conn, err := srv.Upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
  27. if err != nil {
  28. ctx.AbortWithError(http.StatusInternalServerError, err)
  29. return
  30. }
  31. // Builder headers
  32. headers := headerBuilder(ctx)
  33. // Validate token
  34. id, ok, err := jwtParse(headers)
  35. if !ok {
  36. ctx.AbortWithError(http.StatusUnauthorized, err)
  37. return
  38. }
  39. // Create client
  40. client := &Client{
  41. ctx: ctx.Copy(),
  42. UserId: id,
  43. hub: srv.Hub,
  44. UnderlyingConn: conn,
  45. Send: make(chan *Message, 256),
  46. RemotelyMessage: make(chan *Message, 256),
  47. writeWait: srv.WriteWait,
  48. readWait: srv.ReadWait,
  49. pingWait: srv.PingWait,
  50. isSimpleMsg: headers[keys.SimpleHeader].(bool),
  51. localization: headers[keys.LocalizationHeader].(string),
  52. repos: srv.Repositories,
  53. }
  54. client.hub.Connect <- client
  55. // client.hub.exchange.OnlineNotify(client.UserId)
  56. zlog.Info("client online: ", client.UserId)
  57. go client.reader()
  58. go client.writer()
  59. }
  60. type Headers = map[string]interface{}
  61. func headerBuilder(ctx *gin.Context) Headers {
  62. headers := make(Headers)
  63. request := ctx.Request
  64. accessToken := request.URL.Query().Get(keys.AccessTokenHeader)
  65. simple := request.URL.Query().Get(keys.SimpleHeader)
  66. localization := request.URL.Query().Get(keys.LocalizationHeader)
  67. headers[keys.AccessTokenHeader] = accessToken
  68. headers[keys.SimpleHeader] = simple == "1"
  69. headers[keys.LocalizationHeader] = localization
  70. return headers
  71. }
  72. func jwtParse(headers Headers) (string, bool, error) {
  73. accessToken := headers[keys.AccessTokenHeader].(string)
  74. if len(accessToken) == 0 {
  75. return "", false, eris.New("token is empty")
  76. }
  77. accessToken = strings.Trim(accessToken, " ")
  78. token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
  79. return []byte(viper.GetString("auth.secret")), nil
  80. })
  81. if err != nil {
  82. zlog.Error(err.Error())
  83. return "", false, eris.Wrap(err, "token parse error")
  84. }
  85. mapClaims := token.Claims.(jwt.MapClaims)
  86. // exp := mapClaims["exp"].(float64)
  87. // if exp != 0 && exp < float64(time.Now().Unix()) {
  88. // return "", false, eris.New("token is expired")
  89. // }
  90. uid := mapClaims["Uid"].(string)
  91. return uid, true, nil
  92. }