mirror of
https://github.com/DeBrosOfficial/network.git
synced 2025-10-06 14:19:08 +00:00
This commit adds wallet-based authentication to protected CLI commands by removing the manual auth command and automatically prompting for credentials when needed. Protected commands will check for valid credentials and trigger the auth
396 lines
11 KiB
Go
396 lines
11 KiB
Go
package auth
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
// EnhancedCredentialStore manages multiple credentials per gateway
|
|
type EnhancedCredentialStore struct {
|
|
Gateways map[string]*GatewayCredentials `json:"gateways"`
|
|
Version string `json:"version"`
|
|
}
|
|
|
|
// GatewayCredentials holds multiple credentials for a single gateway
|
|
type GatewayCredentials struct {
|
|
Credentials []*Credentials `json:"credentials"`
|
|
DefaultIndex int `json:"default_index"`
|
|
LastUsedIndex int `json:"last_used_index"`
|
|
}
|
|
|
|
// AuthChoice represents user's choice during authentication
|
|
type AuthChoice int
|
|
|
|
const (
|
|
AuthChoiceUseCredential AuthChoice = iota
|
|
AuthChoiceAddCredential
|
|
AuthChoiceLogout
|
|
AuthChoiceExit
|
|
)
|
|
|
|
// LoadEnhancedCredentials loads the enhanced credential store
|
|
func LoadEnhancedCredentials() (*EnhancedCredentialStore, error) {
|
|
credPath, err := GetCredentialsPath()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// If file doesn't exist, return empty store
|
|
if _, err := os.Stat(credPath); os.IsNotExist(err) {
|
|
return &EnhancedCredentialStore{
|
|
Gateways: make(map[string]*GatewayCredentials),
|
|
Version: "2.0",
|
|
}, nil
|
|
}
|
|
|
|
data, err := os.ReadFile(credPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read credentials file: %w", err)
|
|
}
|
|
|
|
// Try to parse as enhanced store first
|
|
var enhancedStore EnhancedCredentialStore
|
|
if err := json.Unmarshal(data, &enhancedStore); err == nil && enhancedStore.Version == "2.0" {
|
|
// Initialize maps if nil
|
|
if enhancedStore.Gateways == nil {
|
|
enhancedStore.Gateways = make(map[string]*GatewayCredentials)
|
|
}
|
|
return &enhancedStore, nil
|
|
}
|
|
|
|
// Fall back to old format and migrate
|
|
var oldStore CredentialStore
|
|
if err := json.Unmarshal(data, &oldStore); err != nil {
|
|
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
|
|
}
|
|
|
|
// Migrate old format to new
|
|
enhancedStore = EnhancedCredentialStore{
|
|
Gateways: make(map[string]*GatewayCredentials),
|
|
Version: "2.0",
|
|
}
|
|
|
|
for gatewayURL, creds := range oldStore.Gateways {
|
|
if creds != nil {
|
|
enhancedStore.Gateways[gatewayURL] = &GatewayCredentials{
|
|
Credentials: []*Credentials{creds},
|
|
DefaultIndex: 0,
|
|
LastUsedIndex: 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
return &enhancedStore, nil
|
|
}
|
|
|
|
// Save saves the enhanced credential store
|
|
func (store *EnhancedCredentialStore) Save() error {
|
|
credPath, err := GetCredentialsPath()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if store.Version == "" {
|
|
store.Version = "2.0"
|
|
}
|
|
|
|
data, err := json.MarshalIndent(store, "", " ")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal credentials: %w", err)
|
|
}
|
|
|
|
return os.WriteFile(credPath, data, 0600)
|
|
}
|
|
|
|
// AddCredential adds a new credential for the gateway
|
|
func (store *EnhancedCredentialStore) AddCredential(gatewayURL string, creds *Credentials) {
|
|
if store.Gateways == nil {
|
|
store.Gateways = make(map[string]*GatewayCredentials)
|
|
}
|
|
|
|
gatewayCredentials := store.Gateways[gatewayURL]
|
|
if gatewayCredentials == nil {
|
|
gatewayCredentials = &GatewayCredentials{
|
|
Credentials: []*Credentials{},
|
|
DefaultIndex: 0,
|
|
LastUsedIndex: 0,
|
|
}
|
|
store.Gateways[gatewayURL] = gatewayCredentials
|
|
}
|
|
|
|
// Check if credential already exists (by wallet address)
|
|
for i, existing := range gatewayCredentials.Credentials {
|
|
if strings.EqualFold(existing.Wallet, creds.Wallet) {
|
|
// Update existing credential
|
|
gatewayCredentials.Credentials[i] = creds
|
|
return
|
|
}
|
|
}
|
|
|
|
// Add new credential
|
|
gatewayCredentials.Credentials = append(gatewayCredentials.Credentials, creds)
|
|
}
|
|
|
|
// GetDefaultCredential returns the default credential for a gateway
|
|
func (store *EnhancedCredentialStore) GetDefaultCredential(gatewayURL string) *Credentials {
|
|
gatewayCredentials := store.Gateways[gatewayURL]
|
|
if gatewayCredentials == nil || len(gatewayCredentials.Credentials) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Ensure default index is valid
|
|
if gatewayCredentials.DefaultIndex < 0 || gatewayCredentials.DefaultIndex >= len(gatewayCredentials.Credentials) {
|
|
gatewayCredentials.DefaultIndex = 0
|
|
}
|
|
|
|
return gatewayCredentials.Credentials[gatewayCredentials.DefaultIndex]
|
|
}
|
|
|
|
// SetDefaultCredential sets the default credential by index
|
|
func (store *EnhancedCredentialStore) SetDefaultCredential(gatewayURL string, index int) bool {
|
|
gatewayCredentials := store.Gateways[gatewayURL]
|
|
if gatewayCredentials == nil || index < 0 || index >= len(gatewayCredentials.Credentials) {
|
|
return false
|
|
}
|
|
|
|
gatewayCredentials.DefaultIndex = index
|
|
gatewayCredentials.LastUsedIndex = index
|
|
return true
|
|
}
|
|
|
|
// ClearAllCredentials removes all credentials
|
|
func (store *EnhancedCredentialStore) ClearAllCredentials() {
|
|
store.Gateways = make(map[string]*GatewayCredentials)
|
|
}
|
|
|
|
// DisplayCredentialMenu shows the interactive credential selection menu
|
|
func (store *EnhancedCredentialStore) DisplayCredentialMenu(gatewayURL string) (AuthChoice, int, error) {
|
|
gatewayCredentials := store.Gateways[gatewayURL]
|
|
|
|
if gatewayCredentials == nil || len(gatewayCredentials.Credentials) == 0 {
|
|
fmt.Println("\n🔐 No credentials found. Choose an option:")
|
|
fmt.Println("1. Authenticate with new wallet")
|
|
fmt.Println("2. Exit")
|
|
fmt.Print("Choose (1-2): ")
|
|
|
|
choice, err := readUserChoice(2)
|
|
if err != nil {
|
|
return AuthChoiceExit, -1, err
|
|
}
|
|
|
|
switch choice {
|
|
case 1:
|
|
return AuthChoiceAddCredential, -1, nil
|
|
case 2:
|
|
return AuthChoiceExit, -1, nil
|
|
default:
|
|
return AuthChoiceExit, -1, fmt.Errorf("invalid choice")
|
|
}
|
|
}
|
|
|
|
fmt.Printf("\n🔐 Multiple wallets available for %s:\n", gatewayURL)
|
|
|
|
// Display credentials
|
|
for i, creds := range gatewayCredentials.Credentials {
|
|
defaultMark := ""
|
|
if i == gatewayCredentials.DefaultIndex {
|
|
defaultMark = " (default)"
|
|
}
|
|
|
|
// Format wallet address for display
|
|
displayAddr := creds.Wallet
|
|
if len(displayAddr) > 10 {
|
|
displayAddr = displayAddr[:6] + "..." + displayAddr[len(displayAddr)-4:]
|
|
}
|
|
|
|
statusEmoji := "✅"
|
|
if !creds.IsValid() {
|
|
statusEmoji = "❌"
|
|
}
|
|
|
|
planInfo := ""
|
|
if creds.Plan != "" {
|
|
planInfo = fmt.Sprintf(" (%s)", creds.Plan)
|
|
}
|
|
|
|
fmt.Printf("%d. %s %s%s%s\n", i+1, statusEmoji, displayAddr, planInfo, defaultMark)
|
|
}
|
|
|
|
fmt.Printf("%d. Add new wallet\n", len(gatewayCredentials.Credentials)+1)
|
|
fmt.Printf("%d. Logout (clear all credentials)\n", len(gatewayCredentials.Credentials)+2)
|
|
fmt.Printf("%d. Exit\n", len(gatewayCredentials.Credentials)+3)
|
|
|
|
maxChoice := len(gatewayCredentials.Credentials) + 3
|
|
fmt.Printf("Choose (1-%d): ", maxChoice)
|
|
|
|
choice, err := readUserChoice(maxChoice)
|
|
if err != nil {
|
|
return AuthChoiceExit, -1, err
|
|
}
|
|
|
|
if choice <= len(gatewayCredentials.Credentials) {
|
|
// User selected a credential
|
|
return AuthChoiceUseCredential, choice - 1, nil
|
|
} else if choice == len(gatewayCredentials.Credentials)+1 {
|
|
// Add new credential
|
|
return AuthChoiceAddCredential, -1, nil
|
|
} else if choice == len(gatewayCredentials.Credentials)+2 {
|
|
// Logout
|
|
return AuthChoiceLogout, -1, nil
|
|
} else {
|
|
// Exit
|
|
return AuthChoiceExit, -1, nil
|
|
}
|
|
}
|
|
|
|
// readUserChoice reads and validates user input
|
|
func readUserChoice(maxChoice int) (int, error) {
|
|
reader := bufio.NewReader(os.Stdin)
|
|
input, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to read input: %w", err)
|
|
}
|
|
|
|
choiceStr := strings.TrimSpace(input)
|
|
choice, err := strconv.Atoi(choiceStr)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("invalid input: please enter a number")
|
|
}
|
|
|
|
if choice < 1 || choice > maxChoice {
|
|
return 0, fmt.Errorf("invalid choice: please enter a number between 1 and %d", maxChoice)
|
|
}
|
|
|
|
return choice, nil
|
|
}
|
|
|
|
// GetOrPromptForCredentials handles the complete authentication flow
|
|
func GetOrPromptForCredentials(gatewayURL string) (*Credentials, error) {
|
|
store, err := LoadEnhancedCredentials()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load credential store: %w", err)
|
|
}
|
|
|
|
// Check if we have a valid default credential
|
|
defaultCreds := store.GetDefaultCredential(gatewayURL)
|
|
if defaultCreds != nil && defaultCreds.IsValid() {
|
|
// Update last used time
|
|
defaultCreds.UpdateLastUsed()
|
|
if err := store.Save(); err != nil {
|
|
// Log warning but don't fail
|
|
fmt.Fprintf(os.Stderr, "Warning: failed to update last used time: %v\n", err)
|
|
}
|
|
return defaultCreds, nil
|
|
}
|
|
|
|
// Need to prompt user for credential selection
|
|
for {
|
|
choice, credIndex, err := store.DisplayCredentialMenu(gatewayURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("menu selection failed: %w", err)
|
|
}
|
|
|
|
switch choice {
|
|
case AuthChoiceUseCredential:
|
|
gatewayCredentials := store.Gateways[gatewayURL]
|
|
if gatewayCredentials == nil || credIndex < 0 || credIndex >= len(gatewayCredentials.Credentials) {
|
|
fmt.Println("❌ Invalid credential selection")
|
|
continue
|
|
}
|
|
|
|
selectedCreds := gatewayCredentials.Credentials[credIndex]
|
|
if !selectedCreds.IsValid() {
|
|
fmt.Println("❌ Selected credentials are invalid or expired")
|
|
continue
|
|
}
|
|
|
|
// Update default and last used
|
|
store.SetDefaultCredential(gatewayURL, credIndex)
|
|
selectedCreds.UpdateLastUsed()
|
|
|
|
if err := store.Save(); err != nil {
|
|
fmt.Fprintf(os.Stderr, "Warning: failed to save credentials: %v\n", err)
|
|
}
|
|
|
|
return selectedCreds, nil
|
|
|
|
case AuthChoiceAddCredential:
|
|
fmt.Println("\n🌐 Opening browser for wallet authentication...")
|
|
newCreds, err := PerformWalletAuthentication(gatewayURL)
|
|
if err != nil {
|
|
fmt.Printf("❌ Authentication failed: %v\n", err)
|
|
continue
|
|
}
|
|
|
|
// Add the new credential
|
|
store.AddCredential(gatewayURL, newCreds)
|
|
|
|
// Set as default if it's the first credential
|
|
gatewayCredentials := store.Gateways[gatewayURL]
|
|
if gatewayCredentials != nil && len(gatewayCredentials.Credentials) == 1 {
|
|
store.SetDefaultCredential(gatewayURL, 0)
|
|
}
|
|
|
|
if err := store.Save(); err != nil {
|
|
return nil, fmt.Errorf("failed to save new credentials: %w", err)
|
|
}
|
|
|
|
fmt.Printf("✅ Wallet %s added successfully\n", newCreds.Wallet)
|
|
return newCreds, nil
|
|
|
|
case AuthChoiceLogout:
|
|
store.ClearAllCredentials()
|
|
if err := store.Save(); err != nil {
|
|
return nil, fmt.Errorf("failed to clear credentials: %w", err)
|
|
}
|
|
fmt.Println("✅ All credentials cleared")
|
|
continue
|
|
|
|
case AuthChoiceExit:
|
|
return nil, fmt.Errorf("authentication cancelled by user")
|
|
|
|
default:
|
|
fmt.Println("❌ Invalid choice")
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
// HasValidEnhancedCredentials checks if there are valid credentials for the default gateway
|
|
func HasValidEnhancedCredentials() (bool, error) {
|
|
store, err := LoadEnhancedCredentials()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
gatewayURL := GetDefaultGatewayURL()
|
|
defaultCreds := store.GetDefaultCredential(gatewayURL)
|
|
|
|
return defaultCreds != nil && defaultCreds.IsValid(), nil
|
|
}
|
|
|
|
// GetValidEnhancedCredentials returns valid credentials for the default gateway
|
|
func GetValidEnhancedCredentials() (*Credentials, error) {
|
|
store, err := LoadEnhancedCredentials()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
gatewayURL := GetDefaultGatewayURL()
|
|
defaultCreds := store.GetDefaultCredential(gatewayURL)
|
|
|
|
if defaultCreds == nil {
|
|
return nil, fmt.Errorf("no credentials found for gateway %s", gatewayURL)
|
|
}
|
|
|
|
if !defaultCreds.IsValid() {
|
|
return nil, fmt.Errorf("credentials for gateway %s are expired or invalid", gatewayURL)
|
|
}
|
|
|
|
return defaultCreds, nil
|
|
}
|