Skip to content

Extract Module

virtughan.extract

ExtractProcessor

Source code in src/virtughan/extract.py
class ExtractProcessor:
    def __init__(
        self,
        bbox: list[float],
        start_date: str,
        end_date: str,
        cloud_cover: int,
        bands_list: list[str],
        output_dir: str,
        log_file: IO[str] = sys.stdout,
        workers: int = 1,
        zip_output: bool = False,
        smart_filter: bool = True,
        collection: str = "sentinel-2-l2a",
    ):
        self.bbox = bbox
        self.start_date = start_date
        self.end_date = end_date
        self.cloud_cover = cloud_cover
        self.bands_list = bands_list
        self.output_dir = output_dir
        self.console = Console(file=log_file)
        self.workers = workers
        self.zip_output = zip_output
        self.crs: Any = None
        self.transform: Any = None
        self.use_smart_filter = smart_filter
        self.collection_config = get_collection(collection)

        self._validate_bands_list()

    def _validate_bands_list(self) -> None:
        invalid_bands = self.collection_config.validate_bands(self.bands_list)
        if invalid_bands:
            available = ", ".join(self.collection_config.band_names)
            raise ValueError(
                f"Invalid band names: {', '.join(invalid_bands)}. "
                f"Band names should be one of: {available}"
            )

    def _get_band_urls(self, features: list[dict[str, Any]]) -> list[list[str]]:
        urls = []
        for feature in features:
            try:
                band_hrefs = [feature["assets"][band]["href"] for band in self.bands_list]
                urls.append(band_hrefs)
            except KeyError:
                continue
        return urls

    def _fetch_and_save_bands(self, band_urls: list[str], feature_id: str) -> str | None:
        try:
            bands: list[np.ndarray] = []
            bands_meta: list[str] = []
            resolutions: list[tuple[float, float]] = []

            for band_url in band_urls:
                with rasterio.open(band_url) as band_cog:
                    resolutions.append(band_cog.res)

            lowest_resolution = max(resolutions, key=lambda res: res[0] * res[1])

            for band_url in band_urls:
                with rasterio.open(band_url) as band_cog:
                    min_x, min_y, max_x, max_y = transform_bbox(self.bbox, band_cog.crs)
                    band_window = calculate_window(band_cog, min_x, min_y, max_x, max_y)

                    if is_window_out_of_bounds(band_window):
                        return None
                    self.crs = band_cog.crs
                    window_transform = band_cog.window_transform(band_window)

                    band_data = band_cog.read(1, window=band_window).astype(float)

                    if band_cog.res != lowest_resolution:
                        scale_factor_x = band_cog.res[0] / lowest_resolution[0]
                        scale_factor_y = band_cog.res[1] / lowest_resolution[1]
                        dst_height = int(band_data.shape[0] * scale_factor_y)
                        dst_width = int(band_data.shape[1] * scale_factor_x)
                        dst_transform = window_transform * rasterio.Affine.scale(
                            1.0 / scale_factor_x, 1.0 / scale_factor_y
                        )
                        band_data = reproject(
                            source=band_data,
                            destination=np.empty(
                                (dst_height, dst_width),
                                dtype=band_data.dtype,
                            ),
                            src_transform=window_transform,
                            src_crs=band_cog.crs,
                            dst_transform=dst_transform,
                            dst_crs=band_cog.crs,
                            resampling=Resampling.average,
                        )[0]
                        window_transform = dst_transform

                    self.transform = window_transform

                    bands.append(band_data)
                    bands_meta.append(band_url.split("/")[-1].split(".")[0])

            self.console.print(f"[bold]Stacking bands for[/bold] {feature_id}")
            stacked_bands = np.stack(bands)
            output_file = os.path.join(self.output_dir, f"{feature_id}_bands_export.tif")
            save_geotiff(
                stacked_bands, output_file, self.crs, self.transform, band_descriptions=bands_meta
            )
            return output_file
        except Exception:
            raise

    def extract(self) -> None:
        self.console.print("[bold blue]Extracting bands...[/bold blue]")
        os.makedirs(self.output_dir, exist_ok=True)

        features = search_stac(
            self.collection_config,
            self.bbox,
            self.start_date,
            self.end_date,
            self.cloud_cover,
        )
        self.console.print(f"Total scenes found: {len(features)}")
        filtered_features = filter_intersected_features(features, self.bbox)
        self.console.print(f"Scenes covering input area: {len(filtered_features)}")
        overlapping_features_removed = remove_overlapping_tiles(
            filtered_features, self.collection_config.tile_id_parser
        )
        self.console.print(f"Scenes after removing overlaps: {len(overlapping_features_removed)}")
        if self.use_smart_filter:
            overlapping_features_removed = smart_filter_images(
                overlapping_features_removed, self.start_date, self.end_date
            )
            self.console.print(f"Scenes after smart filter: {len(overlapping_features_removed)}")

        band_urls_list = self._get_band_urls(overlapping_features_removed)
        result_lists: list[str | None] = []

        if self.workers > 1:
            self.console.print("Using parallel processing...")
            with ThreadPoolExecutor(max_workers=self.workers) as executor:
                futures = [
                    executor.submit(self._fetch_and_save_bands, band_urls, feature["id"])
                    for band_urls, feature in zip(band_urls_list, overlapping_features_removed)
                ]
                with Progress(console=self.console) as progress:
                    total = len(futures)
                    task = progress.add_task("Extracting Bands", total=total)
                    for index, future in enumerate(as_completed(futures), start=1):
                        result = future.result()
                        result_lists.append(result)
                        progress.advance(task)
                        try:
                            percent = int(index / total * 100) if total else 100
                        except Exception:
                            percent = 0
                        self.console.print(f"PROGRESS: {percent}% | {index}/{total}")
        else:
            with Progress(console=self.console) as progress:
                total = len(band_urls_list)
                task = progress.add_task("Extracting Bands", total=total)
                for index, (band_urls, feature) in enumerate(
                    zip(band_urls_list, overlapping_features_removed), start=1
                ):
                    result = self._fetch_and_save_bands(band_urls, feature["id"])
                    result_lists.append(result)
                    progress.advance(task)
                    try:
                        percent = int(index / total * 100) if total else 100
                    except Exception:
                        percent = 0
                    self.console.print(f"PROGRESS: {percent}% | {index}/{total}")

        if self.zip_output:
            valid_files = [f for f in result_lists if f is not None]
            zip_files(
                valid_files,
                os.path.join(self.output_dir, "tiff_files.zip"),
            )

