From d47d3f9694b05af842109c4e0166c5cfa273dc36 Mon Sep 17 00:00:00 2001 From: Aidan Timson Date: Mon, 19 Jan 2026 08:42:36 +0000 Subject: [PATCH] Create shared ai task metadata suggestion task (#29012) Co-authored-by: Wendelin <12148533+wendevlin@users.noreply.github.com> --- .../dialog-automation-save.ts | 213 +++--------- .../config/common/suggest-metadata-ai.ts | 306 ++++++++++++++++++ 2 files changed, 351 insertions(+), 168 deletions(-) create mode 100644 src/panels/config/common/suggest-metadata-ai.ts diff --git a/src/panels/config/automation/automation-save-dialog/dialog-automation-save.ts b/src/panels/config/automation/automation-save-dialog/dialog-automation-save.ts index 30bbe60ae0..5513becc5c 100644 --- a/src/panels/config/automation/automation-save-dialog/dialog-automation-save.ts +++ b/src/panels/config/automation/automation-save-dialog/dialog-automation-save.ts @@ -1,5 +1,4 @@ import { mdiPlus } from "@mdi/js"; -import { dump } from "js-yaml"; import type { CSSResultGroup } from "lit"; import { css, html, LitElement, nothing } from "lit"; import { customElement, property, state } from "lit/decorators"; @@ -21,13 +20,10 @@ import "../../../../components/ha-textfield"; import "../../../../components/ha-wa-dialog"; import "../../category/ha-category-picker"; -import { computeStateDomain } from "../../../../common/entity/compute_state_domain"; import { supportsMarkdownHelper } from "../../../../common/translations/markdown_support"; -import { subscribeOne } from "../../../../common/util/subscribe-one"; import type { GenDataTaskResult } from "../../../../data/ai_task"; -import { fetchCategoryRegistry } from "../../../../data/category_registry"; -import { subscribeEntityRegistry } from "../../../../data/entity/entity_registry"; -import { subscribeLabelRegistry } from "../../../../data/label/label_registry"; +import type { AutomationConfig } from "../../../../data/automation"; +import type { ScriptConfig } from "../../../../data/script"; import type { HassDialog } from "../../../../dialogs/make-dialog-manager"; import { haStyle, haStyleDialog } from "../../../../resources/styles"; import type { HomeAssistant } from "../../../../types"; @@ -35,6 +31,12 @@ import type { EntityRegistryUpdate, SaveDialogParams, } from "./show-dialog-automation-save"; +import { + type MetadataSuggestionResult, + SUGGESTION_INCLUDE_ALL, + generateMetadataSuggestionTask, + processMetadataSuggestion, +} from "../../common/suggest-metadata-ai"; @customElement("ha-dialog-automation-save") class DialogAutomationSave extends LitElement implements HassDialog { @@ -333,184 +335,59 @@ class DialogAutomationSave extends LitElement implements HassDialog { this.closeDialog(); } - private _getSuggestData() { - return Promise.all([ - subscribeOne(this.hass.connection, subscribeLabelRegistry).then((labs) => - Object.fromEntries(labs.map((lab) => [lab.label_id, lab.name])) - ), - subscribeOne(this.hass.connection, subscribeEntityRegistry).then((ents) => - Object.fromEntries(ents.map((ent) => [ent.entity_id, ent])) - ), - fetchCategoryRegistry(this.hass.connection, "automation").then((cats) => - Object.fromEntries(cats.map((cat) => [cat.category_id, cat.name])) - ), - ]); - } - private _generateTask = async (): Promise => { if (!this._params) { throw new Error("Dialog params not set"); } - - const [labels, entities, categories] = await this._getSuggestData(); - const inspirations: string[] = []; - - const domain = this._params.domain; - - for (const entity of Object.values(this.hass.states)) { - const entityEntry = entities[entity.entity_id]; - if ( - computeStateDomain(entity) !== domain || - entity.attributes.restored || - !entity.attributes.friendly_name || - !entityEntry - ) { - continue; - } - - let inspiration = `- ${entity.attributes.friendly_name}`; - - const category = categories[entityEntry.categories.automation]; - if (category) { - inspiration += ` (category: ${category})`; - } - - if (entityEntry.labels.length) { - inspiration += ` (labels: ${entityEntry.labels - .map((label) => labels[label]) - .join(", ")})`; - } - - inspirations.push(inspiration); - } - - const term = this._params.domain === "script" ? "script" : "automation"; - - return { - type: "data", - task: { - task_name: `frontend__${term}__save`, - instructions: `Suggest in language "${this.hass.language}" a name, description, category and labels for the following Home Assistant ${term}. - -The name should be relevant to the ${term}'s purpose. -${ - inspirations.length - ? `The name should be in same style and sentence capitalization as existing ${term}s. -Suggest a category and labels if relevant to the ${term}'s purpose. -Only suggest category and labels that are already used by existing ${term}s.` - : `The name should be short, descriptive, sentence case, and written in the language ${this.hass.language}.` -} -If the ${term} contains 5+ steps, include a short description. - -For inspiration, here are existing ${term}s: -${inspirations.join("\n")} - -The ${term} configuration is as follows: - -${dump(this._params.config)} -`, - structure: { - name: { - description: "The name of the automation", - required: true, - selector: { - text: {}, - }, - }, - description: { - description: "A short description of the automation", - required: false, - selector: { - text: {}, - }, - }, - labels: { - description: "Labels for the automation", - required: false, - selector: { - text: { - multiple: true, - }, - }, - }, - category: { - description: "The category of the automation", - required: false, - selector: { - select: { - options: Object.entries(categories).map(([id, name]) => ({ - value: id, - label: name, - })), - }, - }, - }, - }, - }, - }; + return generateMetadataSuggestionTask( + this.hass.connection, + this.hass.states, + this.hass.language, + this._params.domain, + this._params.config + ); }; private async _handleSuggestion( - event: CustomEvent< - GenDataTaskResult<{ - name: string; - description?: string; - category?: string; - labels?: string[]; - }> - > + event: CustomEvent> ) { + if (!this._params) { + throw new Error("Dialog params not set"); + } const result = event.detail; - const [labels, _entities, categories] = await this._getSuggestData(); + const processed = await processMetadataSuggestion( + this.hass.connection, + this._params.domain, + result, + SUGGESTION_INCLUDE_ALL + ); - this._newName = result.data.name; - if (result.data.description) { - this._newDescription = result.data.description; + this._newName = processed.name; + + if (processed.description) { + this._newDescription = processed.description; if (!this._visibleOptionals.includes("description")) { this._visibleOptionals = [...this._visibleOptionals, "description"]; } } - if (result.data.category) { - // We get back category name, convert it to ID - const categoryId = Object.entries(categories).find( - ([, name]) => name === result.data.category - )?.[0]; - if (categoryId) { - this._entryUpdates = { - ...this._entryUpdates, - category: categoryId, - }; - if (!this._visibleOptionals.includes("category")) { - this._visibleOptionals = [...this._visibleOptionals, "category"]; - } + + if (processed.category) { + this._entryUpdates = { + ...this._entryUpdates, + category: processed.category, + }; + if (!this._visibleOptionals.includes("category")) { + this._visibleOptionals = [...this._visibleOptionals, "category"]; } } - if (result.data.labels?.length) { - // We get back label names, convert them to IDs - const newLabels: Record = Object.fromEntries( - result.data.labels.map((name) => [name, undefined]) - ); - let toFind = result.data.labels.length; - for (const [labelId, labelName] of Object.entries(labels)) { - if (labelName in newLabels && newLabels[labelName] === undefined) { - newLabels[labelName] = labelId; - toFind--; - if (toFind === 0) { - break; - } - } - } - const foundLabels = Object.values(newLabels).filter( - (labelId) => labelId !== undefined - ); - if (foundLabels.length) { - this._entryUpdates = { - ...this._entryUpdates, - labels: foundLabels, - }; - if (!this._visibleOptionals.includes("labels")) { - this._visibleOptionals = [...this._visibleOptionals, "labels"]; - } + + if (processed.labels?.length) { + this._entryUpdates = { + ...this._entryUpdates, + labels: processed.labels, + }; + if (!this._visibleOptionals.includes("labels")) { + this._visibleOptionals = [...this._visibleOptionals, "labels"]; } } } diff --git a/src/panels/config/common/suggest-metadata-ai.ts b/src/panels/config/common/suggest-metadata-ai.ts new file mode 100644 index 0000000000..4a4d6a911e --- /dev/null +++ b/src/panels/config/common/suggest-metadata-ai.ts @@ -0,0 +1,306 @@ +import { dump } from "js-yaml"; +import { computeDomain } from "../../../common/entity/compute_domain"; +import { subscribeOne } from "../../../common/util/subscribe-one"; +import type { AITaskStructure, GenDataTaskResult } from "../../../data/ai_task"; +import { fetchCategoryRegistry } from "../../../data/category_registry"; +import { + subscribeEntityRegistry, + type EntityRegistryEntry, +} from "../../../data/entity/entity_registry"; +import { subscribeLabelRegistry } from "../../../data/label/label_registry"; +import type { HomeAssistant } from "../../../types"; +import type { SuggestWithAIGenerateTask } from "../../../components/ha-suggest-with-ai-button"; + +export interface MetadataSuggestionResult { + name: string; + description?: string; + category?: string; + labels?: string[]; +} + +export type MetadataSuggestionDomain = "automation" | "script"; + +export interface MetadataSuggestionInclude { + description?: boolean; + categories?: boolean; + labels?: boolean; +} + +type Categories = Record; +type Entities = Record; +type Labels = Record; + +export const SUGGESTION_INCLUDE_ALL: MetadataSuggestionInclude = { + description: true, + categories: true, + labels: true, +} as const; + +const tryCatchEmptyObject = (promise: Promise): Promise => + promise.catch((err) => { + // eslint-disable-next-line no-console + console.error("Error fetching data for suggestion: ", err); + return {} as T; + }); + +const fetchCategories = ( + connection: HomeAssistant["connection"], + domain: MetadataSuggestionDomain +): Promise => + tryCatchEmptyObject( + fetchCategoryRegistry(connection, domain).then((cats) => + Object.fromEntries(cats.map((cat) => [cat.category_id, cat.name])) + ) + ); + +const fetchEntities = ( + connection: HomeAssistant["connection"] +): Promise => + tryCatchEmptyObject( + subscribeOne(connection, subscribeEntityRegistry).then((ents) => + Object.fromEntries(ents.map((ent) => [ent.entity_id, ent])) + ) + ); + +const fetchLabels = ( + connection: HomeAssistant["connection"] +): Promise => + tryCatchEmptyObject( + subscribeOne(connection, subscribeLabelRegistry).then((labs) => + Object.fromEntries(labs.map((lab) => [lab.label_id, lab.name])) + ) + ); + +function buildMetadataInspirations( + domain: MetadataSuggestionDomain, + states: HomeAssistant["states"], + entities: Entities, + categories?: Categories, + labels?: Labels +): string[] { + const inspirations: string[] = []; + + for (const entityId of Object.keys(entities)) { + const entityEntry = entities[entityId]; + if (!entityEntry || computeDomain(entityId) !== domain) { + continue; + } + + const entity = states[entityId]; + if ( + !entity || + entity.attributes.restored || + !entity.attributes.friendly_name + ) { + continue; + } + + let inspiration = `- ${entity.attributes.friendly_name}`; + + // Get the category for this domain + if (categories && categories[entityEntry.categories[domain]]) { + inspiration += ` (category: ${categories[entityEntry.categories[domain]]})`; + } + + if (labels && entityEntry.labels.length) { + inspiration += ` (labels: ${entityEntry.labels + .map((label) => labels[label]) + .join(", ")})`; + } + + inspirations.push(inspiration); + } + + return inspirations; +} + +/** + * Generates an AI task for suggesting metadata + * for automations or scripts based on their configuration. + * + * @param connection - Home Assistant connection + * @param states - Current state objects + * @param language - User's language preference + * @param domain - The domain to suggest metadata for (automation, script) + * @param config - The configuration to suggest metadata for + * @param include - The metadata fields to include in the suggestion + * @returns Promise resolving to the AI task structure + */ +export async function generateMetadataSuggestionTask( + connection: HomeAssistant["connection"], + states: HomeAssistant["states"], + language: HomeAssistant["language"], + domain: MetadataSuggestionDomain, + config: T, + include = SUGGESTION_INCLUDE_ALL +): Promise { + const [categories, entities, labels] = await Promise.all([ + include.categories + ? fetchCategories(connection, domain) + : Promise.resolve(undefined), + fetchEntities(connection), + include.labels ? fetchLabels(connection) : Promise.resolve(undefined), + ]); + + const inspirations = buildMetadataInspirations( + domain, + states, + entities, + categories, + labels + ); + + const structure: AITaskStructure = { + name: { + description: `The name of the ${domain}`, + required: true, + selector: { + text: {}, + }, + }, + ...(include.description && { + description: { + description: `A short description of the ${domain}`, + required: false, + selector: { + text: {}, + }, + }, + }), + ...(include.labels && { + labels: { + description: `Labels for the ${domain}`, + required: false, + selector: { + text: { + multiple: true, + }, + }, + }, + }), + ...(include.categories && + categories && { + category: { + description: `The category of the ${domain}`, + required: false, + selector: { + select: { + options: Object.entries(categories).map(([id, name]) => ({ + value: id, + label: name, + })), + }, + }, + }, + }), + }; + + const categoryLabelText: string[] = []; + if (include.categories) { + categoryLabelText.push("category"); + } + if (include.labels) { + categoryLabelText.push("labels"); + } + const categoryLabelString = + categoryLabelText.length > 0 ? `, ${categoryLabelText.join(" and ")}` : ""; + + return { + type: "data", + task: { + task_name: `frontend__${domain}__save`, + instructions: `Suggest in language "${language}" a name${include.description ? ", description" : ""}${categoryLabelString} for the following Home Assistant ${domain}. + +The name should be relevant to the ${domain}'s purpose. +${ + inspirations.length + ? `The name should be in same style and sentence capitalization as existing ${domain}s.${ + include.categories || include.labels + ? ` +Suggest ${categoryLabelText.join(" and ")} if relevant to the ${domain}'s purpose. +Only suggest ${categoryLabelText.join(" and ")} that are already used by existing ${domain}s.` + : "" + }` + : `The name should be short, descriptive, sentence case, and written in the language ${language}.` +}${ + include.description + ? ` +If the ${domain} contains 5+ steps, include a short description.` + : "" + } + +For inspiration, here are existing ${domain}s: +${inspirations.join("\n")} + +The ${domain} configuration is as follows: + +${dump(config)} +`, + structure, + }, + }; +} + +/** + * Processes the result of an AI task for suggesting metadata + * for automations or scripts based on their configuration. + * + * @param connection - Home Assistant connection + * @param domain - The domain of the ${domain} + * @param result - The result of the AI task + * @param include - The metadata fields to include in the suggestion + * @returns Promise resolving to the processed metadata suggestion + */ +export async function processMetadataSuggestion( + connection: HomeAssistant["connection"], + domain: MetadataSuggestionDomain, + result: GenDataTaskResult, + include: MetadataSuggestionInclude +): Promise { + const [categories, labels] = await Promise.all([ + include.categories + ? fetchCategories(connection, domain) + : Promise.resolve(undefined), + include.labels ? fetchLabels(connection) : Promise.resolve(undefined), + ]); + + const processed: MetadataSuggestionResult = { + name: result.data.name, + description: include.description ? result.data.description : undefined, + }; + + // Convert category name to ID + if (include.categories && categories && result.data.category) { + const categoryId = Object.entries(categories).find( + ([, name]) => name === result.data.category + )?.[0]; + if (categoryId) { + processed.category = categoryId; + } + } + + // Convert label names to IDs + if (include.labels && labels && result.data.labels?.length) { + const newLabels: Record = Object.fromEntries( + result.data.labels.map((name) => [name, undefined]) + ); + let toFind = result.data.labels.length; + for (const [labelId, labelName] of Object.entries(labels)) { + if (labelName in newLabels && newLabels[labelName] === undefined) { + newLabels[labelName] = labelId; + toFind--; + if (toFind === 0) { + break; + } + } + } + const foundLabels = Object.values(newLabels).filter( + (labelId): labelId is string => labelId !== undefined + ); + if (foundLabels.length) { + processed.labels = foundLabels; + } + } + + return processed; +}