luoyangwei 1 жил өмнө
parent
commit
c3bedc5669

+ 19 - 3
server/hub.go

@@ -53,17 +53,22 @@ func (h *Hub) run() {
 		select {
 		case client := <-h.Connect:
 
+			h.mutex.Lock()
 			h.clients[client.UserId] = client
 			h.exchange.OnPublishConnect(client)
+			h.mutex.Unlock()
 
 		case client := <-h.Disconnect:
 
+			h.mutex.Lock()
 			close(client.Send)
 			delete(h.clients, client.UserId)
 			h.exchange.OnPublishDisconnect(client)
+			h.mutex.Unlock()
 
 		case message := <-h.Message:
 
+			h.mutex.RLock()
 			if client, ok := h.clients[message.Receiver]; ok {
 				if client.isRemotely {
 					h.exchange.OnPublishMessage(client, message)
@@ -72,29 +77,40 @@ func (h *Hub) run() {
 					client.Send <- message
 				}
 			}
+			h.mutex.RUnlock()
 
 		case conn := <-h.exchange.Connect:
 
+			h.mutex.Lock()
 			h.clients[conn.UserId] = &Client{
 				isRemotely: true,
 				hub:        h,
 				Send:       make(chan *Message),
 			}
+			h.mutex.Unlock()
 
 		case conn := <-h.exchange.Disconnect:
 
+			h.mutex.Lock()
 			if client, ok := h.clients[conn.UserId]; ok {
 				close(client.Send)
 				delete(h.clients, client.UserId)
 			}
+			h.mutex.Unlock()
 
 		case message := <-h.exchange.Message:
-			zlog.Info("收到远程消息:", message)
-			zlog.Info("receiver: ", message.Receiver)
+
+			h.mutex.RLock()
 			if client, ok := h.clients[message.Receiver]; ok {
 				client.Send <- message
-				zlog.Info("发送消息成功")
 			}
+			h.mutex.RUnlock()
 		}
 	}
 }
+
+func (h *Hub) GetClients() map[string]*Client {
+	h.mutex.RLock()
+	defer h.mutex.RUnlock()
+	return h.clients
+}

+ 1 - 1
server/redis_exchange.go

