PT-73 Added SSL support

This commit is contained in:
Carlos Salguero
2017-02-21 14:32:17 -03:00
parent 0f9a1bcf42
commit c00ccf0d8d
6 changed files with 118 additions and 50 deletions

4
glide.lock generated
View File

@@ -1,5 +1,5 @@
hash: 2ff7c989fb0fde1375999fded74ae44e10be513a21416571f026390b679924e4 hash: 2ff7c989fb0fde1375999fded74ae44e10be513a21416571f026390b679924e4
updated: 2017-02-15T13:56:15.338996189-03:00 updated: 2017-02-21T13:59:41.533544309-03:00
imports: imports:
- name: github.com/bradfitz/slice - name: github.com/bradfitz/slice
version: d9036e2120b5ddfa53f3ebccd618c4af275f47da version: d9036e2120b5ddfa53f3ebccd618c4af275f47da
@@ -23,6 +23,8 @@ imports:
version: eeaced052adbcfeea372c749c281099ed7fdaa38 version: eeaced052adbcfeea372c749c281099ed7fdaa38
- name: github.com/pborman/getopt - name: github.com/pborman/getopt
version: 7148bc3a4c3008adfcab60cbebfd0576018f330b version: 7148bc3a4c3008adfcab60cbebfd0576018f330b
subpackages:
- v2
- name: github.com/percona/pmgo - name: github.com/percona/pmgo
version: 27d979df6c6885ff16abe375aead061a86da6df8 version: 27d979df6c6885ff16abe375aead061a86da6df8
subpackages: subpackages:

View File

