import TextEntryDialog from "@/components/overlay/dialog/TextEntryDialog"; import { Button, buttonVariants } from "@/components/ui/button"; import { AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, } from "@/components/ui/alert-dialog"; import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, } from "@/components/ui/dialog"; import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuSeparator, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import { Toaster } from "@/components/ui/sonner"; import { Tooltip, TooltipContent, TooltipTrigger, } from "@/components/ui/tooltip"; import useKeyboardListener from "@/hooks/use-keyboard-listener"; import useOptimisticState from "@/hooks/use-optimistic-state"; import { cn } from "@/lib/utils"; import { CustomClassificationModelConfig } from "@/types/frigateConfig"; import { TooltipPortal } from "@radix-ui/react-tooltip"; import axios from "axios"; import { MutableRefObject, useCallback, useEffect, useMemo, useRef, useState, } from "react"; import { isDesktop, isMobileOnly } from "react-device-detect"; import { Trans, useTranslation } from "react-i18next"; import { LuPencil, LuTrash2 } from "react-icons/lu"; import { toast } from "sonner"; import useSWR from "swr"; import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog"; import { TbCategoryPlus } from "react-icons/tb"; import BlurredIconButton from "@/components/button/BlurredIconButton"; import { useModelState } from "@/api/ws"; import { ModelState } from "@/types/ws"; import ActivityIndicator from "@/components/indicators/activity-indicator"; import { useNavigate } from "react-router-dom"; import { IoMdArrowRoundBack } from "react-icons/io"; import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog"; import useApiFilter from "@/hooks/use-api-filter"; import { ClassificationDatasetResponse, ClassificationItemData, TrainFilter, } from "@/types/classification"; import { ClassificationCard, GroupedClassificationCard, } from "@/components/card/ClassificationCard"; import { Event } from "@/types/event"; import SearchDetailDialog, { SearchTab, } from "@/components/overlay/detail/SearchDetailDialog"; import { SearchResult } from "@/types/search"; import { HiSparkles } from "react-icons/hi"; type ModelTrainingViewProps = { model: CustomClassificationModelConfig; }; export default function ModelTrainingView({ model }: ModelTrainingViewProps) { const { t } = useTranslation(["views/classificationModel"]); const navigate = useNavigate(); const [page, setPage] = useState("train"); const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); // title useEffect(() => { document.title = `${model.name} - ${t("documentTitle")}`; }, [model.name, t]); // model state const [wasTraining, setWasTraining] = useState(false); const { payload: lastModelState } = useModelState(model.name, true); const modelState = useMemo(() => { if (!lastModelState || lastModelState == "downloaded") { return "complete"; } return lastModelState; }, [lastModelState]); useEffect(() => { if (!wasTraining) { return; } if (modelState == "complete") { toast.success(t("toast.success.trainedModel"), { position: "top-center", closeButton: true, }); setWasTraining(false); refreshDataset(); } else if (modelState == "failed") { toast.error(t("toast.error.trainingFailed"), { position: "top-center", closeButton: true, }); setWasTraining(false); } // only refresh when modelState changes // eslint-disable-next-line react-hooks/exhaustive-deps }, [modelState]); // dataset const { data: trainImages, mutate: refreshTrain } = useSWR( `classification/${model.name}/train`, ); const { data: datasetResponse, mutate: refreshDataset } = useSWR( `classification/${model.name}/dataset`, ); const dataset = datasetResponse?.categories || {}; const trainingMetadata = datasetResponse?.training_metadata; const [trainFilter, setTrainFilter] = useApiFilter(); const refreshAll = useCallback(() => { refreshTrain(); refreshDataset(); }, [refreshTrain, refreshDataset]); // image multiselect const [selectedImages, setSelectedImages] = useState([]); const onClickImages = useCallback( (images: string[], ctrl: boolean) => { if (selectedImages.length == 0 && !ctrl) { return; } let newSelectedImages = [...selectedImages]; images.forEach((imageId) => { const index = newSelectedImages.indexOf(imageId); if (index != -1) { if (selectedImages.length == 1) { newSelectedImages = []; } else { const copy = [ ...newSelectedImages.slice(0, index), ...newSelectedImages.slice(index + 1), ]; newSelectedImages = copy; } } else { newSelectedImages.push(imageId); } }); setSelectedImages(newSelectedImages); }, [selectedImages, setSelectedImages], ); // actions const trainModel = useCallback(() => { axios .post(`classification/${model.name}/train`) .then((resp) => { if (resp.status == 200) { setWasTraining(true); toast.success(t("toast.success.trainingModel"), { position: "top-center", closeButton: true, }); } }) .catch((error) => { const errorMessage = error.response?.data?.message || error.response?.data?.detail || "Unknown error"; toast.error(t("toast.error.trainingFailedToStart", { errorMessage }), { position: "top-center", closeButton: true, }); }); }, [model, t]); const [deleteDialogOpen, setDeleteDialogOpen] = useState( null, ); const onRename = useCallback( (old_name: string, new_name: string) => { axios .put(`/classification/${model.name}/dataset/${old_name}/rename`, { new_category: new_name, }) .then((resp) => { if (resp.status == 200) { toast.success( t("toast.success.renamedCategory", { name: new_name }), { position: "top-center", }, ); setPageToggle(new_name); refreshDataset(); } }) .catch((error) => { const errorMessage = error.response?.data?.message || error.response?.data?.detail || "Unknown error"; toast.error(t("toast.error.renameCategoryFailed", { errorMessage }), { position: "top-center", }); }); }, [model, setPageToggle, refreshDataset, t], ); const onDelete = useCallback( (ids: string[], isName: boolean = false, category?: string) => { const targetCategory = category || pageToggle; const api = targetCategory == "train" ? `/classification/${model.name}/train/delete` : `/classification/${model.name}/dataset/${targetCategory}/delete`; axios .post(api, { ids }) .then((resp) => { setSelectedImages([]); if (resp.status == 200) { if (isName) { toast.success( t("toast.success.deletedCategory", { count: ids.length }), { position: "top-center", }, ); } else { toast.success( t("toast.success.deletedImage", { count: ids.length }), { position: "top-center", }, ); } // Always refresh dataset to update the categories list refreshDataset(); if (pageToggle == "train") { refreshTrain(); } } }) .catch((error) => { const errorMessage = error.response?.data?.message || error.response?.data?.detail || "Unknown error"; if (isName) { toast.error( t("toast.error.deleteCategoryFailed", { errorMessage }), { position: "top-center", }, ); } else { toast.error(t("toast.error.deleteImageFailed", { errorMessage }), { position: "top-center", }); } }); }, [pageToggle, model, refreshTrain, refreshDataset, t], ); // keyboard const contentRef = useRef(null); useKeyboardListener( ["a", "Escape"], (key, modifiers) => { if (!modifiers.down) { return true; } switch (key) { case "a": if (modifiers.ctrl && !modifiers.repeat) { if (selectedImages.length) { setSelectedImages([]); } else { setSelectedImages([ ...(pageToggle === "train" ? trainImages || [] : dataset?.[pageToggle] || []), ]); } return true; } break; case "Escape": setSelectedImages([]); return true; } return false; }, contentRef, ); useEffect(() => { setSelectedImages([]); }, [pageToggle]); return (
setDeleteDialogOpen(null)} > {t( pageToggle == "train" ? "deleteTrainImages.title" : "deleteDatasetImages.title", )} {pageToggle == "train" ? "deleteTrainImages.desc" : "deleteDatasetImages.desc"} {t("button.cancel", { ns: "common" })} { if (deleteDialogOpen) { onDelete(deleteDialogOpen); setDeleteDialogOpen(null); } }} > {t("button.delete", { ns: "common" })}
{(isDesktop || !selectedImages?.length) && (
)} {selectedImages?.length > 0 ? (
{`${selectedImages.length} selected`}
{"|"}
setSelectedImages([])} > {t("button.unselect", { ns: "common" })}
) : (
{(!trainingMetadata?.dataset_changed || (modelState != "complete" && modelState != "failed")) && ( {modelState == "training" ? t("tooltip.trainingInProgress") : !trainingMetadata?.dataset_changed ? t("tooltip.noChanges") : t("tooltip.modelNotReady")} )}
)}
{pageToggle == "train" ? ( ) : ( )}
); } type LibrarySelectorProps = { pageToggle: string | undefined; dataset: { [id: string]: string[] }; trainImages: string[]; setPageToggle: (toggle: string) => void; onDelete: (ids: string[], isName: boolean, category?: string) => void; onRename: (old_name: string, new_name: string) => void; }; function LibrarySelector({ pageToggle, dataset, trainImages, setPageToggle, onDelete, onRename, }: LibrarySelectorProps) { const { t } = useTranslation(["views/classificationModel"]); // data const [confirmDelete, setConfirmDelete] = useState(null); const [renameClass, setRenameClass] = useState(null); const pageTitle = useMemo(() => { if (pageToggle != "train") { return pageToggle; } if (isMobileOnly) { return t("train.titleShort"); } return t("train.title"); }, [pageToggle, t]); // interaction const handleDeleteCategory = useCallback( (name: string) => { // Get all image IDs for this category const imageIds = dataset?.[name] || []; onDelete(imageIds, true, name); setPageToggle("train"); }, [dataset, onDelete, setPageToggle], ); const handleSetOpen = useCallback( (open: boolean) => { setRenameClass(open ? renameClass : null); }, [renameClass], ); return ( <> !open && setConfirmDelete(null)} > {Object.keys(dataset).length <= 2 ? t("deleteCategory.minClassesTitle") : t("deleteCategory.title")} {Object.keys(dataset).length <= 2 ? t("deleteCategory.minClassesDesc") : t("deleteCategory.desc", { name: confirmDelete })}
{Object.keys(dataset).length <= 2 ? ( ) : ( <> )}
{ onRename(renameClass!, newName); setRenameClass(null); }} defaultValue={renameClass || ""} regexPattern={/^[\p{L}\p{N}\s'_-]{1,50}$/u} regexErrorMessage={t("description.invalidName")} /> setPageToggle("train")} >
{t("train.title")}
({trainImages.length})
{trainImages.length > 0 && Object.keys(dataset).length > 0 && ( <>
{t("categories")}
)} {Object.keys(dataset).map((id) => (
setPageToggle(id)} > {id.replaceAll("_", " ")} ({dataset?.[id].length})
{id != "none" && (
{t("button.renameCategory")} {t("button.deleteCategory")}
)}
))}
); } type DatasetGridProps = { contentRef: MutableRefObject; modelName: string; categoryName: string; images: string[]; selectedImages: string[]; onClickImages: (images: string[], ctrl: boolean) => void; onDelete: (ids: string[]) => void; }; function DatasetGrid({ contentRef, modelName, categoryName, images, selectedImages, onClickImages, onDelete, }: DatasetGridProps) { const { t } = useTranslation(["views/classificationModel"]); const classData = useMemo( () => images.sort((a, b) => a.localeCompare(b)), [images], ); return (
{classData.map((image) => (
onClickImages([data.filename], true)} > { e.stopPropagation(); onDelete([image]); }} /> {t("button.deleteClassificationAttempts")}
))}
); } type TrainGridProps = { model: CustomClassificationModelConfig; contentRef: MutableRefObject; classes: string[]; trainImages: string[]; trainFilter?: TrainFilter; selectedImages: string[]; onClickImages: (images: string[], ctrl: boolean) => void; onRefresh: () => void; onDelete: (ids: string[]) => void; }; function TrainGrid({ model, contentRef, classes, trainImages, trainFilter, selectedImages, onClickImages, onRefresh, onDelete, }: TrainGridProps) { const trainData = useMemo( () => trainImages .map((raw) => { const parts = raw.replaceAll(".webp", "").split("-"); const rawScore = Number.parseFloat(parts[4]); return { filename: raw, filepath: `clips/${model.name}/train/${raw}`, timestamp: Number.parseFloat(parts[2]), eventId: `${parts[0]}-${parts[1]}`, name: parts[3], score: rawScore, }; }) .filter((data) => { if (!trainFilter) { return true; } if (trainFilter.classes && !trainFilter.classes.includes(data.name)) { return false; } if (trainFilter.min_score && trainFilter.min_score > data.score) { return false; } if (trainFilter.max_score && trainFilter.max_score < data.score) { return false; } return true; }) .sort((a, b) => b.timestamp - a.timestamp), [model, trainImages, trainFilter], ); if (model.state_config) { return ( ); } return ( ); } type StateTrainGridProps = { model: CustomClassificationModelConfig; contentRef: MutableRefObject; classes: string[]; trainData?: ClassificationItemData[]; selectedImages: string[]; onClickImages: (images: string[], ctrl: boolean) => void; onRefresh: () => void; onDelete: (ids: string[]) => void; }; function StateTrainGrid({ model, contentRef, classes, trainData, selectedImages, onClickImages, onRefresh, }: StateTrainGridProps) { const threshold = useMemo(() => { return { recognition: model.threshold, unknown: model.threshold, }; }, [model]); return (
{trainData?.map((data) => (
onClickImages([data.filename], meta)} >
))}
); } type ObjectTrainGridProps = { model: CustomClassificationModelConfig; contentRef: MutableRefObject; classes: string[]; trainData?: ClassificationItemData[]; selectedImages: string[]; onClickImages: (images: string[], ctrl: boolean) => void; onRefresh: () => void; }; function ObjectTrainGrid({ model, contentRef, classes, trainData, selectedImages, onClickImages, onRefresh, }: ObjectTrainGridProps) { // item data const groups = useMemo(() => { const groups: { [eventId: string]: ClassificationItemData[] } = {}; trainData ?.sort((a, b) => a.eventId!.localeCompare(b.eventId!)) .reverse() .forEach((data) => { if (groups[data.eventId!]) { groups[data.eventId!].push(data); } else { groups[data.eventId!] = [data]; } }); return groups; }, [trainData]); const eventIdsQuery = useMemo(() => Object.keys(groups).join(","), [groups]); const { data: events } = useSWR([ "event_ids", { ids: eventIdsQuery }, ]); const threshold = useMemo(() => { return { recognition: model.threshold, unknown: model.threshold, }; }, [model]); // selection const [selectedEvent, setSelectedEvent] = useState(); const [dialogTab, setDialogTab] = useState("snapshot"); // handlers const handleClickEvent = useCallback( ( group: ClassificationItemData[], event: Event | undefined, meta: boolean, ) => { if (event && selectedImages.length == 0 && !meta) { setSelectedEvent(event); } else { const anySelected = group.find((item) => selectedImages.includes(item.filename)) != undefined; if (anySelected) { // deselect all const toDeselect: string[] = []; group.forEach((item) => { if (selectedImages.includes(item.filename)) { toDeselect.push(item.filename); } }); onClickImages(toDeselect, false); } else { // select all onClickImages( group.map((item) => item.filename), true, ); } } }, [selectedImages, onClickImages], ); return ( <> setSelectedEvent(search as unknown as Event)} setInputFocused={() => {}} />
{Object.entries(groups).map(([key, group]) => { const event = events?.find((ev) => ev.id == key); return (
{ if (data) { onClickImages([data.filename], true); } else { handleClickEvent(group, event, true); } }} > {(data) => ( <> )}
); })}
); }