package cookie

import (
	"context"
	"errors"
	"io"
	"net/http"
	"net/http/httptest"
	"sync/atomic"
	"testing"
	"time"

	clockutil "github.com/benbjohnson/clock"
	"github.com/google/go-cmp/cmp"
	"github.com/stretchr/testify/require"

	"github.com/influxdata/telegraf/config"
	"github.com/influxdata/telegraf/testutil"
)

const (
	reqUser      = "testUser"
	reqPasswd    = "testPassword"
	reqBody      = "a body"
	reqHeaderKey = "hello"
	reqHeaderVal = "world"

	authEndpointNoCreds                   = "/auth"
	authEndpointWithBasicAuth             = "/authWithCreds"
	authEndpointWithBasicAuthOnlyUsername = "/authWithCredsUser"
	authEndpointWithBody                  = "/authWithBody"
	authEndpointWithHeader                = "/authWithHeader"
)

var fakeCookie = &http.Cookie{
	Name:  "test-cookie",
	Value: "this is an auth cookie",
}

var reqHeaderValSecret = config.NewSecret([]byte(reqHeaderVal))

type fakeServer struct {
	*httptest.Server
	*int32
}

func newFakeServer(t *testing.T) fakeServer {
	var c int32
	return fakeServer{
		Server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			authed := func() {
				atomic.AddInt32(&c, 1)        // increment auth counter
				http.SetCookie(w, fakeCookie) // set fake cookie
			}
			switch r.URL.Path {
			case authEndpointNoCreds:
				authed()
			case authEndpointWithHeader:
				if !cmp.Equal(r.Header.Get(reqHeaderKey), reqHeaderVal) {
					w.WriteHeader(http.StatusUnauthorized)
					return
				}
				authed()
			case authEndpointWithBody:
				body, err := io.ReadAll(r.Body)
				if err != nil {
					w.WriteHeader(http.StatusInternalServerError)
					t.Error(err)
					return
				}
				if !cmp.Equal([]byte(reqBody), body) {
					w.WriteHeader(http.StatusUnauthorized)
					return
				}
				authed()
			case authEndpointWithBasicAuth:
				u, p, ok := r.BasicAuth()
				if !ok || u != reqUser || p != reqPasswd {
					w.WriteHeader(http.StatusUnauthorized)
					return
				}
				authed()
			case authEndpointWithBasicAuthOnlyUsername:
				u, p, ok := r.BasicAuth()
				if !ok || u != reqUser || p != "" {
					w.WriteHeader(http.StatusUnauthorized)
					return
				}
				authed()
			default:
				// ensure cookie exists on request
				if _, err := r.Cookie(fakeCookie.Name); err != nil {
					w.WriteHeader(http.StatusForbidden)
					return
				}
				if _, err := w.Write([]byte("good test response")); err != nil {
					w.WriteHeader(http.StatusInternalServerError)
					t.Error(err)
					return
				}
			}
		})),
		int32: &c,
	}
}

func (s fakeServer) checkResp(t *testing.T, expCode int) {
	t.Helper()
	resp, err := s.Client().Get(s.URL + "/endpoint")
	require.NoError(t, err)
	defer resp.Body.Close()
	require.Equal(t, expCode, resp.StatusCode)

	if expCode == http.StatusOK {
		require.Len(t, resp.Request.Cookies(), 1)
		require.Equal(t, "test-cookie", resp.Request.Cookies()[0].Name)
	}
}

func (s fakeServer) checkAuthCount(t *testing.T, atLeast int32) {
	t.Helper()
	require.GreaterOrEqual(t, atomic.LoadInt32(s.int32), atLeast)
}

