/*
Copyright 2022 The CDI Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
limitations under the License.
See the License for the specific language governing permissions and
*/

package controller

import (
	"context"
	"errors"
	"fmt"
	"reflect"

	"github.com/go-logr/logr"
	snapshotv1 "github.com/kubernetes-csi/external-snapshotter/client/v6/apis/volumesnapshot/v1"

	corev1 "k8s.io/api/core/v1"
	k8serrors "k8s.io/apimachinery/pkg/api/errors"
	"k8s.io/apimachinery/pkg/api/meta"
	"k8s.io/apimachinery/pkg/runtime"
	"k8s.io/apimachinery/pkg/types"
	"k8s.io/client-go/tools/record"

	"sigs.k8s.io/controller-runtime/pkg/client"
	"sigs.k8s.io/controller-runtime/pkg/controller"
	"sigs.k8s.io/controller-runtime/pkg/event"
	"sigs.k8s.io/controller-runtime/pkg/handler"
	"sigs.k8s.io/controller-runtime/pkg/manager"
	"sigs.k8s.io/controller-runtime/pkg/predicate"
	"sigs.k8s.io/controller-runtime/pkg/reconcile"
	"sigs.k8s.io/controller-runtime/pkg/source"

	cdiv1 "kubevirt.io/containerized-data-importer-api/pkg/apis/core/v1beta1"
	cc "kubevirt.io/containerized-data-importer/pkg/controller/common"
)

// DataSourceReconciler members
type DataSourceReconciler struct {
	client          client.Client
	recorder        record.EventRecorder
	scheme          *runtime.Scheme
	log             logr.Logger
	installerLabels map[string]string
}

const (
	ready                    = "Ready"
	noSource                 = "NoSource"
	dataSourceControllerName = "datasource-controller"
	maxReferenceDepthReached = "MaxReferenceDepthReached"
	selfReference            = "SelfReference"
	crossNamespaceReference  = "CrossNamespaceReference"

	dataSourcePvcField        = "spec.source.pvc"
	dataSourceSnapshotField   = "spec.source.snapshot"
	dataSourceDataSourceField = "spec.source.dataSource"
)

// Reconcile loop for DataSourceReconciler
func (r *DataSourceReconciler) Reconcile(ctx context.Context, req reconcile.Request) (reconcile.Result, error) {
	dataSource := &cdiv1.DataSource{}
	if err := r.client.Get(ctx, req.NamespacedName, dataSource); err != nil {
		if k8serrors.IsNotFound(err) {
			return reconcile.Result{}, nil
		}
		return reconcile.Result{}, err
	}
	if err := r.update(ctx, dataSource); err != nil {
		return reconcile.Result{}, err
	}
	return reconcile.Result{}, nil
}

func (r *DataSourceReconciler) update(ctx context.Context, dataSource *cdiv1.DataSource) error {
	dataSourceCopy := dataSource.DeepCopy()
	resolved, err := cc.ResolveDataSourceChain(ctx, r.client, dataSource)
	if err != nil {
		log := r.log.WithValues("datasource", dataSource.Name, "namespace", dataSource.Namespace)
		log.Info(err.Error())
		if err := handleDataSourceRefError(dataSource, err); err != nil {
			return err
		}
		resolved = dataSource
	} else {
		resolved.Spec.Source.DeepCopyInto(&dataSource.Status.Source)
		dataSource.Status.Conditions = nil
	}

	switch {
	case resolved.Spec.Source.DataSource != nil:
		// Status condition handling already took place, continue to update
	case resolved.Spec.Source.PVC != nil:
		if err := r.handlePvcSource(ctx, resolved.Spec.Source.PVC, dataSource); err != nil {
			return err
		}
	case resolved.Spec.Source.Snapshot != nil:
		if err := r.handleSnapshotSource(ctx, resolved.Spec.Source.Snapshot, dataSource); err != nil {
			return err
		}
	default:
		updateDataSourceCondition(dataSource, cdiv1.DataSourceReady, corev1.ConditionFalse, "No source PVC set", noSource)
	}

	if dsCond := FindDataSourceConditionByType(dataSource, cdiv1.DataSourceReady); dsCond != nil && dsCond.Status == corev1.ConditionFalse {
		dataSource.Status.Source = cdiv1.DataSourceSource{}
	}

	if !reflect.DeepEqual(dataSource, dataSourceCopy) {
		if err := r.client.Update(ctx, dataSource); err != nil {
			return err
		}
	}
	return nil
}

