diff --git a/pkg/cluster/clustermap/errors.go b/pkg/cluster/clustermap/errors.go index 080851008..bfa93e686 100644 --- a/pkg/cluster/clustermap/errors.go +++ b/pkg/cluster/clustermap/errors.go @@ -39,3 +39,13 @@ type ErrClusterNotInMap struct { func (e ErrClusterNotInMap) Error() string { return fmt.Sprintf("cluster '%s' is not defined in cluster map %v", e.Child, e.Map) } + +// ErrClusterCircularDependency returned for circular dependencies +type ErrClusterCircularDependency struct { + Parent string + Map *v1alpha1.ClusterMap +} + +func (e ErrClusterCircularDependency) Error() string { + return fmt.Sprintf("%v contains cluster referenced as both parent and child: %s", e.Map, e.Parent) +} diff --git a/pkg/cluster/clustermap/map.go b/pkg/cluster/clustermap/map.go index 98560b554..c5b55dd01 100644 --- a/pkg/cluster/clustermap/map.go +++ b/pkg/cluster/clustermap/map.go @@ -36,6 +36,7 @@ type WriteOptions struct { // TODO use typed cluster names type ClusterMap interface { ParentCluster(string) (string, error) + ValidateClusterMap() error AllClusters() []string ClusterKubeconfigContext(string) (string, error) Sources(string) ([]v1alpha1.KubeconfigSource, error) @@ -66,6 +67,33 @@ func (cm clusterMap) ParentCluster(child string) (string, error) { return currentCluster.Parent, nil } +// Validates a clustermap has valid parent-child map structure +func (cm clusterMap) ValidateClusterMap() error { + clusterMap := cm.AllClusters() + for _, childCluster := range clusterMap { + var parentClusters []string + var currentChild string = childCluster + for { + currentCluster, _ := cm.apiMap.Map[currentChild] + for _, c := range parentClusters { + if c == currentCluster.Parent { + // Quit on parent whos also child + return ErrClusterCircularDependency{Parent: childCluster, Map: cm.apiMap} + } + } + // Quit loop once top level of current cluster is reached + if currentCluster.Parent == "" { + break + } + parentClusters = append(parentClusters, currentCluster.Parent) + currentChild = currentCluster.Parent + } + } + + // Return success if there are no conflicts + return nil +} + // AllClusters returns all clusters in a map func (cm clusterMap) AllClusters() []string { clusters := []string{} diff --git a/pkg/cluster/clustermap/map_test.go b/pkg/cluster/clustermap/map_test.go index 96ceaff06..009ab7781 100644 --- a/pkg/cluster/clustermap/map_test.go +++ b/pkg/cluster/clustermap/map_test.go @@ -97,6 +97,27 @@ func TestClusterMap(t *testing.T) { assert.Equal(t, "", parent) }) + t.Run("Validate Circular Clustermap", func(t *testing.T) { + // Create new map with circular dependency + circularAPIMap := &v1alpha1.ClusterMap{ + Map: map[string]*v1alpha1.Cluster{}, + } + for key, value := range apiMap.Map { + newValue := *value + circularAPIMap.Map[key] = &newValue + } + circularAPIMap.Map["ephemeral"].Parent = "workload" + cMapCircular := clustermap.NewClusterMap(circularAPIMap) + err := cMapCircular.ValidateClusterMap() + assert.Error(t, err) + }) + + t.Run("Validate all Clustermaps", func(t *testing.T) { + // Check child clusterID against map of parent clusterID map + err := cMap.ValidateClusterMap() + assert.NoError(t, err) + }) + t.Run("all clusters", func(t *testing.T) { clusters := cMap.AllClusters() assert.Len(t, clusters, 4)