feat(generic-oauth): authorization request headers (#2507)

This commit is contained in:
Kevin Gallet
2025-05-01 16:16:54 +02:00
committed by GitHub
parent 5c59506fc2
commit 4bdeb3020c
3 changed files with 71 additions and 11 deletions

View File

@@ -1,8 +1,8 @@
import { base64Url } from "@better-auth/utils/base64";
import { betterFetch } from "@better-fetch/fetch";
import { jwtVerify } from "jose";
import type { ProviderOptions } from "./types";
import { getOAuth2Tokens } from "./utils";
import { jwtVerify } from "jose";
import { base64Url } from "@better-auth/utils/base64";
export async function validateAuthorizationCode({
code,
@@ -12,6 +12,7 @@ export async function validateAuthorizationCode({
tokenEndpoint,
authentication,
deviceId,
headers,
}: {
code: string;
redirectURI: string;
@@ -20,12 +21,14 @@ export async function validateAuthorizationCode({
deviceId?: string;
tokenEndpoint: string;
authentication?: "basic" | "post";
headers?: Record<string, string>;
}) {
const body = new URLSearchParams();
const headers: Record<string, any> = {
const requestHeaders: Record<string, any> = {
"content-type": "application/x-www-form-urlencoded",
accept: "application/json",
"user-agent": "better-auth",
...headers,
};
body.set("grant_type", "authorization_code");
body.set("code", code);
@@ -37,7 +40,7 @@ export async function validateAuthorizationCode({
const encodedCredentials = base64Url.encode(
`${options.clientId}:${options.clientSecret}`,
);
headers["authorization"] = `Basic ${encodedCredentials}`;
requestHeaders["authorization"] = `Basic ${encodedCredentials}`;
} else {
body.set("client_id", options.clientId);
body.set("client_secret", options.clientSecret);
@@ -45,7 +48,7 @@ export async function validateAuthorizationCode({
const { data, error } = await betterFetch<object>(tokenEndpoint, {
method: "POST",
body: body,
headers,
headers: requestHeaders,
});
if (error) {

View File

@@ -1,11 +1,11 @@
import { afterAll, beforeAll, describe, expect, it } from "vitest";
import { getTestInstance } from "../../test-utils/test-instance";
import { genericOAuth } from ".";
import { genericOAuthClient } from "./client";
import { createAuthClient } from "../../client";
import { getTestInstance } from "../../test-utils/test-instance";
import { genericOAuthClient } from "./client";
import { OAuth2Server } from "oauth2-mock-server";
import { betterFetch } from "@better-fetch/fetch";
import { OAuth2Server } from "oauth2-mock-server";
import { parseSetCookieHeader } from "../../cookies";
let server = new OAuth2Server();
@@ -389,4 +389,54 @@ describe("oauth2", async () => {
);
expect(callbackURL).toBe("http://localhost:3000/dashboard");
});
it("should pass authorization headers in oAuth2Callback", async () => {
const customHeaders = {
"X-Custom-Header": "test-value",
};
let receivedHeaders: Record<string, string> = {};
server.service.once("beforeTokenSigning", (token, req) => {
receivedHeaders = req.headers as Record<string, string>;
});
const { customFetchImpl } = await getTestInstance({
plugins: [
genericOAuth({
config: [
{
providerId: "test3",
discoveryUrl:
"http://localhost:8081/.well-known/openid-configuration",
clientId: clientId,
clientSecret: clientSecret,
pkce: true,
authorizationHeaders: customHeaders,
},
],
}),
],
});
const authClient = createAuthClient({
plugins: [genericOAuthClient()],
baseURL: "http://localhost:3000",
fetchOptions: {
customFetchImpl,
},
});
const res = await authClient.signIn.oauth2({
providerId: "test3",
callbackURL: "http://localhost:3000/dashboard",
newUserCallbackURL: "http://localhost:3000/new_user",
});
expect(res.data?.url).toContain("http://localhost:8081/authorize");
const headers = new Headers();
await simulateOAuthFlow(res.data?.url || "", headers, customFetchImpl);
expect(receivedHeaders).toHaveProperty("x-custom-header");
expect(receivedHeaders["x-custom-header"]).toBe("test-value");
});
});

View File

@@ -1,8 +1,10 @@
import { betterFetch } from "@better-fetch/fetch";
import { APIError } from "better-call";
import { decodeJwt } from "jose";
import { z } from "zod";
import { createAuthEndpoint, sessionMiddleware } from "../../api";
import { setSessionCookie } from "../../cookies";
import { BASE_ERROR_CODES } from "../../error/codes";
import {
createAuthorizationURL,
validateAuthorizationCode,
@@ -10,11 +12,9 @@ import {
type OAuthProvider,
} from "../../oauth2";
import { handleOAuthUserInfo } from "../../oauth2/link-account";
import { refreshAccessToken } from "../../oauth2/refresh-access-token";
import { generateState, parseState } from "../../oauth2/state";
import type { BetterAuthPlugin, User } from "../../types";
import { decodeJwt } from "jose";
import { BASE_ERROR_CODES } from "../../error/codes";
import { refreshAccessToken } from "../../oauth2/refresh-access-token";
/**
* Configuration interface for generic OAuth providers.
@@ -132,6 +132,11 @@ export interface GenericOAuthConfig {
* Useful for providers like Epic that require specific headers (e.g., Epic-Client-ID).
*/
discoveryHeaders?: Record<string, string>;
/**
* Custom headers to include in the authorization request.
* Useful for providers like Qonto that require specific headers (e.g., X-Qonto-Staging-Token for local development).
*/
authorizationHeaders?: Record<string, string>;
/**
* Override user info with the provider info.
*
@@ -250,6 +255,7 @@ export const genericOAuth = (options: GenericOAuthOptions) => {
});
}
return validateAuthorizationCode({
headers: c.authorizationHeaders,
code: data.code,
codeVerifier: data.codeVerifier,
redirectURI: data.redirectURI,
@@ -590,6 +596,7 @@ export const genericOAuth = (options: GenericOAuthOptions) => {
});
}
tokens = await validateAuthorizationCode({
headers: provider.authorizationHeaders,
code,
codeVerifier: provider.pkce ? codeVerifier : undefined,
redirectURI: `${ctx.context.baseURL}/oauth2/callback/${provider.providerId}`,