Ajout du fichier de script

Signed-off-by: Francois Pelletier <francois@noreply.git.jevalide.ca>
This commit is contained in:
Francois Pelletier 2024-03-09 23:57:28 +00:00
parent bdc0978d1b
commit b1976ba391

72
whisper.py Normal file
View file

@ -0,0 +1,72 @@
# %% Utilisation de Whisper pour la transcription de podcasts en français
import numpy as np
import torch
import torchaudio
import tqdm
from transformers import (
AutoModelForCausalLM,
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
)
# %% File paths
audio_paths = ["METTRE LES LIENS DES FICHIERS MP3 OU WAV ICI"]
# %% load PyTorch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# %% Load model
model_name_or_path = "bofenghuang/whisper-large-v3-french"
processor = AutoProcessor.from_pretrained(model_name_or_path)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name_or_path,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
)
model.to(device)
# %% Load draft model
assistant_model_name_or_path = "bofenghuang/whisper-large-v3-french-distil-dec2"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_name_or_path,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
)
assistant_model.to(device)
# %% Init pipeline
pipe = pipeline(
"automatic-speech-recognition",
model=model,
feature_extractor=processor.feature_extractor,
tokenizer=processor.tokenizer,
torch_dtype=torch_dtype,
device=device,
generate_kwargs={"assistant_model": assistant_model},
max_new_tokens=128,
)
# %% Transcript function
def transcript(audio_path):
# Load audio
model_sr = 16000
speech, sr = torchaudio.load(audio_path)
speech_16000 = torchaudio.functional.resample(speech, orig_freq=sr, new_freq=model_sr)
speech_16000 = speech_16000.squeeze()
# Run pipeline
result = pipe(np.array(speech_16000))
# Save text result to file
transcript_path = f'whisper-large/{audio_path.replace(".mp3", "_transcript_whisper.txt")}'
with open(transcript_path, "w") as f:
f.write(result["text"])
return None
# %% Transcription loop
for audio_path in tqdm.tqdm(audio_paths):
transcript(audio_path)