mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-17 05:13:01 +00:00
419 lines
12 KiB
Go
419 lines
12 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/DeBrosOfficial/network/pkg/client"
|
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
|
)
|
|
|
|
// mockNetworkClient implements client.NetworkClient for testing
|
|
type mockNetworkClient struct {
|
|
client.NetworkClient
|
|
db *mockDatabaseClient
|
|
}
|
|
|
|
func (m *mockNetworkClient) Database() client.DatabaseClient {
|
|
return m.db
|
|
}
|
|
|
|
// mockDatabaseClient implements client.DatabaseClient for testing
|
|
type mockDatabaseClient struct {
|
|
client.DatabaseClient
|
|
}
|
|
|
|
func (m *mockDatabaseClient) Query(ctx context.Context, sql string, args ...interface{}) (*client.QueryResult, error) {
|
|
return &client.QueryResult{
|
|
Count: 1,
|
|
Rows: [][]interface{}{
|
|
{1}, // Default ID for ResolveNamespaceID
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func createTestService(t *testing.T) *Service {
|
|
logger, _ := logging.NewColoredLogger(logging.ComponentGateway, false)
|
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate key: %v", err)
|
|
}
|
|
keyPEM := pem.EncodeToMemory(&pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
|
})
|
|
|
|
mockDB := &mockDatabaseClient{}
|
|
mockClient := &mockNetworkClient{db: mockDB}
|
|
|
|
s, err := NewService(logger, mockClient, string(keyPEM), "test-ns")
|
|
if err != nil {
|
|
t.Fatalf("failed to create service: %v", err)
|
|
}
|
|
return s
|
|
}
|
|
|
|
func TestBase58Decode(t *testing.T) {
|
|
s := &Service{}
|
|
tests := []struct {
|
|
input string
|
|
expected string // hex representation for comparison
|
|
wantErr bool
|
|
}{
|
|
{"1", "00", false},
|
|
{"2", "01", false},
|
|
{"9", "08", false},
|
|
{"A", "09", false},
|
|
{"B", "0a", false},
|
|
{"2p", "0100", false}, // 58*1 + 0 = 58 (0x3a) - wait, base58 is weird
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
got, err := s.Base58Decode(tt.input)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("Base58Decode(%s) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
|
continue
|
|
}
|
|
if !tt.wantErr {
|
|
hexGot := hex.EncodeToString(got)
|
|
if tt.expected != "" && hexGot != tt.expected {
|
|
// Base58 decoding of single characters might not be exactly what I expect above
|
|
// but let's just ensure it doesn't crash and returns something for now.
|
|
// Better to test a known valid address.
|
|
}
|
|
}
|
|
}
|
|
|
|
// Test a real Solana address (Base58)
|
|
solAddr := "HN7cABqL367i3jkj9684C9C3W197m8q5q1C9C3W197m8"
|
|
_, err := s.Base58Decode(solAddr)
|
|
if err != nil {
|
|
t.Errorf("failed to decode solana address: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestJWTFlow(t *testing.T) {
|
|
s := createTestService(t)
|
|
|
|
ns := "test-ns"
|
|
sub := "0x1234567890abcdef1234567890abcdef12345678"
|
|
ttl := 15 * time.Minute
|
|
|
|
token, exp, err := s.GenerateJWT(ns, sub, ttl)
|
|
if err != nil {
|
|
t.Fatalf("GenerateJWT failed: %v", err)
|
|
}
|
|
|
|
if token == "" {
|
|
t.Fatal("generated token is empty")
|
|
}
|
|
|
|
if exp <= time.Now().Unix() {
|
|
t.Errorf("expiration time %d is in the past", exp)
|
|
}
|
|
|
|
claims, err := s.ParseAndVerifyJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("ParseAndVerifyJWT failed: %v", err)
|
|
}
|
|
|
|
if claims.Sub != sub {
|
|
t.Errorf("expected subject %s, got %s", sub, claims.Sub)
|
|
}
|
|
|
|
if claims.Namespace != ns {
|
|
t.Errorf("expected namespace %s, got %s", ns, claims.Namespace)
|
|
}
|
|
|
|
if claims.Iss != "orama-gateway" {
|
|
t.Errorf("expected issuer orama-gateway, got %s", claims.Iss)
|
|
}
|
|
}
|
|
|
|
func TestVerifyEthSignature(t *testing.T) {
|
|
s := &Service{}
|
|
|
|
// This is a bit hard to test without a real ETH signature
|
|
// but we can check if it returns false for obviously wrong signatures
|
|
wallet := "0x1234567890abcdef1234567890abcdef12345678"
|
|
nonce := "test-nonce"
|
|
sig := hex.EncodeToString(make([]byte, 65))
|
|
|
|
ok, err := s.VerifySignature(context.Background(), wallet, nonce, sig, "ETH")
|
|
if err == nil && ok {
|
|
t.Error("VerifySignature should have failed for zero signature")
|
|
}
|
|
}
|
|
|
|
func TestVerifySolSignature(t *testing.T) {
|
|
s := &Service{}
|
|
|
|
// Solana address (base58)
|
|
wallet := "HN7cABqL367i3jkj9684C9C3W197m8q5q1C9C3W197m8"
|
|
nonce := "test-nonce"
|
|
sig := "invalid-sig"
|
|
|
|
_, err := s.VerifySignature(context.Background(), wallet, nonce, sig, "SOL")
|
|
if err == nil {
|
|
t.Error("VerifySignature should have failed for invalid base64 signature")
|
|
}
|
|
}
|
|
|
|
// createDualKeyService creates a service with both RSA and EdDSA keys configured
|
|
func createDualKeyService(t *testing.T) *Service {
|
|
t.Helper()
|
|
s := createTestService(t) // has RSA
|
|
_, edPriv, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate ed25519 key: %v", err)
|
|
}
|
|
s.SetEdDSAKey(edPriv)
|
|
return s
|
|
}
|
|
|
|
func TestEdDSAJWTFlow(t *testing.T) {
|
|
s := createDualKeyService(t)
|
|
|
|
ns := "test-ns"
|
|
sub := "0xabcdef1234567890abcdef1234567890abcdef12"
|
|
ttl := 15 * time.Minute
|
|
|
|
// With EdDSA preferred, GenerateJWT should produce an EdDSA token
|
|
token, exp, err := s.GenerateJWT(ns, sub, ttl)
|
|
if err != nil {
|
|
t.Fatalf("GenerateJWT (EdDSA) failed: %v", err)
|
|
}
|
|
if token == "" {
|
|
t.Fatal("generated EdDSA token is empty")
|
|
}
|
|
if exp <= time.Now().Unix() {
|
|
t.Errorf("expiration time %d is in the past", exp)
|
|
}
|
|
|
|
// Verify the header contains EdDSA
|
|
parts := strings.Split(token, ".")
|
|
hb, _ := base64.RawURLEncoding.DecodeString(parts[0])
|
|
var header map[string]string
|
|
json.Unmarshal(hb, &header)
|
|
if header["alg"] != "EdDSA" {
|
|
t.Errorf("expected alg EdDSA, got %s", header["alg"])
|
|
}
|
|
if header["kid"] != s.edKeyID {
|
|
t.Errorf("expected kid %s, got %s", s.edKeyID, header["kid"])
|
|
}
|
|
|
|
// Verify the token
|
|
claims, err := s.ParseAndVerifyJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("ParseAndVerifyJWT (EdDSA) failed: %v", err)
|
|
}
|
|
if claims.Sub != sub {
|
|
t.Errorf("expected subject %s, got %s", sub, claims.Sub)
|
|
}
|
|
if claims.Namespace != ns {
|
|
t.Errorf("expected namespace %s, got %s", ns, claims.Namespace)
|
|
}
|
|
}
|
|
|
|
func TestRS256BackwardCompat(t *testing.T) {
|
|
s := createDualKeyService(t)
|
|
|
|
// Generate an RS256 token directly (simulating a legacy token)
|
|
s.preferEdDSA = false
|
|
token, _, err := s.GenerateJWT("test-ns", "user1", 15*time.Minute)
|
|
if err != nil {
|
|
t.Fatalf("GenerateJWT (RS256) failed: %v", err)
|
|
}
|
|
s.preferEdDSA = true // re-enable EdDSA preference
|
|
|
|
// Verify the RS256 token still works with dual-key service
|
|
claims, err := s.ParseAndVerifyJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("ParseAndVerifyJWT should accept RS256 token: %v", err)
|
|
}
|
|
if claims.Sub != "user1" {
|
|
t.Errorf("expected subject user1, got %s", claims.Sub)
|
|
}
|
|
}
|
|
|
|
func TestAlgorithmConfusion_Rejected(t *testing.T) {
|
|
s := createDualKeyService(t)
|
|
|
|
t.Run("none_algorithm", func(t *testing.T) {
|
|
// Craft a token with alg=none
|
|
header := map[string]string{"alg": "none", "typ": "JWT"}
|
|
hb, _ := json.Marshal(header)
|
|
payload := map[string]any{
|
|
"iss": "orama-gateway", "sub": "attacker", "aud": "gateway",
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
|
|
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
|
|
}
|
|
pb, _ := json.Marshal(payload)
|
|
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
|
|
base64.RawURLEncoding.EncodeToString(pb) + "."
|
|
|
|
_, err := s.ParseAndVerifyJWT(token)
|
|
if err == nil {
|
|
t.Error("should reject alg=none")
|
|
}
|
|
})
|
|
|
|
t.Run("HS256_algorithm", func(t *testing.T) {
|
|
header := map[string]string{"alg": "HS256", "typ": "JWT", "kid": s.keyID}
|
|
hb, _ := json.Marshal(header)
|
|
payload := map[string]any{
|
|
"iss": "orama-gateway", "sub": "attacker", "aud": "gateway",
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
|
|
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
|
|
}
|
|
pb, _ := json.Marshal(payload)
|
|
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
|
|
base64.RawURLEncoding.EncodeToString(pb) + "." +
|
|
base64.RawURLEncoding.EncodeToString([]byte("fake-sig"))
|
|
|
|
_, err := s.ParseAndVerifyJWT(token)
|
|
if err == nil {
|
|
t.Error("should reject alg=HS256")
|
|
}
|
|
})
|
|
|
|
t.Run("kid_alg_mismatch_EdDSA_kid_RS256_alg", func(t *testing.T) {
|
|
// Use EdDSA kid but claim RS256 alg
|
|
header := map[string]string{"alg": "RS256", "typ": "JWT", "kid": s.edKeyID}
|
|
hb, _ := json.Marshal(header)
|
|
payload := map[string]any{
|
|
"iss": "orama-gateway", "sub": "attacker", "aud": "gateway",
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
|
|
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
|
|
}
|
|
pb, _ := json.Marshal(payload)
|
|
// Sign with RSA (trying to confuse the verifier into using RSA on EdDSA kid)
|
|
hb64 := base64.RawURLEncoding.EncodeToString(hb)
|
|
pb64 := base64.RawURLEncoding.EncodeToString(pb)
|
|
signingInput := hb64 + "." + pb64
|
|
sum := sha256.Sum256([]byte(signingInput))
|
|
rsaSig, _ := rsa.SignPKCS1v15(rand.Reader, s.signingKey, 4, sum[:]) // crypto.SHA256 = 4
|
|
token := signingInput + "." + base64.RawURLEncoding.EncodeToString(rsaSig)
|
|
|
|
_, err := s.ParseAndVerifyJWT(token)
|
|
if err == nil {
|
|
t.Error("should reject kid/alg mismatch (EdDSA kid with RS256 alg)")
|
|
}
|
|
if err != nil && !strings.Contains(err.Error(), "algorithm mismatch") {
|
|
t.Errorf("expected 'algorithm mismatch' error, got: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("unknown_kid", func(t *testing.T) {
|
|
header := map[string]string{"alg": "RS256", "typ": "JWT", "kid": "unknown-kid-123"}
|
|
hb, _ := json.Marshal(header)
|
|
payload := map[string]any{
|
|
"iss": "orama-gateway", "sub": "attacker", "aud": "gateway",
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
|
|
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
|
|
}
|
|
pb, _ := json.Marshal(payload)
|
|
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
|
|
base64.RawURLEncoding.EncodeToString(pb) + "." +
|
|
base64.RawURLEncoding.EncodeToString([]byte("fake-sig"))
|
|
|
|
_, err := s.ParseAndVerifyJWT(token)
|
|
if err == nil {
|
|
t.Error("should reject unknown kid")
|
|
}
|
|
})
|
|
|
|
t.Run("legacy_token_EdDSA_rejected", func(t *testing.T) {
|
|
// Token with no kid and alg=EdDSA — should be rejected (legacy must be RS256)
|
|
header := map[string]string{"alg": "EdDSA", "typ": "JWT"}
|
|
hb, _ := json.Marshal(header)
|
|
payload := map[string]any{
|
|
"iss": "orama-gateway", "sub": "attacker", "aud": "gateway",
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
|
|
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
|
|
}
|
|
pb, _ := json.Marshal(payload)
|
|
hb64 := base64.RawURLEncoding.EncodeToString(hb)
|
|
pb64 := base64.RawURLEncoding.EncodeToString(pb)
|
|
signingInput := hb64 + "." + pb64
|
|
sig := ed25519.Sign(s.edSigningKey, []byte(signingInput))
|
|
token := signingInput + "." + base64.RawURLEncoding.EncodeToString(sig)
|
|
|
|
_, err := s.ParseAndVerifyJWT(token)
|
|
if err == nil {
|
|
t.Error("should reject legacy token (no kid) with EdDSA alg")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJWKSHandler_DualKey(t *testing.T) {
|
|
s := createDualKeyService(t)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil)
|
|
w := httptest.NewRecorder()
|
|
s.JWKSHandler(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", w.Code)
|
|
}
|
|
|
|
var result struct {
|
|
Keys []map[string]string `json:"keys"`
|
|
}
|
|
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
|
|
t.Fatalf("failed to decode JWKS response: %v", err)
|
|
}
|
|
|
|
if len(result.Keys) != 2 {
|
|
t.Fatalf("expected 2 keys in JWKS, got %d", len(result.Keys))
|
|
}
|
|
|
|
// Verify we have both RSA and OKP keys
|
|
algSet := map[string]bool{}
|
|
for _, k := range result.Keys {
|
|
algSet[k["alg"]] = true
|
|
if k["kid"] == "" {
|
|
t.Error("key missing kid")
|
|
}
|
|
}
|
|
if !algSet["RS256"] {
|
|
t.Error("JWKS missing RS256 key")
|
|
}
|
|
if !algSet["EdDSA"] {
|
|
t.Error("JWKS missing EdDSA key")
|
|
}
|
|
}
|
|
|
|
func TestJWKSHandler_RSAOnly(t *testing.T) {
|
|
s := createTestService(t) // RSA only, no EdDSA
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil)
|
|
w := httptest.NewRecorder()
|
|
s.JWKSHandler(w, req)
|
|
|
|
var result struct {
|
|
Keys []map[string]string `json:"keys"`
|
|
}
|
|
json.NewDecoder(w.Body).Decode(&result)
|
|
|
|
if len(result.Keys) != 1 {
|
|
t.Fatalf("expected 1 key in JWKS (RSA only), got %d", len(result.Keys))
|
|
}
|
|
if result.Keys[0]["alg"] != "RS256" {
|
|
t.Errorf("expected RS256, got %s", result.Keys[0]["alg"])
|
|
}
|
|
}
|