Skip to content

Commit

Permalink
Discord OAuth now functional
Browse files Browse the repository at this point in the history
  • Loading branch information
p0t4t0sandwich committed Apr 18, 2024
1 parent 19e8794 commit fa8ca17
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 46 deletions.
2 changes: 0 additions & 2 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
mw "github.com/NeuralNexusDev/neuralnexus-api/middleware"
authroutes "github.com/NeuralNexusDev/neuralnexus-api/modules/auth/routes"
beenamegenerator "github.com/NeuralNexusDev/neuralnexus-api/modules/bee_name_generator"
cctturtle "github.com/NeuralNexusDev/neuralnexus-api/modules/cct_turtle"
gss "github.com/NeuralNexusDev/neuralnexus-api/modules/game_server_status"
"github.com/NeuralNexusDev/neuralnexus-api/modules/mcstatus"
petpictures "github.com/NeuralNexusDev/neuralnexus-api/modules/pet_pictures"
Expand Down Expand Up @@ -46,7 +45,6 @@ func (s *APIServer) Setup() http.Handler {
routerStack := routes.CreateStack(
authroutes.ApplyRoutes,
beenamegenerator.ApplyRoutes,
cctturtle.ApplyRoutes,
gss.ApplyRoutes,
mcstatus.ApplyRoutes,
petpictures.ApplyRoutes,
Expand Down
9 changes: 3 additions & 6 deletions modules/auth/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ func GetAccountByID(userID uuid.UUID) database.Response[Account] {
return database.ErrorResponse[Account]("Unable to get account", err)
}

var account *Account
account, err = pgx.CollectExactlyOneRow(rows, pgx.RowToAddrOfStructByName[Account])
account, err := pgx.CollectExactlyOneRow(rows, pgx.RowToAddrOfStructByName[Account])
if err != nil {
return database.ErrorResponse[Account]("Unable to get account", err)
}
Expand All @@ -138,8 +137,7 @@ func GetAccountByUsername(username string) database.Response[Account] {
return database.ErrorResponse[Account]("Unable to get account", err)
}

var account *Account
account, err = pgx.CollectExactlyOneRow(rows, pgx.RowToAddrOfStructByName[Account])
account, err := pgx.CollectExactlyOneRow(rows, pgx.RowToAddrOfStructByName[Account])
if err != nil {
return database.ErrorResponse[Account]("Unable to get account", err)
}
Expand All @@ -156,8 +154,7 @@ func GetAccountByEmail(email string) database.Response[Account] {
return database.ErrorResponse[Account]("Unable to get account", err)
}

var account *Account
account, err = pgx.CollectExactlyOneRow(rows, pgx.RowToAddrOfStructByName[Account])
account, err := pgx.CollectExactlyOneRow(rows, pgx.RowToAddrOfStructByName[Account])
if err != nil {
return database.ErrorResponse[Account]("Unable to get account", err)
}
Expand Down
85 changes: 68 additions & 17 deletions modules/auth/linking/account_linking.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
package account_linking
package accountlinking

import "github.com/google/uuid"
import (
"context"
"time"

"github.com/NeuralNexusDev/neuralnexus-api/modules/database"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
)

// CREATE TABLE linked_accounts (
// user_id UUID FOREIGN KEY REFERENCES accounts(user_id),
// platform TEXT NOT NULL,
// platform_username TEXT NOT NULL,
// platform_id TEXT NOT NULL,
// data JSONB NOT NULL,
// data_updated_at timestamp with time zone default current_timestamp,
// created_at timestamp with time zone default current_timestamp,
// CONSTRAINT linked_accounts_unique UNIQUE (user_id, platform)
// user_id UUID NOT NULL,
// platform TEXT NOT NULL,
// platform_username TEXT NOT NULL,
// platform_id TEXT NOT NULL,
// data JSONB NOT NULL,
// data_updated_at timestamp with time zone default current_timestamp,
// created_at timestamp with time zone default current_timestamp,
// FOREIGN KEY (user_id) REFERENCES accounts(user_id),
// CONSTRAINT linked_accounts_unique UNIQUE (user_id, platform)
// );

// -------------- Structs --------------

// LinkedAccount struct
type LinkedAccount struct {
UserID uuid.UUID `db:"user_id" validate:"required"`
Platform string `db:"platform" validate:"required"`
PlatformUsername string `db:"platform_username" validate:"required_without=PlatformID"`
PlatformID string `db:"platform_id" validate:"required_without=PlatformUsername"`
Data Data `db:"data" validate:"required"`
DataUpdatedAt string `db:"data_updated_at"`
CreatedAt string `db:"created_at"`
UserID uuid.UUID `db:"user_id" validate:"required"`
Platform string `db:"platform" validate:"required"`
PlatformUsername string `db:"platform_username" validate:"required_without=PlatformID"`
PlatformID string `db:"platform_id" validate:"required_without=PlatformUsername"`
Data interface{} `db:"data" validate:"required"`
DataUpdatedAt time.Time `db:"data_updated_at"`
CreatedAt time.Time `db:"created_at"`
}

// NewLinkedAccount creates a new linked account
Expand Down Expand Up @@ -55,4 +63,47 @@ var (

// -------------- Functions --------------

func AddLinkedAccountToDB(linkedAccount LinkedAccount) database.Response[LinkedAccount] {
db := database.GetDB("neuralnexus")
defer db.Close()

_, err := db.Exec(context.Background(), "INSERT INTO linked_accounts (user_id, platform, platform_username, platform_id, data) VALUES ($1, $2, $3, $4, $5)", linkedAccount.UserID, linkedAccount.Platform, linkedAccount.PlatformUsername, linkedAccount.PlatformID, linkedAccount.Data)
if err != nil {
return database.ErrorResponse[LinkedAccount]("Failed to create linked account", err)
}
return database.SuccessResponse(linkedAccount)
}

func GetLinkedAccountByPlatformID(platform, platformID string) database.Response[LinkedAccount] {
db := database.GetDB("neuralnexus")
defer db.Close()

rows, err := db.Query(context.Background(), "SELECT * FROM linked_accounts WHERE platform = $1 AND platform_id = $2", platform, platformID)
if err != nil {
return database.ErrorResponse[LinkedAccount]("Failed to get linked account", err)
}

linkedAccount, err := pgx.CollectExactlyOneRow(rows, pgx.RowToAddrOfStructByName[LinkedAccount])
if err != nil {
return database.ErrorResponse[LinkedAccount]("Failed to get linked account", err)
}
return database.SuccessResponse(*linkedAccount)
}

func GetLinkedAccountByUserID(userID uuid.UUID, platform string) database.Response[LinkedAccount] {
db := database.GetDB("neuralnexus")
defer db.Close()

rows, err := db.Query(context.Background(), "SELECT * FROM linked_accounts WHERE user_id = $1 AND platform = $2", userID, platform)
if err != nil {
return database.ErrorResponse[LinkedAccount]("Failed to get linked account", err)
}

linkedAccount, err := pgx.CollectExactlyOneRow(rows, pgx.RowToAddrOfStructByName[LinkedAccount])
if err != nil {
return database.ErrorResponse[LinkedAccount]("Failed to get linked account", err)
}
return database.SuccessResponse(*linkedAccount)
}

// -------------- Handlers --------------
103 changes: 92 additions & 11 deletions modules/auth/linking/discord.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package account_linking
package accountlinking

import (
"encoding/json"
"errors"
"log"
"net/http"
"net/url"
"os"
"strings"
"time"

"github.com/NeuralNexusDev/neuralnexus-api/modules/auth"
"github.com/google/uuid"
)

Expand All @@ -21,8 +24,8 @@ var (

// -------------- Structs --------------

// AccessTokenResponse struct
type AccessTokenResponse struct {
// DiscordTokenResponse struct
type DiscordTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Expand Down Expand Up @@ -73,8 +76,8 @@ func (d DiscordData) CreateLinkedAccount(userID uuid.UUID) LinkedAccount {

// -------------- Functions --------------

// ExchangeCodeForAccessToken exchanges a code for an access token
func ExchangeCodeForAccessToken(code string) (*AccessTokenResponse, error) {
// DiscordExtCodeForToken exchanges a code for an access token
func DiscordExtCodeForToken(code string) (*DiscordTokenResponse, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
Expand All @@ -99,7 +102,7 @@ func ExchangeCodeForAccessToken(code string) (*AccessTokenResponse, error) {
return nil, errors.New("failed to exchange code for access token")
}

var token AccessTokenResponse
var token DiscordTokenResponse
err = json.NewDecoder(resp.Body).Decode(&token)
if err != nil {
return nil, err
Expand All @@ -108,8 +111,8 @@ func ExchangeCodeForAccessToken(code string) (*AccessTokenResponse, error) {
return &token, nil
}

// RefreshAccessToken refreshes an access token
func RefreshAccessToken(refreshToken string) (*AccessTokenResponse, error) {
// DiscordRefreshToken refreshes an access token
func DiscordRefreshToken(refreshToken string) (*DiscordTokenResponse, error) {
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)
Expand All @@ -133,7 +136,7 @@ func RefreshAccessToken(refreshToken string) (*AccessTokenResponse, error) {
return nil, errors.New("failed to refresh access token")
}

var token AccessTokenResponse
var token DiscordTokenResponse
err = json.NewDecoder(resp.Body).Decode(&token)
if err != nil {
return nil, err
Expand All @@ -142,8 +145,8 @@ func RefreshAccessToken(refreshToken string) (*AccessTokenResponse, error) {
return &token, nil
}

// RevokeAccessToken revokes an access token
func RevokeAccessToken(accessToken string) error {
// DiscordRevokeToken revokes an access token
func DiscordRevokeToken(accessToken string) error {
data := url.Values{}
data.Set("token", accessToken)

Expand Down Expand Up @@ -186,6 +189,7 @@ func GetDiscordUser(accessToken string) (*DiscordData, error) {
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
log.Println("Failed to get Discord user:", resp.Status)
return nil, errors.New("failed to get Discord user")
}

Expand All @@ -197,3 +201,80 @@ func GetDiscordUser(accessToken string) (*DiscordData, error) {

return &user, nil
}

// DiscordOAuth process the Discord OAuth flow
func DiscordOAuth(code, state string) (*auth.Session, error) {
var a *auth.Account
// TODO: Sign the state so it can't be tampered with/impersonated
if state != "" && false { // TEMPORARILY DISABLED
// Get account by state (which is the user ID)
id, err := uuid.Parse(state)
if err != nil {
log.Println("Failed to parse state as UUID")
return nil, err
}
ad := auth.GetAccountByID(id)
if !ad.Success {
return nil, errors.New("failed to get account")
}
a = &ad.Data
}

token, err := DiscordExtCodeForToken(code)
if err != nil {
log.Println("Failed to exchange code for token")
return nil, err
}

user, err := GetDiscordUser(token.AccessToken)
if err != nil {
log.Println("Failed to get user from Discord API")
return nil, err
}

// Check if platform account is linked to an account
lad := GetLinkedAccountByPlatformID(PlatformDiscord, user.ID)
if lad.Success {
// If the account IDs don't match, default to OAuth as the source of truth
if a == nil || a.UserID != lad.Data.UserID {
ad := auth.GetAccountByID(lad.Data.UserID)
if !ad.Success {
return nil, errors.New("failed to get account")
}
s := ad.Data.NewSession(time.Now().Add(time.Hour * 24).Unix())
auth.AddSessionToCache(s)
defer auth.AddSessionToDB(s)
return &s, nil
} else if a.UserID == lad.Data.UserID {
s := a.NewSession(time.Now().Add(time.Hour * 24).Unix())
auth.AddSessionToCache(s)
defer auth.AddSessionToDB(s)
return &s, nil
}
}

// Check if the email is already in use -- simple account merging
ad := auth.GetAccountByEmail(user.Email)
if ad.Success {
a = &ad.Data
} else if a == nil {
// Create account
act := auth.NewPasswordLessAccount(user.Username, user.Email)
a = &act
dbResponse := auth.CreateAccount(*a)
if !dbResponse.Success {
return nil, errors.New("failed to create account")
}
}

// Link account
la := NewLinkedAccount(a.UserID, PlatformDiscord, user.Username, user.ID, user)
linkAcctData := AddLinkedAccountToDB(la)
if !linkAcctData.Success {
return nil, errors.New("failed to link account")
}
s := a.NewSession(time.Now().Add(time.Hour * 24).Unix())
auth.AddSessionToCache(s)
defer auth.AddSessionToDB(s)
return &s, nil
}
2 changes: 1 addition & 1 deletion modules/auth/linking/minecraft.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package account_linking
package accountlinking

import (
"encoding/json"
Expand Down
2 changes: 1 addition & 1 deletion modules/auth/linking/twitch.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package account_linking
package accountlinking

import (
"encoding/json"
Expand Down
14 changes: 6 additions & 8 deletions modules/auth/routes/authroutes.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

mw "github.com/NeuralNexusDev/neuralnexus-api/middleware"
"github.com/NeuralNexusDev/neuralnexus-api/modules/auth"
accountlinking "github.com/NeuralNexusDev/neuralnexus-api/modules/auth/linking"
"github.com/NeuralNexusDev/neuralnexus-api/modules/database"
"github.com/NeuralNexusDev/neuralnexus-api/responses"
)
Expand Down Expand Up @@ -66,20 +67,17 @@ func LogoutHandler(w http.ResponseWriter, r *http.Request) {

// DiscordOAuthHandler handles the Discord OAuth route
func DiscordOAuthHandler(w http.ResponseWriter, r *http.Request) {
// Get the code from the query parameters
code := r.URL.Query().Get("code")
if code == "" {
responses.SendAndEncodeBadRequest(w, r, "Invalid request")
return
}
log.Println(code)
// Get the state from the query parameters
state := r.URL.Query().Get("state")
if state == "" {
responses.SendAndEncodeBadRequest(w, r, "Invalid request")
session, err := accountlinking.DiscordOAuth(code, state)
if err != nil {
log.Println("Failed to authenticate with Discord:\n\t", err)
responses.SendAndEncodeBadRequest(w, r, "Failed to authenticate with Discord")
return
}
log.Println(state)

responses.SendAndEncodeStruct(w, r, http.StatusOK, "Discord OAuth")
responses.SendAndEncodeStruct(w, r, http.StatusOK, session)
}

0 comments on commit fa8ca17

Please sign in to comment.