luoyangwei преди 1 година
родител
ревизия
9374a8241b

+ 3 - 1
etc/websocket.toml

@@ -13,7 +13,9 @@ writeWait = 10 # 写超时
 natsUrl = "nats://106.75.230.4:4333" 
 
 [mysql]
-dsn = "root:qq123123@tcp(127.0.0.1:3306)/sikey?charset=utf8mb4&parseTime=true&loc=Local"
+# dsn = "root:qq123123@tcp(127.0.0.1:3306)/sikey?charset=utf8mb4&parseTime=true&loc=Local"
+dsn = "root:9RKdJsEQKnjrni9R@tcp(10.23.148.10:3306)/sikey?charset=utf8mb4&parseTime=true&loc=Local"
+ssh = true
 # Ignore ErrRecordNotFound error for logger
 skipDefaultTransaction = true
 # Slow SQL threshold

+ 2 - 1
models/session.go

@@ -26,7 +26,8 @@ type SessionMember struct {
 	DeletedAt gorm.DeletedAt `gorm:"index"`
 
 	SessionId       string
-	AccountId       string
+	RefId           string
+	RefType         string
 	AccountIdentity string
 	JoinTime        time.Time
 }

+ 1 - 1
repositories/online_repostiroy_test.go

@@ -13,7 +13,7 @@ import (
 func TestOnlineRepository_Is(t *testing.T) {
 	config.MustLoadConfig("../etc/websocket.toml")
 	zlog.WithZapLogger(zlog.NewLogger(config.MustLoadLogger()))
-	db := mysqlx.ConnectMysql()
+	db := mysqlx.Connect()
 	rdb := redisx.RedisConnect()
 	repos := NewRepositories(db, rdb)
 	var err error

+ 3 - 3
repositories/session_repositroy.go

@@ -75,8 +75,8 @@ func (repo *sessionRepository) GetSessionMembersRemoveOneself(ctx context.Contex
 	var members []models.SessionMember
 	err = repo.source.WithContext(ctx).
 		Where(&models.SessionMember{SessionId: sessionId}).
-		Not(&models.SessionMember{AccountId: accountId}).
-		Group("account_id").
+		Not(&models.SessionMember{RefId: accountId}).
+		Group("ref_id").
 		Find(&members).Error
 	switch {
 	case err == nil:
@@ -111,7 +111,7 @@ func (repo *sessionRepository) GetJoinedSessions(ctx context.Context, accountId
 	err = repo.source.WithContext(ctx).Where("id in (?)", repo.source.
 		WithContext(ctx).
 		Select("session_id").
-		Where(&models.SessionMember{AccountId: accountId}).Table("tb_session_member")).Find(&sessions).Error
+		Where(&models.SessionMember{RefId: accountId}).Table("tb_session_member")).Find(&sessions).Error
 	switch {
 	case err == nil:
 		return sessions, nil

+ 1 - 1
repositories/session_repositroy_test.go

@@ -12,7 +12,7 @@ import (
 
 func getSessionRepository() SessionRepository {
 	config.MustLoadConfig("../etc/websocket.toml")
-	source := mysqlx.ConnectMysql()
+	source := mysqlx.Connect()
 	return NewSessionRepository(source)
 }
 

+ 1 - 1
server/client.go

@@ -226,7 +226,7 @@ func (c *Client) getReceiverUserIds(receiver string) []string {
 
 	var ms = make([]string, len(members))
 	for i, memb := range members {
-		ms[i] = memb.AccountId
+		ms[i] = memb.RefId
 	}
 	return ms
 }

+ 16 - 0
server/hub_test.go

@@ -2,11 +2,15 @@ package server
 
 import (
 	"context"
+	"fmt"
 	"log"
 	"testing"
 
+	"github.com/gin-gonic/gin"
 	"github.com/redis/go-redis/v9"
 	"sikey.com/websocket/config"
+	"sikey.com/websocket/repositories"
+	"sikey.com/websocket/utils/mysqlx"
 )
 
 var userId = "d6faa0af-b863-48bb-b658-d961a9381585"
@@ -31,6 +35,18 @@ func TestHub_ConnectMessage(t *testing.T) {
 	}
 }
 
+func TestHub_getReceiverUserIds(t *testing.T) {
+	config.MustLoadConfig("../etc/websocket.toml")
+	repos := repositories.NewRepositories(mysqlx.Connect(), nil)
+	c := &Client{
+		ctx:    &gin.Context{},
+		UserId: "beaf8878-03d8-4bf6-a783-22a0d4881265",
+		repos:  repos,
+	}
+	users := c.getReceiverUserIds("1753371851436216320")
+	fmt.Println(users)
+}
+
 // func TestHub_ConnectMessage(t *testing.T) {
 // 	writer := &kafka.Writer{
 // 		Addr:                   kafka.TCP("106.75.230.4:9092"),

+ 80 - 4
utils/mysqlx/mysql.go

@@ -1,17 +1,26 @@
 package mysqlx
 
 import (
+	"context"
+	"fmt"
 	"log"
+	"net"
 	"os"
+	"strconv"
+	"strings"
 	"time"
 
+	"github.com/go-sql-driver/mysql"
+	_ "github.com/go-sql-driver/mysql"
 	"github.com/spf13/viper"
-	"gorm.io/driver/mysql"
+	"golang.org/x/crypto/ssh"
+	sql "gorm.io/driver/mysql"
 	"gorm.io/gorm"
 	"gorm.io/gorm/logger"
 )
 
 type mysqlConfig struct {
+	SSH                       bool          `toml:"ssh"` // SSH 是否开启SSH
 	Dsn                       string        // Dsn 数据源地址
 	SkipDefaultTransaction    bool          // SkipDefaultTransaction 跳过默认事务
 	SlowThreshold             time.Duration // SlowThreshold 慢 SQL 阈值
@@ -21,10 +30,33 @@ type mysqlConfig struct {
 	MaxIdleConns              int           // MaxIdleConns 空闲连接池中连接的最大数量
 }
 
-// ConnectMysql 初始化 mysql 连接
-func ConnectMysql() *gorm.DB {
+type driverConfig struct {
+	username string
+	password string
+	protocol string
+	address  string
+	port     int
+	db       string
+	params   string
+}
+
+func (vc *driverConfig) formatDSN() string {
+	return vc.username + ":" + vc.password + "@" +
+		vc.protocol + "(" + vc.address + ":" + strconv.Itoa(vc.port) + ")/" + vc.db + "?" + vc.params
+}
+
+type Dialer struct {
+	client *ssh.Client
+}
+
+func (v *Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
+	return v.client.Dial("tcp", address)
+}
+
+// Connect 初始化 mysql 连接
+func Connect() *gorm.DB {
 	cfg := readMysqlConfig()
-	conn, err := gorm.Open(mysql.New(mysql.Config{
+	conn, err := gorm.Open(sql.New(sql.Config{
 		DSN:                       cfg.Dsn,
 		DefaultStringSize:         255,
 		SkipInitializeWithVersion: false,
@@ -45,12 +77,56 @@ func ConnectMysql() *gorm.DB {
 	return conn
 }
 
+func withDsn(dsn string) *driverConfig {
+	// root:qq123123@tcp(127.0.0.1:3306)/sikey?charset=utf8mb4&parseTime=true&loc=Local
+	var user = strings.Split(dsn, "@")
+	var ua = strings.Split(user[0], ":")
+	var protocol = strings.Split(user[1], "(")
+	var address = strings.Split(protocol[1], ")")
+	var addr = strings.Split(address[0], ":")
+	var port, err = strconv.ParseInt(addr[1], 10, 64)
+	if err != nil {
+		port = 3306
+	}
+	return &driverConfig{
+		username: ua[0],
+		password: ua[1],
+		protocol: protocol[0],
+		address:  addr[0],
+		port:     int(port),
+		db:       strings.Split(address[1], "?")[0][1:],
+		params:   strings.Split(address[1], "?")[1],
+	}
+}
+
 // readMysqlConfig 加载配置
 func readMysqlConfig() mysqlConfig {
 	var cfg mysqlConfig
 	if err := viper.UnmarshalKey("mysql", &cfg); err != nil {
 		log.Fatalln(err)
 	}
+
+	if cfg.SSH {
+		config := &ssh.ClientConfig{
+			User: "root",
+			Auth: []ssh.AuthMethod{
+				ssh.Password("RHTUH2z49aEXnsgz"),
+			},
+			HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+		}
+		var err error
+		var clt *ssh.Client
+		if clt, err = ssh.Dial("tcp", "106.75.230.4:22", config); err != nil {
+			log.Fatalln(err)
+		}
+
+		var protocol = "ssh"
+		vc := withDsn(cfg.Dsn)
+		vc.protocol = protocol
+		cfg.Dsn = vc.formatDSN()
+		fmt.Println(cfg.Dsn)
+		mysql.RegisterDialContext(protocol, (&Dialer{client: clt}).Dial)
+	}
 	return cfg
 }
 

+ 12 - 0
utils/mysqlx/mysql_test.go

@@ -0,0 +1,12 @@
+package mysqlx
+
+import (
+	"fmt"
+	"testing"
+)
+
+func TestGetDsn(t *testing.T) {
+	var str = `root:qq123123@tcp(127.0.0.1:3306)/sikey?charset=utf8mb4&parseTime=true&loc=Local`
+	dsn := withDsn(str)
+	fmt.Println(dsn)
+}

+ 1 - 1
websocket.go

@@ -53,7 +53,7 @@ func newApp() *gin.Engine {
 			},
 		},
 		Hub:          server.NewHub(id),
-		Repositories: repositories.NewRepositories(mysqlx.ConnectMysql(), redisx.RedisConnect()),
+		Repositories: repositories.NewRepositories(mysqlx.Connect(), redisx.RedisConnect()),
 	}
 
 	app.GET("/websocket/endpoint", func(ctx *gin.Context) { srv.WebsocketHandler(ctx) })