PT-73 Added SSL support

This commit is contained in:
Carlos Salguero
2017-02-21 14:32:17 -03:00
parent 0f9a1bcf42
commit c00ccf0d8d
6 changed files with 118 additions and 50 deletions

View File

@@ -12,7 +12,7 @@ import (
"gopkg.in/mgo.v2/bson"
)
func GetReplicasetMembers(dialer pmgo.Dialer, di *mgo.DialInfo) ([]proto.Members, error) {
func GetReplicasetMembers(dialer pmgo.Dialer, di *pmgo.DialInfo) ([]proto.Members, error) {
hostnames, err := GetHostnames(dialer, di)
if err != nil {
return nil, err
@@ -75,7 +75,7 @@ func GetReplicasetMembers(dialer pmgo.Dialer, di *mgo.DialInfo) ([]proto.Members
return members, nil
}
func GetHostnames(dialer pmgo.Dialer, di *mgo.DialInfo) ([]string, error) {
func GetHostnames(dialer pmgo.Dialer, di *pmgo.DialInfo) ([]string, error) {
hostnames := []string{di.Addrs[0]}
di.Direct = true
di.Timeout = 2 * time.Second
@@ -182,7 +182,7 @@ func buildHostsListFromShardMap(shardsMap proto.ShardsMap) []string {
// This function is like GetHostnames but it uses listShards instead of getShardMap
// so it won't include config servers in the returned list
func GetShardedHosts(dialer pmgo.Dialer, di *mgo.DialInfo) ([]string, error) {
func GetShardedHosts(dialer pmgo.Dialer, di *pmgo.DialInfo) ([]string, error) {
hostnames := []string{di.Addrs[0]}
session, err := dialer.DialWithInfo(di)
if err != nil {
@@ -206,7 +206,7 @@ func GetShardedHosts(dialer pmgo.Dialer, di *mgo.DialInfo) ([]string, error) {
return hostnames, nil
}
func getTmpDI(di *mgo.DialInfo, hostname string) *mgo.DialInfo {
func getTmpDI(di *pmgo.DialInfo, hostname string) *pmgo.DialInfo {
tmpdi := *di
tmpdi.Addrs = []string{hostname}
tmpdi.Direct = true
@@ -215,7 +215,7 @@ func getTmpDI(di *mgo.DialInfo, hostname string) *mgo.DialInfo {
return &tmpdi
}
func GetServerStatus(dialer pmgo.Dialer, di *mgo.DialInfo, hostname string) (proto.ServerStatus, error) {
func GetServerStatus(dialer pmgo.Dialer, di *pmgo.DialInfo, hostname string) (proto.ServerStatus, error) {
ss := proto.ServerStatus{}
tmpdi := getTmpDI(di, hostname)

View File

@@ -71,6 +71,8 @@ type options struct {
OrderBy []string
Password string
SkipCollections []string
SSLCAFile string
SSLPEMKeyFile string
User string
Version bool
}
@@ -527,6 +529,8 @@ func getOptions() (*options, error) {
gop.StringVarLong(&opts.LogLevel, "log-level", 'l', "Log level: error", "panic, fatal, error, warn, info, debug. Default: error")
gop.StringVarLong(&opts.Password, "password", 'p', "", "Password to use for optional MongoDB authentication").SetOptional()
gop.StringVarLong(&opts.User, "username", 'u', "Username to use for optional MongoDB authentication")
gop.StringVarLong(&opts.SSLCAFile, "sslCAFile", 0, "SSL CA cert file used for authentication")
gop.StringVarLong(&opts.SSLPEMKeyFile, "sslPEMKeyFile", 0, "SSL client PEM file used for authentication")
gop.SetParameters("host[:port]/database")
@@ -566,25 +570,34 @@ func getOptions() (*options, error) {
return opts, nil
}
func getDialInfo(opts *options) *mgo.DialInfo {
func getDialInfo(opts *options) *pmgo.DialInfo {
di, _ := mgo.ParseURL(opts.Host)
di.FailFast = true
if getopt.IsSet("username") {
if di.Username != "" {
di.Username = opts.User
}
if getopt.IsSet("password") {
if di.Password != "" {
di.Password = opts.Password
}
if getopt.IsSet("authenticationDatabase") {
if opts.AuthDB != "" {
di.Source = opts.AuthDB
}
if getopt.IsSet("database") {
if opts.Database != "" {
di.Database = opts.Database
}
return di
pmgoDialInfo := pmgo.NewDialInfo(di)
if opts.SSLCAFile != "" {
pmgoDialInfo.SSLCAFile = opts.SSLCAFile
}
if opts.SSLPEMKeyFile != "" {
pmgoDialInfo.SSLPEMKeyFile = opts.SSLPEMKeyFile
}
return pmgoDialInfo
}
func getQueryField(query map[string]interface{}) (map[string]interface{}, error) {
@@ -891,7 +904,7 @@ func sortQueries(queries []stat, orderby []string) []stat {
}
func isProfilerEnabled(dialer pmgo.Dialer, di *mgo.DialInfo) (bool, error) {
func isProfilerEnabled(dialer pmgo.Dialer, di *pmgo.DialInfo) (bool, error) {
var ps proto.ProfilerStatus
replicaMembers, err := util.GetReplicasetMembers(dialer, di)
if err != nil {

View File

@@ -132,11 +132,17 @@ type options struct {
NoRunningOps bool
RunningOpsSamples int
RunningOpsInterval int
SSLCAFile string
SSLPEMKeyFile string
}
func main() {
opts := parseFlags()
opts, err := parseFlags()
if err != nil {
log.Errorf("cannot get parameters: %s", err.Error())
os.Exit(2)
}
if opts.Help {
getopt.Usage()
@@ -169,22 +175,14 @@ func main() {
}
}
if getopt.IsSet("password") && opts.Password == "" {
print("Password: ")
pass, err := gopass.GetPasswd()
if err != nil {
fmt.Println(err)
os.Exit(2)
}
opts.Password = string(pass)
}
di := &mgo.DialInfo{
Username: opts.User,
Password: opts.Password,
Addrs: []string{opts.Host},
FailFast: true,
Source: opts.AuthDB,
di := &pmgo.DialInfo{
Username: opts.User,
Password: opts.Password,
Addrs: []string{opts.Host},
FailFast: true,
Source: opts.AuthDB,
SSLCAFile: opts.SSLCAFile,
SSLPEMKeyFile: opts.SSLPEMKeyFile,
}
log.Debugf("Connecting to the db using:\n%+v", di)
@@ -211,7 +209,7 @@ func main() {
hostInfo, err := GetHostinfo(session)
if err != nil {
message := fmt.Sprintf("Cannot connect to %q: %s", di.Addrs[0], err.Error())
message := fmt.Sprintf("Cannot get host info for %q: %s", di.Addrs[0], err.Error())
log.Errorf(message)
os.Exit(2)
}
@@ -786,7 +784,7 @@ func externalIP() (string, error) {
return "", errors.New("are you connected to the network?")
}
func parseFlags() options {
func parseFlags() (options, error) {
opts := options{
Host: DEFAULT_HOST,
LogLevel: DEFAULT_LOGLEVEL,
@@ -795,31 +793,42 @@ func parseFlags() options {
AuthDB: DEFAULT_AUTHDB,
}
getopt.BoolVarLong(&opts.Help, "help", 'h', "Show help")
getopt.BoolVarLong(&opts.Version, "version", 'v', "", "Show version & exit")
getopt.BoolVarLong(&opts.NoVersionCheck, "no-version-check", 'c', "", "Default: Don't check for updates")
gop := getopt.New()
gop.BoolVarLong(&opts.Help, "help", 'h', "Show help")
gop.BoolVarLong(&opts.Version, "version", 'v', "", "Show version & exit")
gop.BoolVarLong(&opts.NoVersionCheck, "no-version-check", 'c', "", "Default: Don't check for updates")
getopt.StringVarLong(&opts.User, "username", 'u', "", "Username to use for optional MongoDB authentication")
getopt.StringVarLong(&opts.Password, "password", 'p', "", "Password to use for optional MongoDB authentication").SetOptional()
getopt.StringVarLong(&opts.AuthDB, "authenticationDatabase", 'a', "admin",
gop.StringVarLong(&opts.User, "username", 'u', "", "Username to use for optional MongoDB authentication")
gop.StringVarLong(&opts.Password, "password", 'p', "", "Password to use for optional MongoDB authentication").SetOptional()
gop.StringVarLong(&opts.AuthDB, "authenticationDatabase", 'a', "admin",
"Databaae to use for optional MongoDB authentication. Default: admin")
getopt.StringVarLong(&opts.LogLevel, "log-level", 'l', "error", "Log level: panic, fatal, error, warn, info, debug. Default: error")
gop.StringVarLong(&opts.LogLevel, "log-level", 'l', "error", "Log level: panic, fatal, error, warn, info, debug. Default: error")
getopt.IntVarLong(&opts.RunningOpsSamples, "running-ops-samples", 's',
gop.IntVarLong(&opts.RunningOpsSamples, "running-ops-samples", 's',
fmt.Sprintf("Number of samples to collect for running ops. Default: %d", opts.RunningOpsSamples))
getopt.IntVarLong(&opts.RunningOpsInterval, "running-ops-interval", 'i',
gop.IntVarLong(&opts.RunningOpsInterval, "running-ops-interval", 'i',
fmt.Sprintf("Interval to wait betwwen running ops samples in milliseconds. Default %d milliseconds", opts.RunningOpsInterval))
getopt.SetParameters("host[:port]")
gop.StringVarLong(&opts.SSLCAFile, "sslCAFile", 0, "SSL CA cert file used for authentication")
gop.StringVarLong(&opts.SSLPEMKeyFile, "sslPEMKeyFile", 0, "SSL client PEM file used for authentication")
gop.SetParameters("host[:port]")
var gop = getopt.CommandLine
gop.Parse(os.Args)
if gop.NArgs() > 0 {
opts.Host = gop.Arg(0)
gop.Parse(gop.Args())
}
return opts
if gop.IsSet("password") && opts.Password == "" {
print("Password: ")
pass, err := gopass.GetPasswd()
if err != nil {
return opts, err
}
opts.Password = string(pass)
}
return opts, nil
}
func getChunksCount(session pmgo.SessionManager) ([]proto.ChunksByCollection, error) {

View File

@@ -5,6 +5,7 @@ import (
"io/ioutil"
"os"
"reflect"
"strings"
"testing"
"time"
@@ -12,6 +13,7 @@ import (
"gopkg.in/mgo.v2/dbtest"
"github.com/golang/mock/gomock"
"github.com/pborman/getopt"
"github.com/percona/percona-toolkit/src/go/lib/tutil"
"github.com/percona/percona-toolkit/src/go/mongolib/proto"
"github.com/percona/pmgo"
@@ -382,3 +384,44 @@ func addToCounters(ss proto.ServerStatus, increment int64) proto.ServerStatus {
ss.Opcounters.Update += increment
return ss
}
func TestParseArgs(t *testing.T) {
tests := []struct {
args []string
want *options
}{
{
args: []string{TOOLNAME}, // arg[0] is the command itself
want: &options{
Host: DEFAULT_HOST,
LogLevel: DEFAULT_LOGLEVEL,
OrderBy: strings.Split(DEFAULT_ORDERBY, ","),
SkipCollections: strings.Split(DEFAULT_SKIPCOLLECTIONS, ","),
AuthDB: DEFAULT_AUTHDB,
},
},
{
args: []string{TOOLNAME, "zapp.brannigan.net:27018/samples", "--help"},
want: &options{
Host: "zapp.brannigan.net:27018/samples",
LogLevel: DEFAULT_LOGLEVEL,
OrderBy: strings.Split(DEFAULT_ORDERBY, ","),
SkipCollections: strings.Split(DEFAULT_SKIPCOLLECTIONS, ","),
AuthDB: DEFAULT_AUTHDB,
Help: true,
},
},
}
for i, test := range tests {
getopt.Reset()
os.Args = test.args
got, err := getOptions()
if err != nil {
t.Errorf("error parsing command line arguments: %s", err.Error())
}
if !reflect.DeepEqual(got, test.want) {
t.Errorf("invalid command line options test %d\ngot %+v\nwant %+v\n", i, got, test.want)
}
}
}

View File

@@ -6,12 +6,12 @@ import (
"time"
"github.com/percona/percona-toolkit/src/go/mongolib/proto"
"github.com/percona/pmgo"
"github.com/pkg/errors"
"gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
)
func GetOplogInfo(hostnames []string, di *mgo.DialInfo) ([]proto.OplogInfo, error) {
func GetOplogInfo(hostnames []string, di *pmgo.DialInfo) ([]proto.OplogInfo, error) {
results := proto.OpLogs{}
@@ -20,7 +20,8 @@ func GetOplogInfo(hostnames []string, di *mgo.DialInfo) ([]proto.OplogInfo, erro
Hostname: hostname,
}
di.Addrs = []string{hostname}
session, err := mgo.DialWithInfo(di)
dialer := pmgo.NewDialer()
session, err := dialer.DialWithInfo(di)
if err != nil {
continue
}
@@ -91,7 +92,7 @@ func GetOplogInfo(hostnames []string, di *mgo.DialInfo) ([]proto.OplogInfo, erro
}
func getOplogCollection(session *mgo.Session) (string, error) {
func getOplogCollection(session pmgo.SessionManager) (string, error) {
oplog := "oplog.rs"
db := session.DB("local")
@@ -110,7 +111,7 @@ func getOplogCollection(session *mgo.Session) (string, error) {
return oplog, nil
}
func getOplogEntry(session *mgo.Session, oplogCol string) (*proto.OplogEntry, error) {
func getOplogEntry(session pmgo.SessionManager, oplogCol string) (*proto.OplogEntry, error) {
olEntry := &proto.OplogEntry{}
err := session.DB("local").C("system.namespaces").Find(bson.M{"name": "local." + oplogCol}).One(&olEntry)