PT-1891 Fixed mongodb-summary connection with ssl (#469)

* PT-1891 Fixed mongodb-summary connection with ssl

- Added SSL connection options
- Fixed old tests
- Replaced gofmt by gofumpt in Makefile
- There are no ssl test for mongodb-summary because the current sandbox
doesnt support it

* PT-1891 Ran gofumports

* PMM-1891 Fixes for CR

* PT-1891 Decreased minimum TLS reqs for compatibility
This commit is contained in:
Carlos Salguero
2020-11-02 17:13:29 -03:00
committed by GitHub
parent e731cf4d83
commit ff6b05b381
35 changed files with 286 additions and 136 deletions

View File

@@ -3,11 +3,16 @@ package main
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"html/template"
"io/ioutil"
"net"
"os"
"os/user"
"path/filepath"
"strings"
"time"
@@ -39,18 +44,23 @@ const (
DefaultOutputFormat = "text"
typeMongos = "mongos"
// Exit Codes
// Exit Codes.
cannotFormatResults = 1
cannotParseCommandLineParameters = 2
cannotGetHostInfo = 3
cannotGetReplicasetMembers = 4
cannotGetClientOptions = 4
cannotConnectToMongoDB = 5
)
//nolint:gochecknoglobals
var (
Build string = "2020-04-23" // nolint
GoVersion string = "1.14.1" // nolint
Build string = "2020-04-23"
GoVersion string = "1.14.1"
Version string = "3.2.0"
Commit string
defaultConnectionTimeout = 3 * time.Second
directConnection = true
)
type TimedStats struct {
@@ -69,6 +79,7 @@ type opCounters struct {
Command TimedStats
SampleRate time.Duration
}
type hostInfo struct {
Hostname string
HostOsType string
@@ -164,14 +175,11 @@ func main() {
opts, err := parseFlags()
if err != nil {
log.Errorf("cannot get parameters: %s", err.Error())
os.Exit(cannotParseCommandLineParameters)
}
if opts == nil && err == nil {
return
}
if opts.Help {
getopt.Usage()
if opts == nil && err == nil {
return
}
@@ -187,6 +195,7 @@ func main() {
fmt.Printf("Version %s\n", Version)
fmt.Printf("Build: %s using %s\n", Build, GoVersion)
fmt.Printf("Commit: %s\n", Commit)
return
}
@@ -201,34 +210,46 @@ func main() {
}
ctx := context.Background()
clientOptions := getClientOptions(opts)
clientOptions, err := getClientOptions(opts)
if err != nil {
log.Error(err)
os.Exit(cannotGetClientOptions)
}
client, err := mongo.NewClient(clientOptions)
if err != nil {
log.Fatalf("Cannot get a MongoDB client: %s", err)
log.Errorf("Cannot get a MongoDB client: %s", err)
os.Exit(cannotConnectToMongoDB)
}
if err := client.Connect(ctx); err != nil {
log.Fatalf("Cannot connect to MongoDB: %s", err)
log.Errorf("Cannot connect to MongoDB: %s", err)
os.Exit(cannotConnectToMongoDB)
}
defer client.Disconnect(ctx) // nolint
hostnames, err := util.GetHostnames(ctx, client)
if err != nil && errors.Is(err, util.ShardingNotEnabledError) {
log.Errorf("Cannot get hostnames: %s", err)
}
log.Debugf("hostnames: %v", hostnames)
ci := &collectedInfo{}
ci.HostInfo, err = getHostInfo(ctx, client)
if err != nil {
message := fmt.Sprintf("Cannot get host info for %q: %s", opts.Host, err.Error())
log.Errorf(message)
os.Exit(cannotGetHostInfo)
log.Errorf("Cannot get host info for %q: %s", opts.Host, err)
os.Exit(cannotGetHostInfo) //nolint:gocritic
}
if ci.ReplicaMembers, err = util.GetReplicasetMembers(ctx, clientOptions); err != nil {
log.Warnf("[Error] cannot get replicaset members: %v\n", err)
}
log.Debugf("replicaMembers:\n%+v\n", ci.ReplicaMembers)
if opts.RunningOpsSamples > 0 && opts.RunningOpsInterval > 0 {
@@ -274,9 +295,10 @@ func main() {
out, err := formatResults(ci, opts.OutputFormat)
if err != nil {
log.Errorf("Cannot format the results: %s", err.Error())
log.Errorf("Cannot format the results: %s", err)
os.Exit(cannotFormatResults)
}
fmt.Println(string(out))
}
@@ -287,8 +309,9 @@ func formatResults(ci *collectedInfo, format string) ([]byte, error) {
case "json":
b, err := json.MarshalIndent(ci, "", " ")
if err != nil {
return nil, fmt.Errorf("[Error] Cannot convert results to json: %s", err.Error())
return nil, errors.Wrap(err, "Cannot convert results to json")
}
buf = bytes.NewBuffer(b)
default:
buf = new(bytes.Buffer)
@@ -338,6 +361,7 @@ func getHostInfo(ctx context.Context, client *mongo.Client) (*hostInfo, error) {
hi := proto.HostInfo{}
if err := client.Database("admin").RunCommand(ctx, primitive.M{"hostInfo": 1}).Decode(&hi); err != nil {
log.Debugf("run('hostInfo') error: %s", err)
return nil, errors.Wrap(err, "GetHostInfo.hostInfo")
}
@@ -393,12 +417,15 @@ func countMongodProcesses() (int, error) {
if err != nil {
return 0, err
}
count := 0
for _, pid := range pids {
p, err := process.NewProcess(pid)
if err != nil {
continue
}
if name, _ := p.Name(); name == "mongod" || name == typeMongos {
count++
}
@@ -440,6 +467,7 @@ func getClusterwideInfo(ctx context.Context, client *mongo.Client) (*clusterwide
if collStats.Sharded {
cwi.ShardedDataSize += collStats.Size
cwi.ShardedColsCount++
continue
}
@@ -464,11 +492,14 @@ func sizeAndUnit(size int64) (float64, string) {
unit := []string{"bytes", "KB", "MB", "GB", "TB"}
idx := 0
newSize := float64(size)
for newSize > 1024 {
newSize /= 1024
idx++
}
newSize = float64(int64(newSize*100)) / 100
return newSize, unit[idx]
}
@@ -486,7 +517,10 @@ func getSecuritySettings(ctx context.Context, client *mongo.Client, ver string)
}
cmdOpts := proto.CommandLineOptions{}
err = client.Database("admin").RunCommand(ctx, primitive.D{{"getCmdLineOpts", 1}, {"recordStats", 1}}).Decode(&cmdOpts)
err = client.Database("admin").RunCommand(ctx, primitive.D{
{Key: "getCmdLineOpts", Value: 1},
{Key: "recordStats", Value: 1},
}).Decode(&cmdOpts)
if err != nil {
return nil, errors.Wrap(err, "cannot get command line options")
}
@@ -503,7 +537,7 @@ func getSecuritySettings(ctx context.Context, client *mongo.Client, ver string)
s.BindIP = cmdOpts.Parsed.Net.BindIP
s.Port = cmdOpts.Parsed.Net.Port
if cmdOpts.Parsed.Net.BindIP == "" {
if cmdOpts.Parsed.Net.BindIP == "" { //nolint:nestif
if prior26 {
s.WarningMsgs = append(s.WarningMsgs, "WARNING: You might be insecure. There is no IP binding")
}
@@ -586,7 +620,11 @@ func getOpCountersStats(ctx context.Context, client *mongo.Client, count int,
// count + 1 because we need 1st reading to stablish a base to measure variation
for i := 0; i < count+1; i++ {
<-ticker.C
err := client.Database("admin").RunCommand(ctx, primitive.D{{"serverStatus", 1}, {"recordStats", 1}}).Decode(&ss)
err := client.Database("admin").RunCommand(ctx, primitive.D{
{Key: "serverStatus", Value: 1},
{Key: "recordStats", Value: 1},
}).Decode(&ss)
if err != nil {
return nil, err
}
@@ -598,6 +636,7 @@ func getOpCountersStats(ctx context.Context, client *mongo.Client, count int,
prevOpCount.Insert.Total = ss.Opcounters.Insert
prevOpCount.Query.Total = ss.Opcounters.Query
prevOpCount.Update.Total = ss.Opcounters.Update
continue
}
@@ -631,57 +670,63 @@ func getOpCountersStats(ctx context.Context, client *mongo.Client, count int,
}
// Insert --------------------------------------
if delta.Opcounters.Insert > oc.Insert.Max {
switch {
case delta.Opcounters.Insert > oc.Insert.Max:
oc.Insert.Max = delta.Opcounters.Insert
}
if delta.Opcounters.Insert < oc.Insert.Min {
case delta.Opcounters.Insert < oc.Insert.Min:
oc.Insert.Min = delta.Opcounters.Insert
}
oc.Insert.Total += delta.Opcounters.Insert
// Query ---------------------------------------
if delta.Opcounters.Query > oc.Query.Max {
switch {
case delta.Opcounters.Query > oc.Query.Max:
oc.Query.Max = delta.Opcounters.Query
}
if delta.Opcounters.Query < oc.Query.Min {
case delta.Opcounters.Query < oc.Query.Min:
oc.Query.Min = delta.Opcounters.Query
}
oc.Query.Total += delta.Opcounters.Query
// Command -------------------------------------
if delta.Opcounters.Command > oc.Command.Max {
switch {
case delta.Opcounters.Command > oc.Command.Max:
oc.Command.Max = delta.Opcounters.Command
}
if delta.Opcounters.Command < oc.Command.Min {
case delta.Opcounters.Command < oc.Command.Min:
oc.Command.Min = delta.Opcounters.Command
}
oc.Command.Total += delta.Opcounters.Command
// Update --------------------------------------
if delta.Opcounters.Update > oc.Update.Max {
switch {
case delta.Opcounters.Update > oc.Update.Max:
oc.Update.Max = delta.Opcounters.Update
}
if delta.Opcounters.Update < oc.Update.Min {
case delta.Opcounters.Update < oc.Update.Min:
oc.Update.Min = delta.Opcounters.Update
}
oc.Update.Total += delta.Opcounters.Update
// Delete --------------------------------------
if delta.Opcounters.Delete > oc.Delete.Max {
switch {
case delta.Opcounters.Delete > oc.Delete.Max:
oc.Delete.Max = delta.Opcounters.Delete
}
if delta.Opcounters.Delete < oc.Delete.Min {
case delta.Opcounters.Delete < oc.Delete.Min:
oc.Delete.Min = delta.Opcounters.Delete
}
oc.Delete.Total += delta.Opcounters.Delete
// GetMore -------------------------------------
if delta.Opcounters.GetMore > oc.GetMore.Max {
switch {
case delta.Opcounters.GetMore > oc.GetMore.Max:
oc.GetMore.Max = delta.Opcounters.GetMore
}
if delta.Opcounters.GetMore < oc.GetMore.Min {
case delta.Opcounters.GetMore < oc.GetMore.Min:
oc.GetMore.Min = delta.Opcounters.GetMore
}
oc.GetMore.Total += delta.Opcounters.GetMore
prevOpCount.Insert.Total = ss.Opcounters.Insert
@@ -690,8 +735,8 @@ func getOpCountersStats(ctx context.Context, client *mongo.Client, count int,
prevOpCount.Update.Total = ss.Opcounters.Update
prevOpCount.Delete.Total = ss.Opcounters.Delete
prevOpCount.GetMore.Total = ss.Opcounters.GetMore
}
ticker.Stop()
oc.Insert.Avg = oc.Insert.Total
@@ -707,11 +752,12 @@ func getOpCountersStats(ctx context.Context, client *mongo.Client, count int,
}
func getProcInfo(pid int32, templateData *procInfo) error {
//proc, err := process.NewProcess(templateData.ServerStatus.Pid)
// proc, err := process.NewProcess(templateData.ServerStatus.Pid)
proc, err := process.NewProcess(pid)
if err != nil {
return fmt.Errorf("cannot get process %d", pid)
return errors.New(fmt.Sprintf("cannot get process %d", pid))
}
ct, err := proc.CreateTime()
if err != nil {
return err
@@ -742,6 +788,7 @@ func GetBalancerStats(ctx context.Context, client *mongo.Client) (*proto.Balance
event := item.Id.Event
note := item.Id.Note
count := item.Count
switch event {
case "moveChunk.to", "moveChunk.from", "moveChunk.commit":
if note == "success" || note == "" {
@@ -760,7 +807,7 @@ func GetBalancerStats(ctx context.Context, client *mongo.Client) (*proto.Balance
}
func GetShardingChangelogStatus(ctx context.Context, client *mongo.Client) (*proto.ShardingChangelogStats, error) {
var qresults []proto.ShardingChangelogSummary
qresults := []proto.ShardingChangelogSummary{}
coll := client.Database("config").Collection("changelog")
match := primitive.M{"time": primitive.M{"$gt": time.Now().Add(-240 * time.Hour)}}
group := primitive.M{"_id": primitive.M{"event": "$what", "note": "$details.note"}, "count": primitive.M{"$sum": 1}}
@@ -776,6 +823,7 @@ func GetShardingChangelogStatus(ctx context.Context, client *mongo.Client) (*pro
if err := cursor.Decode(&res); err != nil {
return nil, errors.Wrap(err, "cannot decode GetShardingChangelogStatus")
}
qresults = append(qresults, res)
}
@@ -796,6 +844,7 @@ func isPrivateNetwork(ip string) (bool, error) {
if err != nil {
return false, err
}
addr := net.ParseIP(ip)
if cidrnet.Contains(addr) {
return true, nil
@@ -803,7 +852,6 @@ func isPrivateNetwork(ip string) (bool, error) {
}
return false, nil
}
func externalIP() (string, error) {
@@ -811,17 +859,21 @@ func externalIP() (string, error) {
if err != nil {
return "", err
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
@@ -830,9 +882,11 @@ func externalIP() (string, error) {
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
@@ -840,6 +894,7 @@ func externalIP() (string, error) {
return ip.String(), nil
}
}
return "", errors.New("are you connected to the network?")
}
@@ -880,8 +935,8 @@ func parseFlags() (*cliOptions, error) {
gop.StringVarLong(&opts.SSLPEMKeyFile, "sslPEMKeyFile", 0, "SSL client PEM file used for authentication")
gop.SetParameters("host[:port]")
gop.Parse(os.Args)
if gop.NArgs() > 0 {
opts.Host = gop.Arg(0)
gop.Parse(gop.Args())
@@ -889,19 +944,25 @@ func parseFlags() (*cliOptions, error) {
if gop.IsSet("password") && opts.Password == "" {
print("Password: ")
pass, err := gopass.GetPasswd()
if err != nil {
return opts, err
}
opts.Password = string(pass)
}
if !strings.HasPrefix(opts.Host, "mongodb://") {
opts.Host = "mongodb://" + opts.Host
}
if opts.Help {
gop.PrintUsage(os.Stdout)
return nil, nil
}
if opts.OutputFormat != "json" && opts.OutputFormat != "text" {
log.Infof("Invalid output format '%s'. Using text format", opts.OutputFormat)
}
@@ -920,27 +981,93 @@ func getChunksCount(ctx context.Context, client *mongo.Client) ([]proto.ChunksBy
if err != nil {
return nil, err
}
for cursor.Next(ctx) {
res := proto.ChunksByCollection{}
if err := cursor.Decode(&res); err != nil {
return nil, errors.Wrap(err, "cannot decode chunks aggregation")
}
result = append(result, res)
}
return result, nil
}
func getClientOptions(opts *cliOptions) *options.ClientOptions {
func getClientOptions(opts *cliOptions) (*options.ClientOptions, error) {
clientOptions := options.Client().ApplyURI(opts.Host)
clientOptions.ServerSelectionTimeout = &defaultConnectionTimeout
clientOptions.Direct = &directConnection
credential := options.Credential{}
if opts.User != "" {
credential.Username = opts.User
clientOptions.SetAuth(credential)
}
if opts.Password != "" {
credential.Password = opts.Password
credential.PasswordSet = true
clientOptions.SetAuth(credential)
}
return clientOptions
if opts.SSLPEMKeyFile != "" || opts.SSLCAFile != "" {
tlsConfig, err := getTLSConfig(opts.SSLPEMKeyFile, opts.SSLCAFile)
if err != nil {
return nil, errors.Wrap(err, "cannot read SSL certificate files")
}
clientOptions.TLSConfig = tlsConfig
}
return clientOptions, nil
}
func getTLSConfig(sslPEMKeyFile, sslCAFile string) (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS10,
InsecureSkipVerify: true,
}
roots := x509.NewCertPool()
if sslPEMKeyFile != "" {
crt, err := ioutil.ReadFile(filepath.Clean(expandHome(sslPEMKeyFile)))
if err != nil {
return nil, err
}
cert, err := tls.X509KeyPair(crt, crt)
if err != nil {
log.Fatal(err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
if sslCAFile != "" {
ca, err := ioutil.ReadFile(filepath.Clean(expandHome(sslCAFile)))
if err != nil {
return nil, err
}
roots.AppendCertsFromPEM(ca)
tlsConfig.RootCAs = roots
}
return tlsConfig, nil
}
func expandHome(path string) string {
usr, _ := user.Current()
dir := usr.HomeDir
switch {
case path == "~":
path = dir
case strings.HasPrefix(path, "~/"):
path = filepath.Join(dir, path[2:])
}
return path
}