initial stage

This commit is contained in:
Giambattista Bloisi 2024-03-25 17:54:23 +01:00
parent 4e1955b673
commit f79eb140eb
10 changed files with 41 additions and 791 deletions

View File

@ -0,0 +1,25 @@
def map_access_right(ar: str) -> str:
match ar:
case 'open':
return 'Open Access'
case 'closed':
return 'Closed'
case 'embargo':
return 'Embargo'
case 'restricted':
return 'Restricted'
case _:
return ''
def trasform_product(p: dict) -> dict:
p['accessRights'] = list(
filter(lambda ar: ar != '', map(lambda m: map_access_right(m.get('access_right')), p.get('manifestations'))))
return p
transform_entities = {
# 'products': trasform_product
}

View File

@ -334,9 +334,9 @@ mappings['products'] = {
}
}
},
"accessRight": {
"type": "keyword"
},
# "accessRights": {
# "type": "keyword"
# },
"contributions": {
"type": "object",
"properties": {

View File

@ -1,81 +0,0 @@
import os
from datetime import datetime, timedelta
from airflow import settings
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.operators.s3 import (
S3CreateBucketOperator,
)
from airflow.providers.amazon.aws.transfers.http_to_s3 import HttpToS3Operator
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "zenodo-bucket")
S3_BUCKET_KEY = os.getenv("S3_BUCKET_KEY", "test")
S3_BUCKET_KEY_LIST = os.getenv("S3_BUCKET_KEY_LIST", "test2")
S3_BUCKET_WILDCARD_KEY = os.getenv("S3_BUCKET_WILDCARD_KEY", "test*")
PREFIX = os.getenv("S3_PREFIX", "test")
INACTIVITY_PERIOD = float(os.getenv("INACTIVITY_PERIOD", 5))
AWS_DEFAULT_REGION = os.getenv("AWS_DEFAULT_REGION", "us-east-1")
LOCAL_FILE_PATH = os.getenv("LOCAL_FILE_PATH", "/usr/local/airflow/dags/example_s3_test_file.txt")
AWS_CONN_ID = os.getenv("ASTRO_AWS_S3_CONN_ID", "s3_conn")
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
DATA = os.environ.get(
"DATA",
"""
apple,0.5
milk,2.5
bread,4.0
""",
)
default_args = {
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)),
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
}
@task
def create_connection(conn_id_name: str):
conn = Connection(
conn_id=conn_id_name,
conn_type="https",
host="zenodo.org",
port=80,
)
session = settings.Session()
session.add(conn)
session.commit()
with DAG(
dag_id="zenodo_download_to_s3",
schedule=None,
start_date=datetime(2021, 1, 1),
catchup=False,
default_args=default_args,
tags=["example", "async", "s3"],
) as dag:
conn_id_name = "zenodo"
set_up_connection = create_connection(conn_id_name)
create_bucket = S3CreateBucketOperator(
task_id="create_bucket",
region_name=AWS_DEFAULT_REGION,
bucket_name=S3_BUCKET_NAME,
aws_conn_id=AWS_CONN_ID,
)
http_to_s3_task = HttpToS3Operator(
task_id="http_to_s3_task",
http_conn_id=conn_id_name,
endpoint="/records/8223812/files/organization.tar",
s3_bucket=S3_BUCKET_NAME,
s3_key="organization.tar",
replace=True,
aws_conn_id=AWS_CONN_ID,
)
chain(set_up_connection, create_bucket, http_to_s3_task)

View File

@ -1,39 +0,0 @@
apiVersion: "sparkoperator.k8s.io/v1beta2"
kind: SparkApplication
metadata:
name: spark-pi
namespace: lot1-spark-jobs
spec:
type: Scala
mode: cluster
image: "apache/spark:v3.1.3"
imagePullPolicy: Always
mainClass: org.apache.spark.examples.SparkPi
mainApplicationFile: "local:///opt/spark/examples/jars/spark-examples_2.12-3.1.3.jar"
sparkVersion: "3.1.3"
restartPolicy:
type: Never
volumes:
- name: "test-volume"
hostPath:
path: "/tmp"
type: Directory
driver:
cores: 1
coreLimit: "1200m"
memory: "512m"
labels:
version: 3.1.3
serviceAccount: spark
volumeMounts:
- name: "test-volume"
mountPath: "/tmp"
executor:
cores: 1
instances: 1
memory: "512m"
labels:
version: 3.1.3
volumeMounts:
- name: "test-volume"
mountPath: "/tmp"

