import streamlit as st import os import tempfile from moviepy.editor import VideoFileClip import ffmpeg from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer import tqdm import torch # Add this line to import PyTorch # Load Whisper model @st.cache_resource def load_whisper_model(): try: processor = WhisperProcessor.from_pretrained("openai/whisper-medium") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium") return processor, model except Exception as e: st.error(f"Failed to load Whisper model: {str(e)}") return None, None processor, model = load_whisper_model() def transcribe_audio(audio_file, language, chunk_length=3): if model is None or processor is None: st.error("Whisper model is not loaded. Cannot transcribe audio.") return [] # Load audio audio_input, sr = AudioLoader.load_audio(audio_file) # Calculate number of samples per chunk samples_per_chunk = int(chunk_length * sr) # Get the tokenizer tokenizer = WhisperTokenizer.from_pretrained(model.config._name_or_path, language=language) segments = [] for i in tqdm.tqdm(range(0, len(audio_input), samples_per_chunk)): chunk = audio_input[i:i+samples_per_chunk] # Pad/trim audio chunk inputs = processor(chunk, sampling_rate=sr, return_tensors="pt") input_features = inputs.input_features # Generate attention mask attention_mask = torch.ones_like(input_features) # Generate token ids forced_decoder_ids = tokenizer.get_decoder_prompt_ids(language=language, task="transcribe") predicted_ids = model.generate( input_features, forced_decoder_ids=forced_decoder_ids, attention_mask=attention_mask ) # Decode token ids to text transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) start_time = i / sr end_time = min((i + samples_per_chunk) / sr, len(audio_input) / sr) segments.append({ "start": start_time, "end": end_time, "text": transcription[0].strip() }) return segments def format_srt(segments): srt_content = "" for i, segment in tqdm.tqdm(enumerate(segments, start=1)): start_time = format_timestamp(segment['start']) end_time = format_timestamp(segment['end']) text = segment['text'].strip() if text: # Only add non-empty segments srt_content += f"{i}\n{start_time} --> {end_time}\n{text}\n\n" return srt_content def format_timestamp(seconds): hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) seconds = seconds % 60 milliseconds = int((seconds - int(seconds)) * 1000) return f"{hours:02d}:{minutes:02d}:{int(seconds):02d},{milliseconds:03d}" # Add this helper class for audio loading class AudioLoader: @staticmethod def load_audio(file_path): import librosa audio, sr = librosa.load(file_path, sr=16000) return audio, sr def burn_subtitles(video_path, srt_content, subtitle_style): with tempfile.NamedTemporaryFile(delete=False, suffix='.srt') as temp_srt: temp_srt.write(srt_content.encode('utf-8')) temp_srt_path = temp_srt.name output_path = os.path.splitext(video_path)[0] + '_with_captions.mp4' temp_video_path = os.path.splitext(video_path)[0] + '_temp_video.mp4' temp_audio_path = os.path.splitext(video_path)[0] + '_temp_audio.aac' try: # Extract video metadata probe = ffmpeg.probe(video_path) video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) if video_stream is None: raise ValueError("No video stream found in the input file.") width = int(video_stream['width']) height = int(video_stream['height']) # Extract audio ffmpeg.input(video_path).output(temp_audio_path, acodec='aac', audio_bitrate='128k').overwrite_output().run(capture_stdout=True, capture_stderr=True) # Process video with subtitles (without audio) ffmpeg.input(video_path).filter( 'subtitles', temp_srt_path, force_style=subtitle_style ).output( temp_video_path, vcodec='libx264', video_bitrate='2000k', an=None, s=f'{width}x{height}' ).overwrite_output().run(capture_stdout=True, capture_stderr=True) # Combine video with subtitles and original audio ffmpeg.concat( ffmpeg.input(temp_video_path), ffmpeg.input(temp_audio_path), v=1, a=1 ).output(output_path, vcodec='libx264', acodec='aac').overwrite_output().run(capture_stdout=True, capture_stderr=True) # Check if the output file was created and has both video and audio streams if os.path.exists(output_path): output_probe = ffmpeg.probe(output_path) output_video_stream = next((stream for stream in output_probe['streams'] if stream['codec_type'] == 'video'), None) output_audio_stream = next((stream for stream in output_probe['streams'] if stream['codec_type'] == 'audio'), None) if output_video_stream is None or output_audio_stream is None: raise ValueError("Output file is missing video or audio stream.") else: raise FileNotFoundError("Output file was not created.") except (ffmpeg.Error, ValueError, FileNotFoundError) as e: st.error(f"An error occurred while burning subtitles: {str(e)}") return None finally: os.unlink(temp_srt_path) if os.path.exists(temp_video_path): os.unlink(temp_video_path) if os.path.exists(temp_audio_path): os.unlink(temp_audio_path) return output_path def convert_to_web_compatible(input_path): output_path = os.path.splitext(input_path)[0] + '_web.mp4' try: ( ffmpeg .input(input_path) .output(output_path, vcodec='h264_videotoolbox', # Use VideoToolbox for hardware-accelerated encoding acodec='aac', video_bitrate='1000k', audio_bitrate='128k') .overwrite_output() .run(capture_stdout=True, capture_stderr=True) ) return output_path except ffmpeg.Error as e: st.error(f"An error occurred while converting the video: {e.stderr.decode()}") return None st.title("Reel Caption Maker") if 'temp_video_path' not in st.session_state: st.session_state.temp_video_path = None if 'web_compatible_video_path' not in st.session_state: st.session_state.web_compatible_video_path = None uploaded_file = st.file_uploader("Choose a video file", type=["mp4", "mov", "avi"]) if uploaded_file is not None: # Save the uploaded file to a temporary location if not already done if st.session_state.temp_video_path is None or not os.path.exists(st.session_state.temp_video_path): with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_video: temp_video.write(uploaded_file.read()) st.session_state.temp_video_path = temp_video.name # Convert the video to web-compatible format st.session_state.web_compatible_video_path = convert_to_web_compatible(st.session_state.temp_video_path) # Create two columns for layout col1, col2 = st.columns([2, 3]) with col1: st.subheader("Video Player") # Display the web-compatible video if st.session_state.web_compatible_video_path: st.video(st.session_state.web_compatible_video_path) else: st.error("Failed to convert video to web-compatible format.") with col2: st.subheader("Captions") language = st.selectbox("Select video language", ["French", "English"]) lang_code = "fr" if language == "French" else "en" if st.button("Generate Captions"): with st.spinner("Generating captions..."): video = VideoFileClip(st.session_state.temp_video_path) audio = video.audio audio.write_audiofile("temp_audio.wav") segments = transcribe_audio("temp_audio.wav", lang_code, chunk_length=3) srt_content = format_srt(segments) st.session_state.srt_content = srt_content st.session_state.temp_audio_path = "temp_audio.wav" # Store the audio path video.close() if 'srt_content' in st.session_state: edited_srt = st.text_area("Edit Captions (SRT format)", st.session_state.srt_content, height=300) if st.button("Burn Captions and Download"): with st.spinner("Burning captions onto video..."): subtitle_style = ( 'Fontname=Arial,Fontsize=16,' 'PrimaryColour=&H00FFFFFF&,' 'SecondaryColour=&H00000000&,' 'OutlineColour=&H00000000&,' 'BackColour=&H40000000&,' 'BorderStyle=1,' 'Outline=1,' 'Shadow=1,' 'MarginV=20' ) output_path = burn_subtitles(st.session_state.temp_video_path, edited_srt, subtitle_style) if output_path: with open(output_path, "rb") as file: st.download_button( label="Download Video with Captions", data=file, file_name="video_with_captions.mp4", mime="video/mp4" ) os.remove(output_path) os.remove(st.session_state.temp_audio_path) if 'temp_audio_path' in st.session_state: del st.session_state.temp_audio_path else: st.error("Failed to burn captions onto the video.") if st.button("Reset"): # Clear session state for key in list(st.session_state.keys()): del st.session_state[key] # Remove temporary files if 'temp_video_path' in st.session_state and os.path.exists(st.session_state.temp_video_path): os.remove(st.session_state.temp_video_path) if 'web_compatible_video_path' in st.session_state and os.path.exists(st.session_state.web_compatible_video_path): os.remove(st.session_state.web_compatible_video_path) if 'temp_audio_path' in st.session_state and os.path.exists(st.session_state.temp_audio_path): os.remove(st.session_state.temp_audio_path) st.success("All data has been reset. You can now upload a new video.")