lot1-kickoff/components/curationdb/antispam-batch.py

247 lines
8.1 KiB
Python
Raw Normal View History

2024-06-15 14:23:50 +02:00
import json
import sys
import traceback
from typing import Any, Dict, List, Optional
from jsonargparse import ArgumentParser
from openai import AsyncOpenAI
import asyncio
import enum
import instructor
from pydantic import BaseModel, Field, SecretStr
from datetime import datetime
from opensearchpy import OpenSearch, helpers, AsyncOpenSearch
class Topics(str, enum.Enum):
"""Correctly assign one of the predefined topic to the content"""
SPAM = "SPAM, advertisement, promotional"
SALES = "direct sales of goods or services"
EXPLICIT_CONTENT = "porn, violence or Harmful content"
RESEARCH = "description of a scientific research"
DATASET = "description of a scientific dataset "
OBJECT = "scientific description of an object"
BIBLIOGRAPHIC = "bibliographic record"
NA = "not available"
class ProductInfo(BaseModel):
"""
Your task is to identify SPAM content among research product descriptions.
"""
language: str = Field(description="The language of the content")
topic: Topics
reason: str = Field(description="explain why the topic was chosen")
spam_words: list[str] = Field(description="content's spam words", min_length=0, max_length=3)
main_model_schema = ProductInfo.model_json_schema()
response_schema = json.dumps(main_model_schema, indent=None)
parser = ArgumentParser(env_prefix="CURATION", default_env=True)
parser.add_argument("--opensearch.host", default='opensearch-cluster.local-dataplatform')
parser.add_argument("--opensearch.port", default=443, type=int)
parser.add_argument("--opensearch.user", default="admin", type=SecretStr)
parser.add_argument("--opensearch.password", default="admin", type=SecretStr)
parser.add_argument("--openai.host", default='localhost')
parser.add_argument("--openai.port", default=8000, type=int)
parser.add_argument("--openai.api_key", default='api_key')
parser.add_argument("--parallelism", default=8, type=int)
cfg = parser.parse_args()
with open("/blacklist.txt", "r") as text_file:
blacklist = [line.rstrip().lower() for line in text_file.readlines()]
client = AsyncOpenSearch(
hosts=[{'host': cfg.get("opensearch.host"), 'port': cfg.get("opensearch.port")}],
http_auth=(cfg.get("opensearch.user").get_secret_value(), cfg.get("opensearch.password").get_secret_value()),
use_ssl=True,
verify_certs=False,
ssl_show_warn=False,
pool_maxsize=20
)
oai = instructor.patch(AsyncOpenAI(base_url="http://" + cfg.get("openai.host") + ":" + str(cfg.get("openai.port")) + "/v1",
api_key=cfg.get("openai.api_key"),
timeout=2400.0),
mode=instructor.Mode.JSON_SCHEMA)
def source_txt_value(data: Dict[str, Any], labels: List[str]) -> Optional[Any]:
if len(labels) <= 0:
return None
current_value = data['_source']
for label in labels:
if isinstance(current_value, dict) and label in current_value:
current_value = current_value[label]
else:
return None
if current_value is None:
return None
if isinstance(current_value, list):
if len(current_value) > 0:
return current_value[0]
else:
return None
return str(current_value)
async def eval_spam_candidate(hit: dict) -> ProductInfo:
response = await oai.chat.completions.create(
model="suzume-multilingual",
response_model=ProductInfo,
messages=[
{
"role": "user",
"content": hit['title']
}
],
extra_body={
"cache_propmpt": True
},
temperature=0.0,
stream=False
)
return response.model_dump()
async def evaluate_hit(hit: dict):
obj = await eval_spam_candidate(hit)
if obj['topic'] in [Topics.SPAM, Topics.EXPLICIT_CONTENT, Topics.SALES]:
print("SPAM detected: " + hit['local_identifier'])
print("AI Reponse:" + str(obj) + " for: " + hit['title'], flush=True)
obj['local_identifier'] = hit['local_identifier']
obj['trigger_word'] = hit['found']
obj['abstract'] = hit['title']
obj['timestamp'] = datetime.now().isoformat()
await client.index(
index='spam',
body=obj,
id=hit['local_identifier'],
refresh=True
)
return obj
async def get_potential_spam() -> Any:
count = 0
resume_from = 0
async for hit in helpers.async_scan(client, index="products", query={"query": {"match_all": {}}}, scroll='1d'):
count = count + 1
if count < resume_from:
continue
local_identifier = source_txt_value(hit, ["local_identifier"])
print(f"{count}:\t{local_identifier}")
title = source_txt_value(hit, ["titles", "none"])
description = source_txt_value(hit, ['abstracts', 'none'])
if title is not None:
if description is not None:
title = title + " " + description
utf8_title = title.encode('utf-8')
if len(utf8_title) > 2048:
title = utf8_title[0:2048].decode('utf-8', 'ignore')
test_string = title.lower()
split_string = test_string.split()
found = None
for badword in blacklist:
if badword in test_string:
if len(badword) == 1 or ' ' in badword or badword in split_string:
found = badword
break
if found is None:
continue
if await client.exists(index="spam", id=local_identifier):
print("cached")
continue
yield {"local_identifier": local_identifier, "title": title, "found": found}
async def worker(name, queue):
try:
while True:
# Get a "work item" out of the queue.
hit = await queue.get()
# Sleep for the "sleep_for" seconds.
await evaluate_hit(hit)
# Notify the queue that the "work item" has been processed.
queue.task_done()
except Exception as e:
print(traceback.format_exc())
sys.exit(-1)
async def main():
#if await client.indices.exists("spam"):
# await client.indices.delete("spam")
if not await client.indices.exists("spam"):
await client.indices.create("spam", {
"settings": {
"index": {
"number_of_shards": 3,
"number_of_replicas": 0,
"replication.type": "SEGMENT"
}
},
"mappings": {
"properties": {
"local_identifier": {
"type": "keyword"
},
"language": {
"type": "keyword"
},
"topic": {
"type": "keyword"
},
"abstract": {
"type": "text",
"index": False,
},
"reason": {
"type": "text",
"index": False,
},
"spam_words": {
"type": "keyword"
},
"trigger_word": {
"type": "keyword"
},
"timestamp": {
"type": "date",
"format": "date_hour_minute_second_fraction"
}
}
}
})
parallelism = cfg.get("parallelism")
queue = asyncio.Queue(parallelism)
tasks = []
for i in range(parallelism):
task = asyncio.create_task(worker(f'worker-{i}', queue))
tasks.append(task)
async for hit in get_potential_spam():
await queue.put(hit)
await queue.join()
# Cancel our worker tasks.
for task in tasks:
task.cancel()
# Wait until all worker tasks are cancelled.
await asyncio.gather(*tasks, return_exceptions=True)
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(main())
loop.close()