Cxxs commited on
Commit
c5a1bb3
·
1 Parent(s): 1f60d3b

add input check

Browse files
Files changed (1) hide show
  1. prompt_check.py +7 -1
prompt_check.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  def is_unsafe_prompt(model, tokenizer, user_prompt):
2
  prompt = f"""You are a content safety filter. Your task is to analyze the following user input and determine if it contains unsafe content.
3
 
@@ -15,6 +19,8 @@ def is_unsafe_prompt(model, tokenizer, user_prompt):
15
  messages = [
16
  {"role": "user", "content": prompt}
17
  ]
 
 
18
  text = tokenizer.apply_chat_template(
19
  messages,
20
  tokenize=False,
@@ -32,6 +38,6 @@ def is_unsafe_prompt(model, tokenizer, user_prompt):
32
 
33
  content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
34
 
35
- print(content)
36
 
37
  return "yes" in content.lower()
 
1
+ from diffusers.utils.logging import get_logger
2
+
3
+ logger = get_logger(__name__)
4
+
5
  def is_unsafe_prompt(model, tokenizer, user_prompt):
6
  prompt = f"""You are a content safety filter. Your task is to analyze the following user input and determine if it contains unsafe content.
7
 
 
19
  messages = [
20
  {"role": "user", "content": prompt}
21
  ]
22
+
23
+ logger.warning(str(messages))
24
  text = tokenizer.apply_chat_template(
25
  messages,
26
  tokenize=False,
 
38
 
39
  content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
40
 
41
+ logger.warning(content)
42
 
43
  return "yes" in content.lower()