orama/pkg/gateway/auth/service_test.go
2026-02-14 14:14:04 +02:00

419 lines
12 KiB
Go

package auth
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/DeBrosOfficial/network/pkg/client"
"github.com/DeBrosOfficial/network/pkg/logging"
)
// mockNetworkClient implements client.NetworkClient for testing
type mockNetworkClient struct {
client.NetworkClient
db *mockDatabaseClient
}
func (m *mockNetworkClient) Database() client.DatabaseClient {
return m.db
}
// mockDatabaseClient implements client.DatabaseClient for testing
type mockDatabaseClient struct {
client.DatabaseClient
}
func (m *mockDatabaseClient) Query(ctx context.Context, sql string, args ...interface{}) (*client.QueryResult, error) {
return &client.QueryResult{
Count: 1,
Rows: [][]interface{}{
{1}, // Default ID for ResolveNamespaceID
},
}, nil
}
func createTestService(t *testing.T) *Service {
logger, _ := logging.NewColoredLogger(logging.ComponentGateway, false)
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
mockDB := &mockDatabaseClient{}
mockClient := &mockNetworkClient{db: mockDB}
s, err := NewService(logger, mockClient, string(keyPEM), "test-ns")
if err != nil {
t.Fatalf("failed to create service: %v", err)
}
return s
}
func TestBase58Decode(t *testing.T) {
s := &Service{}
tests := []struct {
input string
expected string // hex representation for comparison
wantErr bool
}{
{"1", "00", false},
{"2", "01", false},
{"9", "08", false},
{"A", "09", false},
{"B", "0a", false},
{"2p", "0100", false}, // 58*1 + 0 = 58 (0x3a) - wait, base58 is weird
}
for _, tt := range tests {
got, err := s.Base58Decode(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("Base58Decode(%s) error = %v, wantErr %v", tt.input, err, tt.wantErr)
continue
}
if !tt.wantErr {
hexGot := hex.EncodeToString(got)
if tt.expected != "" && hexGot != tt.expected {
// Base58 decoding of single characters might not be exactly what I expect above
// but let's just ensure it doesn't crash and returns something for now.
// Better to test a known valid address.
}
}
}
// Test a real Solana address (Base58)
solAddr := "HN7cABqL367i3jkj9684C9C3W197m8q5q1C9C3W197m8"
_, err := s.Base58Decode(solAddr)
if err != nil {
t.Errorf("failed to decode solana address: %v", err)
}
}
func TestJWTFlow(t *testing.T) {
s := createTestService(t)
ns := "test-ns"
sub := "0x1234567890abcdef1234567890abcdef12345678"
ttl := 15 * time.Minute
token, exp, err := s.GenerateJWT(ns, sub, ttl)
if err != nil {
t.Fatalf("GenerateJWT failed: %v", err)
}
if token == "" {
t.Fatal("generated token is empty")
}
if exp <= time.Now().Unix() {
t.Errorf("expiration time %d is in the past", exp)
}
claims, err := s.ParseAndVerifyJWT(token)
if err != nil {
t.Fatalf("ParseAndVerifyJWT failed: %v", err)
}
if claims.Sub != sub {
t.Errorf("expected subject %s, got %s", sub, claims.Sub)
}
if claims.Namespace != ns {
t.Errorf("expected namespace %s, got %s", ns, claims.Namespace)
}
if claims.Iss != "orama-gateway" {
t.Errorf("expected issuer orama-gateway, got %s", claims.Iss)
}
}
func TestVerifyEthSignature(t *testing.T) {
s := &Service{}
// This is a bit hard to test without a real ETH signature
// but we can check if it returns false for obviously wrong signatures
wallet := "0x1234567890abcdef1234567890abcdef12345678"
nonce := "test-nonce"
sig := hex.EncodeToString(make([]byte, 65))
ok, err := s.VerifySignature(context.Background(), wallet, nonce, sig, "ETH")
if err == nil && ok {
t.Error("VerifySignature should have failed for zero signature")
}
}
func TestVerifySolSignature(t *testing.T) {
s := &Service{}
// Solana address (base58)
wallet := "HN7cABqL367i3jkj9684C9C3W197m8q5q1C9C3W197m8"
nonce := "test-nonce"
sig := "invalid-sig"
_, err := s.VerifySignature(context.Background(), wallet, nonce, sig, "SOL")
if err == nil {
t.Error("VerifySignature should have failed for invalid base64 signature")
}
}
// createDualKeyService creates a service with both RSA and EdDSA keys configured
func createDualKeyService(t *testing.T) *Service {
t.Helper()
s := createTestService(t) // has RSA
_, edPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("failed to generate ed25519 key: %v", err)
}
s.SetEdDSAKey(edPriv)
return s
}
func TestEdDSAJWTFlow(t *testing.T) {
s := createDualKeyService(t)
ns := "test-ns"
sub := "0xabcdef1234567890abcdef1234567890abcdef12"
ttl := 15 * time.Minute
// With EdDSA preferred, GenerateJWT should produce an EdDSA token
token, exp, err := s.GenerateJWT(ns, sub, ttl)
if err != nil {
t.Fatalf("GenerateJWT (EdDSA) failed: %v", err)
}
if token == "" {
t.Fatal("generated EdDSA token is empty")
}
if exp <= time.Now().Unix() {
t.Errorf("expiration time %d is in the past", exp)
}
// Verify the header contains EdDSA
parts := strings.Split(token, ".")
hb, _ := base64.RawURLEncoding.DecodeString(parts[0])
var header map[string]string
json.Unmarshal(hb, &header)
if header["alg"] != "EdDSA" {
t.Errorf("expected alg EdDSA, got %s", header["alg"])
}
if header["kid"] != s.edKeyID {
t.Errorf("expected kid %s, got %s", s.edKeyID, header["kid"])
}
// Verify the token
claims, err := s.ParseAndVerifyJWT(token)
if err != nil {
t.Fatalf("ParseAndVerifyJWT (EdDSA) failed: %v", err)
}
if claims.Sub != sub {
t.Errorf("expected subject %s, got %s", sub, claims.Sub)
}
if claims.Namespace != ns {
t.Errorf("expected namespace %s, got %s", ns, claims.Namespace)
}
}
func TestRS256BackwardCompat(t *testing.T) {
s := createDualKeyService(t)
// Generate an RS256 token directly (simulating a legacy token)
s.preferEdDSA = false
token, _, err := s.GenerateJWT("test-ns", "user1", 15*time.Minute)
if err != nil {
t.Fatalf("GenerateJWT (RS256) failed: %v", err)
}
s.preferEdDSA = true // re-enable EdDSA preference
// Verify the RS256 token still works with dual-key service
claims, err := s.ParseAndVerifyJWT(token)
if err != nil {
t.Fatalf("ParseAndVerifyJWT should accept RS256 token: %v", err)
}
if claims.Sub != "user1" {
t.Errorf("expected subject user1, got %s", claims.Sub)
}
}
func TestAlgorithmConfusion_Rejected(t *testing.T) {
s := createDualKeyService(t)
t.Run("none_algorithm", func(t *testing.T) {
// Craft a token with alg=none
header := map[string]string{"alg": "none", "typ": "JWT"}
hb, _ := json.Marshal(header)
payload := map[string]any{
"iss": "orama-gateway", "sub": "attacker", "aud": "gateway",
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
}
pb, _ := json.Marshal(payload)
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
base64.RawURLEncoding.EncodeToString(pb) + "."
_, err := s.ParseAndVerifyJWT(token)
if err == nil {
t.Error("should reject alg=none")
}
})
t.Run("HS256_algorithm", func(t *testing.T) {
header := map[string]string{"alg": "HS256", "typ": "JWT", "kid": s.keyID}
hb, _ := json.Marshal(header)
payload := map[string]any{
"iss": "orama-gateway", "sub": "attacker", "aud": "gateway",
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
}
pb, _ := json.Marshal(payload)
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
base64.RawURLEncoding.EncodeToString(pb) + "." +
base64.RawURLEncoding.EncodeToString([]byte("fake-sig"))
_, err := s.ParseAndVerifyJWT(token)
if err == nil {
t.Error("should reject alg=HS256")
}
})
t.Run("kid_alg_mismatch_EdDSA_kid_RS256_alg", func(t *testing.T) {
// Use EdDSA kid but claim RS256 alg
header := map[string]string{"alg": "RS256", "typ": "JWT", "kid": s.edKeyID}
hb, _ := json.Marshal(header)
payload := map[string]any{
"iss": "orama-gateway", "sub": "attacker", "aud": "gateway",
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
}
pb, _ := json.Marshal(payload)
// Sign with RSA (trying to confuse the verifier into using RSA on EdDSA kid)
hb64 := base64.RawURLEncoding.EncodeToString(hb)
pb64 := base64.RawURLEncoding.EncodeToString(pb)
signingInput := hb64 + "." + pb64
sum := sha256.Sum256([]byte(signingInput))
rsaSig, _ := rsa.SignPKCS1v15(rand.Reader, s.signingKey, 4, sum[:]) // crypto.SHA256 = 4
token := signingInput + "." + base64.RawURLEncoding.EncodeToString(rsaSig)
_, err := s.ParseAndVerifyJWT(token)
if err == nil {
t.Error("should reject kid/alg mismatch (EdDSA kid with RS256 alg)")
}
if err != nil && !strings.Contains(err.Error(), "algorithm mismatch") {
t.Errorf("expected 'algorithm mismatch' error, got: %v", err)
}
})
t.Run("unknown_kid", func(t *testing.T) {
header := map[string]string{"alg": "RS256", "typ": "JWT", "kid": "unknown-kid-123"}
hb, _ := json.Marshal(header)
payload := map[string]any{
"iss": "orama-gateway", "sub": "attacker", "aud": "gateway",
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
}
pb, _ := json.Marshal(payload)
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
base64.RawURLEncoding.EncodeToString(pb) + "." +
base64.RawURLEncoding.EncodeToString([]byte("fake-sig"))
_, err := s.ParseAndVerifyJWT(token)
if err == nil {
t.Error("should reject unknown kid")
}
})
t.Run("legacy_token_EdDSA_rejected", func(t *testing.T) {
// Token with no kid and alg=EdDSA — should be rejected (legacy must be RS256)
header := map[string]string{"alg": "EdDSA", "typ": "JWT"}
hb, _ := json.Marshal(header)
payload := map[string]any{
"iss": "orama-gateway", "sub": "attacker", "aud": "gateway",
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
}
pb, _ := json.Marshal(payload)
hb64 := base64.RawURLEncoding.EncodeToString(hb)
pb64 := base64.RawURLEncoding.EncodeToString(pb)
signingInput := hb64 + "." + pb64
sig := ed25519.Sign(s.edSigningKey, []byte(signingInput))
token := signingInput + "." + base64.RawURLEncoding.EncodeToString(sig)
_, err := s.ParseAndVerifyJWT(token)
if err == nil {
t.Error("should reject legacy token (no kid) with EdDSA alg")
}
})
}
func TestJWKSHandler_DualKey(t *testing.T) {
s := createDualKeyService(t)
req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil)
w := httptest.NewRecorder()
s.JWKSHandler(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var result struct {
Keys []map[string]string `json:"keys"`
}
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
t.Fatalf("failed to decode JWKS response: %v", err)
}
if len(result.Keys) != 2 {
t.Fatalf("expected 2 keys in JWKS, got %d", len(result.Keys))
}
// Verify we have both RSA and OKP keys
algSet := map[string]bool{}
for _, k := range result.Keys {
algSet[k["alg"]] = true
if k["kid"] == "" {
t.Error("key missing kid")
}
}
if !algSet["RS256"] {
t.Error("JWKS missing RS256 key")
}
if !algSet["EdDSA"] {
t.Error("JWKS missing EdDSA key")
}
}
func TestJWKSHandler_RSAOnly(t *testing.T) {
s := createTestService(t) // RSA only, no EdDSA
req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil)
w := httptest.NewRecorder()
s.JWKSHandler(w, req)
var result struct {
Keys []map[string]string `json:"keys"`
}
json.NewDecoder(w.Body).Decode(&result)
if len(result.Keys) != 1 {
t.Fatalf("expected 1 key in JWKS (RSA only), got %d", len(result.Keys))
}
if result.Keys[0]["alg"] != "RS256" {
t.Errorf("expected RS256, got %s", result.Keys[0]["alg"])
}
}