lot1-kickoff/airflow/dags/import_raw_graph.py

179 lines
6.0 KiB
Python
Raw Normal View History

2024-03-17 15:49:09 +01:00
from __future__ import annotations
2024-03-18 00:54:50 +01:00
import gc
2024-03-17 15:49:09 +01:00
import gzip
import io
import json
import os
import zipfile
from datetime import timedelta
import pendulum
from airflow.decorators import dag
from airflow.decorators import task
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.file import TemporaryDirectory
from airflow.utils.helpers import chain
from airflow.models import Variable
from opensearchpy import OpenSearch, helpers
from opensearch_indexes import mappings
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "skgif-openaire-eu")
AWS_CONN_ID = os.getenv("S3_CONN_ID", "s3_conn")
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
OPENSEARCH_HOST= Variable.get("OPENSEARCH_URL", "opensearch-cluster.lot1-opensearch-cluster.svc.cluster.local")
OPENSEARCH_URL= Variable.get("OPENSEARCH_URL", "https://opensearch-cluster.lot1-opensearch-cluster.svc.cluster.local:9200")
OPENSEARCH_USER = Variable.get("OPENSEARCH_USER", "admin")
OPENSEARCH_PASSWD = Variable.get("OPENSEARCH_PASSWORD", "admin")
ENTITIES = ["dataset", "datasource", "organization", "otherresearchproduct",
"project", "publication", "relation", "software"]
2024-03-17 19:56:26 +01:00
BULK_PARALLELISM = 2
2024-03-17 15:49:09 +01:00
#
default_args = {
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 1)),
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
}
def strip_prefix(s, p):
if s.startswith(p):
return s[len(p):]
else:
return s
@dag(
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
default_args=default_args,
tags=["s3"],
)
2024-03-17 15:50:05 +01:00
def import_raw_graph():
2024-03-17 15:49:09 +01:00
@task
def create_indexes():
client = OpenSearch(
hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}],
http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD),
use_ssl=True,
verify_certs=False,
ssl_show_warn=False,
pool_maxsize=20
)
client.cluster.put_settings(body={"persistent": {
#"cluster.routing.allocation.balanace.prefer_primary": True,
"segrep.pressure.enabled": True
}})
for entity in ENTITIES:
if client.indices.exists(entity):
client.indices.delete(entity)
client.indices.create(entity, {
"settings": {
"index": {
"number_of_shards": 40,
"number_of_replicas": 0,
"refresh_interval": -1,
"codec": "zstd_no_dict",
"replication.type": "SEGMENT",
2024-03-17 21:32:40 +01:00
"translog.flush_threshold_size": "2048MB",
2024-03-17 21:33:42 +01:00
"mapping.ignore_malformed": "true"
2024-03-17 15:49:09 +01:00
}
2024-03-17 21:32:40 +01:00
}
2024-03-17 15:49:09 +01:00
# "mappings": mappings[entity]
})
def compute_batches(ds=None, **kwargs):
pieces = []
for entity in ENTITIES:
hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False})
keys = hook.list_keys(bucket_name=S3_BUCKET_NAME, prefix=f'00_graph_aggregator/{entity}/')
for key in keys:
2024-03-17 18:06:08 +01:00
if key.endswith('.gz'):
2024-03-17 15:49:09 +01:00
pieces.append((entity, key))
def split_list(list_a, chunk_size):
for i in range(0, len(list_a), chunk_size):
yield {"files": list_a[i:i + chunk_size]}
return list(split_list(pieces, len(pieces)//BULK_PARALLELISM))
@task
def bulk_load(files: list[(str, str)]):
client = OpenSearch(
hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}],
http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD),
use_ssl=True,
verify_certs=False,
ssl_show_warn=False,
pool_maxsize=20
)
hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False})
2024-03-18 01:00:54 +01:00
for (entity, key) in files:
print(f'{entity}: {key}')
s3_obj = hook.get_key(key, bucket_name=S3_BUCKET_NAME)
with s3_obj.get()["Body"] as body:
with gzip.GzipFile(fileobj=body) as gzipfile:
def _generate_data():
2024-03-18 00:37:22 +01:00
buff = io.BufferedReader(gzipfile)
for line in buff:
data = json.loads(line)
data['_index'] = entity
data['_id'] = data['id']
yield data
2024-03-17 15:49:09 +01:00
2024-03-18 01:00:54 +01:00
succeeded = 0
failed = 0
for success, item in helpers.parallel_bulk(client, actions=_generate_data(), raise_on_exception=False,
2024-03-18 01:07:23 +01:00
raise_on_error=False,
2024-03-18 01:00:54 +01:00
chunk_size=500, max_chunk_bytes=10 * 1024 * 1024):
if success:
succeeded = succeeded + 1
else:
print(item["index"]["error"])
failed = failed+1
if failed > 0:
print(f"There were {failed} errors:")
2024-03-18 01:11:10 +01:00
if succeeded > 0:
2024-03-18 01:00:54 +01:00
print(f"Bulk-inserted {succeeded} items (streaming_bulk).")
2024-03-17 15:49:09 +01:00
@task
def close_indexes():
client = OpenSearch(
hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}],
http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD),
use_ssl=True,
verify_certs=False,
ssl_show_warn=False,
pool_maxsize=20
)
for entity in ENTITIES:
client.indices.refresh(entity)
parallel_batches = PythonOperator(task_id="compute_parallel_batches", python_callable=compute_batches)
chain(
create_indexes.override(task_id="create_indexes")(),
parallel_batches,
bulk_load.expand_kwargs(parallel_batches.output),
close_indexes.override(task_id="close_indexes")()
)
2024-03-17 15:51:07 +01:00
import_raw_graph()