feat(oauth-provider): public client prelogin endpoint (#8214)

This commit is contained in:
Dylan Vanmali
2026-03-18 10:27:19 -07:00
committed by GitHub
parent 40e7676155
commit 20e4561c9b
8 changed files with 131 additions and 20 deletions

View File

@@ -195,6 +195,33 @@ type getOAuthClientPublic = {
```
</APIMethod>
#### Get Public Client Prelogin
To obtain a public client prior to login, you must first enable the endpoint in your configuration:
```ts title="auth.ts"
oauthProvider({
allowPublicClientPrelogin: true,
})
```
Then, the following endpoint will obtain public client information.
<APIMethod path="/oauth2/public-client-prelogin" method="POST">
```ts
type getOAuthClientPublicPrelogin = {
/**
* The OAuth client's client_id
*/
client_id: string,
/**
* Valid oauth query parameters (Sent automatically when using the provided client)
*/
oauth_query: string
}
```
</APIMethod>
#### List Clients
To obtain a list of clients owned by a specific user or organization, use the following endpoint:

View File

@@ -0,0 +1,17 @@
import { APIError, createAuthMiddleware } from "better-auth/api";
import type { OAuthOptions, Scope } from "../types";
import { verifyOAuthQueryParams } from "../utils";
export const publicSessionMiddleware = (opts: OAuthOptions<Scope[]>) =>
createAuthMiddleware(async (ctx) => {
if (!opts.allowPublicClientPrelogin) {
throw new APIError("BAD_REQUEST");
}
const query = ctx.body.oauth_query;
const isValid = await verifyOAuthQueryParams(query, ctx.context.secret);
if (!isValid) {
throw new APIError("UNAUTHORIZED", {
error: "invalid_signature",
});
}
});

View File

@@ -9,7 +9,6 @@ import {
sessionMiddleware,
} from "better-auth/api";
import { parseSetCookieHeader } from "better-auth/cookies";
import { constantTimeEqual, makeSignature } from "better-auth/crypto";
import { mergeSchema } from "better-auth/db";
import type { BetterAuthPlugin } from "better-auth/types";
import * as z from "zod";
@@ -28,7 +27,11 @@ import { tokenEndpoint } from "./token";
import type { OAuthOptions, Scope } from "./types";
import { SafeUrlSchema } from "./types/zod";
import { userInfoEndpoint } from "./userinfo";
import { deleteFromPrompt, getJwtPlugin } from "./utils";
import {
deleteFromPrompt,
getJwtPlugin,
verifyOAuthQueryParams,
} from "./utils";
declare module "@better-auth/core" {
interface BetterAuthPluginRegistry<AuthOptions, Options> {
@@ -209,27 +212,20 @@ export const oauthProvider = <O extends OAuthOptions<Scope[]>>(options: O) => {
handler: createAuthMiddleware(async (ctx) => {
// Verify query signature
const query = ctx.body.oauth_query;
let queryParams = new URLSearchParams(query);
const sig = queryParams.get("sig");
const exp = Number(queryParams.get("exp"));
queryParams.delete("sig");
queryParams = new URLSearchParams(queryParams);
const verifySig = await makeSignature(
queryParams.toString(),
const isValid = await verifyOAuthQueryParams(
query,
ctx.context.secret,
);
if (
!sig ||
!constantTimeEqual(sig, verifySig) ||
new Date(exp * 1000) < new Date()
) {
if (!isValid) {
throw new APIError("BAD_REQUEST", {
error: "invalid_signature",
});
}
const queryParams = new URLSearchParams(query);
queryParams.delete("sig");
queryParams.delete("exp");
await oAuthState.set({
query: new URLSearchParams(queryParams).toString(),
query: queryParams.toString(),
});
// If path starts oauth2 authorize (ie /sign-in/social, /sign-in/oauth2), add to additional data body
@@ -1302,6 +1298,8 @@ export const oauthProvider = <O extends OAuthOptions<Scope[]>>(options: O) => {
createOAuthClient: oauthClientEndpoints.createOAuthClient(opts),
getOAuthClient: oauthClientEndpoints.getOAuthClient(opts),
getOAuthClientPublic: oauthClientEndpoints.getOAuthClientPublic(opts),
getOAuthClientPublicPrelogin:
oauthClientEndpoints.getOAuthClientPublicPrelogin(opts),
getOAuthClients: oauthClientEndpoints.getOAuthClients(opts),
adminUpdateOAuthClient: oauthClientEndpoints.adminUpdateOAuthClient(opts),
updateOAuthClient: oauthClientEndpoints.updateOAuthClient(opts),

View File

@@ -1,4 +1,5 @@
import { createAuthClient } from "better-auth/client";
import { makeSignature } from "better-auth/crypto";
import { jwt } from "better-auth/plugins/jwt";
import { getTestInstance } from "better-auth/test";
import { describe, expect, it } from "vitest";
@@ -11,7 +12,7 @@ describe("oauthClient", async () => {
const baseUrl = "http://localhost:3000";
const rpBaseUrl = "http://localhost:5000";
const redirectUri = `${rpBaseUrl}/api/auth/oauth2/callback/${providerId}`;
const { signInWithTestUser, customFetchImpl } = await getTestInstance({
const { auth, signInWithTestUser, customFetchImpl } = await getTestInstance({
baseURL: baseUrl,
plugins: [
oauthProvider({
@@ -21,6 +22,7 @@ describe("oauthClient", async () => {
oauthAuthServerConfig: true,
openidConfig: true,
},
allowPublicClientPrelogin: true,
}),
jwt(),
],
@@ -103,6 +105,27 @@ describe("oauthClient", async () => {
});
});
it("should get public-only information about a client prelogin", async () => {
// Creates mock valid search params
const signedParams = new URLSearchParams({
exp: `${Math.floor(Date.now() / 1000) + 60}`,
});
const sig = await makeSignature(
signedParams.toString(),
(auth.options as unknown as { secret: string }).secret,
);
signedParams.set("sig", sig);
const client = await authClient.oauth2.publicClientPrelogin({
client_id: oauthUiClient.client_id,
oauth_query: signedParams.toString(),
});
expect(client.data).toMatchObject({
client_id: oauthUiClient.client_id,
...testUiClientInput,
});
});
it("should get user's clients", async () => {
const clients = await authClient.oauth2.getClients();
expect(clients?.data?.length).toBe(3);

View File

@@ -53,10 +53,11 @@ export async function getClientEndpoint(
* This is commonly used to display information on login flow pages.
*/
export async function getClientPublicEndpoint(
ctx: GenericEndpointContext & { query: { client_id: string } },
ctx: GenericEndpointContext,
opts: OAuthOptions<Scope[]>,
clientId: string,
) {
const client = await getClient(ctx, opts, ctx.query.client_id);
const client = await getClient(ctx, opts, clientId);
if (!client) {
throw new APIError("NOT_FOUND", {
error_description: "client not found",

View File

@@ -1,5 +1,6 @@
import { createAuthEndpoint, sessionMiddleware } from "better-auth/api";
import * as z from "zod";
import { publicSessionMiddleware } from "../middleware";
import { createOAuthClientEndpoint } from "../register";
import type { OAuthOptions, Scope } from "../types";
import { SafeUrlSchema } from "../types/zod";
@@ -451,12 +452,35 @@ export const getOAuthClientPublic = (opts: OAuthOptions<Scope[]>) =>
}),
metadata: {
openapi: {
description: "Gets publically available client fields",
description: "Gets publicly available client fields",
},
},
},
async (ctx) => {
return getClientPublicEndpoint(ctx, opts);
const clientId = ctx.query.client_id;
return getClientPublicEndpoint(ctx, opts, clientId);
},
);
export const getOAuthClientPublicPrelogin = (opts: OAuthOptions<Scope[]>) =>
createAuthEndpoint(
"/oauth2/public-client-prelogin",
{
method: "POST",
use: [publicSessionMiddleware(opts)],
body: z.object({
client_id: z.string(),
oauth_query: z.string().optional(),
}),
metadata: {
openapi: {
description: "Gets publicly available client fields (prior to login)",
},
},
},
async (ctx) => {
const clientId = ctx.body.client_id;
return getClientPublicEndpoint(ctx, opts, clientId);
},
);

View File

@@ -115,6 +115,11 @@ export interface OAuthOptions<
scopeExpirations?: {
[K in Scopes[number]]?: number | string | Date;
};
/**
* Allows /oauth2/public-client-prelogin endpoint to be
* requestable prior to login via a valid oauth_query.
*/
allowPublicClientPrelogin?: boolean;
/**
* Allow unauthenticated dynamic client registration.
*

View File

@@ -62,6 +62,22 @@ export const getJwtPlugin = (ctx: AuthContext) => {
const cachedTrustedClients = new TTLCache<string, SchemaClient<Scope[]>>();
export async function verifyOAuthQueryParams(
oauth_query: string,
secret: string,
) {
const queryParams = new URLSearchParams(oauth_query);
const sig = queryParams.get("sig");
const exp = Number(queryParams.get("exp"));
queryParams.delete("sig");
const verifySig = await makeSignature(queryParams.toString(), secret);
return (
!!sig &&
constantTimeEqual(sig, verifySig) &&
new Date(exp * 1000) >= new Date()
);
}
/**
* Get a client by ID, checking trusted clients first, then database
*/