package server import ( "encoding/json" "regexp" "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/mitchellh/mapstructure" "github.com/rotisserie/eris" "gorm.io/gorm" "sikey.com/websocket/models" "sikey.com/websocket/repositories" "sikey.com/websocket/utils/gid" "sikey.com/websocket/utils/zlog" ) type Client struct { ctx *gin.Context UserId string hub *Hub UnderlyingConn *websocket.Conn online *models.Online isSimpleMsg bool // isSimpleMsg 是否是简单消息 localization string // localization 国际码 firebaseToken string // firebaseToken FCM 推送的 token // Send message channel 发送消息 // 当用户在线时会通过 Send channel 发送在线消息 但如果用户不在线, Send chan *Message readWait time.Duration // readWait 读超时 writeWait time.Duration // writeWait 写超时 pingWait time.Duration // pingWait 心跳超时 repos *repositories.Repositories } func (c *Client) reader() { defer func() { c.hub.Disconnect <- c c.Close() zlog.Debugf("Close client %s", c.UserId) }() c.UnderlyingConn.SetReadDeadline(time.Now().Add(c.readWait)) for { _, bytes, err := c.UnderlyingConn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { zlog.Errorf("error: %v", eris.Wrap(err, c.UserId)) } else { zlog.Errorf("error: %v", eris.Wrap(err, c.UserId)) } // Close connect // _ = c.repos.OnlineRepository.Offline(c.ctx, c.online) return } message := deserializeMessage(bytes) switch message.Type { case MessageTypePingPong: zlog.Debugf("receive ping message from %s", c.UserId) _ = c.repos.OnlineRepository.Heartbeat(c.ctx, c.online) case MessageTypeUpChating, MessageTypeDownChating: // Chat dialogue messages chatingContent := message.Content.(ChatingContent) // Save message to database messageId := gid.GetMessageId() chatingContent.MessageId = messageId chatingContent.SessionId = chatingContent.Receiver err = c.saveMessage(messageId, message.Type, &chatingContent) if err != nil { c.writeError(message.RequestId, err) continue } // Receiver ID format determines whether the receiver is an account or a session zlog.Debugf("message receiver: %s", chatingContent.Receiver) users := c.getReceiverUserIds(chatingContent.Receiver) zlog.Debugf("users: %s", users) for _, id := range users { var messaging = *message messaging.Receiver = id messaging.Content = chatingContent zlog.Infof("Send message %s to %s", c.UserId, id) // Check if the user is online if c.firebaseToken != "" { var online bool if online, err = c.repos.OnlineRepository.Is(c.ctx, id); err != nil { zlog.Error(eris.Wrap(err, "unable to find online user")) continue } else { if !online { // Send FCM message } } } c.hub.Message <- &messaging } // Reply message id message.Content = &ContentReply{MessageId: messageId} } // Reply message // zlog.Debugf("reply message %s to %s", message.RequestId, c.UserId) if message.IsNeedReply() { c.Send <- newReplyMessage(message) // Reset read deadline, prevent Reader from shutting down // zlog.Debugf("reset read deadline for %s", c.UserId) c.UnderlyingConn.SetReadDeadline(time.Now().Add(c.pingWait)) } } } func (c *Client) writer() { ticker := time.NewTicker(c.pingWait) defer func() { ticker.Stop() c.Close() }() for { select { case message, ok := <-c.Send: c.UnderlyingConn.SetWriteDeadline(time.Now().Add(c.writeWait)) if !ok { // The hub closed the channel. c.UnderlyingConn.WriteMessage(websocket.CloseMessage, []byte{}) return } var err error err = c.UnderlyingConn.WriteMessage(websocket.TextMessage, serializationMessage(message)) if err != nil { return } // Received modification message status switch message.Type { case MessageTypeUpChating, MessageTypeDownChating: // Chat dialogue messages if chatingContent, ok := message.Content.(ChatingContent); ok { if msg, err := c.repos.SessionRepository.FindMessageById(c.ctx, chatingContent.MessageId); err == nil { msg.Received = true c.repos.SessionRepository.UpdateMessage(c.ctx, msg) } else { if eris.Is(err, gorm.ErrRecordNotFound) { break } zlog.Error(err) } } } buf, _ := message.MarshalBinary() zlog.Debugf("send message %s to %s content: %v", message.RequestId, c.UserId, string(buf)) case <-ticker.C: // 到时间发送 ping 信号 c.UnderlyingConn.SetWriteDeadline(time.Now().Add(c.writeWait)) if err := c.UnderlyingConn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } func (c *Client) writeError(requestId string, err error) { c.Send <- &Message{ RequestId: requestId, Type: MessageTypeError, Content: ContentError{Err: err.Error()}, } } func (c *Client) saveMessage(messageId string, messageType int8, content *ChatingContent) error { // Standardized structure, This is not an unnecessary step!!! // Filter out excess fields. var err error switch content.PayloadType { case ChatingContentTypeText: var textContent ContentText err = mapstructure.Decode(content.Payload, &textContent) content.Payload = textContent case ChatingContentTypeMetadata: var contentMetadata ContentMetadata err = mapstructure.Decode(content.Payload, &contentMetadata) content.Payload = contentMetadata } if err != nil { return eris.Wrap(err, "unable to decode message content") } payload, _ := json.Marshal(content.Payload) return c.repos.SessionRepository.CreateMessage(c.ctx, &models.SessionMessage{ ID: messageId, SessionId: content.Receiver, Receiver: content.Receiver, Sender: c.UserId, Type: messageType, ContentType: content.PayloadType, Content: payload, IsRead: false, }) } // getReceiverUserIds 通过 receiver 获取接收者的用户ID // 使用正则表达式验证ID 是否是 account_id 或 session_id // session_id 的话需要查询 session_member 表获取 session 的成员 func (c *Client) getReceiverUserIds(receiver string) []string { reg, _ := regexp.Compile(`[0-9a-f]{8}(-[0-9a-f]{4}){3}-[0-9a-f]{12}`) if reg.Match([]byte(receiver)) { return []string{receiver} } members, err := c.repos.SessionRepository.GetSessionMembersRemoveOneself( c.ctx, receiver, c.UserId) if err != nil { return []string{} } var ms = make([]string, len(members)) for i, memb := range members { ms[i] = memb.RefId } return ms } // Close websocket connection func (c *Client) Close() { c.UnderlyingConn.Close() online := &models.Online{UserId: c.UserId, ServerId: c.hub.serverId} c.repos.OnlineRepository.Offline(c.ctx, online) }