import PyPDF2 from io import StringIO from langchain import FAISS from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings from chat_utils import load_chat_embeddings, create_and_save_chat_embeddings, qa_from_db, doc_loader import streamlit as st from my_prompts import file_map, file_combine, youtube_map, youtube_combine import os from summary_utils import doc_to_text, token_counter, summary_prompt_creator, doc_to_final_summary def pdf_to_text(pdf_file): """ Convert a PDF file to a string of text. :param pdf_file: The PDF file to convert. :return: A string of text. """ pdf_reader = PyPDF2.PdfReader(pdf_file) text = StringIO() for i in range(len(pdf_reader.pages)): p = pdf_reader.pages[i] text.write(p.extract_text()) return text.getvalue().encode('utf-8') def check_gpt_4(): """ Check if the user has access to GPT-4. :param api_key: The user's OpenAI API key. :return: True if the user has access to GPT-4, False otherwise. """ try: ChatOpenAI(model_name='gpt-4').call_as_llm('Hi') return True except Exception as e: return False def token_limit(doc, maximum=200000): """ Check if a document has more tokens than a specified maximum. :param doc: The langchain Document object to check. :param maximum: The maximum number of tokens allowed. :return: True if the document has less than the maximum number of tokens, False otherwise. """ text = doc_to_text(doc) count = token_counter(text) print(count) if count > maximum: return False return True def token_minimum(doc, minimum=2000): """ Check if a document has more tokens than a specified minimum. :param doc: The langchain Document object to check. :param minimum: The minimum number of tokens allowed. :return: True if the document has more than the minimum number of tokens, False otherwise. """ text = doc_to_text(doc) count = token_counter(text) if count < minimum: return False return True def validate_api_key(model_name='gpt-3.5-turbo'): try: ChatOpenAI(model_name=model_name).call_as_llm('Hi') print('API Key is valid') return True except Exception as e: print(e) st.warning('API key is invalid or OpenAI is having issues.') print('Invalid API key.') def create_chat_model_for_summary(use_gpt_4): """ Create a chat model ensuring that the token limit of the overall summary is not exceeded - GPT-4 has a higher token limit. :param api_key: The OpenAI API key to use for the chat model. :param use_gpt_4: Whether to use GPT-4 or not. :return: A chat model. """ if use_gpt_4: return ChatOpenAI(temperature=0, max_tokens=500, model_name='gpt-3.5-turbo') else: return ChatOpenAI(temperature=0, max_tokens=250, model_name='gpt-3.5-turbo') def process_summarize_button(file_or_transcript, use_gpt_4, find_clusters, file=True): """ Processes the summarize button, and displays the summary if input and doc size are valid :param file_or_transcript: The file uploaded by the user or the transcript from the YouTube URL :param api_key: The API key entered by the user :param use_gpt_4: Whether to use GPT-4 or not :param find_clusters: Whether to find optimal clusters or not, experimental :return: None """ if not validate_input(file_or_transcript, use_gpt_4): return with st.spinner("Summarizing... please wait..."): if file: doc = doc_loader(file_or_transcript) map_prompt = file_map combine_prompt = file_combine head, tail = os.path.split(file_or_transcript) name = tail.split('.')[0] else: doc = file_or_transcript map_prompt = youtube_map combine_prompt = youtube_combine name = str(file_or_transcript)[30:44].strip() llm = create_chat_model_for_summary(use_gpt_4) initial_prompt_list = summary_prompt_creator(map_prompt, 'text', llm) final_prompt_list = summary_prompt_creator(combine_prompt, 'text', llm) if not validate_doc_size(doc): return if find_clusters: summary = doc_to_final_summary(doc, 10, initial_prompt_list, final_prompt_list, use_gpt_4, find_clusters) else: summary = doc_to_final_summary(doc, 10, initial_prompt_list, final_prompt_list, use_gpt_4) st.markdown(summary, unsafe_allow_html=True) with open(f'summaries/{name}_summary.txt', 'w') as f: f.write(summary) st.text(f' Summary saved to summaries/{name}_summary.txt') def validate_doc_size(doc): """ Validates the size of the document :param doc: doc to validate :return: True if the doc is valid, False otherwise """ if not token_limit(doc, 800000): st.warning('File or transcript too big!') return False if not token_minimum(doc, 2000): st.warning('File or transcript too small!') return False return True def validate_input(file_or_transcript, use_gpt_4): """ Validates the user input, and displays warnings if the input is invalid :param file_or_transcript: The file uploaded by the user or the YouTube URL entered by the user :param use_gpt_4: Whether the user wants to use GPT-4 :return: True if the input is valid, False otherwise """ if file_or_transcript == None: st.warning("Please upload a file or enter a YouTube URL.") return False if not validate_api_key(): st.warning('Key not valid or API is down.') return False if use_gpt_4 and not check_gpt_4(): st.warning('Key not valid for GPT-4.') return False return True def generate_answer(db=None, llm_model=None, hypothetical=False): user_message = st.session_state.text_input if db and user_message.strip() != "": with st.spinner('Generating answer...'): print('About to call API') sys_message, sources = qa_from_db(user_message, db, llm_model, hypothetical) print('Done calling API') st.session_state.history.append({'message': user_message, 'is_user': True}) st.session_state.history.append({'message': sys_message, 'is_user': False}) st.session_state.sources = [] st.session_state.sources.append(sources) return sys_message, sources else: print(user_message) print('failed') print(db) def load_db_from_file_and_create_if_not_exists(file_path): with st.spinner('Loading chat embeddings...'): try: db = load_chat_embeddings(file_path) print('success') except RuntimeError: print('not found') create_and_save_chat_embeddings(file_path) db = load_chat_embeddings(file_path) if db: st.success('Loaded successfully! Start a chat below.') else: st.warning('Something went wrong... failed to load chat embeddings.') return db def load_dir_chat_embeddings(file_path): name = os.path.split(file_path)[1].split('.')[0] embeddings = OpenAIEmbeddings() try: db = FAISS.load_local(folder_path='directory_embeddings', index_name=name, embeddings=embeddings) st.success('Embeddings loaded successfully.') except Exception as e: st.warning('Loading embeddings failed. Please try again.') return None return db