func (r *DataSourceReconciler) handlePvcSource(ctx context.Context, sourcePVC *cdiv1.DataVolumeSourcePVC, dataSource *cdiv1.DataSource) error {
	ns := cc.GetNamespace(sourcePVC.Namespace, dataSource.Namespace)
	isReady := false

	pvc := &corev1.PersistentVolumeClaim{}
	pvcErr := r.client.Get(ctx, types.NamespacedName{Namespace: ns, Name: sourcePVC.Name}, pvc)
	if pvcErr != nil && !k8serrors.IsNotFound(pvcErr) {
		return pvcErr
	}

	dv := &cdiv1.DataVolume{}
	if err := r.client.Get(ctx, types.NamespacedName{Namespace: ns, Name: sourcePVC.Name}, dv); err != nil {
		if !k8serrors.IsNotFound(err) {
			return err
		}
		if pvcErr != nil {
			r.log.Info("PVC not found", "name", sourcePVC.Name)
			updateDataSourceCondition(dataSource, cdiv1.DataSourceReady, corev1.ConditionFalse, "PVC not found", cc.NotFound)
		} else {
			isReady = true
		}
	} else if dv.Status.Phase == cdiv1.Succeeded {
		isReady = true
	} else {
		updateDataSourceCondition(dataSource, cdiv1.DataSourceReady, corev1.ConditionFalse, fmt.Sprintf("Import DataVolume phase %s", dv.Status.Phase), string(dv.Status.Phase))
	}

	if isReady {
		cc.CopyAllowedLabels(dv.GetLabels(), dataSource, true)
		cc.CopyAllowedLabels(pvc.GetLabels(), dataSource, true)
		updateDataSourceCondition(dataSource, cdiv1.DataSourceReady, corev1.ConditionTrue, "DataSource is ready to be consumed", ready)
	}

	return nil
}

func (r *DataSourceReconciler) handleSnapshotSource(ctx context.Context, sourceSnapshot *cdiv1.DataVolumeSourceSnapshot, dataSource *cdiv1.DataSource) error {
	snapshot := &snapshotv1.VolumeSnapshot{}
	ns := cc.GetNamespace(sourceSnapshot.Namespace, dataSource.Namespace)
	if err := r.client.Get(ctx, types.NamespacedName{Namespace: ns, Name: sourceSnapshot.Name}, snapshot); err != nil {
		if !k8serrors.IsNotFound(err) {
			return err
		}
		r.log.Info("Snapshot not found", "name", sourceSnapshot.Name)
		updateDataSourceCondition(dataSource, cdiv1.DataSourceReady, corev1.ConditionFalse, "Snapshot not found", cc.NotFound)
	} else if cc.IsSnapshotReady(snapshot) {
		cc.CopyAllowedLabels(snapshot.GetLabels(), dataSource, true)
		updateDataSourceCondition(dataSource, cdiv1.DataSourceReady, corev1.ConditionTrue, "DataSource is ready to be consumed", ready)
	} else {
		updateDataSourceCondition(dataSource, cdiv1.DataSourceReady, corev1.ConditionFalse, "Snapshot phase is not ready", "SnapshotNotReady")
	}

	return nil
}

