diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index b6a4e24a8..0ab2906e9 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -3,6 +3,7 @@ package cluster // Postgres CustomResourceDefinition object i.e. Spilo import ( + "context" "database/sql" "encoding/json" "fmt" @@ -70,7 +71,7 @@ type kubeResources struct { CriticalOpPodDisruptionBudget *policyv1.PodDisruptionBudget LogicalBackupJob *batchv1.CronJob Streams map[string]*zalandov1.FabricEventStream - //Pods are treated separately + // Pods are treated separately } // Cluster describes postgresql cluster @@ -88,6 +89,11 @@ type Cluster struct { podSubscribersMu sync.RWMutex pgDb *sql.DB mu sync.Mutex + ctx context.Context + cancelFunc context.CancelFunc + syncMu sync.Mutex // protects syncRunning and needsResync + syncRunning bool + needsResync bool userSyncStrategy spec.UserSyncer deleteOptions metav1.DeleteOptions podEventsQueue *cache.FIFO @@ -95,7 +101,7 @@ type Cluster struct { teamsAPIClient teams.Interface oauthTokenGetter OAuthTokenGetter - KubeClient k8sutil.KubernetesClient //TODO: move clients to the better place? + KubeClient k8sutil.KubernetesClient // TODO: move clients to the better place? currentProcess Process processMu sync.RWMutex // protects the current operation for reporting, no need to hold the master mutex specMu sync.RWMutex // protects the spec for reporting, no need to hold the master mutex @@ -120,9 +126,12 @@ type compareLogicalBackupJobResult struct { } // New creates a new cluster. This function should be called from a controller. -func New(cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgresql, logger *logrus.Entry, eventRecorder record.EventRecorder) *Cluster { +func New(ctx context.Context, cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgresql, logger *logrus.Entry, eventRecorder record.EventRecorder) *Cluster { deletePropagationPolicy := metav1.DeletePropagationOrphan + // Create a cancellable context for this cluster + clusterCtx, cancelFunc := context.WithCancel(ctx) + podEventsQueue := cache.NewFIFO(func(obj interface{}) (string, error) { e, ok := obj.(PodEvent) if !ok { @@ -137,6 +146,8 @@ func New(cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgres } cluster := &Cluster{ + ctx: clusterCtx, + cancelFunc: cancelFunc, Config: cfg, Postgresql: pgSpec, pgUsers: make(map[string]spec.PgUser), @@ -149,7 +160,8 @@ func New(cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgres PatroniEndpoints: make(map[string]*v1.Endpoints), PatroniConfigMaps: make(map[string]*v1.ConfigMap), VolumeClaims: make(map[types.UID]*v1.PersistentVolumeClaim), - Streams: make(map[string]*zalandov1.FabricEventStream)}, + Streams: make(map[string]*zalandov1.FabricEventStream), + }, userSyncStrategy: users.DefaultUserSyncStrategy{ PasswordEncryption: passwordEncryption, RoleDeletionSuffix: cfg.OpConfig.RoleDeletionSuffix, @@ -175,6 +187,62 @@ func New(cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgres return cluster } +// Cancel cancels the cluster's context, which will cause any ongoing +// context-aware operations (like Sync) to return early. +func (c *Cluster) Cancel() { + if c.cancelFunc != nil { + c.cancelFunc() + } +} + +// StartSync attempts to start a sync operation. Returns true if sync can start +// (no sync currently running and context not cancelled). Returns false if a sync +// is already running (needsResync is set) or if context is cancelled (deletion in progress). +func (c *Cluster) StartSync() bool { + c.syncMu.Lock() + defer c.syncMu.Unlock() + + // Check if context is cancelled (deletion in progress) + select { + case <-c.ctx.Done(): + return false + default: + } + + if c.syncRunning { + c.needsResync = true + return false + } + c.syncRunning = true + c.needsResync = false + return true +} + +// EndSync marks the sync operation as complete. +func (c *Cluster) EndSync() { + c.syncMu.Lock() + defer c.syncMu.Unlock() + c.syncRunning = false +} + +// NeedsResync returns true if a resync was requested while sync was running, +// and clears the flag. Returns false if context is cancelled (deletion in progress). +func (c *Cluster) NeedsResync() bool { + c.syncMu.Lock() + defer c.syncMu.Unlock() + + // Check if context is cancelled (deletion in progress) + select { + case <-c.ctx.Done(): + return false + default: + } + + result := c.needsResync + c.needsResync = false + return result +} + func (c *Cluster) clusterName() spec.NamespacedName { return util.NameFromMeta(c.ObjectMeta) } @@ -276,7 +344,7 @@ func (c *Cluster) Create() (err error) { errStatus error ) if err == nil { - pgUpdatedStatus, errStatus = c.KubeClient.SetPostgresCRDStatus(c.clusterName(), acidv1.ClusterStatusRunning) //TODO: are you sure it's running? + pgUpdatedStatus, errStatus = c.KubeClient.SetPostgresCRDStatus(c.clusterName(), acidv1.ClusterStatusRunning) // TODO: are you sure it's running? } else { c.logger.Warningf("cluster created failed: %v", err) pgUpdatedStatus, errStatus = c.KubeClient.SetPostgresCRDStatus(c.clusterName(), acidv1.ClusterStatusAddFailed) @@ -440,7 +508,7 @@ func (c *Cluster) compareStatefulSetWith(statefulSet *appsv1.StatefulSet) *compa var match, needsRollUpdate, needsReplace bool match = true - //TODO: improve me + // TODO: improve me if *c.Statefulset.Spec.Replicas != *statefulSet.Spec.Replicas { match = false reasons = append(reasons, "new statefulset's number of replicas does not match the current one") @@ -672,7 +740,6 @@ func compareResourcesAssumeFirstNotNil(a *v1.ResourceRequirements, b *v1.Resourc } } return true - } func compareEnv(a, b []v1.EnvVar) bool { @@ -707,9 +774,7 @@ func compareEnv(a, b []v1.EnvVar) bool { } func compareSpiloConfiguration(configa, configb string) bool { - var ( - oa, ob spiloConfiguration - ) + var oa, ob spiloConfiguration var err error err = json.Unmarshal([]byte(configa), &oa) @@ -818,7 +883,6 @@ func (c *Cluster) compareAnnotations(old, new map[string]string, removedList *[] } return reason != "", reason - } func (c *Cluster) compareServices(old, new *v1.Service) (bool, string) { @@ -895,7 +959,7 @@ func (c *Cluster) compareLogicalBackupJob(cur, new *batchv1.CronJob) *compareLog } func (c *Cluster) comparePodDisruptionBudget(cur, new *policyv1.PodDisruptionBudget) (bool, string) { - //TODO: improve comparison + // TODO: improve comparison if !reflect.DeepEqual(new.Spec, cur.Spec) { return false, "new PDB's spec does not match the current one" } @@ -944,8 +1008,17 @@ func (c *Cluster) removeFinalizer() error { } c.logger.Infof("removing finalizer %s", finalizerName) - finalizers := util.RemoveString(c.ObjectMeta.Finalizers, finalizerName) - newSpec, err := c.KubeClient.SetFinalizer(c.clusterName(), c.DeepCopy(), finalizers) + + // Fetch the latest version of the object to avoid resourceVersion conflicts + clusterName := c.clusterName() + latestPg, err := c.KubeClient.PostgresqlsGetter.Postgresqls(clusterName.Namespace).Get( + context.TODO(), clusterName.Name, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("error fetching latest postgresql for finalizer removal: %v", err) + } + + finalizers := util.RemoveString(latestPg.ObjectMeta.Finalizers, finalizerName) + newSpec, err := c.KubeClient.SetFinalizer(clusterName, latestPg, finalizers) if err != nil { return fmt.Errorf("error removing finalizer: %v", err) } @@ -1063,7 +1136,7 @@ func (c *Cluster) Update(oldSpec, newSpec *acidv1.Postgresql) error { } c.logger.Debug("syncing secrets") - //TODO: mind the secrets of the deleted/new users + // TODO: mind the secrets of the deleted/new users if err := c.syncSecrets(); err != nil { c.logger.Errorf("could not sync secrets: %v", err) updateFailed = true @@ -1101,7 +1174,6 @@ func (c *Cluster) Update(oldSpec, newSpec *acidv1.Postgresql) error { // logical backup job func() { - // create if it did not exist if !oldSpec.Spec.EnableLogicalBackup && newSpec.Spec.EnableLogicalBackup { c.logger.Debug("creating backup cron job") @@ -1129,7 +1201,6 @@ func (c *Cluster) Update(oldSpec, newSpec *acidv1.Postgresql) error { updateFailed = true } } - }() // Roles and Databases @@ -1206,7 +1277,7 @@ func syncResources(a, b *v1.ResourceRequirements) bool { // before the pods, it will be re-created by the current master pod and will remain, obstructing the // creation of the new cluster with the same name. Therefore, the endpoints should be deleted last. func (c *Cluster) Delete() error { - var anyErrors = false + anyErrors := false c.mu.Lock() defer c.mu.Unlock() c.eventRecorder.Event(c.GetReference(), v1.EventTypeNormal, "Delete", "Started deletion of cluster resources") @@ -1297,7 +1368,6 @@ func (c *Cluster) NeedsRepair() (bool, acidv1.PostgresStatus) { c.specMu.RLock() defer c.specMu.RUnlock() return !c.Status.Success(), c.Status - } // ReceivePodEvent is called back by the controller in order to add the cluster's pod event to the queue. @@ -1406,7 +1476,6 @@ func (c *Cluster) initSystemUsers() { } func (c *Cluster) initPreparedDatabaseRoles() error { - if c.Spec.PreparedDatabases != nil && len(c.Spec.PreparedDatabases) == 0 { // TODO: add option to disable creating such a default DB c.Spec.PreparedDatabases = map[string]acidv1.PreparedDatabase{strings.Replace(c.Name, "-", "_", -1): {}} } @@ -1472,10 +1541,9 @@ func (c *Cluster) initPreparedDatabaseRoles() error { } func (c *Cluster) initDefaultRoles(defaultRoles map[string]string, admin, prefix, searchPath, secretNamespace string) error { - for defaultRole, inherits := range defaultRoles { namespace := c.Namespace - //if namespaced secrets are allowed + // if namespaced secrets are allowed if secretNamespace != "" { if c.Config.OpConfig.EnableCrossNamespaceSecret { namespace = secretNamespace @@ -1543,7 +1611,7 @@ func (c *Cluster) initRobotUsers() error { } } - //if namespaced secrets are allowed + // if namespaced secrets are allowed if c.Config.OpConfig.EnableCrossNamespaceSecret { if strings.Contains(username, ".") { splits := strings.Split(username, ".") @@ -1594,7 +1662,6 @@ func (c *Cluster) initAdditionalOwnerRoles() { func (c *Cluster) initTeamMembers(teamID string, isPostgresSuperuserTeam bool) error { teamMembers, err := c.getTeamMembers(teamID) - if err != nil { return fmt.Errorf("could not get list of team members for team %q: %v", teamID, err) } @@ -1633,7 +1700,6 @@ func (c *Cluster) initTeamMembers(teamID string, isPostgresSuperuserTeam bool) e } func (c *Cluster) initHumanUsers() error { - var clusterIsOwnedBySuperuserTeam bool superuserTeams := []string{} diff --git a/pkg/cluster/cluster_test.go b/pkg/cluster/cluster_test.go index d78d4c92e..cf5f4aae1 100644 --- a/pkg/cluster/cluster_test.go +++ b/pkg/cluster/cluster_test.go @@ -43,7 +43,7 @@ var logger = logrus.New().WithField("test", "cluster") // 1 cluster, primary endpoint, 2 services, the secrets, the statefulset and pods being ready var eventRecorder = record.NewFakeRecorder(7) -var cl = New( +var cl = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -135,7 +135,7 @@ func TestCreate(t *testing.T) { client.Postgresqls(clusterNamespace).Create(context.TODO(), &pg, metav1.CreateOptions{}) client.Pods(clusterNamespace).Create(context.TODO(), &pod, metav1.CreateOptions{}) - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -1629,7 +1629,7 @@ func TestCompareLogicalBackupJob(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -1778,7 +1778,7 @@ func TestCrossNamespacedSecrets(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ConnectionPooler: config.ConnectionPooler{ diff --git a/pkg/cluster/connection_pooler_test.go b/pkg/cluster/connection_pooler_test.go index 78d1c2527..cda965e9a 100644 --- a/pkg/cluster/connection_pooler_test.go +++ b/pkg/cluster/connection_pooler_test.go @@ -161,7 +161,7 @@ func noEmptySync(cluster *Cluster, err error, reason SyncReason) error { func TestNeedConnectionPooler(t *testing.T) { testName := "Test how connection pooler can be enabled" - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -297,7 +297,7 @@ func TestConnectionPoolerCreateDeletion(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ConnectionPooler: config.ConnectionPooler{ @@ -405,7 +405,7 @@ func TestConnectionPoolerSync(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ConnectionPooler: config.ConnectionPooler{ @@ -667,7 +667,7 @@ func TestConnectionPoolerSync(t *testing.T) { func TestConnectionPoolerPodSpec(t *testing.T) { testName := "Test connection pooler pod template generation" - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -690,7 +690,7 @@ func TestConnectionPoolerPodSpec(t *testing.T) { ConnectionPooler: &acidv1.ConnectionPooler{}, EnableReplicaConnectionPooler: boolToPointer(true), } - var clusterNoDefaultRes = New( + var clusterNoDefaultRes = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -780,7 +780,7 @@ func TestConnectionPoolerPodSpec(t *testing.T) { func TestConnectionPoolerDeploymentSpec(t *testing.T) { testName := "Test connection pooler deployment spec generation" - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -983,7 +983,7 @@ func TestPoolerTLS(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -1063,7 +1063,7 @@ func TestPoolerTLS(t *testing.T) { func TestConnectionPoolerServiceSpec(t *testing.T) { testName := "Test connection pooler service spec generation" - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, diff --git a/pkg/cluster/database.go b/pkg/cluster/database.go index 56b5f3638..4f3d5c775 100644 --- a/pkg/cluster/database.go +++ b/pkg/cluster/database.go @@ -2,6 +2,7 @@ package cluster import ( "bytes" + "context" "database/sql" "fmt" "net" @@ -121,26 +122,31 @@ func (c *Cluster) initDbConn() error { if c.pgDb != nil { return nil } - return c.initDbConnWithName("") } -// Worker function for connection initialization. This function does not check -// if the connection is already open, if it is then it will be overwritten. -// Callers need to make sure no connection is open, otherwise we could leak -// connections +// initDbConnWithName initializes a database connection using the cluster's context. +// This function does not check if the connection is already open. func (c *Cluster) initDbConnWithName(dbname string) error { + return c.initDbConnWithNameContext(c.ctx, dbname) +} + +// initDbConnWithNameContext initializes a database connection with an explicit context. +// Use this when you need a custom context (e.g., different timeout, or context.Background() +// for operations that should not be cancelled). This function does not check if the +// connection is already open, callers need to ensure no connection is open to avoid leaks. +func (c *Cluster) initDbConnWithNameContext(ctx context.Context, dbname string) error { c.setProcessName("initializing db connection") var conn *sql.DB connstring := c.pgConnectionString(dbname) - finalerr := retryutil.Retry(constants.PostgresConnectTimeout, constants.PostgresConnectRetryTimeout, + finalerr := retryutil.RetryWithContext(ctx, constants.PostgresConnectTimeout, constants.PostgresConnectRetryTimeout, func() (bool, error) { var err error conn, err = sql.Open("postgres", connstring) if err == nil { - err = conn.Ping() + err = conn.PingContext(ctx) } if err == nil { @@ -268,9 +274,7 @@ func findUsersFromRotation(rotatedUsers []string, db *sql.DB) (map[string]string }() for rows.Next() { - var ( - rolname, roldatesuffix string - ) + var rolname, roldatesuffix string err := rows.Scan(&rolname, &roldatesuffix) if err != nil { return nil, fmt.Errorf("error when processing rows of deprecated users: %v", err) @@ -331,9 +335,7 @@ func (c *Cluster) cleanupRotatedUsers(rotatedUsers []string) error { // getDatabases returns the map of current databases with owners // The caller is responsible for opening and closing the database connection func (c *Cluster) getDatabases() (dbs map[string]string, err error) { - var ( - rows *sql.Rows - ) + var rows *sql.Rows if rows, err = c.pgDb.Query(getDatabasesSQL); err != nil { return nil, fmt.Errorf("could not query database: %v", err) @@ -551,9 +553,7 @@ func (c *Cluster) getOwnerRoles(dbObjPath string, withUser bool) (owners []strin // getExtension returns the list of current database extensions // The caller is responsible for opening and closing the database connection func (c *Cluster) getExtensions() (dbExtensions map[string]string, err error) { - var ( - rows *sql.Rows - ) + var rows *sql.Rows if rows, err = c.pgDb.Query(getExtensionsSQL); err != nil { return nil, fmt.Errorf("could not query database extensions: %v", err) @@ -598,7 +598,6 @@ func (c *Cluster) executeAlterExtension(extName, schemaName string) error { } func (c *Cluster) execCreateOrAlterExtension(extName, schemaName, statement, doing, operation string) error { - c.logger.Infof("%s %q schema %q", doing, extName, schemaName) if _, err := c.pgDb.Exec(fmt.Sprintf(statement, extName, schemaName)); err != nil { return fmt.Errorf("could not execute %s: %v", operation, err) @@ -610,9 +609,7 @@ func (c *Cluster) execCreateOrAlterExtension(extName, schemaName, statement, doi // getPublications returns the list of current database publications with tables // The caller is responsible for opening and closing the database connection func (c *Cluster) getPublications() (publications map[string]string, err error) { - var ( - rows *sql.Rows - ) + var rows *sql.Rows if rows, err = c.pgDb.Query(getPublicationsSQL); err != nil { return nil, fmt.Errorf("could not query database publications: %v", err) @@ -668,7 +665,6 @@ func (c *Cluster) executeAlterPublication(pubName, tableList string) error { } func (c *Cluster) execCreateOrAlterPublication(pubName, tableList, statement, doing, operation string) error { - c.logger.Debugf("%s %q with table list %q", doing, pubName, tableList) if _, err := c.pgDb.Exec(fmt.Sprintf(statement, pubName, tableList)); err != nil { return fmt.Errorf("could not execute %s: %v", operation, err) @@ -743,7 +739,6 @@ func (c *Cluster) installLookupFunction(poolerSchema, poolerUser string) error { constants.PostgresConnectTimeout, constants.PostgresConnectRetryTimeout, func() (bool, error) { - // At this moment we are not connected to any database if err := c.initDbConnWithName(dbname); err != nil { msg := "could not init database connection to %s" diff --git a/pkg/cluster/k8sres_test.go b/pkg/cluster/k8sres_test.go index 6bd87366d..7c98eb6f0 100644 --- a/pkg/cluster/k8sres_test.go +++ b/pkg/cluster/k8sres_test.go @@ -52,7 +52,7 @@ type ExpectedValue struct { } func TestGenerateSpiloJSONConfiguration(t *testing.T) { - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -1143,7 +1143,7 @@ func TestGetNumberOfInstances(t *testing.T) { } for _, tt := range tests { - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: tt.config, }, k8sutil.KubernetesClient{}, acidv1.Postgresql{}, logger, eventRecorder) @@ -1210,7 +1210,7 @@ func TestCloneEnv(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ WALES3Bucket: "wale-bucket", @@ -1384,7 +1384,7 @@ func TestStandbyEnv(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{}, k8sutil.KubernetesClient{}, acidv1.Postgresql{}, logger, eventRecorder) for _, tt := range tests { @@ -1431,7 +1431,7 @@ func TestNodeAffinity(t *testing.T) { } } - cluster = New( + cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -1524,7 +1524,7 @@ func TestPodAffinity(t *testing.T) { } for _, tt := range tests { - cluster := New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ EnablePodAntiAffinity: tt.anti, @@ -1687,7 +1687,7 @@ func TestTLS(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -1942,7 +1942,7 @@ func TestAdditionalVolume(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -2083,7 +2083,7 @@ func TestVolumeSelector(t *testing.T) { }, } - cluster := New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -2182,7 +2182,7 @@ func TestSidecars(t *testing.T) { }, } - cluster = New( + cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -2491,7 +2491,7 @@ func TestContainerValidation(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - cluster := New(tc.clusterConfig, k8sutil.KubernetesClient{}, acidv1.Postgresql{}, logger, eventRecorder) + cluster := New(context.Background(), tc.clusterConfig, k8sutil.KubernetesClient{}, acidv1.Postgresql{}, logger, eventRecorder) _, err := cluster.generateStatefulSet(&tc.spec) @@ -2580,7 +2580,7 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }{ { scenario: "With multiple instances", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb"}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ @@ -2597,7 +2597,7 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With zero instances", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb"}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ @@ -2614,7 +2614,7 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With PodDisruptionBudget disabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb", EnablePodDisruptionBudget: util.False()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ @@ -2631,7 +2631,7 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With non-default PDBNameFormat and PodDisruptionBudget explicitly enabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-databass-budget", EnablePodDisruptionBudget: util.True()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ @@ -2648,7 +2648,7 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With PDBMasterLabelSelector disabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb", EnablePodDisruptionBudget: util.True(), PDBMasterLabelSelector: util.False()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ @@ -2665,7 +2665,7 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With OwnerReference enabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role", EnableOwnerReferences: util.True()}, PDBNameFormat: "postgres-{cluster}-pdb", EnablePodDisruptionBudget: util.True()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ @@ -2700,7 +2700,7 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }{ { scenario: "With multiple instances", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb"}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ @@ -2717,7 +2717,7 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With zero instances", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb"}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ @@ -2734,7 +2734,7 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With PodDisruptionBudget disabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb", EnablePodDisruptionBudget: util.False()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ @@ -2751,7 +2751,7 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With OwnerReference enabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role", EnableOwnerReferences: util.True()}, PDBNameFormat: "postgres-{cluster}-pdb", EnablePodDisruptionBudget: util.True()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ @@ -2813,7 +2813,7 @@ func TestGenerateService(t *testing.T) { EnableMasterLoadBalancer: &enableLB, } - cluster = New( + cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -2864,7 +2864,7 @@ func TestGenerateService(t *testing.T) { } func TestCreateLoadBalancerLogic(t *testing.T) { - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -3076,7 +3076,7 @@ func TestEnableLoadBalancers(t *testing.T) { } for _, tt := range tests { - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: tt.config, }, client, tt.pgSpec, logger, eventRecorder) @@ -3708,7 +3708,7 @@ func TestGenerateResourceRequirements(t *testing.T) { } for _, tt := range tests { - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: tt.config, }, client, tt.pgSpec, logger, newEventRecorder) @@ -3893,7 +3893,7 @@ func TestGenerateLogicalBackupJob(t *testing.T) { } for _, tt := range tests { - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: tt.config, }, k8sutil.NewMockKubernetesClient(), acidv1.Postgresql{}, logger, eventRecorder) diff --git a/pkg/cluster/pod_test.go b/pkg/cluster/pod_test.go index 6816b4d7a..854f7e8c2 100644 --- a/pkg/cluster/pod_test.go +++ b/pkg/cluster/pod_test.go @@ -2,6 +2,7 @@ package cluster import ( "bytes" + "context" "fmt" "io" "net/http" @@ -24,7 +25,7 @@ func TestGetSwitchoverCandidate(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ PatroniAPICheckInterval: time.Duration(1), diff --git a/pkg/cluster/streams_test.go b/pkg/cluster/streams_test.go index 934f2bfd4..86f26eea5 100644 --- a/pkg/cluster/streams_test.go +++ b/pkg/cluster/streams_test.go @@ -223,7 +223,7 @@ var ( }, } - cluster = New( + cluster = New(context.Background(), Config{ OpConfig: config.Config{ Auth: config.Auth{ @@ -529,7 +529,7 @@ func newFabricEventStream(streams []zalandov1.EventStream, annotations map[strin func TestSyncStreams(t *testing.T) { newClusterName := fmt.Sprintf("%s-2", pg.Name) pg.Name = newClusterName - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -688,7 +688,7 @@ func TestSameStreams(t *testing.T) { func TestUpdateStreams(t *testing.T) { pg.Name = fmt.Sprintf("%s-3", pg.Name) - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -787,7 +787,7 @@ func patchPostgresqlStreams(t *testing.T, cluster *Cluster, pgSpec *acidv1.Postg func TestDeleteStreams(t *testing.T) { pg.Name = fmt.Sprintf("%s-4", pg.Name) - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", diff --git a/pkg/cluster/sync.go b/pkg/cluster/sync.go index ecf692702..320be6cc6 100644 --- a/pkg/cluster/sync.go +++ b/pkg/cluster/sync.go @@ -70,7 +70,7 @@ func (c *Cluster) Sync(newSpec *acidv1.Postgresql) error { return err } - //TODO: mind the secrets of the deleted/new users + // TODO: mind the secrets of the deleted/new users if err = c.syncSecrets(); err != nil { err = fmt.Errorf("could not sync secrets: %v", err) return err @@ -856,7 +856,6 @@ func (c *Cluster) restartInstance(pod *v1.Pod, restartWait uint32) error { // AnnotationsToPropagate get the annotations to update if required // based on the annotations in postgres CRD func (c *Cluster) AnnotationsToPropagate(annotations map[string]string) map[string]string { - if annotations == nil { annotations = make(map[string]string) } @@ -1110,7 +1109,8 @@ func (c *Cluster) updateSecret( secretUsername string, generatedSecret *v1.Secret, retentionUsers *[]string, - currentTime time.Time) (*v1.Secret, error) { + currentTime time.Time, +) (*v1.Secret, error) { var ( secret *v1.Secret err error @@ -1244,7 +1244,8 @@ func (c *Cluster) rotatePasswordInSecret( secretUsername string, roleOrigin spec.RoleOrigin, currentTime time.Time, - retentionUsers *[]string) (string, error) { + retentionUsers *[]string, +) (string, error) { var ( err error nextRotationDate time.Time @@ -1469,7 +1470,7 @@ func (c *Cluster) syncDatabases() error { preparedDatabases := make([]string, 0) if err := c.initDbConn(); err != nil { - return fmt.Errorf("could not init database connection") + return fmt.Errorf("could not init database connection: %v", err) } defer func() { if err := c.closeDbConn(); err != nil { @@ -1553,7 +1554,7 @@ func (c *Cluster) syncPreparedDatabases() error { errors := make([]string, 0) for preparedDbName, preparedDB := range c.Spec.PreparedDatabases { - if err := c.initDbConnWithName(preparedDbName); err != nil { + if err := c.initDbConnWithNameContext(c.ctx, preparedDbName); err != nil { errors = append(errors, fmt.Sprintf("could not init connection to database %s: %v", preparedDbName, err)) continue } @@ -1697,7 +1698,8 @@ func (c *Cluster) syncLogicalBackupJob() error { } if len(cmp.deletedPodAnnotations) != 0 { templateMetadataReq := map[string]map[string]map[string]map[string]map[string]map[string]map[string]*string{ - "spec": {"jobTemplate": {"spec": {"template": {"metadata": {"annotations": {}}}}}}} + "spec": {"jobTemplate": {"spec": {"template": {"metadata": {"annotations": {}}}}}}, + } for _, anno := range cmp.deletedPodAnnotations { templateMetadataReq["spec"]["jobTemplate"]["spec"]["template"]["metadata"]["annotations"][anno] = nil } diff --git a/pkg/cluster/sync_test.go b/pkg/cluster/sync_test.go index 87e9dc8a5..dda862d67 100644 --- a/pkg/cluster/sync_test.go +++ b/pkg/cluster/sync_test.go @@ -88,7 +88,7 @@ func TestSyncStatefulSetsAnnotations(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -184,7 +184,7 @@ func TestPodAnnotationsSync(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PatroniAPICheckInterval: time.Duration(1), @@ -369,7 +369,7 @@ func TestCheckAndSetGlobalPostgreSQLConfiguration(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -691,7 +691,7 @@ func TestSyncStandbyClusterConfiguration(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PatroniAPICheckInterval: time.Duration(1), @@ -844,7 +844,7 @@ func TestUpdateSecret(t *testing.T) { } // new cluster with enabled password rotation - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Auth: config.Auth{ @@ -988,7 +988,7 @@ func TestUpdateSecretNameConflict(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Auth: config.Auth{ diff --git a/pkg/cluster/util_test.go b/pkg/cluster/util_test.go index 9cd7dc7e9..e9c3ca1bb 100644 --- a/pkg/cluster/util_test.go +++ b/pkg/cluster/util_test.go @@ -288,7 +288,7 @@ func newInheritedAnnotationsCluster(client k8sutil.KubernetesClient) (*Cluster, }, } - cluster := New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ PatroniAPICheckInterval: time.Duration(1), diff --git a/pkg/cluster/volumes_test.go b/pkg/cluster/volumes_test.go index 95ecc7624..4d9eb0189 100644 --- a/pkg/cluster/volumes_test.go +++ b/pkg/cluster/volumes_test.go @@ -59,7 +59,7 @@ func TestResizeVolumeClaim(t *testing.T) { assert.NoError(t, err) // new cluster with pvc storage resize mode and configured labels - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Resources: config.Resources{ @@ -185,7 +185,7 @@ func TestMigrateEBS(t *testing.T) { namespace := "default" // new cluster with pvc storage resize mode and configured labels - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Resources: config.Resources{ @@ -293,7 +293,7 @@ func TestMigrateGp3Support(t *testing.T) { namespace := "default" // new cluster with pvc storage resize mode and configured labels - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Resources: config.Resources{ @@ -355,7 +355,7 @@ func TestManualGp2Gp3Support(t *testing.T) { namespace := "default" // new cluster with pvc storage resize mode and configured labels - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Resources: config.Resources{ @@ -415,7 +415,7 @@ func TestDontTouchType(t *testing.T) { namespace := "default" // new cluster with pvc storage resize mode and configured labels - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Resources: config.Resources{ diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index e46b9ee44..aa6262264 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -277,7 +277,7 @@ func (c *Controller) initRoleBinding() { }`, c.PodServiceAccount.Name, c.PodServiceAccount.Name, c.PodServiceAccount.Name) c.opConfig.PodServiceAccountRoleBindingDefinition = compactValue(stringValue) } - c.logger.Info("Parse role bindings") + // re-uses k8s internal parsing. See k8s client-go issue #193 for explanation decode := scheme.Codecs.UniversalDeserializer().Decode obj, groupVersionKind, err := decode([]byte(c.opConfig.PodServiceAccountRoleBindingDefinition), nil, nil) diff --git a/pkg/controller/postgresql.go b/pkg/controller/postgresql.go index 824a030f4..3a1f748f9 100644 --- a/pkg/controller/postgresql.go +++ b/pkg/controller/postgresql.go @@ -166,7 +166,7 @@ func (c *Controller) addCluster(lg *logrus.Entry, clusterName spec.NamespacedNam } } - cl := cluster.New(c.makeClusterConfig(), c.KubeClient, *pgSpec, lg, c.eventRecorder) + cl := cluster.New(context.Background(), c.makeClusterConfig(), c.KubeClient, *pgSpec, lg, c.eventRecorder) cl.Run(c.stopCh) teamName := strings.ToLower(cl.Spec.TeamID) @@ -258,13 +258,26 @@ func (c *Controller) processEvent(event ClusterEvent) { lg.Infoln("cluster has been created") case EventUpdate: - lg.Infoln("update of the cluster started") - if !clusterFound { lg.Warningln("cluster does not exist") return } c.curWorkerCluster.Store(event.WorkerID, cl) + + // Check if this cluster has been marked for deletion + if !event.NewSpec.ObjectMeta.DeletionTimestamp.IsZero() { + lg.Infof("cluster has a DeletionTimestamp of %s, starting deletion now.", event.NewSpec.ObjectMeta.DeletionTimestamp.Format(time.RFC3339)) + cl.Cancel() // Cancel any ongoing operations + if err = cl.Delete(); err != nil { + cl.Error = fmt.Sprintf("error deleting cluster and its resources: %v", err) + c.eventRecorder.Eventf(cl.GetReference(), v1.EventTypeWarning, "Delete", "%v", cl.Error) + lg.Error(cl.Error) + return + } + lg.Infoln("cluster has been deleted via update event") + return + } + err = cl.Update(event.OldSpec, event.NewSpec) if err != nil { cl.Error = fmt.Sprintf("could not update cluster: %v", err) @@ -292,6 +305,7 @@ func (c *Controller) processEvent(event ClusterEvent) { // when using finalizers the deletion already happened if c.opConfig.EnableFinalizers == nil || !*c.opConfig.EnableFinalizers { lg.Infoln("deletion of the cluster started") + cl.Cancel() // Cancel any ongoing operations if err := cl.Delete(); err != nil { cl.Error = fmt.Sprintf("could not delete cluster: %v", err) c.eventRecorder.Eventf(cl.GetReference(), v1.EventTypeWarning, "Delete", "%v", cl.Error) @@ -317,8 +331,6 @@ func (c *Controller) processEvent(event ClusterEvent) { lg.Infof("cluster has been deleted") case EventSync: - lg.Infof("syncing of the cluster started") - // no race condition because a cluster is always processed by single worker if !clusterFound { cl, err = c.addCluster(lg, clusterName, event.NewSpec) @@ -333,22 +345,42 @@ func (c *Controller) processEvent(event ClusterEvent) { // has this cluster been marked as deleted already, then we shall start cleaning up if !cl.ObjectMeta.DeletionTimestamp.IsZero() { lg.Infof("cluster has a DeletionTimestamp of %s, starting deletion now.", cl.ObjectMeta.DeletionTimestamp.Format(time.RFC3339)) + cl.Cancel() // Cancel any ongoing operations if err = cl.Delete(); err != nil { cl.Error = fmt.Sprintf("error deleting cluster and its resources: %v", err) c.eventRecorder.Eventf(cl.GetReference(), v1.EventTypeWarning, "Delete", "%v", cl.Error) lg.Error(cl.Error) return } - } else { - if err = cl.Sync(event.NewSpec); err != nil { + return + } + + // Try to start sync - returns false if sync already running or cluster deleted + if !cl.StartSync() { + lg.Infof("sync already in progress, will resync when current sync completes") + return + } + + // Run sync in background goroutine so we can process other events (like delete) + lg.Infof("syncing of the cluster started (background)") + go func() { + defer cl.EndSync() + + if err := cl.Sync(event.NewSpec); err != nil { cl.Error = fmt.Sprintf("could not sync cluster: %v", err) c.eventRecorder.Eventf(cl.GetReference(), v1.EventTypeWarning, "Sync", "%v", cl.Error) lg.Error(cl.Error) return } + cl.Error = "" lg.Infof("cluster has been synced") - } - cl.Error = "" + + // Check if resync was requested while we were syncing + if cl.NeedsResync() { + lg.Infof("resync requested, queueing new sync event") + c.queueClusterEvent(nil, event.NewSpec, EventSync) + } + }() } } @@ -379,7 +411,6 @@ func (c *Controller) processClusterEventsQueue(idx int, stopCh <-chan struct{}, } func (c *Controller) warnOnDeprecatedPostgreSQLSpecParameters(spec *acidv1.PostgresSpec) { - deprecate := func(deprecated, replacement string) { c.logger.Warningf("parameter %q is deprecated. Consider setting %q instead", deprecated, replacement) } @@ -425,7 +456,7 @@ func (c *Controller) queueClusterEvent(informerOldSpec, informerNewSpec *acidv1. clusterError string ) - if informerOldSpec != nil { //update, delete + if informerOldSpec != nil { // update, delete uid = informerOldSpec.GetUID() clusterName = util.NameFromMeta(informerOldSpec.ObjectMeta) @@ -440,7 +471,7 @@ func (c *Controller) queueClusterEvent(informerOldSpec, informerNewSpec *acidv1. } else { clusterError = informerOldSpec.Error } - } else { //add, sync + } else { // add, sync uid = informerNewSpec.GetUID() clusterName = util.NameFromMeta(informerNewSpec.ObjectMeta) clusterError = informerNewSpec.Error @@ -465,6 +496,19 @@ func (c *Controller) queueClusterEvent(informerOldSpec, informerNewSpec *acidv1. } } + // If the cluster is marked for deletion, cancel any ongoing operations immediately + // This unblocks stuck Sync operations so the delete can proceed + if informerNewSpec != nil && !informerNewSpec.ObjectMeta.DeletionTimestamp.IsZero() { + c.clustersMu.RLock() + if cl, found := c.clusters[clusterName]; found { + c.logger.WithField("cluster-name", clusterName).Infof( + "cluster marked for deletion (DeletionTimestamp: %s), cancelling ongoing operations", + informerNewSpec.ObjectMeta.DeletionTimestamp.Format(time.RFC3339)) + cl.Cancel() + } + c.clustersMu.RUnlock() + } + if clusterError != "" && eventType != EventDelete { c.logger.WithField("cluster-name", clusterName).Debugf("skipping %q event for the invalid cluster: %s", eventType, clusterError) @@ -539,12 +583,28 @@ func (c *Controller) postgresqlUpdate(prev, cur interface{}) { pgOld := c.postgresqlCheck(prev) pgNew := c.postgresqlCheck(cur) if pgOld != nil && pgNew != nil { - // Avoid the inifinite recursion for status updates + clusterName := util.NameFromMeta(pgNew.ObjectMeta) + + // Check if DeletionTimestamp was set (resource marked for deletion) + deletionTimestampChanged := pgOld.ObjectMeta.DeletionTimestamp.IsZero() && !pgNew.ObjectMeta.DeletionTimestamp.IsZero() + if deletionTimestampChanged { + c.logger.WithField("cluster-name", clusterName).Infof( + "UPDATE event: DeletionTimestamp set to %s, queueing event", + pgNew.ObjectMeta.DeletionTimestamp.Format(time.RFC3339)) + c.queueClusterEvent(pgOld, pgNew, EventUpdate) + return + } + + // Avoid the infinite recursion for status updates if reflect.DeepEqual(pgOld.Spec, pgNew.Spec) { if reflect.DeepEqual(pgNew.Annotations, pgOld.Annotations) { + c.logger.WithField("cluster-name", clusterName).Debugf( + "UPDATE event: no spec/annotation changes, skipping") return } } + + c.logger.WithField("cluster-name", clusterName).Infof("UPDATE event: spec or annotations changed, queueing event") c.queueClusterEvent(pgOld, pgNew, EventUpdate) } } @@ -578,7 +638,6 @@ or config maps. The operator does not sync accounts/role bindings after creation. */ func (c *Controller) submitRBACCredentials(event ClusterEvent) error { - namespace := event.NewSpec.GetNamespace() if err := c.createPodServiceAccount(namespace); err != nil { @@ -592,7 +651,6 @@ func (c *Controller) submitRBACCredentials(event ClusterEvent) error { } func (c *Controller) createPodServiceAccount(namespace string) error { - podServiceAccountName := c.opConfig.PodServiceAccountName _, err := c.KubeClient.ServiceAccounts(namespace).Get(context.TODO(), podServiceAccountName, metav1.GetOptions{}) if k8sutil.ResourceNotFound(err) { @@ -615,7 +673,6 @@ func (c *Controller) createPodServiceAccount(namespace string) error { } func (c *Controller) createRoleBindings(namespace string) error { - podServiceAccountName := c.opConfig.PodServiceAccountName podServiceAccountRoleBindingName := c.PodServiceAccountRoleBinding.Name diff --git a/pkg/util/retryutil/retry_util.go b/pkg/util/retryutil/retry_util.go index 868ba6e98..b5fab3b47 100644 --- a/pkg/util/retryutil/retry_util.go +++ b/pkg/util/retryutil/retry_util.go @@ -1,6 +1,7 @@ package retryutil import ( + "context" "fmt" "time" ) @@ -25,7 +26,7 @@ func (t *Ticker) Tick() { <-t.ticker.C } // Retry is a wrapper around RetryWorker that provides a real RetryTicker func Retry(interval time.Duration, timeout time.Duration, f func() (bool, error)) error { - //TODO: make the retry exponential + // TODO: make the retry exponential if timeout < interval { return fmt.Errorf("timeout(%s) should be greater than interval(%v)", timeout, interval) } @@ -33,6 +34,18 @@ func Retry(interval time.Duration, timeout time.Duration, f func() (bool, error) return RetryWorker(interval, timeout, tick, f) } +// RetryWithContext is like Retry but checks for context cancellation before each attempt. +func RetryWithContext(ctx context.Context, interval time.Duration, timeout time.Duration, f func() (bool, error)) error { + return Retry(interval, timeout, func() (bool, error) { + select { + case <-ctx.Done(): + return false, ctx.Err() + default: + return f() + } + }) +} + // RetryWorker calls ConditionFunc until either: // * it returns boolean true // * a timeout expires @@ -41,8 +54,8 @@ func RetryWorker( interval time.Duration, timeout time.Duration, tick RetryTicker, - f func() (bool, error)) error { - + f func() (bool, error), +) error { maxRetries := int(timeout / interval) defer tick.Stop()