瀏覽代碼

通知消息

luoyangwei 1 年之前
父節點
當前提交
bd16be4e0b
共有 9 個文件被更改,包括 229 次插入30 次删除
  1. 19 0
      models/notify.go
  2. 17 0
      repositories/message_repository.go
  3. 25 0
      repositories/notify_repository.go
  4. 15 0
      repositories/repositories.go
  5. 22 1
      server/client.go
  6. 52 21
      server/message.go
  7. 11 2
      server/nats.go
  8. 26 1
      server/nats_test.go
  9. 42 5
      server/server.go

+ 19 - 0
models/notify.go

@@ -0,0 +1,19 @@
+package models
+
+import (
+	"encoding/json"
+	"time"
+)
+
+type Notify struct {
+	NotifyId  int
+	Sender    string
+	Receiver  string
+	Payload   json.RawMessage
+	CreatedAt time.Time
+	Id        int64
+}
+
+func (*Notify) TableName() string {
+	return "tb_notify"
+}

+ 17 - 0
repositories/message_repository.go

@@ -9,6 +9,9 @@ import (
 
 type MessageRepository interface {
 	Create(ctx context.Context, m *models.Message) error
+
+	// 查询未读消息
+	FindUnread(ctx context.Context, uid string) ([]models.Message, error)
 }
 
 var _ MessageRepository = (*messageRepository)(nil)
@@ -17,6 +20,20 @@ type messageRepository struct {
 	source *gorm.DB
 }
 
+// FindUnread implements MessageRepository.
+func (repo *messageRepository) FindUnread(ctx context.Context, uid string) ([]models.Message, error) {
+	var err error
+	var mods []models.Message
+	err = repo.source.WithContext(ctx).
+		Where(&models.Message{
+			IsRead:   -1,
+			Receiver: uid,
+		}).
+		Find(&mods).
+		Error
+	return resultSlice(mods, err)
+}
+
 // Create implements MessageRepository.
 func (repo *messageRepository) Create(ctx context.Context, m *models.Message) error {
 	return repo.source.WithContext(ctx).Model(&models.Message{}).Create(m).Error

+ 25 - 0
repositories/notify_repository.go

@@ -0,0 +1,25 @@
+package repositories
+
+import (
+	"context"
+
+	"gorm.io/gorm"
+	"sikey.com/websocket/models"
+)
+
+type NotifyRepository interface {
+	Create(ctx context.Context, n *models.Notify) error
+}
+
+func NewNotifyRepository(source *gorm.DB) NotifyRepository {
+	return &notifyRepository{source: source}
+}
+
+type notifyRepository struct {
+	source *gorm.DB
+}
+
+// Create implements NotifyRepository.
+func (repo *notifyRepository) Create(ctx context.Context, n *models.Notify) error {
+	return repo.source.WithContext(ctx).Model(&models.Notify{}).Create(n).Error
+}

+ 15 - 0
repositories/repositories.go

@@ -4,7 +4,9 @@ import (
 	"context"
 
 	"github.com/redis/go-redis/v9"
+	"github.com/rotisserie/eris"
 	"gorm.io/gorm"
+	"sikey.com/websocket/models"
 )
 
 type TransactionFun func(ctx context.Context, repos *Repositories) error
@@ -18,6 +20,7 @@ type Repositories struct {
 	FirebaseMessageRepository FirebaseMessageRepository
 	MessageReadLogRepository  MessageReadLogRepository
 	MessageRepository         MessageRepository
+	NotifyRepository          NotifyRepository
 }
 
 func NewRepositories(source *gorm.DB, rdb *redis.Client) *Repositories {
@@ -30,6 +33,7 @@ func NewRepositories(source *gorm.DB, rdb *redis.Client) *Repositories {
 		FirebaseMessageRepository: NewFirebaseMessageRepository(source),
 		MessageReadLogRepository:  NewMessageReadLogRepository(source),
 		MessageRepository:         NewMessageRepository(source),
+		NotifyRepository:          NewNotifyRepository(source),
 	}
 }
 
@@ -38,3 +42,14 @@ func (repos *Repositories) Transaction(ctx context.Context, fn TransactionFun) e
 		return fn(ctx, repos)
 	})
 }
+
+func resultSlice[T any](mods []T, err error) ([]T, error) {
+	switch {
+	case err == nil:
+		return mods, nil
+	case eris.Is(err, models.ErrRecordNotFound):
+		return make([]T, 0), nil
+	default:
+		return nil, err
+	}
+}

+ 22 - 1
server/client.go

@@ -140,6 +140,9 @@ func (c *Client) reader() {
 						content.Receiver = receiver
 
 						// 发送消息到 Nats
+						if c.nats.nc.IsClosed() {
+							c.nats.nc = natx.Connect()
+						}
 						resp, err := c.nats.nc.RequestMsg(&nats.Msg{
 							Subject: natx.GetSubject(),
 							Data:    serializeMessage(message),
@@ -165,8 +168,26 @@ func (c *Client) reader() {
 					c.ReplySend <- newErrorMessage(message.RequestId(), err)
 				}
 			}
-
 			zap.L().Info("received", zap.Any("message", message))
+
+		case MessageTypeNotification:
+
+			// 通知消息
+			if notification, ok := message.(*Notification); ok {
+				// 将通知消息存起来, 记录下通知消息的
+				if err := c.repos.NotifyRepository.Create(c.ctx, &models.Notify{
+					NotifyId: notification.Content.ID,
+					Sender:   notification.Content.Sender,
+					Receiver: notification.Content.Receiver,
+					Payload:  serializePayload(0, notification.Content.Payload),
+				}); err != nil {
+					c.ReplySend <- newErrorMessage(message.RequestId(), err)
+					continue
+				}
+
+				c.Received <- message
+			}
+
 		}
 	}
 }

+ 52 - 21
server/message.go

@@ -26,14 +26,14 @@ const (
 )
 
 const (
-	NotificationTypeDeviceStatusChange = 100 // NotificationTypeDeviceStatusChange 设备绑定状态变更通知
-	NotificationTypeAskLocation        = 111 // NotificationTypeAskLocation 询问设备位置
-	NotificationTypeChangedContacts    = 130 // NotificationTypeChangedContacts 联系人变动通知
-	NotificationTypeCreatedFriend      = 120 //  NotificationTypeCreatedFriend 通知添加好友结果通知
-	NotificationTypeChangedAlarmClock  = 140 //  NotificationTypeChangedAlarmClock 闹钟变更通知
-	NotificationTypeChangedSchoolMode  = 141 //  NotificationTypeChangedSchoolMode 上课禁用变更通知
-	NotificationTypeDeviceShutdown     = 150 //  NotificationTypeDeviceShutdown 设备关机通知
-	NotificationTypeDeviceReboot       = 160 //  NotificationTypeDeviceReboot 设备重启通知
+	NotificationTypeDeviceStatusChange int = 100 // NotificationTypeDeviceStatusChange 设备绑定状态变更通知
+	NotificationTypeAskLocation        int = 111 // NotificationTypeAskLocation 询问设备位置
+	NotificationTypeChangedContacts    int = 130 // NotificationTypeChangedContacts 联系人变动通知
+	NotificationTypeCreatedFriend      int = 120 //  NotificationTypeCreatedFriend 通知添加好友结果通知
+	NotificationTypeChangedAlarmClock  int = 140 //  NotificationTypeChangedAlarmClock 闹钟变更通知
+	NotificationTypeChangedSchoolMode  int = 141 //  NotificationTypeChangedSchoolMode 上课禁用变更通知
+	NotificationTypeDeviceShutdown     int = 150 //  NotificationTypeDeviceShutdown 设备关机通知
+	NotificationTypeDeviceReboot       int = 160 //  NotificationTypeDeviceReboot 设备重启通知
 )
 
 // var _ encoding.BinaryMarshaler = (*Message)(nil)
@@ -118,32 +118,51 @@ type metadata struct {
 	Duration uint   `json:"duration" mapstructure:""` // Duration 视频/语音时长
 }
 
-type ErrMessage struct {
+type Err struct {
 	MessageImpl
 
 	Content string `json:"content"` // 错误内容或者提示
 }
 
 // Data implements Message.
-func (c *ErrMessage) Data() []byte {
+func (c *Err) Data() []byte {
 	data, _ := json.Marshal(c)
 	return data
 }
 
-var _ Message = (*HeartbeatMessage)(nil)
+var _ Message = (*Heartbeat)(nil)
 
-type HeartbeatMessage struct {
+type Heartbeat struct {
 	MessageImpl
 
 	Content string `json:"content"`
 }
 
 // Data implements Message.
-func (h *HeartbeatMessage) Data() []byte {
+func (h *Heartbeat) Data() []byte {
 	data, _ := json.Marshal(h)
 	return data
 }
 
+// 通知消息
+type Notification struct {
+	MessageImpl
+	Content *NotificationContent `json:"context"`
+}
+
+type NotificationContent struct {
+	ID       int                    `json:"id"` // ID 通知消息的ID, 不同的通知会有不同的ID, 可以 NotificationType
+	Receiver string                 `json:"receiver"`
+	Sender   string                 `json:"sender"`
+	Payload  map[string]interface{} `json:"payload,omitempty"`
+}
+
+// Data implements Message.
+func (n *Notification) Data() []byte {
+	data, _ := json.Marshal(n)
+	return data
+}
+
 type messageOption func(msg Message)
 
 func deserializeMessage(data []byte, opts ...messageOption) Message {
@@ -160,9 +179,14 @@ func deserializeMessage(data []byte, opts ...messageOption) Message {
 		_ = json.Unmarshal(data, &chating)
 		msg = &chating
 	case MessageTypePingPong:
-		var heartbeat HeartbeatMessage
+		var heartbeat Heartbeat
 		_ = json.Unmarshal(data, &heartbeat)
 		msg = &heartbeat
+
+	case MessageTypeNotification:
+		var notification Notification
+		_ = json.Unmarshal(data, &notification)
+		msg = &notification
 	}
 
 	for _, opt := range opts {
@@ -187,29 +211,36 @@ func serializeMessage(message Message) []byte {
 }
 
 func serializePayload(payloadType PayloadType, m map[string]interface{}) []byte {
-	if payloadType == PayloadTypeText {
+	switch payloadType {
+	case PayloadTypeText:
 		var text text
 		_ = mapstructure.Decode(m, &text)
 		data, _ := json.Marshal(text)
 		return data
-	}
-
-	if payloadType == PayloadTypeMetadata {
+	case PayloadTypeMetadata:
 		var metadata metadata
 		_ = mapstructure.Decode(m, &metadata)
 		data, _ := json.Marshal(metadata)
 		return data
+
+	default:
+		data, _ := json.Marshal(m)
+		return data
 	}
+}
 
-	return nil
+func deserializePayload(data []byte) map[string]interface{} {
+	var res map[string]interface{}
+	_ = json.Unmarshal(data, &res)
+	return res
 }
 
 func newErrorMessage(rid string, err error) Message {
-	return &ErrMessage{MessageImpl: MessageImpl{Type: MessageTypeError, RId: rid}, Content: err.Error()}
+	return &Err{MessageImpl: MessageImpl{Type: MessageTypeError, RId: rid}, Content: err.Error()}
 }
 
 func newPongMessage(rid string) Message {
-	return &HeartbeatMessage{MessageImpl: MessageImpl{Type: MessageTypePingPong, RId: rid}, Content: "pong"}
+	return &Heartbeat{MessageImpl: MessageImpl{Type: MessageTypePingPong, RId: rid}, Content: "pong"}
 }
 
 // type EncodedMessage struct {

+ 11 - 2
server/nats.go

@@ -94,10 +94,19 @@ func (n *Nats) run() {
 			resp := RespondStructural{RequestId: message.RequestId(), Ok: true}
 			_ = natsMsg.Respond([]byte(resp.Marshaler()))
 
-			n.mutex.RLock()
+			// 从消息里获取得到接收人信息
+			var receiver string
 			if chating, ok := message.(*Chating); ok {
+				receiver = chating.Content.Receiver
+			}
+			if notification, ok := message.(*Notification); ok {
+				receiver = notification.Content.Receiver
+			}
+
+			n.mutex.RLock()
+			if receiver != "" {
 				for uid, sub := range n.Subscribers {
-					if uid == chating.Content.Receiver {
+					if uid == receiver {
 						// 写入消息到不同的客户端
 						sub.client.Received <- message
 					}

+ 26 - 1
server/nats_test.go

@@ -13,7 +13,6 @@ import (
 
 func TestNats_Connect(t *testing.T) {
 	nc, _ := nats.Connect("nats://127.0.0.1:4222")
-
 	chating := Chating{
 		MessageImpl: MessageImpl{
 			Type: MessageTypeUpChating,
@@ -43,8 +42,34 @@ func TestNats_Connect(t *testing.T) {
 	resp, err := nc.RequestMsg(&msg, time.Second*10)
 	if err != nil {
 		slog.Error("err", err)
+		return
 	}
 
 	str := string(resp.Data)
 	slog.Info(fmt.Sprintf("ok %s", str))
 }
+
+func TestNats_NotificationMessage(t *testing.T) {
+	nc, _ := nats.Connect("nats://127.0.0.1:4222")
+
+	notification := Notification{
+		MessageImpl: MessageImpl{Type: MessageTypeNotification, RId: uuid.NewString()},
+		Content: &NotificationContent{
+			ID:       NotificationTypeChangedContacts,
+			Receiver: "dce34294-467d-4fcb-a550-b729f7167f69",
+			Sender:   "f2b17475-7800-47e2-a1e3-dfa39828b54d",
+		},
+	}
+
+	natsMsg := &nats.Msg{
+		Subject: "clients.message",
+		Data:    notification.Data(),
+	}
+	resp, err := nc.RequestMsg(natsMsg, time.Second*10)
+	if err != nil {
+		slog.Error("err", err)
+		return
+	}
+	str := string(resp.Data)
+	slog.Info(fmt.Sprintf("ok %s", str))
+}

+ 42 - 5
server/server.go

@@ -11,6 +11,7 @@ import (
 	"github.com/golang-jwt/jwt/v5"
 	"github.com/google/uuid"
 	"github.com/gorilla/websocket"
+	"github.com/nats-io/nats.go"
 	"github.com/rotisserie/eris"
 	"github.com/spf13/viper"
 	"go.uber.org/zap"
@@ -27,7 +28,7 @@ type Server struct {
 	id   string
 	nats *Nats
 	// ctx          *gin.Context
-	repo *repositories.Repositories
+	repos *repositories.Repositories
 
 	clients      map[*Client]struct{}
 	clientsMutex sync.RWMutex
@@ -50,7 +51,7 @@ func NewServer() *Server {
 				return true
 			},
 		},
-		repo: repositories.NewRepositories(dbx.GetConnect(), rdbx.GetConnect()),
+		repos: repositories.NewRepositories(dbx.GetConnect(), rdbx.GetConnect()),
 
 		clients:      make(map[*Client]struct{}),
 		clientsMutex: sync.RWMutex{},
@@ -88,7 +89,6 @@ func (srv *Server) events() {
 }
 
 func (srv *Server) WebsocketHandler(ctx *gin.Context) {
-
 	// Builder headers
 	headers := headerBuilder(ctx)
 
@@ -136,7 +136,7 @@ func (srv *Server) WebsocketHandler(ctx *gin.Context) {
 		//isSimpleMsg:  headers[keys.SimpleHeader].(bool),
 		//localization: headers[keys.LocalizationHeader].(string),
 
-		repos: srv.repo,
+		repos: srv.repos,
 	}
 
 	// 连接监听 nats 的消息, 加入到监听者
@@ -146,7 +146,7 @@ func (srv *Server) WebsocketHandler(ctx *gin.Context) {
 	srv.Connect <- client
 
 	// Online status to redis
-	if err := srv.repo.OnlineRepository.SetOnline(ctx, client.online); err != nil {
+	if err := srv.repos.OnlineRepository.SetOnline(ctx, client.online); err != nil {
 		ctx.AbortWithError(http.StatusInternalServerError,
 			eris.Wrapf(err, "unable to set online status for user: %s", client.UserId))
 		return
@@ -155,6 +155,43 @@ func (srv *Server) WebsocketHandler(ctx *gin.Context) {
 	go client.recv()
 	go client.reader()
 	go client.writer()
+
+	// 查询离线时未接收的消息,并且推送给客户端
+	rid := uuid.NewString()
+	unreadMsg, err := srv.repos.MessageRepository.FindUnread(ctx, id)
+	if err != nil {
+		// 查询未读消息出现错误, 给登录的用户发送一个错误信息
+		client.ReplySend <- newErrorMessage(rid, err)
+		return
+	}
+
+	for _, msg := range unreadMsg {
+		nc := client.nats.nc
+		if nc.IsClosed() {
+			nc = natx.Connect()
+		}
+
+		chating := Chating{
+			MessageImpl: MessageImpl{
+				Type: MessageTypeUpChating,
+				RId:  rid,
+			},
+			Content: &ChatingContent{
+				MessageId:   msg.MessageId,
+				Receiver:    msg.Receiver,
+				SessionId:   msg.SessionId.String,
+				PayloadType: uint8(msg.PayloadType),
+				Payload:     deserializePayload(msg.Payload),
+				SendTime:    msg.SendTime.UTC().UnixMilli(),
+			},
+		}
+		if _, err := nc.RequestMsg(&nats.Msg{
+			Subject: natx.GetSubject(),
+			Data:    chating.Data(),
+		}, time.Second*5); err != nil {
+			client.ReplySend <- newErrorMessage(rid, err)
+		}
+	}
 }
 
 func serverId() string {