func handleDataSourceRefError(dataSource *cdiv1.DataSource, err error) error {
	reason := ""
	switch {
	case errors.Is(err, cc.ErrDataSourceMaxDepthReached):
		reason = maxReferenceDepthReached
	case errors.Is(err, cc.ErrDataSourceSelfReference):
		reason = selfReference
	case errors.Is(err, cc.ErrDataSourceCrossNamespace):
		reason = crossNamespaceReference
	case k8serrors.IsNotFound(err):
		reason = cc.NotFound
	default:
		return err
	}
	updateDataSourceCondition(dataSource, cdiv1.DataSourceReady, corev1.ConditionFalse, err.Error(), reason)
	return nil
}

func updateDataSourceCondition(ds *cdiv1.DataSource, conditionType cdiv1.DataSourceConditionType, status corev1.ConditionStatus, message, reason string) {
	if condition := FindDataSourceConditionByType(ds, conditionType); condition != nil {
		updateConditionState(&condition.ConditionState, status, message, reason)
	} else {
		condition = &cdiv1.DataSourceCondition{Type: conditionType}
		updateConditionState(&condition.ConditionState, status, message, reason)
		ds.Status.Conditions = append(ds.Status.Conditions, *condition)
	}
}

// FindDataSourceConditionByType finds DataSourceCondition by condition type
func FindDataSourceConditionByType(ds *cdiv1.DataSource, conditionType cdiv1.DataSourceConditionType) *cdiv1.DataSourceCondition {
	for i, condition := range ds.Status.Conditions {
		if condition.Type == conditionType {
			return &ds.Status.Conditions[i]
		}
	}
	return nil
}

// NewDataSourceController creates a new instance of the DataSource controller
func NewDataSourceController(mgr manager.Manager, log logr.Logger, installerLabels map[string]string) (controller.Controller, error) {
	reconciler := &DataSourceReconciler{
		client:          mgr.GetClient(),
		recorder:        mgr.GetEventRecorderFor(dataSourceControllerName),
		scheme:          mgr.GetScheme(),
		log:             log.WithName(dataSourceControllerName),
		installerLabels: installerLabels,
	}
	DataSourceController, err := controller.New(dataSourceControllerName, mgr, controller.Options{
		MaxConcurrentReconciles: 3,
		Reconciler:              reconciler,
	})
	if err != nil {
		return nil, err
	}
	if err := addDataSourceControllerWatches(mgr, DataSourceController, log); err != nil {
		return nil, err
	}
	log.Info("Initialized DataSource controller")
	return DataSourceController, nil
}

func addDataSourceControllerWatches(mgr manager.Manager, c controller.Controller, log logr.Logger) error {
	if err := setupIndexers(mgr); err != nil {
		return err
	}
	if err := setupWatches(mgr, c, log); err != nil {
		return err
	}
	return nil
}

func setupIndexers(mgr manager.Manager) error {
	if err := mgr.GetFieldIndexer().IndexField(context.TODO(), &cdiv1.DataSource{}, dataSourcePvcField, func(obj client.Object) []string {
		if pvc := obj.(*cdiv1.DataSource).Spec.Source.PVC; pvc != nil {
			ns := cc.GetNamespace(pvc.Namespace, obj.GetNamespace())
			return []string{types.NamespacedName{Name: pvc.Name, Namespace: ns}.String()}
		}
		return nil
	}); err != nil {
		return err
	}

	if err := mgr.GetFieldIndexer().IndexField(context.TODO(), &cdiv1.DataSource{}, dataSourceSnapshotField, func(obj client.Object) []string {
		if snapshot := obj.(*cdiv1.DataSource).Spec.Source.Snapshot; snapshot != nil {
			ns := cc.GetNamespace(snapshot.Namespace, obj.GetNamespace())
			return []string{types.NamespacedName{Name: snapshot.Name, Namespace: ns}.String()}
		}
		return nil
	}); err != nil {
		return err
	}

	if err := mgr.GetFieldIndexer().IndexField(context.TODO(), &cdiv1.DataSource{}, dataSourceDataSourceField, func(obj client.Object) []string {
		if sourceDS := obj.(*cdiv1.DataSource).Spec.Source.DataSource; sourceDS != nil {
			ns := cc.GetNamespace(sourceDS.Namespace, obj.GetNamespace())
			return []string{types.NamespacedName{Name: sourceDS.Name, Namespace: ns}.String()}
		}
		return nil
	}); err != nil {
		return err
	}

	return nil
}

