diff --git a/packages/better-auth/src/oauth2/validate-authorization-code.ts b/packages/better-auth/src/oauth2/validate-authorization-code.ts index 297eba5ae5..0f3eed64f0 100644 --- a/packages/better-auth/src/oauth2/validate-authorization-code.ts +++ b/packages/better-auth/src/oauth2/validate-authorization-code.ts @@ -1,8 +1,8 @@ +import { base64Url } from "@better-auth/utils/base64"; import { betterFetch } from "@better-fetch/fetch"; +import { jwtVerify } from "jose"; import type { ProviderOptions } from "./types"; import { getOAuth2Tokens } from "./utils"; -import { jwtVerify } from "jose"; -import { base64Url } from "@better-auth/utils/base64"; export async function validateAuthorizationCode({ code, @@ -12,6 +12,7 @@ export async function validateAuthorizationCode({ tokenEndpoint, authentication, deviceId, + headers, }: { code: string; redirectURI: string; @@ -20,12 +21,14 @@ export async function validateAuthorizationCode({ deviceId?: string; tokenEndpoint: string; authentication?: "basic" | "post"; + headers?: Record; }) { const body = new URLSearchParams(); - const headers: Record = { + const requestHeaders: Record = { "content-type": "application/x-www-form-urlencoded", accept: "application/json", "user-agent": "better-auth", + ...headers, }; body.set("grant_type", "authorization_code"); body.set("code", code); @@ -37,7 +40,7 @@ export async function validateAuthorizationCode({ const encodedCredentials = base64Url.encode( `${options.clientId}:${options.clientSecret}`, ); - headers["authorization"] = `Basic ${encodedCredentials}`; + requestHeaders["authorization"] = `Basic ${encodedCredentials}`; } else { body.set("client_id", options.clientId); body.set("client_secret", options.clientSecret); @@ -45,7 +48,7 @@ export async function validateAuthorizationCode({ const { data, error } = await betterFetch(tokenEndpoint, { method: "POST", body: body, - headers, + headers: requestHeaders, }); if (error) { diff --git a/packages/better-auth/src/plugins/generic-oauth/generic-oauth.test.ts b/packages/better-auth/src/plugins/generic-oauth/generic-oauth.test.ts index a246cff516..ebcf90886c 100644 --- a/packages/better-auth/src/plugins/generic-oauth/generic-oauth.test.ts +++ b/packages/better-auth/src/plugins/generic-oauth/generic-oauth.test.ts @@ -1,11 +1,11 @@ import { afterAll, beforeAll, describe, expect, it } from "vitest"; -import { getTestInstance } from "../../test-utils/test-instance"; import { genericOAuth } from "."; -import { genericOAuthClient } from "./client"; import { createAuthClient } from "../../client"; +import { getTestInstance } from "../../test-utils/test-instance"; +import { genericOAuthClient } from "./client"; -import { OAuth2Server } from "oauth2-mock-server"; import { betterFetch } from "@better-fetch/fetch"; +import { OAuth2Server } from "oauth2-mock-server"; import { parseSetCookieHeader } from "../../cookies"; let server = new OAuth2Server(); @@ -389,4 +389,54 @@ describe("oauth2", async () => { ); expect(callbackURL).toBe("http://localhost:3000/dashboard"); }); + + it("should pass authorization headers in oAuth2Callback", async () => { + const customHeaders = { + "X-Custom-Header": "test-value", + }; + + let receivedHeaders: Record = {}; + server.service.once("beforeTokenSigning", (token, req) => { + receivedHeaders = req.headers as Record; + }); + + const { customFetchImpl } = await getTestInstance({ + plugins: [ + genericOAuth({ + config: [ + { + providerId: "test3", + discoveryUrl: + "http://localhost:8081/.well-known/openid-configuration", + clientId: clientId, + clientSecret: clientSecret, + pkce: true, + authorizationHeaders: customHeaders, + }, + ], + }), + ], + }); + + const authClient = createAuthClient({ + plugins: [genericOAuthClient()], + baseURL: "http://localhost:3000", + fetchOptions: { + customFetchImpl, + }, + }); + + const res = await authClient.signIn.oauth2({ + providerId: "test3", + callbackURL: "http://localhost:3000/dashboard", + newUserCallbackURL: "http://localhost:3000/new_user", + }); + + expect(res.data?.url).toContain("http://localhost:8081/authorize"); + const headers = new Headers(); + await simulateOAuthFlow(res.data?.url || "", headers, customFetchImpl); + + expect(receivedHeaders).toHaveProperty("x-custom-header"); + expect(receivedHeaders["x-custom-header"]).toBe("test-value"); + }); }); diff --git a/packages/better-auth/src/plugins/generic-oauth/index.ts b/packages/better-auth/src/plugins/generic-oauth/index.ts index 2a09d91dad..326b81189d 100644 --- a/packages/better-auth/src/plugins/generic-oauth/index.ts +++ b/packages/better-auth/src/plugins/generic-oauth/index.ts @@ -1,8 +1,10 @@ import { betterFetch } from "@better-fetch/fetch"; import { APIError } from "better-call"; +import { decodeJwt } from "jose"; import { z } from "zod"; import { createAuthEndpoint, sessionMiddleware } from "../../api"; import { setSessionCookie } from "../../cookies"; +import { BASE_ERROR_CODES } from "../../error/codes"; import { createAuthorizationURL, validateAuthorizationCode, @@ -10,11 +12,9 @@ import { type OAuthProvider, } from "../../oauth2"; import { handleOAuthUserInfo } from "../../oauth2/link-account"; +import { refreshAccessToken } from "../../oauth2/refresh-access-token"; import { generateState, parseState } from "../../oauth2/state"; import type { BetterAuthPlugin, User } from "../../types"; -import { decodeJwt } from "jose"; -import { BASE_ERROR_CODES } from "../../error/codes"; -import { refreshAccessToken } from "../../oauth2/refresh-access-token"; /** * Configuration interface for generic OAuth providers. @@ -132,6 +132,11 @@ export interface GenericOAuthConfig { * Useful for providers like Epic that require specific headers (e.g., Epic-Client-ID). */ discoveryHeaders?: Record; + /** + * Custom headers to include in the authorization request. + * Useful for providers like Qonto that require specific headers (e.g., X-Qonto-Staging-Token for local development). + */ + authorizationHeaders?: Record; /** * Override user info with the provider info. * @@ -250,6 +255,7 @@ export const genericOAuth = (options: GenericOAuthOptions) => { }); } return validateAuthorizationCode({ + headers: c.authorizationHeaders, code: data.code, codeVerifier: data.codeVerifier, redirectURI: data.redirectURI, @@ -590,6 +596,7 @@ export const genericOAuth = (options: GenericOAuthOptions) => { }); } tokens = await validateAuthorizationCode({ + headers: provider.authorizationHeaders, code, codeVerifier: provider.pkce ? codeVerifier : undefined, redirectURI: `${ctx.context.baseURL}/oauth2/callback/${provider.providerId}`,