diff --git a/main.py b/main.py index 229388c..60063b4 100644 --- a/main.py +++ b/main.py @@ -68,6 +68,7 @@ def chat(): dir_or_doc = st.radio('Select a chat method', ('Document', 'Directory')) st.title('Chat') model_name = st.radio('Select a model', ('gpt-3.5-turbo', 'gpt-4')) + hypothetical = st.checkbox('Use hypothetical embeddings', value=False) if dir_or_doc == 'Document': if 'text_input' not in st.session_state: st.session_state.text_input = '' @@ -111,7 +112,7 @@ def chat(): user_input = st.text_input('Enter your question', key='text_input') if st.button('Ask') and 'db' in st.session_state and validate_api_key(model_name): - answer = generate_answer(st.session_state.db, model_name) + answer = generate_answer(st.session_state.db, model_name, hypothetical) if 'history' not in st.session_state: st.session_state.history = [] @@ -137,17 +138,66 @@ def documents(): st.write('No documents found in documents folder. Add some documents first!') +def compare_results(): + st.title('Compare') + st.write("Compare retrieval results using hypothetical embeddings vs. normal embeddings. Support for comparing multiple models coming soon.") + model_name = 'gpt-3.5-turbo' + + if 'text_input' not in st.session_state: + st.session_state.text_input = '' + directory = 'documents' + files = os.listdir(directory) + files = [file for file in files if file.endswith(tuple(accepted_filetypes))] + selected_file = st.selectbox('Select a file', files) + st.write('You selected: ' + selected_file) + selected_file_path = os.path.join(directory, selected_file) + + if st.button('Load file (first time might take a second...) pressing this button will reset the chat history'): + db = load_db_from_file_and_create_if_not_exists(selected_file_path) + st.session_state.db = db + st.session_state.history = [] + + + + + user_input = st.text_input('Enter your question', key='text_input') + + if st.button('Ask') and 'db' in st.session_state and validate_api_key(model_name): + st.markdown('Question: ' + user_input) + answer_a, sources_a = generate_answer(st.session_state.db, model_name, hypothetical=True) + answer_b, sources_b = generate_answer(st.session_state.db, model_name, hypothetical=False) + + col1, col2 = st.columns(2) + + with col1: + st.header('Hypothetical embeddings') + st.markdown(answer_a) + with st.expander('Sources', expanded=False): + st.markdown(sources_a) + with col2: + st.header('Normal embeddings') + st.markdown(answer_b) + with st.expander('Sources', expanded=False): + st.markdown(sources_b) + + st.session_state.history = [] + st.session_state.sources = [] + + + PAGES = { "Chat": chat, "Summarize": summarize, "Documents": documents, + "Compare": compare_results } st.sidebar.title("Navigation") selection = st.sidebar.radio("Go to", list(PAGES.keys())) st.sidebar.markdown(' [Contact author](mailto:ethanujohnston@gmail.com)') st.sidebar.markdown(' [Github](https://github.com/e-johnstonn/docGPT)') +st.sidebar.markdown('[More info on hypothetical embeddings here](https://arxiv.org/abs/2212.10496)', unsafe_allow_html=True) page = PAGES[selection] page()