conn.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package biz
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/go-kratos/kratos/v2/errors"
  6. "github.com/gorilla/websocket"
  7. "io"
  8. "time"
  9. v2 "w303a/server/apis/gate/v2"
  10. "w303a/server/app/gate/internal/conf"
  11. "w303a/server/pkg/zaplog"
  12. )
  13. var (
  14. // ErrUserOffline 用户不在线
  15. ErrUserOffline = errors.New(1001, v2.ErrorReason_ERR_USER_OFFLINE.String(), "user offline")
  16. )
  17. type Conn struct {
  18. *websocket.Conn
  19. c *conf.Server
  20. log *zaplog.Logger
  21. ctx context.Context
  22. msguc *MessageUsecase
  23. ID string
  24. }
  25. type ConnRepo interface {
  26. // Online 上线
  27. Online(ctx context.Context, id string, exp time.Duration) error
  28. // Offline 下线
  29. Offline(ctx context.Context, id string) error
  30. // ResetHeartbeat 重置心跳
  31. ResetHeartbeat(ctx context.Context, c *Conn, exp time.Duration) error
  32. // GetServerId 获取服务器ID
  33. GetServerId(ctx context.Context, id string) (string, error)
  34. }
  35. type ConnUsecase struct {
  36. log *zaplog.Logger
  37. c *conf.Server
  38. h *Hub
  39. connRepo ConnRepo
  40. msguc *MessageUsecase
  41. }
  42. func NewConnUsecase(
  43. log *zaplog.Logger,
  44. c *conf.Server,
  45. h *Hub,
  46. msguc *MessageUsecase,
  47. connRepo ConnRepo,
  48. ) *ConnUsecase {
  49. return &ConnUsecase{
  50. log: log,
  51. c: c,
  52. h: h,
  53. msguc: msguc,
  54. connRepo: connRepo,
  55. }
  56. }
  57. func (uc *ConnUsecase) NewConn(ctx context.Context, nsConn *websocket.Conn, id string) (*Conn, error) {
  58. if id == "" {
  59. _ = nsConn.Close()
  60. return nil, errors.New(400, "ID_IS_EMPTY", "id is empty")
  61. }
  62. return &Conn{
  63. ctx: ctx,
  64. Conn: nsConn,
  65. ID: id,
  66. c: uc.c,
  67. log: uc.log,
  68. }, nil
  69. }
  70. // Reading 读取消息并处理
  71. func (uc *ConnUsecase) Reading(c *Conn) error {
  72. // 查询时的消息数据
  73. uc.msguc.LoadOfflineMessage(c.ctx, c)
  74. // 持续读取消息
  75. return uc.heartbeatTimeoutWrapper(c, uc.msguc.HandlerMessage)
  76. }
  77. // 读取消息并处理
  78. func (uc *ConnUsecase) heartbeatTimeoutWrapper(c *Conn, handlerMessage MessageHandler) error {
  79. ctx := context.WithoutCancel(c.ctx)
  80. c.SetCloseHandler(func(code int, text string) error {
  81. uc.h.Unregister(c)
  82. uc.connRepo.Offline(ctx, c.ID)
  83. uc.log.Sugar().Infof("conn closed: %d %s", code, text)
  84. return errors.New(code, "CONN_CLOSED", text)
  85. })
  86. for {
  87. c.resetReadDeadline()
  88. uc.connRepo.ResetHeartbeat(ctx, c, c.c.Websocket.Keepalive.AsDuration())
  89. msgBuf, err := c.readMessage()
  90. if err != nil {
  91. return err
  92. }
  93. // 没消息内容
  94. if len(msgBuf) == 0 {
  95. continue
  96. }
  97. msg, err := newMessage(msgBuf)
  98. if err != nil {
  99. c.log.Sugar().Errorf("new message error: %v", err)
  100. continue
  101. }
  102. // 如果是心跳消息,回复 pong
  103. if msg.isHeartbeat() {
  104. c.resetReadDeadline()
  105. uc.connRepo.ResetHeartbeat(ctx, c, c.c.Websocket.Keepalive.AsDuration())
  106. c.writeMessage(newReplyHeartbeat(msg))
  107. continue
  108. }
  109. if err = handlerMessage(ctx, c, msg); err != nil {
  110. c.log.Sugar().Errorf("handler message error: %v", err)
  111. continue
  112. }
  113. }
  114. }
  115. func (c *Conn) readMessage() ([]byte, error) {
  116. var r io.Reader
  117. mt, r, err := c.NextReader()
  118. if err != nil {
  119. return nil, err
  120. }
  121. if !c.isMessageTypeSupported(mt) {
  122. c.log.Sugar().Errorf("unsupported message type: %d", mt)
  123. return nil, fmt.Errorf("unsupported message type: %d", mt)
  124. }
  125. // 处理 ping / pong 帧消息,处理结束后跳过当前消息
  126. if c.handlePingPongFrame(mt) {
  127. return make([]byte, 0), nil
  128. }
  129. p, err := io.ReadAll(r)
  130. return p, err
  131. }
  132. // 判断消息类型是否支持
  133. func (c *Conn) isMessageTypeSupported(mt int) bool {
  134. return mt == websocket.TextMessage || mt == websocket.PingMessage || mt == websocket.PongMessage
  135. }
  136. // 处理 ping / pong 帧消息
  137. func (c *Conn) handlePingPongFrame(mt int) bool {
  138. if mt == websocket.PingMessage {
  139. c.WriteMessage(websocket.PongMessage, []byte(""))
  140. return true
  141. } else if mt == websocket.PongMessage {
  142. c.WriteMessage(websocket.PingMessage, []byte(""))
  143. return true
  144. }
  145. return false
  146. }
  147. func (c *Conn) resetReadDeadline() {
  148. c.SetReadDeadline(time.Now().Add(c.c.Websocket.Keepalive.AsDuration()))
  149. }
  150. func (c *Conn) resetWriteDeadline() {
  151. c.SetWriteDeadline(time.Now().Add(c.c.Websocket.WriteTimeout.AsDuration()))
  152. }
  153. func (c *Conn) writeMessage(msg *Message) {
  154. c.resetWriteDeadline()
  155. if err := c.WriteMessage(websocket.TextMessage, msg.Bytes()); err != nil {
  156. c.log.Sugar().Errorf("write message error: %v", err)
  157. }
  158. }