func setupWatches(mgr manager.Manager, c controller.Controller, log logr.Logger) error {
	if err := c.Watch(source.Kind(mgr.GetCache(), &cdiv1.DataSource{},
		handler.TypedEnqueueRequestsFromMapFunc(func(ctx context.Context, obj *cdiv1.DataSource) []reconcile.Request {
			reqs := []reconcile.Request{
				{
					NamespacedName: types.NamespacedName{
						Name:      obj.Name,
						Namespace: obj.Namespace,
					},
				},
			}
			return appendMatchingDataSourceRequests(ctx, mgr, dataSourceDataSourceField, obj, reqs, log)
		}),
		predicate.TypedFuncs[*cdiv1.DataSource]{
			CreateFunc: func(e event.TypedCreateEvent[*cdiv1.DataSource]) bool { return true },
			DeleteFunc: func(e event.TypedDeleteEvent[*cdiv1.DataSource]) bool { return true },
			UpdateFunc: func(e event.TypedUpdateEvent[*cdiv1.DataSource]) bool {
				return !sameSourceSpec(e.ObjectOld, e.ObjectNew) ||
					!sameConditions(e.ObjectOld, e.ObjectNew)
			},
		},
	)); err != nil {
		return err
	}

	if err := c.Watch(source.Kind(mgr.GetCache(), &cdiv1.DataVolume{},
		handler.TypedEnqueueRequestsFromMapFunc(func(ctx context.Context, obj *cdiv1.DataVolume) []reconcile.Request {
			return mapToDataSource(ctx, mgr, obj, log)
		}),
		predicate.TypedFuncs[*cdiv1.DataVolume]{
			CreateFunc: func(e event.TypedCreateEvent[*cdiv1.DataVolume]) bool { return true },
			DeleteFunc: func(e event.TypedDeleteEvent[*cdiv1.DataVolume]) bool { return true },
			// Only DV status phase update is interesting to reconcile
			UpdateFunc: func(e event.TypedUpdateEvent[*cdiv1.DataVolume]) bool {
				return e.ObjectOld.Status.Phase != e.ObjectNew.Status.Phase ||
					!reflect.DeepEqual(e.ObjectOld.Labels, e.ObjectNew.Labels)
			},
		},
	)); err != nil {
		return err
	}

	if err := c.Watch(source.Kind(mgr.GetCache(), &corev1.PersistentVolumeClaim{},
		handler.TypedEnqueueRequestsFromMapFunc(func(ctx context.Context, obj *corev1.PersistentVolumeClaim) []reconcile.Request {
			return mapToDataSource(ctx, mgr, obj, log)
		}),
		predicate.TypedFuncs[*corev1.PersistentVolumeClaim]{
			CreateFunc: func(e event.TypedCreateEvent[*corev1.PersistentVolumeClaim]) bool { return true },
			DeleteFunc: func(e event.TypedDeleteEvent[*corev1.PersistentVolumeClaim]) bool { return true },
			UpdateFunc: func(e event.TypedUpdateEvent[*corev1.PersistentVolumeClaim]) bool {
				return e.ObjectOld.Status.Phase != e.ObjectNew.Status.Phase ||
					!reflect.DeepEqual(e.ObjectOld.Labels, e.ObjectNew.Labels)
			},
		},
	)); err != nil {
		return err
	}

	if err := mgr.GetClient().List(context.TODO(), &snapshotv1.VolumeSnapshotList{}); err != nil {
		if meta.IsNoMatchError(err) {
			// Back out if there's no point to attempt watch
			return nil
		}
		if !cc.IsErrCacheNotStarted(err) {
			return err
		}
	}
	if err := c.Watch(source.Kind(mgr.GetCache(), &snapshotv1.VolumeSnapshot{},
		handler.TypedEnqueueRequestsFromMapFunc(func(ctx context.Context, obj *snapshotv1.VolumeSnapshot) []reconcile.Request {
			return mapToDataSource(ctx, mgr, obj, log)
		}),
		predicate.TypedFuncs[*snapshotv1.VolumeSnapshot]{
			CreateFunc: func(e event.TypedCreateEvent[*snapshotv1.VolumeSnapshot]) bool { return true },
			DeleteFunc: func(e event.TypedDeleteEvent[*snapshotv1.VolumeSnapshot]) bool { return true },
			UpdateFunc: func(e event.TypedUpdateEvent[*snapshotv1.VolumeSnapshot]) bool {
				return !reflect.DeepEqual(e.ObjectOld.Status, e.ObjectNew.Status) ||
					!reflect.DeepEqual(e.ObjectOld.Labels, e.ObjectNew.Labels)
			},
		},
	)); err != nil {
		return err
	}

	return nil
}

