import json import sys import traceback from typing import Any, Dict, List, Optional from jsonargparse import ArgumentParser from openai import AsyncOpenAI import asyncio import enum import instructor from pydantic import BaseModel, Field, SecretStr from datetime import datetime from opensearchpy import OpenSearch, helpers, AsyncOpenSearch class Topics(str, enum.Enum): """Correctly assign one of the predefined topic to the content""" SPAM = "SPAM, advertisement, promotional" SALES = "direct sales of goods or services" EXPLICIT_CONTENT = "porn, violence or Harmful content" RESEARCH = "description of a scientific research" DATASET = "description of a scientific dataset " OBJECT = "scientific description of an object" BIBLIOGRAPHIC = "bibliographic record" NA = "not available" class ProductInfo(BaseModel): """ Your task is to identify SPAM content among research product descriptions. """ language: str = Field(description="The language of the content") topic: Topics reason: str = Field(description="explain why the topic was chosen") spam_words: list[str] = Field(description="content's spam words", min_length=0, max_length=3) main_model_schema = ProductInfo.model_json_schema() response_schema = json.dumps(main_model_schema, indent=None) parser = ArgumentParser(env_prefix="CURATION", default_env=True) parser.add_argument("--opensearch.host", default='opensearch-cluster.local-dataplatform') parser.add_argument("--opensearch.port", default=443, type=int) parser.add_argument("--opensearch.user", default="admin", type=SecretStr) parser.add_argument("--opensearch.password", default="admin", type=SecretStr) parser.add_argument("--openai.host", default='localhost') parser.add_argument("--openai.port", default=8000, type=int) parser.add_argument("--openai.api_key", default='api_key') parser.add_argument("--parallelism", default=36, type=int) cfg = parser.parse_args() with open("/blacklist.txt", "r") as text_file: blacklist = [line.rstrip().lower() for line in text_file.readlines()] client = AsyncOpenSearch( hosts=[{'host': cfg.get("opensearch.host"), 'port': cfg.get("opensearch.port")}], http_auth=(cfg.get("opensearch.user").get_secret_value(), cfg.get("opensearch.password").get_secret_value()), use_ssl=True, verify_certs=False, ssl_show_warn=False, pool_maxsize=20 ) oai = instructor.patch(AsyncOpenAI(base_url="http://" + cfg.get("openai.host") + ":" + str(cfg.get("openai.port")) + "/v1", api_key=cfg.get("openai.api_key"), timeout=2400.0*6.0), mode=instructor.Mode.JSON_SCHEMA) def source_txt_value(data: Dict[str, Any], labels: List[str]) -> Optional[Any]: if len(labels) <= 0: return None current_value = data['_source'] for label in labels: if isinstance(current_value, dict) and label in current_value: current_value = current_value[label] else: return None if current_value is None: return None if isinstance(current_value, list): if len(current_value) > 0: return current_value[0] else: return None return str(current_value) async def eval_spam_candidate(hit: dict) -> ProductInfo: response = await oai.chat.completions.create( model="suzume-multilingual", response_model=ProductInfo, messages=[ { "role": "user", "content": hit['title'] } ], extra_body={ "cache_prompt": True, "json_schema": response_schema }, temperature=0.0, max_retries=5, stream=False ) return response.model_dump() async def evaluate_hit(hit: dict): obj = await eval_spam_candidate(hit) if obj['topic'] in [Topics.SPAM, Topics.EXPLICIT_CONTENT, Topics.SALES]: print("SPAM detected: " + hit['local_identifier'], flush=True) print("AI Reponse:" + str(obj) + " for: " + hit['title'], flush=True) obj['local_identifier'] = hit['local_identifier'] obj['trigger_word'] = hit['found'] obj['abstract'] = hit['title'] obj['timestamp'] = datetime.now().isoformat() await client.index( index='spam', body=obj, id=hit['local_identifier'], refresh=True ) return obj async def get_potential_spam() -> Any: count = 0 resume_from = 0 async for hit in helpers.async_scan(client, index="products", query={"query": {"match_all": {}}}, scroll='1d'): count = count + 1 if count < resume_from: continue local_identifier = source_txt_value(hit, ["local_identifier"]) print(f"{count}:\t{local_identifier}") title = source_txt_value(hit, ["titles", "none"]) description = source_txt_value(hit, ['abstracts', 'none']) if title is None: if description is None: print("No description! {local_identifier}", flush=True) continue title = "" if description is not None: title = title + " " + description utf8_title = title.encode('utf-8') if len(utf8_title) > 2048: title = utf8_title[0:2048].decode('utf-8', 'ignore') test_string = title.lower() split_string = test_string.split() found = None for badword in blacklist: if badword in test_string: if len(badword) == 1 or ' ' in badword or badword in split_string: found = badword break if found is None: continue if await client.exists(index="spam", id=local_identifier): print("cached") continue yield {"local_identifier": local_identifier, "title": title, "found": found} async def worker(name, queue): try: while True: # Get a "work item" out of the queue. hit = await queue.get() # Sleep for the "sleep_for" seconds. await evaluate_hit(hit) # Notify the queue that the "work item" has been processed. queue.task_done() except Exception as e: print(traceback.format_exc()) sys.exit(-1) async def main(): #if await client.indices.exists("spam"): # await client.indices.delete("spam") if not await client.indices.exists("spam"): await client.indices.create("spam", { "settings": { "index": { "number_of_shards": 3, "number_of_replicas": 0, "replication.type": "SEGMENT" } }, "mappings": { "properties": { "local_identifier": { "type": "keyword" }, "language": { "type": "keyword" }, "topic": { "type": "keyword" }, "abstract": { "type": "text", "index": False, }, "reason": { "type": "text", "index": False, }, "spam_words": { "type": "keyword" }, "trigger_word": { "type": "keyword" }, "timestamp": { "type": "date", "format": "date_hour_minute_second_fraction" } } } }) parallelism = cfg.get("parallelism") queue = asyncio.Queue(parallelism) tasks = [] for i in range(parallelism): task = asyncio.create_task(worker(f'worker-{i}', queue)) tasks.append(task) async for hit in get_potential_spam(): await queue.put(hit) await queue.join() # Cancel our worker tasks. for task in tasks: task.cancel() # Wait until all worker tasks are cancelled. await asyncio.gather(*tasks, return_exceptions=True) if __name__ == "__main__": loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(main()) loop.close()