mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-17 14:56:58 +00:00
607 lines
15 KiB
Go
607 lines
15 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
// Solana Token Program ID
|
|
tokenProgramID = "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA"
|
|
// Metaplex Token Metadata Program ID
|
|
metaplexProgramID = "metaqbxxUerdq28cj1RbAWkYQm3ybzjb6a8bt518x1s"
|
|
|
|
// Hardcoded Solana RPC endpoint (mainnet-beta)
|
|
defaultSolanaRPCURL = "https://api.mainnet-beta.solana.com"
|
|
// Required NFT collection address for Phantom auth
|
|
defaultNFTCollectionAddress = "GtsCViqB9fWriKeDMQdveDvYmqqvBCEoxRfu1gzE48uh"
|
|
)
|
|
|
|
// SolanaNFTVerifier verifies NFT ownership on Solana via JSON-RPC.
|
|
type SolanaNFTVerifier struct {
|
|
rpcURL string
|
|
collectionAddress string
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// NewDefaultSolanaNFTVerifier creates a verifier with the hardcoded collection and RPC endpoint.
|
|
func NewDefaultSolanaNFTVerifier() *SolanaNFTVerifier {
|
|
return &SolanaNFTVerifier{
|
|
rpcURL: defaultSolanaRPCURL,
|
|
collectionAddress: defaultNFTCollectionAddress,
|
|
httpClient: &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
},
|
|
}
|
|
}
|
|
|
|
// VerifyNFTOwnership checks if the wallet owns at least one NFT from the configured collection.
|
|
func (v *SolanaNFTVerifier) VerifyNFTOwnership(ctx context.Context, walletAddress string) (bool, error) {
|
|
// 1. Get all token accounts owned by the wallet
|
|
tokenAccounts, err := v.getTokenAccountsByOwner(ctx, walletAddress)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to get token accounts: %w", err)
|
|
}
|
|
|
|
// 2. Filter for NFT-like accounts (amount == 1, decimals == 0)
|
|
var mints []string
|
|
for _, ta := range tokenAccounts {
|
|
if ta.amount == "1" && ta.decimals == 0 {
|
|
mints = append(mints, ta.mint)
|
|
}
|
|
}
|
|
|
|
if len(mints) == 0 {
|
|
return false, nil
|
|
}
|
|
|
|
// Cap mints to prevent excessive RPC calls from wallets with many tokens
|
|
const maxMints = 500
|
|
if len(mints) > maxMints {
|
|
mints = mints[:maxMints]
|
|
}
|
|
|
|
// 3. Derive metadata PDA for each mint
|
|
metaplexProgram, err := base58Decode(metaplexProgramID)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to decode metaplex program: %w", err)
|
|
}
|
|
|
|
var pdas []string
|
|
for _, mint := range mints {
|
|
mintBytes, err := base58Decode(mint)
|
|
if err != nil || len(mintBytes) != 32 {
|
|
continue
|
|
}
|
|
pda, err := findProgramAddress(
|
|
[][]byte{[]byte("metadata"), metaplexProgram, mintBytes},
|
|
metaplexProgram,
|
|
)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
pdas = append(pdas, base58Encode(pda))
|
|
}
|
|
|
|
if len(pdas) == 0 {
|
|
return false, nil
|
|
}
|
|
|
|
// 4. Batch fetch metadata accounts (max 100 per call)
|
|
targetCollection, err := base58Decode(v.collectionAddress)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to decode collection address: %w", err)
|
|
}
|
|
|
|
for i := 0; i < len(pdas); i += 100 {
|
|
end := i + 100
|
|
if end > len(pdas) {
|
|
end = len(pdas)
|
|
}
|
|
batch := pdas[i:end]
|
|
|
|
accounts, err := v.getMultipleAccounts(ctx, batch)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to get metadata accounts: %w", err)
|
|
}
|
|
|
|
for _, acct := range accounts {
|
|
if acct == nil {
|
|
continue
|
|
}
|
|
collKey, verified := parseMetaplexCollection(acct)
|
|
if verified && bytes.Equal(collKey, targetCollection) {
|
|
return true, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// tokenAccountInfo holds parsed SPL token account data.
|
|
type tokenAccountInfo struct {
|
|
mint string
|
|
amount string
|
|
decimals int
|
|
}
|
|
|
|
// getTokenAccountsByOwner fetches all SPL token accounts for a wallet.
|
|
func (v *SolanaNFTVerifier) getTokenAccountsByOwner(ctx context.Context, wallet string) ([]tokenAccountInfo, error) {
|
|
params := []interface{}{
|
|
wallet,
|
|
map[string]string{"programId": tokenProgramID},
|
|
map[string]string{"encoding": "jsonParsed"},
|
|
}
|
|
|
|
result, err := v.rpcCall(ctx, "getTokenAccountsByOwner", params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse the result
|
|
resultMap, ok := result.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected result format")
|
|
}
|
|
|
|
valueArr, ok := resultMap["value"].([]interface{})
|
|
if !ok {
|
|
return nil, nil
|
|
}
|
|
|
|
var accounts []tokenAccountInfo
|
|
for _, item := range valueArr {
|
|
itemMap, ok := item.(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
account, ok := itemMap["account"].(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
data, ok := account["data"].(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
parsed, ok := data["parsed"].(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
info, ok := parsed["info"].(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
mint, _ := info["mint"].(string)
|
|
tokenAmount, ok := info["tokenAmount"].(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
amount, _ := tokenAmount["amount"].(string)
|
|
decimals, _ := tokenAmount["decimals"].(float64)
|
|
|
|
if mint != "" && amount != "" {
|
|
accounts = append(accounts, tokenAccountInfo{
|
|
mint: mint,
|
|
amount: amount,
|
|
decimals: int(decimals),
|
|
})
|
|
}
|
|
}
|
|
|
|
return accounts, nil
|
|
}
|
|
|
|
// getMultipleAccounts fetches multiple accounts by their addresses.
|
|
// Returns raw account data (base64-decoded) for each address, nil for missing accounts.
|
|
func (v *SolanaNFTVerifier) getMultipleAccounts(ctx context.Context, addresses []string) ([][]byte, error) {
|
|
params := []interface{}{
|
|
addresses,
|
|
map[string]string{"encoding": "base64"},
|
|
}
|
|
|
|
result, err := v.rpcCall(ctx, "getMultipleAccounts", params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resultMap, ok := result.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected result format")
|
|
}
|
|
|
|
valueArr, ok := resultMap["value"].([]interface{})
|
|
if !ok {
|
|
return nil, nil
|
|
}
|
|
|
|
accounts := make([][]byte, len(valueArr))
|
|
for i, item := range valueArr {
|
|
if item == nil {
|
|
continue
|
|
}
|
|
acct, ok := item.(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
dataArr, ok := acct["data"].([]interface{})
|
|
if !ok || len(dataArr) < 1 {
|
|
continue
|
|
}
|
|
dataStr, ok := dataArr[0].(string)
|
|
if !ok {
|
|
continue
|
|
}
|
|
decoded, err := base64.StdEncoding.DecodeString(dataStr)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
accounts[i] = decoded
|
|
}
|
|
|
|
return accounts, nil
|
|
}
|
|
|
|
// rpcCall executes a Solana JSON-RPC call.
|
|
func (v *SolanaNFTVerifier) rpcCall(ctx context.Context, method string, params []interface{}) (interface{}, error) {
|
|
reqBody := map[string]interface{}{
|
|
"jsonrpc": "2.0",
|
|
"id": 1,
|
|
"method": method,
|
|
"params": params,
|
|
}
|
|
|
|
payload, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal RPC request: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", v.rpcURL, bytes.NewReader(payload))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create RPC request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := v.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("RPC request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("RPC returned HTTP %d", resp.StatusCode)
|
|
}
|
|
|
|
// Limit response size to 10MB to prevent memory exhaustion
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read RPC response: %w", err)
|
|
}
|
|
|
|
var rpcResp struct {
|
|
Result interface{} `json:"result"`
|
|
Error map[string]interface{} `json:"error"`
|
|
}
|
|
if err := json.Unmarshal(body, &rpcResp); err != nil {
|
|
return nil, fmt.Errorf("failed to parse RPC response: %w", err)
|
|
}
|
|
|
|
if rpcResp.Error != nil {
|
|
msg, _ := rpcResp.Error["message"].(string)
|
|
return nil, fmt.Errorf("RPC error: %s", msg)
|
|
}
|
|
|
|
return rpcResp.Result, nil
|
|
}
|
|
|
|
// parseMetaplexCollection extracts the collection key and verified flag from
|
|
// Borsh-encoded Metaplex metadata account data.
|
|
//
|
|
// Metaplex Token Metadata v1 layout (simplified):
|
|
// - [0]: key (1 byte, should be 4 for MetadataV1)
|
|
// - [1..33]: update_authority (32 bytes)
|
|
// - [33..65]: mint (32 bytes)
|
|
// - [65..]: name (4-byte len prefix + UTF-8, borsh string)
|
|
// - then: symbol (borsh string)
|
|
// - then: uri (borsh string)
|
|
// - then: seller_fee_basis_points (u16, 2 bytes)
|
|
// - then: creators (Option<Vec<Creator>>)
|
|
// - then: primary_sale_happened (bool, 1 byte)
|
|
// - then: is_mutable (bool, 1 byte)
|
|
// - then: edition_nonce (Option<u8>)
|
|
// - then: token_standard (Option<u8>)
|
|
// - then: collection (Option<Collection>)
|
|
// - Collection: { verified: bool(1), key: Pubkey(32) }
|
|
func parseMetaplexCollection(data []byte) (collectionKey []byte, verified bool) {
|
|
if len(data) < 66 {
|
|
return nil, false
|
|
}
|
|
|
|
// Validate metadata key byte (must be 4 = MetadataV1)
|
|
if data[0] != 4 {
|
|
return nil, false
|
|
}
|
|
|
|
// Skip: key(1) + update_authority(32) + mint(32)
|
|
offset := 65
|
|
|
|
// Skip name (borsh string: 4-byte LE length + bytes)
|
|
offset, _ = skipBorshString(data, offset)
|
|
if offset < 0 {
|
|
return nil, false
|
|
}
|
|
|
|
// Skip symbol
|
|
offset, _ = skipBorshString(data, offset)
|
|
if offset < 0 {
|
|
return nil, false
|
|
}
|
|
|
|
// Skip uri
|
|
offset, _ = skipBorshString(data, offset)
|
|
if offset < 0 {
|
|
return nil, false
|
|
}
|
|
|
|
// Skip seller_fee_basis_points (u16)
|
|
offset += 2
|
|
if offset > len(data) {
|
|
return nil, false
|
|
}
|
|
|
|
// Skip creators (Option<Vec<Creator>>)
|
|
// Option: 1 byte (0 = None, 1 = Some)
|
|
if offset >= len(data) {
|
|
return nil, false
|
|
}
|
|
if data[offset] == 1 {
|
|
offset++ // skip option byte
|
|
if offset+4 > len(data) {
|
|
return nil, false
|
|
}
|
|
numCreators := int(binary.LittleEndian.Uint32(data[offset : offset+4]))
|
|
offset += 4
|
|
// Solana limits creators to 5, but be generous with 20
|
|
if numCreators > 20 {
|
|
return nil, false
|
|
}
|
|
// Each Creator: pubkey(32) + verified(1) + share(1) = 34 bytes
|
|
creatorBytes := numCreators * 34
|
|
if offset+creatorBytes > len(data) {
|
|
return nil, false
|
|
}
|
|
offset += creatorBytes
|
|
} else {
|
|
offset++ // skip None byte
|
|
}
|
|
|
|
if offset >= len(data) {
|
|
return nil, false
|
|
}
|
|
|
|
// Skip primary_sale_happened (bool)
|
|
offset++
|
|
if offset >= len(data) {
|
|
return nil, false
|
|
}
|
|
|
|
// Skip is_mutable (bool)
|
|
offset++
|
|
if offset >= len(data) {
|
|
return nil, false
|
|
}
|
|
|
|
// Skip edition_nonce (Option<u8>)
|
|
if offset >= len(data) {
|
|
return nil, false
|
|
}
|
|
if data[offset] == 1 {
|
|
offset += 2 // option byte + u8
|
|
} else {
|
|
offset++ // None
|
|
}
|
|
|
|
// Skip token_standard (Option<u8>)
|
|
if offset >= len(data) {
|
|
return nil, false
|
|
}
|
|
if data[offset] == 1 {
|
|
offset += 2
|
|
} else {
|
|
offset++
|
|
}
|
|
|
|
// Collection (Option<Collection>)
|
|
if offset >= len(data) {
|
|
return nil, false
|
|
}
|
|
if data[offset] == 0 {
|
|
// No collection
|
|
return nil, false
|
|
}
|
|
offset++ // skip option byte
|
|
|
|
// Collection: verified(1 byte bool) + key(32 bytes)
|
|
if offset+33 > len(data) {
|
|
return nil, false
|
|
}
|
|
verified = data[offset] == 1
|
|
offset++
|
|
collectionKey = data[offset : offset+32]
|
|
|
|
return collectionKey, verified
|
|
}
|
|
|
|
// skipBorshString skips a Borsh-encoded string (4-byte LE length + bytes) at the given offset.
|
|
// Returns the new offset, or -1 if the data is too short.
|
|
func skipBorshString(data []byte, offset int) (int, string) {
|
|
if offset+4 > len(data) {
|
|
return -1, ""
|
|
}
|
|
strLen := int(binary.LittleEndian.Uint32(data[offset : offset+4]))
|
|
offset += 4
|
|
if offset+strLen > len(data) {
|
|
return -1, ""
|
|
}
|
|
s := string(data[offset : offset+strLen])
|
|
return offset + strLen, s
|
|
}
|
|
|
|
// findProgramAddress derives a Solana Program Derived Address (PDA).
|
|
// It finds the first valid PDA by trying bump seeds from 255 down to 0.
|
|
func findProgramAddress(seeds [][]byte, programID []byte) ([]byte, error) {
|
|
for bump := byte(255); ; bump-- {
|
|
candidate := derivePDA(seeds, bump, programID)
|
|
if !isOnCurve(candidate) {
|
|
return candidate, nil
|
|
}
|
|
if bump == 0 {
|
|
break
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("could not find valid PDA")
|
|
}
|
|
|
|
// derivePDA computes SHA256(seeds || bump || programID || "ProgramDerivedAddress").
|
|
func derivePDA(seeds [][]byte, bump byte, programID []byte) []byte {
|
|
h := sha256.New()
|
|
for _, seed := range seeds {
|
|
h.Write(seed)
|
|
}
|
|
h.Write([]byte{bump})
|
|
h.Write(programID)
|
|
h.Write([]byte("ProgramDerivedAddress"))
|
|
return h.Sum(nil)
|
|
}
|
|
|
|
// isOnCurve checks if a 32-byte key is on the Ed25519 curve.
|
|
// PDAs must NOT be on the curve (they have no private key).
|
|
// This uses a simplified check based on the Ed25519 point decompression.
|
|
func isOnCurve(key []byte) bool {
|
|
if len(key) != 32 {
|
|
return false
|
|
}
|
|
|
|
// Ed25519 field prime: p = 2^255 - 19
|
|
p := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 255), big.NewInt(19))
|
|
|
|
// Extract y coordinate (little-endian, clear top bit)
|
|
yBytes := make([]byte, 32)
|
|
copy(yBytes, key)
|
|
yBytes[31] &= 0x7f
|
|
|
|
// Reverse for big-endian
|
|
for i, j := 0, len(yBytes)-1; i < j; i, j = i+1, j-1 {
|
|
yBytes[i], yBytes[j] = yBytes[j], yBytes[i]
|
|
}
|
|
|
|
y := new(big.Int).SetBytes(yBytes)
|
|
if y.Cmp(p) >= 0 {
|
|
return false
|
|
}
|
|
|
|
// Compute u = y^2 - 1
|
|
y2 := new(big.Int).Mul(y, y)
|
|
y2.Mod(y2, p)
|
|
u := new(big.Int).Sub(y2, big.NewInt(1))
|
|
u.Mod(u, p)
|
|
if u.Sign() < 0 {
|
|
u.Add(u, p)
|
|
}
|
|
|
|
// d = -121665/121666 mod p
|
|
d := new(big.Int).SetInt64(121666)
|
|
d.ModInverse(d, p)
|
|
d.Mul(d, big.NewInt(-121665))
|
|
d.Mod(d, p)
|
|
if d.Sign() < 0 {
|
|
d.Add(d, p)
|
|
}
|
|
|
|
// Compute v = d*y^2 + 1
|
|
v := new(big.Int).Mul(d, y2)
|
|
v.Mod(v, p)
|
|
v.Add(v, big.NewInt(1))
|
|
v.Mod(v, p)
|
|
|
|
// Check if u/v is a quadratic residue mod p
|
|
// x^2 = u * v^{-1} mod p
|
|
vInv := new(big.Int).ModInverse(v, p)
|
|
if vInv == nil {
|
|
return false
|
|
}
|
|
x2 := new(big.Int).Mul(u, vInv)
|
|
x2.Mod(x2, p)
|
|
|
|
// Euler criterion: x2^((p-1)/2) mod p == 1 means QR
|
|
exp := new(big.Int).Sub(p, big.NewInt(1))
|
|
exp.Rsh(exp, 1)
|
|
result := new(big.Int).Exp(x2, exp, p)
|
|
|
|
return result.Cmp(big.NewInt(1)) == 0 || x2.Sign() == 0
|
|
}
|
|
|
|
// base58Decode decodes a base58-encoded string (same as Service.Base58Decode but standalone).
|
|
func base58Decode(input string) ([]byte, error) {
|
|
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
|
answer := big.NewInt(0)
|
|
j := big.NewInt(1)
|
|
for i := len(input) - 1; i >= 0; i-- {
|
|
tmp := strings.IndexByte(alphabet, input[i])
|
|
if tmp == -1 {
|
|
return nil, fmt.Errorf("invalid base58 character")
|
|
}
|
|
idx := big.NewInt(int64(tmp))
|
|
tmp1 := new(big.Int).Mul(idx, j)
|
|
answer.Add(answer, tmp1)
|
|
j.Mul(j, big.NewInt(58))
|
|
}
|
|
res := answer.Bytes()
|
|
for i := 0; i < len(input) && input[i] == alphabet[0]; i++ {
|
|
res = append([]byte{0}, res...)
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// base58Encode encodes bytes to base58.
|
|
func base58Encode(input []byte) string {
|
|
const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
|
|
|
x := new(big.Int).SetBytes(input)
|
|
base := big.NewInt(58)
|
|
zero := big.NewInt(0)
|
|
mod := new(big.Int)
|
|
|
|
var result []byte
|
|
for x.Cmp(zero) > 0 {
|
|
x.DivMod(x, base, mod)
|
|
result = append(result, alphabet[mod.Int64()])
|
|
}
|
|
|
|
// Leading zeros
|
|
for _, b := range input {
|
|
if b != 0 {
|
|
break
|
|
}
|
|
result = append(result, alphabet[0])
|
|
}
|
|
|
|
// Reverse
|
|
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
|
|
result[i], result[j] = result[j], result[i]
|
|
}
|
|
|
|
return string(result)
|
|
}
|