diff --git a/src/app/api/auth/hf/callback/route.ts b/src/app/api/auth/hf/callback/route.ts deleted file mode 100644 index 2e353f3149ff687b697cb541e787a42442de501a..0000000000000000000000000000000000000000 --- a/src/app/api/auth/hf/callback/route.ts +++ /dev/null @@ -1,112 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { cookies } from 'next/headers'; - -const TOKEN_ENDPOINT = 'https://huggingface.co/oauth/token'; -const USERINFO_ENDPOINT = 'https://huggingface.co/oauth/userinfo'; -const STATE_COOKIE = 'hf_oauth_state'; - -function htmlResponse(script: string) { - return new NextResponse( - ``, - { - headers: { 'Content-Type': 'text/html; charset=utf-8' }, - }, - ); -} - -export async function GET(request: NextRequest) { - const clientId = process.env.HF_OAUTH_CLIENT_ID || process.env.NEXT_PUBLIC_HF_OAUTH_CLIENT_ID; - const clientSecret = process.env.HF_OAUTH_CLIENT_SECRET; - - if (!clientId || !clientSecret) { - return NextResponse.json({ error: 'OAuth application is not configured' }, { status: 500 }); - } - - const { searchParams } = new URL(request.url); - const code = searchParams.get('code'); - const incomingState = searchParams.get('state'); - - const cookieStore = cookies(); - const storedState = cookieStore.get(STATE_COOKIE)?.value; - - cookieStore.delete(STATE_COOKIE); - - const origin = request.nextUrl.origin; - - if (!code || !incomingState || !storedState || incomingState !== storedState) { - const script = ` - window.opener && window.opener.postMessage({ - type: 'HF_OAUTH_ERROR', - payload: { message: 'Invalid or expired OAuth state.' } - }, '${origin}'); - window.close(); - `; - return htmlResponse(script.trim()); - } - - const redirectUri = process.env.HF_OAUTH_REDIRECT_URI || process.env.NEXT_PUBLIC_HF_OAUTH_REDIRECT_URI || `${origin}/api/auth/hf/callback`; - - try { - const tokenResponse = await fetch(TOKEN_ENDPOINT, { - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - }, - body: new URLSearchParams({ - grant_type: 'authorization_code', - code, - redirect_uri: redirectUri, - client_id: clientId, - client_secret: clientSecret, - }), - }); - - if (!tokenResponse.ok) { - const errorPayload = await tokenResponse.json().catch(() => ({})); - throw new Error(errorPayload?.error_description || 'Failed to exchange code for token'); - } - - const tokenData = await tokenResponse.json(); - const accessToken = tokenData?.access_token; - if (!accessToken) { - throw new Error('Access token missing in response'); - } - - const userResponse = await fetch(USERINFO_ENDPOINT, { - headers: { - Authorization: `Bearer ${accessToken}`, - }, - }); - - if (!userResponse.ok) { - throw new Error('Failed to fetch user info'); - } - - const profile = await userResponse.json(); - const namespace = profile?.preferred_username || profile?.name || 'user'; - - const script = ` - window.opener && window.opener.postMessage({ - type: 'HF_OAUTH_SUCCESS', - payload: { - token: ${JSON.stringify(accessToken)}, - namespace: ${JSON.stringify(namespace)}, - } - }, '${origin}'); - window.close(); - `; - - return htmlResponse(script.trim()); - } catch (error: any) { - const message = error?.message || 'OAuth flow failed'; - const script = ` - window.opener && window.opener.postMessage({ - type: 'HF_OAUTH_ERROR', - payload: { message: ${JSON.stringify(message)} } - }, '${origin}'); - window.close(); - `; - - return htmlResponse(script.trim()); - } -} diff --git a/src/app/api/auth/hf/login/route.ts b/src/app/api/auth/hf/login/route.ts deleted file mode 100644 index 22c252217d8b94f9db7a495892a79193df05786a..0000000000000000000000000000000000000000 --- a/src/app/api/auth/hf/login/route.ts +++ /dev/null @@ -1,36 +0,0 @@ -import { randomUUID } from 'crypto'; -import { NextRequest, NextResponse } from 'next/server'; - -const HF_AUTHORIZE_URL = 'https://huggingface.co/oauth/authorize'; -const STATE_COOKIE = 'hf_oauth_state'; - -export async function GET(request: NextRequest) { - const clientId = process.env.HF_OAUTH_CLIENT_ID || process.env.NEXT_PUBLIC_HF_OAUTH_CLIENT_ID; - if (!clientId) { - return NextResponse.json({ error: 'OAuth client ID not configured' }, { status: 500 }); - } - - const state = randomUUID(); - const origin = request.nextUrl.origin; - const redirectUri = process.env.HF_OAUTH_REDIRECT_URI || process.env.NEXT_PUBLIC_HF_OAUTH_REDIRECT_URI || `${origin}/api/auth/hf/callback`; - - const authorizeUrl = new URL(HF_AUTHORIZE_URL); - authorizeUrl.searchParams.set('response_type', 'code'); - authorizeUrl.searchParams.set('client_id', clientId); - authorizeUrl.searchParams.set('redirect_uri', redirectUri); - authorizeUrl.searchParams.set('scope', 'openid profile read-repos'); - authorizeUrl.searchParams.set('state', state); - - const response = NextResponse.redirect(authorizeUrl.toString(), { status: 302 }); - response.cookies.set({ - name: STATE_COOKIE, - value: state, - httpOnly: true, - sameSite: 'lax', - secure: process.env.NODE_ENV === 'production', - maxAge: 60 * 5, - path: '/', - }); - - return response; -} diff --git a/src/app/api/auth/hf/validate/route.ts b/src/app/api/auth/hf/validate/route.ts deleted file mode 100644 index 32dc41fb4d3a7e82d8434ce577aa9e563c349203..0000000000000000000000000000000000000000 --- a/src/app/api/auth/hf/validate/route.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { whoAmI } from '@huggingface/hub'; - -export async function POST(request: NextRequest) { - try { - const body = await request.json().catch(() => ({})); - const token = (body?.token || '').trim(); - - if (!token) { - return NextResponse.json({ error: 'Token is required' }, { status: 400 }); - } - - const info = await whoAmI({ accessToken: token }); - return NextResponse.json({ - name: info?.name || info?.username || 'user', - email: info?.email || null, - orgs: info?.orgs || [], - }); - } catch (error: any) { - return NextResponse.json({ error: error?.message || 'Invalid token' }, { status: 401 }); - } -} diff --git a/src/app/api/auth/route.ts b/src/app/api/auth/route.ts deleted file mode 100644 index 1dc229739fbbeaabf307e3be544dd7e2bc8ab66f..0000000000000000000000000000000000000000 --- a/src/app/api/auth/route.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { NextResponse } from 'next/server'; - -export async function GET() { - // if this gets hit, auth has already been verified - return NextResponse.json({ isAuthenticated: true }); -} diff --git a/src/app/api/caption/get/route.ts b/src/app/api/caption/get/route.ts deleted file mode 100644 index 4f8d2818318805f97a80370e1a9cfc584cd9dc26..0000000000000000000000000000000000000000 --- a/src/app/api/caption/get/route.ts +++ /dev/null @@ -1,46 +0,0 @@ -/* eslint-disable */ -import { NextRequest, NextResponse } from 'next/server'; -import fs from 'fs'; -import path from 'path'; -import { getDatasetsRoot } from '@/server/settings'; - -export async function POST(request: NextRequest) { - - const body = await request.json(); - const { imgPath } = body; - console.log('Received POST request for caption:', imgPath); - try { - // Decode the path - const filepath = imgPath; - console.log('Decoded image path:', filepath); - - // caption name is the filepath without extension but with .txt - const captionPath = filepath.replace(/\.[^/.]+$/, '') + '.txt'; - - // Get allowed directories - const allowedDir = await getDatasetsRoot(); - - // Security check: Ensure path is in allowed directory - const isAllowed = filepath.startsWith(allowedDir) && !filepath.includes('..'); - - if (!isAllowed) { - console.warn(`Access denied: ${filepath} not in ${allowedDir}`); - return new NextResponse('Access denied', { status: 403 }); - } - - // Check if file exists - if (!fs.existsSync(captionPath)) { - // send back blank string if caption file does not exist - return new NextResponse(''); - } - - // Read caption file - const caption = fs.readFileSync(captionPath, 'utf-8'); - - // Return caption - return new NextResponse(caption); - } catch (error) { - console.error('Error getting caption:', error); - return new NextResponse('Error getting caption', { status: 500 }); - } -} diff --git a/src/app/api/datasets/create/route.tsx b/src/app/api/datasets/create/route.tsx deleted file mode 100644 index e005d058f3423db41f4830b69a1d51c7872d1351..0000000000000000000000000000000000000000 --- a/src/app/api/datasets/create/route.tsx +++ /dev/null @@ -1,25 +0,0 @@ -import { NextResponse } from 'next/server'; -import fs from 'fs'; -import path from 'path'; -import { getDatasetsRoot } from '@/server/settings'; - -export async function POST(request: Request) { - try { - const body = await request.json(); - let { name } = body; - // clean name by making lower case, removing special characters, and replacing spaces with underscores - name = name.toLowerCase().replace(/[^a-z0-9]+/g, '_'); - - let datasetsPath = await getDatasetsRoot(); - let datasetPath = path.join(datasetsPath, name); - - // if folder doesnt exist, create it - if (!fs.existsSync(datasetPath)) { - fs.mkdirSync(datasetPath); - } - - return NextResponse.json({ success: true, name: name, path: datasetPath }); - } catch (error) { - return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); - } -} diff --git a/src/app/api/datasets/delete/route.tsx b/src/app/api/datasets/delete/route.tsx deleted file mode 100644 index 9a1d970ee415c9d040596854ce74ad5401859259..0000000000000000000000000000000000000000 --- a/src/app/api/datasets/delete/route.tsx +++ /dev/null @@ -1,24 +0,0 @@ -import { NextResponse } from 'next/server'; -import fs from 'fs'; -import path from 'path'; -import { getDatasetsRoot } from '@/server/settings'; - -export async function POST(request: Request) { - try { - const body = await request.json(); - const { name } = body; - let datasetsPath = await getDatasetsRoot(); - let datasetPath = path.join(datasetsPath, name); - - // if folder doesnt exist, ignore - if (!fs.existsSync(datasetPath)) { - return NextResponse.json({ success: true }); - } - - // delete it and return success - fs.rmdirSync(datasetPath, { recursive: true }); - return NextResponse.json({ success: true }); - } catch (error) { - return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); - } -} diff --git a/src/app/api/datasets/list/route.ts b/src/app/api/datasets/list/route.ts deleted file mode 100644 index dc829c65f3cab2829221f85341967fc1b52a921c..0000000000000000000000000000000000000000 --- a/src/app/api/datasets/list/route.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { NextResponse } from 'next/server'; -import fs from 'fs'; -import { getDatasetsRoot } from '@/server/settings'; - -export async function GET() { - try { - let datasetsPath = await getDatasetsRoot(); - - // if folder doesnt exist, create it - if (!fs.existsSync(datasetsPath)) { - fs.mkdirSync(datasetsPath); - } - - // find all the folders in the datasets folder - let folders = fs - .readdirSync(datasetsPath, { withFileTypes: true }) - .filter(dirent => dirent.isDirectory()) - .filter(dirent => !dirent.name.startsWith('.')) - .map(dirent => dirent.name); - - return NextResponse.json(folders); - } catch (error) { - return NextResponse.json({ error: 'Failed to fetch datasets' }, { status: 500 }); - } -} diff --git a/src/app/api/datasets/listImages/route.ts b/src/app/api/datasets/listImages/route.ts deleted file mode 100644 index 06dca84ae780c7fddb200fc6de422b7a42e309ea..0000000000000000000000000000000000000000 --- a/src/app/api/datasets/listImages/route.ts +++ /dev/null @@ -1,61 +0,0 @@ -import { NextResponse } from 'next/server'; -import fs from 'fs'; -import path from 'path'; -import { getDatasetsRoot } from '@/server/settings'; - -export async function POST(request: Request) { - const datasetsPath = await getDatasetsRoot(); - const body = await request.json(); - const { datasetName } = body; - const datasetFolder = path.join(datasetsPath, datasetName); - - try { - // Check if folder exists - if (!fs.existsSync(datasetFolder)) { - return NextResponse.json({ error: `Folder '${datasetName}' not found` }, { status: 404 }); - } - - // Find all images recursively - const imageFiles = findImagesRecursively(datasetFolder); - - // Format response - const result = imageFiles.map(imgPath => ({ - img_path: imgPath, - })); - - return NextResponse.json({ images: result }); - } catch (error) { - console.error('Error finding images:', error); - return NextResponse.json({ error: 'Failed to process request' }, { status: 500 }); - } -} - -/** - * Recursively finds all image files in a directory and its subdirectories - * @param dir Directory to search - * @returns Array of absolute paths to image files - */ -function findImagesRecursively(dir: string): string[] { - const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp', '.mp4', '.avi', '.mov', '.mkv', '.wmv', '.m4v', '.flv']; - let results: string[] = []; - - const items = fs.readdirSync(dir); - - for (const item of items) { - const itemPath = path.join(dir, item); - const stat = fs.statSync(itemPath); - - if (stat.isDirectory() && item !== '_controls' && !item.startsWith('.')) { - // If it's a directory, recursively search it - results = results.concat(findImagesRecursively(itemPath)); - } else { - // If it's a file, check if it's an image - const ext = path.extname(itemPath).toLowerCase(); - if (imageExtensions.includes(ext)) { - results.push(itemPath); - } - } - } - - return results; -} diff --git a/src/app/api/datasets/upload/route.ts b/src/app/api/datasets/upload/route.ts deleted file mode 100644 index 51aff81fd3bf4b091f10a1df9f2da887910f4753..0000000000000000000000000000000000000000 --- a/src/app/api/datasets/upload/route.ts +++ /dev/null @@ -1,57 +0,0 @@ -// src/app/api/datasets/upload/route.ts -import { NextRequest, NextResponse } from 'next/server'; -import { writeFile, mkdir } from 'fs/promises'; -import { join } from 'path'; -import { getDatasetsRoot } from '@/server/settings'; - -export async function POST(request: NextRequest) { - try { - const datasetsPath = await getDatasetsRoot(); - if (!datasetsPath) { - return NextResponse.json({ error: 'Datasets path not found' }, { status: 500 }); - } - const formData = await request.formData(); - const files = formData.getAll('files'); - const datasetName = formData.get('datasetName') as string; - - if (!files || files.length === 0) { - return NextResponse.json({ error: 'No files provided' }, { status: 400 }); - } - - // Create upload directory if it doesn't exist - const uploadDir = join(datasetsPath, datasetName); - await mkdir(uploadDir, { recursive: true }); - - const savedFiles: string[] = []; - - // Process files sequentially to avoid overwhelming the system - for (let i = 0; i < files.length; i++) { - const file = files[i] as any; - const bytes = await file.arrayBuffer(); - const buffer = Buffer.from(bytes); - - // Clean filename and ensure it's unique - const fileName = file.name.replace(/[^a-zA-Z0-9.-]/g, '_'); - const filePath = join(uploadDir, fileName); - - await writeFile(filePath, buffer); - savedFiles.push(fileName); - } - - return NextResponse.json({ - message: 'Files uploaded successfully', - files: savedFiles, - }); - } catch (error) { - console.error('Upload error:', error); - return NextResponse.json({ error: 'Error uploading files' }, { status: 500 }); - } -} - -// Increase payload size limit (default is 4mb) -export const config = { - api: { - bodyParser: false, - responseLimit: '50mb', - }, -}; diff --git a/src/app/api/files/[...filePath]/route.ts b/src/app/api/files/[...filePath]/route.ts deleted file mode 100644 index 46eb5c4ab08b9c02ba4ff8d0fe7f6dc2cd15442a..0000000000000000000000000000000000000000 --- a/src/app/api/files/[...filePath]/route.ts +++ /dev/null @@ -1,116 +0,0 @@ -/* eslint-disable */ -import { NextRequest, NextResponse } from 'next/server'; -import fs from 'fs'; -import path from 'path'; -import { getDatasetsRoot, getTrainingFolder } from '@/server/settings'; - -export async function GET(request: NextRequest, { params }: { params: { filePath: string } }) { - const { filePath } = await params; - try { - // Decode the path - const decodedFilePath = decodeURIComponent(filePath); - - // Get allowed directories - const datasetRoot = await getDatasetsRoot(); - const trainingRoot = await getTrainingFolder(); - const allowedDirs = [datasetRoot, trainingRoot]; - - // Security check: Ensure path is in allowed directory - const isAllowed = - allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..'); - - if (!isAllowed) { - console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`); - return new NextResponse('Access denied', { status: 403 }); - } - - // Check if file exists - if (!fs.existsSync(decodedFilePath)) { - console.warn(`File not found: ${decodedFilePath}`); - return new NextResponse('File not found', { status: 404 }); - } - - // Get file info - const stat = fs.statSync(decodedFilePath); - if (!stat.isFile()) { - return new NextResponse('Not a file', { status: 400 }); - } - - // Get filename for Content-Disposition - const filename = path.basename(decodedFilePath); - - // Determine content type - const ext = path.extname(decodedFilePath).toLowerCase(); - const contentTypeMap: { [key: string]: string } = { - '.jpg': 'image/jpeg', - '.jpeg': 'image/jpeg', - '.png': 'image/png', - '.gif': 'image/gif', - '.webp': 'image/webp', - '.svg': 'image/svg+xml', - '.bmp': 'image/bmp', - '.safetensors': 'application/octet-stream', - '.zip': 'application/zip', - // Videos - '.mp4': 'video/mp4', - '.avi': 'video/x-msvideo', - '.mov': 'video/quicktime', - '.mkv': 'video/x-matroska', - '.wmv': 'video/x-ms-wmv', - '.m4v': 'video/x-m4v', - '.flv': 'video/x-flv' - }; - - const contentType = contentTypeMap[ext] || 'application/octet-stream'; - - // Get range header for partial content support - const range = request.headers.get('range'); - - // Common headers for better download handling - const commonHeaders = { - 'Content-Type': contentType, - 'Accept-Ranges': 'bytes', - 'Cache-Control': 'public, max-age=86400', - 'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`, - 'X-Content-Type-Options': 'nosniff', - }; - - if (range) { - // Parse range header - const parts = range.replace(/bytes=/, '').split('-'); - const start = parseInt(parts[0], 10); - const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks - const chunkSize = end - start + 1; - - const fileStream = fs.createReadStream(decodedFilePath, { - start, - end, - highWaterMark: 64 * 1024, // 64KB buffer - }); - - return new NextResponse(fileStream as any, { - status: 206, - headers: { - ...commonHeaders, - 'Content-Range': `bytes ${start}-${end}/${stat.size}`, - 'Content-Length': String(chunkSize), - }, - }); - } else { - // For full file download, read directly without streaming wrapper - const fileStream = fs.createReadStream(decodedFilePath, { - highWaterMark: 64 * 1024, // 64KB buffer - }); - - return new NextResponse(fileStream as any, { - headers: { - ...commonHeaders, - 'Content-Length': String(stat.size), - }, - }); - } - } catch (error) { - console.error('Error serving file:', error); - return new NextResponse('Internal Server Error', { status: 500 }); - } -} diff --git a/src/app/api/gpu/route.ts b/src/app/api/gpu/route.ts deleted file mode 100644 index 8b11dbb0e6d8e8de0f191bb1e78bb8687376881a..0000000000000000000000000000000000000000 --- a/src/app/api/gpu/route.ts +++ /dev/null @@ -1,121 +0,0 @@ -import { NextResponse } from 'next/server'; -import { exec } from 'child_process'; -import { promisify } from 'util'; -import os from 'os'; - -const execAsync = promisify(exec); - -export async function GET() { - try { - // Get platform - const platform = os.platform(); - const isWindows = platform === 'win32'; - - // Check if nvidia-smi is available - const hasNvidiaSmi = await checkNvidiaSmi(isWindows); - - if (!hasNvidiaSmi) { - return NextResponse.json({ - hasNvidiaSmi: false, - gpus: [], - error: 'nvidia-smi not found or not accessible', - }); - } - - // Get GPU stats - const gpuStats = await getGpuStats(isWindows); - - return NextResponse.json({ - hasNvidiaSmi: true, - gpus: gpuStats, - }); - } catch (error) { - console.error('Error fetching NVIDIA GPU stats:', error); - return NextResponse.json( - { - hasNvidiaSmi: false, - gpus: [], - error: `Failed to fetch GPU stats: ${error instanceof Error ? error.message : String(error)}`, - }, - { status: 500 }, - ); - } -} - -async function checkNvidiaSmi(isWindows: boolean): Promise { - try { - if (isWindows) { - // Check if nvidia-smi is available on Windows - // It's typically located in C:\Program Files\NVIDIA Corporation\NVSMI\nvidia-smi.exe - // but we'll just try to run it directly as it may be in PATH - await execAsync('nvidia-smi -L'); - } else { - // Linux/macOS check - await execAsync('which nvidia-smi'); - } - return true; - } catch (error) { - return false; - } -} - -async function getGpuStats(isWindows: boolean) { - // Command is the same for both platforms, but the path might be different - const command = - 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits'; - - // Execute command - const { stdout } = await execAsync(command); - - // Parse CSV output - const gpus = stdout - .trim() - .split('\n') - .map(line => { - const [ - index, - name, - driverVersion, - temperature, - gpuUtil, - memoryUtil, - memoryTotal, - memoryFree, - memoryUsed, - powerDraw, - powerLimit, - clockGraphics, - clockMemory, - fanSpeed, - ] = line.split(', ').map(item => item.trim()); - - return { - index: parseInt(index), - name, - driverVersion, - temperature: parseInt(temperature), - utilization: { - gpu: parseInt(gpuUtil), - memory: parseInt(memoryUtil), - }, - memory: { - total: parseInt(memoryTotal), - free: parseInt(memoryFree), - used: parseInt(memoryUsed), - }, - power: { - draw: parseFloat(powerDraw), - limit: parseFloat(powerLimit), - }, - clocks: { - graphics: parseInt(clockGraphics), - memory: parseInt(clockMemory), - }, - fan: { - speed: parseInt(fanSpeed) || 0, // Some GPUs might not report fan speed, default to 0 - }, - }; - }); - - return gpus; -} diff --git a/src/app/api/hf-hub/route.ts b/src/app/api/hf-hub/route.ts deleted file mode 100644 index afdfb64c599b6fbad3c832d7450176ba3ca2b2c0..0000000000000000000000000000000000000000 --- a/src/app/api/hf-hub/route.ts +++ /dev/null @@ -1,165 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { whoAmI, createRepo, uploadFiles, datasetInfo } from '@huggingface/hub'; -import { readdir, stat } from 'fs/promises'; -import path from 'path'; - -export async function POST(request: NextRequest) { - try { - const body = await request.json(); - const { action, token, namespace, datasetName, datasetPath, datasetId } = body; - - if (!token) { - return NextResponse.json({ error: 'HF token is required' }, { status: 400 }); - } - - switch (action) { - case 'whoami': - try { - const user = await whoAmI({ accessToken: token }); - return NextResponse.json({ user }); - } catch (error) { - return NextResponse.json({ error: 'Invalid token or network error' }, { status: 401 }); - } - - case 'createDataset': - try { - if (!namespace || !datasetName) { - return NextResponse.json({ error: 'Namespace and dataset name required' }, { status: 400 }); - } - - const repoId = `datasets/${namespace}/${datasetName}`; - - // Create repository - await createRepo({ - repo: repoId, - accessToken: token, - private: false, - }); - - return NextResponse.json({ success: true, repoId }); - } catch (error: any) { - if (error.message?.includes('already exists')) { - return NextResponse.json({ success: true, repoId: `${namespace}/${datasetName}`, exists: true }); - } - return NextResponse.json({ error: error.message || 'Failed to create dataset' }, { status: 500 }); - } - - case 'uploadDataset': - try { - if (!namespace || !datasetName || !datasetPath) { - return NextResponse.json({ error: 'Missing required parameters' }, { status: 400 }); - } - - const repoId = `datasets/${namespace}/${datasetName}`; - - // Check if directory exists - try { - await stat(datasetPath); - } catch { - return NextResponse.json({ error: 'Dataset path does not exist' }, { status: 400 }); - } - - // Read files from directory and upload them - const files = await readdir(datasetPath); - const filesToUpload = []; - - for (const fileName of files) { - const filePath = path.join(datasetPath, fileName); - const fileStats = await stat(filePath); - - if (fileStats.isFile()) { - filesToUpload.push({ - path: fileName, - content: new URL(`file://${filePath}`) - }); - } - } - - if (filesToUpload.length > 0) { - await uploadFiles({ - repo: repoId, - accessToken: token, - files: filesToUpload, - }); - } - - return NextResponse.json({ success: true, repoId }); - } catch (error: any) { - console.error('Upload error:', error); - return NextResponse.json({ error: error.message || 'Failed to upload dataset' }, { status: 500 }); - } - - case 'listFiles': - try { - if (!datasetPath) { - return NextResponse.json({ error: 'Dataset path required' }, { status: 400 }); - } - - const files = await readdir(datasetPath, { withFileTypes: true }); - const imageExtensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp']; - - const imageFiles = files - .filter(file => file.isFile()) - .filter(file => imageExtensions.some(ext => file.name.toLowerCase().endsWith(ext))) - .map(file => ({ - name: file.name, - path: path.join(datasetPath, file.name), - })); - - const captionFiles = files - .filter(file => file.isFile()) - .filter(file => file.name.endsWith('.txt')) - .map(file => ({ - name: file.name, - path: path.join(datasetPath, file.name), - })); - - return NextResponse.json({ - images: imageFiles, - captions: captionFiles, - total: imageFiles.length - }); - } catch (error: any) { - return NextResponse.json({ error: error.message || 'Failed to list files' }, { status: 500 }); - } - - case 'validateDataset': - try { - if (!datasetId) { - return NextResponse.json({ error: 'Dataset ID required' }, { status: 400 }); - } - - // Try to get dataset info to validate it exists and is accessible - const dataset = await datasetInfo({ - name: datasetId, - accessToken: token, - }); - - return NextResponse.json({ - exists: true, - dataset: { - id: dataset.id, - author: dataset.author, - downloads: dataset.downloads, - likes: dataset.likes, - private: dataset.private, - } - }); - } catch (error: any) { - if (error.message?.includes('404') || error.message?.includes('not found')) { - return NextResponse.json({ exists: false }, { status: 200 }); - } - if (error.message?.includes('401') || error.message?.includes('403')) { - return NextResponse.json({ error: 'Dataset not accessible with current token' }, { status: 403 }); - } - return NextResponse.json({ error: error.message || 'Failed to validate dataset' }, { status: 500 }); - } - - default: - return NextResponse.json({ error: 'Invalid action' }, { status: 400 }); - } - } catch (error: any) { - console.error('HF Hub API error:', error); - return NextResponse.json({ error: error.message || 'Internal server error' }, { status: 500 }); - } -} \ No newline at end of file diff --git a/src/app/api/hf-jobs/route.ts b/src/app/api/hf-jobs/route.ts deleted file mode 100644 index 12fe64374cc3552584f8f9fdbe2948fa47996b62..0000000000000000000000000000000000000000 --- a/src/app/api/hf-jobs/route.ts +++ /dev/null @@ -1,761 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { spawn } from 'child_process'; -import { writeFile } from 'fs/promises'; -import path from 'path'; -import { tmpdir } from 'os'; - -export async function POST(request: NextRequest) { - try { - const body = await request.json(); - const { action, token, hardware, namespace, jobConfig, datasetRepo } = body; - - switch (action) { - case 'checkStatus': - try { - if (!token || !jobConfig?.hf_job_id) { - return NextResponse.json({ error: 'Token and job ID required' }, { status: 400 }); - } - - const jobStatus = await checkHFJobStatus(token, jobConfig.hf_job_id); - return NextResponse.json({ status: jobStatus }); - } catch (error: any) { - console.error('Job status check error:', error); - return NextResponse.json({ error: error.message }, { status: 500 }); - } - - case 'generateScript': - try { - const uvScript = generateUVScript({ - jobConfig, - datasetRepo, - namespace, - token: token || 'YOUR_HF_TOKEN', - }); - - return NextResponse.json({ - script: uvScript, - filename: `train_${jobConfig.config.name.replace(/[^a-zA-Z0-9]/g, '_')}.py` - }); - } catch (error: any) { - return NextResponse.json({ error: error.message }, { status: 500 }); - } - - case 'submitJob': - try { - if (!token || !hardware) { - return NextResponse.json({ error: 'Token and hardware required' }, { status: 400 }); - } - - // Generate UV script - const uvScript = generateUVScript({ - jobConfig, - datasetRepo, - namespace, - token, - }); - - // Write script to temporary file - const scriptPath = path.join(tmpdir(), `train_${Date.now()}.py`); - await writeFile(scriptPath, uvScript); - - // Submit HF job using uv run - const jobId = await submitHFJobUV(token, hardware, scriptPath); - - return NextResponse.json({ - success: true, - jobId, - message: `Job submitted successfully with ID: ${jobId}` - }); - } catch (error: any) { - console.error('Job submission error:', error); - return NextResponse.json({ error: error.message }, { status: 500 }); - } - - default: - return NextResponse.json({ error: 'Invalid action' }, { status: 400 }); - } - } catch (error: any) { - console.error('HF Jobs API error:', error); - return NextResponse.json({ error: error.message }, { status: 500 }); - } -} - -function generateUVScript({ jobConfig, datasetRepo, namespace, token }: { - jobConfig: any; - datasetRepo: string; - namespace: string; - token: string; -}) { - const config = jobConfig.config; - const process = config.process[0]; - - return `# /// script -# dependencies = [ -# "torch>=2.0.0", -# "torchvision", -# "torchao==0.10.0", -# "safetensors", -# "diffusers @ git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63", -# "transformers==4.52.4", -# "lycoris-lora==1.8.3", -# "flatten_json", -# "pyyaml", -# "oyaml", -# "tensorboard", -# "kornia", -# "invisible-watermark", -# "einops", -# "accelerate", -# "toml", -# "albumentations==1.4.15", -# "albucore==0.0.16", -# "pydantic", -# "omegaconf", -# "k-diffusion", -# "open_clip_torch", -# "timm", -# "prodigyopt", -# "controlnet_aux==0.0.10", -# "python-dotenv", -# "bitsandbytes", -# "hf_transfer", -# "lpips", -# "pytorch_fid", -# "optimum-quanto==0.2.4", -# "sentencepiece", -# "huggingface_hub", -# "peft", -# "python-slugify", -# "opencv-python-headless", -# "pytorch-wavelets==1.3.0", -# "matplotlib==3.10.1", -# "setuptools==69.5.1", -# "datasets==4.0.0", -# "pyarrow==20.0.0", -# "pillow", -# "ftfy", -# ] -# /// - -import os -import sys -import subprocess -import argparse -import oyaml as yaml -from datasets import load_dataset -from huggingface_hub import HfApi, create_repo, upload_folder, snapshot_download -import tempfile -import shutil -import glob -from PIL import Image - -def setup_ai_toolkit(): - """Clone and setup ai-toolkit repository""" - repo_dir = "ai-toolkit" - if not os.path.exists(repo_dir): - print("Cloning ai-toolkit repository...") - subprocess.run( - ["git", "clone", "https://github.com/ostris/ai-toolkit.git", repo_dir], - check=True - ) - sys.path.insert(0, os.path.abspath(repo_dir)) - return repo_dir - -def download_dataset(dataset_repo: str, local_path: str): - """Download dataset from HF Hub as files""" - print(f"Downloading dataset from {dataset_repo}...") - - # Create local dataset directory - os.makedirs(local_path, exist_ok=True) - - # Use snapshot_download to get the dataset files directly - from huggingface_hub import snapshot_download - - try: - # First try to download as a structured dataset - dataset = load_dataset(dataset_repo, split="train") - - # Download images and captions from structured dataset - for i, item in enumerate(dataset): - # Save image - if "image" in item: - image_path = os.path.join(local_path, f"image_{i:06d}.jpg") - image = item["image"] - - # Convert RGBA to RGB if necessary (for JPEG compatibility) - if image.mode == 'RGBA': - # Create a white background and paste the RGBA image on it - background = Image.new('RGB', image.size, (255, 255, 255)) - background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask - image = background - elif image.mode not in ['RGB', 'L']: - # Convert any other mode to RGB - image = image.convert('RGB') - - image.save(image_path, 'JPEG') - - # Save caption - if "text" in item: - caption_path = os.path.join(local_path, f"image_{i:06d}.txt") - with open(caption_path, "w", encoding="utf-8") as f: - f.write(item["text"]) - - print(f"Downloaded {len(dataset)} items to {local_path}") - - except Exception as e: - print(f"Failed to load as structured dataset: {e}") - print("Attempting to download raw files...") - - # Download the dataset repository as files - temp_repo_path = snapshot_download(repo_id=dataset_repo, repo_type="dataset") - - # Copy all image and text files to the local path - import glob - import shutil - - print(f"Downloaded repo to: {temp_repo_path}") - print(f"Contents: {os.listdir(temp_repo_path)}") - - # Find all image files - image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp', '*.JPG', '*.JPEG', '*.PNG'] - image_files = [] - for ext in image_extensions: - pattern = os.path.join(temp_repo_path, "**", ext) - found_files = glob.glob(pattern, recursive=True) - image_files.extend(found_files) - print(f"Pattern {pattern} found {len(found_files)} files") - - # Find all text files - text_files = glob.glob(os.path.join(temp_repo_path, "**", "*.txt"), recursive=True) - - print(f"Found {len(image_files)} image files and {len(text_files)} text files") - - # Copy image files - for i, img_file in enumerate(image_files): - dest_path = os.path.join(local_path, f"image_{i:06d}.jpg") - - # Load and convert image if needed - try: - with Image.open(img_file) as image: - if image.mode == 'RGBA': - background = Image.new('RGB', image.size, (255, 255, 255)) - background.paste(image, mask=image.split()[-1]) - image = background - elif image.mode not in ['RGB', 'L']: - image = image.convert('RGB') - - image.save(dest_path, 'JPEG') - except Exception as img_error: - print(f"Error processing image {img_file}: {img_error}") - continue - - # Copy text files (captions) - for i, txt_file in enumerate(text_files[:len(image_files)]): # Match number of images - dest_path = os.path.join(local_path, f"image_{i:06d}.txt") - try: - shutil.copy2(txt_file, dest_path) - except Exception as txt_error: - print(f"Error copying text file {txt_file}: {txt_error}") - continue - - print(f"Downloaded {len(image_files)} images and {len(text_files)} captions to {local_path}") - -def create_config(dataset_path: str, output_path: str): - """Create training configuration""" - import json - - # Load config from JSON string and fix boolean/null values for Python - config_str = """${JSON.stringify(jobConfig, null, 2)}""" - config_str = config_str.replace('true', 'True').replace('false', 'False').replace('null', 'None') - config = eval(config_str) - - # Update paths for cloud environment - config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_path - config["config"]["process"][0]["training_folder"] = output_path - - # Remove sqlite_db_path as it's not needed for cloud training - if "sqlite_db_path" in config["config"]["process"][0]: - del config["config"]["process"][0]["sqlite_db_path"] - - # Also change trainer type from ui_trainer to standard trainer to avoid UI dependencies - if config["config"]["process"][0]["type"] == "ui_trainer": - config["config"]["process"][0]["type"] = "sd_trainer" - - return config - -def upload_results(output_path: str, model_name: str, namespace: str, token: str, config: dict): - """Upload trained model to HF Hub with README generation and proper file organization""" - import tempfile - import shutil - import glob - import re - import yaml - from datetime import datetime - from huggingface_hub import create_repo, upload_file, HfApi - - try: - repo_id = f"{namespace}/{model_name}" - - # Create repository - create_repo(repo_id=repo_id, token=token, exist_ok=True) - - print(f"Uploading model to {repo_id}...") - - # Create temporary directory for organized upload - with tempfile.TemporaryDirectory() as temp_upload_dir: - api = HfApi() - - # 1. Find and upload model files to root directory - safetensors_files = glob.glob(os.path.join(output_path, "**", "*.safetensors"), recursive=True) - json_files = glob.glob(os.path.join(output_path, "**", "*.json"), recursive=True) - txt_files = glob.glob(os.path.join(output_path, "**", "*.txt"), recursive=True) - - uploaded_files = [] - - # Upload .safetensors files to root - for file_path in safetensors_files: - filename = os.path.basename(file_path) - print(f"Uploading {filename} to repository root...") - api.upload_file( - path_or_fileobj=file_path, - path_in_repo=filename, - repo_id=repo_id, - token=token - ) - uploaded_files.append(filename) - - # Upload relevant JSON config files to root (skip metadata.json and other internal files) - config_files_uploaded = [] - for file_path in json_files: - filename = os.path.basename(file_path) - # Only upload important config files, skip internal metadata - if any(keyword in filename.lower() for keyword in ['config', 'adapter', 'lora', 'model']): - print(f"Uploading {filename} to repository root...") - api.upload_file( - path_or_fileobj=file_path, - path_in_repo=filename, - repo_id=repo_id, - token=token - ) - uploaded_files.append(filename) - config_files_uploaded.append(filename) - - # 2. Handle sample images - samples_uploaded = [] - samples_dir = os.path.join(output_path, "samples") - if os.path.isdir(samples_dir): - print("Uploading sample images...") - # Create samples directory in repo - for filename in os.listdir(samples_dir): - if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')): - file_path = os.path.join(samples_dir, filename) - repo_path = f"samples/{filename}" - api.upload_file( - path_or_fileobj=file_path, - path_in_repo=repo_path, - repo_id=repo_id, - token=token - ) - samples_uploaded.append(repo_path) - - # 3. Generate and upload README.md - readme_content = generate_model_card_readme( - repo_id=repo_id, - config=config, - model_name=model_name, - samples_dir=samples_dir if os.path.isdir(samples_dir) else None, - uploaded_files=uploaded_files - ) - - # Create README.md file and upload to root - readme_path = os.path.join(temp_upload_dir, "README.md") - with open(readme_path, "w", encoding="utf-8") as f: - f.write(readme_content) - - print("Uploading README.md to repository root...") - api.upload_file( - path_or_fileobj=readme_path, - path_in_repo="README.md", - repo_id=repo_id, - token=token - ) - - print(f"Model uploaded successfully to https://huggingface.co/{repo_id}") - print(f"Files uploaded: {len(uploaded_files)} model files, {len(samples_uploaded)} samples, README.md") - - except Exception as e: - print(f"Failed to upload model: {e}") - raise e - -def generate_model_card_readme(repo_id: str, config: dict, model_name: str, samples_dir: str = None, uploaded_files: list = None) -> str: - """Generate README.md content for the model card based on AI Toolkit's implementation""" - import re - import yaml - import os - - try: - # Extract configuration details - process_config = config.get("config", {}).get("process", [{}])[0] - model_config = process_config.get("model", {}) - train_config = process_config.get("train", {}) - sample_config = process_config.get("sample", {}) - - # Gather model info - base_model = model_config.get("name_or_path", "unknown") - trigger_word = process_config.get("trigger_word") - arch = model_config.get("arch", "") - - # Determine license based on base model - if "FLUX.1-schnell" in base_model: - license_info = {"license": "apache-2.0"} - elif "FLUX.1-dev" in base_model: - license_info = { - "license": "other", - "license_name": "flux-1-dev-non-commercial-license", - "license_link": "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md" - } - else: - license_info = {"license": "creativeml-openrail-m"} - - # Generate tags based on model architecture - tags = ["text-to-image"] - - if "xl" in arch.lower(): - tags.append("stable-diffusion-xl") - if "flux" in arch.lower(): - tags.append("flux") - if "lumina" in arch.lower(): - tags.append("lumina2") - if "sd3" in arch.lower() or "v3" in arch.lower(): - tags.append("sd3") - - # Add LoRA-specific tags - tags.extend(["lora", "diffusers", "template:sd-lora", "ai-toolkit"]) - - # Generate widgets from sample images and prompts - widgets = [] - if samples_dir and os.path.isdir(samples_dir): - sample_prompts = sample_config.get("samples", []) - if not sample_prompts: - # Fallback to old format - sample_prompts = [{"prompt": p} for p in sample_config.get("prompts", [])] - - # Get sample image files - sample_files = [] - if os.path.isdir(samples_dir): - for filename in os.listdir(samples_dir): - if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')): - # Parse filename pattern: timestamp__steps_index.jpg - match = re.search(r"__(\d+)_(\d+)\.jpg$", filename) - if match: - steps, index = int(match.group(1)), int(match.group(2)) - # Only use samples from final training step - final_steps = train_config.get("steps", 1000) - if steps == final_steps: - sample_files.append((index, f"samples/{filename}")) - - # Sort by index and create widgets - sample_files.sort(key=lambda x: x[0]) - - for i, prompt_obj in enumerate(sample_prompts): - prompt = prompt_obj.get("prompt", "") if isinstance(prompt_obj, dict) else str(prompt_obj) - if i < len(sample_files): - _, image_path = sample_files[i] - widgets.append({ - "text": prompt, - "output": {"url": image_path} - }) - - # Determine torch dtype based on model - dtype = "torch.bfloat16" if "flux" in arch.lower() else "torch.float16" - - # Find the main safetensors file for usage example - main_safetensors = f"{model_name}.safetensors" - if uploaded_files: - safetensors_files = [f for f in uploaded_files if f.endswith('.safetensors')] - if safetensors_files: - main_safetensors = safetensors_files[0] - - # Construct YAML frontmatter - frontmatter = { - "tags": tags, - "base_model": base_model, - **license_info - } - - if widgets: - frontmatter["widget"] = widgets - - if trigger_word: - frontmatter["instance_prompt"] = trigger_word - - # Get first prompt for usage example - usage_prompt = trigger_word or "a beautiful landscape" - if widgets: - usage_prompt = widgets[0]["text"] - elif trigger_word: - usage_prompt = trigger_word - - # Construct README content - trigger_section = f"You should use \`{trigger_word}\` to trigger the image generation." if trigger_word else "No trigger words defined." - - # Build YAML frontmatter string - frontmatter_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True, sort_keys=False).strip() - - readme_content = f"""--- -{frontmatter_yaml} ---- - -# {model_name} - -Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) - - - -## Trigger words - -{trigger_section} - -## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc. - -Weights for this model are available in Safetensors format. - -[Download]({repo_id}/tree/main) them in the Files & versions tab. - -## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) - -\`\`\`py -from diffusers import AutoPipelineForText2Image -import torch - -pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda') -pipeline.load_lora_weights('{repo_id}', weight_name='{main_safetensors}') -image = pipeline('{usage_prompt}').images[0] -image.save("my_image.png") -\`\`\` - -For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) - -""" - return readme_content - - except Exception as e: - print(f"Error generating README: {e}") - # Fallback simple README - return f"""# {model_name} - -Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) - -## Download model - -Weights for this model are available in Safetensors format. - -[Download]({repo_id}/tree/main) them in the Files & versions tab. -""" - -def main(): - # Setup environment - token comes from HF Jobs secrets - if "HF_TOKEN" not in os.environ: - raise ValueError("HF_TOKEN environment variable not set") - - # Install system dependencies for headless operation - print("Installing system dependencies...") - try: - subprocess.run(["apt-get", "update"], check=True, capture_output=True) - subprocess.run([ - "apt-get", "install", "-y", - "libgl1-mesa-glx", - "libglib2.0-0", - "libsm6", - "libxext6", - "libxrender-dev", - "libgomp1", - "ffmpeg" - ], check=True, capture_output=True) - print("System dependencies installed successfully") - except subprocess.CalledProcessError as e: - print(f"Failed to install system dependencies: {e}") - print("Continuing without system dependencies...") - - # Setup ai-toolkit - toolkit_dir = setup_ai_toolkit() - - # Create temporary directories - with tempfile.TemporaryDirectory() as temp_dir: - dataset_path = os.path.join(temp_dir, "dataset") - output_path = os.path.join(temp_dir, "output") - - # Download dataset - download_dataset("${datasetRepo}", dataset_path) - - # Create config - config = create_config(dataset_path, output_path) - config_path = os.path.join(temp_dir, "config.yaml") - - with open(config_path, "w") as f: - yaml.dump(config, f, default_flow_style=False) - - # Run training - print("Starting training...") - os.chdir(toolkit_dir) - - subprocess.run([ - sys.executable, "run.py", - config_path - ], check=True) - - print("Training completed!") - - # Upload results - model_name = f"${jobConfig.config.name}-lora" - upload_results(output_path, model_name, "${namespace}", os.environ["HF_TOKEN"], config) - -if __name__ == "__main__": - main() -`; -} - -async function submitHFJobUV(token: string, hardware: string, scriptPath: string): Promise { - return new Promise((resolve, reject) => { - // Ensure token is available - if (!token) { - reject(new Error('HF_TOKEN is required')); - return; - } - - console.log('Setting up environment with HF_TOKEN for job submission'); - console.log(`Command: hf jobs uv run --flavor ${hardware} --timeout 5h --secrets HF_TOKEN --detach ${scriptPath}`); - - // Use hf jobs uv run command with timeout and detach to get job ID - const childProcess = spawn('hf', [ - 'jobs', 'uv', 'run', - '--flavor', hardware, - '--timeout', '5h', - '--secrets', 'HF_TOKEN', - '--detach', - scriptPath - ], { - env: { - ...process.env, - HF_TOKEN: token - } - }); - - let output = ''; - let error = ''; - - childProcess.stdout.on('data', (data) => { - const text = data.toString(); - output += text; - console.log('HF Jobs stdout:', text); - }); - - childProcess.stderr.on('data', (data) => { - const text = data.toString(); - error += text; - console.log('HF Jobs stderr:', text); - }); - - childProcess.on('close', (code) => { - console.log('HF Jobs process closed with code:', code); - console.log('Full output:', output); - console.log('Full error:', error); - - if (code === 0) { - // With --detach flag, the output should be just the job ID - const fullText = (output + ' ' + error).trim(); - - // Updated patterns to handle variable-length hex job IDs (16-24+ characters) - const jobIdPatterns = [ - /Job started with ID:\s*([a-f0-9]{16,})/i, // "Job started with ID: 68b26b73767540db9fc726ac" - /job\s+([a-f0-9]{16,})/i, // "job 68b26b73767540db9fc726ac" - /Job ID:\s*([a-f0-9]{16,})/i, // "Job ID: 68b26b73767540db9fc726ac" - /created\s+job\s+([a-f0-9]{16,})/i, // "created job 68b26b73767540db9fc726ac" - /submitted.*?job\s+([a-f0-9]{16,})/i, // "submitted ... job 68b26b73767540db9fc726ac" - /https:\/\/huggingface\.co\/jobs\/[^\/]+\/([a-f0-9]{16,})/i, // URL pattern - /([a-f0-9]{20,})/i, // Fallback: any 20+ char hex string - ]; - - let jobId = 'unknown'; - - for (const pattern of jobIdPatterns) { - const match = fullText.match(pattern); - if (match && match[1] && match[1] !== 'started') { - jobId = match[1]; - console.log(`Extracted job ID using pattern: ${pattern.toString()} -> ${jobId}`); - break; - } - } - - resolve(jobId); - } else { - reject(new Error(error || output || 'Failed to submit job')); - } - }); - - childProcess.on('error', (err) => { - console.error('HF Jobs process error:', err); - reject(new Error(`Process error: ${err.message}`)); - }); - }); -} - -async function checkHFJobStatus(token: string, jobId: string): Promise { - return new Promise((resolve, reject) => { - console.log(`Checking HF Job status for: ${jobId}`); - - const childProcess = spawn('hf', [ - 'jobs', 'inspect', jobId - ], { - env: { - ...process.env, - HF_TOKEN: token - } - }); - - let output = ''; - let error = ''; - - childProcess.stdout.on('data', (data) => { - const text = data.toString(); - output += text; - }); - - childProcess.stderr.on('data', (data) => { - const text = data.toString(); - error += text; - }); - - childProcess.on('close', (code) => { - if (code === 0) { - try { - // Parse the JSON output from hf jobs inspect - const jobInfo = JSON.parse(output); - if (Array.isArray(jobInfo) && jobInfo.length > 0) { - const job = jobInfo[0]; - resolve({ - id: job.id, - status: job.status?.stage || 'UNKNOWN', - message: job.status?.message, - created_at: job.created_at, - flavor: job.flavor, - url: job.url, - }); - } else { - reject(new Error('Invalid job info response')); - } - } catch (parseError: any) { - console.error('Failed to parse job status:', parseError, output); - reject(new Error('Failed to parse job status')); - } - } else { - reject(new Error(error || output || 'Failed to check job status')); - } - }); - - childProcess.on('error', (err) => { - console.error('HF Jobs inspect process error:', err); - reject(new Error(`Process error: ${err.message}`)); - }); - }); -} \ No newline at end of file diff --git a/src/app/api/img/[...imagePath]/route.ts b/src/app/api/img/[...imagePath]/route.ts deleted file mode 100644 index 80fc727216dd6a64e402385078725443234e636a..0000000000000000000000000000000000000000 --- a/src/app/api/img/[...imagePath]/route.ts +++ /dev/null @@ -1,78 +0,0 @@ -/* eslint-disable */ -import { NextRequest, NextResponse } from 'next/server'; -import fs from 'fs'; -import path from 'path'; -import { getDatasetsRoot, getTrainingFolder, getDataRoot } from '@/server/settings'; - -export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) { - const { imagePath } = await params; - try { - // Decode the path - const filepath = decodeURIComponent(imagePath); - - // Get allowed directories - const datasetRoot = await getDatasetsRoot(); - const trainingRoot = await getTrainingFolder(); - const dataRoot = await getDataRoot(); - - const allowedDirs = [datasetRoot, trainingRoot, dataRoot]; - - // Security check: Ensure path is in allowed directory - const isAllowed = allowedDirs.some(allowedDir => filepath.startsWith(allowedDir)) && !filepath.includes('..'); - - if (!isAllowed) { - console.warn(`Access denied: ${filepath} not in ${allowedDirs.join(', ')}`); - return new NextResponse('Access denied', { status: 403 }); - } - - // Check if file exists - if (!fs.existsSync(filepath)) { - console.warn(`File not found: ${filepath}`); - return new NextResponse('File not found', { status: 404 }); - } - - // Get file info - const stat = fs.statSync(filepath); - if (!stat.isFile()) { - return new NextResponse('Not a file', { status: 400 }); - } - - // Determine content type - const ext = path.extname(filepath).toLowerCase(); - const contentTypeMap: { [key: string]: string } = { - // Images - '.jpg': 'image/jpeg', - '.jpeg': 'image/jpeg', - '.png': 'image/png', - '.gif': 'image/gif', - '.webp': 'image/webp', - '.svg': 'image/svg+xml', - '.bmp': 'image/bmp', - // Videos - '.mp4': 'video/mp4', - '.avi': 'video/x-msvideo', - '.mov': 'video/quicktime', - '.mkv': 'video/x-matroska', - '.wmv': 'video/x-ms-wmv', - '.m4v': 'video/x-m4v', - '.flv': 'video/x-flv' - }; - - const contentType = contentTypeMap[ext] || 'application/octet-stream'; - - // Read file as buffer - const fileBuffer = fs.readFileSync(filepath); - - // Return file with appropriate headers - return new NextResponse(fileBuffer, { - headers: { - 'Content-Type': contentType, - 'Content-Length': String(stat.size), - 'Cache-Control': 'public, max-age=86400', - }, - }); - } catch (error) { - console.error('Error serving image:', error); - return new NextResponse('Internal Server Error', { status: 500 }); - } -} diff --git a/src/app/api/img/caption/route.ts b/src/app/api/img/caption/route.ts deleted file mode 100644 index df4235f99986dedf253b45b802537b4b559b43ca..0000000000000000000000000000000000000000 --- a/src/app/api/img/caption/route.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { NextResponse } from 'next/server'; -import fs from 'fs'; -import { getDatasetsRoot } from '@/server/settings'; - -export async function POST(request: Request) { - try { - const body = await request.json(); - const { imgPath, caption } = body; - let datasetsPath = await getDatasetsRoot(); - // make sure the dataset path is in the image path - if (!imgPath.startsWith(datasetsPath)) { - return NextResponse.json({ error: 'Invalid image path' }, { status: 400 }); - } - - // if img doesnt exist, ignore - if (!fs.existsSync(imgPath)) { - return NextResponse.json({ error: 'Image does not exist' }, { status: 404 }); - } - - // check for caption - const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt'; - // save caption to file - fs.writeFileSync(captionPath, caption); - - return NextResponse.json({ success: true }); - } catch (error) { - return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); - } -} diff --git a/src/app/api/img/delete/route.ts b/src/app/api/img/delete/route.ts deleted file mode 100644 index d4d968f8eab6f6b1d9c988c3fd86aee2d6c2fe4f..0000000000000000000000000000000000000000 --- a/src/app/api/img/delete/route.ts +++ /dev/null @@ -1,34 +0,0 @@ -import { NextResponse } from 'next/server'; -import fs from 'fs'; -import { getDatasetsRoot } from '@/server/settings'; - -export async function POST(request: Request) { - try { - const body = await request.json(); - const { imgPath } = body; - let datasetsPath = await getDatasetsRoot(); - // make sure the dataset path is in the image path - if (!imgPath.startsWith(datasetsPath)) { - return NextResponse.json({ error: 'Invalid image path' }, { status: 400 }); - } - - // if img doesnt exist, ignore - if (!fs.existsSync(imgPath)) { - return NextResponse.json({ success: true }); - } - - // delete it and return success - fs.unlinkSync(imgPath); - - // check for caption - const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt'; - if (fs.existsSync(captionPath)) { - // delete caption file - fs.unlinkSync(captionPath); - } - - return NextResponse.json({ success: true }); - } catch (error) { - return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); - } -} diff --git a/src/app/api/img/upload/route.ts b/src/app/api/img/upload/route.ts deleted file mode 100644 index 56615bd06c4bfee9e7aef4b81a620d4c8c7cbcb7..0000000000000000000000000000000000000000 --- a/src/app/api/img/upload/route.ts +++ /dev/null @@ -1,58 +0,0 @@ -// src/app/api/datasets/upload/route.ts -import { NextRequest, NextResponse } from 'next/server'; -import { writeFile, mkdir } from 'fs/promises'; -import { join } from 'path'; -import { getDataRoot } from '@/server/settings'; -import {v4 as uuidv4} from 'uuid'; - -export async function POST(request: NextRequest) { - try { - const dataRoot = await getDataRoot(); - if (!dataRoot) { - return NextResponse.json({ error: 'Data root path not found' }, { status: 500 }); - } - const imgRoot = join(dataRoot, 'images'); - - - const formData = await request.formData(); - const files = formData.getAll('files'); - - if (!files || files.length === 0) { - return NextResponse.json({ error: 'No files provided' }, { status: 400 }); - } - - // make it recursive if it doesn't exist - await mkdir(imgRoot, { recursive: true }); - const savedFiles = await Promise.all( - files.map(async (file: any) => { - const bytes = await file.arrayBuffer(); - const buffer = Buffer.from(bytes); - - const extension = file.name.split('.').pop() || 'jpg'; - - // Clean filename and ensure it's unique - const fileName = `${uuidv4()}`; // Use UUID for unique file names - const filePath = join(imgRoot, `${fileName}.${extension}`); - - await writeFile(filePath, buffer); - return filePath; - }), - ); - - return NextResponse.json({ - message: 'Files uploaded successfully', - files: savedFiles, - }); - } catch (error) { - console.error('Upload error:', error); - return NextResponse.json({ error: 'Error uploading files' }, { status: 500 }); - } -} - -// Increase payload size limit (default is 4mb) -export const config = { - api: { - bodyParser: false, - responseLimit: '50mb', - }, -}; diff --git a/src/app/api/jobs/[jobID]/delete/route.ts b/src/app/api/jobs/[jobID]/delete/route.ts deleted file mode 100644 index 618e33f440301495c47141bff70b99b43438c4a3..0000000000000000000000000000000000000000 --- a/src/app/api/jobs/[jobID]/delete/route.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { PrismaClient } from '@prisma/client'; -import { getTrainingFolder } from '@/server/settings'; -import path from 'path'; -import fs from 'fs'; - -const prisma = new PrismaClient(); - -export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { - const { jobID } = await params; - - const job = await prisma.job.findUnique({ - where: { id: jobID }, - }); - - if (!job) { - return NextResponse.json({ error: 'Job not found' }, { status: 404 }); - } - - const trainingRoot = await getTrainingFolder(); - const trainingFolder = path.join(trainingRoot, job.name); - - if (fs.existsSync(trainingFolder)) { - fs.rmdirSync(trainingFolder, { recursive: true }); - } - - await prisma.job.delete({ - where: { id: jobID }, - }); - - return NextResponse.json(job); -} diff --git a/src/app/api/jobs/[jobID]/files/route.ts b/src/app/api/jobs/[jobID]/files/route.ts deleted file mode 100644 index 575df5e5a68cc8739aac16b55f2631d267b040fe..0000000000000000000000000000000000000000 --- a/src/app/api/jobs/[jobID]/files/route.ts +++ /dev/null @@ -1,48 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { PrismaClient } from '@prisma/client'; -import path from 'path'; -import fs from 'fs'; -import { getTrainingFolder } from '@/server/settings'; - -const prisma = new PrismaClient(); - -export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { - const { jobID } = await params; - - const job = await prisma.job.findUnique({ - where: { id: jobID }, - }); - - if (!job) { - return NextResponse.json({ error: 'Job not found' }, { status: 404 }); - } - - const trainingFolder = await getTrainingFolder(); - const jobFolder = path.join(trainingFolder, job.name); - - if (!fs.existsSync(jobFolder)) { - return NextResponse.json({ files: [] }); - } - - // find all safetensors files in the job folder - let files = fs - .readdirSync(jobFolder) - .filter(file => { - return file.endsWith('.safetensors'); - }) - .map(file => { - return path.join(jobFolder, file); - }) - .sort(); - - // get the file size for each file - const fileObjects = files.map(file => { - const stats = fs.statSync(file); - return { - path: file, - size: stats.size, - }; - }); - - return NextResponse.json({ files: fileObjects }); -} diff --git a/src/app/api/jobs/[jobID]/log/route.ts b/src/app/api/jobs/[jobID]/log/route.ts deleted file mode 100644 index 10ccbdaac76b76ec20cead8e7f634af0d723ad8f..0000000000000000000000000000000000000000 --- a/src/app/api/jobs/[jobID]/log/route.ts +++ /dev/null @@ -1,35 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { PrismaClient } from '@prisma/client'; -import path from 'path'; -import fs from 'fs'; -import { getTrainingFolder } from '@/server/settings'; - -const prisma = new PrismaClient(); - -export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { - const { jobID } = await params; - - const job = await prisma.job.findUnique({ - where: { id: jobID }, - }); - - if (!job) { - return NextResponse.json({ error: 'Job not found' }, { status: 404 }); - } - - const trainingFolder = await getTrainingFolder(); - const jobFolder = path.join(trainingFolder, job.name); - const logPath = path.join(jobFolder, 'log.txt'); - - if (!fs.existsSync(logPath)) { - return NextResponse.json({ log: '' }); - } - let log = ''; - try { - log = fs.readFileSync(logPath, 'utf-8'); - } catch (error) { - console.error('Error reading log file:', error); - log = 'Error reading log file'; - } - return NextResponse.json({ log: log }); -} diff --git a/src/app/api/jobs/[jobID]/samples/route.ts b/src/app/api/jobs/[jobID]/samples/route.ts deleted file mode 100644 index 2a98a6eac1a7581243aa7adfec6da5d5a40c938c..0000000000000000000000000000000000000000 --- a/src/app/api/jobs/[jobID]/samples/route.ts +++ /dev/null @@ -1,40 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { PrismaClient } from '@prisma/client'; -import path from 'path'; -import fs from 'fs'; -import { getTrainingFolder } from '@/server/settings'; - -const prisma = new PrismaClient(); - -export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { - const { jobID } = await params; - - const job = await prisma.job.findUnique({ - where: { id: jobID }, - }); - - if (!job) { - return NextResponse.json({ error: 'Job not found' }, { status: 404 }); - } - - // setup the training - const trainingFolder = await getTrainingFolder(); - - const samplesFolder = path.join(trainingFolder, job.name, 'samples'); - if (!fs.existsSync(samplesFolder)) { - return NextResponse.json({ samples: [] }); - } - - // find all img (png, jpg, jpeg) files in the samples folder - const samples = fs - .readdirSync(samplesFolder) - .filter(file => { - return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp'); - }) - .map(file => { - return path.join(samplesFolder, file); - }) - .sort(); - - return NextResponse.json({ samples }); -} diff --git a/src/app/api/jobs/[jobID]/start/route.ts b/src/app/api/jobs/[jobID]/start/route.ts deleted file mode 100644 index e26c1e499373e1aa3821f2031472ec0e0727526f..0000000000000000000000000000000000000000 --- a/src/app/api/jobs/[jobID]/start/route.ts +++ /dev/null @@ -1,215 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { PrismaClient } from '@prisma/client'; -import { TOOLKIT_ROOT } from '@/paths'; -import { spawn } from 'child_process'; -import path from 'path'; -import fs from 'fs'; -import os from 'os'; -import { getTrainingFolder, getHFToken } from '@/server/settings'; -const isWindows = process.platform === 'win32'; - -const prisma = new PrismaClient(); - -export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { - const { jobID } = await params; - - const job = await prisma.job.findUnique({ - where: { id: jobID }, - }); - - if (!job) { - return NextResponse.json({ error: 'Job not found' }, { status: 404 }); - } - - // update job status to 'running' - await prisma.job.update({ - where: { id: jobID }, - data: { - status: 'running', - stop: false, - info: 'Starting job...', - }, - }); - - // setup the training - const trainingRoot = await getTrainingFolder(); - - const trainingFolder = path.join(trainingRoot, job.name); - if (!fs.existsSync(trainingFolder)) { - fs.mkdirSync(trainingFolder, { recursive: true }); - } - - // make the config file - const configPath = path.join(trainingFolder, '.job_config.json'); - - //log to path - const logPath = path.join(trainingFolder, 'log.txt'); - - try { - // if the log path exists, move it to a folder called logs and rename it {num}_log.txt, looking for the highest num - // if the log path does not exist, create it - if (fs.existsSync(logPath)) { - const logsFolder = path.join(trainingFolder, 'logs'); - if (!fs.existsSync(logsFolder)) { - fs.mkdirSync(logsFolder, { recursive: true }); - } - - let num = 0; - while (fs.existsSync(path.join(logsFolder, `${num}_log.txt`))) { - num++; - } - - fs.renameSync(logPath, path.join(logsFolder, `${num}_log.txt`)); - } - } catch (e) { - console.error('Error moving log file:', e); - } - - // update the config dataset path - const jobConfig = JSON.parse(job.job_config); - jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db'); - - // write the config file - fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2)); - - let pythonPath = 'python'; - // use .venv or venv if it exists - if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) { - if (isWindows) { - pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe'); - } else { - pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python'); - } - } else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) { - if (isWindows) { - pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe'); - } else { - pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python'); - } - } - - const runFilePath = path.join(TOOLKIT_ROOT, 'run.py'); - if (!fs.existsSync(runFilePath)) { - return NextResponse.json({ error: 'run.py not found' }, { status: 500 }); - } - - const additionalEnv: any = { - AITK_JOB_ID: jobID, - CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`, - IS_AI_TOOLKIT_UI: '1' - }; - - // HF_TOKEN - const hfToken = await getHFToken(); - if (hfToken && hfToken.trim() !== '') { - additionalEnv.HF_TOKEN = hfToken; - } - - // Add the --log argument to the command - const args = [runFilePath, configPath, '--log', logPath]; - - try { - let subprocess; - - if (isWindows) { - // For Windows, use 'cmd.exe' to open a new command window - subprocess = spawn('cmd.exe', ['/c', 'start', 'cmd.exe', '/k', pythonPath, ...args], { - env: { - ...process.env, - ...additionalEnv, - }, - cwd: TOOLKIT_ROOT, - windowsHide: false, - }); - } else { - // For non-Windows platforms - subprocess = spawn(pythonPath, args, { - detached: true, - stdio: ['ignore', 'pipe', 'pipe'], // Changed from 'ignore' to capture output - env: { - ...process.env, - ...additionalEnv, - }, - cwd: TOOLKIT_ROOT, - }); - } - - // Start monitoring in the background without blocking the response - const monitorProcess = async () => { - const startTime = Date.now(); - let errorOutput = ''; - let stdoutput = ''; - - if (subprocess.stderr) { - subprocess.stderr.on('data', data => { - errorOutput += data.toString(); - }); - subprocess.stdout.on('data', data => { - stdoutput += data.toString(); - // truncate to only get the last 500 characters - if (stdoutput.length > 500) { - stdoutput = stdoutput.substring(stdoutput.length - 500); - } - }); - } - - subprocess.on('exit', async code => { - const currentTime = Date.now(); - const duration = (currentTime - startTime) / 1000; - console.log(`Job ${jobID} exited with code ${code} after ${duration} seconds.`); - // wait for 5 seconds to give it time to stop itself. It id still has a status of running in the db, update it to stopped - await new Promise(resolve => setTimeout(resolve, 5000)); - const updatedJob = await prisma.job.findUnique({ - where: { id: jobID }, - }); - if (updatedJob?.status === 'running') { - let errorString = errorOutput; - if (errorString.trim() === '') { - errorString = stdoutput; - } - await prisma.job.update({ - where: { id: jobID }, - data: { - status: 'error', - info: `Error launching job: ${errorString.substring(0, 500)}`, - }, - }); - } - }); - - // Wait 30 seconds before releasing the process - await new Promise(resolve => setTimeout(resolve, 30000)); - // Detach the process for non-Windows systems - if (!isWindows && subprocess.unref) { - subprocess.unref(); - } - }; - - // Start the monitoring without awaiting it - monitorProcess().catch(err => { - console.error(`Error in process monitoring for job ${jobID}:`, err); - }); - - // Return the response immediately - return NextResponse.json(job); - } catch (error: any) { - // Handle any exceptions during process launch - console.error('Error launching process:', error); - - await prisma.job.update({ - where: { id: jobID }, - data: { - status: 'error', - info: `Error launching job: ${error?.message || 'Unknown error'}`, - }, - }); - - return NextResponse.json( - { - error: 'Failed to launch job process', - details: error?.message || 'Unknown error', - }, - { status: 500 }, - ); - } -} diff --git a/src/app/api/jobs/[jobID]/stop/route.ts b/src/app/api/jobs/[jobID]/stop/route.ts deleted file mode 100644 index 73b352dfc55664b1b689075727f7245589523005..0000000000000000000000000000000000000000 --- a/src/app/api/jobs/[jobID]/stop/route.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { NextRequest, NextResponse } from 'next/server'; -import { PrismaClient } from '@prisma/client'; - -const prisma = new PrismaClient(); - -export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { - const { jobID } = await params; - - const job = await prisma.job.findUnique({ - where: { id: jobID }, - }); - - // update job status to 'running' - await prisma.job.update({ - where: { id: jobID }, - data: { - stop: true, - info: 'Stopping job...', - }, - }); - - return NextResponse.json(job); -} diff --git a/src/app/api/jobs/route.ts b/src/app/api/jobs/route.ts deleted file mode 100644 index 8f0419b924cfa6724371712b279e89c666437eb6..0000000000000000000000000000000000000000 --- a/src/app/api/jobs/route.ts +++ /dev/null @@ -1,67 +0,0 @@ -import { NextResponse } from 'next/server'; -import { PrismaClient } from '@prisma/client'; - -const prisma = new PrismaClient(); - -export async function GET(request: Request) { - const { searchParams } = new URL(request.url); - const id = searchParams.get('id'); - - try { - if (id) { - const job = await prisma.job.findUnique({ - where: { id }, - }); - return NextResponse.json(job); - } - - const jobs = await prisma.job.findMany({ - orderBy: { created_at: 'desc' }, - }); - return NextResponse.json({ jobs: jobs }); - } catch (error) { - console.error(error); - return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 }); - } -} - -export async function POST(request: Request) { - try { - const body = await request.json(); - const { id, name, job_config, gpu_ids } = body; - - // Ensure gpu_ids is never null/undefined - provide default value - const safeGpuIds = gpu_ids || '0'; - - if (id) { - // Update existing training - const training = await prisma.job.update({ - where: { id }, - data: { - name, - gpu_ids: safeGpuIds, - job_config: JSON.stringify(job_config), - }, - }); - return NextResponse.json(training); - } else { - // Create new training - const training = await prisma.job.create({ - data: { - name, - gpu_ids: safeGpuIds, - job_config: JSON.stringify(job_config), - }, - }); - return NextResponse.json(training); - } - } catch (error: any) { - if (error.code === 'P2002') { - // Handle unique constraint violation, 409=Conflict - return NextResponse.json({ error: 'Job name already exists' }, { status: 409 }); - } - console.error(error); - // Handle other errors - return NextResponse.json({ error: 'Failed to save training data' }, { status: 500 }); - } -} diff --git a/src/app/api/settings/route.ts b/src/app/api/settings/route.ts deleted file mode 100644 index 62528cdd0b6a7de39c7ade3e96ea9f0b1ec2a226..0000000000000000000000000000000000000000 --- a/src/app/api/settings/route.ts +++ /dev/null @@ -1,59 +0,0 @@ -import { NextResponse } from 'next/server'; -import { PrismaClient } from '@prisma/client'; -import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths'; -import { flushCache } from '@/server/settings'; - -const prisma = new PrismaClient(); - -export async function GET() { - try { - const settings = await prisma.settings.findMany(); - const settingsObject = settings.reduce((acc: any, setting) => { - acc[setting.key] = setting.value; - return acc; - }, {}); - // if TRAINING_FOLDER is not set, use default - if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') { - settingsObject.TRAINING_FOLDER = defaultTrainFolder; - } - // if DATASETS_FOLDER is not set, use default - if (!settingsObject.DATASETS_FOLDER || settingsObject.DATASETS_FOLDER === '') { - settingsObject.DATASETS_FOLDER = defaultDatasetsFolder; - } - return NextResponse.json(settingsObject); - } catch (error) { - return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 }); - } -} - -export async function POST(request: Request) { - try { - const body = await request.json(); - const { HF_TOKEN, TRAINING_FOLDER, DATASETS_FOLDER } = body; - - // Upsert both settings - await Promise.all([ - prisma.settings.upsert({ - where: { key: 'HF_TOKEN' }, - update: { value: HF_TOKEN }, - create: { key: 'HF_TOKEN', value: HF_TOKEN }, - }), - prisma.settings.upsert({ - where: { key: 'TRAINING_FOLDER' }, - update: { value: TRAINING_FOLDER }, - create: { key: 'TRAINING_FOLDER', value: TRAINING_FOLDER }, - }), - prisma.settings.upsert({ - where: { key: 'DATASETS_FOLDER' }, - update: { value: DATASETS_FOLDER }, - create: { key: 'DATASETS_FOLDER', value: DATASETS_FOLDER }, - }), - ]); - - flushCache(); - - return NextResponse.json({ success: true }); - } catch (error) { - return NextResponse.json({ error: 'Failed to update settings' }, { status: 500 }); - } -} diff --git a/src/app/api/zip/route.ts b/src/app/api/zip/route.ts deleted file mode 100644 index fc4b946da5f6265d4d193849bf218fea41ea6e01..0000000000000000000000000000000000000000 --- a/src/app/api/zip/route.ts +++ /dev/null @@ -1,78 +0,0 @@ -/* eslint-disable */ -import { NextRequest, NextResponse } from 'next/server'; -import fs from 'fs'; -import fsp from 'fs/promises'; -import path from 'path'; -import archiver from 'archiver'; -import { getTrainingFolder } from '@/server/settings'; - -export const runtime = 'nodejs'; // ensure Node APIs are available -export const dynamic = 'force-dynamic'; // long-running, non-cached - -type PostBody = { - zipTarget: 'samples'; //only samples for now - jobName: string; -}; - -async function resolveSafe(p: string) { - // resolve symlinks + normalize - return await fsp.realpath(p); -} - -export async function POST(request: NextRequest) { - try { - const body = (await request.json()) as PostBody; - if (!body || !body.jobName) { - return NextResponse.json({ error: 'jobName is required' }, { status: 400 }); - } - - const trainingRoot = await resolveSafe(await getTrainingFolder()); - const folderPath = await resolveSafe(path.join(trainingRoot, body.jobName, 'samples')); - const outputPath = path.resolve(trainingRoot, body.jobName, 'samples.zip'); - - // Must be a directory - let stat: fs.Stats; - try { - stat = await fsp.stat(folderPath); - } catch { - return new NextResponse('Folder not found', { status: 404 }); - } - if (!stat.isDirectory()) { - return new NextResponse('Not a directory', { status: 400 }); - } - - // delete current one if it exists - if (fs.existsSync(outputPath)) { - await fsp.unlink(outputPath); - } - - // Create write stream & archive - await new Promise((resolve, reject) => { - const output = fs.createWriteStream(outputPath); - const archive = archiver('zip', { zlib: { level: 9 } }); - - output.on('close', () => resolve()); - output.on('error', reject); - archive.on('error', reject); - - archive.pipe(output); - - // Add the directory contents (place them under the folder's base name in the zip) - const rootName = path.basename(folderPath); - archive.directory(folderPath, rootName); - - archive.finalize().catch(reject); - }); - - // Return the absolute path so your existing /api/files/[...filePath] can serve it - // Example download URL (client-side): `/api/files/${encodeURIComponent(resolvedOutPath)}` - return NextResponse.json({ - ok: true, - zipPath: outputPath, - fileName: path.basename(outputPath), - }); - } catch (err) { - console.error('Zip error:', err); - return new NextResponse('Internal Server Error', { status: 500 }); - } -} diff --git a/src/app/apple-icon.png b/src/app/apple-icon.png deleted file mode 100644 index 595cb880e5cff0ab9605c2ef76dba8ebb7e7fc62..0000000000000000000000000000000000000000 Binary files a/src/app/apple-icon.png and /dev/null differ diff --git a/src/app/dashboard/page.tsx b/src/app/dashboard/page.tsx deleted file mode 100644 index 45d5596afc5831c419579afafdfe3cd515c4e3d0..0000000000000000000000000000000000000000 --- a/src/app/dashboard/page.tsx +++ /dev/null @@ -1,85 +0,0 @@ -'use client'; - -import JobsTable from '@/components/JobsTable'; -import { TopBar, MainContent } from '@/components/layout'; -import Link from 'next/link'; -import { useAuth } from '@/contexts/AuthContext'; -import HFLoginButton from '@/components/HFLoginButton'; - -export default function Dashboard() { - const { status: authStatus, namespace } = useAuth(); - const isAuthenticated = authStatus === 'authenticated'; - - return ( - <> - -
-

Dashboard

-
-
- - -
-
-

- {isAuthenticated ? `Welcome back, ${namespace || 'creator'}!` : 'Welcome to Ostris AI Toolkit'} -

-

- {isAuthenticated - ? 'You are signed in with Hugging Face and can manage jobs, datasets, and submissions.' - : 'Authenticate with Hugging Face or add a personal access token to create jobs, upload datasets, and launch training.'} -

-
- {isAuthenticated ? ( -
- - Create a Training Job - - - Manage Datasets - - - Settings - -
- ) : ( -
- - - Or manage tokens in Settings - -
- )} -
- -
-
-

Active Jobs

-
- View All -
-
- {isAuthenticated ? ( - - ) : ( -
- Sign in with Hugging Face or add an access token in Settings to view and manage jobs. -
- )} -
-
- - ); -} diff --git a/src/app/datasets/[datasetName]/page.tsx b/src/app/datasets/[datasetName]/page.tsx deleted file mode 100644 index 776eeb525fdacf0621828734ffdd79bbd21697a8..0000000000000000000000000000000000000000 --- a/src/app/datasets/[datasetName]/page.tsx +++ /dev/null @@ -1,190 +0,0 @@ -'use client'; - -import { useEffect, useState, use, useMemo } from 'react'; -import { LuImageOff, LuLoader, LuBan } from 'react-icons/lu'; -import { FaChevronLeft } from 'react-icons/fa'; -import DatasetImageCard from '@/components/DatasetImageCard'; -import { Button } from '@headlessui/react'; -import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal'; -import { TopBar, MainContent } from '@/components/layout'; -import { apiClient } from '@/utils/api'; -import FullscreenDropOverlay from '@/components/FullscreenDropOverlay'; -import { useRouter } from 'next/navigation'; -import { usingBrowserDb } from '@/utils/env'; -import { hasUserDataset } from '@/utils/storage/datasetStorage'; -import { useAuth } from '@/contexts/AuthContext'; -import HFLoginButton from '@/components/HFLoginButton'; -import Link from 'next/link'; - -export default function DatasetPage({ params }: { params: { datasetName: string } }) { - const [imgList, setImgList] = useState<{ img_path: string }[]>([]); - const usableParams = use(params as any) as { datasetName: string }; - const datasetName = usableParams.datasetName; - const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle'); - const router = useRouter(); - const { status: authStatus } = useAuth(); - const isAuthenticated = authStatus === 'authenticated'; - const hasDatasetEntry = !usingBrowserDb || hasUserDataset(datasetName); - const allowAccess = hasDatasetEntry && isAuthenticated; - - const refreshImageList = (dbName: string) => { - setStatus('loading'); - console.log('Fetching images for dataset:', dbName); - apiClient - .post('/api/datasets/listImages', { datasetName: dbName }) - .then((res: any) => { - const data = res.data; - console.log('Images:', data.images); - // sort - data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path)); - setImgList(data.images); - setStatus('success'); - }) - .catch(error => { - console.error('Error fetching images:', error); - setStatus('error'); - }); - }; - useEffect(() => { - if (!datasetName) { - return; - } - - if (!isAuthenticated) { - return; - } - - if (!hasDatasetEntry) { - setImgList([]); - setStatus('error'); - router.replace('/datasets'); - return; - } - - refreshImageList(datasetName); - }, [datasetName, hasDatasetEntry, isAuthenticated, router]); - - if (!allowAccess) { - return ( - <> - -
- -
-
-

Dataset: {datasetName}

-
-
-
- -
-

You need to sign in with Hugging Face or provide a valid token to view this dataset.

-
- - - Manage authentication in Settings - -
-
-
- - ); - } - - const PageInfoContent = useMemo(() => { - let icon = null; - let text = ''; - let subtitle = ''; - let showIt = false; - let bgColor = ''; - let textColor = ''; - let iconColor = ''; - - if (status == 'loading') { - icon = ; - text = 'Loading Images'; - subtitle = 'Please wait while we fetch your dataset images...'; - showIt = true; - bgColor = 'bg-gray-50 dark:bg-gray-800/50'; - textColor = 'text-gray-900 dark:text-gray-100'; - iconColor = 'text-gray-500 dark:text-gray-400'; - } - if (status == 'error') { - icon = ; - text = 'Error Loading Images'; - subtitle = 'There was a problem fetching the images. Please try refreshing the page.'; - showIt = true; - bgColor = 'bg-red-50 dark:bg-red-950/20'; - textColor = 'text-red-900 dark:text-red-100'; - iconColor = 'text-red-600 dark:text-red-400'; - } - if (status == 'success' && imgList.length === 0) { - icon = ; - text = 'No Images Found'; - subtitle = 'This dataset is empty. Click "Add Images" to get started.'; - showIt = true; - bgColor = 'bg-gray-50 dark:bg-gray-800/50'; - textColor = 'text-gray-900 dark:text-gray-100'; - iconColor = 'text-gray-500 dark:text-gray-400'; - } - - if (!showIt) return null; - - return ( -
-
{icon}
-

{text}

-

{subtitle}

-
- ); - }, [status, imgList.length]); - - return ( - <> - {/* Fixed top bar */} - -
- -
-
-

Dataset: {datasetName}

-
-
-
- -
-
- - {PageInfoContent} - {status === 'success' && imgList.length > 0 && ( -
- {imgList.map(img => ( - refreshImageList(datasetName)} - /> - ))} -
- )} -
- - refreshImageList(datasetName)} - /> - - ); -} diff --git a/src/app/datasets/page.tsx b/src/app/datasets/page.tsx deleted file mode 100644 index eec8310f9ba6f38f5eca345a0b6400e754241a64..0000000000000000000000000000000000000000 --- a/src/app/datasets/page.tsx +++ /dev/null @@ -1,217 +0,0 @@ -'use client'; - -import { useState } from 'react'; -import { Modal } from '@/components/Modal'; -import Link from 'next/link'; -import { TextInput } from '@/components/formInputs'; -import useDatasetList from '@/hooks/useDatasetList'; -import { Button } from '@headlessui/react'; -import { FaRegTrashAlt } from 'react-icons/fa'; -import { openConfirm } from '@/components/ConfirmModal'; -import { TopBar, MainContent } from '@/components/layout'; -import UniversalTable, { TableColumn } from '@/components/UniversalTable'; -import { apiClient } from '@/utils/api'; -import { useRouter } from 'next/navigation'; -import { usingBrowserDb } from '@/utils/env'; -import { addUserDataset, removeUserDataset } from '@/utils/storage/datasetStorage'; -import { useAuth } from '@/contexts/AuthContext'; -import HFLoginButton from '@/components/HFLoginButton'; - -export default function Datasets() { - const router = useRouter(); - const { datasets, status, refreshDatasets } = useDatasetList(); - const [newDatasetName, setNewDatasetName] = useState(''); - const [isNewDatasetModalOpen, setIsNewDatasetModalOpen] = useState(false); - const { status: authStatus } = useAuth(); - const isAuthenticated = authStatus === 'authenticated'; - - // Transform datasets array into rows with objects - const tableRows = datasets.map(dataset => ({ - name: dataset, - actions: dataset, // Pass full dataset name for actions - })); - - const columns: TableColumn[] = [ - { - title: 'Dataset Name', - key: 'name', - render: row => ( - - {row.name} - - ), - }, - { - title: 'Actions', - key: 'actions', - className: 'w-20 text-right', - render: row => ( - - ), - }, - ]; - - const handleDeleteDataset = (datasetName: string) => { - openConfirm({ - title: 'Delete Dataset', - message: `Are you sure you want to delete the dataset "${datasetName}"? This action cannot be undone.`, - type: 'warning', - confirmText: 'Delete', - onConfirm: () => { - apiClient - .post('/api/datasets/delete', { name: datasetName }) - .then(() => { - console.log('Dataset deleted:', datasetName); - if (usingBrowserDb) { - removeUserDataset(datasetName); - } - refreshDatasets(); - }) - .catch(error => { - console.error('Error deleting dataset:', error); - }); - }, - }); - }; - - const handleCreateDataset = async (e: React.FormEvent) => { - e.preventDefault(); - if (!isAuthenticated) { - return; - } - try { - const data = await apiClient.post('/api/datasets/create', { name: newDatasetName }).then(res => res.data); - console.log('New dataset created:', data); - if (usingBrowserDb && data?.name) { - addUserDataset(data.name, data?.path || ''); - } - refreshDatasets(); - setNewDatasetName(''); - setIsNewDatasetModalOpen(false); - } catch (error) { - console.error('Error creating new dataset:', error); - } - }; - - const openNewDatasetModal = () => { - if (!isAuthenticated) { - return; - } - openConfirm({ - title: 'New Dataset', - message: 'Enter the name of the new dataset:', - type: 'info', - confirmText: 'Create', - inputTitle: 'Dataset Name', - onConfirm: async (name?: string) => { - if (!name) { - console.error('Dataset name is required.'); - return; - } - if (!isAuthenticated) { - return; - } - try { - const data = await apiClient.post('/api/datasets/create', { name }).then(res => res.data); - console.log('New dataset created:', data); - if (usingBrowserDb && data?.name) { - addUserDataset(data.name, data?.path || ''); - } - if (data.name) { - router.push(`/datasets/${data.name}`); - } else { - refreshDatasets(); - } - } catch (error) { - console.error('Error creating new dataset:', error); - } - }, - }); - }; - - return ( - <> - -
-

Datasets

-
-
-
- {isAuthenticated ? ( - - ) : ( - - Sign in to add datasets - - )} -
-
- - - {isAuthenticated ? ( - - ) : ( -
-

Sign in with Hugging Face or add an access token to manage datasets.

-
- - - Manage authentication in Settings - -
-
- )} -
- - setIsNewDatasetModalOpen(false)} - title="New Dataset" - size="md" - > -
-
-
- This will create a new folder with the name below in your dataset folder. -
-
- setNewDatasetName(value)} /> -
- -
- - -
-
-
-
- - ); -} diff --git a/src/app/favicon.ico b/src/app/favicon.ico deleted file mode 100644 index a20b629a5996a0b62c038bf356f1e28eab9bdb99..0000000000000000000000000000000000000000 Binary files a/src/app/favicon.ico and /dev/null differ diff --git a/src/app/globals.css b/src/app/globals.css deleted file mode 100644 index 890dc5bc7b9125662f38d11d758350ba5a80f744..0000000000000000000000000000000000000000 --- a/src/app/globals.css +++ /dev/null @@ -1,72 +0,0 @@ -@tailwind base; -@tailwind components; -@tailwind utilities; - -:root { - --background: #ffffff; - --foreground: #171717; -} - -@media (prefers-color-scheme: dark) { - :root { - --background: #0a0a0a; - --foreground: #ededed; - } -} - -body { - color: var(--foreground); - background: var(--background); - font-family: Arial, Helvetica, sans-serif; -} - -@layer components { - /* control */ - .aitk-react-select-container .aitk-react-select__control { - @apply flex w-full h-8 min-h-0 px-0 text-sm bg-gray-800 border border-gray-700 rounded-sm hover:border-gray-600 items-center; - } - - /* selected label */ - .aitk-react-select-container .aitk-react-select__single-value { - @apply flex-1 min-w-0 truncate text-sm text-neutral-200; - } - - /* invisible input (keeps focus & typing, never wraps) */ - .aitk-react-select-container .aitk-react-select__input-container { - @apply text-neutral-200; - } - - /* focus */ - .aitk-react-select-container .aitk-react-select__control--is-focused { - @apply ring-2 ring-gray-600 border-transparent hover:border-transparent shadow-none; - } - - /* menu */ - .aitk-react-select-container .aitk-react-select__menu { - @apply bg-gray-800 border border-gray-700; - } - - /* options */ - .aitk-react-select-container .aitk-react-select__option { - @apply text-sm text-neutral-200 bg-gray-800 hover:bg-gray-700; - } - - /* indicator separator */ - .aitk-react-select-container .aitk-react-select__indicator-separator { - @apply bg-gray-600; - } - - /* indicators */ - .aitk-react-select-container .aitk-react-select__indicators, - .aitk-react-select-container .aitk-react-select__indicator { - @apply py-0 flex items-center; - } - - /* placeholder */ - .aitk-react-select-container .aitk-react-select__placeholder { - @apply text-sm text-neutral-200; - } -} - - - diff --git a/src/app/icon.png b/src/app/icon.png deleted file mode 100644 index 8bcfbf80f1f08f9b1f6678914370f00a105a37b2..0000000000000000000000000000000000000000 Binary files a/src/app/icon.png and /dev/null differ diff --git a/src/app/icon.svg b/src/app/icon.svg deleted file mode 100644 index 2689ae5393931a68144db7d92555343aeef0155c..0000000000000000000000000000000000000000 --- a/src/app/icon.svg +++ /dev/null @@ -1,3 +0,0 @@ - \ No newline at end of file diff --git a/src/app/jobs/[jobID]/page.tsx b/src/app/jobs/[jobID]/page.tsx deleted file mode 100644 index ae001714d36deec0f26e52f0d1542f4684a7ef7e..0000000000000000000000000000000000000000 --- a/src/app/jobs/[jobID]/page.tsx +++ /dev/null @@ -1,147 +0,0 @@ -'use client'; - -import { useState, use } from 'react'; -import { FaChevronLeft } from 'react-icons/fa'; -import { Button } from '@headlessui/react'; -import { TopBar, MainContent } from '@/components/layout'; -import useJob from '@/hooks/useJob'; -import SampleImages, {SampleImagesMenu} from '@/components/SampleImages'; -import JobOverview from '@/components/JobOverview'; -import { redirect } from 'next/navigation'; -import { useAuth } from '@/contexts/AuthContext'; -import HFLoginButton from '@/components/HFLoginButton'; -import Link from 'next/link'; -import JobActionBar from '@/components/JobActionBar'; -import JobConfigViewer from '@/components/JobConfigViewer'; -import { JobRecord } from '@/types'; - -type PageKey = 'overview' | 'samples' | 'config'; - -interface Page { - name: string; - value: PageKey; - component: React.ComponentType<{ job: JobRecord }>; - menuItem?: React.ComponentType<{ job?: JobRecord | null }> | null; - mainCss?: string; -} - -const pages: Page[] = [ - { - name: 'Overview', - value: 'overview', - component: JobOverview, - mainCss: 'pt-24', - }, - { - name: 'Samples', - value: 'samples', - component: SampleImages, - menuItem: SampleImagesMenu, - mainCss: 'pt-24', - }, - { - name: 'Config File', - value: 'config', - component: JobConfigViewer, - mainCss: 'pt-[80px] px-0 pb-0', - }, -]; - -export default function JobPage({ params }: { params: { jobID: string } }) { - const usableParams = use(params as any) as { jobID: string }; - const jobID = usableParams.jobID; - const { job, status, refreshJob } = useJob(jobID, 5000); - const [pageKey, setPageKey] = useState('overview'); - const { status: authStatus } = useAuth(); - const isAuthenticated = authStatus === 'authenticated'; - - const page = pages.find(p => p.value === pageKey); - - if (!isAuthenticated) { - return ( - <> - -
- -
-
-

Job Details

-
-
-
- -
-

Sign in with Hugging Face or add an access token to view job details.

-
- - - Manage authentication in Settings - -
-
-
- - ); - } - - return ( - <> - {/* Fixed top bar */} - -
- -
-
-

Job: {job?.name}

-
-
- {job && ( - { - redirect('/jobs'); - }} - /> - )} -
- page.value === pageKey)?.mainCss}> - {status === 'loading' && job == null &&

Loading...

} - {status === 'error' && job == null &&

Error fetching job

} - {job && ( - <> - {pages.map(page => { - const Component = page.component; - return page.value === pageKey ? : null; - })} - - )} -
-
- {pages.map(page => ( - - ))} - { - page?.menuItem && ( - <> -
-
- - - ) - } -
- - ); -} diff --git a/src/app/jobs/new/AdvancedJob.tsx b/src/app/jobs/new/AdvancedJob.tsx deleted file mode 100644 index bccc4da22a57660ae23e0882f641362f1dfd4dec..0000000000000000000000000000000000000000 --- a/src/app/jobs/new/AdvancedJob.tsx +++ /dev/null @@ -1,146 +0,0 @@ -'use client'; -import { useEffect, useState, useRef } from 'react'; -import { JobConfig } from '@/types'; -import YAML from 'yaml'; -import Editor, { OnMount } from '@monaco-editor/react'; -import type { editor } from 'monaco-editor'; -import { SettingsData } from '@/types'; -import { migrateJobConfig } from './jobConfig'; - -type Props = { - jobConfig: JobConfig; - setJobConfig: (value: any, key?: string) => void; - status: 'idle' | 'saving' | 'success' | 'error'; - handleSubmit: (event: React.FormEvent) => void; - runId: string | null; - gpuIDs: string | null; - setGpuIDs: (value: string | null) => void; - gpuList: any; - datasetOptions: any; - settings: SettingsData; -}; - -const isDev = process.env.NODE_ENV === 'development'; - -const yamlConfig: YAML.DocumentOptions & - YAML.SchemaOptions & - YAML.ParseOptions & - YAML.CreateNodeOptions & - YAML.ToStringOptions = { - indent: 2, - lineWidth: 999999999999, - defaultStringType: 'QUOTE_DOUBLE', - defaultKeyType: 'PLAIN', - directives: true, -}; - -export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props) { - const [editorValue, setEditorValue] = useState(''); - const lastJobConfigUpdateStringRef = useRef(''); - const editorRef = useRef(null); - - // Track if the editor has been mounted - const isEditorMounted = useRef(false); - - // Handler for editor mounting - const handleEditorDidMount: OnMount = editor => { - editorRef.current = editor; - isEditorMounted.current = true; - - // Initial content setup - try { - const yamlContent = YAML.stringify(jobConfig, yamlConfig); - setEditorValue(yamlContent); - lastJobConfigUpdateStringRef.current = JSON.stringify(jobConfig); - } catch (e) { - console.warn(e); - } - }; - - useEffect(() => { - const lastUpdate = lastJobConfigUpdateStringRef.current; - const currentUpdate = JSON.stringify(jobConfig); - - // Skip if no changes or editor not yet mounted - if (lastUpdate === currentUpdate || !isEditorMounted.current) { - return; - } - - try { - // Preserve cursor position and selection - const editor = editorRef.current; - if (editor) { - // Save current editor state - const position = editor.getPosition(); - const selection = editor.getSelection(); - const scrollTop = editor.getScrollTop(); - - // Update content - const yamlContent = YAML.stringify(jobConfig, yamlConfig); - - // Only update if the content is actually different - if (yamlContent !== editor.getValue()) { - // Set value directly on the editor model instead of using React state - editor.getModel()?.setValue(yamlContent); - - // Restore cursor position and selection - if (position) editor.setPosition(position); - if (selection) editor.setSelection(selection); - editor.setScrollTop(scrollTop); - } - - lastJobConfigUpdateStringRef.current = currentUpdate; - } - } catch (e) { - console.warn(e); - } - }, [jobConfig]); - - const handleChange = (value: string | undefined) => { - if (value === undefined) return; - - try { - const parsed = YAML.parse(value); - // Don't update jobConfig if the change came from the editor itself - // to avoid a circular update loop - if (JSON.stringify(parsed) !== lastJobConfigUpdateStringRef.current) { - lastJobConfigUpdateStringRef.current = JSON.stringify(parsed); - - // We have to ensure certain things are always set - try { - parsed.config.process[0].type = 'ui_trainer'; - parsed.config.process[0].sqlite_db_path = './aitk_db.db'; - parsed.config.process[0].training_folder = settings.TRAINING_FOLDER; - parsed.config.process[0].device = 'cuda'; - parsed.config.process[0].performance_log_every = 10; - } catch (e) { - console.warn(e); - } - migrateJobConfig(parsed); - setJobConfig(parsed); - } - } catch (e) { - // Don't update on parsing errors - console.warn(e); - } - }; - - return ( - <> - - - ); -} diff --git a/src/app/jobs/new/SimpleJob.tsx b/src/app/jobs/new/SimpleJob.tsx deleted file mode 100644 index 080c383de00f4858199e0937cbca92385910a598..0000000000000000000000000000000000000000 --- a/src/app/jobs/new/SimpleJob.tsx +++ /dev/null @@ -1,973 +0,0 @@ -'use client'; -import { useMemo, useState } from 'react'; -import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options'; -import { defaultDatasetConfig } from './jobConfig'; -import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; -import { objectCopy } from '@/utils/basic'; -import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs'; -import Card from '@/components/Card'; -import { X } from 'lucide-react'; -import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal'; -import {FlipHorizontal2, FlipVertical2} from "lucide-react"; -import HFJobsWorkflow from '@/components/HFJobsWorkflow'; - -type Props = { - jobConfig: JobConfig; - setJobConfig: (value: any, key: string) => void; - status: 'idle' | 'saving' | 'success' | 'error'; - handleSubmit: (event: React.FormEvent) => void; - runId: string | null; - gpuIDs: string | null; - setGpuIDs: (value: string | null) => void; - gpuList: any; - datasetOptions: any; - trainingBackend?: 'local' | 'hf-jobs'; - setTrainingBackend?: (backend: 'local' | 'hf-jobs') => void; - hfJobSubmitted?: boolean; - onHFJobComplete?: (jobId: string, localJobId?: string) => void; - forceHFBackend?: boolean; -}; - -const isDev = process.env.NODE_ENV === 'development'; - -export default function SimpleJob({ - jobConfig, - setJobConfig, - handleSubmit, - status, - runId, - gpuIDs, - setGpuIDs, - gpuList, - datasetOptions, - trainingBackend: parentTrainingBackend, - setTrainingBackend: parentSetTrainingBackend, - hfJobSubmitted, - onHFJobComplete, - forceHFBackend = false, -}: Props) { - const [localTrainingBackend, setLocalTrainingBackend] = useState(forceHFBackend ? 'hf-jobs' : 'local'); - const trainingBackend = parentTrainingBackend || localTrainingBackend; - const setTrainingBackend = forceHFBackend - ? (_: 'local' | 'hf-jobs') => undefined - : parentSetTrainingBackend || setLocalTrainingBackend; - const backendOptions = forceHFBackend - ? [{ value: 'hf-jobs', label: 'HF Jobs (Cloud)' }] - : [ - { value: 'local', label: 'Local GPU' }, - { value: 'hf-jobs', label: 'HF Jobs (Cloud)' }, - ]; - const modelArch = useMemo(() => { - return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; - }, [jobConfig.config.process[0].model.arch]); - - const isVideoModel = !!(modelArch?.group === 'video'); - - const numTopCards = useMemo(() => { - let count = 4; // job settings, model config, target config, save config - if (modelArch?.additionalSections?.includes('model.multistage')) { - count += 1; // add multistage card - } - if (!modelArch?.disableSections?.includes('model.quantize')) { - count += 1; // add quantization card - } - return count; - - }, [modelArch]); - - let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; - - if (numTopCards == 5) { - topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6'; - } - if (numTopCards == 6) { - topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6'; - } - - const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => { - const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0; - if (!hasARA) { - return quantizationOptions; - } - let newQuantizationOptions = [ - { - label: 'Standard', - options: [quantizationOptions[0], quantizationOptions[1]], - }, - ]; - - // add ARAs if they exist for the model - let ARAs: SelectOption[] = []; - if (modelArch.accuracyRecoveryAdapters) { - for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) { - ARAs.push({ value, label }); - } - } - if (ARAs.length > 0) { - newQuantizationOptions.push({ - label: 'Accuracy Recovery Adapters', - options: ARAs, - }); - } - - let additionalQuantizationOptions: SelectOption[] = []; - // add the quantization options if they are not already included - for (let i = 2; i < quantizationOptions.length; i++) { - const option = quantizationOptions[i]; - additionalQuantizationOptions.push(option); - } - if (additionalQuantizationOptions.length > 0) { - newQuantizationOptions.push({ - label: 'Additional Quantization Options', - options: additionalQuantizationOptions, - }); - } - return newQuantizationOptions; - }, [modelArch]); - - return ( - <> -
-
- - setJobConfig(value, 'config.name')} - placeholder="Enter training name" - disabled={runId !== null} - required - /> - { - setTrainingBackend(value); - }} - options={backendOptions} - disabled={forceHFBackend} - /> - {trainingBackend === 'local' && ( - setGpuIDs(value)} - options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} - /> - )} - { - if (value?.trim() === '') { - value = null; - } - setJobConfig(value, 'config.process[0].trigger_word'); - }} - placeholder="" - required - /> - {trainingBackend === 'hf-jobs' && ( -
-

- {hfJobSubmitted - ? '✓ HF Job already submitted! You can modify settings and resubmit if needed.' - : '⏳ HF Job ready for submission. Submit to the cloud below.' - } -

-
- )} -
- - {/* Model Configuration Section */} - - { - const currentArch = modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch); - if (!currentArch || currentArch.name === value) { - return; - } - // update the defaults when a model is selected - const newArch = modelArchs.find(model => model.name === value); - - // update vram setting - if (!newArch?.additionalSections?.includes('model.low_vram')) { - setJobConfig(false, 'config.process[0].model.low_vram'); - } - - // revert defaults from previous model - for (const key in currentArch.defaults) { - setJobConfig(currentArch.defaults[key][1], key); - } - - if (newArch?.defaults) { - for (const key in newArch.defaults) { - setJobConfig(newArch.defaults[key][0], key); - } - } - // set new model - setJobConfig(value, 'config.process[0].model.arch'); - - // update datasets - const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false; - const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false; - const controls = newArch?.controls ?? []; - const datasets = jobConfig.config.process[0].datasets.map(dataset => { - const newDataset = objectCopy(dataset); - newDataset.controls = controls; - if (!hasControlPath) { - newDataset.control_path = null; // reset control path if not applicable - } - if (!hasNumFrames) { - newDataset.num_frames = 1; // reset num_frames if not applicable - } - return newDataset; - }); - setJobConfig(datasets, 'config.process[0].datasets'); - - // update samples - const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false; - const samples = jobConfig.config.process[0].sample.samples.map(sample => { - const newSample = objectCopy(sample); - if (!hasSampleCtrlImg) { - delete newSample.ctrl_img; // remove ctrl_img if not applicable - } - return newSample; - }); - setJobConfig(samples, 'config.process[0].sample.samples'); - }} - options={groupedModelOptions} - /> - { - if (value?.trim() === '') { - value = null; - } - setJobConfig(value, 'config.process[0].model.name_or_path'); - }} - placeholder="" - required - /> - {modelArch?.additionalSections?.includes('model.low_vram') && ( - - setJobConfig(value, 'config.process[0].model.low_vram')} - /> - - )} - - {modelArch?.disableSections?.includes('model.quantize') ? null : ( - - { - if (value === '') { - setJobConfig(false, 'config.process[0].model.quantize'); - value = defaultQtype; - } else { - setJobConfig(true, 'config.process[0].model.quantize'); - } - setJobConfig(value, 'config.process[0].model.qtype'); - }} - options={transformerQuantizationOptions} - /> - { - if (value === '') { - setJobConfig(false, 'config.process[0].model.quantize_te'); - value = defaultQtype; - } else { - setJobConfig(true, 'config.process[0].model.quantize_te'); - } - setJobConfig(value, 'config.process[0].model.qtype_te'); - }} - options={quantizationOptions} - /> - - )} - {modelArch?.additionalSections?.includes('model.multistage') && ( - - - setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')} - /> - setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')} - /> - - setJobConfig(value, 'config.process[0].train.switch_boundary_every')} - placeholder="eg. 1" - docKey={'train.switch_boundary_every'} - min={1} - required - /> - - )} - - setJobConfig(value, 'config.process[0].network.type')} - options={[ - { value: 'lora', label: 'LoRA' }, - { value: 'lokr', label: 'LoKr' }, - ]} - /> - {jobConfig.config.process[0].network?.type == 'lokr' && ( - setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')} - options={[ - { value: '-1', label: 'Auto' }, - { value: '4', label: '4' }, - { value: '8', label: '8' }, - { value: '16', label: '16' }, - { value: '32', label: '32' }, - ]} - /> - )} - {jobConfig.config.process[0].network?.type == 'lora' && ( - <> - { - console.log('onChange', value); - setJobConfig(value, 'config.process[0].network.linear'); - setJobConfig(value, 'config.process[0].network.linear_alpha'); - }} - placeholder="eg. 16" - min={0} - max={1024} - required - /> - {modelArch?.disableSections?.includes('network.conv') ? null : ( - { - console.log('onChange', value); - setJobConfig(value, 'config.process[0].network.conv'); - setJobConfig(value, 'config.process[0].network.conv_alpha'); - }} - placeholder="eg. 16" - min={0} - max={1024} - /> - )} - - )} - - - setJobConfig(value, 'config.process[0].save.dtype')} - options={[ - { value: 'bf16', label: 'BF16' }, - { value: 'fp16', label: 'FP16' }, - { value: 'fp32', label: 'FP32' }, - ]} - /> - setJobConfig(value, 'config.process[0].save.save_every')} - placeholder="eg. 250" - min={1} - required - /> - setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')} - placeholder="eg. 4" - min={1} - required - /> - -
-
- -
-
- setJobConfig(value, 'config.process[0].train.batch_size')} - placeholder="eg. 4" - min={1} - required - /> - setJobConfig(value, 'config.process[0].train.gradient_accumulation')} - placeholder="eg. 1" - min={1} - required - /> - setJobConfig(value, 'config.process[0].train.steps')} - placeholder="eg. 2000" - min={1} - required - /> -
-
- setJobConfig(value, 'config.process[0].train.optimizer')} - options={[ - { value: 'adamw8bit', label: 'AdamW8Bit' }, - { value: 'adafactor', label: 'Adafactor' }, - ]} - /> - setJobConfig(value, 'config.process[0].train.lr')} - placeholder="eg. 0.0001" - min={0} - required - /> - setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')} - placeholder="eg. 0.0001" - min={0} - required - /> -
-
- {modelArch?.disableSections?.includes('train.timestep_type') ? null : ( - setJobConfig(value, 'config.process[0].train.timestep_type')} - options={[ - { value: 'sigmoid', label: 'Sigmoid' }, - { value: 'linear', label: 'Linear' }, - { value: 'shift', label: 'Shift' }, - { value: 'weighted', label: 'Weighted' }, - ]} - /> - )} - setJobConfig(value, 'config.process[0].train.content_or_style')} - options={[ - { value: 'balanced', label: 'Balanced' }, - { value: 'content', label: 'High Noise' }, - { value: 'style', label: 'Low Noise' }, - ]} - /> - setJobConfig(value, 'config.process[0].train.noise_scheduler')} - options={[ - { value: 'flowmatch', label: 'FlowMatch' }, - { value: 'ddpm', label: 'DDPM' }, - ]} - /> -
-
- - setJobConfig(value, 'config.process[0].train.ema_config.use_ema')} - /> - - {jobConfig.config.process[0].train.ema_config?.use_ema && ( - setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')} - placeholder="eg. 0.99" - min={0} - /> - )} - - - { - setJobConfig(value, 'config.process[0].train.unload_text_encoder'); - if (value) { - setJobConfig(false, 'config.process[0].train.cache_text_embeddings'); - } - }} - /> - { - setJobConfig(value, 'config.process[0].train.cache_text_embeddings'); - if (value) { - setJobConfig(false, 'config.process[0].train.unload_text_encoder'); - } - }} - /> - -
-
- - setJobConfig(value, 'config.process[0].train.diff_output_preservation')} - /> - - {jobConfig.config.process[0].train.diff_output_preservation && ( - <> - - setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') - } - placeholder="eg. 1.0" - min={0} - /> - setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')} - placeholder="eg. woman" - /> - - )} -
-
-
-
-
- - <> - {jobConfig.config.process[0].datasets.map((dataset, i) => ( -
- -

Dataset {i + 1}

-
-
- setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} - options={datasetOptions} - /> - {modelArch?.additionalSections?.includes('datasets.control_path') && ( - - setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`) - } - options={[{ value: '', label: <>  }, ...datasetOptions]} - /> - )} - setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)} - placeholder="eg. 1.0" - /> -
-
- setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)} - placeholder="eg. A photo of a cat" - /> - setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)} - placeholder="eg. 0.05" - min={0} - required - /> - {modelArch?.additionalSections?.includes('datasets.num_frames') && ( - setJobConfig(value, `config.process[0].datasets[${i}].num_frames`)} - placeholder="eg. 41" - min={1} - required - /> - )} -
-
- - - setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`) - } - /> - setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)} - /> - {modelArch?.additionalSections?.includes('datasets.do_i2v') && ( - setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)} - docKey="datasets.do_i2v" - /> - )} - - - Flip X } - checked={dataset.flip_x || false} - onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)} - /> - Flip Y } - checked={dataset.flip_y || false} - onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)} - /> - -
-
- -
- {[ - [256, 512, 768], - [1024, 1280, 1536], - ].map(resGroup => ( -
- {resGroup.map(res => ( - { - const resolutions = dataset.resolution.includes(res) - ? dataset.resolution.filter(r => r !== res) - : [...dataset.resolution, res]; - setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`); - }} - /> - ))} -
- ))} -
-
-
-
-
- ))} - - -
-
-
- -
-
- setJobConfig(value, 'config.process[0].sample.sample_every')} - placeholder="eg. 250" - min={1} - required - /> - setJobConfig(value, 'config.process[0].sample.sampler')} - options={[ - { value: 'flowmatch', label: 'FlowMatch' }, - { value: 'ddpm', label: 'DDPM' }, - ]} - /> - setJobConfig(value, 'config.process[0].sample.guidance_scale')} - placeholder="eg. 1.0" - className="pt-2" - min={0} - required - /> - setJobConfig(value, 'config.process[0].sample.sample_steps')} - placeholder="eg. 1" - className="pt-2" - min={1} - required - /> -
-
- setJobConfig(value, 'config.process[0].sample.width')} - placeholder="eg. 1024" - min={0} - required - /> - setJobConfig(value, 'config.process[0].sample.height')} - placeholder="eg. 1024" - className="pt-2" - min={0} - required - /> - {isVideoModel && ( -
- setJobConfig(value, 'config.process[0].sample.num_frames')} - placeholder="eg. 0" - className="pt-2" - min={0} - required - /> - setJobConfig(value, 'config.process[0].sample.fps')} - placeholder="eg. 0" - className="pt-2" - min={0} - required - /> -
- )} -
- -
- setJobConfig(value, 'config.process[0].sample.seed')} - placeholder="eg. 0" - min={0} - required - /> - setJobConfig(value, 'config.process[0].sample.walk_seed')} - /> -
-
- -
- setJobConfig(value, 'config.process[0].train.skip_first_sample')} - /> -
-
- setJobConfig(value, 'config.process[0].train.disable_sampling')} - /> -
-
-
-
- -
-
- {jobConfig.config.process[0].sample.samples.map((sample, i) => ( -
-
-
-
-
- setJobConfig(value, `config.process[0].sample.samples[${i}].prompt`)} - placeholder="Enter prompt" - required - /> -
- - {modelArch?.additionalSections?.includes('sample.ctrl_img') && ( -
{ - openAddImageModal(imagePath => { - console.log('Selected image path:', imagePath); - if (!imagePath) return; - setJobConfig(imagePath, `config.process[0].sample.samples[${i}].ctrl_img`); - }); - }} - > - {!sample.ctrl_img && ( -
Add Control Image
- )} -
- )} -
-
-
-
- -
-
-
- ))} - -
-
- - {status === 'success' &&

Training saved successfully!

} - {status === 'error' &&

Error saving training. Please try again.

} -
- - {trainingBackend === 'hf-jobs' && ( -
- { - console.log('HF Job submitted:', jobId, 'Local job ID:', localJobId); - if (onHFJobComplete) { - onHFJobComplete(jobId, localJobId); - } - }} - /> -
- )} - - - - ); -} diff --git a/src/app/jobs/new/jobConfig.ts b/src/app/jobs/new/jobConfig.ts deleted file mode 100644 index df257bb985dad2eaada5d2913ab1e6347cf36ec1..0000000000000000000000000000000000000000 --- a/src/app/jobs/new/jobConfig.ts +++ /dev/null @@ -1,167 +0,0 @@ -import { JobConfig, DatasetConfig } from '@/types'; - -export const defaultDatasetConfig: DatasetConfig = { - folder_path: '/path/to/images/folder', - control_path: null, - mask_path: null, - mask_min_value: 0.1, - default_caption: '', - caption_ext: 'txt', - caption_dropout_rate: 0.05, - cache_latents_to_disk: false, - is_reg: false, - network_weight: 1, - resolution: [512, 768, 1024], - controls: [], - shrink_video_to_frames: true, - num_frames: 1, - do_i2v: true, - flip_x: false, - flip_y: false, -}; - -export const defaultJobConfig: JobConfig = { - job: 'extension', - config: { - name: 'my_first_lora_v1', - process: [ - { - type: 'ui_trainer', - training_folder: 'output', - sqlite_db_path: './aitk_db.db', - device: 'cuda', - trigger_word: null, - performance_log_every: 10, - network: { - type: 'lora', - linear: 32, - linear_alpha: 32, - conv: 16, - conv_alpha: 16, - lokr_full_rank: true, - lokr_factor: -1, - network_kwargs: { - ignore_if_contains: [], - }, - }, - save: { - dtype: 'bf16', - save_every: 250, - max_step_saves_to_keep: 4, - save_format: 'diffusers', - push_to_hub: false, - }, - datasets: [defaultDatasetConfig], - train: { - batch_size: 1, - bypass_guidance_embedding: true, - steps: 3000, - gradient_accumulation: 1, - train_unet: true, - train_text_encoder: false, - gradient_checkpointing: true, - noise_scheduler: 'flowmatch', - optimizer: 'adamw8bit', - timestep_type: 'sigmoid', - content_or_style: 'balanced', - optimizer_params: { - weight_decay: 1e-4, - }, - unload_text_encoder: false, - cache_text_embeddings: false, - lr: 0.0001, - ema_config: { - use_ema: false, - ema_decay: 0.99, - }, - skip_first_sample: false, - disable_sampling: false, - dtype: 'bf16', - diff_output_preservation: false, - diff_output_preservation_multiplier: 1.0, - diff_output_preservation_class: 'person', - switch_boundary_every: 1, - }, - model: { - name_or_path: 'ostris/Flex.1-alpha', - quantize: true, - qtype: 'qfloat8', - quantize_te: true, - qtype_te: 'qfloat8', - arch: 'flex1', - low_vram: false, - model_kwargs: {}, - }, - sample: { - sampler: 'flowmatch', - sample_every: 250, - width: 1024, - height: 1024, - samples: [ - { - prompt: 'woman with red hair, playing chess at the park, bomb going off in the background' - }, - { - prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe', - }, - { - prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini', - }, - { - prompt: 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background', - }, - { - prompt: 'a bear building a log cabin in the snow covered mountains', - }, - { - prompt: 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker', - }, - { - prompt: 'hipster man with a beard, building a chair, in a wood shop', - }, - { - prompt: 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop', - }, - { - prompt: "a man holding a sign that says, 'this is a sign'", - }, - { - prompt: 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle', - }, - ], - neg: '', - seed: 42, - walk_seed: true, - guidance_scale: 4, - sample_steps: 25, - num_frames: 1, - fps: 1, - }, - }, - ], - }, - meta: { - name: '[name]', - version: '1.0', - }, -}; - -export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => { - // upgrade prompt strings to samples - if ( - jobConfig?.config?.process && - jobConfig.config.process[0]?.sample && - Array.isArray(jobConfig.config.process[0].sample.prompts) && - jobConfig.config.process[0].sample.prompts.length > 0 - ) { - let newSamples = []; - for (const prompt of jobConfig.config.process[0].sample.prompts) { - newSamples.push({ - prompt: prompt, - }); - } - jobConfig.config.process[0].sample.samples = newSamples; - delete jobConfig.config.process[0].sample.prompts; - } - return jobConfig; -}; diff --git a/src/app/jobs/new/options.ts b/src/app/jobs/new/options.ts deleted file mode 100644 index 71fdc9d8e767d2cbc078475d32b37e6996948199..0000000000000000000000000000000000000000 --- a/src/app/jobs/new/options.ts +++ /dev/null @@ -1,441 +0,0 @@ -import { GroupedSelectOption, SelectOption } from '@/types'; - -type Control = 'depth' | 'line' | 'pose' | 'inpaint'; - -type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; -type AdditionalSections = - | 'datasets.control_path' - | 'datasets.do_i2v' - | 'sample.ctrl_img' - | 'datasets.num_frames' - | 'model.multistage' - | 'model.low_vram'; -type ModelGroup = 'image' | 'instruction' | 'video'; - -export interface ModelArch { - name: string; - label: string; - group: ModelGroup; - controls?: Control[]; - isVideoModel?: boolean; - defaults?: { [key: string]: any }; - disableSections?: DisableableSections[]; - additionalSections?: AdditionalSections[]; - accuracyRecoveryAdapters?: { [key: string]: string }; -} - -const defaultNameOrPath = ''; - -export const modelArchs: ModelArch[] = [ - { - name: 'flux', - label: 'FLUX.1', - group: 'image', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - }, - disableSections: ['network.conv'], - }, - { - name: 'flux_kontext', - label: 'FLUX.1-Kontext-dev', - group: 'instruction', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], - }, - disableSections: ['network.conv'], - additionalSections: ['datasets.control_path', 'sample.ctrl_img'], - }, - { - name: 'flex1', - label: 'Flex.1', - group: 'image', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].train.bypass_guidance_embedding': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - }, - disableSections: ['network.conv'], - }, - { - name: 'flex2', - label: 'Flex.2', - group: 'image', - controls: ['depth', 'line', 'pose', 'inpaint'], - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['ostris/Flex.2-preview', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].model.model_kwargs': [ - { - invert_inpaint_mask_chance: 0.2, - inpaint_dropout: 0.5, - control_dropout: 0.5, - inpaint_random_chance: 0.2, - do_random_inpainting: true, - random_blur_mask: true, - random_dialate_mask: true, - }, - {}, - ], - 'config.process[0].train.bypass_guidance_embedding': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - }, - disableSections: ['network.conv'], - }, - { - name: 'chroma', - label: 'Chroma', - group: 'image', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['lodestones/Chroma1-Base', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - }, - disableSections: ['network.conv'], - }, - { - name: 'wan21:1b', - label: 'Wan 2.1 (1.3B)', - group: 'video', - isVideoModel: true, - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-1.3B-Diffusers', defaultNameOrPath], - 'config.process[0].model.quantize': [false, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].sample.num_frames': [41, 1], - 'config.process[0].sample.fps': [16, 1], - }, - disableSections: ['network.conv'], - additionalSections: ['datasets.num_frames', 'model.low_vram'], - }, - { - name: 'wan21_i2v:14b480p', - label: 'Wan 2.1 I2V (14B-480P)', - group: 'video', - isVideoModel: true, - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-480P-Diffusers', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].sample.num_frames': [41, 1], - 'config.process[0].sample.fps': [16, 1], - 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], - }, - disableSections: ['network.conv'], - additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'], - }, - { - name: 'wan21_i2v:14b', - label: 'Wan 2.1 I2V (14B-720P)', - group: 'video', - isVideoModel: true, - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-720P-Diffusers', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].sample.num_frames': [41, 1], - 'config.process[0].sample.fps': [16, 1], - 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], - }, - disableSections: ['network.conv'], - additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'], - }, - { - name: 'wan21:14b', - label: 'Wan 2.1 (14B)', - group: 'video', - isVideoModel: true, - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-14B-Diffusers', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].sample.num_frames': [41, 1], - 'config.process[0].sample.fps': [16, 1], - }, - disableSections: ['network.conv'], - additionalSections: ['datasets.num_frames', 'model.low_vram'], - }, - { - name: 'wan22_14b:t2v', - label: 'Wan 2.2 (14B)', - group: 'video', - isVideoModel: true, - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].sample.num_frames': [41, 1], - 'config.process[0].sample.fps': [16, 1], - 'config.process[0].model.low_vram': [true, false], - 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], - 'config.process[0].model.model_kwargs': [ - { - train_high_noise: true, - train_low_noise: true, - }, - {}, - ], - }, - disableSections: ['network.conv'], - additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage'], - accuracyRecoveryAdapters: { - // '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors', - '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors', - }, - }, - { - name: 'wan22_14b_i2v', - label: 'Wan 2.2 I2V (14B)', - group: 'video', - isVideoModel: true, - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-I2V-A14B-Diffusers-bf16', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].sample.num_frames': [41, 1], - 'config.process[0].sample.fps': [16, 1], - 'config.process[0].model.low_vram': [true, false], - 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], - 'config.process[0].model.model_kwargs': [ - { - train_high_noise: true, - train_low_noise: true, - }, - {}, - ], - }, - disableSections: ['network.conv'], - additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage'], - accuracyRecoveryAdapters: { - '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors', - }, - }, - { - name: 'wan22_5b', - label: 'Wan 2.2 TI2V (5B)', - group: 'video', - isVideoModel: true, - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.2-TI2V-5B-Diffusers', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].model.low_vram': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].sample.num_frames': [121, 1], - 'config.process[0].sample.fps': [24, 1], - 'config.process[0].sample.width': [768, 1024], - 'config.process[0].sample.height': [768, 1024], - 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], - }, - disableSections: ['network.conv'], - additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'], - }, - { - name: 'lumina2', - label: 'Lumina2', - group: 'image', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath], - 'config.process[0].model.quantize': [false, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - }, - disableSections: ['network.conv'], - }, - { - name: 'qwen_image', - label: 'Qwen-Image', - group: 'image', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].model.low_vram': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], - 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], - }, - disableSections: ['network.conv'], - additionalSections: ['model.low_vram'], - accuracyRecoveryAdapters: { - '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors', - }, - }, - { - name: 'qwen_image_edit', - label: 'Qwen-Image-Edit', - group: 'instruction', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].model.low_vram': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], - 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], - }, - disableSections: ['network.conv'], - additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'], - accuracyRecoveryAdapters: { - '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors', - }, - }, - { - name: 'hidream', - label: 'HiDream', - group: 'image', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.lr': [0.0002, 0.0001], - 'config.process[0].train.timestep_type': ['shift', 'sigmoid'], - 'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []], - }, - disableSections: ['network.conv'], - additionalSections: ['model.low_vram'], - }, - { - name: 'hidream_e1', - label: 'HiDream E1', - group: 'instruction', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-E1-1', defaultNameOrPath], - 'config.process[0].model.quantize': [true, false], - 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.lr': [0.0001, 0.0001], - 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], - 'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []], - }, - disableSections: ['network.conv'], - additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'], - }, - { - name: 'sdxl', - label: 'SDXL', - group: 'image', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath], - 'config.process[0].model.quantize': [false, false], - 'config.process[0].model.quantize_te': [false, false], - 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'], - 'config.process[0].sample.guidance_scale': [6, 4], - }, - disableSections: ['model.quantize', 'train.timestep_type'], - }, - { - name: 'sd15', - label: 'SD 1.5', - group: 'image', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath], - 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'], - 'config.process[0].sample.width': [512, 1024], - 'config.process[0].sample.height': [512, 1024], - 'config.process[0].sample.guidance_scale': [6, 4], - }, - disableSections: ['model.quantize', 'train.timestep_type'], - }, - { - name: 'omnigen2', - label: 'OmniGen2', - group: 'image', - defaults: { - // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath], - 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], - 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].model.quantize': [false, false], - 'config.process[0].model.quantize_te': [true, false], - }, - disableSections: ['network.conv'], - additionalSections: ['datasets.control_path', 'sample.ctrl_img'], - }, -].sort((a, b) => { - // Sort by label, case-insensitive - return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }); -}) as any; - -export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => { - const group = acc.find(g => g.label === arch.group); - if (group) { - group.options.push({ value: arch.name, label: arch.label }); - } else { - acc.push({ - label: arch.group, - options: [{ value: arch.name, label: arch.label }], - }); - } - return acc; -}, [] as GroupedSelectOption[]); - -export const quantizationOptions: SelectOption[] = [ - { value: '', label: '- NONE -' }, - { value: 'qfloat8', label: 'float8 (default)' }, - { value: 'uint8', label: '8 bit' }, - { value: 'uint7', label: '7 bit' }, - { value: 'uint6', label: '6 bit' }, - { value: 'uint5', label: '5 bit' }, - { value: 'uint4', label: '4 bit' }, - { value: 'uint3', label: '3 bit' }, - { value: 'uint2', label: '2 bit' }, -]; - -export const defaultQtype = 'qfloat8'; diff --git a/src/app/jobs/new/page.tsx b/src/app/jobs/new/page.tsx deleted file mode 100644 index 1da413490f5703ceb577dd5bb29502a9e3970045..0000000000000000000000000000000000000000 --- a/src/app/jobs/new/page.tsx +++ /dev/null @@ -1,306 +0,0 @@ -'use client'; - -import { useEffect, useState } from 'react'; -import { useSearchParams, useRouter } from 'next/navigation'; -import Link from 'next/link'; -import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig'; -import { JobConfig } from '@/types'; -import { objectCopy } from '@/utils/basic'; -import { useNestedState } from '@/utils/hooks'; -import { SelectInput } from '@/components/formInputs'; -import useSettings from '@/hooks/useSettings'; -import useGPUInfo from '@/hooks/useGPUInfo'; -import useDatasetList from '@/hooks/useDatasetList'; -import path from 'path'; -import { TopBar, MainContent } from '@/components/layout'; -import { Button } from '@headlessui/react'; -import { FaChevronLeft } from 'react-icons/fa'; -import SimpleJob from './SimpleJob'; -import AdvancedJob from './AdvancedJob'; -import ErrorBoundary from '@/components/ErrorBoundary'; -import { getJob, upsertJob } from '@/utils/storage/jobStorage'; -import { usingBrowserDb } from '@/utils/env'; -import { getUserDatasetPath, updateUserDatasetPath } from '@/utils/storage/datasetStorage'; -import { apiClient } from '@/utils/api'; -import { useAuth } from '@/contexts/AuthContext'; -import HFLoginButton from '@/components/HFLoginButton'; - -const isDev = process.env.NODE_ENV === 'development'; - -export default function TrainingForm() { - const router = useRouter(); - const searchParams = useSearchParams(); - const runId = searchParams.get('id'); - const { status: authStatus } = useAuth(); - const isAuthenticated = authStatus === 'authenticated'; - const [gpuIDs, setGpuIDs] = useState(null); - const { settings, isSettingsLoaded } = useSettings(); - const { gpuList, isGPUInfoLoaded } = useGPUInfo(); - const { datasets, status: datasetFetchStatus } = useDatasetList(); - const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]); - const [showAdvancedView, setShowAdvancedView] = useState(false); - - const [jobConfig, setJobConfig] = useNestedState(objectCopy(defaultJobConfig)); - const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle'); - - // Track HF Jobs backend state - const [trainingBackend, setTrainingBackend] = useState<'local' | 'hf-jobs'>( - usingBrowserDb ? 'hf-jobs' : 'local', - ); - const [hfJobSubmitted, setHfJobSubmitted] = useState(false); - - useEffect(() => { - if (!isSettingsLoaded || !isAuthenticated) return; - if (datasetFetchStatus !== 'success') return; - - let isMounted = true; - - const buildDatasetOptions = async () => { - const options = await Promise.all( - datasets.map(async name => { - let datasetPath = settings.DATASETS_FOLDER ? path.join(settings.DATASETS_FOLDER, name) : ''; - - if (usingBrowserDb) { - const storedPath = getUserDatasetPath(name); - if (storedPath) { - datasetPath = storedPath; - } else { - try { - const response = await apiClient - .post('/api/datasets/create', { name }) - .then(res => res.data); - if (response?.path) { - datasetPath = response.path; - updateUserDatasetPath(name, datasetPath); - } - } catch (err) { - console.error('Error resolving dataset path:', err); - } - } - } - - if (!datasetPath) { - datasetPath = name; - } - - return { value: datasetPath, label: name }; - }), - ); - - if (!isMounted) { - return; - } - - setDatasetOptions(options); - const defaultDatasetPath = defaultDatasetConfig.folder_path; - - for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) { - const dataset = jobConfig.config.process[0].datasets[i]; - if (dataset.folder_path === defaultDatasetPath) { - if (options.length > 0) { - setJobConfig(options[0].value, `config.process[0].datasets[${i}].folder_path`); - } - } - } - }; - - buildDatasetOptions(); - - return () => { - isMounted = false; - }; - }, [datasets, settings, isSettingsLoaded, datasetFetchStatus]); - - useEffect(() => { - if (runId) { - getJob(runId) - .then(data => { - if (!data) { - throw new Error('Job not found'); - } - setGpuIDs(data.gpu_ids); - const parsedJobConfig = migrateJobConfig(JSON.parse(data.job_config)); - setJobConfig(parsedJobConfig); - - if (parsedJobConfig.is_hf_job) { - setTrainingBackend('hf-jobs'); - setHfJobSubmitted(true); - } - }) - .catch(error => console.error('Error fetching training:', error)); - } - }, [runId]); - - useEffect(() => { - if (isGPUInfoLoaded) { - if (gpuIDs === null && gpuList.length > 0) { - setGpuIDs(`${gpuList[0].index}`); - } - } - }, [gpuList, isGPUInfoLoaded]); - - useEffect(() => { - if (isSettingsLoaded) { - setJobConfig(settings.TRAINING_FOLDER, 'config.process[0].training_folder'); - } - }, [settings, isSettingsLoaded]); - - const saveJob = async () => { - if (!isAuthenticated) return; - if (status === 'saving') return; - setStatus('saving'); - - try { - const savedJob = await upsertJob({ - id: runId || undefined, - name: jobConfig.config.name, - gpu_ids: gpuIDs, - job_config: { - ...jobConfig, - is_hf_job: trainingBackend === 'hf-jobs', - hf_job_submitted: hfJobSubmitted, - training_backend: trainingBackend, - }, - status: trainingBackend === 'hf-jobs' ? (hfJobSubmitted ? 'submitted' : 'stopped') : undefined, - }); - - setStatus('success'); - router.push(`/jobs/${savedJob.id}`); - } catch (error: any) { - console.log('Error saving training:', error); - if (error?.code === 'P2002') { - alert('Training name already exists. Please choose a different name.'); - } else { - alert('Failed to save job. Please try again.'); - } - } finally { - setTimeout(() => { - setStatus('idle'); - }, 2000); - } - }; - - const handleSubmit = async (e: React.FormEvent) => { - e.preventDefault(); - saveJob(); - }; - - return ( - <> - -
- -
-
-

{runId ? 'Edit Training Job' : 'New Training Job'}

-
-
- {showAdvancedView && isAuthenticated && ( - <> -
- setGpuIDs(value)} - options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} - /> -
-
- - )} - -
- -
-
- -
-
- - {!isAuthenticated ? ( - -
-

You need to sign in with Hugging Face or provide a valid access token before creating or editing jobs.

-
- - - Manage authentication in Settings - -
-
-
- ) : showAdvancedView ? ( -
- -
- ) : ( - - - Advanced job detected. Please switch to advanced view to continue. -
- } - > - { - setHfJobSubmitted(true); - // Redirect to the job detail page - if (localJobId) { - router.push(`/jobs/${localJobId}`); - } - }} - forceHFBackend={usingBrowserDb} - /> - - -
- - )} - - ); -} - useEffect(() => { - if (!isAuthenticated) { - setDatasetOptions([]); - } - }, [isAuthenticated]); diff --git a/src/app/jobs/page.tsx b/src/app/jobs/page.tsx deleted file mode 100644 index a29dd77c3a5463069f683f20f903add3b343fe40..0000000000000000000000000000000000000000 --- a/src/app/jobs/page.tsx +++ /dev/null @@ -1,49 +0,0 @@ -'use client'; - -import JobsTable from '@/components/JobsTable'; -import { TopBar, MainContent } from '@/components/layout'; -import Link from 'next/link'; -import { useAuth } from '@/contexts/AuthContext'; -import HFLoginButton from '@/components/HFLoginButton'; - -export default function Dashboard() { - const { status: authStatus } = useAuth(); - const isAuthenticated = authStatus === 'authenticated'; - - return ( - <> - -
-

Training Jobs

-
-
-
- {isAuthenticated ? ( - - New Training Job - - ) : ( - - Sign in to create jobs - - )} -
-
- - {isAuthenticated ? ( - - ) : ( -
-

Sign in with Hugging Face or add a personal access token to view and manage training jobs.

-
- - - Manage tokens in Settings - -
-
- )} -
- - ); -} diff --git a/src/app/layout.tsx b/src/app/layout.tsx deleted file mode 100644 index b3ce381e88faf2bbd71b8cb67f61662d6bead943..0000000000000000000000000000000000000000 --- a/src/app/layout.tsx +++ /dev/null @@ -1,50 +0,0 @@ -import type { Metadata } from 'next'; -import { Inter } from 'next/font/google'; -import './globals.css'; -import Sidebar from '@/components/Sidebar'; -import { ThemeProvider } from '@/components/ThemeProvider'; -import ConfirmModal from '@/components/ConfirmModal'; -import SampleImageModal from '@/components/SampleImageModal'; -import { Suspense } from 'react'; -import AuthWrapper from '@/components/AuthWrapper'; -import DocModal from '@/components/DocModal'; -import { AuthProvider } from '@/contexts/AuthContext'; - -export const dynamic = 'force-dynamic'; - -const inter = Inter({ subsets: ['latin'] }); - -export const metadata: Metadata = { - title: 'Ostris - AI Toolkit', - description: 'A toolkit for building AI things.', -}; - -export default function RootLayout({ children }: { children: React.ReactNode }) { - // Check if the AI_TOOLKIT_AUTH environment variable is set - const authRequired = process.env.AI_TOOLKIT_AUTH ? true : false; - - return ( - - - - - - - - -
- -
- {children} -
-
-
-
-
- - - - - - ); -} diff --git a/src/app/manifest.json b/src/app/manifest.json deleted file mode 100644 index ced3ca5d79e5ec230be33c6f0e0907fb419c5588..0000000000000000000000000000000000000000 --- a/src/app/manifest.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "name": "AI Toolkit", - "short_name": "AIToolkit", - "icons": [ - { - "src": "/web-app-manifest-192x192.png", - "sizes": "192x192", - "type": "image/png", - "purpose": "maskable" - }, - { - "src": "/web-app-manifest-512x512.png", - "sizes": "512x512", - "type": "image/png", - "purpose": "maskable" - } - ], - "theme_color": "#000000", - "background_color": "#000000", - "display": "standalone" -} \ No newline at end of file diff --git a/src/app/page.tsx b/src/app/page.tsx deleted file mode 100644 index f889cb6122cb25d33bccef074ef51a6b26b692c9..0000000000000000000000000000000000000000 --- a/src/app/page.tsx +++ /dev/null @@ -1,5 +0,0 @@ -import { redirect } from 'next/navigation'; - -export default function Home() { - redirect('/dashboard'); -} diff --git a/src/app/settings/page.tsx b/src/app/settings/page.tsx deleted file mode 100644 index 25fc6bd922360cda4452a119810529a9c132b695..0000000000000000000000000000000000000000 --- a/src/app/settings/page.tsx +++ /dev/null @@ -1,264 +0,0 @@ -'use client'; - -import { useEffect, useState } from 'react'; -import useSettings from '@/hooks/useSettings'; -import { TopBar, MainContent } from '@/components/layout'; -import { persistSettings } from '@/utils/storage/settingsStorage'; -import { useAuth } from '@/contexts/AuthContext'; -import HFLoginButton from '@/components/HFLoginButton'; -import { useMemo } from 'react'; -import Link from 'next/link'; - -export default function Settings() { - const { settings, setSettings } = useSettings(); - const { status: authStatus, namespace, oauthAvailable, loginWithOAuth, logout, setManualToken, error: authError, token: authToken } = useAuth(); - const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle'); - const [manualToken, setManualTokenInput] = useState(settings.HF_TOKEN || ''); - const isAuthenticated = authStatus === 'authenticated'; - - useEffect(() => { - setManualTokenInput(settings.HF_TOKEN || ''); - }, [settings.HF_TOKEN]); - - const handleSubmit = async (e: React.FormEvent) => { - e.preventDefault(); - setStatus('saving'); - - persistSettings(settings) - .then(() => { - setStatus('success'); - }) - .catch(error => { - console.error('Error saving settings:', error); - setStatus('error'); - }) - .finally(() => { - setTimeout(() => setStatus('idle'), 2000); - }); - }; - - const handleChange = (e: React.ChangeEvent) => { - const { name, value } = e.target; - setSettings(prev => ({ ...prev, [name]: value })); - }; - - const handleManualSubmit = async (e: React.FormEvent) => { - e.preventDefault(); - await setManualToken(manualToken); - }; - - const authDescription = useMemo(() => { - if (authStatus === 'checking') { - return 'Checking your Hugging Face session…'; - } - if (isAuthenticated) { - return `Connected as ${namespace}`; - } - return 'Sign in to use Hugging Face Jobs or submit your own access token.'; - }, [authStatus, isAuthenticated, namespace]); - - return ( - <> - -
-

Settings

-
-
-
- {isAuthenticated ? ( - Welcome, {namespace || 'user'} - ) : ( - Authenticate to unlock training features - )} -
-
- -
-
-
-
-

Sign in with Hugging Face

-

{authDescription}

-
- {isAuthenticated && ( - Authenticated - )} -
-
- {isAuthenticated ? ( - - ) : ( - <> - - {!oauthAvailable && ( - - OAuth is unavailable. Set HF_OAUTH_CLIENT_ID/SECRET on the server. - - )} - - )} -
- {!isAuthenticated && authError && ( -

{authError}

- )} -
- -
-

Manual Token

-

- Paste an access token created at{' '} - - huggingface.co/settings/tokens - - . -

-
- setManualTokenInput(event.target.value)} - className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent" - placeholder="Enter Hugging Face token" - /> -
-
- - {isAuthenticated && authToken === manualToken && ( - Active token - )} -
- {authError && ( -

{authError}

- )} -
-
- -
-
-
-
-
- - -
- -
- - -
-
-
-
-
-

Hugging Face Jobs (Cloud Training)

- -
- - -
- -
- - -
-
-
-
- - - - {status === 'success' &&

Settings saved successfully!

} - {status === 'error' &&

Error saving settings. Please try again.

} -
-
- - ); -} diff --git a/src/components/AddImagesModal.tsx b/src/components/AddImagesModal.tsx deleted file mode 100644 index ff91a8836dcfe7dc67a9ee237d7d5a1b16941cf2..0000000000000000000000000000000000000000 --- a/src/components/AddImagesModal.tsx +++ /dev/null @@ -1,152 +0,0 @@ -'use client'; -import { createGlobalState } from 'react-global-hooks'; -import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react'; -import { FaUpload } from 'react-icons/fa'; -import { useCallback, useState } from 'react'; -import { useDropzone } from 'react-dropzone'; -import { apiClient } from '@/utils/api'; - -export interface AddImagesModalState { - datasetName: string; - onComplete?: () => void; -} - -export const addImagesModalState = createGlobalState(null); - -export const openImagesModal = (datasetName: string, onComplete: () => void) => { - addImagesModalState.set({ datasetName, onComplete }); -}; - -export default function AddImagesModal() { - const [addImagesModalInfo, setAddImagesModalInfo] = addImagesModalState.use(); - const [uploadProgress, setUploadProgress] = useState(0); - const [isUploading, setIsUploading] = useState(false); - const open = addImagesModalInfo !== null; - - const onCancel = () => { - if (!isUploading) { - setAddImagesModalInfo(null); - } - }; - - const onDone = () => { - if (addImagesModalInfo?.onComplete && !isUploading) { - addImagesModalInfo.onComplete(); - setAddImagesModalInfo(null); - } - }; - - const onDrop = useCallback( - async (acceptedFiles: File[]) => { - if (acceptedFiles.length === 0) return; - - setIsUploading(true); - setUploadProgress(0); - - const formData = new FormData(); - acceptedFiles.forEach(file => { - formData.append('files', file); - }); - formData.append('datasetName', addImagesModalInfo?.datasetName || ''); - - try { - await apiClient.post(`/api/datasets/upload`, formData, { - headers: { - 'Content-Type': 'multipart/form-data', - }, - onUploadProgress: progressEvent => { - const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100)); - setUploadProgress(percentCompleted); - }, - timeout: 0, // Disable timeout - }); - - onDone(); - } catch (error) { - console.error('Upload failed:', error); - } finally { - setIsUploading(false); - setUploadProgress(0); - } - }, - [addImagesModalInfo], - ); - - const { getRootProps, getInputProps, isDragActive } = useDropzone({ - onDrop, - accept: { - 'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'], - 'video/*': ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.m4v', '.flv'], - 'text/*': ['.txt'], - }, - multiple: true, - }); - - return ( - - - -
-
- -
-
- - Add Images to: {addImagesModalInfo?.datasetName} - -
-
- - -

- {isDragActive ? 'Drop the files here...' : 'Drag & drop files here, or click to select files'} -

-
- {isUploading && ( -
-
-
-
-

Uploading... {uploadProgress}%

-
- )} -
-
-
-
- - -
-
-
-
-
- ); -} diff --git a/src/components/AddSingleImageModal.tsx b/src/components/AddSingleImageModal.tsx deleted file mode 100644 index ba32ef9dff916b5f6e605909f5d328dfce49783a..0000000000000000000000000000000000000000 --- a/src/components/AddSingleImageModal.tsx +++ /dev/null @@ -1,141 +0,0 @@ -'use client'; -import { createGlobalState } from 'react-global-hooks'; -import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react'; -import { FaUpload } from 'react-icons/fa'; -import { useCallback, useState } from 'react'; -import { useDropzone } from 'react-dropzone'; -import { apiClient } from '@/utils/api'; - -export interface AddSingleImageModalState { - - onComplete?: (imagePath: string|null) => void; -} - -export const addSingleImageModalState = createGlobalState(null); - -export const openAddImageModal = (onComplete: (imagePath: string|null) => void) => { - addSingleImageModalState.set({onComplete }); -}; - -export default function AddSingleImageModal() { - const [addSingleImageModalInfo, setAddSingleImageModalInfo] = addSingleImageModalState.use(); - const [uploadProgress, setUploadProgress] = useState(0); - const [isUploading, setIsUploading] = useState(false); - const open = addSingleImageModalInfo !== null; - - const onCancel = () => { - if (!isUploading) { - setAddSingleImageModalInfo(null); - } - }; - - const onDone = (imagePath: string|null) => { - if (addSingleImageModalInfo?.onComplete && !isUploading) { - addSingleImageModalInfo.onComplete(imagePath); - setAddSingleImageModalInfo(null); - } - }; - - const onDrop = useCallback( - async (acceptedFiles: File[]) => { - if (acceptedFiles.length === 0) return; - - setIsUploading(true); - setUploadProgress(0); - - const formData = new FormData(); - acceptedFiles.forEach(file => { - formData.append('files', file); - }); - - try { - const resp = await apiClient.post(`/api/img/upload`, formData, { - headers: { - 'Content-Type': 'multipart/form-data', - }, - onUploadProgress: progressEvent => { - const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100)); - setUploadProgress(percentCompleted); - }, - timeout: 0, // Disable timeout - }); - console.log('Upload successful:', resp.data); - - onDone(resp.data.files[0] || null); - } catch (error) { - console.error('Upload failed:', error); - } finally { - setIsUploading(false); - setUploadProgress(0); - } - }, - [addSingleImageModalInfo], - ); - - const { getRootProps, getInputProps, isDragActive } = useDropzone({ - onDrop, - accept: { - 'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'], - }, - multiple: false, - }); - - return ( - - - -
-
- -
-
- - Add Control Image - -
-
- - -

- {isDragActive ? 'Drop the image here...' : 'Drag & drop an image here, or click to select one'} -

-
- {isUploading && ( -
-
-
-
-

Uploading... {uploadProgress}%

-
- )} -
-
-
-
- -
-
-
-
-
- ); -} diff --git a/src/components/AuthWrapper.tsx b/src/components/AuthWrapper.tsx deleted file mode 100644 index bdf287a8dca4aa022b852680a13c8c3b0bb33926..0000000000000000000000000000000000000000 --- a/src/components/AuthWrapper.tsx +++ /dev/null @@ -1,166 +0,0 @@ -'use client'; - -import { useState, useEffect, useRef } from 'react'; -import { apiClient, isAuthorizedState } from '@/utils/api'; -import { createGlobalState } from 'react-global-hooks'; - -interface AuthWrapperProps { - authRequired: boolean; - children: React.ReactNode | React.ReactNode[]; -} - -export default function AuthWrapper({ authRequired, children }: AuthWrapperProps) { - const [token, setToken] = useState(''); - // start with true, and deauth if needed - const [isAuthorizedGlobal, setIsAuthorized] = isAuthorizedState.use(); - const [isLoading, setIsLoading] = useState(false); - const [error, setError] = useState(''); - const [isBrowser, setIsBrowser] = useState(false); - const inputRef = useRef(null); - - const isAuthorized = authRequired ? isAuthorizedGlobal : true; - - // Set isBrowser to true when component mounts - useEffect(() => { - setIsBrowser(true); - // Get token from localStorage only after component has mounted - const storedToken = localStorage.getItem('AI_TOOLKIT_AUTH') || ''; - setToken(storedToken); - checkAuth(); - }, []); - - // auto focus on input when not authorized - useEffect(() => { - if (isAuthorized) { - return; - } - setTimeout(() => { - if (inputRef.current) { - inputRef.current.focus(); - } - }, 100); - }, [isAuthorized]); - - const checkAuth = async () => { - // always get current stored token here to avoid state race conditions - const currentToken = localStorage.getItem('AI_TOOLKIT_AUTH') || ''; - if (!authRequired || isLoading || currentToken === '') { - return; - } - setIsLoading(true); - setError(''); - try { - const response = await apiClient.get('/api/auth'); - if (response.data.isAuthenticated) { - setIsAuthorized(true); - } else { - setIsAuthorized(false); - setError('Invalid token. Please try again.'); - } - } catch (err) { - setIsAuthorized(false); - console.log(err); - setError('Invalid token. Please try again.'); - } - setIsLoading(false); - }; - - const handleSubmit = async (e: React.FormEvent) => { - e.preventDefault(); - setError(''); - - if (!token.trim()) { - setError('Please enter your token'); - return; - } - - if (isBrowser) { - localStorage.setItem('AI_TOOLKIT_AUTH', token); - checkAuth(); - } - }; - - if (isAuthorized) { - return <>{children}; - } - - return ( -
- {/* Left side - decorative or brand area */} -
-
- {/* Replace with your own logo */} -
- Ostris AI Toolkit -
-
-

AI Toolkit

-
- - {/* Right side - login form */} -
-
-
- {/* Mobile logo */} -
- Ostris AI Toolkit -
-
- -

AI Toolkit

- -
-
- - setToken(e.target.value)} - className="w-full px-4 py-3 rounded-lg bg-gray-800 border border-gray-700 focus:border-blue-500 focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 text-gray-100 transition duration-200" - placeholder="Enter your password" - /> -
- The password is set with the environment variable AI_TOOLKIT_AUTH, the default is the super secure secret word "password" -
-
- - {error && ( -
{error}
- )} - - -
-
-
-
- ); -} diff --git a/src/components/Card.tsx b/src/components/Card.tsx deleted file mode 100644 index 13c7409b8be089a104eb6613664a188cb35d78d7..0000000000000000000000000000000000000000 --- a/src/components/Card.tsx +++ /dev/null @@ -1,15 +0,0 @@ -interface CardProps { - title?: string; - children?: React.ReactNode; -} - -const Card: React.FC = ({ title, children }) => { - return ( -
- {title &&

{title}

} - {children ? children : null} -
- ); -}; - -export default Card; diff --git a/src/components/ConfirmModal.tsx b/src/components/ConfirmModal.tsx deleted file mode 100644 index 6ecea8136accffeb9f312afb0130d2988ef485d3..0000000000000000000000000000000000000000 --- a/src/components/ConfirmModal.tsx +++ /dev/null @@ -1,201 +0,0 @@ -'use client'; -import { useRef } from 'react'; -import { useState, useEffect } from 'react'; -import { createGlobalState } from 'react-global-hooks'; -import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react'; -import { FaExclamationTriangle, FaInfo } from 'react-icons/fa'; -import { TextInput } from './formInputs'; -import React from 'react'; -import { useFromNull } from '@/hooks/useFromNull'; -import classNames from 'classnames'; - -export interface ConfirmState { - title: string; - message?: string; - confirmText?: string; - type?: 'danger' | 'warning' | 'info'; - inputTitle?: string; - onConfirm?: (value?: string) => void | Promise; - onCancel?: () => void; -} - -export const confirmstate = createGlobalState(null); - -export const openConfirm = (confirmProps: ConfirmState) => { - confirmstate.set(confirmProps); -}; - -export default function ConfirmModal() { - const [confirm, setConfirm] = confirmstate.use(); - const [isOpen, setIsOpen] = useState(false); - const [inputValue, setInputValue] = useState(''); - const inputRef = useRef(null); - - useFromNull(() => { - setTimeout(() => { - if (inputRef.current) { - inputRef.current.focus(); - } - }, 100); - }, [confirm]); - - useEffect(() => { - if (confirm) { - setIsOpen(true); - setInputValue(''); - } - }, [confirm]); - - useEffect(() => { - if (!isOpen) { - // use timeout to allow the dialog to close before resetting the state - setTimeout(() => { - setConfirm(null); - }, 500); - } - }, [isOpen]); - - const onCancel = () => { - if (confirm?.onCancel) { - confirm.onCancel(); - } - setIsOpen(false); - }; - - const onConfirm = () => { - if (confirm?.onConfirm) { - confirm.onConfirm(inputValue); - } - setIsOpen(false); - }; - - let Icon = FaExclamationTriangle; - let color = confirm?.type || 'danger'; - - // Use conditional rendering for icon - if (color === 'info') { - Icon = FaInfo; - } - - // Color mapping for background colors - const getBgColor = () => { - switch (color) { - case 'danger': - return 'bg-red-500'; - case 'warning': - return 'bg-yellow-500'; - case 'info': - return 'bg-blue-500'; - default: - return 'bg-red-500'; - } - }; - - // Color mapping for text colors - const getTextColor = () => { - switch (color) { - case 'danger': - return 'text-red-950'; - case 'warning': - return 'text-yellow-950'; - case 'info': - return 'text-blue-950'; - default: - return 'text-red-950'; - } - }; - - // Color mapping for titles - const getTitleColor = () => { - switch (color) { - case 'danger': - return 'text-red-500'; - case 'warning': - return 'text-yellow-500'; - case 'info': - return 'text-blue-500'; - default: - return 'text-red-500'; - } - }; - - // Button background color mapping - const getButtonBgColor = () => { - switch (color) { - case 'danger': - return 'bg-red-700 hover:bg-red-500'; - case 'warning': - return 'bg-yellow-700 hover:bg-yellow-500'; - case 'info': - return 'bg-blue-700 hover:bg-blue-500'; - default: - return 'bg-red-700 hover:bg-red-500'; - } - }; - - return ( - - - -
-
- -
-
-
-
-
- - {confirm?.title} - -
-

{confirm?.message}

-
-
{ - e.preventDefault() - onConfirm() - }}> - - -
-
-
-
-
-
- - -
-
-
-
-
- ); -} diff --git a/src/components/DatasetImageCard.tsx b/src/components/DatasetImageCard.tsx deleted file mode 100644 index 7eb562b5cd6edb7906f7e9e55507223ac5141878..0000000000000000000000000000000000000000 --- a/src/components/DatasetImageCard.tsx +++ /dev/null @@ -1,231 +0,0 @@ -import React, { useRef, useEffect, useState, ReactNode, KeyboardEvent } from 'react'; -import { FaTrashAlt, FaEye, FaEyeSlash } from 'react-icons/fa'; -import { openConfirm } from './ConfirmModal'; -import classNames from 'classnames'; -import { apiClient } from '@/utils/api'; -import { isVideo } from '@/utils/basic'; - -interface DatasetImageCardProps { - imageUrl: string; - alt: string; - children?: ReactNode; - className?: string; - onDelete?: () => void; -} - -const DatasetImageCard: React.FC = ({ - imageUrl, - alt, - children, - className = '', - onDelete = () => {}, -}) => { - const cardRef = useRef(null); - const [isVisible, setIsVisible] = useState(false); - const [inViewport, setInViewport] = useState(false); - const [loaded, setLoaded] = useState(false); - const [isCaptionLoaded, setIsCaptionLoaded] = useState(false); - const [caption, setCaption] = useState(''); - const [savedCaption, setSavedCaption] = useState(''); - const isGettingCaption = useRef(false); - - const fetchCaption = async () => { - if (isGettingCaption.current || isCaptionLoaded) return; - isGettingCaption.current = true; - apiClient - .post(`/api/caption/get`, { imgPath: imageUrl }) - .then(res => res.data) - .then(data => { - console.log('Caption fetched:', data); - - setCaption(data || ''); - setSavedCaption(data || ''); - setIsCaptionLoaded(true); - }) - .catch(error => { - console.error('Error fetching caption:', error); - }) - .finally(() => { - isGettingCaption.current = false; - }); - }; - - const saveCaption = () => { - const trimmedCaption = caption.trim(); - if (trimmedCaption === savedCaption) return; - apiClient - .post('/api/img/caption', { imgPath: imageUrl, caption: trimmedCaption }) - .then(res => res.data) - .then(data => { - console.log('Caption saved:', data); - setSavedCaption(trimmedCaption); - }) - .catch(error => { - console.error('Error saving caption:', error); - }); - }; - - // Only fetch caption when the component is both in viewport and visible - useEffect(() => { - if (inViewport && isVisible) { - fetchCaption(); - } - }, [inViewport, isVisible]); - - useEffect(() => { - // Create intersection observer to check viewport visibility - const observer = new IntersectionObserver( - entries => { - if (entries[0].isIntersecting) { - setInViewport(true); - // Initialize isVisible to true when first coming into view - if (!isVisible) { - setIsVisible(true); - } - } else { - setInViewport(false); - } - }, - { threshold: 0.1 }, - ); - - if (cardRef.current) { - observer.observe(cardRef.current); - } - - return () => { - observer.disconnect(); - }; - }, []); - - const toggleVisibility = (): void => { - setIsVisible(prev => !prev); - if (!isVisible && !isCaptionLoaded) { - fetchCaption(); - } - }; - - const handleLoad = (): void => { - setLoaded(true); - }; - - const handleKeyDown = (e: KeyboardEvent): void => { - // If Enter is pressed without Shift, prevent default behavior and save - if (e.key === 'Enter' && !e.shiftKey) { - e.preventDefault(); - saveCaption(); - } - }; - - const isCaptionCurrent = caption.trim() === savedCaption; - - const isItAVideo = isVideo(imageUrl); - - return ( -
- {/* Square image container */} -
-
- {inViewport && isVisible && ( - <> - {isItAVideo ? ( -
- {inViewport && isVisible && ( -
- {imageUrl} -
- )} -
-
- {inViewport && isVisible && isCaptionLoaded && ( -
{ - e.preventDefault(); - saveCaption(); - }} - onBlur={saveCaption} - > -