Use argparse to setup spark (#2082)

This commit is contained in:
Ayaz Salikhov
2024-01-17 15:07:15 +04:00
committed by GitHub
parent bf33945b9e
commit afe30f0c9a
2 changed files with 22 additions and 17 deletions

View File

@@ -41,11 +41,11 @@ ENV SPARK_OPTS="--driver-java-options=-Xms1024M --driver-java-options=-Xmx4096M
COPY setup_spark.py /opt/setup-scripts/
# Setup Spark
RUN SPARK_VERSION="${spark_version}" \
HADOOP_VERSION="${hadoop_version}" \
SCALA_VERSION="${scala_version}" \
SPARK_DOWNLOAD_URL="${spark_download_url}" \
/opt/setup-scripts/setup_spark.py
RUN /opt/setup-scripts/setup_spark.py \
--spark-version="${spark_version}" \
--hadoop-version="${hadoop_version}" \
--scala-version="${scala_version}" \
--spark-download-url="${spark_download_url}"
# Configure IPython system-wide
COPY ipython_kernel_config.py "/etc/ipython/"

View File

@@ -4,9 +4,9 @@
# Requirements:
# - Run as the root user
# - Required env variables: SPARK_HOME, HADOOP_VERSION, SPARK_DOWNLOAD_URL
# - Optional env variables: SPARK_VERSION, SCALA_VERSION
# - Required env variable: SPARK_HOME
import argparse
import logging
import os
import subprocess
@@ -27,13 +27,10 @@ def get_all_refs(url: str) -> list[str]:
return [a["href"] for a in soup.find_all("a", href=True)]
def get_spark_version() -> str:
def get_latest_spark_version() -> str:
"""
If ${SPARK_VERSION} env variable is non-empty, simply returns it
Otherwise, returns the last stable version of Spark using spark archive
Returns the last stable version of Spark using spark archive
"""
if (version := os.environ["SPARK_VERSION"]) != "":
return version
LOGGER.info("Downloading Spark versions information")
all_refs = get_all_refs("https://archive.apache.org/dist/spark/")
stable_versions = [
@@ -106,12 +103,20 @@ def configure_spark(spark_dir_name: str, spark_home: Path) -> None:
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
spark_version = get_spark_version()
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--spark-version", required=True)
arg_parser.add_argument("--hadoop-version", required=True)
arg_parser.add_argument("--scala-version", required=True)
arg_parser.add_argument("--spark-download-url", type=Path, required=True)
args = arg_parser.parse_args()
args.spark_version = args.spark_version or get_latest_spark_version()
spark_dir_name = download_spark(
spark_version=spark_version,
hadoop_version=os.environ["HADOOP_VERSION"],
scala_version=os.environ["SCALA_VERSION"],
spark_download_url=Path(os.environ["SPARK_DOWNLOAD_URL"]),
spark_version=args.spark_version,
hadoop_version=args.hadoop_version,
scala_version=args.scala_version,
spark_download_url=args.spark_download_url,
)
configure_spark(
spark_dir_name=spark_dir_name, spark_home=Path(os.environ["SPARK_HOME"])