Create shared ai task metadata suggestion task (#29012)

Co-authored-by: Wendelin <12148533+wendevlin@users.noreply.github.com>
This commit is contained in:
Aidan Timson 2026-01-19 08:42:36 +00:00 committed by GitHub
parent 622df52167
commit d47d3f9694
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 351 additions and 168 deletions

View File

@ -1,5 +1,4 @@
import { mdiPlus } from "@mdi/js";
import { dump } from "js-yaml";
import type { CSSResultGroup } from "lit";
import { css, html, LitElement, nothing } from "lit";
import { customElement, property, state } from "lit/decorators";
@ -21,13 +20,10 @@ import "../../../../components/ha-textfield";
import "../../../../components/ha-wa-dialog";
import "../../category/ha-category-picker";
import { computeStateDomain } from "../../../../common/entity/compute_state_domain";
import { supportsMarkdownHelper } from "../../../../common/translations/markdown_support";
import { subscribeOne } from "../../../../common/util/subscribe-one";
import type { GenDataTaskResult } from "../../../../data/ai_task";
import { fetchCategoryRegistry } from "../../../../data/category_registry";
import { subscribeEntityRegistry } from "../../../../data/entity/entity_registry";
import { subscribeLabelRegistry } from "../../../../data/label/label_registry";
import type { AutomationConfig } from "../../../../data/automation";
import type { ScriptConfig } from "../../../../data/script";
import type { HassDialog } from "../../../../dialogs/make-dialog-manager";
import { haStyle, haStyleDialog } from "../../../../resources/styles";
import type { HomeAssistant } from "../../../../types";
@ -35,6 +31,12 @@ import type {
EntityRegistryUpdate,
SaveDialogParams,
} from "./show-dialog-automation-save";
import {
type MetadataSuggestionResult,
SUGGESTION_INCLUDE_ALL,
generateMetadataSuggestionTask,
processMetadataSuggestion,
} from "../../common/suggest-metadata-ai";
@customElement("ha-dialog-automation-save")
class DialogAutomationSave extends LitElement implements HassDialog {
@ -333,184 +335,59 @@ class DialogAutomationSave extends LitElement implements HassDialog {
this.closeDialog();
}
private _getSuggestData() {
return Promise.all([
subscribeOne(this.hass.connection, subscribeLabelRegistry).then((labs) =>
Object.fromEntries(labs.map((lab) => [lab.label_id, lab.name]))
),
subscribeOne(this.hass.connection, subscribeEntityRegistry).then((ents) =>
Object.fromEntries(ents.map((ent) => [ent.entity_id, ent]))
),
fetchCategoryRegistry(this.hass.connection, "automation").then((cats) =>
Object.fromEntries(cats.map((cat) => [cat.category_id, cat.name]))
),
]);
}
private _generateTask = async (): Promise<SuggestWithAIGenerateTask> => {
if (!this._params) {
throw new Error("Dialog params not set");
}
const [labels, entities, categories] = await this._getSuggestData();
const inspirations: string[] = [];
const domain = this._params.domain;
for (const entity of Object.values(this.hass.states)) {
const entityEntry = entities[entity.entity_id];
if (
computeStateDomain(entity) !== domain ||
entity.attributes.restored ||
!entity.attributes.friendly_name ||
!entityEntry
) {
continue;
}
let inspiration = `- ${entity.attributes.friendly_name}`;
const category = categories[entityEntry.categories.automation];
if (category) {
inspiration += ` (category: ${category})`;
}
if (entityEntry.labels.length) {
inspiration += ` (labels: ${entityEntry.labels
.map((label) => labels[label])
.join(", ")})`;
}
inspirations.push(inspiration);
}
const term = this._params.domain === "script" ? "script" : "automation";
return {
type: "data",
task: {
task_name: `frontend__${term}__save`,
instructions: `Suggest in language "${this.hass.language}" a name, description, category and labels for the following Home Assistant ${term}.
The name should be relevant to the ${term}'s purpose.
${
inspirations.length
? `The name should be in same style and sentence capitalization as existing ${term}s.
Suggest a category and labels if relevant to the ${term}'s purpose.
Only suggest category and labels that are already used by existing ${term}s.`
: `The name should be short, descriptive, sentence case, and written in the language ${this.hass.language}.`
}
If the ${term} contains 5+ steps, include a short description.
For inspiration, here are existing ${term}s:
${inspirations.join("\n")}
The ${term} configuration is as follows:
${dump(this._params.config)}
`,
structure: {
name: {
description: "The name of the automation",
required: true,
selector: {
text: {},
},
},
description: {
description: "A short description of the automation",
required: false,
selector: {
text: {},
},
},
labels: {
description: "Labels for the automation",
required: false,
selector: {
text: {
multiple: true,
},
},
},
category: {
description: "The category of the automation",
required: false,
selector: {
select: {
options: Object.entries(categories).map(([id, name]) => ({
value: id,
label: name,
})),
},
},
},
},
},
};
return generateMetadataSuggestionTask<AutomationConfig | ScriptConfig>(
this.hass.connection,
this.hass.states,
this.hass.language,
this._params.domain,
this._params.config
);
};
private async _handleSuggestion(
event: CustomEvent<
GenDataTaskResult<{
name: string;
description?: string;
category?: string;
labels?: string[];
}>
>
event: CustomEvent<GenDataTaskResult<MetadataSuggestionResult>>
) {
if (!this._params) {
throw new Error("Dialog params not set");
}
const result = event.detail;
const [labels, _entities, categories] = await this._getSuggestData();
const processed = await processMetadataSuggestion(
this.hass.connection,
this._params.domain,
result,
SUGGESTION_INCLUDE_ALL
);
this._newName = result.data.name;
if (result.data.description) {
this._newDescription = result.data.description;
this._newName = processed.name;
if (processed.description) {
this._newDescription = processed.description;
if (!this._visibleOptionals.includes("description")) {
this._visibleOptionals = [...this._visibleOptionals, "description"];
}
}
if (result.data.category) {
// We get back category name, convert it to ID
const categoryId = Object.entries(categories).find(
([, name]) => name === result.data.category
)?.[0];
if (categoryId) {
this._entryUpdates = {
...this._entryUpdates,
category: categoryId,
};
if (!this._visibleOptionals.includes("category")) {
this._visibleOptionals = [...this._visibleOptionals, "category"];
}
if (processed.category) {
this._entryUpdates = {
...this._entryUpdates,
category: processed.category,
};
if (!this._visibleOptionals.includes("category")) {
this._visibleOptionals = [...this._visibleOptionals, "category"];
}
}
if (result.data.labels?.length) {
// We get back label names, convert them to IDs
const newLabels: Record<string, undefined | string> = Object.fromEntries(
result.data.labels.map((name) => [name, undefined])
);
let toFind = result.data.labels.length;
for (const [labelId, labelName] of Object.entries(labels)) {
if (labelName in newLabels && newLabels[labelName] === undefined) {
newLabels[labelName] = labelId;
toFind--;
if (toFind === 0) {
break;
}
}
}
const foundLabels = Object.values(newLabels).filter(
(labelId) => labelId !== undefined
);
if (foundLabels.length) {
this._entryUpdates = {
...this._entryUpdates,
labels: foundLabels,
};
if (!this._visibleOptionals.includes("labels")) {
this._visibleOptionals = [...this._visibleOptionals, "labels"];
}
if (processed.labels?.length) {
this._entryUpdates = {
...this._entryUpdates,
labels: processed.labels,
};
if (!this._visibleOptionals.includes("labels")) {
this._visibleOptionals = [...this._visibleOptionals, "labels"];
}
}
}

View File

@ -0,0 +1,306 @@
import { dump } from "js-yaml";
import { computeDomain } from "../../../common/entity/compute_domain";
import { subscribeOne } from "../../../common/util/subscribe-one";
import type { AITaskStructure, GenDataTaskResult } from "../../../data/ai_task";
import { fetchCategoryRegistry } from "../../../data/category_registry";
import {
subscribeEntityRegistry,
type EntityRegistryEntry,
} from "../../../data/entity/entity_registry";
import { subscribeLabelRegistry } from "../../../data/label/label_registry";
import type { HomeAssistant } from "../../../types";
import type { SuggestWithAIGenerateTask } from "../../../components/ha-suggest-with-ai-button";
export interface MetadataSuggestionResult {
name: string;
description?: string;
category?: string;
labels?: string[];
}
export type MetadataSuggestionDomain = "automation" | "script";
export interface MetadataSuggestionInclude {
description?: boolean;
categories?: boolean;
labels?: boolean;
}
type Categories = Record<string, string>;
type Entities = Record<string, EntityRegistryEntry>;
type Labels = Record<string, string>;
export const SUGGESTION_INCLUDE_ALL: MetadataSuggestionInclude = {
description: true,
categories: true,
labels: true,
} as const;
const tryCatchEmptyObject = <T>(promise: Promise<T>): Promise<T> =>
promise.catch((err) => {
// eslint-disable-next-line no-console
console.error("Error fetching data for suggestion: ", err);
return {} as T;
});
const fetchCategories = (
connection: HomeAssistant["connection"],
domain: MetadataSuggestionDomain
): Promise<Categories> =>
tryCatchEmptyObject<Categories>(
fetchCategoryRegistry(connection, domain).then((cats) =>
Object.fromEntries(cats.map((cat) => [cat.category_id, cat.name]))
)
);
const fetchEntities = (
connection: HomeAssistant["connection"]
): Promise<Entities> =>
tryCatchEmptyObject<Entities>(
subscribeOne(connection, subscribeEntityRegistry).then((ents) =>
Object.fromEntries(ents.map((ent) => [ent.entity_id, ent]))
)
);
const fetchLabels = (
connection: HomeAssistant["connection"]
): Promise<Labels> =>
tryCatchEmptyObject<Labels>(
subscribeOne(connection, subscribeLabelRegistry).then((labs) =>
Object.fromEntries(labs.map((lab) => [lab.label_id, lab.name]))
)
);
function buildMetadataInspirations(
domain: MetadataSuggestionDomain,
states: HomeAssistant["states"],
entities: Entities,
categories?: Categories,
labels?: Labels
): string[] {
const inspirations: string[] = [];
for (const entityId of Object.keys(entities)) {
const entityEntry = entities[entityId];
if (!entityEntry || computeDomain(entityId) !== domain) {
continue;
}
const entity = states[entityId];
if (
!entity ||
entity.attributes.restored ||
!entity.attributes.friendly_name
) {
continue;
}
let inspiration = `- ${entity.attributes.friendly_name}`;
// Get the category for this domain
if (categories && categories[entityEntry.categories[domain]]) {
inspiration += ` (category: ${categories[entityEntry.categories[domain]]})`;
}
if (labels && entityEntry.labels.length) {
inspiration += ` (labels: ${entityEntry.labels
.map((label) => labels[label])
.join(", ")})`;
}
inspirations.push(inspiration);
}
return inspirations;
}
/**
* Generates an AI task for suggesting metadata
* for automations or scripts based on their configuration.
*
* @param connection - Home Assistant connection
* @param states - Current state objects
* @param language - User's language preference
* @param domain - The domain to suggest metadata for (automation, script)
* @param config - The configuration to suggest metadata for
* @param include - The metadata fields to include in the suggestion
* @returns Promise resolving to the AI task structure
*/
export async function generateMetadataSuggestionTask<T>(
connection: HomeAssistant["connection"],
states: HomeAssistant["states"],
language: HomeAssistant["language"],
domain: MetadataSuggestionDomain,
config: T,
include = SUGGESTION_INCLUDE_ALL
): Promise<SuggestWithAIGenerateTask> {
const [categories, entities, labels] = await Promise.all([
include.categories
? fetchCategories(connection, domain)
: Promise.resolve(undefined),
fetchEntities(connection),
include.labels ? fetchLabels(connection) : Promise.resolve(undefined),
]);
const inspirations = buildMetadataInspirations(
domain,
states,
entities,
categories,
labels
);
const structure: AITaskStructure = {
name: {
description: `The name of the ${domain}`,
required: true,
selector: {
text: {},
},
},
...(include.description && {
description: {
description: `A short description of the ${domain}`,
required: false,
selector: {
text: {},
},
},
}),
...(include.labels && {
labels: {
description: `Labels for the ${domain}`,
required: false,
selector: {
text: {
multiple: true,
},
},
},
}),
...(include.categories &&
categories && {
category: {
description: `The category of the ${domain}`,
required: false,
selector: {
select: {
options: Object.entries(categories).map(([id, name]) => ({
value: id,
label: name,
})),
},
},
},
}),
};
const categoryLabelText: string[] = [];
if (include.categories) {
categoryLabelText.push("category");
}
if (include.labels) {
categoryLabelText.push("labels");
}
const categoryLabelString =
categoryLabelText.length > 0 ? `, ${categoryLabelText.join(" and ")}` : "";
return {
type: "data",
task: {
task_name: `frontend__${domain}__save`,
instructions: `Suggest in language "${language}" a name${include.description ? ", description" : ""}${categoryLabelString} for the following Home Assistant ${domain}.
The name should be relevant to the ${domain}'s purpose.
${
inspirations.length
? `The name should be in same style and sentence capitalization as existing ${domain}s.${
include.categories || include.labels
? `
Suggest ${categoryLabelText.join(" and ")} if relevant to the ${domain}'s purpose.
Only suggest ${categoryLabelText.join(" and ")} that are already used by existing ${domain}s.`
: ""
}`
: `The name should be short, descriptive, sentence case, and written in the language ${language}.`
}${
include.description
? `
If the ${domain} contains 5+ steps, include a short description.`
: ""
}
For inspiration, here are existing ${domain}s:
${inspirations.join("\n")}
The ${domain} configuration is as follows:
${dump(config)}
`,
structure,
},
};
}
/**
* Processes the result of an AI task for suggesting metadata
* for automations or scripts based on their configuration.
*
* @param connection - Home Assistant connection
* @param domain - The domain of the ${domain}
* @param result - The result of the AI task
* @param include - The metadata fields to include in the suggestion
* @returns Promise resolving to the processed metadata suggestion
*/
export async function processMetadataSuggestion(
connection: HomeAssistant["connection"],
domain: MetadataSuggestionDomain,
result: GenDataTaskResult<MetadataSuggestionResult>,
include: MetadataSuggestionInclude
): Promise<MetadataSuggestionResult> {
const [categories, labels] = await Promise.all([
include.categories
? fetchCategories(connection, domain)
: Promise.resolve(undefined),
include.labels ? fetchLabels(connection) : Promise.resolve(undefined),
]);
const processed: MetadataSuggestionResult = {
name: result.data.name,
description: include.description ? result.data.description : undefined,
};
// Convert category name to ID
if (include.categories && categories && result.data.category) {
const categoryId = Object.entries(categories).find(
([, name]) => name === result.data.category
)?.[0];
if (categoryId) {
processed.category = categoryId;
}
}
// Convert label names to IDs
if (include.labels && labels && result.data.labels?.length) {
const newLabels: Record<string, undefined | string> = Object.fromEntries(
result.data.labels.map((name) => [name, undefined])
);
let toFind = result.data.labels.length;
for (const [labelId, labelName] of Object.entries(labels)) {
if (labelName in newLabels && newLabels[labelName] === undefined) {
newLabels[labelName] = labelId;
toFind--;
if (toFind === 0) {
break;
}
}
}
const foundLabels = Object.values(newLabels).filter(
(labelId): labelId is string => labelId !== undefined
);
if (foundLabels.length) {
processed.labels = foundLabels;
}
}
return processed;
}