import os import jax import jax.numpy as jnp import nmslib import numpy as np import streamlit as st from PIL import Image from transformers import AutoTokenizer, CLIPProcessor from model import FlaxHybridCLIP # st.header('Under construction') st.sidebar.title("CLIP React Demo") st.sidebar.write("[Model Card](https://huggingface.co/flax-community/clip-reply)") sc = st.sidebar.columns(2) sc[0].image("./huggingface_explode3.png", width=150) sc[1].write(" ") sc[1].write(" ") sc[1].markdown("## Researching fun") with st.sidebar.expander("Motivation", expanded=True): st.markdown( """ Reaction GIFs became an integral part of communication. They convey complex emotions with many levels, in a short compact format. If a picture is worth a thousand words then a GIF is worth more. A lot of people would agree it is not always easy to find the perfect reaction GIF. This is just a first step in the more ambitious goal of GIF/Image generation. """ ) top_k = st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20) col_count = 4 file_names = os.listdir("./jpg") file_names.sort() show_val = st.sidebar.button("show all validation set images") if show_val: cols = st.sidebar.columns(col_count) for i, im in enumerate(file_names): j = i % col_count cols[j].image("./jpg/" + im) st.write("# Search Reaction GIFs with CLIP ") st.write(" ") st.write(" ") @st.cache_resource() def load_model(): model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") processor.tokenizer = AutoTokenizer.from_pretrained( "cardiffnlp/twitter-roberta-base" ) return model, processor @st.cache_resource() def load_image_index(): index = nmslib.init(method="hnsw", space="cosinesimil") index.loadIndex("./features/image_embeddings", load_data=True) return index image_index = load_image_index() model, processor = load_model() # TODO def add_image_emb(image): image = Image.open(image).convert("RGB") inputs = processor(text=[""], images=image, return_tensors="jax", padding=True) inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) features = model(**inputs).image_embeds image_index.addDataPoint(features) def query_with_images(query_images, query_text): images = [] for im in query_images: img = Image.open(im).convert("RGB") if im.name.endswith(".gif"): img.seek(0) images.append(img) inputs = processor( text=[query_text], images=images, return_tensors="jax", padding=True ) inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) outputs = model(**inputs) logits_per_image = outputs.logits_per_image.reshape(-1) # st.write(logits_per_image) probs = jax.nn.softmax(logits_per_image) # st.write(probs) # st.write(list(zip(images,probs))) results = sorted(list(zip(images, probs)), key=lambda x: x[1], reverse=True) # st.write(results) return zip(*results) q_cols = st.columns([5, 2, 5]) examples = [ "OMG that is disgusting", "I'm so scared right now", " I got the job 🎉", "Congratulations to all the flax-community week teams", "You're awesome", "I love you ❤️", ] example_input = q_cols[0].radio( "Example Queries :", examples, index=4, help="These are examples I wrote off the top of my head. They don't occur in the dataset", ) q_cols[2].markdown( """ Searches among the validation set images if not specified (There may be non-exact duplicates) """ ) query_text = q_cols[0].text_input( "Write text you want to get reaction for", value=example_input ) query_images = q_cols[2].file_uploader( "(optional) Upload images to rank them", type=["jpg", "jpeg", "gif"], accept_multiple_files=True, ) if query_images: st.write("Ranking your uploaded images with respect to input text:") with st.spinner("Calculating..."): ids, dists = query_with_images(query_images, query_text) else: st.write("Found these images within validation set:") with st.spinner("Calculating..."): proc = processor( text=[query_text], images=None, return_tensors="jax", padding=True ) vec = np.asarray(model.get_text_features(**proc)) ids, dists = image_index.knnQuery(vec, k=top_k) show_gif = st.checkbox( "Play GIFs", value=True, help="Will play the original animation. Only first frame is used in training!", ) ext = "jpg" if not show_gif else "gif" res_cols = st.columns(col_count) for i, (id_, dist) in enumerate(zip(ids, dists)): j = i % col_count with res_cols[j]: if isinstance(id_, np.int32): st.image(f"./{ext}/{file_names[id_][:-4]}.{ext}") # st.write(file_names[id_]) st.write(1.0 - dist) else: st.image(id_) st.write(dist) # Credits st.sidebar.caption("Made by [Ceyda Cinarel](https://huggingface.co/ceyda)")