diff --git a/package.json b/package.json index 908ac7c..d4011ac 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@ztimson/ai-utils", - "version": "0.6.10", + "version": "0.7.0", "description": "AI Utility library", "author": "Zak Timson", "license": "MIT", diff --git a/src/asr.ts b/src/asr.ts index f076766..b544214 100644 --- a/src/asr.ts +++ b/src/asr.ts @@ -1,14 +1,17 @@ import { pipeline } from '@xenova/transformers'; import { parentPort } from 'worker_threads'; -import * as fs from 'node:fs'; -import wavefile from 'wavefile'; import { spawn } from 'node:child_process'; +import { execSync } from 'node:child_process'; +import { mkdtempSync, rmSync, readFileSync } from 'node:fs'; +import { join } from 'node:path'; +import { tmpdir } from 'node:os'; +import wavefile from 'wavefile'; let whisperPipeline: any; export async function canDiarization(): Promise { return new Promise((resolve) => { - const proc = spawn('python3', ['-c', 'import pyannote.audio']); + const proc = spawn('python', ['-c', 'import pyannote.audio']); proc.on('close', (code: number) => resolve(code === 0)); proc.on('error', () => resolve(false)); }); @@ -21,25 +24,20 @@ import json import os from pyannote.audio import Pipeline -os.environ['TORCH_HOME'] = "${dir}" -os.environ['HF_TOKEN'] = "${token}" -pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1") -diarization = pipeline(sys.argv[1]) +os.environ['TORCH_HOME'] = r"${dir}" +pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token="${token}") +output = pipeline(sys.argv[1]) segments = [] -for turn, _, speaker in diarization.itertracks(yield_label=True): - segments.append({ - "start": turn.start, - "end": turn.end, - "speaker": speaker - }) +for turn, speaker in output.speaker_diarization: + segments.append({"start": turn.start, "end": turn.end, "speaker": speaker}) print(json.dumps(segments)) `; return new Promise((resolve, reject) => { let output = ''; - const proc = spawn('python3', ['-c', script, audioPath]); + const proc = spawn('python', ['-c', script, audioPath]); proc.stdout.on('data', (data: Buffer) => output += data.toString()); proc.stderr.on('data', (data: Buffer) => console.error(data.toString())); proc.on('close', (code: number) => { @@ -72,55 +70,65 @@ function combineSpeakerTranscript(chunks: any[], speakers: any[]): string { const speaker = speakers.find((s: any) => time >= s.start && time <= s.end); const speakerNum = speaker ? speakerMap.get(speaker.speaker) : 1; if (speakerNum !== currentSpeaker) { - if(currentText) lines.push(`[speaker ${currentSpeaker}]: ${currentText.trim()}`); + if(currentText) lines.push(`[Speaker ${currentSpeaker}]: ${currentText.trim()}`); currentSpeaker = speakerNum; currentText = chunk.text; } else { currentText += chunk.text; } }); - if(currentText) lines.push(`[speaker ${currentSpeaker}]: ${currentText.trim()}`); + if(currentText) lines.push(`[Speaker ${currentSpeaker}]: ${currentText.trim()}`); return lines.join('\n'); } +function prepareAudioBuffer(file: string): [string, Float32Array] { + let wav: any, tmp; + try { + wav = new wavefile.WaveFile(readFileSync(file)); + } catch(err) { + tmp = join(mkdtempSync(join(tmpdir(), 'audio-')), 'converted.wav'); + execSync(`ffmpeg -i "${file}" -ar 16000 -ac 1 -f wav "${tmp}"`, { stdio: 'ignore' }); + wav = new wavefile.WaveFile(readFileSync(tmp)); + } finally { + wav.toBitDepth('32f'); + wav.toSampleRate(16000); + const samples = wav.getSamples(); + if(Array.isArray(samples)) { + const left = samples[0]; + const right = samples[1]; + const buffer = new Float32Array(left.length); + for (let i = 0; i < left.length; i++) buffer[i] = (left[i] + right[i]) / 2; + return [tmp || file, buffer]; + } + return [tmp || file, samples]; + } +} + parentPort?.on('message', async ({ file, speaker, model, modelDir, token }) => { try { - console.log('worker', file); if(!whisperPipeline) whisperPipeline = await pipeline('automatic-speech-recognition', `Xenova/${model}`, {cache_dir: modelDir, quantized: true}); - // Prepare audio file (convert to mono channel wave) - const wav = new wavefile.WaveFile(fs.readFileSync(file)); - wav.toBitDepth('32f'); - wav.toSampleRate(16000); - const samples = wav.getSamples(); - let buffer; - if(Array.isArray(samples)) { // stereo to mono - average the channels - const left = samples[0]; - const right = samples[1]; - buffer = new Float32Array(left.length); - for (let i = 0; i < left.length; i++) buffer[i] = (left[i] + right[i]) / 2; - } else { - buffer = samples; - } + // Prepare audio file + const [f, buffer] = prepareAudioBuffer(file); - // Transcribe - const transcriptResult = await whisperPipeline(buffer, {return_timestamps: speaker ? 'word' : false}); - if(!speaker) { - parentPort?.postMessage({ text: transcriptResult.text?.trim() || null }); - return; - } + // Fetch transcript and speakers + const hasDiarization = speaker && await canDiarization(); + const [transcript, speakers] = await Promise.all([ + whisperPipeline(buffer, {return_timestamps: speaker ? 'word' : false}), + (!speaker || !token || !hasDiarization) ? Promise.resolve(): runDiarization(f, modelDir, token), + ]); + if(file != f) rmSync(f, { recursive: true, force: true }); - // Speaker Diarization - const hasDiarization = await canDiarization(); - if(!token || !hasDiarization) { - parentPort?.postMessage({ text: transcriptResult.text?.trim() || null, error: 'Speaker diarization unavailable' }); - return; - } + // Return any results / errors if no more processing required + const text = transcript.text?.trim() || null; + if(!speaker) return parentPort?.postMessage({ text }); + if(!token) return parentPort?.postMessage({ text, error: 'HuggingFace token required' }); + if(!hasDiarization) return parentPort?.postMessage({ text, error: 'Speaker diarization unavailable' }); - const speakers = await runDiarization(file, modelDir, token); - const combined = combineSpeakerTranscript(transcriptResult.chunks || [], speakers); + // Combine transcript and speakers + const combined = combineSpeakerTranscript(transcript.chunks || [], speakers || []); parentPort?.postMessage({ text: combined }); - } catch (err) { - parentPort?.postMessage({ error: (err as Error).message }); + } catch (err: any) { + parentPort?.postMessage({ error: err.stack || err.message }); } }); diff --git a/src/audio.ts b/src/audio.ts index e17c885..ae80bf7 100644 --- a/src/audio.ts +++ b/src/audio.ts @@ -7,12 +7,12 @@ import {dirname, join} from 'path'; export class Audio { constructor(private ai: Ai) {} - asr(file: string, options: { model?: string; speaker?: boolean } = {}): AbortablePromise { + asr(file: string, options: { model?: string; speaker?: boolean | 'id' } = {}): AbortablePromise { const { model = this.ai.options.asr || 'whisper-base', speaker = false } = options; let aborted = false; const abort = () => { aborted = true; }; - const p = new Promise((resolve, reject) => { + let p = new Promise((resolve, reject) => { const worker = new Worker(join(dirname(fileURLToPath(import.meta.url)), 'asr.js')); const handleMessage = ({ text, warning, error }: any) => { worker.terminate(); @@ -34,6 +34,23 @@ export class Audio { }); worker.postMessage({file, model, speaker, modelDir: this.ai.options.path, token: this.ai.options.hfToken}); }); + + // Name speakers using AI + if(options.speaker == 'id') { + if(!this.ai.language.defaultModel) throw new Error('Configure an LLM for advanced ASR speaker detection'); + p = p.then(async transcript => { + if(!transcript) return transcript; + const names = await this.ai.language.json(transcript, '{1: "Detected Name"}', { + system: 'Use this following transcript to identify speakers. Only identify speakers you are sure about', + temperature: 0.2, + }); + Object.entries(names).forEach(([speaker, name]) => { + transcript = (transcript).replaceAll(`[Speaker ${speaker}]`, `[${name}]`); + }); + return transcript; + }) + } + return Object.assign(p, { abort }); } diff --git a/src/llm.ts b/src/llm.ts index fc0c2d0..139e565 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -75,8 +75,8 @@ export type LLMRequest = { } class LLM { - private models: {[model: string]: LLMProvider} = {}; - private defaultModel!: string; + defaultModel!: string; + models: {[model: string]: LLMProvider} = {}; constructor(public readonly ai: Ai) { if(!ai.options.llm?.models) return; @@ -184,7 +184,12 @@ class LLM { const system = history[0].role == 'system' ? history[0] : null, recent = keep == 0 ? [] : history.slice(-keep), process = (keep == 0 ? history : history.slice(0, -keep)).filter(h => h.role === 'assistant' || h.role === 'user'); - const summary: any = await this.json(`Create the smallest summary possible, no more than 500 tokens. Create a list of NEW facts (split by subject [pro]noun and fact) about what you learned from this conversation that you didn't already know or get from a tool call or system prompt. Focus only on new information about people, topics, or facts. Avoid generating facts about the AI. Match this format: {summary: string, facts: [[subject, fact]]}\n\n${process.map(m => `${m.role}: ${m.content}`).join('\n\n')}`, {model: options?.model, temperature: options?.temperature || 0.3}); + + const summary: any = await this.json(process.map(m => `${m.role}: ${m.content}`).join('\n\n'), '{summary: string, facts: [[subject, fact]]}', { + system: 'Create the smallest summary possible, no more than 500 tokens. Create a list of NEW facts (split by subject [pro]noun and fact) about what you learned from this conversation that you didn\'t already know or get from a tool call or system prompt. Focus only on new information about people, topics, or facts. Avoid generating facts about the AI.', + model: options?.model, + temperature: options?.temperature || 0.3 + }); const timestamp = new Date(); const memory = await Promise.all((summary?.facts || [])?.map(async ([owner, fact]: [string, string]) => { const e = await Promise.all([this.embedding(owner), this.embedding(`${owner}: ${fact}`)]); @@ -312,12 +317,16 @@ class LLM { /** * Ask a question with JSON response - * @param {string} message Question + * @param {string} text Text to process + * @param {string} schema JSON schema the AI should match * @param {LLMRequest} options Configuration options and chat history * @returns {Promise<{} | {} | RegExpExecArray | null>} */ - async json(message: string, options?: LLMRequest): Promise { - let resp = await this.ask(message, {system: 'Respond using a JSON blob matching any provided examples', ...options}); + async json(text: string, schema: string, options?: LLMRequest): Promise { + let resp = await this.ask(text, {...options, system: (options?.system ? `${options.system}\n` : '') + `Only respond using a JSON code block matching this schema: +\`\`\`json +${schema} +\`\`\``}); if(!resp) return {}; const codeBlock = /```(?:.+)?\s*([\s\S]*?)```/.exec(resp); const jsonStr = codeBlock ? codeBlock[1].trim() : resp;