mirror of
https://github.com/better-auth/better-auth.git
synced 2026-06-02 20:36:19 -05:00
feat(generic-oauth): authorization request headers (#2507)
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import { base64Url } from "@better-auth/utils/base64";
|
||||
import { betterFetch } from "@better-fetch/fetch";
|
||||
import { jwtVerify } from "jose";
|
||||
import type { ProviderOptions } from "./types";
|
||||
import { getOAuth2Tokens } from "./utils";
|
||||
import { jwtVerify } from "jose";
|
||||
import { base64Url } from "@better-auth/utils/base64";
|
||||
|
||||
export async function validateAuthorizationCode({
|
||||
code,
|
||||
@@ -12,6 +12,7 @@ export async function validateAuthorizationCode({
|
||||
tokenEndpoint,
|
||||
authentication,
|
||||
deviceId,
|
||||
headers,
|
||||
}: {
|
||||
code: string;
|
||||
redirectURI: string;
|
||||
@@ -20,12 +21,14 @@ export async function validateAuthorizationCode({
|
||||
deviceId?: string;
|
||||
tokenEndpoint: string;
|
||||
authentication?: "basic" | "post";
|
||||
headers?: Record<string, string>;
|
||||
}) {
|
||||
const body = new URLSearchParams();
|
||||
const headers: Record<string, any> = {
|
||||
const requestHeaders: Record<string, any> = {
|
||||
"content-type": "application/x-www-form-urlencoded",
|
||||
accept: "application/json",
|
||||
"user-agent": "better-auth",
|
||||
...headers,
|
||||
};
|
||||
body.set("grant_type", "authorization_code");
|
||||
body.set("code", code);
|
||||
@@ -37,7 +40,7 @@ export async function validateAuthorizationCode({
|
||||
const encodedCredentials = base64Url.encode(
|
||||
`${options.clientId}:${options.clientSecret}`,
|
||||
);
|
||||
headers["authorization"] = `Basic ${encodedCredentials}`;
|
||||
requestHeaders["authorization"] = `Basic ${encodedCredentials}`;
|
||||
} else {
|
||||
body.set("client_id", options.clientId);
|
||||
body.set("client_secret", options.clientSecret);
|
||||
@@ -45,7 +48,7 @@ export async function validateAuthorizationCode({
|
||||
const { data, error } = await betterFetch<object>(tokenEndpoint, {
|
||||
method: "POST",
|
||||
body: body,
|
||||
headers,
|
||||
headers: requestHeaders,
|
||||
});
|
||||
|
||||
if (error) {
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { afterAll, beforeAll, describe, expect, it } from "vitest";
|
||||
import { getTestInstance } from "../../test-utils/test-instance";
|
||||
import { genericOAuth } from ".";
|
||||
import { genericOAuthClient } from "./client";
|
||||
import { createAuthClient } from "../../client";
|
||||
import { getTestInstance } from "../../test-utils/test-instance";
|
||||
import { genericOAuthClient } from "./client";
|
||||
|
||||
import { OAuth2Server } from "oauth2-mock-server";
|
||||
import { betterFetch } from "@better-fetch/fetch";
|
||||
import { OAuth2Server } from "oauth2-mock-server";
|
||||
import { parseSetCookieHeader } from "../../cookies";
|
||||
|
||||
let server = new OAuth2Server();
|
||||
@@ -389,4 +389,54 @@ describe("oauth2", async () => {
|
||||
);
|
||||
expect(callbackURL).toBe("http://localhost:3000/dashboard");
|
||||
});
|
||||
|
||||
it("should pass authorization headers in oAuth2Callback", async () => {
|
||||
const customHeaders = {
|
||||
"X-Custom-Header": "test-value",
|
||||
};
|
||||
|
||||
let receivedHeaders: Record<string, string> = {};
|
||||
server.service.once("beforeTokenSigning", (token, req) => {
|
||||
receivedHeaders = req.headers as Record<string, string>;
|
||||
});
|
||||
|
||||
const { customFetchImpl } = await getTestInstance({
|
||||
plugins: [
|
||||
genericOAuth({
|
||||
config: [
|
||||
{
|
||||
providerId: "test3",
|
||||
discoveryUrl:
|
||||
"http://localhost:8081/.well-known/openid-configuration",
|
||||
clientId: clientId,
|
||||
clientSecret: clientSecret,
|
||||
pkce: true,
|
||||
authorizationHeaders: customHeaders,
|
||||
},
|
||||
],
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
const authClient = createAuthClient({
|
||||
plugins: [genericOAuthClient()],
|
||||
baseURL: "http://localhost:3000",
|
||||
fetchOptions: {
|
||||
customFetchImpl,
|
||||
},
|
||||
});
|
||||
|
||||
const res = await authClient.signIn.oauth2({
|
||||
providerId: "test3",
|
||||
callbackURL: "http://localhost:3000/dashboard",
|
||||
newUserCallbackURL: "http://localhost:3000/new_user",
|
||||
});
|
||||
|
||||
expect(res.data?.url).toContain("http://localhost:8081/authorize");
|
||||
const headers = new Headers();
|
||||
await simulateOAuthFlow(res.data?.url || "", headers, customFetchImpl);
|
||||
|
||||
expect(receivedHeaders).toHaveProperty("x-custom-header");
|
||||
expect(receivedHeaders["x-custom-header"]).toBe("test-value");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import { betterFetch } from "@better-fetch/fetch";
|
||||
import { APIError } from "better-call";
|
||||
import { decodeJwt } from "jose";
|
||||
import { z } from "zod";
|
||||
import { createAuthEndpoint, sessionMiddleware } from "../../api";
|
||||
import { setSessionCookie } from "../../cookies";
|
||||
import { BASE_ERROR_CODES } from "../../error/codes";
|
||||
import {
|
||||
createAuthorizationURL,
|
||||
validateAuthorizationCode,
|
||||
@@ -10,11 +12,9 @@ import {
|
||||
type OAuthProvider,
|
||||
} from "../../oauth2";
|
||||
import { handleOAuthUserInfo } from "../../oauth2/link-account";
|
||||
import { refreshAccessToken } from "../../oauth2/refresh-access-token";
|
||||
import { generateState, parseState } from "../../oauth2/state";
|
||||
import type { BetterAuthPlugin, User } from "../../types";
|
||||
import { decodeJwt } from "jose";
|
||||
import { BASE_ERROR_CODES } from "../../error/codes";
|
||||
import { refreshAccessToken } from "../../oauth2/refresh-access-token";
|
||||
|
||||
/**
|
||||
* Configuration interface for generic OAuth providers.
|
||||
@@ -132,6 +132,11 @@ export interface GenericOAuthConfig {
|
||||
* Useful for providers like Epic that require specific headers (e.g., Epic-Client-ID).
|
||||
*/
|
||||
discoveryHeaders?: Record<string, string>;
|
||||
/**
|
||||
* Custom headers to include in the authorization request.
|
||||
* Useful for providers like Qonto that require specific headers (e.g., X-Qonto-Staging-Token for local development).
|
||||
*/
|
||||
authorizationHeaders?: Record<string, string>;
|
||||
/**
|
||||
* Override user info with the provider info.
|
||||
*
|
||||
@@ -250,6 +255,7 @@ export const genericOAuth = (options: GenericOAuthOptions) => {
|
||||
});
|
||||
}
|
||||
return validateAuthorizationCode({
|
||||
headers: c.authorizationHeaders,
|
||||
code: data.code,
|
||||
codeVerifier: data.codeVerifier,
|
||||
redirectURI: data.redirectURI,
|
||||
@@ -590,6 +596,7 @@ export const genericOAuth = (options: GenericOAuthOptions) => {
|
||||
});
|
||||
}
|
||||
tokens = await validateAuthorizationCode({
|
||||
headers: provider.authorizationHeaders,
|
||||
code,
|
||||
codeVerifier: provider.pkce ? codeVerifier : undefined,
|
||||
redirectURI: `${ctx.context.baseURL}/oauth2/callback/${provider.providerId}`,
|
||||
|
||||
Reference in New Issue
Block a user