Fix check for blacklisted on free form answers.

This commit is contained in:
Cristiano Ferreira
2024-06-10 19:39:08 +01:00
parent b7c18517de
commit 3676d7ad39
2 changed files with 20 additions and 16 deletions

8
app.py
View File

@@ -480,7 +480,7 @@ def grade_speaking_task_1():
response['transcript'] = answer
logging.info("POST - speaking_task_1 - " + str(request_id) + " - Requesting fixed_text.")
logging.info("POST - speaking_task_1 - " + str(request_id) + " - Requesting fixed text.")
response['fixed_text'] = get_speaking_corrections(answer)
logging.info("POST - speaking_task_1 - " + str(request_id) + " - Fixed text: " + response['fixed_text'])
@@ -624,7 +624,7 @@ def grade_speaking_task_2():
response['transcript'] = answer
logging.info("POST - speaking_task_2 - " + str(request_id) + " - Requesting fixed_text.")
logging.info("POST - speaking_task_2 - " + str(request_id) + " - Requesting fixed text.")
response['fixed_text'] = get_speaking_corrections(answer)
logging.info("POST - speaking_task_2 - " + str(request_id) + " - Fixed text: " + response['fixed_text'])
@@ -748,11 +748,11 @@ def grade_speaking_task_3():
logging.info("POST - speaking_task_3 - " + str(request_id) + " - Downloading file " + item["answer"])
download_firebase_file(FIREBASE_BUCKET, item["answer"], sound_file_name)
logging.info("POST - speaking_task_1 - " + str(
logging.info("POST - speaking_task_3 - " + str(
request_id) + " - Downloaded file " + item["answer"] + " to " + sound_file_name)
answer_text = speech_to_text(sound_file_name)
logging.info("POST - speaking_task_1 - " + str(request_id) + " - Transcripted answer: " + answer_text)
logging.info("POST - speaking_task_3 - " + str(request_id) + " - Transcripted answer: " + answer_text)
text_answers.append(answer_text)
item["answer"] = answer_text

View File

@@ -2,8 +2,8 @@ import json
import os
import re
from openai import OpenAI
from dotenv import load_dotenv
from openai import OpenAI
from helper.constants import BLACKLISTED_WORDS, GPT_3_5_TURBO
from helper.token_counter import count_tokens
@@ -54,7 +54,7 @@ def check_fields(obj, fields):
return all(field in obj for field in fields)
def make_openai_call(model, messages, token_count, fields_to_check, temperature):
def make_openai_call(model, messages, token_count, fields_to_check, temperature, check_blacklisted=True):
global try_count
result = client.chat.completions.create(
model=model,
@@ -65,6 +65,7 @@ def make_openai_call(model, messages, token_count, fields_to_check, temperature)
)
result = result.choices[0].message.content
if check_blacklisted:
found_blacklisted_word = get_found_blacklisted_words(result)
if found_blacklisted_word is not None and try_count < TRY_LIMIT:
@@ -188,7 +189,7 @@ def get_fixed_text(text):
}
]
token_count = count_total_tokens(messages)
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2)
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2, False)
return response["fixed_text"]
@@ -203,7 +204,7 @@ def get_speaking_corrections(text):
}
]
token_count = count_total_tokens(messages)
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2)
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2, False)
return response["fixed_text"]
@@ -211,6 +212,7 @@ def has_blacklisted_words(text: str):
text_lower = text.lower()
return any(word in text_lower for word in BLACKLISTED_WORDS)
def get_found_blacklisted_words(text: str):
text_lower = text.lower()
for word in BLACKLISTED_WORDS:
@@ -218,6 +220,7 @@ def get_found_blacklisted_words(text: str):
return word
return None
def remove_special_characters_from_beginning(string):
cleaned_string = string.lstrip('\n')
if string.startswith("'") or string.startswith('"'):
@@ -239,6 +242,7 @@ def replace_expression_in_object(obj, expression, replacement):
obj[key] = replace_expression_in_object(obj[key], expression, replacement)
return obj
def count_total_tokens(messages):
total_tokens = 0
for message in messages: