Improving user management and auto-refreshing user firms

This commit is contained in:
2025-04-10 01:31:37 +02:00
parent bc059de65b
commit f1fe81a146
13 changed files with 131 additions and 67 deletions

View File

@@ -1,19 +1,19 @@
import os import os
from typing import Any, Optional from typing import Any
from beanie import PydanticObjectId from beanie import PydanticObjectId, Document
from fastapi import Depends, Response, status from fastapi import Depends, Response, status
from fastapi_users import BaseUserManager, FastAPIUsers, schemas, models from fastapi_users import BaseUserManager, FastAPIUsers, schemas, models
from fastapi_users.authentication import AuthenticationBackend, BearerTransport, CookieTransport, Strategy from fastapi_users.authentication import AuthenticationBackend, CookieTransport, Strategy
from fastapi_users.authentication.strategy import AccessTokenDatabase, DatabaseStrategy from fastapi_users.authentication.strategy import AccessTokenDatabase, DatabaseStrategy
from fastapi_users_db_beanie.access_token import BeanieBaseAccessTokenDocument, BeanieAccessTokenDatabase from fastapi_users_db_beanie.access_token import BeanieBaseAccessToken, BeanieAccessTokenDatabase
from fastapi_users.openapi import OpenAPIResponseType from fastapi_users.openapi import OpenAPIResponseType
from httpx_oauth.clients.google import GoogleOAuth2 from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.discord import DiscordOAuth2 from httpx_oauth.clients.discord import DiscordOAuth2
from starlette.responses import JSONResponse, RedirectResponse from starlette.responses import JSONResponse, RedirectResponse
from hub.user import User, get_user_db from hub.user import User, get_user_db
from hub.user.schemas import UserSchema from hub.user.schemas import UserSchema, UserUpdateSchema
SECRET = os.getenv("FASTAPI_USERS_SECRET") SECRET = os.getenv("FASTAPI_USERS_SECRET")
@@ -23,7 +23,7 @@ discord_oauth_client = DiscordOAuth2(os.getenv("DISCORD_CLIENT_ID"), os.getenv("
TOKEN_LIFETIME = 3600 TOKEN_LIFETIME = 3600
class AccessToken(BeanieBaseAccessTokenDocument): class AccessToken(BeanieBaseAccessToken, Document):
pass pass
async def get_access_token_db(): async def get_access_token_db():
@@ -84,10 +84,11 @@ auth_router = fastapi_users.get_auth_router(auth_backend, requires_verification=
register_router = fastapi_users.get_register_router(UserSchema, schemas.BaseUserCreate) register_router = fastapi_users.get_register_router(UserSchema, schemas.BaseUserCreate)
password_router = fastapi_users.get_reset_password_router() password_router = fastapi_users.get_reset_password_router()
verification_router = fastapi_users.get_verify_router(UserSchema) verification_router = fastapi_users.get_verify_router(UserSchema)
users_router = fastapi_users.get_users_router(UserSchema, schemas.BaseUserUpdate) users_router = fastapi_users.get_users_router(UserSchema, UserUpdateSchema)
cookie_transport = CookieTransportOauth(cookie_name="rpkapiusersauth") cookie_transport = CookieTransportOauth(cookie_name="rpkapiusersauth")
auth_backend = AuthenticationBackend(name="db", transport=cookie_transport, get_strategy=get_database_strategy, ) auth_backend = AuthenticationBackend(name="db", transport=cookie_transport, get_strategy=get_database_strategy, )
google_oauth_router = fastapi_users.get_oauth_router(google_oauth_client, auth_backend, SECRET, is_verified_by_default=True) google_oauth_router = fastapi_users.get_oauth_router(google_oauth_client, auth_backend, SECRET, is_verified_by_default=True)
discord_oauth_router = fastapi_users.get_oauth_router(discord_oauth_client, auth_backend, SECRET, is_verified_by_default=True) discord_oauth_router = fastapi_users.get_oauth_router(discord_oauth_client, auth_backend, SECRET, is_verified_by_default=True)

View File

@@ -18,4 +18,4 @@ async def init_db():
await init_beanie(database=client.hub, await init_beanie(database=client.hub,
document_models=[User, AccessToken, Firm], document_models=[User, AccessToken, Firm],
allow_index_dropping=True) allow_index_dropping=True)

View File

@@ -4,44 +4,41 @@ from fastapi import APIRouter, Depends, HTTPException
from hub.auth import get_current_user from hub.auth import get_current_user
from hub.firm import Firm, FirmRead, FirmCreate, FirmUpdate from hub.firm import Firm, FirmRead, FirmCreate, FirmUpdate
model = Firm
model_read = FirmRead
model_create = FirmCreate
model_update = FirmUpdate
router = APIRouter() router = APIRouter()
@router.post("/", response_description="{} added to the database".format(model.__name__)) @router.post("/", response_description="{} added to the database".format(Firm.__name__))
async def create(item: model_create, user=Depends(get_current_user)) -> model_read: async def create(item: FirmCreate, user=Depends(get_current_user)) -> FirmRead:
exists = await Firm.find_one({"name": item.name, "instance": item.instance}) firm_dict = {"name": item.name, "instance": item.instance}
exists = await Firm.find_one(firm_dict)
if exists: if exists:
raise HTTPException(status_code=400, detail="Firm already exists") raise HTTPException(status_code=400, detail="Firm already exists")
record = model(created_by=user.id, updated_by=user.id, owner=user.id, **item.model_dump()) record = Firm(created_by=user.id, updated_by=user.id, owner=user.id, **item.model_dump())
o = await record.create() o = await record.create()
user.firms.append(o.id) user.firms.append(firm_dict)
user.save() await user.save()
return model_read(**o.model_dump()) return FirmRead(**o.model_dump())
@router.get("/{id}", response_description="{} record retrieved".format(model.__name__)) @router.get("/{id}", response_description="{} record retrieved".format(Firm.__name__))
async def read_id(id: PydanticObjectId, user=Depends(get_current_user)) -> model_read: async def read_id(id: PydanticObjectId, user=Depends(get_current_user)) -> FirmRead:
item = await model.get(id) item = await Firm.get(id)
return model_read(**item.model_dump()) return FirmRead(**item.model_dump())
@router.put("/{id}", response_description="{} record updated".format(model.__name__)) @router.put("/{id}", response_description="{} record updated".format(Firm.__name__))
async def update(id: PydanticObjectId, req: model_update, user=Depends(get_current_user)) -> model_read: async def update(id: PydanticObjectId, req: FirmUpdate, user=Depends(get_current_user)) -> FirmRead:
item = await model.get(id) item = await Firm.get(id)
if not item: if not item:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="{} record not found!".format(model.__name__) detail="{} record not found!".format(Firm.__name__)
) )
if item.owner != user.id: if item.owner != user.id:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail="Insufficient credentials to modify {} record".format(model.__name__) detail="Insufficient credentials to modify {} record".format(Firm.__name__)
) )
req = {k: v for k, v in req.model_dump().items() if v is not None} req = {k: v for k, v in req.model_dump().items() if v is not None}
@@ -50,23 +47,23 @@ async def update(id: PydanticObjectId, req: model_update, user=Depends(get_curre
}} }}
await item.update(update_query) await item.update(update_query)
return model_read(**item.dict()) return FirmRead(**item.dict())
@router.delete("/{id}", response_description="{} record deleted from the database".format(model.__name__)) @router.delete("/{id}", response_description="{} record deleted from the database".format(Firm.__name__))
async def delete(id: PydanticObjectId, user=Depends(get_current_user)) -> dict: async def delete(id: PydanticObjectId, user=Depends(get_current_user)) -> dict:
item = await model.get(id) item = await Firm.get(id)
if not item: if not item:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="{} record not found!".format(model.__name__) detail="{} record not found!".format(Firm.__name__)
) )
if item.owner != user.id: if item.owner != user.id:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail="Insufficient credentials delete {} record".format(model.__name__) detail="Insufficient credentials delete {} record".format(Firm.__name__)
) )
await item.delete() await item.delete()
return { return {
"message": "{} deleted successfully".format(model.__name__) "message": "{} deleted successfully".format(Firm.__name__)
} }

View File

@@ -1,13 +1,17 @@
from beanie import PydanticObjectId from beanie import Document
from fastapi_users_db_beanie import BaseOAuthAccount, BeanieUserDatabase, BeanieBaseUserDocument from fastapi_users_db_beanie import BaseOAuthAccount, BeanieUserDatabase, BeanieBaseUser
from pydantic import Field from pydantic import Field
from hub.firm import FirmRead
from hub.user.schemas import UserSchema, UserUpdateSchema
class OAuthAccount(BaseOAuthAccount): class OAuthAccount(BaseOAuthAccount):
pass pass
class User(BeanieBaseUserDocument): class User(BeanieBaseUser, Document):
oauth_accounts: list[OAuthAccount] = Field(default_factory=list) oauth_accounts: list[OAuthAccount] = Field(default_factory=list)
firms: list[PydanticObjectId] = Field(default_factory=list) firms: list[FirmRead] = Field(default_factory=list)
class UserDatabase(BeanieUserDatabase): class UserDatabase(BeanieUserDatabase):
pass pass

View File

@@ -1,7 +1,12 @@
from beanie import PydanticObjectId from beanie import PydanticObjectId
from fastapi_users.schemas import BaseUser from fastapi_users.schemas import BaseUser, BaseUserUpdate
from pydantic import Field from pydantic import Field
from hub.firm import FirmRead
class UserSchema(BaseUser[PydanticObjectId]): class UserSchema(BaseUser[PydanticObjectId]):
firms: list[PydanticObjectId] = Field() firms: list[FirmRead] = Field()
class UserUpdateSchema(BaseUserUpdate):
pass