View File

@ -4,7 +4,6 @@ import gzip
import io
import json
import os
import zipfile
from datetime import timedelta
import pendulum
@ -12,15 +11,15 @@ from airflow.decorators import dag
from airflow.decorators import task
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.file import TemporaryDirectory
from airflow.utils.helpers import chain
from airflow.models import Variable
from opensearchpy import OpenSearch, helpers
from opensearch_indexes import mappings
from EOSC_indexes import mappings
S3_CONN_ID = os.getenv("S3_CONN_ID", "s3_conn")
EOSC_CATALOG_BUCKET = os.getenv("EOSC_CATALOG_BUCKET", "eosc-catalog")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "eosc-catalog")
AWS_CONN_ID = os.getenv("S3_CONN_ID", "s3_conn")
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
OPENSEARCH_HOST = Variable.get("OPENSEARCH_URL", "opensearch-cluster.lot1-opensearch-cluster.svc.cluster.local")
OPENSEARCH_URL = Variable.get("OPENSEARCH_URL", "https://opensearch-cluster.lot1-opensearch-cluster.svc.cluster.local:9200")
@ -40,13 +39,6 @@ default_args = {
}
def strip_prefix(s, p):
if s.startswith(p):
return s[len(p):]
else:
return s
@dag(
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
@ -55,26 +47,6 @@ def strip_prefix(s, p):
tags=["lot1"],
)
def eosc_catalog_import():
@task
def unzip_to_s3(key: str, bucket: str):
hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False})
with TemporaryDirectory() as dwl_dir:
with TemporaryDirectory() as tmp_dir:
archive = f'{dwl_dir}/{key}'
hook.download_file(key=key, bucket_name=bucket, local_path=dwl_dir, preserve_file_name=True,
use_autogenerated_subdir=False)
with zipfile.ZipFile(archive, 'r') as zip_ref:
zip_ref.extractall(tmp_dir)
for root, _, files in os.walk(tmp_dir):
for file in files:
if file == key:
continue
local_file_path = os.path.join(root, file)
hook.load_file(local_file_path, strip_prefix(local_file_path, tmp_dir), S3_BUCKET_NAME,
replace=True)
return ""
@task
def create_indexes():
@ -112,8 +84,8 @@ def eosc_catalog_import():
def compute_batches(ds=None, **kwargs):
pieces = []
for entity in ENTITIES:
hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False})
keys = hook.list_keys(bucket_name=S3_BUCKET_NAME, prefix=f'{entity}/')
hook = S3Hook(S3_CONN_ID, transfer_config_args={'use_threads': False})
keys = hook.list_keys(bucket_name=EOSC_CATALOG_BUCKET, prefix=f'{entity}/')
for key in keys:
pieces.append((entity, key))
@ -133,12 +105,12 @@ def eosc_catalog_import():
ssl_show_warn=False,
pool_maxsize=20
)
hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False})
hook = S3Hook(S3_CONN_ID, transfer_config_args={'use_threads': False})
def _generate_data():
for (entity, key) in files:
print(f'{entity}: {key}')
s3_obj = hook.get_key(key, bucket_name=S3_BUCKET_NAME)
s3_obj = hook.get_key(key, bucket_name=EOSC_CATALOG_BUCKET)
with gzip.GzipFile(fileobj=s3_obj.get()["Body"]) as gzipfile:
buff = io.BufferedReader(gzipfile)
for line in buff:
@ -181,7 +153,6 @@ def eosc_catalog_import():
parallel_batches = PythonOperator(task_id="compute_parallel_batches", python_callable=compute_batches)
chain(
unzip_to_s3.override(task_id="unzip_to_s3")("dump.zip", S3_BUCKET_NAME),
create_indexes.override(task_id="create_indexes")(),
parallel_batches,
bulk_load.expand_kwargs(parallel_batches.output),

View File

@ -19,7 +19,9 @@ from airflow.utils.helpers import chain
from airflow.models import Variable
from opensearchpy import OpenSearch, helpers
from opensearch_indexes import mappings
from EOSC_indexes import mappings
from EOSC_entity_trasform import transform_entities
from common import strip_prefix
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "skgif-eosc-eu")
AWS_CONN_ID = os.getenv("S3_CONN_ID", "s3_conn")
@ -43,37 +45,6 @@ default_args = {
}
def strip_prefix(s, p):
if s.startswith(p):
return s[len(p):]
else:
return s
def map_access_right(ar: str) -> str:
match ar:
case 'open':
return 'Open Access'
case 'closed':
return 'Closed'
case 'embargo':
return 'Embargo'
case 'restricted':
return 'Restricted'
case _:
return ''
def map_product(p: dict) -> dict:
p['accessRight'] = list(
filter(lambda ar: ar != '', map(lambda m: map_access_right(m.get('access_right')), p.get('manifestations'))))
return p
map_entities = {
'products': map_product
}
@dag(
schedule=None,
@ -204,8 +175,8 @@ def import_EOSC_graph():
data = json.loads(line)
data['_index'] = entity
data['_id'] = data['local_identifier']
if entity in map_entities:
data = map_entities[entity](data)
if entity in transform_entities:
data = transform_entities[entity](data)
yield data
# disable success post logging

View File

@ -1,178 +0,0 @@
from __future__ import annotations
import gc
import gzip
import io
import json
import os
import zipfile
from datetime import timedelta
import pendulum
from airflow.decorators import dag
from airflow.decorators import task
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.file import TemporaryDirectory
from airflow.utils.helpers import chain
from airflow.models import Variable
from opensearchpy import OpenSearch, helpers
from opensearch_indexes import mappings
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "skgif-openaire-eu")
AWS_CONN_ID = os.getenv("S3_CONN_ID", "s3_conn")
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
OPENSEARCH_HOST= Variable.get("OPENSEARCH_URL", "opensearch-cluster.lot1-opensearch-cluster.svc.cluster.local")
OPENSEARCH_URL= Variable.get("OPENSEARCH_URL", "https://opensearch-cluster.lot1-opensearch-cluster.svc.cluster.local:9200")
OPENSEARCH_USER = Variable.get("OPENSEARCH_USER", "admin")
OPENSEARCH_PASSWD = Variable.get("OPENSEARCH_PASSWORD", "admin")
ENTITIES = ["dataset", "datasource", "organization", "otherresearchproduct",
"project", "publication", "relation", "software"]
BULK_PARALLELISM = 2
#
default_args = {
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 1)),
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
}
def strip_prefix(s, p):
if s.startswith(p):
return s[len(p):]
else:
return s
@dag(
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
default_args=default_args,
tags=["s3"],
)
def import_raw_graph():
@task
def create_indexes():
client = OpenSearch(
hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}],
http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD),
use_ssl=True,
verify_certs=False,
ssl_show_warn=False,
pool_maxsize=20
)
client.cluster.put_settings(body={"persistent": {
#"cluster.routing.allocation.balanace.prefer_primary": True,
"segrep.pressure.enabled": True
}})
for entity in ENTITIES:
if client.indices.exists(entity):
client.indices.delete(entity)
client.indices.create(entity, {
"settings": {
"index": {
"number_of_shards": 40,
"number_of_replicas": 0,
"refresh_interval": -1,
"codec": "zstd_no_dict",
"replication.type": "SEGMENT",
"translog.flush_threshold_size": "2048MB",
"mapping.ignore_malformed": "true"
}
}
# "mappings": mappings[entity]
})
def compute_batches(ds=None, **kwargs):
pieces = []
for entity in ENTITIES:
hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False})
keys = hook.list_keys(bucket_name=S3_BUCKET_NAME, prefix=f'00_graph_aggregator/{entity}/')
for key in keys:
if key.endswith('.gz'):
pieces.append((entity, key))
def split_list(list_a, chunk_size):
for i in range(0, len(list_a), chunk_size):
yield {"files": list_a[i:i + chunk_size]}
return list(split_list(pieces, len(pieces)//BULK_PARALLELISM))
@task
def bulk_load(files: list[(str, str)]):
client = OpenSearch(
hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}],
http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD),
use_ssl=True,
verify_certs=False,
ssl_show_warn=False,
pool_maxsize=20
)
hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False})
for (entity, key) in files:
print(f'{entity}: {key}')
s3_obj = hook.get_key(key, bucket_name=S3_BUCKET_NAME)
with s3_obj.get()["Body"] as body:
with gzip.GzipFile(fileobj=body) as gzipfile:
def _generate_data():
buff = io.BufferedReader(gzipfile)
for line in buff:
data = json.loads(line)
data['_index'] = entity
data['_id'] = data['id']
yield data
succeeded = 0
failed = 0
for success, item in helpers.parallel_bulk(client, actions=_generate_data(), raise_on_exception=False,
raise_on_error=False,
chunk_size=500, max_chunk_bytes=10 * 1024 * 1024):
if success:
succeeded = succeeded + 1
else:
print(item["index"]["error"])
failed = failed+1
if failed > 0:
print(f"There were {failed} errors:")
if succeeded > 0:
print(f"Bulk-inserted {succeeded} items (streaming_bulk).")
@task
def close_indexes():
client = OpenSearch(
hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}],
http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD),
use_ssl=True,
verify_certs=False,
ssl_show_warn=False,
pool_maxsize=20
)
for entity in ENTITIES:
client.indices.refresh(entity)
parallel_batches = PythonOperator(task_id="compute_parallel_batches", python_callable=compute_batches)
chain(
create_indexes.override(task_id="create_indexes")(),
parallel_batches,
bulk_load.expand_kwargs(parallel_batches.output),
close_indexes.override(task_id="close_indexes")()
)
import_raw_graph()

