From b077e7a7eec1fe42ad2dd2f418b52796d4096989 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Pelletier?= Date: Sun, 3 Nov 2024 23:12:05 -0500 Subject: [PATCH] =?UTF-8?q?Changements=20de=20param=C3=A8tres=20pour=20la?= =?UTF-8?q?=20performance?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/app.py b/app.py index 7404248..9ce255c 100644 --- a/app.py +++ b/app.py @@ -5,6 +5,7 @@ from moviepy.editor import VideoFileClip import ffmpeg from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer import tqdm +import torch # Add this line to import PyTorch # Load Whisper model @st.cache_resource @@ -38,11 +39,19 @@ def transcribe_audio(audio_file, language, chunk_length=3): chunk = audio_input[i:i+samples_per_chunk] # Pad/trim audio chunk - chunk_input = processor.feature_extractor(chunk, sampling_rate=sr, return_tensors="pt").input_features + inputs = processor(chunk, sampling_rate=sr, return_tensors="pt") + input_features = inputs.input_features + + # Generate attention mask + attention_mask = torch.ones_like(input_features) # Generate token ids forced_decoder_ids = tokenizer.get_decoder_prompt_ids(language=language, task="transcribe") - predicted_ids = model.generate(chunk_input, forced_decoder_ids=forced_decoder_ids) + predicted_ids = model.generate( + input_features, + forced_decoder_ids=forced_decoder_ids, + attention_mask=attention_mask + ) # Decode token ids to text transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) @@ -156,8 +165,11 @@ def convert_to_web_compatible(input_path): ( ffmpeg .input(input_path) - .output(output_path, vcodec='libx264', acodec='aac', - video_bitrate='1000k', audio_bitrate='128k') + .output(output_path, + vcodec='h264_videotoolbox', # Use VideoToolbox for hardware-accelerated encoding + acodec='aac', + video_bitrate='1000k', + audio_bitrate='128k') .overwrite_output() .run(capture_stdout=True, capture_stderr=True) )