from __future__ import annotations import gzip import io import json import logging import os from datetime import timedelta from kubernetes.client import models as k8s import pendulum from airflow.decorators import dag from airflow.decorators import task from airflow.operators.python import PythonOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.utils.helpers import chain from airflow.hooks.base import BaseHook from opensearchpy import OpenSearch, helpers from EOSC_indexes import mappings from EOSC_entity_trasform import transform_entities S3_CONN_ID = os.getenv("S3_CONN_ID", "s3_conn") EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) ENTITIES = ["datasource", "grants", "organizations", "persons", "products", "topics", "venues"] BULK_PARALLELISM = 10 default_args = { "execution_timeout": timedelta(days=EXECUTION_TIMEOUT), "retries": int(os.getenv("DEFAULT_TASK_RETRIES", 1)), "retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))), } @dag( schedule=None, dagrun_timeout=None, start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, default_args=default_args, params={ "S3_CONN_ID": "s3_conn", "OPENSEARCH_CONN_ID": "opensearch_default", "EOSC_CATALOG_BUCKET": "eosc-portal-import" }, tags=["lot1"] ) def import_EOSC_graph(): @task def create_indexes(**kwargs): conn = BaseHook.get_connection(kwargs["params"]["OPENSEARCH_CONN_ID"]) client = OpenSearch( hosts=[{'host': conn.host, 'port': conn.port}], http_auth=(conn.login, conn.password), use_ssl=True, verify_certs=False, ssl_show_warn=False, pool_maxsize=20 ) client.cluster.put_settings(body={ "persistent": { "cluster.routing.allocation.balance.prefer_primary": True, "segrep.pressure.enabled": True } }) for entity in ENTITIES: if client.indices.exists(entity): client.indices.delete(entity) client.indices.create(entity, { "settings": { "index": { "number_of_shards": 40, "number_of_replicas": 0, "refresh_interval": -1, "translog.flush_threshold_size": "2048MB", "codec": "zstd_no_dict", "replication.type": "SEGMENT" } }, "mappings": mappings[entity] }) def compute_batches(ds=None, **kwargs): hook = S3Hook(S3_CONN_ID, transfer_config_args={'use_threads': False}) pieces = [] for entity in ENTITIES: keys = hook.list_keys(bucket_name=kwargs["params"]["EOSC_CATALOG_BUCKET"], prefix=f'{entity}/') to_delete = list(filter(lambda key: key.endswith('.PROCESSED'), keys)) for obj in to_delete: hook.get_conn().delete_object(Bucket=kwargs["params"]["EOSC_CATALOG_BUCKET"], Key=obj) for key in keys: if key.endswith('.gz'): pieces.append((entity, key)) def split_list(list_a, chunk_size): for i in range(0, len(list_a), chunk_size): yield {"files": list_a[i:i + chunk_size]} return list(split_list(pieces, len(pieces) // BULK_PARALLELISM)) @task(executor_config={ "pod_override": k8s.V1Pod( spec=k8s.V1PodSpec( containers=[ k8s.V1Container( name="base", resources=k8s.V1ResourceRequirements( requests={ "cpu": "550m", "memory": "256Mi" } ) ) ] ) ) }) def bulk_load(files: list[(str, str)], **kwargs): conn = BaseHook.get_connection(kwargs["params"]["OPENSEARCH_CONN_ID"]) client = OpenSearch( hosts=[{'host': conn.host, 'port': conn.port}], http_auth=(conn.login, conn.password), use_ssl=True, verify_certs=False, ssl_show_warn=False, pool_maxsize=20 ) hook = S3Hook(S3_CONN_ID, transfer_config_args={'use_threads': False}) for (entity, key) in files: if hook.check_for_key(key=f"{key}.PROCESSED", bucket_name=kwargs["params"]["EOSC_CATALOG_BUCKET"]): print(f'Skipping {entity}: {key}') continue print(f'Processing {entity}: {key}') s3_obj = hook.get_key(key, bucket_name=kwargs["params"]["EOSC_CATALOG_BUCKET"]) with s3_obj.get()["Body"] as body: with gzip.GzipFile(fileobj=body) as gzipfile: def _generate_data(): buff = io.BufferedReader(gzipfile) for line in buff: data = json.loads(line) data['_index'] = entity data['_id'] = data['local_identifier'] if entity in transform_entities: data = transform_entities[entity](data) yield data # disable success post logging logging.getLogger("opensearch").setLevel(logging.WARN) succeeded = 0 failed = 0 for success, item in helpers.parallel_bulk(client, actions=_generate_data(), raise_on_exception=False, raise_on_error=False, chunk_size=5000, max_chunk_bytes=50 * 1024 * 1024, timeout=180): if success: succeeded = succeeded + 1 else: print(item["index"]["error"]) failed = failed + 1 if failed > 0: print(f"There were {failed} errors:") else: hook.load_string( "", f"{key}.PROCESSED", bucket_name=kwargs["params"]["EOSC_CATALOG_BUCKET"], replace=False ) if succeeded > 0: print(f"Bulk-inserted {succeeded} items (streaming_bulk).") @task def close_indexes(**kwargs): conn = BaseHook.get_connection(kwargs["params"]["OPENSEARCH_CONN_ID"]) client = OpenSearch( hosts=[{'host': conn.host, 'port': conn.port}], http_auth=(conn.login, conn.password), use_ssl=True, verify_certs=False, ssl_show_warn=False, pool_maxsize=20, timeout=180 ) for entity in ENTITIES: client.indices.refresh(entity) parallel_batches = PythonOperator(task_id="compute_parallel_batches", python_callable=compute_batches) chain( create_indexes.override(task_id="create_indexes")(), parallel_batches, bulk_load.expand_kwargs(parallel_batches.output), close_indexes.override(task_id="close_indexes")() ) import_EOSC_graph()