mysql.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package mysqlx
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "net"
  7. "os"
  8. "strconv"
  9. "strings"
  10. "time"
  11. "github.com/go-sql-driver/mysql"
  12. _ "github.com/go-sql-driver/mysql"
  13. "github.com/spf13/viper"
  14. "golang.org/x/crypto/ssh"
  15. sql "gorm.io/driver/mysql"
  16. "gorm.io/gorm"
  17. "gorm.io/gorm/logger"
  18. )
  19. type mysqlConfig struct {
  20. SSH bool `toml:"ssh"` // SSH 是否开启SSH
  21. Dsn string // Dsn 数据源地址
  22. SkipDefaultTransaction bool // SkipDefaultTransaction 跳过默认事务
  23. SlowThreshold time.Duration // SlowThreshold 慢 SQL 阈值
  24. IgnoreRecordNotFoundError bool // IgnoreRecordNotFoundError 忽略记录未找到的错误
  25. MaxLifetime time.Duration // MaxLifetime 连接的有效时长
  26. MaxOpenConns int // MaxOpenConns 打开数据库连接的最大数量。
  27. MaxIdleConns int // MaxIdleConns 空闲连接池中连接的最大数量
  28. }
  29. type driverConfig struct {
  30. username string
  31. password string
  32. protocol string
  33. address string
  34. port int
  35. db string
  36. params string
  37. }
  38. func (vc *driverConfig) formatDSN() string {
  39. return vc.username + ":" + vc.password + "@" +
  40. vc.protocol + "(" + vc.address + ":" + strconv.Itoa(vc.port) + ")/" + vc.db + "?" + vc.params
  41. }
  42. type Dialer struct {
  43. client *ssh.Client
  44. }
  45. func (v *Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
  46. return v.client.Dial("tcp", address)
  47. }
  48. // Connect 初始化 mysql 连接
  49. func Connect() *gorm.DB {
  50. cfg := readMysqlConfig()
  51. conn, err := gorm.Open(sql.New(sql.Config{
  52. DSN: cfg.Dsn,
  53. DefaultStringSize: 255,
  54. SkipInitializeWithVersion: false,
  55. }), getGormConfig(cfg))
  56. if err != nil {
  57. log.Panicln(err)
  58. }
  59. sqlDB, _ := conn.DB()
  60. // SetMaxIdleConns 设置空闲连接池中连接的最大数量
  61. sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
  62. // SetMaxOpenConns 设置打开数据库连接的最大数量。
  63. sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
  64. // SetConnMaxLifetime 设置了连接可复用的最大时间。
  65. sqlDB.SetConnMaxLifetime(cfg.MaxLifetime)
  66. log.Printf("Mysql connected to %s \n", cfg.Dsn)
  67. return conn
  68. }
  69. func withDsn(dsn string) *driverConfig {
  70. // root:qq123123@tcp(127.0.0.1:3306)/sikey?charset=utf8mb4&parseTime=true&loc=Local
  71. var user = strings.Split(dsn, "@")
  72. var ua = strings.Split(user[0], ":")
  73. var protocol = strings.Split(user[1], "(")
  74. var address = strings.Split(protocol[1], ")")
  75. var addr = strings.Split(address[0], ":")
  76. var port, err = strconv.ParseInt(addr[1], 10, 64)
  77. if err != nil {
  78. port = 3306
  79. }
  80. return &driverConfig{
  81. username: ua[0],
  82. password: ua[1],
  83. protocol: protocol[0],
  84. address: addr[0],
  85. port: int(port),
  86. db: strings.Split(address[1], "?")[0][1:],
  87. params: strings.Split(address[1], "?")[1],
  88. }
  89. }
  90. // readMysqlConfig 加载配置
  91. func readMysqlConfig() mysqlConfig {
  92. var cfg mysqlConfig
  93. if err := viper.UnmarshalKey("mysql", &cfg); err != nil {
  94. log.Fatalln(err)
  95. }
  96. if cfg.SSH {
  97. config := &ssh.ClientConfig{
  98. User: "root",
  99. Auth: []ssh.AuthMethod{
  100. ssh.Password("RHTUH2z49aEXnsgz"),
  101. },
  102. HostKeyCallback: ssh.InsecureIgnoreHostKey(),
  103. }
  104. var err error
  105. var clt *ssh.Client
  106. if clt, err = ssh.Dial("tcp", "106.75.230.4:22", config); err != nil {
  107. log.Fatalln(err)
  108. }
  109. var protocol = "ssh"
  110. vc := withDsn(cfg.Dsn)
  111. vc.protocol = protocol
  112. cfg.Dsn = vc.formatDSN()
  113. fmt.Println(cfg.Dsn)
  114. mysql.RegisterDialContext(protocol, (&Dialer{client: clt}).Dial)
  115. }
  116. return cfg
  117. }
  118. // getGormConfig 获取 gorm 配置
  119. func getGormConfig(cfg mysqlConfig) *gorm.Config {
  120. return &gorm.Config{
  121. DisableForeignKeyConstraintWhenMigrating: true,
  122. SkipDefaultTransaction: cfg.SkipDefaultTransaction,
  123. Logger: defaultLogger(cfg),
  124. }
  125. }
  126. // defaultLogger 默认的日志打印
  127. func defaultLogger(cfg mysqlConfig) logger.Interface {
  128. return logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{
  129. SlowThreshold: cfg.SlowThreshold * time.Millisecond, // Slow SQL threshold
  130. LogLevel: logger.Silent, // Log level
  131. IgnoreRecordNotFoundError: cfg.IgnoreRecordNotFoundError, // Ignore ErrRecordNotFound error for logger
  132. })
  133. }