diff --git a/interfaces/generate.go b/interfaces/generate.go index 4f3d9b9..9fca13d 100644 --- a/interfaces/generate.go +++ b/interfaces/generate.go @@ -156,23 +156,44 @@ func Generate(clients []any, dir string, opts ...Option) error { return nil } +func normalizeFullTypeName(typeName string) string { + // Preserve pointer prefix + prefix := "" + if strings.HasPrefix(typeName, "*") { + prefix = "*" + typeName = typeName[1:] + } + + versionPattern := regexp.MustCompile(`/v\d+\.`) + parts := strings.Split(typeName, "/") + importName := parts[len(parts)-1] + if versionPattern.MatchString(typeName) { + // Example typeName: github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appconfiguration/armappconfiguration/v2.ConfigurationStoresClientCreateResponse + importName = parts[len(parts)-2] + "." + strings.Split(parts[len(parts)-1], ".")[1] + } + return prefix + importName +} + func normalizedGenericTypeName(str string) string { // Generic output types have the full import path in the string value, so we need to normalize it - pattern := regexp.MustCompile(`\[(.*?)\]`) - groups := pattern.FindStringSubmatch((str)) + pattern := regexp.MustCompile(`\[(.*)\]`) + groups := pattern.FindStringSubmatch(str) if len(groups) < 2 { return str } - typeName := groups[1] - normalizedGenericTypeName := strings.Split(typeName, "/") - importName := normalizedGenericTypeName[len(normalizedGenericTypeName)-1] - versionPattern := regexp.MustCompile(`/v\d+\.`) - if versionPattern.MatchString(typeName) { - // Example typeName: github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appconfiguration/armappconfiguration/v2.ConfigurationStoresClientCreateResponse - importName = normalizedGenericTypeName[len(normalizedGenericTypeName)-2] + "." + strings.Split(normalizedGenericTypeName[len(normalizedGenericTypeName)-1], ".")[1] + // Handle multiple type parameters (e.g. iter.Seq2[*github.com/.../github.Artifact,error]) + typeParams := strings.Split(groups[1], ",") + normalized := make([]string, len(typeParams)) + for i, tp := range typeParams { + tp = strings.TrimSpace(tp) + if strings.Contains(tp, "/") { + normalized[i] = normalizeFullTypeName(tp) + } else { + normalized[i] = tp + } } - return pattern.ReplaceAllString(str, "["+importName+"]") + return pattern.ReplaceAllString(str, "["+strings.Join(normalized, ", ")+"]") } // Adapted from https://stackoverflow.com/a/54129236 diff --git a/interfaces/generate_test.go b/interfaces/generate_test.go index ac9d613..419ee21 100644 --- a/interfaces/generate_test.go +++ b/interfaces/generate_test.go @@ -129,6 +129,59 @@ type ConfigurationStoresClient interface { `, } +func TestNormalizedGenericTypeName(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "no generics", + input: "github.Response", + want: "github.Response", + }, + { + name: "single type param without pointer (existing behavior)", + input: "interfaces.Pager[github.com/cloudquery/codegen/interfaces.Response]", + want: "interfaces.Pager[interfaces.Response]", + }, + { + name: "single type param with versioned import", + input: "runtime.Pager[github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appconfiguration/armappconfiguration/v2.ConfigurationStoresClientListResponse]", + want: "runtime.Pager[armappconfiguration.ConfigurationStoresClientListResponse]", + }, + { + name: "two type params with pointer (iter.Seq2 style)", + input: "iter.Seq2[*github.com/google/go-github/v83/github.Artifact,error]", + want: "iter.Seq2[*github.Artifact, error]", + }, + { + name: "two type params without pointer", + input: "iter.Seq2[github.com/google/go-github/v83/github.Artifact,error]", + want: "iter.Seq2[github.Artifact, error]", + }, + { + name: "two type params both with full paths", + input: "iter.Seq2[*github.com/google/go-github/v83/github.Artifact,*github.com/google/go-github/v83/github.Response]", + want: "iter.Seq2[*github.Artifact, *github.Response]", + }, + { + name: "single type param with pointer and versioned import", + input: "runtime.Pager[*github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appconfiguration/armappconfiguration/v2.ConfigurationStoresClientListResponse]", + want: "runtime.Pager[*armappconfiguration.ConfigurationStoresClientListResponse]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizedGenericTypeName(tt.input) + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("normalizedGenericTypeName(%q) mismatch (-got +want):\n%s", tt.input, diff) + } + }) + } +} + func TestGenerate(t *testing.T) { dir := t.TempDir() err := Generate([]any{&Client{}}, dir,