network/pkg/auth/enhanced_auth.go
anonpenguin 1fca8cb411 Add authentication to protected CLI commands
This commit adds wallet-based authentication to protected CLI commands
by removing the manual auth command and automatically prompting for
credentials when needed. Protected commands will check for valid
credentials and trigger the auth
2025-08-20 12:51:54 +03:00

396 lines
11 KiB
Go

package auth
import (
"bufio"
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
)
// EnhancedCredentialStore manages multiple credentials per gateway
type EnhancedCredentialStore struct {
Gateways map[string]*GatewayCredentials `json:"gateways"`
Version string `json:"version"`
}
// GatewayCredentials holds multiple credentials for a single gateway
type GatewayCredentials struct {
Credentials []*Credentials `json:"credentials"`
DefaultIndex int `json:"default_index"`
LastUsedIndex int `json:"last_used_index"`
}
// AuthChoice represents user's choice during authentication
type AuthChoice int
const (
AuthChoiceUseCredential AuthChoice = iota
AuthChoiceAddCredential
AuthChoiceLogout
AuthChoiceExit
)
// LoadEnhancedCredentials loads the enhanced credential store
func LoadEnhancedCredentials() (*EnhancedCredentialStore, error) {
credPath, err := GetCredentialsPath()
if err != nil {
return nil, err
}
// If file doesn't exist, return empty store
if _, err := os.Stat(credPath); os.IsNotExist(err) {
return &EnhancedCredentialStore{
Gateways: make(map[string]*GatewayCredentials),
Version: "2.0",
}, nil
}
data, err := os.ReadFile(credPath)
if err != nil {
return nil, fmt.Errorf("failed to read credentials file: %w", err)
}
// Try to parse as enhanced store first
var enhancedStore EnhancedCredentialStore
if err := json.Unmarshal(data, &enhancedStore); err == nil && enhancedStore.Version == "2.0" {
// Initialize maps if nil
if enhancedStore.Gateways == nil {
enhancedStore.Gateways = make(map[string]*GatewayCredentials)
}
return &enhancedStore, nil
}
// Fall back to old format and migrate
var oldStore CredentialStore
if err := json.Unmarshal(data, &oldStore); err != nil {
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
}
// Migrate old format to new
enhancedStore = EnhancedCredentialStore{
Gateways: make(map[string]*GatewayCredentials),
Version: "2.0",
}
for gatewayURL, creds := range oldStore.Gateways {
if creds != nil {
enhancedStore.Gateways[gatewayURL] = &GatewayCredentials{
Credentials: []*Credentials{creds},
DefaultIndex: 0,
LastUsedIndex: 0,
}
}
}
return &enhancedStore, nil
}
// Save saves the enhanced credential store
func (store *EnhancedCredentialStore) Save() error {
credPath, err := GetCredentialsPath()
if err != nil {
return err
}
if store.Version == "" {
store.Version = "2.0"
}
data, err := json.MarshalIndent(store, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal credentials: %w", err)
}
return os.WriteFile(credPath, data, 0600)
}
// AddCredential adds a new credential for the gateway
func (store *EnhancedCredentialStore) AddCredential(gatewayURL string, creds *Credentials) {
if store.Gateways == nil {
store.Gateways = make(map[string]*GatewayCredentials)
}
gatewayCredentials := store.Gateways[gatewayURL]
if gatewayCredentials == nil {
gatewayCredentials = &GatewayCredentials{
Credentials: []*Credentials{},
DefaultIndex: 0,
LastUsedIndex: 0,
}
store.Gateways[gatewayURL] = gatewayCredentials
}
// Check if credential already exists (by wallet address)
for i, existing := range gatewayCredentials.Credentials {
if strings.EqualFold(existing.Wallet, creds.Wallet) {
// Update existing credential
gatewayCredentials.Credentials[i] = creds
return
}
}
// Add new credential
gatewayCredentials.Credentials = append(gatewayCredentials.Credentials, creds)
}
// GetDefaultCredential returns the default credential for a gateway
func (store *EnhancedCredentialStore) GetDefaultCredential(gatewayURL string) *Credentials {
gatewayCredentials := store.Gateways[gatewayURL]
if gatewayCredentials == nil || len(gatewayCredentials.Credentials) == 0 {
return nil
}
// Ensure default index is valid
if gatewayCredentials.DefaultIndex < 0 || gatewayCredentials.DefaultIndex >= len(gatewayCredentials.Credentials) {
gatewayCredentials.DefaultIndex = 0
}
return gatewayCredentials.Credentials[gatewayCredentials.DefaultIndex]
}
// SetDefaultCredential sets the default credential by index
func (store *EnhancedCredentialStore) SetDefaultCredential(gatewayURL string, index int) bool {
gatewayCredentials := store.Gateways[gatewayURL]
if gatewayCredentials == nil || index < 0 || index >= len(gatewayCredentials.Credentials) {
return false
}
gatewayCredentials.DefaultIndex = index
gatewayCredentials.LastUsedIndex = index
return true
}
// ClearAllCredentials removes all credentials
func (store *EnhancedCredentialStore) ClearAllCredentials() {
store.Gateways = make(map[string]*GatewayCredentials)
}
// DisplayCredentialMenu shows the interactive credential selection menu
func (store *EnhancedCredentialStore) DisplayCredentialMenu(gatewayURL string) (AuthChoice, int, error) {
gatewayCredentials := store.Gateways[gatewayURL]
if gatewayCredentials == nil || len(gatewayCredentials.Credentials) == 0 {
fmt.Println("\n🔐 No credentials found. Choose an option:")
fmt.Println("1. Authenticate with new wallet")
fmt.Println("2. Exit")
fmt.Print("Choose (1-2): ")
choice, err := readUserChoice(2)
if err != nil {
return AuthChoiceExit, -1, err
}
switch choice {
case 1:
return AuthChoiceAddCredential, -1, nil
case 2:
return AuthChoiceExit, -1, nil
default:
return AuthChoiceExit, -1, fmt.Errorf("invalid choice")
}
}
fmt.Printf("\n🔐 Multiple wallets available for %s:\n", gatewayURL)
// Display credentials
for i, creds := range gatewayCredentials.Credentials {
defaultMark := ""
if i == gatewayCredentials.DefaultIndex {
defaultMark = " (default)"
}
// Format wallet address for display
displayAddr := creds.Wallet
if len(displayAddr) > 10 {
displayAddr = displayAddr[:6] + "..." + displayAddr[len(displayAddr)-4:]
}
statusEmoji := "✅"
if !creds.IsValid() {
statusEmoji = "❌"
}
planInfo := ""
if creds.Plan != "" {
planInfo = fmt.Sprintf(" (%s)", creds.Plan)
}
fmt.Printf("%d. %s %s%s%s\n", i+1, statusEmoji, displayAddr, planInfo, defaultMark)
}
fmt.Printf("%d. Add new wallet\n", len(gatewayCredentials.Credentials)+1)
fmt.Printf("%d. Logout (clear all credentials)\n", len(gatewayCredentials.Credentials)+2)
fmt.Printf("%d. Exit\n", len(gatewayCredentials.Credentials)+3)
maxChoice := len(gatewayCredentials.Credentials) + 3
fmt.Printf("Choose (1-%d): ", maxChoice)
choice, err := readUserChoice(maxChoice)
if err != nil {
return AuthChoiceExit, -1, err
}
if choice <= len(gatewayCredentials.Credentials) {
// User selected a credential
return AuthChoiceUseCredential, choice - 1, nil
} else if choice == len(gatewayCredentials.Credentials)+1 {
// Add new credential
return AuthChoiceAddCredential, -1, nil
} else if choice == len(gatewayCredentials.Credentials)+2 {
// Logout
return AuthChoiceLogout, -1, nil
} else {
// Exit
return AuthChoiceExit, -1, nil
}
}
// readUserChoice reads and validates user input
func readUserChoice(maxChoice int) (int, error) {
reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
return 0, fmt.Errorf("failed to read input: %w", err)
}
choiceStr := strings.TrimSpace(input)
choice, err := strconv.Atoi(choiceStr)
if err != nil {
return 0, fmt.Errorf("invalid input: please enter a number")
}
if choice < 1 || choice > maxChoice {
return 0, fmt.Errorf("invalid choice: please enter a number between 1 and %d", maxChoice)
}
return choice, nil
}
// GetOrPromptForCredentials handles the complete authentication flow
func GetOrPromptForCredentials(gatewayURL string) (*Credentials, error) {
store, err := LoadEnhancedCredentials()
if err != nil {
return nil, fmt.Errorf("failed to load credential store: %w", err)
}
// Check if we have a valid default credential
defaultCreds := store.GetDefaultCredential(gatewayURL)
if defaultCreds != nil && defaultCreds.IsValid() {
// Update last used time
defaultCreds.UpdateLastUsed()
if err := store.Save(); err != nil {
// Log warning but don't fail
fmt.Fprintf(os.Stderr, "Warning: failed to update last used time: %v\n", err)
}
return defaultCreds, nil
}
// Need to prompt user for credential selection
for {
choice, credIndex, err := store.DisplayCredentialMenu(gatewayURL)
if err != nil {
return nil, fmt.Errorf("menu selection failed: %w", err)
}
switch choice {
case AuthChoiceUseCredential:
gatewayCredentials := store.Gateways[gatewayURL]
if gatewayCredentials == nil || credIndex < 0 || credIndex >= len(gatewayCredentials.Credentials) {
fmt.Println("❌ Invalid credential selection")
continue
}
selectedCreds := gatewayCredentials.Credentials[credIndex]
if !selectedCreds.IsValid() {
fmt.Println("❌ Selected credentials are invalid or expired")
continue
}
// Update default and last used
store.SetDefaultCredential(gatewayURL, credIndex)
selectedCreds.UpdateLastUsed()
if err := store.Save(); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to save credentials: %v\n", err)
}
return selectedCreds, nil
case AuthChoiceAddCredential:
fmt.Println("\n🌐 Opening browser for wallet authentication...")
newCreds, err := PerformWalletAuthentication(gatewayURL)
if err != nil {
fmt.Printf("❌ Authentication failed: %v\n", err)
continue
}
// Add the new credential
store.AddCredential(gatewayURL, newCreds)
// Set as default if it's the first credential
gatewayCredentials := store.Gateways[gatewayURL]
if gatewayCredentials != nil && len(gatewayCredentials.Credentials) == 1 {
store.SetDefaultCredential(gatewayURL, 0)
}
if err := store.Save(); err != nil {
return nil, fmt.Errorf("failed to save new credentials: %w", err)
}
fmt.Printf("✅ Wallet %s added successfully\n", newCreds.Wallet)
return newCreds, nil
case AuthChoiceLogout:
store.ClearAllCredentials()
if err := store.Save(); err != nil {
return nil, fmt.Errorf("failed to clear credentials: %w", err)
}
fmt.Println("✅ All credentials cleared")
continue
case AuthChoiceExit:
return nil, fmt.Errorf("authentication cancelled by user")
default:
fmt.Println("❌ Invalid choice")
continue
}
}
}
// HasValidEnhancedCredentials checks if there are valid credentials for the default gateway
func HasValidEnhancedCredentials() (bool, error) {
store, err := LoadEnhancedCredentials()
if err != nil {
return false, err
}
gatewayURL := GetDefaultGatewayURL()
defaultCreds := store.GetDefaultCredential(gatewayURL)
return defaultCreds != nil && defaultCreds.IsValid(), nil
}
// GetValidEnhancedCredentials returns valid credentials for the default gateway
func GetValidEnhancedCredentials() (*Credentials, error) {
store, err := LoadEnhancedCredentials()
if err != nil {
return nil, err
}
gatewayURL := GetDefaultGatewayURL()
defaultCreds := store.GetDefaultCredential(gatewayURL)
if defaultCreds == nil {
return nil, fmt.Errorf("no credentials found for gateway %s", gatewayURL)
}
if !defaultCreds.IsValid() {
return nil, fmt.Errorf("credentials for gateway %s are expired or invalid", gatewayURL)
}
return defaultCreds, nil
}