diff --git a/pkg/scanners/terraform/parser/evaluator.go b/pkg/scanners/terraform/parser/evaluator.go index ef40a85..251291c 100644 --- a/pkg/scanners/terraform/parser/evaluator.go +++ b/pkg/scanners/terraform/parser/evaluator.go @@ -3,6 +3,7 @@ package parser import ( "context" "errors" + "fmt" "io/fs" "reflect" "time" @@ -17,7 +18,6 @@ import ( "github.com/hashicorp/hcl/v2/ext/typeexpr" "github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty/convert" - "github.com/zclconf/go-cty/cty/gocty" ) const ( @@ -228,6 +228,42 @@ func (e *evaluator) expandDynamicBlock(b *terraform.Block) { } } +func validateForEachArg(arg cty.Value) error { + if arg.IsNull() { + return errors.New("arg is null") + } + + ty := arg.Type() + + if !arg.IsKnown() || ty.Equals(cty.DynamicPseudoType) || arg.LengthInt() == 0 { + return nil + } + + if !(ty.IsSetType() || ty.IsObjectType() || ty.IsMapType()) { + return fmt.Errorf("%s type is not supported: arg is not set or map", ty.FriendlyName()) + } + + if ty.IsSetType() { + if !ty.ElementType().Equals(cty.String) { + return errors.New("arg is not set of strings") + } + + it := arg.ElementIterator() + for it.Next() { + key, _ := it.Element() + if key.IsNull() { + return errors.New("arg is set of strings, but contains null") + } + + if !key.IsKnown() { + return errors.New("arg is set of strings, but contains unknown value") + } + } + } + + return nil +} + func isBlockSupportsForEachMetaArgument(block *terraform.Block) bool { return slices.Contains([]string{"module", "resource", "data", "dynamic"}, block.Type()) } @@ -243,43 +279,50 @@ func (e *evaluator) expandBlockForEaches(blocks terraform.Blocks) terraform.Bloc forEachFiltered = append(forEachFiltered, block) continue } - if !forEachAttr.Value().IsNull() && forEachAttr.Value().IsKnown() && forEachAttr.IsIterable() { - var clones []cty.Value - _ = forEachAttr.Each(func(key cty.Value, val cty.Value) { - index := key + forEachVal := forEachAttr.Value() + + if err := validateForEachArg(forEachVal); err != nil { + e.debug.Log(`"for_each" argument is invalid: %s`, err.Error()) + continue + } + + clones := make(map[string]cty.Value) + _ = forEachAttr.Each(func(key cty.Value, val cty.Value) { - switch val.Type() { - case cty.String, cty.Number: - index = val - } + if !key.Type().Equals(cty.String) { + e.debug.Log( + `Invalid "for-each" argument: map key (or set value) is not a string, but %s`, + key.Type().FriendlyName(), + ) + return + } - clone := block.Clone(index) + clone := block.Clone(key) - ctx := clone.Context() + ctx := clone.Context() - e.copyVariables(block, clone) + e.copyVariables(block, clone) - ctx.SetByDot(key, "each.key") - ctx.SetByDot(val, "each.value") + ctx.SetByDot(key, "each.key") + ctx.SetByDot(val, "each.value") - ctx.Set(key, block.TypeLabel(), "key") - ctx.Set(val, block.TypeLabel(), "value") + ctx.Set(key, block.TypeLabel(), "key") + ctx.Set(val, block.TypeLabel(), "value") - forEachFiltered = append(forEachFiltered, clone) + forEachFiltered = append(forEachFiltered, clone) - clones = append(clones, clone.Values()) - metadata := clone.GetMetadata() - e.ctx.SetByDot(clone.Values(), metadata.Reference()) - }) - metadata := block.GetMetadata() - if len(clones) == 0 { - e.ctx.SetByDot(cty.EmptyTupleVal, metadata.Reference()) - } else { - e.ctx.SetByDot(cty.TupleVal(clones), metadata.Reference()) - } - e.debug.Log("Expanded block '%s' into %d clones via 'for_each' attribute.", block.LocalName(), len(clones)) + clones[key.AsString()] = clone.Values() + metadata := clone.GetMetadata() + e.ctx.SetByDot(clone.Values(), metadata.Reference()) + }) + metadata := block.GetMetadata() + if len(clones) == 0 { + e.ctx.SetByDot(cty.EmptyTupleVal, metadata.Reference()) + } else { + e.ctx.SetByDot(cty.MapVal(clones), metadata.Reference()) } + e.debug.Log("Expanded block '%s' into %d clones via 'for_each' attribute.", block.LocalName(), len(clones)) } return forEachFiltered @@ -298,19 +341,15 @@ func (e *evaluator) expandBlockCounts(blocks terraform.Blocks) terraform.Blocks continue } count := 1 - if !countAttr.Value().IsNull() && countAttr.Value().IsKnown() { - if countAttr.Value().Type() == cty.Number { - f, _ := countAttr.Value().AsBigFloat().Float64() - count = int(f) - } + countAttrVal := countAttr.Value() + if !countAttrVal.IsNull() && countAttrVal.IsKnown() && countAttrVal.Type() == cty.Number { + count = int(countAttr.AsNumber()) } var clones []cty.Value for i := 0; i < count; i++ { - c, _ := gocty.ToCtyValue(i, cty.Number) - clone := block.Clone(c) + clone := block.Clone(cty.NumberIntVal(int64(i))) clones = append(clones, clone.Values()) - block.TypeLabel() countFiltered = append(countFiltered, clone) metadata := clone.GetMetadata() e.ctx.SetByDot(clone.Values(), metadata.Reference()) diff --git a/pkg/scanners/terraform/parser/evaluator_test.go b/pkg/scanners/terraform/parser/evaluator_test.go new file mode 100644 index 0000000..8d3ef7b --- /dev/null +++ b/pkg/scanners/terraform/parser/evaluator_test.go @@ -0,0 +1,94 @@ +package parser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zclconf/go-cty/cty" +) + +func TestValidateForEachArg(t *testing.T) { + tests := []struct { + name string + arg cty.Value + expectedError string + }{ + { + name: "empty set", + arg: cty.SetValEmpty(cty.String), + }, + { + name: "set of strings", + arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.StringVal("val2")}), + }, + { + name: "set of non-strings", + arg: cty.SetVal([]cty.Value{cty.NumberIntVal(1), cty.NumberIntVal(2)}), + expectedError: "is not set of strings", + }, + { + name: "set with null", + arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.NullVal(cty.String)}), + expectedError: "arg is set of strings, but contains null", + }, + { + name: "set with unknown", + arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.UnknownVal(cty.String)}), + expectedError: "arg is set of strings, but contains unknown", + }, + { + name: "set with unknown", + arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.UnknownVal(cty.String)}), + expectedError: "arg is set of strings, but contains unknown", + }, + { + name: "non empty map", + arg: cty.MapVal(map[string]cty.Value{ + "val1": cty.StringVal("..."), + "val2": cty.StringVal("..."), + }), + }, + { + name: "map with unknown", + arg: cty.MapVal(map[string]cty.Value{ + "val1": cty.UnknownVal(cty.String), + "val2": cty.StringVal("..."), + }), + }, + { + name: "empty obj", + arg: cty.EmptyObjectVal, + }, + { + name: "obj with strings", + arg: cty.ObjectVal(map[string]cty.Value{ + "val1": cty.StringVal("..."), + "val2": cty.StringVal("..."), + }), + }, + { + name: "null", + arg: cty.NullVal(cty.Set(cty.String)), + expectedError: "arg is null", + }, + { + name: "unknown", + arg: cty.UnknownVal(cty.Set(cty.String)), + }, + { + name: "dynamic", + arg: cty.DynamicVal, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateForEachArg(tt.arg) + if tt.expectedError != "" && err != nil { + assert.ErrorContains(t, err, tt.expectedError) + return + } + assert.NoError(t, err) + }) + } +} diff --git a/pkg/scanners/terraform/parser/parser_test.go b/pkg/scanners/terraform/parser/parser_test.go index 18cee56..57ce4e6 100644 --- a/pkg/scanners/terraform/parser/parser_test.go +++ b/pkg/scanners/terraform/parser/parser_test.go @@ -805,10 +805,48 @@ policy_rules = { assert.Equal(t, 1001, block.GetAttribute("priority").AsIntValueOrDefault(0, block).Value()) } +func Test_ForEachRefersToMapThatContainsSameStringValues(t *testing.T) { + fs := testutil.CreateFS(t, map[string]string{ + "main.tf": `locals { + buckets = { + bucket1 = "test1" + bucket2 = "test1" + } +} + +resource "aws_s3_bucket" "this" { + for_each = local.buckets + bucket = each.key +} +`, + }) + + parser := New(fs, "", OptionStopOnHCLError(true)) + require.NoError(t, parser.ParseFS(context.TODO(), ".")) + + modules, _, err := parser.EvaluateAll(context.TODO()) + assert.NoError(t, err) + assert.Len(t, modules, 1) + + bucketBlocks := modules.GetResourcesByType("aws_s3_bucket") + assert.Len(t, bucketBlocks, 2) + + var labels []string + + for _, b := range bucketBlocks { + labels = append(labels, b.Label()) + } + + expectedLabels := []string{ + `aws_s3_bucket.this["bucket1"]`, + `aws_s3_bucket.this["bucket2"]`, + } + assert.Equal(t, expectedLabels, labels) +} + func TestDataSourceWithCountMetaArgument(t *testing.T) { fs := testutil.CreateFS(t, map[string]string{ "main.tf": ` - data "http" "example" { count = 2 } @@ -843,10 +881,10 @@ func TestDataSourceWithForEachMetaArgument(t *testing.T) { fs := testutil.CreateFS(t, map[string]string{ "main.tf": ` locals { - ports = [80, 8080] + ports = ["80", "8080"] } data "http" "example" { - for_each = local.ports + for_each = toset(local.ports) url = "localhost:${each.key}" } `, @@ -864,3 +902,125 @@ data "http" "example" { httpDataSources := rootModule.GetDatasByType("http") assert.Len(t, httpDataSources, 2) } + +func TestForEach(t *testing.T) { + + tests := []struct { + name string + source string + expectedCount int + }{ + { + name: "arg is list of strings", + source: `locals { + buckets = ["bucket1", "bucket2"] +} + +resource "aws_s3_bucket" "this" { + for_each = local.buckets + bucket = each.key +}`, + expectedCount: 0, + }, + { + name: "arg is empty set", + source: `locals { + buckets = toset([]) +} + +resource "aws_s3_bucket" "this" { + for_each = loca.buckets + bucket = each.key +}`, + expectedCount: 0, + }, + { + name: "arg is set of strings", + source: `locals { + buckets = ["bucket1", "bucket2"] +} + +resource "aws_s3_bucket" "this" { + for_each = toset(local.buckets) + bucket = each.key +}`, + expectedCount: 2, + }, + { + name: "arg is map", + source: `locals { + buckets = { + 1 = {} + 2 = {} + } +} + +resource "aws_s3_bucket" "this" { + for_each = local.buckets + bucket = each.key +}`, + expectedCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fs := testutil.CreateFS(t, map[string]string{ + "main.tf": tt.source, + }) + parser := New(fs, "", OptionStopOnHCLError(true)) + require.NoError(t, parser.ParseFS(context.TODO(), ".")) + + modules, _, err := parser.EvaluateAll(context.TODO()) + assert.NoError(t, err) + assert.Len(t, modules, 1) + + bucketBlocks := modules.GetResourcesByType("aws_s3_bucket") + assert.Len(t, bucketBlocks, tt.expectedCount) + }) + } +} + +func TestForEachRefToResource(t *testing.T) { + fs := testutil.CreateFS(t, map[string]string{ + "main.tf": ` + locals { + vpcs = { + "test1" = { + cidr_block = "192.168.0.0/28" + } + "test2" = { + cidr_block = "192.168.1.0/28" + } + } +} + +resource "aws_vpc" "example" { + for_each = local.vpcs + cidr_block = each.value.cidr_block +} + +resource "aws_internet_gateway" "example" { + for_each = aws_vpc.example + vpc_id = each.key +} +`, + }) + parser := New(fs, "", OptionStopOnHCLError(true)) + require.NoError(t, parser.ParseFS(context.TODO(), ".")) + + modules, _, err := parser.EvaluateAll(context.TODO()) + assert.NoError(t, err) + assert.Len(t, modules, 1) + + blocks := modules.GetResourcesByType("aws_internet_gateway") + assert.Len(t, blocks, 2) + + var vpcIds []string + for _, b := range blocks { + vpcIds = append(vpcIds, b.GetAttribute("vpc_id").Value().AsString()) + } + + expectedVpcIds := []string{"test1", "test2"} + assert.Equal(t, expectedVpcIds, vpcIds) +} diff --git a/test/module_test.go b/test/module_test.go index c1af12d..ffe4141 100644 --- a/test/module_test.go +++ b/test/module_test.go @@ -537,7 +537,7 @@ func Test_Dynamic_Variables_FalsePositive(t *testing.T) { resource "something" "else" { x = 1 dynamic "blah" { - for_each = [true] + for_each = toset(["true"]) content { ok = each.value