Fix check for blacklisted on free form answers.
This commit is contained in:
@@ -2,8 +2,8 @@ import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from openai import OpenAI
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
|
||||
from helper.constants import BLACKLISTED_WORDS, GPT_3_5_TURBO
|
||||
from helper.token_counter import count_tokens
|
||||
@@ -54,7 +54,7 @@ def check_fields(obj, fields):
|
||||
return all(field in obj for field in fields)
|
||||
|
||||
|
||||
def make_openai_call(model, messages, token_count, fields_to_check, temperature):
|
||||
def make_openai_call(model, messages, token_count, fields_to_check, temperature, check_blacklisted=True):
|
||||
global try_count
|
||||
result = client.chat.completions.create(
|
||||
model=model,
|
||||
@@ -65,15 +65,16 @@ def make_openai_call(model, messages, token_count, fields_to_check, temperature)
|
||||
)
|
||||
result = result.choices[0].message.content
|
||||
|
||||
found_blacklisted_word = get_found_blacklisted_words(result)
|
||||
if check_blacklisted:
|
||||
found_blacklisted_word = get_found_blacklisted_words(result)
|
||||
|
||||
if found_blacklisted_word is not None and try_count < TRY_LIMIT:
|
||||
from app import app
|
||||
app.logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word))
|
||||
try_count = try_count + 1
|
||||
return make_openai_call(model, messages, token_count, fields_to_check, temperature)
|
||||
elif found_blacklisted_word is not None and try_count >= TRY_LIMIT:
|
||||
return ""
|
||||
if found_blacklisted_word is not None and try_count < TRY_LIMIT:
|
||||
from app import app
|
||||
app.logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word))
|
||||
try_count = try_count + 1
|
||||
return make_openai_call(model, messages, token_count, fields_to_check, temperature)
|
||||
elif found_blacklisted_word is not None and try_count >= TRY_LIMIT:
|
||||
return ""
|
||||
|
||||
if fields_to_check is None:
|
||||
return json.loads(result)
|
||||
@@ -188,7 +189,7 @@ def get_fixed_text(text):
|
||||
}
|
||||
]
|
||||
token_count = count_total_tokens(messages)
|
||||
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2)
|
||||
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2, False)
|
||||
return response["fixed_text"]
|
||||
|
||||
|
||||
@@ -203,7 +204,7 @@ def get_speaking_corrections(text):
|
||||
}
|
||||
]
|
||||
token_count = count_total_tokens(messages)
|
||||
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2)
|
||||
response = make_openai_call(GPT_3_5_TURBO, messages, token_count, ["fixed_text"], 0.2, False)
|
||||
return response["fixed_text"]
|
||||
|
||||
|
||||
@@ -211,6 +212,7 @@ def has_blacklisted_words(text: str):
|
||||
text_lower = text.lower()
|
||||
return any(word in text_lower for word in BLACKLISTED_WORDS)
|
||||
|
||||
|
||||
def get_found_blacklisted_words(text: str):
|
||||
text_lower = text.lower()
|
||||
for word in BLACKLISTED_WORDS:
|
||||
@@ -218,6 +220,7 @@ def get_found_blacklisted_words(text: str):
|
||||
return word
|
||||
return None
|
||||
|
||||
|
||||
def remove_special_characters_from_beginning(string):
|
||||
cleaned_string = string.lstrip('\n')
|
||||
if string.startswith("'") or string.startswith('"'):
|
||||
@@ -239,6 +242,7 @@ def replace_expression_in_object(obj, expression, replacement):
|
||||
obj[key] = replace_expression_in_object(obj[key], expression, replacement)
|
||||
return obj
|
||||
|
||||
|
||||
def count_total_tokens(messages):
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
|
||||
Reference in New Issue
Block a user