File size: 2,887 Bytes
9036c2a
a0d0642
 
9036c2a
a0d0642
9036c2a
a0d0642
9036c2a
 
 
 
a0d0642
9036c2a
a0d0642
9036c2a
 
a0d0642
 
 
 
 
9036c2a
 
 
 
a0d0642
9036c2a
 
 
 
 
 
 
 
42e6cab
 
 
 
 
 
a0d0642
42e6cab
 
 
 
 
 
 
 
a0d0642
42e6cab
 
a0d0642
42e6cab
 
 
 
 
 
 
 
 
 
 
 
a0d0642
42e6cab
 
a0d0642
42e6cab
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st  # Don't forget to include `streamlit` in your `requirements.txt` file to ensure the app runs properly on Hugging Face Spaces.

from transformers import AutoProcessor, AutoModelForImageTextToText  # Updated imports to reflect changes
from PIL import Image  # Ensure the `pillow` library is included in your `requirements.txt`.

import torch  # Since PyTorch is required for this app, specify the appropriate version of `torch` in `requirements.txt` based on compatibility with the model.

import os

def load_model():
    """Load PaliGemma2 model and processor with Hugging Face token."""
    
    token = os.getenv("HUGGINGFACEHUB_API_TOKEN")  # Retrieve token from environment variable

    if not token:
        raise ValueError("Hugging Face API token not found. Please set it in the environment variables.")

    # Load the processor and model using the correct identifier
    processor = AutoProcessor.from_pretrained("google/paligemma2-3b-pt-224", use_auth_token=token)
    model = AutoModelForImageTextToText.from_pretrained("google/paligemma2-3b-pt-224", use_auth_token=token)

    return processor, model

def process_image(image, processor, model):
    """Extract text from image using PaliGemma2."""
    
    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt")
    
    # Generate predictions
    with torch.no_grad():
        generated_ids = model.generate(**inputs)
        text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    return text

def main():
    # Set page configuration
    st.set_page_config(page_title="Text Reading with PaliGemma2", layout="centered")
    st.title("Text Reading from Images using PaliGemma2")
    
    # Load model and processor
    with st.spinner("Loading PaliGemma2 model... This may take a few moments."):
        try:
            processor, model = load_model()
            st.success("Model loaded successfully!")
        except ValueError as e:
            st.error(str(e))
            st.stop()
    
    # User input: upload image
    uploaded_image = st.file_uploader("Upload an image containing text", type=["png", "jpg", "jpeg"])
    
    if uploaded_image is not None:
        # Display uploaded image
        image = Image.open(uploaded_image)
        st.image(image, caption="Uploaded Image", use_column_width=True)

        # Extract text button
        if st.button("Extract Text"):
            with st.spinner("Processing image..."):
                extracted_text = process_image(image, processor, model)
                st.success("Text extraction complete!")
                st.subheader("Extracted Text")
                st.write(extracted_text)
    
    # Footer
    st.markdown("---")
    st.markdown("**Built with [PaliGemma2](https://huggingface.co/google/paligemma2-3b-pt-224) and Streamlit**")

if __name__ == "__main__":
    main()