from __future__ import annotations import gzip import io import json import os import zipfile from datetime import timedelta import pendulum from airflow.decorators import dag from airflow.decorators import task from airflow.operators.python import PythonOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.utils.file import TemporaryDirectory from airflow.utils.helpers import chain from airflow.models import Variable from opensearchpy import OpenSearch, helpers from opensearch_indexes import mappings S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "skgif-openaire-eu") AWS_CONN_ID = os.getenv("S3_CONN_ID", "s3_conn") EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) OPENSEARCH_HOST= Variable.get("OPENSEARCH_URL", "opensearch-cluster.lot1-opensearch-cluster.svc.cluster.local") OPENSEARCH_URL= Variable.get("OPENSEARCH_URL", "https://opensearch-cluster.lot1-opensearch-cluster.svc.cluster.local:9200") OPENSEARCH_USER = Variable.get("OPENSEARCH_USER", "admin") OPENSEARCH_PASSWD = Variable.get("OPENSEARCH_PASSWORD", "admin") ENTITIES = ["dataset", "datasource", "organization", "otherresearchproduct", "project", "publication", "relation", "software"] BULK_PARALLELISM = 2 # default_args = { "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), "retries": int(os.getenv("DEFAULT_TASK_RETRIES", 1)), "retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))), } def strip_prefix(s, p): if s.startswith(p): return s[len(p):] else: return s @dag( schedule=None, start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, default_args=default_args, tags=["s3"], ) def import_raw_graph(): @task def create_indexes(): client = OpenSearch( hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}], http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD), use_ssl=True, verify_certs=False, ssl_show_warn=False, pool_maxsize=20 ) client.cluster.put_settings(body={"persistent": { #"cluster.routing.allocation.balanace.prefer_primary": True, "segrep.pressure.enabled": True }}) for entity in ENTITIES: if client.indices.exists(entity): client.indices.delete(entity) client.indices.create(entity, { "settings": { "index": { "number_of_shards": 40, "number_of_replicas": 0, "refresh_interval": -1, "codec": "zstd_no_dict", "replication.type": "SEGMENT", "translog.flush_threshold_size": "2048MB", "mapping.ignore_malformed": "true" } } # "mappings": mappings[entity] }) def compute_batches(ds=None, **kwargs): pieces = [] for entity in ENTITIES: hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False}) keys = hook.list_keys(bucket_name=S3_BUCKET_NAME, prefix=f'00_graph_aggregator/{entity}/') for key in keys: if key.endswith('.gz'): pieces.append((entity, key)) def split_list(list_a, chunk_size): for i in range(0, len(list_a), chunk_size): yield {"files": list_a[i:i + chunk_size]} return list(split_list(pieces, len(pieces)//BULK_PARALLELISM)) @task def bulk_load(files: list[(str, str)]): client = OpenSearch( hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}], http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD), use_ssl=True, verify_certs=False, ssl_show_warn=False, pool_maxsize=20 ) hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False}) def _generate_data(): for (entity, key) in files: print(f'{entity}: {key}') s3_obj = hook.get_key(key, bucket_name=S3_BUCKET_NAME) with s3_obj.get()["Body"] as body: with gzip.GzipFile(fileobj=body) as gzipfile: buff = io.BufferedReader(gzipfile) for line in buff: data = json.loads(line) data['_index'] = entity data['_id'] = data['id'] yield data succeeded = 0 failed = 0 for success, item in helpers.parallel_bulk(client, actions=_generate_data(), raise_on_exception=False, chunk_size=500, max_chunk_bytes=10 * 1024 * 1024): if success: succeeded = succeeded + 1 else: print(item["index"]["error"]) failed = failed+1 if failed > 0: print(f"There were {failed} errors:") if len(succeeded) > 0: print(f"Bulk-inserted {succeeded} items (streaming_bulk).") @task def close_indexes(): client = OpenSearch( hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}], http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD), use_ssl=True, verify_certs=False, ssl_show_warn=False, pool_maxsize=20 ) for entity in ENTITIES: client.indices.refresh(entity) parallel_batches = PythonOperator(task_id="compute_parallel_batches", python_callable=compute_batches) chain( create_indexes.override(task_id="create_indexes")(), parallel_batches, bulk_load.expand_kwargs(parallel_batches.output), close_indexes.override(task_id="close_indexes")() ) import_raw_graph()