func appendMatchingDataSourceRequests(ctx context.Context, mgr manager.Manager, indexingKey string, obj client.Object, reqs []reconcile.Request, log logr.Logger) []reconcile.Request {
	var dataSources cdiv1.DataSourceList
	matchingFields := client.MatchingFields{indexingKey: client.ObjectKeyFromObject(obj).String()}
	if err := mgr.GetClient().List(ctx, &dataSources, matchingFields); err != nil {
		log.Error(err, "Unable to list DataSources", "matchingFields", matchingFields)
		return reqs
	}
	for _, ds := range dataSources.Items {
		reqs = append(reqs, reconcile.Request{NamespacedName: types.NamespacedName{Namespace: ds.Namespace, Name: ds.Name}})
	}
	return reqs
}

func mapToDataSource(ctx context.Context, mgr manager.Manager, obj client.Object, log logr.Logger) []reconcile.Request {
	reqs := appendMatchingDataSourceRequests(ctx, mgr, dataSourcePvcField, obj, nil, log)
	return appendMatchingDataSourceRequests(ctx, mgr, dataSourceSnapshotField, obj, reqs, log)
}

func sameSourceSpec(objOld, objNew client.Object) bool {
	dsOld, okOld := objOld.(*cdiv1.DataSource)
	dsNew, okNew := objNew.(*cdiv1.DataSource)

	if !okOld || !okNew {
		return false
	}
	if dsOld.Spec.Source.PVC != nil {
		return reflect.DeepEqual(dsOld.Spec.Source.PVC, dsNew.Spec.Source.PVC)
	}
	if dsOld.Spec.Source.Snapshot != nil {
		return reflect.DeepEqual(dsOld.Spec.Source.Snapshot, dsNew.Spec.Source.Snapshot)
	}
	if dsOld.Spec.Source.DataSource != nil {
		return reflect.DeepEqual(dsOld.Spec.Source.DataSource, dsNew.Spec.Source.DataSource)
	}

	return false
}

func sameConditions(objOld, objNew client.Object) bool {
	dsOld, okOld := objOld.(*cdiv1.DataSource)
	dsNew, okNew := objNew.(*cdiv1.DataSource)

	if !okOld || !okNew {
		return false
	}

	oldConditions := dsOld.Status.Conditions
	newConditions := dsNew.Status.Conditions

	if len(oldConditions) != len(newConditions) {
		return false
	}

	condMap := make(map[cdiv1.DataSourceConditionType]cdiv1.DataSourceCondition, len(oldConditions))
	for _, c := range oldConditions {
		condMap[c.Type] = c
	}

	for _, c := range newConditions {
		if oldC, ok := condMap[c.Type]; !ok ||
			oldC.Reason != c.Reason ||
			oldC.Message != c.Message ||
			oldC.Status != c.Status {
			return false
		}
	}

	return true
}