@@ -41,7 +41,7 @@ var (
 	disconnectChannelEvent = "client.event.disconnect"
 	messageChannelEvent    = "client.event.message"
 
-	clients = "client:users:%s"
+	clients = "client.users.%s"
 )
 
 type (

+ 48 - 18
server/server_test.go

@@ -2,9 +2,11 @@ package server
 
 import (
 	"fmt"
+	"github.com/rotisserie/eris"
 	"log"
 	"strconv"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 
@@ -19,36 +21,64 @@ func BenchmarkTest1(b *testing.B) {
 	}
 }
 
-var conns []*websocket.Conn
-
 func TestServer_WebsocketKeep(t *testing.T) {
 	wg := sync.WaitGroup{}
-	for j := 10; j < 100; j++ {
+
+	var count int32
+	var failed int32
+
+	go func() {
+		wg.Add(1)
+		defer wg.Done()
+		for {
+			time.Sleep(3 * time.Second)
+			fmt.Printf("%d, %d \n", atomic.LoadInt32(&count), atomic.LoadInt32(&failed))
+		}
+	}()
+
+	for i := 0; i < 3000; i++ {
 		go func() {
+			wg.Add(1)
 			defer wg.Done()
-			for i := 0; i < 100; i++ {
-				conn, _, err := websocket.DefaultDialer.Dial(
-					// fmt.Sprintf(`ws://127.0.0.1:10082/websocket/endpoint?X-Websocket-Header-ID=%s`, uuid.NewString()),
-					fmt.Sprintf(`ws://106.75.230.4:10082/websocket/endpoint?X-Websocket-Header-ID=%s`, uuid.NewString()),
-					nil)
-				if err != nil {
-					continue
+
+			conn, _, err := websocket.DefaultDialer.Dial(
+				fmt.Sprintf(`ws://127.0.0.1:10082/websocket/endpoint?X-Websocket-Header-ID=%s`, uuid.NewString()),
+				//fmt.Sprintf(`ws://106.75.230.4:10082/websocket/endpoint?X-Websocket-Header-ID=%s`, uuid.NewString()),
+				nil)
+			if err != nil {
+				atomic.StoreInt32(&failed, failed+1)
+				log.Println(eris.Wrap(err, "unable to create connect"))
+				return
+			}
+
+			conn.SetCloseHandler(func(code int, text string) error {
+				log.Println(code, text)
+				return nil
+			})
+
+			go func() {
+				for {
+					time.Sleep(3 * time.Second)
+					message := Message{Type: MessageTypePingPong, RequestId: strconv.Itoa(int(time.Now().UnixMilli())), Content: "ping"}
+					err = conn.WriteMessage(websocket.TextMessage, serializationMessage(&message))
+					if err != nil {
+						log.Println(eris.Wrap(err, "unable to write message"))
+						return
+					}
 				}
+			}()
 
-				message := Message{Type: MessageTypePingPong, RequestId: strconv.Itoa(int(time.Now().UnixMilli())), Content: "ping"}
-				conn.SetWriteDeadline(time.Now().Add(time.Second * 60))
-				err = conn.WriteMessage(websocket.TextMessage, serializationMessage(&message))
+			atomic.StoreInt32(&count, count+1)
+			for {
+				_, _, err = conn.ReadMessage()
 				if err != nil {
-					continue
+					log.Println(eris.Wrap(err, "unable to read message"))
+					return
 				}
-
-				conns = append(conns, conn)
 			}
 		}()
-		wg.Add(1)
 	}
 	wg.Wait()
-	fmt.Println(len(conns))
 }
 
 func BenchmarkServer_WebsocketPressure(b *testing.B) {

+ 21 - 15
websocket.go

@@ -35,6 +35,17 @@ func newApp() *gin.Engine {
 	app := gin.Default()
 	ginpprof.Wrap(app)
 
+	hub := server.NewHub(server.HubConfig{
+		ServerId: uuid.NewString(),
+		Rdb: redis.NewUniversalClient(&redis.UniversalOptions{
+			Addrs:    []string{"106.75.230.4:6379"},
+			Password: "sikey!Q@W#E456",
+			DB:       0,
+		}),
+		ConnectSize:    1024,
+		DisconnectSize: 1024,
+		MessageSize:    125,
+	})
 	srv := &server.Server{
 		Upgrader: websocket.Upgrader{
 			ReadBufferSize:  1024,
@@ -43,23 +54,18 @@ func newApp() *gin.Engine {
 				return true
 			},
 		},
-		WriteWait: 10 * time.Second,
-		ReadWait:  10 * time.Second,
-		PingWait:  120 * time.Second,
-		Hub: server.NewHub(server.HubConfig{
-			ServerId: uuid.NewString(),
-			Rdb: redis.NewUniversalClient(&redis.UniversalOptions{
-				Addrs:    []string{"106.75.230.4:6379"},
-				Password: "sikey!Q@W#E456",
-				DB:       0,
-			}),
-			ConnectSize:    1024,
-			DisconnectSize: 1024,
-			MessageSize:    125,
-		}),
+		WriteWait:    10 * time.Second,
+		ReadWait:     10 * time.Second,
+		PingWait:     120 * time.Second,
+		Hub:          hub,
 		Repositories: repositories.NewRepositories(mysqlx.ConnectMysql()),
 	}
 	app.GET("/websocket/endpoint", func(ctx *gin.Context) { server.WebsocketHandler(ctx, srv) })
-
+	app.GET("/index", func(ctx *gin.Context) {
+		clients := hub.GetClients()
+		ctx.JSON(http.StatusOK, gin.H{
+			"clients": len(clients),
+		})
+	})
 	return app
 }