bbox = bbox instance-attribute

start_date = start_date instance-attribute

end_date = end_date instance-attribute

cloud_cover = cloud_cover instance-attribute

bands_list = bands_list instance-attribute

output_dir = output_dir instance-attribute

console = Console(file=log_file) instance-attribute

workers = workers instance-attribute

zip_output = zip_output instance-attribute

crs: Any = None instance-attribute

transform: Any = None instance-attribute

use_smart_filter = smart_filter instance-attribute

collection_config = get_collection(collection) instance-attribute

__init__(bbox: list[float], start_date: str, end_date: str, cloud_cover: int, bands_list: list[str], output_dir: str, log_file: IO[str] = sys.stdout, workers: int = 1, zip_output: bool = False, smart_filter: bool = True, collection: str = 'sentinel-2-l2a')

Source code in src/virtughan/extract.py
def __init__(
    self,
    bbox: list[float],
    start_date: str,
    end_date: str,
    cloud_cover: int,
    bands_list: list[str],
    output_dir: str,
    log_file: IO[str] = sys.stdout,
    workers: int = 1,
    zip_output: bool = False,
    smart_filter: bool = True,
    collection: str = "sentinel-2-l2a",
):
    self.bbox = bbox
    self.start_date = start_date
    self.end_date = end_date
    self.cloud_cover = cloud_cover
    self.bands_list = bands_list
    self.output_dir = output_dir
    self.console = Console(file=log_file)
    self.workers = workers
    self.zip_output = zip_output
    self.crs: Any = None
    self.transform: Any = None
    self.use_smart_filter = smart_filter
    self.collection_config = get_collection(collection)

    self._validate_bands_list()

extract() -> None

Source code in src/virtughan/extract.py
def extract(self) -> None:
    self.console.print("[bold blue]Extracting bands...[/bold blue]")
    os.makedirs(self.output_dir, exist_ok=True)

    features = search_stac(
        self.collection_config,
        self.bbox,
        self.start_date,
        self.end_date,
        self.cloud_cover,
    )
    self.console.print(f"Total scenes found: {len(features)}")
    filtered_features = filter_intersected_features(features, self.bbox)
    self.console.print(f"Scenes covering input area: {len(filtered_features)}")
    overlapping_features_removed = remove_overlapping_tiles(
        filtered_features, self.collection_config.tile_id_parser
    )
    self.console.print(f"Scenes after removing overlaps: {len(overlapping_features_removed)}")
    if self.use_smart_filter:
        overlapping_features_removed = smart_filter_images(
            overlapping_features_removed, self.start_date, self.end_date
        )
        self.console.print(f"Scenes after smart filter: {len(overlapping_features_removed)}")

    band_urls_list = self._get_band_urls(overlapping_features_removed)
    result_lists: list[str | None] = []

    if self.workers > 1:
        self.console.print("Using parallel processing...")
        with ThreadPoolExecutor(max_workers=self.workers) as executor:
            futures = [
                executor.submit(self._fetch_and_save_bands, band_urls, feature["id"])
                for band_urls, feature in zip(band_urls_list, overlapping_features_removed)
            ]
            with Progress(console=self.console) as progress:
                total = len(futures)
                task = progress.add_task("Extracting Bands", total=total)
                for index, future in enumerate(as_completed(futures), start=1):
                    result = future.result()
                    result_lists.append(result)
                    progress.advance(task)
                    try:
                        percent = int(index / total * 100) if total else 100
                    except Exception:
                        percent = 0
                    self.console.print(f"PROGRESS: {percent}% | {index}/{total}")
    else:
        with Progress(console=self.console) as progress:
            total = len(band_urls_list)
            task = progress.add_task("Extracting Bands", total=total)
            for index, (band_urls, feature) in enumerate(
                zip(band_urls_list, overlapping_features_removed), start=1
            ):
                result = self._fetch_and_save_bands(band_urls, feature["id"])
                result_lists.append(result)
                progress.advance(task)
                try:
                    percent = int(index / total * 100) if total else 100
                except Exception:
                    percent = 0
                self.console.print(f"PROGRESS: {percent}% | {index}/{total}")

    if self.zip_output:
        valid_files = [f for f in result_lists if f is not None]
        zip_files(
            valid_files,
            os.path.join(self.output_dir, "tiff_files.zip"),
        )