func TestAuthConfig_Start(t *testing.T) {
	const (
		renewal      = 50 * time.Millisecond
		renewalCheck = 5 * renewal
	)
	type fields struct {
		Method   string
		Username string
		Password string
		Body     string
		Headers  map[string]*config.Secret
	}
	type args struct {
		renewal  time.Duration
		endpoint string
	}
	tests := []struct {
		name              string
		fields            fields
		args              args
		wantErr           error
		firstAuthCount    int32
		lastAuthCount     int32
		firstHTTPResponse int
		lastHTTPResponse  int
	}{
		{
			name: "success no creds, no body, default method",
			args: args{
				renewal:  renewal,
				endpoint: authEndpointNoCreds,
			},
			firstAuthCount:    1,
			lastAuthCount:     3,
			firstHTTPResponse: http.StatusOK,
			lastHTTPResponse:  http.StatusOK,
		},
		{
			name: "success no creds, no body, default method, header set",
			args: args{
				renewal:  renewal,
				endpoint: authEndpointWithHeader,
			},
			fields: fields{
				Headers: map[string]*config.Secret{reqHeaderKey: &reqHeaderValSecret},
			},
			firstAuthCount:    1,
			lastAuthCount:     3,
			firstHTTPResponse: http.StatusOK,
			lastHTTPResponse:  http.StatusOK,
		},
		{
			name: "success with creds, no body",
			fields: fields{
				Method:   http.MethodPost,
				Username: reqUser,
				Password: reqPasswd,
			},
			args: args{
				renewal:  renewal,
				endpoint: authEndpointWithBasicAuth,
			},
			firstAuthCount:    1,
			lastAuthCount:     3,
			firstHTTPResponse: http.StatusOK,
			lastHTTPResponse:  http.StatusOK,
		},
		{
			name: "failure with bad creds",
			fields: fields{
				Method:   http.MethodPost,
				Username: reqUser,
				Password: "a bad password",
			},
			args: args{
				renewal:  renewal,
				endpoint: authEndpointWithBasicAuth,
			},
			wantErr:           errors.New("cookie auth renewal received status code: 401 (Unauthorized) []"),
			firstAuthCount:    0,
			lastAuthCount:     0,
			firstHTTPResponse: http.StatusForbidden,
			lastHTTPResponse:  http.StatusForbidden,
		},
		{
			name: "success with no creds, with good body",
			fields: fields{
				Method: http.MethodPost,
				Body:   reqBody,
			},
			args: args{
				renewal:  renewal,
				endpoint: authEndpointWithBody,
			},
			firstAuthCount:    1,
			lastAuthCount:     3,
			firstHTTPResponse: http.StatusOK,
			lastHTTPResponse:  http.StatusOK,
		},
		{
			name: "failure with bad body",
			fields: fields{
				Method: http.MethodPost,
				Body:   "a bad body",
			},
			args: args{
				renewal:  renewal,
				endpoint: authEndpointWithBody,
			},
			wantErr:           errors.New("cookie auth renewal received status code: 401 (Unauthorized) []"),
			firstAuthCount:    0,
			lastAuthCount:     0,
			firstHTTPResponse: http.StatusForbidden,
			lastHTTPResponse:  http.StatusForbidden,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			srv := newFakeServer(t)
			c := &CookieAuthConfig{
				URL:      srv.URL + tt.args.endpoint,
				Method:   tt.fields.Method,
				Username: tt.fields.Username,
				Password: tt.fields.Password,
				Body:     tt.fields.Body,
				Headers:  tt.fields.Headers,
				Renewal:  config.Duration(tt.args.renewal),
			}
			if err := c.initializeClient(srv.Client()); tt.wantErr != nil {
				require.EqualError(t, err, tt.wantErr.Error())
			} else {
				require.NoError(t, err)
			}
			mock := clockutil.NewMock()
			ticker := mock.Ticker(time.Duration(c.Renewal))
			defer ticker.Stop()

			c.wg.Add(1)
			ctx, cancel := context.WithCancel(t.Context())
			go c.authRenewal(ctx, ticker, testutil.Logger{Name: "cookie_auth"})

			srv.checkAuthCount(t, tt.firstAuthCount)
			srv.checkResp(t, tt.firstHTTPResponse)
			mock.Add(renewalCheck)

			// Ensure that the auth renewal goroutine has completed
			require.Eventually(t, func() bool { return atomic.LoadInt32(srv.int32) >= tt.lastAuthCount }, time.Second, 10*time.Millisecond)

			cancel()
			c.wg.Wait()
			srv.checkAuthCount(t, tt.lastAuthCount)
			srv.checkResp(t, tt.lastHTTPResponse)

			srv.Close()
		})
	}
}