View File

@ -1,243 +0,0 @@
from __future__ import annotations
import gzip
import io
import json
import logging
import os
import zipfile
from datetime import timedelta
from kubernetes.client import models as k8s
import pendulum
from airflow.decorators import dag
from airflow.decorators import task
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.file import TemporaryDirectory
from airflow.utils.helpers import chain
from airflow.models import Variable
from opensearchpy import OpenSearch, helpers
from opensearch_indexes import mappings
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "lot1-openaire-eu")
AWS_CONN_ID = os.getenv("S3_CONN_ID", "s3_conn")
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
OPENSEARCH_HOST= Variable.get("OPENSEARCH_URL", "opensearch-cluster.lot1-opensearch-cluster.svc.cluster.local")
OPENSEARCH_URL= Variable.get("OPENSEARCH_URL", "https://opensearch-cluster.lot1-opensearch-cluster.svc.cluster.local:9200")
OPENSEARCH_USER = Variable.get("OPENSEARCH_USER", "admin")
OPENSEARCH_PASSWD = Variable.get("OPENSEARCH_PASSWORD", "admin")
ENTITIES = ["datasources", "grants", "organizations", "persons", "products", "topics", "venues"]
BULK_PARALLELISM = 10
#
default_args = {
"execution_timeout": timedelta(days=EXECUTION_TIMEOUT),
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 1)),
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
}
def strip_prefix(s, p):
if s.startswith(p):
return s[len(p):]
else:
return s
@dag(
schedule=None,
dagrun_timeout=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
default_args=default_args,
tags=["lot1"],
)
def skg_if_pipeline():
@task
def unzip_to_s3(key: str, bucket: str):
hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False})
with TemporaryDirectory() as dwl_dir:
with TemporaryDirectory() as tmp_dir:
archive = f'{dwl_dir}/{key}'
hook.download_file(key=key, bucket_name=bucket, local_path=dwl_dir, preserve_file_name=True,
use_autogenerated_subdir=False)
with zipfile.ZipFile(archive, 'r') as zip_ref:
zip_ref.extractall(tmp_dir)
for root, _, files in os.walk(tmp_dir):
for file in files:
if file == key:
continue
local_file_path = os.path.join(root, file)
hook.load_file(local_file_path, strip_prefix(local_file_path, tmp_dir), S3_BUCKET_NAME,
replace=True)
return ""
@task
def create_indexes():
client = OpenSearch(
hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}],
http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD),
use_ssl=True,
verify_certs=False,
ssl_show_warn=False,
pool_maxsize=20
)
client.cluster.put_settings(body={
"persistent": {
"cluster.routing.allocation.balance.prefer_primary": True,
"segrep.pressure.enabled": True
}
})
for entity in ENTITIES:
if client.indices.exists(entity):
client.indices.delete(entity)
client.indices.create(entity, {
"settings": {
"index": {
"number_of_shards": 40,
"number_of_replicas": 0,
"refresh_interval": -1,
"translog.flush_threshold_size": "2048MB",
"codec": "zstd_no_dict",
"replication.type": "SEGMENT"
}
},
"mappings": mappings[entity]
# "mappings":{
# "dynamic": False,
# "properties": {
# "local_identifier": {
# "type": "keyword"
# }
# }
# }
})
def compute_batches(ds=None, **kwargs):
hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False})
pieces = []
for entity in ENTITIES:
keys = hook.list_keys(bucket_name=S3_BUCKET_NAME, prefix=f'{entity}/')
to_delete = list(filter(lambda key: key.endswith('.PROCESSED'), keys))
hook.delete_objects(bucket=S3_BUCKET_NAME,keys=to_delete)
for key in keys:
if key.endswith('.gz'):
pieces.append((entity, key))
def split_list(list_a, chunk_size):
for i in range(0, len(list_a), chunk_size):
yield {"files": list_a[i:i + chunk_size]}
return list(split_list(pieces, len(pieces)//BULK_PARALLELISM))
@task(executor_config={
"pod_override": k8s.V1Pod(
spec=k8s.V1PodSpec(
containers=[
k8s.V1Container(
name="base",
resources=k8s.V1ResourceRequirements(
requests={
"cpu": "550m",
"memory": "256Mi"
}
)
)
]
)
)
})
def bulk_load(files: list[(str, str)]):
client = OpenSearch(
hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}],
http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD),
use_ssl=True,
verify_certs=False,
ssl_show_warn=False,
pool_maxsize=20
)
hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False})
for (entity, key) in files:
if hook.check_for_key(key=f"{key}.PROCESSED", bucket_name=S3_BUCKET_NAME):
print(f'Skipping {entity}: {key}')
continue
print(f'Processing {entity}: {key}')
s3_obj = hook.get_key(key, bucket_name=S3_BUCKET_NAME)
with s3_obj.get()["Body"] as body:
with gzip.GzipFile(fileobj=body) as gzipfile:
def _generate_data():
buff = io.BufferedReader(gzipfile)
for line in buff:
data = json.loads(line)
data['_index'] = entity
data['_id'] = data['local_identifier']
yield data
# disable success post logging
logging.getLogger("opensearch").setLevel(logging.WARN)
succeeded = 0
failed = 0
for success, item in helpers.parallel_bulk(client, actions=_generate_data(),
raise_on_exception=False,
raise_on_error=False,
chunk_size=5000,
max_chunk_bytes=50 * 1024 * 1024,
timeout=180):
if success:
succeeded = succeeded + 1
else:
print(item["index"]["error"])
failed = failed+1
if failed > 0:
print(f"There were {failed} errors:")
else:
hook.load_string(
"",
f"{key}.PROCESSED",
bucket_name=S3_BUCKET_NAME,
replace=False
)
if succeeded > 0:
print(f"Bulk-inserted {succeeded} items (streaming_bulk).")
@task
def close_indexes():
client = OpenSearch(
hosts=[{'host': OPENSEARCH_HOST, 'port': 9200}],
http_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWD),
use_ssl=True,
verify_certs=False,
ssl_show_warn=False,
pool_maxsize=20,
timeout=180
)
for entity in ENTITIES:
client.indices.refresh(entity)
parallel_batches = PythonOperator(task_id="compute_parallel_batches", python_callable=compute_batches)
chain(
# unzip_to_s3.override(task_id="unzip_to_s3")("dump.zip", S3_BUCKET_NAME),
create_indexes.override(task_id="create_indexes")(),
parallel_batches,
bulk_load.expand_kwargs(parallel_batches.output),
close_indexes.override(task_id="close_indexes")()
)
skg_if_pipeline()

