reel-caption-maker/app.py
2024-11-03 23:12:05 -05:00

276 lines
11 KiB
Python

import streamlit as st
import os
import tempfile
from moviepy.editor import VideoFileClip
import ffmpeg
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer
import tqdm
import torch # Add this line to import PyTorch
# Load Whisper model
@st.cache_resource
def load_whisper_model():
try:
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")
return processor, model
except Exception as e:
st.error(f"Failed to load Whisper model: {str(e)}")
return None, None
processor, model = load_whisper_model()
def transcribe_audio(audio_file, language, chunk_length=3):
if model is None or processor is None:
st.error("Whisper model is not loaded. Cannot transcribe audio.")
return []
# Load audio
audio_input, sr = AudioLoader.load_audio(audio_file)
# Calculate number of samples per chunk
samples_per_chunk = int(chunk_length * sr)
# Get the tokenizer
tokenizer = WhisperTokenizer.from_pretrained(model.config._name_or_path, language=language)
segments = []
for i in tqdm.tqdm(range(0, len(audio_input), samples_per_chunk)):
chunk = audio_input[i:i+samples_per_chunk]
# Pad/trim audio chunk
inputs = processor(chunk, sampling_rate=sr, return_tensors="pt")
input_features = inputs.input_features
# Generate attention mask
attention_mask = torch.ones_like(input_features)
# Generate token ids
forced_decoder_ids = tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
predicted_ids = model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids,
attention_mask=attention_mask
)
# Decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
start_time = i / sr
end_time = min((i + samples_per_chunk) / sr, len(audio_input) / sr)
segments.append({
"start": start_time,
"end": end_time,
"text": transcription[0].strip()
})
return segments
def format_srt(segments):
srt_content = ""
for i, segment in tqdm.tqdm(enumerate(segments, start=1)):
start_time = format_timestamp(segment['start'])
end_time = format_timestamp(segment['end'])
text = segment['text'].strip()
if text: # Only add non-empty segments
srt_content += f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
return srt_content
def format_timestamp(seconds):
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
milliseconds = int((seconds - int(seconds)) * 1000)
return f"{hours:02d}:{minutes:02d}:{int(seconds):02d},{milliseconds:03d}"
# Add this helper class for audio loading
class AudioLoader:
@staticmethod
def load_audio(file_path):
import librosa
audio, sr = librosa.load(file_path, sr=16000)
return audio, sr
def burn_subtitles(video_path, srt_content, subtitle_style):
with tempfile.NamedTemporaryFile(delete=False, suffix='.srt') as temp_srt:
temp_srt.write(srt_content.encode('utf-8'))
temp_srt_path = temp_srt.name
output_path = os.path.splitext(video_path)[0] + '_with_captions.mp4'
temp_video_path = os.path.splitext(video_path)[0] + '_temp_video.mp4'
temp_audio_path = os.path.splitext(video_path)[0] + '_temp_audio.aac'
try:
# Extract video metadata
probe = ffmpeg.probe(video_path)
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
if video_stream is None:
raise ValueError("No video stream found in the input file.")
width = int(video_stream['width'])
height = int(video_stream['height'])
# Extract audio
ffmpeg.input(video_path).output(temp_audio_path, acodec='aac', audio_bitrate='128k').overwrite_output().run(capture_stdout=True, capture_stderr=True)
# Process video with subtitles (without audio)
ffmpeg.input(video_path).filter(
'subtitles',
temp_srt_path,
force_style=subtitle_style
).output(
temp_video_path,
vcodec='libx264',
video_bitrate='2000k',
an=None,
s=f'{width}x{height}'
).overwrite_output().run(capture_stdout=True, capture_stderr=True)
# Combine video with subtitles and original audio
ffmpeg.concat(
ffmpeg.input(temp_video_path),
ffmpeg.input(temp_audio_path),
v=1,
a=1
).output(output_path, vcodec='libx264', acodec='aac').overwrite_output().run(capture_stdout=True, capture_stderr=True)
# Check if the output file was created and has both video and audio streams
if os.path.exists(output_path):
output_probe = ffmpeg.probe(output_path)
output_video_stream = next((stream for stream in output_probe['streams'] if stream['codec_type'] == 'video'), None)
output_audio_stream = next((stream for stream in output_probe['streams'] if stream['codec_type'] == 'audio'), None)
if output_video_stream is None or output_audio_stream is None:
raise ValueError("Output file is missing video or audio stream.")
else:
raise FileNotFoundError("Output file was not created.")
except (ffmpeg.Error, ValueError, FileNotFoundError) as e:
st.error(f"An error occurred while burning subtitles: {str(e)}")
return None
finally:
os.unlink(temp_srt_path)
if os.path.exists(temp_video_path):
os.unlink(temp_video_path)
if os.path.exists(temp_audio_path):
os.unlink(temp_audio_path)
return output_path
def convert_to_web_compatible(input_path):
output_path = os.path.splitext(input_path)[0] + '_web.mp4'
try:
(
ffmpeg
.input(input_path)
.output(output_path,
vcodec='h264_videotoolbox', # Use VideoToolbox for hardware-accelerated encoding
acodec='aac',
video_bitrate='1000k',
audio_bitrate='128k')
.overwrite_output()
.run(capture_stdout=True, capture_stderr=True)
)
return output_path
except ffmpeg.Error as e:
st.error(f"An error occurred while converting the video: {e.stderr.decode()}")
return None
st.title("Reel Caption Maker")
if 'temp_video_path' not in st.session_state:
st.session_state.temp_video_path = None
if 'web_compatible_video_path' not in st.session_state:
st.session_state.web_compatible_video_path = None
uploaded_file = st.file_uploader("Choose a video file", type=["mp4", "mov", "avi"])
if uploaded_file is not None:
# Save the uploaded file to a temporary location if not already done
if st.session_state.temp_video_path is None or not os.path.exists(st.session_state.temp_video_path):
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_video:
temp_video.write(uploaded_file.read())
st.session_state.temp_video_path = temp_video.name
# Convert the video to web-compatible format
st.session_state.web_compatible_video_path = convert_to_web_compatible(st.session_state.temp_video_path)
# Create two columns for layout
col1, col2 = st.columns([2, 3])
with col1:
st.subheader("Video Player")
# Display the web-compatible video
if st.session_state.web_compatible_video_path:
st.video(st.session_state.web_compatible_video_path)
else:
st.error("Failed to convert video to web-compatible format.")
with col2:
st.subheader("Captions")
language = st.selectbox("Select video language", ["French", "English"])
lang_code = "fr" if language == "French" else "en"
if st.button("Generate Captions"):
with st.spinner("Generating captions..."):
video = VideoFileClip(st.session_state.temp_video_path)
audio = video.audio
audio.write_audiofile("temp_audio.wav")
segments = transcribe_audio("temp_audio.wav", lang_code, chunk_length=3)
srt_content = format_srt(segments)
st.session_state.srt_content = srt_content
st.session_state.temp_audio_path = "temp_audio.wav" # Store the audio path
video.close()
if 'srt_content' in st.session_state:
edited_srt = st.text_area("Edit Captions (SRT format)", st.session_state.srt_content, height=300)
if st.button("Burn Captions and Download"):
with st.spinner("Burning captions onto video..."):
subtitle_style = (
'Fontname=Arial,Fontsize=16,'
'PrimaryColour=&H00FFFFFF&,'
'SecondaryColour=&H00000000&,'
'OutlineColour=&H00000000&,'
'BackColour=&H40000000&,'
'BorderStyle=1,'
'Outline=1,'
'Shadow=1,'
'MarginV=20'
)
output_path = burn_subtitles(st.session_state.temp_video_path, edited_srt, subtitle_style)
if output_path:
with open(output_path, "rb") as file:
st.download_button(
label="Download Video with Captions",
data=file,
file_name="video_with_captions.mp4",
mime="video/mp4"
)
os.remove(output_path)
os.remove(st.session_state.temp_audio_path)
if 'temp_audio_path' in st.session_state:
del st.session_state.temp_audio_path
else:
st.error("Failed to burn captions onto the video.")
if st.button("Reset"):
# Clear session state
for key in list(st.session_state.keys()):
del st.session_state[key]
# Remove temporary files
if 'temp_video_path' in st.session_state and os.path.exists(st.session_state.temp_video_path):
os.remove(st.session_state.temp_video_path)
if 'web_compatible_video_path' in st.session_state and os.path.exists(st.session_state.web_compatible_video_path):
os.remove(st.session_state.web_compatible_video_path)
if 'temp_audio_path' in st.session_state and os.path.exists(st.session_state.temp_audio_path):
os.remove(st.session_state.temp_audio_path)
st.success("All data has been reset. You can now upload a new video.")