test: providers

This commit is contained in:
Bereket Engida
2024-09-24 02:10:48 +03:00
parent 675a9206f9
commit bc82bf3d30
11 changed files with 270 additions and 213 deletions

View File

@@ -5,8 +5,8 @@ import {
createRouter,
} from "better-call";
import type { AuthContext } from "../init";
import type { BetterAuthOptions, InferSession, InferUser } from "../types";
import type { Prettify, UnionToIntersection } from "../types/helper";
import type { BetterAuthOptions } from "../types";
import type { UnionToIntersection } from "../types/helper";
import { csrfMiddleware } from "./middlewares/csrf";
import {
callbackOAuth,
@@ -78,56 +78,11 @@ export function getEndpoints<
.filter((plugin) => plugin !== undefined)
.flat() || [];
/**
* Helper function to type the session output
* TODO: find a better way to do this
*/
async function typedSession(
ctx: Context<
"/session",
{
method: "GET";
requireHeaders: true;
}
>,
) {
const handler = await getSession(ctx);
return handler as {
session: Prettify<InferSession<Option>>;
user: Prettify<InferUser<Option>>;
} | null;
}
typedSession.path = getSession.path;
typedSession.method = getSession.method;
typedSession.options = getSession.options;
typedSession.headers = getSession.headers;
/**
* Helper function to type the list sessions output
* TODO: find a better way to do this
*/
async function typeListSessions(
ctx: Context<
"/user/sessions",
{
method: "GET";
requireHeaders: true;
}
>,
) {
const handler = await listSessions(ctx);
return handler as unknown as Prettify<InferSession<Option>>[];
}
typeListSessions.path = listSessions.path;
typeListSessions.method = listSessions.method;
typeListSessions.options = listSessions.options;
typeListSessions.headers = listSessions.headers;
const baseEndpoints = {
signInOAuth,
callbackOAuth,
getCSRFToken,
getSession: typedSession,
getSession: getSession<Option>(),
signOut,
signUpEmail,
signInEmail,
@@ -137,7 +92,7 @@ export function getEndpoints<
sendVerificationEmail,
changePassword,
updateUser,
listSessions: typeListSessions,
listSessions: listSessions<Option>(),
revokeSession,
revokeSessions,
};
@@ -232,6 +187,7 @@ export const router = <C extends AuthContext, Option extends BetterAuthOptions>(
},
...middlewares,
],
onError(e) {
if (options.disableLog !== true) {
if (e instanceof APIError) {

View File

@@ -98,10 +98,16 @@ export const callbackOAuth = createAuthEndpoint(
const isTrustedProvider = trustedProviders
? trustedProviders.includes(provider.id as "apple")
: false;
if (!hasBeenLinked && (!user.emailVerified || !isTrustedProvider)) {
const url = new URL(currentURL || callbackURL);
url.searchParams.set("error", "user_already_exists");
let url: URL;
try {
url = new URL(currentURL || callbackURL);
url.searchParams.set("error", "user_already_exists");
} catch (e) {
throw c.redirect(
`${c.context.baseURL}/error?error=user_already_exists`,
);
}
throw c.redirect(url.toString());
}

View File

@@ -1,4 +1,4 @@
import type { Context } from "better-call";
import type { Context, InferUse } from "better-call";
import { createAuthEndpoint } from "../call";
import { getDate } from "../../utils/date";
import { deleteSessionCookie, setSessionCookie } from "../../utils/cookies";
@@ -6,6 +6,12 @@ import { sessionMiddleware } from "../middlewares/session";
import type { Session, User } from "../../adapters/schema";
import { z } from "zod";
import { getIp } from "../../utils/get-request-ip";
import type {
BetterAuthOptions,
InferSession,
InferUser,
Prettify,
} from "../../types";
const sessionCache = new Map<
string,
@@ -37,113 +43,124 @@ function getRequestUniqueKey(ctx: Context<any, any>, token: string): string {
return uniqueString;
}
export const getSession = createAuthEndpoint(
"/session",
{
method: "GET",
requireHeaders: true,
},
async (ctx) => {
try {
const sessionCookieToken = await ctx.getSignedCookie(
ctx.context.authCookies.sessionToken.name,
ctx.context.secret,
);
export const getSession = <Option extends BetterAuthOptions>() =>
createAuthEndpoint(
"/session",
{
method: "GET",
requireHeaders: true,
},
async (ctx) => {
console.log("called");
try {
const sessionCookieToken = await ctx.getSignedCookie(
ctx.context.authCookies.sessionToken.name,
ctx.context.secret,
);
if (!sessionCookieToken) {
return ctx.json(null, {
status: 401,
});
}
const key = getRequestUniqueKey(ctx, sessionCookieToken);
const cachedSession = sessionCache.get(key);
if (cachedSession) {
if (cachedSession.expiresAt > Date.now()) {
return ctx.json(cachedSession.data);
if (!sessionCookieToken) {
return ctx.json(null, {
status: 401,
});
}
sessionCache.delete(key);
}
const session =
await ctx.context.internalAdapter.findSession(sessionCookieToken);
if (!session || session.session.expiresAt < new Date()) {
deleteSessionCookie(ctx);
if (session) {
/**
* if session expired clean up the session
*/
await ctx.context.internalAdapter.deleteSession(session.session.id);
const key = getRequestUniqueKey(ctx, sessionCookieToken);
const cachedSession = sessionCache.get(key);
if (cachedSession) {
if (cachedSession.expiresAt > Date.now()) {
return ctx.json(cachedSession.data);
}
sessionCache.delete(key);
}
return ctx.json(null, {
status: 401,
});
}
const dontRememberMe = await ctx.getSignedCookie(
ctx.context.authCookies.dontRememberToken.name,
ctx.context.secret,
);
/**
* We don't need to update the session if the user doesn't want to be remembered
*/
if (dontRememberMe) {
return ctx.json(session);
}
const expiresIn = ctx.context.session.expiresIn;
const updateAge = ctx.context.session.updateAge;
/**
* Calculate last updated date to throttle write updates to database
* Formula: ({expiry date} - sessionMaxAge) + sessionUpdateAge
*
* e.g. ({expiry date} - 30 days) + 1 hour
*
* inspired by: https://github.com/nextauthjs/next-auth/blob/main/packages/core/src/lib/
* actions/session.ts
*/
const sessionIsDueToBeUpdatedDate =
session.session.expiresAt.valueOf() -
expiresIn * 1000 +
updateAge * 1000;
const shouldBeUpdated = sessionIsDueToBeUpdatedDate <= Date.now();
if (shouldBeUpdated) {
const updatedSession = await ctx.context.internalAdapter.updateSession(
session.session.id,
{
expiresAt: getDate(ctx.context.session.expiresIn, true),
const session =
await ctx.context.internalAdapter.findSession(sessionCookieToken);
if (!session || session.session.expiresAt < new Date()) {
deleteSessionCookie(ctx);
if (session) {
/**
* if session expired clean up the session
*/
await ctx.context.internalAdapter.deleteSession(session.session.id);
}
return ctx.json(null, {
status: 401,
});
}
const dontRememberMe = await ctx.getSignedCookie(
ctx.context.authCookies.dontRememberToken.name,
ctx.context.secret,
);
/**
* We don't need to update the session if the user doesn't want to be remembered
*/
if (dontRememberMe) {
return ctx.json(session);
}
const expiresIn = ctx.context.session.expiresIn;
const updateAge = ctx.context.session.updateAge;
/**
* Calculate last updated date to throttle write updates to database
* Formula: ({expiry date} - sessionMaxAge) + sessionUpdateAge
*
* e.g. ({expiry date} - 30 days) + 1 hour
*
* inspired by: https://github.com/nextauthjs/next-auth/blob/main/packages/core/src/lib/
* actions/session.ts
*/
const sessionIsDueToBeUpdatedDate =
session.session.expiresAt.valueOf() -
expiresIn * 1000 +
updateAge * 1000;
const shouldBeUpdated = sessionIsDueToBeUpdatedDate <= Date.now();
if (shouldBeUpdated) {
const updatedSession =
await ctx.context.internalAdapter.updateSession(
session.session.id,
{
expiresAt: getDate(ctx.context.session.expiresIn, true),
},
);
if (!updatedSession) {
/**
* Handle case where session update fails (e.g., concurrent deletion)
*/
deleteSessionCookie(ctx);
return ctx.json(null, { status: 401 });
}
const maxAge =
(updatedSession.expiresAt.valueOf() - Date.now()) / 1000;
await setSessionCookie(ctx, updatedSession.id, false, {
maxAge,
});
return ctx.json({
session: updatedSession as unknown as Prettify<
InferSession<Option>
>,
user: session.user as unknown as Prettify<InferUser<Option>>,
});
}
sessionCache.set(key, {
data: session,
expiresAt: Date.now() + 5000,
});
return ctx.json(
session as unknown as {
session: Prettify<InferSession<Option>>;
user: Prettify<InferUser<Option>>;
},
);
if (!updatedSession) {
/**
* Handle case where session update fails (e.g., concurrent deletion)
*/
deleteSessionCookie(ctx);
return ctx.json(null, { status: 401 });
}
const maxAge = (updatedSession.expiresAt.valueOf() - Date.now()) / 1000;
await setSessionCookie(ctx, updatedSession.id, false, {
maxAge,
});
return ctx.json({
session: updatedSession,
user: session.user,
});
} catch (error) {
ctx.context.logger.error(error);
return ctx.json(null, { status: 500 });
}
sessionCache.set(key, {
data: session,
expiresAt: Date.now() + 5000,
});
return ctx.json(session);
} catch (error) {
ctx.context.logger.error(error);
return ctx.json(null, { status: 500 });
}
},
);
},
);
export const getSessionFromCtx = async (ctx: Context<any, any>) => {
const session = await getSession({
const session = await getSession()({
...ctx,
//@ts-expect-error: By default since this request context comes from a router it'll have a `router` flag which force it to be a request object
_flag: undefined,
@@ -154,29 +171,32 @@ export const getSessionFromCtx = async (ctx: Context<any, any>) => {
/**
* user active sessions list
*/
export const listSessions = createAuthEndpoint(
"/user/list-sessions",
{
method: "GET",
use: [sessionMiddleware],
requireHeaders: true,
},
async (ctx) => {
const sessions = await ctx.context.adapter.findMany<Session>({
model: ctx.context.tables.session.tableName,
where: [
{
field: "userId",
value: ctx.context.session.user.id,
},
],
});
const activeSessions = sessions.filter((session) => {
return session.expiresAt > new Date();
});
return ctx.json(activeSessions);
},
);
export const listSessions = <Option extends BetterAuthOptions>() =>
createAuthEndpoint(
"/user/list-sessions",
{
method: "GET",
use: [sessionMiddleware],
requireHeaders: true,
},
async (ctx) => {
const sessions = await ctx.context.adapter.findMany<Session>({
model: ctx.context.tables.session.tableName,
where: [
{
field: "userId",
value: ctx.context.session.user.id,
},
],
});
const activeSessions = sessions.filter((session) => {
return session.expiresAt > new Date();
});
return ctx.json(
activeSessions as unknown as Prettify<InferSession<Option>>[],
);
},
);
/**
* revoke a single session

View File

@@ -2,6 +2,7 @@ import { alphabet, generateRandomString } from "oslo/crypto";
import { z } from "zod";
import { createAuthEndpoint } from "../call";
import { createEmailVerificationToken } from "./verify-email";
import { setSessionCookie } from "../../utils";
export const signUpEmail = createAuthEndpoint(
"/sign-up/email",
@@ -79,12 +80,7 @@ export const signUpEmail = createAuthEndpoint(
createdUser.id,
ctx.request,
);
await ctx.setSignedCookie(
ctx.context.authCookies.sessionToken.name,
session.id,
ctx.context.secret,
ctx.context.authCookies.sessionToken.options,
);
await setSessionCookie(ctx, session.id);
if (ctx.context.options.emailAndPassword.sendEmailVerificationOnSignUp) {
const token = await createEmailVerificationToken(
ctx.context.secret,

View File

@@ -47,7 +47,6 @@ export function createDynamicPathProxy<T extends Record<string, any>>(
break;
}
}
if (typeof current === "function") {
return current;
}
@@ -61,7 +60,6 @@ export function createDynamicPathProxy<T extends Record<string, any>>(
segment.replace(/[A-Z]/g, (letter) => `-${letter.toLowerCase()}`),
)
.join("/");
const arg = (args[0] || {}) as ProxyRequest;
const method = getMethod(routePath, knownPathMethods, arg);
const { query, fetchOptions: options, ...body } = arg;

View File

@@ -67,7 +67,6 @@ export const useAuthQuery = <T>(
},
});
};
fn();
initializedAtom = Array.isArray(initializedAtom)
? initializedAtom
: [initializedAtom];

View File

@@ -8,7 +8,7 @@ import { useAuthQuery } from "./query";
import type { BetterAuthPlugin } from "../plugins";
export function getSessionAtom<Option extends ClientOptions>(
client: BetterFetch,
$fetch: BetterFetch,
) {
type Plugins = Option["plugins"] extends Array<BetterAuthClientPlugin>
? Array<
@@ -39,7 +39,7 @@ export function getSessionAtom<Option extends ClientOptions>(
const session = useAuthQuery<{
user: Prettify<UserWithAdditionalFields>;
session: Prettify<SessionWithAdditionalFields>;
}>($signal, "/session", client, {
}>($signal, "/session", $fetch, {
method: "GET",
});
return {

View File

@@ -6,7 +6,7 @@ import { logger } from "../../utils/logger";
export async function getRateLimitKey(req: Request) {
if (req.headers.get("Authorization") || req.headers.get("cookie")) {
try {
const session = await getSession({
const session = await getSession()({
headers: req.headers,
// @ts-ignore
_flag: undefined,

View File

@@ -24,7 +24,7 @@ describe("rate-limiter", async () => {
password: testUser.password,
});
if (i === 9) {
if (i === 10) {
expect(response.error?.status).toBe(429);
} else {
expect(response.error).toBeNull();

View File

@@ -3,7 +3,7 @@ import type { OAuthProvider } from ".";
import { parseJWT } from "oslo/jwt";
import { betterFetch } from "@better-fetch/fetch";
import { BetterAuthError } from "../error/better-auth-error";
import { getRedirectURI } from "./utils";
import { getRedirectURI, validateAuthorizationCode } from "./utils";
export interface AppleProfile {
/**
* The subject registered claim identifies the principal thats the subject
@@ -53,41 +53,30 @@ export interface AppleOptions {
redirectURI?: string;
}
export const apple = ({
clientId,
clientSecret,
redirectURI,
}: AppleOptions) => {
export const apple = (options: AppleOptions) => {
const tokenEndpoint = "https://appleid.apple.com/auth/token";
redirectURI = getRedirectURI("apple", redirectURI);
return {
id: "apple",
name: "Apple",
createAuthorizationURL({ state, scopes }) {
createAuthorizationURL({ state, scopes, redirectURI }) {
const _scope = scopes || ["email", "name", "openid"];
return new URL(
`https://appleid.apple.com/auth/authorize?client_id=${clientId}&response_type=code&redirect_uri=${redirectURI}&scope=${_scope.join(
" ",
)}&state=${state}`,
`https://appleid.apple.com/auth/authorize?client_id=${
options.clientId
}&response_type=code&redirect_uri=${
redirectURI || options.redirectURI
}&scope=${_scope.join(" ")}&state=${state}`,
);
},
validateAuthorizationCode: async (code) => {
const data = await betterFetch<OAuth2Tokens>(tokenEndpoint, {
method: "POST",
body: new URLSearchParams({
client_id: clientId,
client_secret: clientSecret,
grant_type: "authorization_code",
code,
}),
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
validateAuthorizationCode: async (code, codeVerifier, redirectURI) => {
return validateAuthorizationCode({
code,
codeVerifier,
redirectURI:
redirectURI || getRedirectURI("apple", options.redirectURI),
options,
tokenEndpoint,
});
if (data.error) {
throw new BetterAuthError(data.error?.message || "");
}
return data.data;
},
async getUserInfo(token) {
const data = parseJWT(token.idToken())?.payload as AppleProfile | null;

View File

@@ -0,0 +1,93 @@
import { describe, expect, it, vi } from "vitest";
import { getTestInstance } from "../test-utils/test-instance";
import { OAuth2Tokens } from "arctic";
import { createJWT } from "oslo/jwt";
import { DEFAULT_SECRET } from "../utils/constants";
import type { GoogleProfile } from "./google";
import { parseSetCookieHeader } from "../utils";
import { createAuthClient } from "../client";
vi.mock("./utils", async (importOriginal) => {
const original = (await importOriginal()) as any;
return {
...original,
validateAuthorizationCode: vi
.fn()
.mockImplementation(async (...args: any) => {
const data: GoogleProfile = {
email: "user@email.com",
email_verified: true,
name: "First Last",
picture: "https://lh3.googleusercontent.com/a-/AOh14GjQ4Z7Vw",
exp: 1234567890,
sub: "1234567890",
iat: 1234567890,
aud: "test",
azp: "test",
nbf: 1234567890,
iss: "test",
locale: "en",
jti: "test",
given_name: "First",
family_name: "Last",
};
const testIdToken = await createJWT(
"HS256",
Buffer.from(DEFAULT_SECRET),
data,
);
const tokens = new OAuth2Tokens({
access_token: "test",
refresh_token: "test",
id_token: testIdToken,
});
return tokens;
}),
};
});
describe("Social Providers", async () => {
const { auth, customFetchImpl, client } = await getTestInstance({
socialProviders: {
google: {
clientId: "test",
clientSecret: "test",
enabled: true,
},
apple: {
clientId: "test",
clientSecret: "test",
},
},
});
let state = "";
it("should be able to add social providers", async () => {
const signInRes = await client.signIn.social({
provider: "google",
});
expect(signInRes.data).toMatchObject({
url: expect.stringContaining("google.com"),
state: expect.any(String),
codeVerifier: expect.any(String),
redirect: true,
});
state = signInRes.data?.state || "";
});
it("should be able to sign in with social providers", async () => {
await client.$fetch("/callback/google", {
query: {
state,
code: "test",
},
method: "GET",
onError(context) {
expect(context.response.status).toBe(302);
const cookies = parseSetCookieHeader(
context.response.headers.get("set-cookie") || "",
);
expect(cookies.get("better-auth.session_token")?.value).toBeDefined();
},
});
});
});