diff --git a/package.json b/package.json index db20f88..c538a30 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@ztimson/ai-utils", - "version": "0.7.7", + "version": "0.7.8", "description": "AI Utility library", "author": "Zak Timson", "license": "MIT", @@ -32,8 +32,7 @@ "@ztimson/utils": "^0.28.13", "cheerio": "^1.2.0", "openai": "^6.22.0", - "tesseract.js": "^7.0.0", - "wavefile": "^11.0.0" + "tesseract.js": "^7.0.0" }, "devDependencies": { "@types/node": "^24.8.1", diff --git a/src/ai.ts b/src/ai.ts index a3dc58a..bb6fa76 100644 --- a/src/ai.ts +++ b/src/ai.ts @@ -12,7 +12,7 @@ export type AiOptions = { hfToken?: string; /** Path to models */ path?: string; - /** ASR model: whisper-tiny, whisper-base */ + /** Whisper ASR model: ggml-tiny.en.bin, ggml-base.en.bin */ asr?: string; /** Embedding model: all-MiniLM-L6-v2, bge-small-en-v1.5, bge-large-en-v1.5 */ embedder?: string; @@ -22,6 +22,8 @@ export type AiOptions = { } /** OCR model: eng, eng_best, eng_fast */ ocr?: string; + /** Whisper binary */ + whisper?: string; } export class Ai { diff --git a/src/asr.ts b/src/asr.ts deleted file mode 100644 index f934481..0000000 --- a/src/asr.ts +++ /dev/null @@ -1,137 +0,0 @@ -import { pipeline } from '@xenova/transformers'; -import { parentPort } from 'worker_threads'; -import { spawn } from 'node:child_process'; -import { execSync } from 'node:child_process'; -import { mkdtempSync, rmSync, readFileSync } from 'node:fs'; -import { join } from 'node:path'; -import { tmpdir } from 'node:os'; -import wavefile from 'wavefile'; - -let whisperPipeline: any; - -export async function canDiarization(): Promise { - const checkPython = (cmd: string) => { - return new Promise((resolve) => { - const proc = spawn(cmd, ['-c', 'import pyannote.audio']); - proc.on('close', (code: number) => resolve(code === 0)); - proc.on('error', () => resolve(false)); - }); - }; - if(await checkPython('python3')) return 'python3'; - if(await checkPython('python')) return 'python'; - return null; -} - -async function runDiarization(binary: string, audioPath: string, dir: string, token: string): Promise { - const script = ` -import sys -import json -import os -from pyannote.audio import Pipeline - -os.environ['TORCH_HOME'] = r"${dir}" -pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token="${token}") -output = pipeline(sys.argv[1]) - -segments = [] -for turn, speaker in output.speaker_diarization: - segments.append({"start": turn.start, "end": turn.end, "speaker": speaker}) - -print(json.dumps(segments)) -`; - - return new Promise((resolve, reject) => { - let output = ''; - const proc = spawn(binary, ['-c', script, audioPath]); - proc.stdout.on('data', (data: Buffer) => output += data.toString()); - proc.stderr.on('data', (data: Buffer) => console.error(data.toString())); - proc.on('close', (code: number) => { - if(code === 0) { - try { - resolve(JSON.parse(output)); - } catch (err) { - reject(new Error('Failed to parse diarization output')); - } - } else { - reject(new Error(`Python process exited with code ${code}`)); - } - }); - proc.on('error', reject); - }); -} - -function combineSpeakerTranscript(chunks: any[], speakers: any[]): string { - const speakerMap = new Map(); - let speakerCount = 0; - speakers.forEach((seg: any) => { - if(!speakerMap.has(seg.speaker)) speakerMap.set(seg.speaker, ++speakerCount); - }); - - const lines: string[] = []; - let currentSpeaker = -1; - let currentText = ''; - chunks.forEach((chunk: any) => { - const time = chunk.timestamp[0]; - const speaker = speakers.find((s: any) => time >= s.start && time <= s.end); - const speakerNum = speaker ? speakerMap.get(speaker.speaker) : 1; - if (speakerNum !== currentSpeaker) { - if(currentText) lines.push(`[Speaker ${currentSpeaker}]: ${currentText.trim()}`); - currentSpeaker = speakerNum; - currentText = chunk.text; - } else { - currentText += chunk.text; - } - }); - if(currentText) lines.push(`[Speaker ${currentSpeaker}]: ${currentText.trim()}`); - return lines.join('\n'); -} - -function prepareAudioBuffer(file: string): [string, Float32Array] { - let wav: any, tmp; - try { - wav = new wavefile.WaveFile(readFileSync(file)); - } catch(err) { - tmp = join(mkdtempSync(join(tmpdir(), 'audio-')), 'converted.wav'); - execSync(`ffmpeg -i "${file}" -ar 16000 -ac 1 -f wav "${tmp}"`, { stdio: 'ignore' }); - wav = new wavefile.WaveFile(readFileSync(tmp)); - } finally { - wav.toBitDepth('32f'); - wav.toSampleRate(16000); - const samples = wav.getSamples(); - if(Array.isArray(samples)) { - const left = samples[0]; - const right = samples[1]; - const buffer = new Float32Array(left.length); - for (let i = 0; i < left.length; i++) buffer[i] = (left[i] + right[i]) / 2; - return [tmp || file, buffer]; - } - return [tmp || file, samples]; - } -} - -parentPort?.on('message', async ({ file, speaker, model, modelDir, token }) => { - let tempFile = null; - try { - if(!whisperPipeline) whisperPipeline = await pipeline('automatic-speech-recognition', `Xenova/${model}`, {cache_dir: modelDir, quantized: true}); - - const [f, buffer] = prepareAudioBuffer(file); - tempFile = f !== file ? f : null; - const hasDiarization = await canDiarization(); - const [transcript, speakers] = await Promise.all([ - whisperPipeline(buffer, {return_timestamps: speaker ? 'word' : false}), - (!speaker || !token || !hasDiarization) ? Promise.resolve(): runDiarization(hasDiarization, f, modelDir, token), - ]); - - const text = transcript.text?.trim() || null; - if(!speaker) return parentPort?.postMessage({ text }); - if(!token) return parentPort?.postMessage({ text, error: 'HuggingFace token required' }); - if(!hasDiarization) return parentPort?.postMessage({ text, error: 'Speaker diarization unavailable' }); - - const combined = combineSpeakerTranscript(transcript.chunks || [], speakers || []); - parentPort?.postMessage({ text: combined }); - } catch (err: any) { - parentPort?.postMessage({ error: err.stack || err.message }); - } finally { - if(tempFile) rmSync(tempFile, { recursive: true, force: true }); - } -}); diff --git a/src/audio.ts b/src/audio.ts index 6e3ea3a..ec3b5bf 100644 --- a/src/audio.ts +++ b/src/audio.ts @@ -1,82 +1,172 @@ -import {fileURLToPath} from 'url'; -import {Worker} from 'worker_threads'; +import {execSync, spawn} from 'node:child_process'; +import {mkdtempSync, rmSync} from 'node:fs'; +import fs from 'node:fs/promises'; +import {tmpdir} from 'node:os'; +import Path, {join} from 'node:path'; import {AbortablePromise, Ai} from './ai.ts'; -import {canDiarization} from './asr.ts'; -import {dirname, join} from 'path'; export class Audio { - private busy = false; - private currentJob: any; - private queue: Array<{file: string, model: string, speaker: boolean | 'id', modelDir: string, token: string, resolve: any, reject: any}> = []; - private worker: Worker | null = null; + private downloads: {[key: string]: Promise} = {}; + private pyannote!: string; + private whisperModel!: string; - constructor(private ai: Ai) {} - - private processQueue() { - if(this.busy || !this.queue.length) return; - - this.busy = true; - const job = this.queue.shift()!; - if(!this.worker) { - this.worker = new Worker(join(dirname(fileURLToPath(import.meta.url)), 'asr.js')); - this.worker.on('message', this.handleMessage.bind(this)); - this.worker.on('error', this.handleError.bind(this)); + constructor(private ai: Ai) { + if(ai.options.whisper) { + this.whisperModel = ai.options.asr?.endsWith('.bin') ? ai.options.asr : ai.options.asr + '.bin'; + this.downloadAsrModel(); } - this.currentJob = job; - this.worker.postMessage({file: job.file, model: job.model, speaker: job.speaker, modelDir: job.modelDir, token: job.token}); + this.pyannote = ` +import sys +import json +import os +from pyannote.audio import Pipeline + +os.environ['TORCH_HOME'] = r"${ai.options.path}" +pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token="${ai.options.hfToken}") +output = pipeline(sys.argv[1]) + +segments = [] +for turn, speaker in output.speaker_diarization: + segments.append({"start": turn.start, "end": turn.end, "speaker": speaker}) + +print(json.dumps(segments)) +`; } - private handleMessage({text, warning, error}: any) { - const job = this.currentJob!; - this.busy = false; - if(error) job.reject(new Error(error)); - else { - if(warning) console.warn(warning); - job.resolve(text); - } - this.processQueue(); - } - - private handleError(err: Error) { - if(this.currentJob) { - this.currentJob.reject(err); - this.busy = false; - this.processQueue(); - } - } - - asr(file: string, options: { model?: string; speaker?: boolean | 'id' } = {}): AbortablePromise { - const { model = this.ai.options.asr || 'whisper-base', speaker = false } = options; - let aborted = false; - const abort = () => { aborted = true; }; - let p = new Promise((resolve, reject) => { - this.queue.push({file, model, speaker, modelDir: this.ai.options.path, token: this.ai.options.hfToken, - resolve: (text: string | null) => !aborted && resolve(text), - reject: (err: Error) => !aborted && reject(err) + private runAsr(file: string, opts: {model?: string, diarization?: boolean} = {}): AbortablePromise { + let proc: any; + const p = new Promise((resolve, reject) => { + this.downloadAsrModel(opts.model).then(m => { + let output = ''; + const args = [opts.diarization ? '-owts' : '-nt', '-m', m, '-f', file]; + proc = spawn(this.ai.options.whisper, args, {stdio: ['ignore', 'pipe', 'ignore']}); + proc.on('error', (err: Error) => reject(err)); + proc.stdout.on('data', (data: Buffer) => output += data.toString()); + proc.on('close', (code: number) => { + if(code === 0) { + if(opts.diarization) { + try { resolve(JSON.parse(output)); } + catch(e) { reject(new Error('Failed to parse whisper JSON')); } + } else { + resolve(output.trim() || null); + } + } else { + reject(new Error(`Exit code ${code}`)); + } + }); }); - this.processQueue(); + }); + return Object.assign(p, {abort: () => proc?.kill('SIGTERM')}); + } + + private runDiarization(file: string): AbortablePromise { + let aborted = false, abort = () => { aborted = true; }; + const checkPython = (cmd: string) => { + return new Promise((resolve) => { + const proc = spawn(cmd, ['-c', 'import pyannote.audio']); + proc.on('close', (code: number) => resolve(code === 0)); + proc.on('error', () => resolve(false)); + }); + }; + const p = Promise.all([ + checkPython('python'), + checkPython('python3'), + ]).then((async ([p, p3]: [boolean, boolean]) => { + if(aborted) return; + if(!p && !p3) throw new Error('Pyannote is not installed: pip install pyannote.audio'); + const binary = p3 ? 'python3' : 'python'; + let tmp: string | null = null; + return new Promise((resolve, reject) => { + tmp = join(mkdtempSync(join(tmpdir(), 'audio-')), 'converted.wav'); + execSync(`ffmpeg -i "${file}" -ar 16000 -ac 1 -f wav "${tmp}"`, { stdio: 'ignore' }); + if(aborted) return; + let output = ''; + const proc = spawn(binary, ['-c', this.pyannote, tmp]); + proc.stdout.on('data', (data: Buffer) => output += data.toString()); + proc.stderr.on('data', (data: Buffer) => console.error(data.toString())); + proc.on('close', (code: number) => { + if(code === 0) { + try { resolve(JSON.parse(output)); } + catch (err) { reject(new Error('Failed to parse diarization output')); } + } else { + reject(new Error(`Python process exited with code ${code}`)); + } + }); + proc.on('error', reject); + abort = () => proc.kill('SIGTERM'); + }).finally(() => { if(tmp) rmSync(Path.dirname(tmp), { recursive: true, force: true }); }); + })); + return Object.assign(p, {abort}); + } + + private combineSpeakerTranscript(transcript: any, speakers: any[]): string { + const speakerMap = new Map(); + let speakerCount = 0; + speakers.forEach((seg: any) => { + if(!speakerMap.has(seg.speaker)) speakerMap.set(seg.speaker, ++speakerCount); }); - if(options.speaker == 'id') { - if(!this.ai.language.defaultModel) throw new Error('Configure an LLM for advanced ASR speaker detection'); - p = p.then(async transcript => { - if(!transcript) return transcript; - let chunks = this.ai.language.chunk(transcript, 500, 0); + const lines: string[] = []; + let currentSpeaker = -1; + let currentText = ''; + transcript.transcription.forEach((word: any) => { + const time = word.offsets.from / 1000; // Convert ms to seconds + const speaker = speakers.find((s: any) => time >= s.start && time <= s.end); + const speakerNum = speaker ? speakerMap.get(speaker.speaker) : 1; + if (speakerNum !== currentSpeaker) { + if(currentText) lines.push(`[Speaker ${currentSpeaker}]: ${currentText.trim()}`); + currentSpeaker = speakerNum; + currentText = word.text; + } else { + currentText += ' ' + word.text; + } + }); + if(currentText) lines.push(`[Speaker ${currentSpeaker}]: ${currentText.trim()}`); + return lines.join('\n'); + } + + asr(file: string, options: { model?: string; diarization?: boolean | 'id' } = {}): AbortablePromise { + if(!this.ai.options.whisper) throw new Error('Whisper not configured'); + + const transcript = this.runAsr(file, {model: options.model, diarization: !!options.diarization}); + const diarization: any = options.diarization ? this.runDiarization(file) : Promise.resolve(null); + const abort = () => { + transcript.abort(); + diarization?.abort?.(); + }; + + const response = Promise.all([transcript, diarization]).then(async ([t, d]) => { + if(!options.diarization) return t; + t = this.combineSpeakerTranscript(t, d); + if(options.diarization === 'id') { + if(!this.ai.language.defaultModel) throw new Error('Configure an LLM for advanced ASR speaker detection'); + let chunks = this.ai.language.chunk(t, 500, 0); if(chunks.length > 4) chunks = [...chunks.slice(0, 3), chunks.at(-1)]; const names = await this.ai.language.json(chunks.join('\n'), '{1: "Detected Name", 2: "Second Name"}', { system: 'Use the following transcript to identify speakers. Only identify speakers you are positive about, dont mention speakers you are unsure about in your response', temperature: 0.1, }); - Object.entries(names).forEach(([speaker, name]) => { - transcript = (transcript).replaceAll(`[Speaker ${speaker}]`, `[${name}]`); - }); - return transcript; - }) - } - - return Object.assign(p, { abort }); + Object.entries(names).forEach(([speaker, name]) => t = t.replaceAll(`[Speaker ${speaker}]`, `[${name}]`)); + } + return t; + }); + return Object.assign(response, {abort}); } - canDiarization = () => canDiarization().then(resp => !!resp); + async downloadAsrModel(model: string = this.whisperModel): Promise { + if(!this.ai.options.whisper) throw new Error('Whisper not configured'); + if(!model.endsWith('.bin')) model += '.bin'; + const p = Path.join(this.ai.options.path, model); + if(await fs.stat(p).then(() => true).catch(() => false)) return p; + if(!!this.downloads[model]) return this.downloads[model]; + this.downloads[model] = fetch(`https://huggingface.co/ggerganov/whisper.cpp/resolve/main/${model}`) + .then(resp => resp.arrayBuffer()) + .then(arr => Buffer.from(arr)).then(async buffer => { + await fs.writeFile(p, buffer); + delete this.downloads[model]; + return p; + }); + return this.downloads[model]; + } } diff --git a/src/embedder.ts b/src/embedder.ts index 492f43a..8b0d0ba 100644 --- a/src/embedder.ts +++ b/src/embedder.ts @@ -1,11 +1,13 @@ import { pipeline } from '@xenova/transformers'; -import { parentPort } from 'worker_threads'; -let embedder: any; +const [modelDir, model] = process.argv.slice(2); -parentPort?.on('message', async ({text, model, modelDir }) => { - if(!embedder) embedder = await pipeline('feature-extraction', 'Xenova/' + model, {quantized: true, cache_dir: modelDir}); +let text = ''; +process.stdin.on('data', chunk => text += chunk); +process.stdin.on('end', async () => { + const embedder = await pipeline('feature-extraction', 'Xenova/' + model, {quantized: true, cache_dir: modelDir}); const output = await embedder(text, { pooling: 'mean', normalize: true }); const embedding = Array.from(output.data); - parentPort?.postMessage({embedding}); + console.log(JSON.stringify({embedding})); + process.exit(); }); diff --git a/src/index.ts b/src/index.ts index a27ae96..b25cf61 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,6 +1,5 @@ export * from './ai'; export * from './antrhopic'; -export * from './asr'; export * from './audio'; export * from './embedder' export * from './llm'; diff --git a/src/llm.ts b/src/llm.ts index 4a9f0eb..4e4c650 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -4,9 +4,9 @@ import {Anthropic} from './antrhopic.ts'; import {OpenAi} from './open-ai.ts'; import {LLMProvider} from './provider.ts'; import {AiTool} from './tools.ts'; -import {Worker} from 'worker_threads'; import {fileURLToPath} from 'url'; import {dirname, join} from 'path'; +import { spawn } from 'node:child_process'; export type AnthropicConfig = {proto: 'anthropic', token: string}; export type OllamaConfig = {proto: 'ollama', host: string}; @@ -258,34 +258,54 @@ class LLM { * @param {maxTokens?: number, overlapTokens?: number} opts Options for embedding such as chunk sizes * @returns {Promise[]>} Chunked embeddings */ - async embedding(target: object | string, opts: {maxTokens?: number, overlapTokens?: number} = {}) { + embedding(target: object | string, opts: {maxTokens?: number, overlapTokens?: number} = {}): AbortablePromise { let {maxTokens = 500, overlapTokens = 50} = opts; + let aborted = false; + const abort = () => { aborted = true; }; + const embed = (text: string): Promise => { return new Promise((resolve, reject) => { - const worker = new Worker(join(dirname(fileURLToPath(import.meta.url)), 'embedder.js')); - const handleMessage = ({ embedding }: any) => { - worker.terminate(); - resolve(embedding); - }; - const handleError = (err: Error) => { - worker.terminate(); - reject(err); - }; - worker.on('message', handleMessage); - worker.on('error', handleError); - worker.on('exit', (code) => { - if(code !== 0) reject(new Error(`Worker exited with code ${code}`)); + if(aborted) return reject(new Error('Aborted')); + + const args: string[] = [ + join(dirname(fileURLToPath(import.meta.url)), 'embedder.js'), + this.ai.options.path, + this.ai.options?.embedder || 'bge-small-en-v1.5' + ]; + const proc = spawn('node', args, {stdio: ['pipe', 'pipe', 'ignore']}); + proc.stdin.write(text); + proc.stdin.end(); + + let output = ''; + proc.stdout.on('data', (data: Buffer) => output += data.toString()); + proc.on('close', (code: number) => { + if(aborted) return reject(new Error('Aborted')); + if(code === 0) { + try { + const result = JSON.parse(output); + resolve(result.embedding); + } catch(err) { + reject(new Error('Failed to parse embedding output')); + } + } else { + reject(new Error(`Embedder process exited with code ${code}`)); + } }); - worker.postMessage({text, model: this.ai.options?.embedder || 'bge-small-en-v1.5', modelDir: this.ai.options.path}); + proc.on('error', reject); }); }; - const chunks = this.chunk(target, maxTokens, overlapTokens), results: any[] = []; - for(let i = 0; i < chunks.length; i++) { - const text= chunks[i]; - const embedding = await embed(text); - results.push({index: i, embedding, text, tokens: this.estimateTokens(text)}); - } - return results; + + const p = (async () => { + const chunks = this.chunk(target, maxTokens, overlapTokens), results: any[] = []; + for(let i = 0; i < chunks.length; i++) { + if(aborted) break; + const text = chunks[i]; + const embedding = await embed(text); + results.push({index: i, embedding, text, tokens: this.estimateTokens(text)}); + } + return results; + })(); + return Object.assign(p, { abort }); } /** diff --git a/src/vision.ts b/src/vision.ts index 911e5fb..28e2d23 100644 --- a/src/vision.ts +++ b/src/vision.ts @@ -2,43 +2,22 @@ import {createWorker} from 'tesseract.js'; import {AbortablePromise, Ai} from './ai.ts'; export class Vision { - private worker: any = null; - private queue: Array<{ path: string, resolve: any, reject: any }> = []; - private busy = false; constructor(private ai: Ai) {} - private async processQueue() { - if(this.busy || !this.queue.length) return; - this.busy = true; - const job = this.queue.shift()!; - if(!this.worker) this.worker = await createWorker(this.ai.options.ocr || 'eng', 2, {cachePath: this.ai.options.path}); - try { - const {data} = await this.worker.recognize(job.path); - job.resolve(data.text.trim() || null); - } catch(err) { - job.reject(err); - } - this.busy = false; - this.processQueue(); - } - /** * Convert image to text using Optical Character Recognition * @param {string} path Path to image * @returns {AbortablePromise} Promise of extracted text with abort method */ ocr(path: string): AbortablePromise { - let aborted = false; - const abort = () => { aborted = true; }; - const p = new Promise((resolve, reject) => { - this.queue.push({ - path, - resolve: (text: string | null) => !aborted && resolve(text), - reject: (err: Error) => !aborted && reject(err) - }); - this.processQueue(); + let worker: any; + const p = new Promise(async res => { + worker = await createWorker(this.ai.options.ocr || 'eng', 2, {cachePath: this.ai.options.path}); + const {data} = await worker.recognize(path); + await worker.terminate(); + res(data.text.trim() || null); }); - return Object.assign(p, {abort}); + return Object.assign(p, {abort: () => worker?.terminate()}); } } diff --git a/vite.config.ts b/vite.config.ts index 66b7f7a..4202699 100644 --- a/vite.config.ts +++ b/vite.config.ts @@ -5,7 +5,6 @@ export default defineConfig({ build: { lib: { entry: { - asr: './src/asr.ts', index: './src/index.ts', embedder: './src/embedder.ts', },