client.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. package server
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "regexp"
  8. "time"
  9. "github.com/gin-gonic/gin"
  10. "github.com/google/uuid"
  11. "github.com/gorilla/websocket"
  12. "github.com/nats-io/nats.go"
  13. "github.com/rotisserie/eris"
  14. "go.uber.org/zap"
  15. "sikey.com/websocket/models"
  16. "sikey.com/websocket/pkg/gid"
  17. "sikey.com/websocket/pkg/natx"
  18. "sikey.com/websocket/pkg/tox"
  19. "sikey.com/websocket/repositories"
  20. )
  21. type Client struct {
  22. ctx *gin.Context
  23. UserId string
  24. srv *Server
  25. nats *Nats
  26. UnderlyingConn *websocket.Conn
  27. online *models.Online
  28. isSimpleMsg bool // isSimpleMsg 是否是简单消息
  29. localization string // localization 国际码
  30. firebaseToken string // firebaseToken FCM 推送的 token
  31. loginToken string // loginToken 登录 token
  32. Received chan Message
  33. // Send message channel 发送消息
  34. // 当用户在线时会通过 Send channel 发送在线消息 但如果用户不在线,
  35. // Send chan Message
  36. // send chan *EncodedMessage // 消息通道
  37. ReplySend chan Message // ReplySend 回复消息专用 Channel
  38. // firstReadWait 首次消息等待超时时间
  39. // 当客户端连接创建后会等待一个首次消息,首次消息如果没有在指定时间内发送会主动断开连接
  40. firstReadWait time.Duration
  41. // readWait time.Duration // readWait 读超时
  42. // writeWait 写超时
  43. // 为了保持服务器稳定, 在往客户端发送消息时设置一个超时时间,
  44. // 客户端连接不佳时不用花太多的时间在写客户端消息上, 从而保证服务器的协程不堵塞
  45. writeWait time.Duration
  46. // heartbeatWait 心跳等待
  47. // 服务器控制客户端连接的手段, 通过心跳的方式控制服务器保持存活,
  48. // 一方面是网络协议需要保持活跃, 另一方面是服务器需要踢出长期未活跃的连接
  49. heartbeatWait time.Duration
  50. // LastHeartbeatTime 上次心跳时间
  51. // 记录上次心跳的时间,方便 debug
  52. LastHeartbeatTime time.Time
  53. LastReceivedMessageTime time.Time // LastReceivedMessageTime 上次收到消息的时间
  54. LastReceivedNotifyTime time.Time // LastReceivedNotifyTime 上次收到推送的时间
  55. repos *repositories.Repositories
  56. }
  57. func (c *Client) withRequestIdContext(ctx context.Context, requestId string) context.Context {
  58. return context.WithValue(ctx, "request_id", requestId)
  59. }
  60. // reader 读取到客户端发送的消息, 将消息发送到 nats 里
  61. func (c *Client) reader() {
  62. defer func() {
  63. // c.srv.Disconnect <- c
  64. // c.nats.Unsubscribe <- &subscriber{client: c}
  65. // _ = c.UnderlyingConn.Close()
  66. c.close()
  67. }()
  68. // 首次消息超时设置
  69. //
  70. // firstReadDeadlineTime := time.Now().Add(c.firstReadWait)
  71. // _ = c.UnderlyingConn.SetReadDeadline(firstReadDeadlineTime)
  72. //
  73. // 客户端断开重新连接后无法保证首次发送消息,所以按照之前的 firstReadWait 无法保证一直连接
  74. // 这里设置一个写周期,让客户端持续保持,不会因为重连后没有在规定时间内发送心跳而断开
  75. writeReadDeadlineTime := time.Now().Add(c.writeWait)
  76. _ = c.UnderlyingConn.SetWriteDeadline(writeReadDeadlineTime)
  77. for {
  78. // 接收消息
  79. msgType, bytes, err := c.UnderlyingConn.ReadMessage()
  80. if err != nil {
  81. if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
  82. zap.L().Error("[conn] normal disconnection",
  83. zap.String("user_id", c.UserId),
  84. zap.Error(err))
  85. }
  86. // Close connect
  87. zap.L().Error("[conn] read message error",
  88. zap.String("user_id", c.UserId),
  89. zap.Error(err))
  90. return
  91. }
  92. // 收到客户端 websocket 的 ping 消息
  93. // 收到后默认返回,用途告诉客户端这条连接还是需要保留着
  94. if msgType == websocket.PingMessage {
  95. writeReadDeadlineTime := time.Now().Add(c.writeWait)
  96. c.UnderlyingConn.SetWriteDeadline(writeReadDeadlineTime)
  97. _ = c.UnderlyingConn.WriteMessage(websocket.PongMessage, []byte(""))
  98. continue
  99. }
  100. // 解码消息
  101. message := deserializeMessage(bytes)
  102. // 刷新超时时间, 客户端让客户端保持
  103. heartbeatReadDeadlineTime := time.Now().Add(c.heartbeatWait)
  104. _ = c.UnderlyingConn.SetReadDeadline(heartbeatReadDeadlineTime)
  105. switch message.MessageType() {
  106. case MessageTypePingPong:
  107. c.LastHeartbeatTime = time.Now()
  108. zap.L().Info("[reader] 心跳消息", zap.String("user_id", c.UserId))
  109. // 检查用户是否还在线
  110. // 如果已经不在线需要关闭连接
  111. online, _ := c.repos.OnlineRepository.GetOnline(c.ctx, c.UserId)
  112. if online == nil {
  113. zap.L().Info("[reader] 心跳时检查 online", zap.Any("data", online), zap.String("user_id", c.UserId))
  114. return
  115. }
  116. // 刷新 Redis 的在线信息
  117. _ = c.repos.OnlineRepository.Heartbeat(c.ctx, c.online)
  118. // 心跳消息
  119. c.ReplySend <- newPongMessage(message.RequestId())
  120. zap.L().Info("[reader] 心跳响应", zap.String("user_id", c.UserId))
  121. case MessageTypeUpChating:
  122. c.LastReceivedMessageTime = time.Now()
  123. // 如果是语聊消息, 将消息落库到 tb_message 表
  124. if chating, ok := message.(*Chating); ok {
  125. zap.L().Info("[reader] 收到语聊消息", zap.Any("msg", chating), zap.String("user_id", c.UserId))
  126. content := chating.Content
  127. // 字段验证
  128. if content.SendTime == 0 {
  129. // 暂时先将发送时间字段不做限制, 并且设置一个默认值
  130. content.SendTime = time.Now().UTC().UnixMilli()
  131. // c.ReplySend <- newErrorMessage(message.RequestId(), eris.New("incorrect send time"))
  132. // continue
  133. }
  134. // 查询出来接收消息的人, 这里会过滤掉当前客户端
  135. isSessionId, receivers := c.regexpReceiveUserIds(content.Receiver)
  136. if isSessionId {
  137. content.SessionId = content.Receiver
  138. }
  139. zap.L().Info("[reader] 消息接收人", zap.Strings("receivers", receivers))
  140. // 语聊消息回执, 告诉客户端服务器收到了消息
  141. // 需要做一条回复消息给到客户端, 让客户端知道消息发送成功了, 然后客户端会站直给用户消息发送状态
  142. // 这条消息叫消息回执
  143. c.ReplySend <- newReplyMessage(chating)
  144. // 将消息写入数据库
  145. if err := c.repos.Transaction(c.ctx, func(ctx context.Context, repos *repositories.Repositories) error {
  146. for _, receiver := range receivers {
  147. mid := gid.GetSnowflakeId()
  148. err = c.repos.MessageRepository.Create(c.ctx, &models.Message{
  149. MessageId: mid,
  150. PayloadType: int(content.PayloadType),
  151. Payload: serializePayload(content.PayloadType, content.Payload),
  152. IsRead: -1,
  153. Receiver: receiver,
  154. Sender: c.UserId,
  155. SessionId: sql.NullString{String: content.SessionId, Valid: content.SessionId != ""},
  156. SendTime: time.UnixMilli(content.SendTime).UTC(),
  157. })
  158. if err != nil {
  159. return err
  160. }
  161. // 将消息发送给不同的接收人
  162. content.MessageId = mid
  163. content.Receiver = receiver
  164. // 发送消息到 Nats
  165. if c.nats.nc.IsClosed() {
  166. c.nats.nc = natx.Connect()
  167. }
  168. resp, err := c.nats.nc.RequestMsg(&nats.Msg{
  169. Subject: natx.GetSubject(),
  170. Data: serializeMessage(message),
  171. }, time.Second*10)
  172. if err != nil {
  173. // ERROR
  174. return err
  175. }
  176. var respond RespondStructural
  177. _ = json.Unmarshal(resp.Data, &respond)
  178. if !respond.Ok {
  179. // ERROR
  180. return err
  181. }
  182. }
  183. return nil
  184. }); err != nil {
  185. // ERROR
  186. c.ReplySend <- newErrorMessage(message.RequestId(), err)
  187. zap.L().Error("[reader] 存消息出现问题")
  188. }
  189. }
  190. case MessageTypeNotification:
  191. c.LastReceivedNotifyTime = time.Now()
  192. // 通知消息
  193. if notification, ok := message.(*Notification); ok {
  194. // 将通知消息存起来, 记录下通知消息的
  195. if err := c.repos.NotifyRepository.Create(c.ctx, &models.Notify{
  196. NotifyId: notification.Content.ID,
  197. Sender: notification.Content.Sender,
  198. Receiver: notification.Content.Receiver,
  199. IsSent: models.NotifyNotReceived,
  200. Payload: serializePayload(0, notification.Content.Payload),
  201. }); err != nil {
  202. c.ReplySend <- newErrorMessage(message.RequestId(), err)
  203. continue
  204. }
  205. c.Received <- message
  206. }
  207. case MessageTypeAck:
  208. // 收到消息的回执, 处理回执消息
  209. if notification, ok := message.(*Notification); ok {
  210. notify, err := c.repos.NotifyRepository.Find(c.ctx, int64(notification.Content.AckId))
  211. if err != nil || notify == nil {
  212. c.ReplySend <- newErrorMessage(message.RequestId(), eris.New(fmt.Sprintf("unable to find notify ackId:%d", notification.Content.AckId)))
  213. continue
  214. }
  215. notify.IsSent = models.NotifyReceived
  216. _ = c.repos.NotifyRepository.Save(c.ctx, notify)
  217. }
  218. }
  219. }
  220. }
  221. // recv 客户端收到消息, 收到消息的统一收口
  222. // 消息来源分为两种:
  223. // 1. 消息通过 nats 发送过来
  224. // 2. 消息通过 reader 读取到后被放到了 nats 里然后发送过来
  225. func (c *Client) recv() {
  226. defer close(c.Received)
  227. for {
  228. select {
  229. case message, ok := <-c.Received:
  230. if !ok {
  231. zap.L().Error("[recv] channel error",
  232. zap.Error(eris.New("empty is received channel.")),
  233. zap.String("user_id", c.UserId))
  234. return
  235. }
  236. // 有一些客户端是通过 receiver 来取出 sessionId, 然后去查询消息的, 但这并不是正确的.
  237. switch message.MessageType() {
  238. case MessageTypeUpChating:
  239. if chating, ok := message.(*Chating); ok {
  240. // 将消息类型改为客户端可以接受到的类型, 客户端下行消息只会收到 message_type: MessageTypeDownChating
  241. chating.Type = MessageTypeDownChating
  242. // 这个 message_id 其实用不到, 但为了先兼容App收发消息
  243. chating.Content.MessageId = "_" + gid.GetSnowflakeId()
  244. message = chating
  245. }
  246. }
  247. zap.L().Info("[recv] 发送消息", zap.Any("message", message), zap.String("user_id", c.UserId))
  248. writeReadDeadlineTime := time.Now().Add(c.writeWait)
  249. _ = c.UnderlyingConn.SetWriteDeadline(writeReadDeadlineTime)
  250. if err := c.UnderlyingConn.WriteJSON(message); err != nil {
  251. zap.L().Error("[recv] write error", zap.Error(err), zap.String("user_id", c.UserId))
  252. }
  253. }
  254. }
  255. }
  256. // reply 回复消息专用
  257. func (c *Client) reply() {
  258. defer close(c.ReplySend)
  259. for {
  260. select {
  261. case message, ok := <-c.ReplySend:
  262. if !ok {
  263. zap.L().Error("[reply] channel error",
  264. zap.Error(eris.New("empty is replySend channel.")),
  265. zap.String("user_id", c.UserId))
  266. return
  267. }
  268. // zap.L().Info("[conn] 回复消息到客户端", zap.Any("message", message), zap.String("user_id", c.UserId))
  269. writeReadDeadlineTime := time.Now().Add(c.writeWait)
  270. _ = c.UnderlyingConn.SetWriteDeadline(writeReadDeadlineTime)
  271. if err := c.UnderlyingConn.WriteJSON(message); err != nil {
  272. zap.L().Error("[reply] write error", zap.Error(err), zap.String("user_id", c.UserId))
  273. }
  274. }
  275. }
  276. }
  277. // regexpReceiveUserIds 通过 receiver 获取接收者的用户ID
  278. // 使用正则表达式验证ID 是否是 account_id 或 session_id
  279. // session_id 的话需要查询 session_member 表获取 session 的成员
  280. func (c *Client) regexpReceiveUserIds(receiver string) (bool, []string) {
  281. if receiver == "" {
  282. return false, []string{}
  283. }
  284. reg, _ := regexp.Compile(`[0-9a-f]{8}(-[0-9a-f]{4}){3}-[0-9a-f]{12}`)
  285. if reg.Match([]byte(receiver)) {
  286. return false, []string{receiver}
  287. }
  288. var receivers = make([]string, 0)
  289. if models.IsSessionSingle(receiver) {
  290. single, err := c.repos.SessionSingleRepository.Get(c.ctx, receiver)
  291. if err != nil {
  292. zap.L().Error("unable to get single session", zap.Error(err), zap.String("user_id", c.UserId))
  293. return false, []string{}
  294. }
  295. if single == nil {
  296. // 会话已经被删除
  297. zap.L().Error("[reader] 检查接收人 ID 类型时出错", zap.Error(err), zap.String("user_id", c.UserId))
  298. return false, []string{}
  299. }
  300. toUser := tox.TernaryOperation(
  301. single.TargetUserID == c.UserId,
  302. single.ToUserID,
  303. single.TargetUserID,
  304. ).(string)
  305. receivers = append(receivers, toUser)
  306. } else if models.IsSessionGroup(receiver) {
  307. group, err := c.repos.SessionGroupRepository.Get(c.ctx, receiver)
  308. if err != nil {
  309. zap.L().Error("unable to get session group", zap.Error(err), zap.String("user_id", c.UserId))
  310. return false, []string{}
  311. }
  312. members := group.Members.ToSlice()
  313. for _, member := range members {
  314. if member.UserId != c.UserId {
  315. receivers = append(receivers, member.UserId)
  316. }
  317. }
  318. } else {
  319. return false, []string{}
  320. }
  321. return true, receivers
  322. }
  323. // loadOfflineMessage 查询离线时未接收的消息,并且推送给客户端
  324. func (c *Client) loadOfflineMessage() {
  325. defer func() {
  326. if err := recover(); err != nil {
  327. zap.L().Error("[conn] 加载离线消息出现问题", zap.Any("err", err))
  328. }
  329. return
  330. }()
  331. rid := uuid.NewString()
  332. unreadMsg, err := c.repos.MessageRepository.FindUnread(c.ctx, c.UserId)
  333. if err != nil {
  334. // 查询未读消息出现错误, 给登录的用户发送一个错误信息
  335. c.ReplySend <- newErrorMessage(rid, err)
  336. return
  337. }
  338. for _, msg := range unreadMsg {
  339. nc := c.nats.nc
  340. if nc.IsClosed() {
  341. nc = natx.Connect()
  342. }
  343. chating := Chating{
  344. MessageImpl: MessageImpl{
  345. Type: MessageTypeUpChating,
  346. RId: rid,
  347. },
  348. Content: &ChatingContent{
  349. MessageId: msg.MessageId,
  350. Receiver: msg.Receiver,
  351. SessionId: msg.SessionId.String,
  352. PayloadType: uint8(msg.PayloadType),
  353. Payload: deserializePayload(msg.Payload),
  354. SendTime: msg.SendTime.UTC().UnixMilli(),
  355. },
  356. }
  357. zap.L().Info("[conn] 客户端离线时的离线消息", zap.Any("msg", chating), zap.String("user_id", c.UserId))
  358. if _, err := nc.RequestMsg(&nats.Msg{
  359. Subject: natx.GetSubject(),
  360. Data: chating.Data(),
  361. }, time.Second*5); err != nil {
  362. c.ReplySend <- newErrorMessage(rid, err)
  363. }
  364. }
  365. }
  366. // Online 客户端上线, 将用户数据加入到 Redis
  367. func (c *Client) Online() error {
  368. err := c.repos.OnlineRepository.SetOnline(c.ctx, c.online)
  369. if err != nil {
  370. return eris.Wrapf(err, "unable to set online status for user: %s", c.UserId)
  371. }
  372. return nil
  373. }
  374. // Offline 客户端下线, 将用户信息从 Redis 移除
  375. func (c *Client) Offline() {
  376. _ = c.repos.OnlineRepository.Offline(c.ctx, c.online)
  377. }
  378. // Close websocket connection
  379. func (c *Client) close() {
  380. c.srv.Disconnect <- c
  381. c.nats.Unsubscribe <- &subscriber{client: c}
  382. c.UnderlyingConn.Close()
  383. c.Offline()
  384. }