From 3676d7ad391e4054a4ca187c2bcf5c5d4b2dbd51 Mon Sep 17 00:00:00 2001 From: Cristiano Ferreira Date: Mon, 10 Jun 2024 19:39:08 +0100 Subject: [PATCH] Fix check for blacklisted on free form answers. --- app.py | 8 ++++---- helper/openai_interface.py | 28 ++++++++++++++++------------ 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/app.py b/app.py index 9712e09..66cfab7 100644 --- a/app.py +++ b/app.py @@ -480,7 +480,7 @@ def grade_speaking_task_1(): response['transcript'] = answer - logging.info("POST - speaking_task_1 - " + str(request_id) + " - Requesting fixed_text.") + logging.info("POST - speaking_task_1 - " + str(request_id) + " - Requesting fixed text.") response['fixed_text'] = get_speaking_corrections(answer) logging.info("POST - speaking_task_1 - " + str(request_id) + " - Fixed text: " + response['fixed_text']) @@ -624,7 +624,7 @@ def grade_speaking_task_2(): response['transcript'] = answer - logging.info("POST - speaking_task_2 - " + str(request_id) + " - Requesting fixed_text.") + logging.info("POST - speaking_task_2 - " + str(request_id) + " - Requesting fixed text.") response['fixed_text'] = get_speaking_corrections(answer) logging.info("POST - speaking_task_2 - " + str(request_id) + " - Fixed text: " + response['fixed_text']) @@ -748,11 +748,11 @@ def grade_speaking_task_3(): logging.info("POST - speaking_task_3 - " + str(request_id) + " - Downloading file " + item["answer"]) download_firebase_file(FIREBASE_BUCKET, item["answer"], sound_file_name) - logging.info("POST - speaking_task_1 - " + str( + logging.info("POST - speaking_task_3 - " + str( request_id) + " - Downloaded file " + item["answer"] + " to " + sound_file_name) answer_text = speech_to_text(sound_file_name) - logging.info("POST - speaking_task_1 - " + str(request_id) + " - Transcripted answer: " + answer_text) + logging.info("POST - speaking_task_3 - " + str(request_id) + " - Transcripted answer: " + answer_text) text_answers.append(answer_text) item["answer"] = answer_text diff --git a/helper/openai_interface.py b/helper/openai_interface.py index 77fa05d..2d50a64 100644 --- a/helper/openai_interface.py +++ b/helper/openai_interface.py @@ -2,8 +2,8 @@ import json import os import re -from openai import OpenAI from dotenv import load_dotenv +from openai import OpenAI from helper.constants import BLACKLISTED_WORDS, GPT_3_5_TURBO from helper.token_counter import count_tokens @@ -54,7 +54,7 @@ def check_fields(obj, fields): return all(field in obj for field in fields) -def make_openai_call(model, messages, token_count, fields_to_check, temperature): +def make_openai_call(model, messages, token_count, fields_to_check, temperature, check_blacklisted=True): global try_count result = client.chat.completions.create( model=model, @@ -65,15 +65,16 @@ def make_openai_call(model, messages, token_count, fields_to_check, temperature) ) result = result.choices[0].message.content - found_blacklisted_word = get_found_blacklisted_words(result) + if check_blacklisted: + found_blacklisted_word = get_found_blacklisted_words(result) - if found_blacklisted_word is not None and try_count < TRY_LIMIT: - from app import app - app.logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word)) - try_count = try_count + 1 - return make_openai_call(model, messages, token_count, fields_to_check, temperature) - elif found_blacklisted_word is not None and try_count >= TRY_LIMIT: - return "" + if found_blacklisted_word is not None and try_count < TRY_LIMIT: + from app import app + app.logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word)) + try_count = try_count + 1 + return make_openai_call(model, messages, token_count, fields_to_check, temperature) + elif found_blacklisted_word is not None and try_count >= TRY_LIMIT: + return "" if fields_to_check is None: return json.loads(result) @@ -188,7 +189,7 @@ def get_fixed_text(text): } ] token_count = count_total_tokens(messages) - response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2) + response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2, False) return response["fixed_text"] @@ -203,7 +204,7 @@ def get_speaking_corrections(text): } ] token_count = count_total_tokens(messages) - response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2) + response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2, False) return response["fixed_text"] @@ -211,6 +212,7 @@ def has_blacklisted_words(text: str): text_lower = text.lower() return any(word in text_lower for word in BLACKLISTED_WORDS) + def get_found_blacklisted_words(text: str): text_lower = text.lower() for word in BLACKLISTED_WORDS: @@ -218,6 +220,7 @@ def get_found_blacklisted_words(text: str): return word return None + def remove_special_characters_from_beginning(string): cleaned_string = string.lstrip('\n') if string.startswith("'") or string.startswith('"'): @@ -239,6 +242,7 @@ def replace_expression_in_object(obj, expression, replacement): obj[key] = replace_expression_in_object(obj[key], expression, replacement) return obj + def count_total_tokens(messages): total_tokens = 0 for message in messages: