// Copyright 2015 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cmd

import (
	"crypto/rand"
	"fmt"
	"os"
	"strings"

	"github.com/bgentry/speakeasy"
	"google.golang.org/grpc/grpclog"

	clientv3 "go.etcd.io/etcd/client/v3"
	"go.etcd.io/etcd/pkg/v3/report"
)

var (
	// cache the username and password for multiple connections
	globalUserName string
	globalPassword string
)

func getUsernamePassword(usernameFlag string) (string, string, error) {
	if globalUserName != "" && globalPassword != "" {
		return globalUserName, globalPassword, nil
	}
	var ok bool
	globalUserName, globalPassword, ok = strings.Cut(usernameFlag, ":")
	if !ok {
		// Prompt for the password.
		password, err := speakeasy.Ask("Password: ")
		if err != nil {
			return "", "", err
		}
		globalUserName = usernameFlag
		globalPassword = password
	}
	return globalUserName, globalPassword, nil
}

func mustCreateConn() *clientv3.Client {
	cfg := clientv3.Config{
		AutoSyncInterval: autoSyncInterval,
		Endpoints:        endpoints,
		DialTimeout:      dialTimeout,
	}
	if !tls.Empty() || tls.TrustedCAFile != "" {
		cfgtls, err := tls.ClientConfig()
		if err != nil {
			fmt.Fprintf(os.Stderr, "bad tls config: %v\n", err)
			os.Exit(1)
		}
		cfg.TLS = cfgtls
	}

	if len(user) != 0 {
		username, password, err := getUsernamePassword(user)
		if err != nil {
			fmt.Fprintf(os.Stderr, "bad user information: %s %v\n", user, err)
			os.Exit(1)
		}
		cfg.Username = username
		cfg.Password = password
	}

	client, err := clientv3.New(cfg)
	grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr))

	if err != nil {
		fmt.Fprintf(os.Stderr, "dial error: %v\n", err)
		os.Exit(1)
	}

	return client
}

func mustCreateClients(totalClients, totalConns uint) []*clientv3.Client {
	conns := make([]*clientv3.Client, totalConns)
	for i := range conns {
		conns[i] = mustCreateConn()
	}

	clients := make([]*clientv3.Client, totalClients)
	for i := range clients {
		clients[i] = conns[i%int(totalConns)]
	}
	return clients
}

func mustRandBytes(n int) []byte {
	rb := make([]byte, n)
	_, err := rand.Read(rb)
	if err != nil {
		fmt.Fprintf(os.Stderr, "failed to generate value: %v\n", err)
		os.Exit(1)
	}
	return rb
}

func newReport(benchmarkOp string) report.Report {
	p := "%4.4f"
	if precise {
		p = "%g"
	}
	if sample {
		return report.NewReportSample(p, benchmarkOp, generatePerfReport)
	}
	return report.NewReport(p, benchmarkOp, generatePerfReport)
}

func newWeightedReport(benchmarkOp string) report.Report {
	p := "%4.4f"
	if precise {
		p = "%g"
	}
	if sample {
		return report.NewReportSample(p, benchmarkOp, generatePerfReport)
	}
	return report.NewWeightedReport(report.NewReport(p, benchmarkOp, generatePerfReport), p, benchmarkOp, generatePerfReport)
}
