256 lines
8.3 KiB
Python
256 lines
8.3 KiB
Python
import json
|
|
import sys
|
|
import traceback
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from jsonargparse import ArgumentParser
|
|
from openai import AsyncOpenAI
|
|
|
|
import asyncio
|
|
import enum
|
|
import instructor
|
|
|
|
from pydantic import BaseModel, Field, SecretStr
|
|
|
|
from datetime import datetime
|
|
from opensearchpy import OpenSearch, helpers, AsyncOpenSearch
|
|
|
|
|
|
class Topics(str, enum.Enum):
|
|
"""Correctly assign one of the predefined topic to the content"""
|
|
SPAM = "SPAM, advertisement, promotional"
|
|
SALES = "direct sales of goods or services"
|
|
EXPLICIT_CONTENT = "porn, violence or Harmful content"
|
|
RESEARCH = "description of a scientific research"
|
|
DATASET = "description of a scientific dataset "
|
|
OBJECT = "scientific description of an object"
|
|
BIBLIOGRAPHIC = "bibliographic record"
|
|
NA = "not available"
|
|
|
|
|
|
class ProductInfo(BaseModel):
|
|
"""
|
|
Your task is to identify SPAM content among research product descriptions.
|
|
"""
|
|
language: str = Field(description="The language of the content")
|
|
topic: Topics
|
|
reason: str = Field(description="explain why the topic was chosen")
|
|
spam_words: list[str] = Field(description="content's spam words", min_length=0, max_length=3)
|
|
|
|
main_model_schema = ProductInfo.model_json_schema()
|
|
response_schema = json.dumps(main_model_schema, indent=None)
|
|
|
|
parser = ArgumentParser(env_prefix="CURATION", default_env=True)
|
|
parser.add_argument("--opensearch.host", default='opensearch-cluster.local-dataplatform')
|
|
parser.add_argument("--opensearch.port", default=443, type=int)
|
|
parser.add_argument("--opensearch.user", default="admin", type=SecretStr)
|
|
parser.add_argument("--opensearch.password", default="admin", type=SecretStr)
|
|
parser.add_argument("--openai.host", default='localhost')
|
|
parser.add_argument("--openai.port", default=8000, type=int)
|
|
parser.add_argument("--openai.api_key", default='api_key')
|
|
parser.add_argument("--parallelism", default=36, type=int)
|
|
cfg = parser.parse_args()
|
|
|
|
with open("/blacklist.txt", "r") as text_file:
|
|
blacklist = [line.rstrip().lower() for line in text_file.readlines()]
|
|
|
|
|
|
client = AsyncOpenSearch(
|
|
hosts=[{'host': cfg.get("opensearch.host"), 'port': cfg.get("opensearch.port")}],
|
|
http_auth=(cfg.get("opensearch.user").get_secret_value(), cfg.get("opensearch.password").get_secret_value()),
|
|
use_ssl=True,
|
|
verify_certs=False,
|
|
ssl_show_warn=False,
|
|
pool_maxsize=20
|
|
)
|
|
|
|
oai = instructor.patch(AsyncOpenAI(base_url="http://" + cfg.get("openai.host") + ":" + str(cfg.get("openai.port")) + "/v1",
|
|
api_key=cfg.get("openai.api_key"),
|
|
timeout=2400.0*6.0),
|
|
mode=instructor.Mode.JSON_SCHEMA)
|
|
|
|
|
|
def source_txt_value(data: Dict[str, Any], labels: List[str]) -> Optional[Any]:
|
|
if len(labels) <= 0:
|
|
return None
|
|
current_value = data['_source']
|
|
for label in labels:
|
|
if isinstance(current_value, dict) and label in current_value:
|
|
current_value = current_value[label]
|
|
else:
|
|
return None
|
|
if current_value is None:
|
|
return None
|
|
if isinstance(current_value, list):
|
|
if len(current_value) > 0:
|
|
return current_value[0]
|
|
else:
|
|
return None
|
|
return str(current_value)
|
|
|
|
|
|
async def eval_spam_candidate(hit: dict) -> ProductInfo:
|
|
response = await oai.chat.completions.create(
|
|
model="suzume-multilingual",
|
|
response_model=ProductInfo,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": hit['title']
|
|
}
|
|
],
|
|
extra_body={
|
|
"cache_prompt": True,
|
|
"json_schema": response_schema
|
|
},
|
|
temperature=0.0,
|
|
max_retries=5,
|
|
stream=False
|
|
)
|
|
return response.model_dump()
|
|
|
|
|
|
async def evaluate_hit(hit: dict):
|
|
obj = await eval_spam_candidate(hit)
|
|
if obj['topic'] in [Topics.SPAM, Topics.EXPLICIT_CONTENT, Topics.SALES]:
|
|
print("SPAM detected: " + hit['local_identifier'], flush=True)
|
|
print("AI Reponse:" + str(obj) + " for: " + hit['title'], flush=True)
|
|
obj['local_identifier'] = hit['local_identifier']
|
|
obj['trigger_word'] = hit['found']
|
|
obj['abstract'] = hit['title']
|
|
obj['timestamp'] = datetime.now().isoformat()
|
|
await client.index(
|
|
index='spam',
|
|
body=obj,
|
|
id=hit['local_identifier'],
|
|
refresh=True
|
|
)
|
|
return obj
|
|
|
|
async def get_potential_spam() -> Any:
|
|
count = 0
|
|
resume_from = 0
|
|
async for hit in helpers.async_scan(client, index="products", query={"query": {"match_all": {}}}, scroll='1d'):
|
|
count = count + 1
|
|
if count < resume_from:
|
|
continue
|
|
local_identifier = source_txt_value(hit, ["local_identifier"])
|
|
print(f"{count}:\t{local_identifier}")
|
|
title = source_txt_value(hit, ["titles", "none"])
|
|
description = source_txt_value(hit, ['abstracts', 'none'])
|
|
|
|
if title is None:
|
|
if description is None:
|
|
print("No description! {local_identifier}", flush=True)
|
|
continue
|
|
title = ""
|
|
|
|
if description is not None:
|
|
title = title + " " + description
|
|
|
|
utf8_title = title.encode('utf-8')
|
|
if len(utf8_title) > 2048:
|
|
title = utf8_title[0:2048].decode('utf-8', 'ignore')
|
|
test_string = title.lower()
|
|
split_string = test_string.split()
|
|
found = None
|
|
for badword in blacklist:
|
|
if badword in test_string:
|
|
if len(badword) == 1 or ' ' in badword or badword in split_string:
|
|
found = badword
|
|
break
|
|
if found is None:
|
|
continue
|
|
if await client.exists(index="spam", id=local_identifier):
|
|
print("cached")
|
|
continue
|
|
yield {"local_identifier": local_identifier, "title": title, "found": found}
|
|
|
|
|
|
|
|
async def worker(name, queue):
|
|
try:
|
|
while True:
|
|
# Get a "work item" out of the queue.
|
|
hit = await queue.get()
|
|
# Sleep for the "sleep_for" seconds.
|
|
await evaluate_hit(hit)
|
|
# Notify the queue that the "work item" has been processed.
|
|
queue.task_done()
|
|
except Exception as e:
|
|
print(traceback.format_exc())
|
|
sys.exit(-1)
|
|
|
|
|
|
async def main():
|
|
#if await client.indices.exists("spam"):
|
|
# await client.indices.delete("spam")
|
|
|
|
if not await client.indices.exists("spam"):
|
|
await client.indices.create("spam", {
|
|
"settings": {
|
|
"index": {
|
|
"number_of_shards": 3,
|
|
"number_of_replicas": 0,
|
|
"replication.type": "SEGMENT"
|
|
}
|
|
|
|
},
|
|
"mappings": {
|
|
"properties": {
|
|
"local_identifier": {
|
|
"type": "keyword"
|
|
},
|
|
"language": {
|
|
"type": "keyword"
|
|
},
|
|
"topic": {
|
|
"type": "keyword"
|
|
},
|
|
"abstract": {
|
|
"type": "text",
|
|
"index": False,
|
|
},
|
|
"reason": {
|
|
"type": "text",
|
|
"index": False,
|
|
},
|
|
"spam_words": {
|
|
"type": "keyword"
|
|
},
|
|
"trigger_word": {
|
|
"type": "keyword"
|
|
},
|
|
"timestamp": {
|
|
"type": "date",
|
|
"format": "date_hour_minute_second_fraction"
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
parallelism = cfg.get("parallelism")
|
|
queue = asyncio.Queue(parallelism)
|
|
tasks = []
|
|
for i in range(parallelism):
|
|
task = asyncio.create_task(worker(f'worker-{i}', queue))
|
|
tasks.append(task)
|
|
|
|
async for hit in get_potential_spam():
|
|
await queue.put(hit)
|
|
|
|
await queue.join()
|
|
# Cancel our worker tasks.
|
|
for task in tasks:
|
|
task.cancel()
|
|
|
|
# Wait until all worker tasks are cancelled.
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_until_complete(main())
|
|
loop.close()
|