import {execSync, spawn} from 'node:child_process'; import {mkdtempSync} from 'node:fs'; import fs from 'node:fs/promises'; import {tmpdir} from 'node:os'; import * as path from 'node:path'; import Path, {join} from 'node:path'; import {AbortablePromise, Ai} from './ai.ts'; export class Audio { private downloads: {[key: string]: Promise} = {}; private pyannote!: string; private whisperModel!: string; constructor(private ai: Ai) { if(ai.options.whisper) { this.whisperModel = ai.options.asr || 'ggml-base.en.bin'; this.downloadAsrModel(); } this.pyannote = ` import sys import json import os from pyannote.audio import Pipeline os.environ['TORCH_HOME'] = r"${ai.options.path}" pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token="${ai.options.hfToken}") output = pipeline(sys.argv[1]) segments = [] for turn, speaker in output.speaker_diarization: segments.append({"start": turn.start, "end": turn.end, "speaker": speaker}) print(json.dumps(segments)) `; } private async addPunctuation(timestampData: any, llm?: boolean, cadence = 150): Promise { const countSyllables = (word: string): number => { word = word.toLowerCase().replace(/[^a-z]/g, ''); if(word.length <= 3) return 1; const matches = word.match(/[aeiouy]+/g); let count = matches ? matches.length : 1; if(word.endsWith('e')) count--; return Math.max(1, count); }; let result = ''; timestampData.transcription.filter((word, i) => { let skip = false; const prevWord = timestampData.transcription[i - 1]; const nextWord = timestampData.transcription[i + 1]; if(!word.text && nextWord) { nextWord.offsets.from = word.offsets.from; nextWord.timestamps.from = word.offsets.from; } else if(word.text && word.text[0] != ' ' && prevWord) { prevWord.offsets.to = word.offsets.to; prevWord.timestamps.to = word.timestamps.to; prevWord.text += word.text; skip = true; } return !!word.text && !skip; }).forEach((word: any) => { const capital = /^[A-Z]/.test(word.text.trim()); const length = word.offsets.to - word.offsets.from; const syllables = countSyllables(word.text.trim()); const expected = syllables * cadence; if(capital && length > expected * 2 && word.text[0] == ' ') result += '.'; result += word.text; }); if(!llm) return result.trim(); return this.ai.language.ask(result, { system: 'Remove any misplaced punctuation from the following ASR transcript using the replace tool. Avoid modifying words unless there is an obvious typo', temperature: 0.1, tools: [{ name: 'replace', description: 'Use find and replace to fix errors', args: { find: {type: 'string', description: 'Text to find', required: true}, replace: {type: 'string', description: 'Text to replace', required: true} }, fn: (args) => result = result.replace(args.find, args.replace) }] }).then(() => result); } private async diarizeTranscript(timestampData: any, speakers: any[], llm: boolean): Promise { const speakerMap = new Map(); let speakerCount = 0; speakers.forEach((seg: any) => { if(!speakerMap.has(seg.speaker)) speakerMap.set(seg.speaker, ++speakerCount); }); const punctuatedText = await this.addPunctuation(timestampData, llm); const sentences = punctuatedText.match(/[^.!?]+[.!?]+/g) || [punctuatedText]; const words = timestampData.transcription.filter((w: any) => w.text.trim()); // Assign speaker to each sentence const sentencesWithSpeakers = sentences.map(sentence => { sentence = sentence.trim(); if(!sentence) return null; const sentenceWords = sentence.toLowerCase().replace(/[^\w\s]/g, '').split(/\s+/); const speakerWordCount = new Map(); sentenceWords.forEach(sw => { const word = words.find((w: any) => sw === w.text.trim().toLowerCase().replace(/[^\w]/g, '')); if(!word) return; const wordTime = word.offsets.from / 1000; const speaker = speakers.find((seg: any) => wordTime >= seg.start && wordTime <= seg.end); if(speaker) { const spkNum = speakerMap.get(speaker.speaker); speakerWordCount.set(spkNum, (speakerWordCount.get(spkNum) || 0) + 1); } }); let bestSpeaker = 1; let maxWords = 0; speakerWordCount.forEach((count, speaker) => { if(count > maxWords) { maxWords = count; bestSpeaker = speaker; } }); return {speaker: bestSpeaker, text: sentence}; }).filter(s => s !== null); // Merge adjacent sentences from same speaker const merged: Array<{speaker: number, text: string}> = []; sentencesWithSpeakers.forEach(item => { const last = merged[merged.length - 1]; if(last && last.speaker === item.speaker) { last.text += ' ' + item.text; } else { merged.push({...item}); } }); let transcript = merged.map(item => `[Speaker ${item.speaker}]: ${item.text}`).join('\n').trim(); if(!llm) return transcript; let chunks = this.ai.language.chunk(transcript, 500, 0); if(chunks.length > 4) chunks = [...chunks.slice(0, 3), chunks.at(-1)]; const names = await this.ai.language.json(chunks.join('\n'), '{1: "Detected Name", 2: "Second Name"}', { system: 'Use the following transcript to identify speakers. Only identify speakers you are positive about, dont mention speakers you are unsure about in your response', temperature: 0.1, }); Object.entries(names).forEach(([speaker, name]) => transcript = transcript.replaceAll(`[Speaker ${speaker}]`, `[${name}]`)); return transcript; } private runAsr(file: string, opts: {model?: string, diarization?: boolean} = {}): AbortablePromise { let proc: any; const p = new Promise((resolve, reject) => { this.downloadAsrModel(opts.model).then(m => { if(opts.diarization) { let output = path.join(path.dirname(file), 'transcript'); proc = spawn(this.ai.options.whisper, ['-m', m, '-f', file, '-np', '-ml', '1', '-oj', '-of', output], {stdio: ['ignore', 'ignore', 'pipe']} ); proc.on('error', (err: Error) => reject(err)); proc.on('close', async (code: number) => { if(code === 0) { output = await fs.readFile(output + '.json', 'utf-8'); fs.rm(output + '.json').catch(() => { }); try { resolve(JSON.parse(output)); } catch(e) { reject(new Error('Failed to parse whisper JSON')); } } else { reject(new Error(`Exit code ${code}`)); } }); } else { let output = ''; proc = spawn(this.ai.options.whisper, ['-m', m, '-f', file, '-np', '-nt']); proc.on('error', (err: Error) => reject(err)); proc.stdout.on('data', (data: Buffer) => output += data.toString()); proc.on('close', async (code: number) => { if(code === 0) { resolve(output.trim() || null); } else { reject(new Error(`Exit code ${code}`)); } }); } }); }); return Object.assign(p, {abort: () => proc?.kill('SIGTERM')}); } private runDiarization(file: string): AbortablePromise { let aborted = false, abort = () => { aborted = true; }; const checkPython = (cmd: string) => { return new Promise((resolve) => { const proc = spawn(cmd, ['-W', 'ignore', '-c', 'import pyannote.audio']); proc.on('close', (code: number) => resolve(code === 0)); proc.on('error', () => resolve(false)); }); }; const p = Promise.all([ checkPython('python'), checkPython('python3'), ]).then((async ([p, p3]: [boolean, boolean]) => { if(aborted) return; if(!p && !p3) throw new Error('Pyannote is not installed: pip install pyannote.audio'); const binary = p3 ? 'python3' : 'python'; return new Promise((resolve, reject) => { if(aborted) return; let output = ''; const proc = spawn(binary, ['-W', 'ignore', '-c', this.pyannote, file]); proc.stdout.on('data', (data: Buffer) => output += data.toString()); proc.stderr.on('data', (data: Buffer) => console.error(data.toString())); proc.on('close', (code: number) => { if(code === 0) { try { resolve(JSON.parse(output)); } catch (err) { reject(new Error('Failed to parse diarization output')); } } else { reject(new Error(`Python process exited with code ${code}`)); } }); proc.on('error', reject); abort = () => proc.kill('SIGTERM'); }); })); return Object.assign(p, {abort}); } asr(file: string, options: { model?: string; diarization?: boolean | 'llm' } = {}): AbortablePromise { if(!this.ai.options.whisper) throw new Error('Whisper not configured'); const tmp = join(mkdtempSync(join(tmpdir(), 'audio-')), 'converted.wav'); execSync(`ffmpeg -i "${file}" -ar 16000 -ac 1 -f wav "${tmp}"`, { stdio: 'ignore' }); const clean = () => fs.rm(Path.dirname(tmp), {recursive: true, force: true}).catch(() => {}); if(!options.diarization) return this.runAsr(tmp, {model: options.model}); const timestamps = this.runAsr(tmp, {model: options.model, diarization: true}); const diarization = this.runDiarization(tmp); let aborted = false, abort = () => { aborted = true; timestamps.abort(); diarization.abort(); clean(); }; const response = Promise.allSettled([timestamps, diarization]).then(async ([ts, d]) => { if(ts.status == 'rejected') throw new Error('Whisper.cpp timestamps:\n' + ts.reason); if(d.status == 'rejected') throw new Error('Pyannote:\n' + d.reason); if(aborted || !options.diarization) return ts.value; return this.diarizeTranscript(ts.value, d.value, options.diarization == 'llm'); }).finally(() => clean()); return Object.assign(response, {abort}); } async downloadAsrModel(model: string = this.whisperModel): Promise { if(!this.ai.options.whisper) throw new Error('Whisper not configured'); if(!model.endsWith('.bin')) model += '.bin'; const p = Path.join(this.ai.options.path, model); if(await fs.stat(p).then(() => true).catch(() => false)) return p; if(!!this.downloads[model]) return this.downloads[model]; this.downloads[model] = fetch(`https://huggingface.co/ggerganov/whisper.cpp/resolve/main/${model}`) .then(resp => resp.arrayBuffer()) .then(arr => Buffer.from(arr)).then(async buffer => { await fs.writeFile(p, buffer); delete this.downloads[model]; return p; }); return this.downloads[model]; } }