package auth import ( "bufio" "encoding/json" "fmt" "os" "strconv" "strings" ) // EnhancedCredentialStore manages multiple credentials per gateway type EnhancedCredentialStore struct { Gateways map[string]*GatewayCredentials `json:"gateways"` Version string `json:"version"` } // GatewayCredentials holds multiple credentials for a single gateway type GatewayCredentials struct { Credentials []*Credentials `json:"credentials"` DefaultIndex int `json:"default_index"` LastUsedIndex int `json:"last_used_index"` } // AuthChoice represents user's choice during authentication type AuthChoice int const ( AuthChoiceUseCredential AuthChoice = iota AuthChoiceAddCredential AuthChoiceLogout AuthChoiceExit ) // LoadEnhancedCredentials loads the enhanced credential store func LoadEnhancedCredentials() (*EnhancedCredentialStore, error) { credPath, err := GetCredentialsPath() if err != nil { return nil, err } // If file doesn't exist, return empty store if _, err := os.Stat(credPath); os.IsNotExist(err) { return &EnhancedCredentialStore{ Gateways: make(map[string]*GatewayCredentials), Version: "2.0", }, nil } data, err := os.ReadFile(credPath) if err != nil { return nil, fmt.Errorf("failed to read credentials file: %w", err) } // Try to parse as enhanced store first var enhancedStore EnhancedCredentialStore if err := json.Unmarshal(data, &enhancedStore); err == nil && enhancedStore.Version == "2.0" { // Initialize maps if nil if enhancedStore.Gateways == nil { enhancedStore.Gateways = make(map[string]*GatewayCredentials) } return &enhancedStore, nil } // Fall back to old format and migrate var oldStore CredentialStore if err := json.Unmarshal(data, &oldStore); err != nil { return nil, fmt.Errorf("failed to parse credentials file: %w", err) } // Migrate old format to new enhancedStore = EnhancedCredentialStore{ Gateways: make(map[string]*GatewayCredentials), Version: "2.0", } for gatewayURL, creds := range oldStore.Gateways { if creds != nil { enhancedStore.Gateways[gatewayURL] = &GatewayCredentials{ Credentials: []*Credentials{creds}, DefaultIndex: 0, LastUsedIndex: 0, } } } return &enhancedStore, nil } // Save saves the enhanced credential store func (store *EnhancedCredentialStore) Save() error { credPath, err := GetCredentialsPath() if err != nil { return err } if store.Version == "" { store.Version = "2.0" } data, err := json.MarshalIndent(store, "", " ") if err != nil { return fmt.Errorf("failed to marshal credentials: %w", err) } return os.WriteFile(credPath, data, 0600) } // AddCredential adds a new credential for the gateway func (store *EnhancedCredentialStore) AddCredential(gatewayURL string, creds *Credentials) { if store.Gateways == nil { store.Gateways = make(map[string]*GatewayCredentials) } gatewayCredentials := store.Gateways[gatewayURL] if gatewayCredentials == nil { gatewayCredentials = &GatewayCredentials{ Credentials: []*Credentials{}, DefaultIndex: 0, LastUsedIndex: 0, } store.Gateways[gatewayURL] = gatewayCredentials } // Check if credential already exists (by wallet address) for i, existing := range gatewayCredentials.Credentials { if strings.EqualFold(existing.Wallet, creds.Wallet) { // Update existing credential gatewayCredentials.Credentials[i] = creds return } } // Add new credential gatewayCredentials.Credentials = append(gatewayCredentials.Credentials, creds) } // GetDefaultCredential returns the default credential for a gateway func (store *EnhancedCredentialStore) GetDefaultCredential(gatewayURL string) *Credentials { gatewayCredentials := store.Gateways[gatewayURL] if gatewayCredentials == nil || len(gatewayCredentials.Credentials) == 0 { return nil } // Ensure default index is valid if gatewayCredentials.DefaultIndex < 0 || gatewayCredentials.DefaultIndex >= len(gatewayCredentials.Credentials) { gatewayCredentials.DefaultIndex = 0 } return gatewayCredentials.Credentials[gatewayCredentials.DefaultIndex] } // SetDefaultCredential sets the default credential by index func (store *EnhancedCredentialStore) SetDefaultCredential(gatewayURL string, index int) bool { gatewayCredentials := store.Gateways[gatewayURL] if gatewayCredentials == nil || index < 0 || index >= len(gatewayCredentials.Credentials) { return false } gatewayCredentials.DefaultIndex = index gatewayCredentials.LastUsedIndex = index return true } // ClearAllCredentials removes all credentials func (store *EnhancedCredentialStore) ClearAllCredentials() { store.Gateways = make(map[string]*GatewayCredentials) } // DisplayCredentialMenu shows the interactive credential selection menu func (store *EnhancedCredentialStore) DisplayCredentialMenu(gatewayURL string) (AuthChoice, int, error) { gatewayCredentials := store.Gateways[gatewayURL] if gatewayCredentials == nil || len(gatewayCredentials.Credentials) == 0 { fmt.Println("\nšŸ” No credentials found. Choose an option:") fmt.Println("1. Authenticate with new wallet") fmt.Println("2. Exit") fmt.Print("Choose (1-2): ") choice, err := readUserChoice(2) if err != nil { return AuthChoiceExit, -1, err } switch choice { case 1: return AuthChoiceAddCredential, -1, nil case 2: return AuthChoiceExit, -1, nil default: return AuthChoiceExit, -1, fmt.Errorf("invalid choice") } } fmt.Printf("\nšŸ” Multiple wallets available for %s:\n", gatewayURL) // Display credentials for i, creds := range gatewayCredentials.Credentials { defaultMark := "" if i == gatewayCredentials.DefaultIndex { defaultMark = " (default)" } // Format wallet address for display displayAddr := creds.Wallet if len(displayAddr) > 10 { displayAddr = displayAddr[:6] + "..." + displayAddr[len(displayAddr)-4:] } statusEmoji := "āœ…" if !creds.IsValid() { statusEmoji = "āŒ" } planInfo := "" if creds.Plan != "" { planInfo = fmt.Sprintf(" (%s)", creds.Plan) } fmt.Printf("%d. %s %s%s%s\n", i+1, statusEmoji, displayAddr, planInfo, defaultMark) } fmt.Printf("%d. Add new wallet\n", len(gatewayCredentials.Credentials)+1) fmt.Printf("%d. Logout (clear all credentials)\n", len(gatewayCredentials.Credentials)+2) fmt.Printf("%d. Exit\n", len(gatewayCredentials.Credentials)+3) maxChoice := len(gatewayCredentials.Credentials) + 3 fmt.Printf("Choose (1-%d): ", maxChoice) choice, err := readUserChoice(maxChoice) if err != nil { return AuthChoiceExit, -1, err } if choice <= len(gatewayCredentials.Credentials) { // User selected a credential return AuthChoiceUseCredential, choice - 1, nil } else if choice == len(gatewayCredentials.Credentials)+1 { // Add new credential return AuthChoiceAddCredential, -1, nil } else if choice == len(gatewayCredentials.Credentials)+2 { // Logout return AuthChoiceLogout, -1, nil } else { // Exit return AuthChoiceExit, -1, nil } } // readUserChoice reads and validates user input func readUserChoice(maxChoice int) (int, error) { reader := bufio.NewReader(os.Stdin) input, err := reader.ReadString('\n') if err != nil { return 0, fmt.Errorf("failed to read input: %w", err) } choiceStr := strings.TrimSpace(input) choice, err := strconv.Atoi(choiceStr) if err != nil { return 0, fmt.Errorf("invalid input: please enter a number") } if choice < 1 || choice > maxChoice { return 0, fmt.Errorf("invalid choice: please enter a number between 1 and %d", maxChoice) } return choice, nil } // GetOrPromptForCredentials handles the complete authentication flow func GetOrPromptForCredentials(gatewayURL string) (*Credentials, error) { store, err := LoadEnhancedCredentials() if err != nil { return nil, fmt.Errorf("failed to load credential store: %w", err) } // Check if we have a valid default credential defaultCreds := store.GetDefaultCredential(gatewayURL) if defaultCreds != nil && defaultCreds.IsValid() { // Update last used time defaultCreds.UpdateLastUsed() if err := store.Save(); err != nil { // Log warning but don't fail fmt.Fprintf(os.Stderr, "Warning: failed to update last used time: %v\n", err) } return defaultCreds, nil } // Need to prompt user for credential selection for { choice, credIndex, err := store.DisplayCredentialMenu(gatewayURL) if err != nil { return nil, fmt.Errorf("menu selection failed: %w", err) } switch choice { case AuthChoiceUseCredential: gatewayCredentials := store.Gateways[gatewayURL] if gatewayCredentials == nil || credIndex < 0 || credIndex >= len(gatewayCredentials.Credentials) { fmt.Println("āŒ Invalid credential selection") continue } selectedCreds := gatewayCredentials.Credentials[credIndex] if !selectedCreds.IsValid() { fmt.Println("āŒ Selected credentials are invalid or expired") continue } // Update default and last used store.SetDefaultCredential(gatewayURL, credIndex) selectedCreds.UpdateLastUsed() if err := store.Save(); err != nil { fmt.Fprintf(os.Stderr, "Warning: failed to save credentials: %v\n", err) } return selectedCreds, nil case AuthChoiceAddCredential: fmt.Println("\n🌐 Opening browser for wallet authentication...") newCreds, err := PerformWalletAuthentication(gatewayURL) if err != nil { fmt.Printf("āŒ Authentication failed: %v\n", err) continue } // Add the new credential store.AddCredential(gatewayURL, newCreds) // Set as default if it's the first credential gatewayCredentials := store.Gateways[gatewayURL] if gatewayCredentials != nil && len(gatewayCredentials.Credentials) == 1 { store.SetDefaultCredential(gatewayURL, 0) } if err := store.Save(); err != nil { return nil, fmt.Errorf("failed to save new credentials: %w", err) } fmt.Printf("āœ… Wallet %s added successfully\n", newCreds.Wallet) return newCreds, nil case AuthChoiceLogout: store.ClearAllCredentials() if err := store.Save(); err != nil { return nil, fmt.Errorf("failed to clear credentials: %w", err) } fmt.Println("āœ… All credentials cleared") continue case AuthChoiceExit: return nil, fmt.Errorf("authentication cancelled by user") default: fmt.Println("āŒ Invalid choice") continue } } } // HasValidEnhancedCredentials checks if there are valid credentials for the default gateway func HasValidEnhancedCredentials() (bool, error) { store, err := LoadEnhancedCredentials() if err != nil { return false, err } gatewayURL := GetDefaultGatewayURL() defaultCreds := store.GetDefaultCredential(gatewayURL) return defaultCreds != nil && defaultCreds.IsValid(), nil } // GetValidEnhancedCredentials returns valid credentials for the default gateway func GetValidEnhancedCredentials() (*Credentials, error) { store, err := LoadEnhancedCredentials() if err != nil { return nil, err } gatewayURL := GetDefaultGatewayURL() defaultCreds := store.GetDefaultCredential(gatewayURL) if defaultCreds == nil { return nil, fmt.Errorf("no credentials found for gateway %s", gatewayURL) } if !defaultCreds.IsValid() { return nil, fmt.Errorf("credentials for gateway %s are expired or invalid", gatewayURL) } return defaultCreds, nil }