@@ -12,7 +12,7 @@ import (
"gopkg.in/mgo.v2/bson" "gopkg.in/mgo.v2/bson"
) )
func GetReplicasetMembers(dialer pmgo.Dialer, di *mgo.DialInfo) ([]proto.Members, error) { func GetReplicasetMembers(dialer pmgo.Dialer, di *pmgo.DialInfo) ([]proto.Members, error) {
hostnames, err := GetHostnames(dialer, di) hostnames, err := GetHostnames(dialer, di)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -75,7 +75,7 @@ func GetReplicasetMembers(dialer pmgo.Dialer, di *mgo.DialInfo) ([]proto.Members
return members, nil return members, nil
} }
func GetHostnames(dialer pmgo.Dialer, di *mgo.DialInfo) ([]string, error) { func GetHostnames(dialer pmgo.Dialer, di *pmgo.DialInfo) ([]string, error) {
hostnames := []string{di.Addrs[0]} hostnames := []string{di.Addrs[0]}
di.Direct = true di.Direct = true
di.Timeout = 2 * time.Second di.Timeout = 2 * time.Second
@@ -182,7 +182,7 @@ func buildHostsListFromShardMap(shardsMap proto.ShardsMap) []string {
// This function is like GetHostnames but it uses listShards instead of getShardMap // This function is like GetHostnames but it uses listShards instead of getShardMap
// so it won't include config servers in the returned list // so it won't include config servers in the returned list
func GetShardedHosts(dialer pmgo.Dialer, di *mgo.DialInfo) ([]string, error) { func GetShardedHosts(dialer pmgo.Dialer, di *pmgo.DialInfo) ([]string, error) {
hostnames := []string{di.Addrs[0]} hostnames := []string{di.Addrs[0]}
session, err := dialer.DialWithInfo(di) session, err := dialer.DialWithInfo(di)
if err != nil { if err != nil {
@@ -206,7 +206,7 @@ func GetShardedHosts(dialer pmgo.Dialer, di *mgo.DialInfo) ([]string, error) {
return hostnames, nil return hostnames, nil
} }
func getTmpDI(di *mgo.DialInfo, hostname string) *mgo.DialInfo { func getTmpDI(di *pmgo.DialInfo, hostname string) *pmgo.DialInfo {
tmpdi := *di tmpdi := *di
tmpdi.Addrs = []string{hostname} tmpdi.Addrs = []string{hostname}
tmpdi.Direct = true tmpdi.Direct = true
@@ -215,7 +215,7 @@ func getTmpDI(di *mgo.DialInfo, hostname string) *mgo.DialInfo {
return &tmpdi return &tmpdi
} }
func GetServerStatus(dialer pmgo.Dialer, di *mgo.DialInfo, hostname string) (proto.ServerStatus, error) { func GetServerStatus(dialer pmgo.Dialer, di *pmgo.DialInfo, hostname string) (proto.ServerStatus, error) {
ss := proto.ServerStatus{} ss := proto.ServerStatus{}
tmpdi := getTmpDI(di, hostname) tmpdi := getTmpDI(di, hostname)

View File

@@ -71,6 +71,8 @@ type options struct {
OrderBy []string OrderBy []string
Password string Password string
SkipCollections []string SkipCollections []string
SSLCAFile string
SSLPEMKeyFile string
User string User string
Version bool Version bool
} }
@@ -527,6 +529,8 @@ func getOptions() (*options, error) {
gop.StringVarLong(&opts.LogLevel, "log-level", 'l', "Log level: error", "panic, fatal, error, warn, info, debug. Default: error") gop.StringVarLong(&opts.LogLevel, "log-level", 'l', "Log level: error", "panic, fatal, error, warn, info, debug. Default: error")
gop.StringVarLong(&opts.Password, "password", 'p', "", "Password to use for optional MongoDB authentication").SetOptional() gop.StringVarLong(&opts.Password, "password", 'p', "", "Password to use for optional MongoDB authentication").SetOptional()
gop.StringVarLong(&opts.User, "username", 'u', "Username to use for optional MongoDB authentication") gop.StringVarLong(&opts.User, "username", 'u', "Username to use for optional MongoDB authentication")
gop.StringVarLong(&opts.SSLCAFile, "sslCAFile", 0, "SSL CA cert file used for authentication")
gop.StringVarLong(&opts.SSLPEMKeyFile, "sslPEMKeyFile", 0, "SSL client PEM file used for authentication")
gop.SetParameters("host[:port]/database") gop.SetParameters("host[:port]/database")
@@ -566,25 +570,34 @@ func getOptions() (*options, error) {
return opts, nil return opts, nil
} }
func getDialInfo(opts *options) *mgo.DialInfo { func getDialInfo(opts *options) *pmgo.DialInfo {
di, _ := mgo.ParseURL(opts.Host) di, _ := mgo.ParseURL(opts.Host)
di.FailFast = true di.FailFast = true
if getopt.IsSet("username") { if di.Username != "" {
di.Username = opts.User di.Username = opts.User
} }
if getopt.IsSet("password") { if di.Password != "" {
di.Password = opts.Password di.Password = opts.Password
} }
if getopt.IsSet("authenticationDatabase") { if opts.AuthDB != "" {
di.Source = opts.AuthDB di.Source = opts.AuthDB
} }
if opts.Database != "" {
if getopt.IsSet("database") {
di.Database = opts.Database di.Database = opts.Database
} }
return di pmgoDialInfo := pmgo.NewDialInfo(di)
if opts.SSLCAFile != "" {
pmgoDialInfo.SSLCAFile = opts.SSLCAFile
}
if opts.SSLPEMKeyFile != "" {
pmgoDialInfo.SSLPEMKeyFile = opts.SSLPEMKeyFile
}
return pmgoDialInfo
} }
func getQueryField(query map[string]interface{}) (map[string]interface{}, error) { func getQueryField(query map[string]interface{}) (map[string]interface{}, error) {
@@ -891,7 +904,7 @@ func sortQueries(queries []stat, orderby []string) []stat {
} }
func isProfilerEnabled(dialer pmgo.Dialer, di *mgo.DialInfo) (bool, error) { func isProfilerEnabled(dialer pmgo.Dialer, di *pmgo.DialInfo) (bool, error) {
var ps proto.ProfilerStatus var ps proto.ProfilerStatus
replicaMembers, err := util.GetReplicasetMembers(dialer, di) replicaMembers, err := util.GetReplicasetMembers(dialer, di)
if err != nil { if err != nil {

View File

@@ -132,11 +132,17 @@ type options struct {
NoRunningOps bool NoRunningOps bool
RunningOpsSamples int RunningOpsSamples int
RunningOpsInterval int RunningOpsInterval int
SSLCAFile string
SSLPEMKeyFile string
} }
func main() { func main() {
opts := parseFlags() opts, err := parseFlags()
if err != nil {
log.Errorf("cannot get parameters: %s", err.Error())
os.Exit(2)
}
if opts.Help { if opts.Help {
getopt.Usage() getopt.Usage()
@@ -169,22 +175,14 @@ func main() {
} }
} }
if getopt.IsSet("password") && opts.Password == "" { di := &pmgo.DialInfo{
print("Password: ") Username: opts.User,
pass, err := gopass.GetPasswd() Password: opts.Password,
if err != nil { Addrs: []string{opts.Host},
fmt.Println(err) FailFast: true,
os.Exit(2) Source: opts.AuthDB,
} SSLCAFile: opts.SSLCAFile,
opts.Password = string(pass) SSLPEMKeyFile: opts.SSLPEMKeyFile,
}
di := &mgo.DialInfo{
Username: opts.User,
Password: opts.Password,
Addrs: []string{opts.Host},
FailFast: true,
Source: opts.AuthDB,
} }
log.Debugf("Connecting to the db using:\n%+v", di) log.Debugf("Connecting to the db using:\n%+v", di)
@@ -211,7 +209,7 @@ func main() {
hostInfo, err := GetHostinfo(session) hostInfo, err := GetHostinfo(session)
if err != nil { if err != nil {
message := fmt.Sprintf("Cannot connect to %q: %s", di.Addrs[0], err.Error()) message := fmt.Sprintf("Cannot get host info for %q: %s", di.Addrs[0], err.Error())
log.Errorf(message) log.Errorf(message)
os.Exit(2) os.Exit(2)
} }
@@ -786,7 +784,7 @@ func externalIP() (string, error) {
return "", errors.New("are you connected to the network?") return "", errors.New("are you connected to the network?")
} }
func parseFlags() options { func parseFlags() (options, error) {
opts := options{ opts := options{
Host: DEFAULT_HOST, Host: DEFAULT_HOST,
LogLevel: DEFAULT_LOGLEVEL, LogLevel: DEFAULT_LOGLEVEL,
@@ -795,31 +793,42 @@ func parseFlags() options {
AuthDB: DEFAULT_AUTHDB, AuthDB: DEFAULT_AUTHDB,
} }
getopt.BoolVarLong(&opts.Help, "help", 'h', "Show help") gop := getopt.New()
getopt.BoolVarLong(&opts.Version, "version", 'v', "", "Show version & exit") gop.BoolVarLong(&opts.Help, "help", 'h', "Show help")
getopt.BoolVarLong(&opts.NoVersionCheck, "no-version-check", 'c', "", "Default: Don't check for updates") gop.BoolVarLong(&opts.Version, "version", 'v', "", "Show version & exit")
gop.BoolVarLong(&opts.NoVersionCheck, "no-version-check", 'c', "", "Default: Don't check for updates")
getopt.StringVarLong(&opts.User, "username", 'u', "", "Username to use for optional MongoDB authentication") gop.StringVarLong(&opts.User, "username", 'u', "", "Username to use for optional MongoDB authentication")
getopt.StringVarLong(&opts.Password, "password", 'p', "", "Password to use for optional MongoDB authentication").SetOptional() gop.StringVarLong(&opts.Password, "password", 'p', "", "Password to use for optional MongoDB authentication").SetOptional()
getopt.StringVarLong(&opts.AuthDB, "authenticationDatabase", 'a', "admin", gop.StringVarLong(&opts.AuthDB, "authenticationDatabase", 'a', "admin",
"Databaae to use for optional MongoDB authentication. Default: admin") "Databaae to use for optional MongoDB authentication. Default: admin")
getopt.StringVarLong(&opts.LogLevel, "log-level", 'l', "error", "Log level: panic, fatal, error, warn, info, debug. Default: error") gop.StringVarLong(&opts.LogLevel, "log-level", 'l', "error", "Log level: panic, fatal, error, warn, info, debug. Default: error")
getopt.IntVarLong(&opts.RunningOpsSamples, "running-ops-samples", 's', gop.IntVarLong(&opts.RunningOpsSamples, "running-ops-samples", 's',
fmt.Sprintf("Number of samples to collect for running ops. Default: %d", opts.RunningOpsSamples)) fmt.Sprintf("Number of samples to collect for running ops. Default: %d", opts.RunningOpsSamples))
getopt.IntVarLong(&opts.RunningOpsInterval, "running-ops-interval", 'i', gop.IntVarLong(&opts.RunningOpsInterval, "running-ops-interval", 'i',
fmt.Sprintf("Interval to wait betwwen running ops samples in milliseconds. Default %d milliseconds", opts.RunningOpsInterval)) fmt.Sprintf("Interval to wait betwwen running ops samples in milliseconds. Default %d milliseconds", opts.RunningOpsInterval))
getopt.SetParameters("host[:port]") gop.StringVarLong(&opts.SSLCAFile, "sslCAFile", 0, "SSL CA cert file used for authentication")
gop.StringVarLong(&opts.SSLPEMKeyFile, "sslPEMKeyFile", 0, "SSL client PEM file used for authentication")
gop.SetParameters("host[:port]")
var gop = getopt.CommandLine
gop.Parse(os.Args) gop.Parse(os.Args)
if gop.NArgs() > 0 { if gop.NArgs() > 0 {
opts.Host = gop.Arg(0) opts.Host = gop.Arg(0)
gop.Parse(gop.Args()) gop.Parse(gop.Args())
} }
return opts if gop.IsSet("password") && opts.Password == "" {
print("Password: ")
pass, err := gopass.GetPasswd()
if err != nil {
return opts, err
}
opts.Password = string(pass)
}
return opts, nil
} }
func getChunksCount(session pmgo.SessionManager) ([]proto.ChunksByCollection, error) { func getChunksCount(session pmgo.SessionManager) ([]proto.ChunksByCollection, error) {

View File

@@ -5,6 +5,7 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"reflect" "reflect"
"strings"
"testing" "testing"
"time" "time"
@@ -12,6 +13,7 @@ import (
"gopkg.in/mgo.v2/dbtest" "gopkg.in/mgo.v2/dbtest"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/pborman/getopt"
"github.com/percona/percona-toolkit/src/go/lib/tutil" "github.com/percona/percona-toolkit/src/go/lib/tutil"
"github.com/percona/percona-toolkit/src/go/mongolib/proto" "github.com/percona/percona-toolkit/src/go/mongolib/proto"
"github.com/percona/pmgo" "github.com/percona/pmgo"
@@ -382,3 +384,44 @@ func addToCounters(ss proto.ServerStatus, increment int64) proto.ServerStatus {
ss.Opcounters.Update += increment ss.Opcounters.Update += increment
return ss return ss
} }
func TestParseArgs(t *testing.T) {
tests := []struct {
args []string
want *options
}{
{
args: []string{TOOLNAME}, // arg[0] is the command itself
want: &options{
Host: DEFAULT_HOST,
LogLevel: DEFAULT_LOGLEVEL,
OrderBy: strings.Split(DEFAULT_ORDERBY, ","),
SkipCollections: strings.Split(DEFAULT_SKIPCOLLECTIONS, ","),
AuthDB: DEFAULT_AUTHDB,
},
},
{
args: []string{TOOLNAME, "zapp.brannigan.net:27018/samples", "--help"},
want: &options{
Host: "zapp.brannigan.net:27018/samples",
LogLevel: DEFAULT_LOGLEVEL,
OrderBy: strings.Split(DEFAULT_ORDERBY, ","),
SkipCollections: strings.Split(DEFAULT_SKIPCOLLECTIONS, ","),
AuthDB: DEFAULT_AUTHDB,
Help: true,
},
},
}
for i, test := range tests {
getopt.Reset()
os.Args = test.args
got, err := getOptions()
if err != nil {
t.Errorf("error parsing command line arguments: %s", err.Error())
}
if !reflect.DeepEqual(got, test.want) {
t.Errorf("invalid command line options test %d\ngot %+v\nwant %+v\n", i, got, test.want)
}
}
}

View File

@@ -6,12 +6,12 @@ import (
"time" "time"
"github.com/percona/percona-toolkit/src/go/mongolib/proto" "github.com/percona/percona-toolkit/src/go/mongolib/proto"
"github.com/percona/pmgo"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson" "gopkg.in/mgo.v2/bson"
) )
func GetOplogInfo(hostnames []string, di *mgo.DialInfo) ([]proto.OplogInfo, error) { func GetOplogInfo(hostnames []string, di *pmgo.DialInfo) ([]proto.OplogInfo, error) {
results := proto.OpLogs{} results := proto.OpLogs{}
@@ -20,7 +20,8 @@ func GetOplogInfo(hostnames []string, di *mgo.DialInfo) ([]proto.OplogInfo, erro
Hostname: hostname, Hostname: hostname,
} }
di.Addrs = []string{hostname} di.Addrs = []string{hostname}
session, err := mgo.DialWithInfo(di) dialer := pmgo.NewDialer()
session, err := dialer.DialWithInfo(di)
if err != nil { if err != nil {
continue continue
} }
@@ -91,7 +92,7 @@ func GetOplogInfo(hostnames []string, di *mgo.DialInfo) ([]proto.OplogInfo, erro
} }
func getOplogCollection(session *mgo.Session) (string, error) { func getOplogCollection(session pmgo.SessionManager) (string, error) {
oplog := "oplog.rs" oplog := "oplog.rs"
db := session.DB("local") db := session.DB("local")
@@ -110,7 +111,7 @@ func getOplogCollection(session *mgo.Session) (string, error) {
return oplog, nil return oplog, nil
} }
func getOplogEntry(session *mgo.Session, oplogCol string) (*proto.OplogEntry, error) { func getOplogEntry(session pmgo.SessionManager, oplogCol string) (*proto.OplogEntry, error) {
olEntry := &proto.OplogEntry{} olEntry := &proto.OplogEntry{}
err := session.DB("local").C("system.namespaces").Find(bson.M{"name": "local." + oplogCol}).One(&olEntry) err := session.DB("local").C("system.namespaces").Find(bson.M{"name": "local." + oplogCol}).One(&olEntry)