Spaces:
Running
Running
| import React, { useState, useEffect, useRef } from "react"; | |
| import { useToast } from "@/components/ui/use-toast"; | |
| import { TrainingConfig, TrainingStatus, LogEntry } from "@/components/training/types"; | |
| import TrainingHeader from "@/components/training/TrainingHeader"; | |
| import TrainingTabs from "@/components/training/TrainingTabs"; | |
| import ConfigurationTab from "@/components/training/ConfigurationTab"; | |
| import MonitoringTab from "@/components/training/MonitoringTab"; | |
| import TrainingControls from "@/components/training/TrainingControls"; | |
| const Training = () => { | |
| const { toast } = useToast(); | |
| const logContainerRef = useRef<HTMLDivElement>(null); | |
| const [trainingConfig, setTrainingConfig] = useState<TrainingConfig>({ | |
| dataset_repo_id: "", | |
| policy_type: "act", | |
| steps: 10000, | |
| batch_size: 8, | |
| seed: 1000, | |
| num_workers: 4, | |
| log_freq: 250, | |
| save_freq: 1000, | |
| eval_freq: 0, | |
| save_checkpoint: true, | |
| output_dir: "outputs/train", | |
| resume: false, | |
| wandb_enable: false, | |
| wandb_mode: "online", | |
| wandb_disable_artifact: false, | |
| eval_n_episodes: 10, | |
| eval_batch_size: 50, | |
| eval_use_async_envs: false, | |
| policy_device: "cuda", | |
| policy_use_amp: false, | |
| optimizer_type: "adam", | |
| use_policy_training_preset: true, | |
| }); | |
| const [trainingStatus, setTrainingStatus] = useState<TrainingStatus>({ | |
| training_active: false, | |
| current_step: 0, | |
| total_steps: 0, | |
| available_controls: { | |
| stop_training: false, | |
| pause_training: false, | |
| resume_training: false, | |
| }, | |
| }); | |
| const [logs, setLogs] = useState<LogEntry[]>([]); | |
| const [isStartingTraining, setIsStartingTraining] = useState(false); | |
| const [activeTab, setActiveTab] = useState<"config" | "monitoring">("config"); | |
| // Poll for training status and logs | |
| useEffect(() => { | |
| const pollInterval = setInterval(async () => { | |
| if (trainingStatus.training_active) { | |
| try { | |
| // Get status | |
| const statusResponse = await fetch("/training-status"); | |
| if (statusResponse.ok) { | |
| const status = await statusResponse.json(); | |
| setTrainingStatus(status); | |
| } | |
| // Get logs | |
| const logsResponse = await fetch("/training-logs"); | |
| if (logsResponse.ok) { | |
| const logsData = await logsResponse.json(); | |
| if (logsData.logs && logsData.logs.length > 0) { | |
| setLogs((prevLogs) => [...prevLogs, ...logsData.logs]); | |
| } | |
| } | |
| } catch (error) { | |
| console.error("Error polling training status:", error); | |
| } | |
| } | |
| }, 1000); | |
| return () => clearInterval(pollInterval); | |
| }, [trainingStatus.training_active]); | |
| // Auto-scroll logs | |
| useEffect(() => { | |
| if (logContainerRef.current) { | |
| logContainerRef.current.scrollTop = logContainerRef.current.scrollHeight; | |
| } | |
| }, [logs]); | |
| const handleStartTraining = async () => { | |
| if (!trainingConfig.dataset_repo_id.trim()) { | |
| toast({ | |
| title: "Error", | |
| description: "Dataset repository ID is required", | |
| variant: "destructive", | |
| }); | |
| return; | |
| } | |
| setIsStartingTraining(true); | |
| try { | |
| const response = await fetch("/start-training", { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| }, | |
| body: JSON.stringify(trainingConfig), | |
| }); | |
| if (response.ok) { | |
| const result = await response.json(); | |
| if (result.success) { | |
| toast({ | |
| title: "Training Started", | |
| description: "Training session has been started successfully", | |
| }); | |
| setActiveTab("monitoring"); | |
| setLogs([]); | |
| } else { | |
| toast({ | |
| title: "Error", | |
| description: result.message || "Failed to start training", | |
| variant: "destructive", | |
| }); | |
| } | |
| } else { | |
| toast({ | |
| title: "Error", | |
| description: "Failed to start training", | |
| variant: "destructive", | |
| }); | |
| } | |
| } catch (error) { | |
| console.error("Error starting training:", error); | |
| toast({ | |
| title: "Error", | |
| description: "Failed to start training", | |
| variant: "destructive", | |
| }); | |
| } finally { | |
| setIsStartingTraining(false); | |
| } | |
| }; | |
| const handleStopTraining = async () => { | |
| try { | |
| const response = await fetch("/stop-training", { | |
| method: "POST", | |
| }); | |
| if (response.ok) { | |
| const result = await response.json(); | |
| if (result.success) { | |
| toast({ | |
| title: "Training Stopped", | |
| description: "Training session has been stopped", | |
| }); | |
| } else { | |
| toast({ | |
| title: "Error", | |
| description: result.message || "Failed to stop training", | |
| variant: "destructive", | |
| }); | |
| } | |
| } | |
| } catch (error) { | |
| console.error("Error stopping training:", error); | |
| toast({ | |
| title: "Error", | |
| description: "Failed to stop training", | |
| variant: "destructive", | |
| }); | |
| } | |
| }; | |
| const updateConfig = <T extends keyof TrainingConfig>( | |
| key: T, | |
| value: TrainingConfig[T] | |
| ) => { | |
| setTrainingConfig((prev) => ({ ...prev, [key]: value })); | |
| }; | |
| const formatTime = (seconds: number): string => { | |
| const hours = Math.floor(seconds / 3600); | |
| const minutes = Math.floor((seconds % 3600) / 60); | |
| const secs = Math.floor(seconds % 60); | |
| return `${hours.toString().padStart(2, "0")}:${minutes | |
| .toString() | |
| .padStart(2, "0")}:${secs.toString().padStart(2, "0")}`; | |
| }; | |
| const getProgressPercentage = () => { | |
| if (trainingStatus.total_steps === 0) return 0; | |
| return (trainingStatus.current_step / trainingStatus.total_steps) * 100; | |
| }; | |
| return ( | |
| <div className="min-h-screen bg-slate-900 text-white p-4"> | |
| <div className="max-w-7xl mx-auto"> | |
| <TrainingHeader trainingStatus={trainingStatus} /> | |
| <TrainingTabs activeTab={activeTab} setActiveTab={setActiveTab} /> | |
| {activeTab === "config" && ( | |
| <ConfigurationTab config={trainingConfig} updateConfig={updateConfig} /> | |
| )} | |
| {activeTab === "monitoring" && ( | |
| <MonitoringTab | |
| trainingStatus={trainingStatus} | |
| logs={logs} | |
| logContainerRef={logContainerRef} | |
| getProgressPercentage={getProgressPercentage} | |
| formatTime={formatTime} | |
| /> | |
| )} | |
| <TrainingControls | |
| trainingStatus={trainingStatus} | |
| isStartingTraining={isStartingTraining} | |
| trainingConfig={trainingConfig} | |
| handleStartTraining={handleStartTraining} | |
| handleStopTraining={handleStopTraining} | |
| /> | |
| </div> | |
| </div> | |
| ); | |
| }; | |
| export default Training; | |