View File

@ -1,113 +0,0 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is an example DAG which uses SparkKubernetesOperator and SparkKubernetesSensor.
In this example, we create two tasks which execute sequentially.
The first task is to submit sparkApplication on Kubernetes cluster(the example uses spark-pi application).
and the second task is to check the final state of the sparkApplication that submitted in the first state.
Spark-on-k8s operator is required to be already installed on Kubernetes
https://github.com/GoogleCloudPlatform/spark-on-k8s-operator
"""
from os import path
from datetime import timedelta, datetime
# [START import_module]
# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG
# Operators; we need this to operate!
from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator
from airflow.providers.cncf.kubernetes.sensors.spark_kubernetes import SparkKubernetesSensor
from airflow.utils.dates import days_ago
# [END import_module]
# [START default_args]
# These args will get passed on to each operator
# You can override them on a per-task basis during operator initialization
default_args = {
'owner': 'airflow',
'depends_on_past': False,
'start_date': days_ago(1),
'email': ['airflow@example.com'],
'email_on_failure': False,
'email_on_retry': False,
'max_active_runs': 1,
'retries': 3
}
spec = {'apiVersion': 'sparkoperator.k8s.io/v1beta2',
'kind': 'SparkApplication',
'metadata': {
'name': 'spark-pi-{{ ds }}-{{ task_instance.try_number }}',
'namespace': 'lot1-spark-jobs'
},
'spec': {
'type': 'Scala',
'mode': 'cluster',
'image': 'apache/spark:v3.1.3',
'imagePullPolicy': 'Always',
'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.12-3.1.3.jar',
'mainClass': 'org.apache.spark.examples.SparkPi',
'sparkVersion': '3.1.3',
'restartPolicy': {'type': 'Never'},
# 'arguments': ['{{ds}}'],
'driver': {
'coreLimit': '1200m',
'cores': 1,
'labels': {'version': '3.1.3'},
'memory': '1g',
'serviceAccount': 'spark',
},
'executor': {
'cores': 1,
'instances': 1,
'memory': '512m',
'labels': {'version': '3.1.3'}
}
}}
dag = DAG(
'spark_pi',
default_args=default_args,
schedule_interval=None,
tags=['example', 'spark']
)
submit = SparkKubernetesOperator(
task_id='spark_pi_submit',
namespace='lot1-spark-jobs',
template_spec=spec,
kubernetes_conn_id="kubernetes_default",
# do_xcom_push=True,
# delete_on_termination=True,
base_container_name="spark-kubernetes-driver",
dag=dag
)
# sensor = SparkKubernetesSensor(
# task_id='spark_pi_monitor',
# namespace='lot1-spark-jobs',
# application_name="{{ task_instance.xcom_pull(task_ids='spark_pi_submit')['metadata']['name'] }}",
# kubernetes_conn_id="kubernetes_default",
# dag=dag,
# attach_log=False
# )
submit

View File

@ -1,63 +0,0 @@
import os
import tarfile
from datetime import datetime, timedelta
from io import BytesIO
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "zenodo-bucket")
S3_BUCKET_KEY = os.getenv("S3_BUCKET_KEY", "test")
S3_BUCKET_KEY_LIST = os.getenv("S3_BUCKET_KEY_LIST", "test2")
S3_BUCKET_WILDCARD_KEY = os.getenv("S3_BUCKET_WILDCARD_KEY", "test*")
PREFIX = os.getenv("S3_PREFIX", "test")
INACTIVITY_PERIOD = float(os.getenv("INACTIVITY_PERIOD", 5))
AWS_DEFAULT_REGION = os.getenv("AWS_DEFAULT_REGION", "us-east-1")
LOCAL_FILE_PATH = os.getenv("LOCAL_FILE_PATH", "/usr/local/airflow/dags/example_s3_test_file.txt")
AWS_CONN_ID = os.getenv("ASTRO_AWS_S3_CONN_ID", "s3_conn")
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
DATA = os.environ.get(
"DATA",
"""
apple,0.5
milk,2.5
bread,4.0
""",
)
default_args = {
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)),
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
}
@task
def untar_to_s3(key: str, bucket: str):
hook = S3Hook(AWS_CONN_ID, transfer_config_args={'use_threads': False})
tarball_obj = hook.get_key(key, bucket_name=bucket)
with tarfile.open(name=None, mode="r|", fileobj=tarball_obj.get()['Body']) as tarball:
for member in tarball:
if not member.isfile():
continue
fd = tarball.extractfile(member)
hook.load_file_obj(BytesIO(fd.read()), member.path, S3_BUCKET_NAME)
with DAG(
dag_id="untar_zenodo_organization",
schedule=None,
start_date=datetime(2021, 1, 1),
catchup=False,
default_args=default_args,
tags=["example", "async", "s3"],
) as dag:
untar_task = untar_to_s3("organization.tar", S3_BUCKET_NAME)
chain(untar_task)