View File

@@ -13,7 +13,7 @@ import routerBindings, {
DocumentTitleHandler, DocumentTitleHandler,
UnsavedChangesNotifier, UnsavedChangesNotifier,
} from "@refinedev/react-router"; } from "@refinedev/react-router";
import { BrowserRouter, Outlet, Route, Routes } from "react-router"; import { BrowserRouter, Link, Outlet, Route, Routes } from "react-router";
import { authProvider } from "./providers/auth-provider"; import { authProvider } from "./providers/auth-provider";
import { dataProvider } from "./providers/data-provider"; import { dataProvider } from "./providers/data-provider";
import { ColorModeContextProvider } from "./contexts/color-mode"; import { ColorModeContextProvider } from "./contexts/color-mode";
@@ -58,7 +58,7 @@ function App() {
<Route path="/hub" element={ <Hub /> } /> <Route path="/hub" element={ <Hub /> } />
<Route path="/hub/create-firm" element={ <CreateFirm /> } /> <Route path="/hub/create-firm" element={ <CreateFirm /> } />
</Route> </Route>
<Route index element={<h1>HOME</h1>} /> <Route index element={<h1>HOME&nbsp;<Link to={"/login"}>Login</Link></h1>} />
<Route path="/login" element={<Login />} /> <Route path="/login" element={<Login />} />
<Route path="/register" element={<Register />} /> <Route path="/register" element={<Register />} />
<Route path="/forgot-password" element={<ForgotPassword />} /> <Route path="/forgot-password" element={<ForgotPassword />} />

View File

@@ -1,11 +1,7 @@
import {Navigate, useSearchParams} from "react-router";
import {AuthPage} from "@refinedev/mui";
import GoogleIcon from "@mui/icons-material/Google";
import DiscordIcon from "../DiscordIcon";
import { useLogout } from "@refinedev/core"; import { useLogout } from "@refinedev/core";
export const Logout = () => { export const Logout = () => {
const { mutate: logout } = useLogout(); const { mutate: logout } = useLogout();
return <button onClick={() => logout()}>Logout</button>; return <button onClick={() => logout()} >Logout</button>;
}; };

View File

@@ -10,13 +10,8 @@ import { useGetIdentity } from "@refinedev/core";
import { HamburgerMenu, RefineThemedLayoutV2HeaderProps } from "@refinedev/mui"; import { HamburgerMenu, RefineThemedLayoutV2HeaderProps } from "@refinedev/mui";
import React, { useContext } from "react"; import React, { useContext } from "react";
import { ColorModeContext } from "../../contexts/color-mode"; import { ColorModeContext } from "../../contexts/color-mode";
import {Logout} from "../auth/Logout"; import { Logout } from "../auth/Logout";
import { IUser } from "../../interfaces";
type IUser = {
id: number;
email: string;
avatar: string;
};
export const Header: React.FC<RefineThemedLayoutV2HeaderProps> = ({ export const Header: React.FC<RefineThemedLayoutV2HeaderProps> = ({
sticky = true, sticky = true,
@@ -50,7 +45,7 @@ export const Header: React.FC<RefineThemedLayoutV2HeaderProps> = ({
{mode === "dark" ? <LightModeOutlined /> : <DarkModeOutlined />} {mode === "dark" ? <LightModeOutlined /> : <DarkModeOutlined />}
</IconButton> </IconButton>
{(user?.avatar || user?.email) && ( {(user?.email) && (
<Stack <Stack
direction="row" direction="row"
gap="16px" gap="16px"
@@ -70,7 +65,7 @@ export const Header: React.FC<RefineThemedLayoutV2HeaderProps> = ({
{user?.email} {user?.email}
</Typography> </Typography>
)} )}
<Avatar src={user?.avatar} alt={user?.email} /> <Avatar src={"user?.avatar"} alt={user?.email} />
<Logout /> <Logout />
</Stack> </Stack>
)} )}

View File

@@ -0,0 +1,13 @@
export type IFirm = {
instance: string,
name: string
}
type User = {
id: number,
email: string,
firms: [IFirm],
};
export type IUser = User | null;

View File

@@ -8,10 +8,11 @@ import CrudTextWidget from "./widgets/crud-text-widget";
import UnionEnumField from "./fields/union-enum"; import UnionEnumField from "./fields/union-enum";
type Props = { type Props = {
schemaName: string, schemaName: string,
resource: string, resource: string,
id?: string, id?: string,
//onSubmit: (data: IChangeEvent, event: FormEvent<any>) => void //onSubmit: (data: IChangeEvent, event: FormEvent<any>) => void
onSuccess?: (data: any) => void
} }
const customWidgets: RegistryWidgetsType = { const customWidgets: RegistryWidgetsType = {
@@ -22,12 +23,13 @@ const customFields: RegistryFieldsType = {
AnyOfField: UnionEnumField AnyOfField: UnionEnumField
} }
export const CrudForm: React.FC<Props> = ({schemaName, resource, id}) => { export const CrudForm: React.FC<Props> = ({ schemaName, resource, id, onSuccess }) => {
const { onFinish, query, formLoading } = useForm({ const { onFinish, query, formLoading } = useForm({
resource: resource, resource: resource,
action: id === undefined ? "create" : "edit", action: id === undefined ? "create" : "edit",
redirect: "show", redirect: "show",
id, id,
onMutationSuccess: (data: any) => { if (onSuccess) { onSuccess(data) } },
}); });
const record = query?.data?.data; const record = query?.data?.data;

View File

@@ -1,7 +1,19 @@
import { useInvalidateAuthStore } from "@refinedev/core";
import { CrudForm } from "../../lib/crud/components/crud-form"; import { CrudForm } from "../../lib/crud/components/crud-form";
import {empty_user} from "../../providers/auth-provider";
export const CreateFirm = () => { export const CreateFirm = () => {
const invalidateAuthStore = useInvalidateAuthStore()
const refreshUser = () => {
empty_user();
invalidateAuthStore().then();
}
return ( return (
<CrudForm schemaName={"FirmCreate"} resource={"firms"} /> <CrudForm
schemaName={"FirmCreate"}
resource={"firms"}
onSuccess={() => { refreshUser() }}
/>
) )
} }

View File

@@ -1,12 +1,38 @@
import { Button } from "@mui/material"; import { Button } from "@mui/material";
import { Link } from "react-router"; import { Link } from "react-router";
import { useGetIdentity } from "@refinedev/core";
type Firm = {
name: string,
instance: string,
}
type User = {
firms: [Firm],
}
export const Hub = () => { export const Hub = () => {
const user = useGetIdentity<User>();
console.log(user);
let ownFirms = [];
let workFirms = [];
//firms.forEach((f, index) => {
// workFirms.push(<li>{f.instance}/{f.name}</li>)
//})
//{firms.map((f: Firm, index) => (
// <li key={index}>{f.instance} / {f.name}</li>
// ))}
return ( return (
<div> <div>
<h1>HUB</h1> <h1>HUB</h1>
<p>List of managed firms</p> <p>List of managed firms</p>
<p>List of firm you're working atx</p> <ul>
<li></li>
</ul>
<p>List of firm you're working at</p>
<ul>
</ul>
<Link to="/hub/create-firm" ><Button >Create a new firm</Button></Link> <Link to="/hub/create-firm" ><Button >Create a new firm</Button></Link>
</div> </div>
); );

View File

@@ -1,4 +1,6 @@
import isEmpty from 'lodash/isEmpty';
import { AuthProvider } from "@refinedev/core"; import { AuthProvider } from "@refinedev/core";
import {IUser} from "../interfaces";
const API_URL = "/api/v1"; const API_URL = "/api/v1";
const LOCAL_STORAGE_USER_KEY = "rpk-gui-current-user"; const LOCAL_STORAGE_USER_KEY = "rpk-gui-current-user";
@@ -55,17 +57,24 @@ export const authProvider: AuthProvider = {
return { success: false }; return { success: false };
}, },
check: async () => { check: async () => {
return { authenticated: Boolean(get_user()) }; if (get_user() == null) {
return {
authenticated: false,
redirectTo: "/login",
logout: true
}
}
return { authenticated: true };
}, },
getIdentity: async () => { getIdentity: async (): Promise<IUser> => {
const user = get_user(); const user = get_user();
if (user != null) { if (user !== null && !isEmpty(user)) {
return user; return user;
} }
const response = await fetch(`${API_URL}/users/me`); const response = await fetch(`${API_URL}/users/me`);
if (response.status < 200 || response.status > 299) { if (response.status < 200 || response.status > 299) {
return return null;
} }
const user_data = await response.json(); const user_data = await response.json();
store_user(user_data) store_user(user_data)
@@ -163,6 +172,10 @@ function forget_user() {
localStorage.removeItem(LOCAL_STORAGE_USER_KEY); localStorage.removeItem(LOCAL_STORAGE_USER_KEY);
} }
export function empty_user() {
store_user({})
}
function findGetParameter(parameterName: string) { function findGetParameter(parameterName: string) {
let result = null, tmp = []; let result = null, tmp = [];
location.search.substr(1).split("&") location.search.substr(1).split("&")