Fix check for blacklisted on free form answers.
This commit is contained in:
8
app.py
8
app.py
@@ -480,7 +480,7 @@ def grade_speaking_task_1():
|
|||||||
|
|
||||||
response['transcript'] = answer
|
response['transcript'] = answer
|
||||||
|
|
||||||
logging.info("POST - speaking_task_1 - " + str(request_id) + " - Requesting fixed_text.")
|
logging.info("POST - speaking_task_1 - " + str(request_id) + " - Requesting fixed text.")
|
||||||
response['fixed_text'] = get_speaking_corrections(answer)
|
response['fixed_text'] = get_speaking_corrections(answer)
|
||||||
logging.info("POST - speaking_task_1 - " + str(request_id) + " - Fixed text: " + response['fixed_text'])
|
logging.info("POST - speaking_task_1 - " + str(request_id) + " - Fixed text: " + response['fixed_text'])
|
||||||
|
|
||||||
@@ -624,7 +624,7 @@ def grade_speaking_task_2():
|
|||||||
|
|
||||||
response['transcript'] = answer
|
response['transcript'] = answer
|
||||||
|
|
||||||
logging.info("POST - speaking_task_2 - " + str(request_id) + " - Requesting fixed_text.")
|
logging.info("POST - speaking_task_2 - " + str(request_id) + " - Requesting fixed text.")
|
||||||
response['fixed_text'] = get_speaking_corrections(answer)
|
response['fixed_text'] = get_speaking_corrections(answer)
|
||||||
logging.info("POST - speaking_task_2 - " + str(request_id) + " - Fixed text: " + response['fixed_text'])
|
logging.info("POST - speaking_task_2 - " + str(request_id) + " - Fixed text: " + response['fixed_text'])
|
||||||
|
|
||||||
@@ -748,11 +748,11 @@ def grade_speaking_task_3():
|
|||||||
|
|
||||||
logging.info("POST - speaking_task_3 - " + str(request_id) + " - Downloading file " + item["answer"])
|
logging.info("POST - speaking_task_3 - " + str(request_id) + " - Downloading file " + item["answer"])
|
||||||
download_firebase_file(FIREBASE_BUCKET, item["answer"], sound_file_name)
|
download_firebase_file(FIREBASE_BUCKET, item["answer"], sound_file_name)
|
||||||
logging.info("POST - speaking_task_1 - " + str(
|
logging.info("POST - speaking_task_3 - " + str(
|
||||||
request_id) + " - Downloaded file " + item["answer"] + " to " + sound_file_name)
|
request_id) + " - Downloaded file " + item["answer"] + " to " + sound_file_name)
|
||||||
|
|
||||||
answer_text = speech_to_text(sound_file_name)
|
answer_text = speech_to_text(sound_file_name)
|
||||||
logging.info("POST - speaking_task_1 - " + str(request_id) + " - Transcripted answer: " + answer_text)
|
logging.info("POST - speaking_task_3 - " + str(request_id) + " - Transcripted answer: " + answer_text)
|
||||||
|
|
||||||
text_answers.append(answer_text)
|
text_answers.append(answer_text)
|
||||||
item["answer"] = answer_text
|
item["answer"] = answer_text
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
from helper.constants import BLACKLISTED_WORDS, GPT_3_5_TURBO
|
from helper.constants import BLACKLISTED_WORDS, GPT_3_5_TURBO
|
||||||
from helper.token_counter import count_tokens
|
from helper.token_counter import count_tokens
|
||||||
@@ -54,7 +54,7 @@ def check_fields(obj, fields):
|
|||||||
return all(field in obj for field in fields)
|
return all(field in obj for field in fields)
|
||||||
|
|
||||||
|
|
||||||
def make_openai_call(model, messages, token_count, fields_to_check, temperature):
|
def make_openai_call(model, messages, token_count, fields_to_check, temperature, check_blacklisted=True):
|
||||||
global try_count
|
global try_count
|
||||||
result = client.chat.completions.create(
|
result = client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -65,6 +65,7 @@ def make_openai_call(model, messages, token_count, fields_to_check, temperature)
|
|||||||
)
|
)
|
||||||
result = result.choices[0].message.content
|
result = result.choices[0].message.content
|
||||||
|
|
||||||
|
if check_blacklisted:
|
||||||
found_blacklisted_word = get_found_blacklisted_words(result)
|
found_blacklisted_word = get_found_blacklisted_words(result)
|
||||||
|
|
||||||
if found_blacklisted_word is not None and try_count < TRY_LIMIT:
|
if found_blacklisted_word is not None and try_count < TRY_LIMIT:
|
||||||
@@ -188,7 +189,7 @@ def get_fixed_text(text):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
token_count = count_total_tokens(messages)
|
token_count = count_total_tokens(messages)
|
||||||
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2)
|
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2, False)
|
||||||
return response["fixed_text"]
|
return response["fixed_text"]
|
||||||
|
|
||||||
|
|
||||||
@@ -203,7 +204,7 @@ def get_speaking_corrections(text):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
token_count = count_total_tokens(messages)
|
token_count = count_total_tokens(messages)
|
||||||
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2)
|
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2, False)
|
||||||
return response["fixed_text"]
|
return response["fixed_text"]
|
||||||
|
|
||||||
|
|
||||||
@@ -211,6 +212,7 @@ def has_blacklisted_words(text: str):
|
|||||||
text_lower = text.lower()
|
text_lower = text.lower()
|
||||||
return any(word in text_lower for word in BLACKLISTED_WORDS)
|
return any(word in text_lower for word in BLACKLISTED_WORDS)
|
||||||
|
|
||||||
|
|
||||||
def get_found_blacklisted_words(text: str):
|
def get_found_blacklisted_words(text: str):
|
||||||
text_lower = text.lower()
|
text_lower = text.lower()
|
||||||
for word in BLACKLISTED_WORDS:
|
for word in BLACKLISTED_WORDS:
|
||||||
@@ -218,6 +220,7 @@ def get_found_blacklisted_words(text: str):
|
|||||||
return word
|
return word
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def remove_special_characters_from_beginning(string):
|
def remove_special_characters_from_beginning(string):
|
||||||
cleaned_string = string.lstrip('\n')
|
cleaned_string = string.lstrip('\n')
|
||||||
if string.startswith("'") or string.startswith('"'):
|
if string.startswith("'") or string.startswith('"'):
|
||||||
@@ -239,6 +242,7 @@ def replace_expression_in_object(obj, expression, replacement):
|
|||||||
obj[key] = replace_expression_in_object(obj[key], expression, replacement)
|
obj[key] = replace_expression_in_object(obj[key], expression, replacement)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def count_total_tokens(messages):
|
def count_total_tokens(messages):
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
|||||||
Reference in New Issue
Block a user