From 4fdf50673d4c87d20d1cf1c7ccd6eff92aaf2c86 Mon Sep 17 00:00:00 2001 From: Alexander Minges Date: Thu, 10 Jul 2025 11:18:00 +0200 Subject: [PATCH] fix: Update AbstractProcessor to accept optional console parameter --- doi2dataset.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/doi2dataset.py b/doi2dataset.py index 2aac6e5..936b571 100755 --- a/doi2dataset.py +++ b/doi2dataset.py @@ -489,7 +489,7 @@ class Config: if not config_path.exists(): raise FileNotFoundError(f"Config file not found: {config_path}") - with open(config_path, 'r', encoding='utf-8') as f: + with open(config_path, encoding='utf-8') as f: config_data = yaml.safe_load(f) # Validate PI email addresses @@ -769,14 +769,16 @@ class AbstractProcessor: """ Retrieves and processes abstracts from CrossRef and OpenAlex. """ - def __init__(self, api_client: APIClient): + def __init__(self, api_client: APIClient, console: Console | None = None): """ Initialize with an APIClient instance. Args: api_client (APIClient): The API client to use for requests. + console (Console | None): Rich console instance for output. """ self.api_client = api_client + self.console = console or Console() def get_abstract(self, doi: str, data: dict[str, Any], license: License) -> Abstract: """ @@ -793,26 +795,26 @@ class AbstractProcessor: license_ok = {"cc-by", "cc-by-sa", "cc-by-nc", "cc-by-nc-sa", "cc0", "pd"} if license.short in license_ok: - console.print(f"\n{ICONS['info']} License {license.name} allows derivative works. Pulling abstract from CrossRef.", style="info") + self.console.print(f"\n{ICONS['info']} License {license.name} allows derivative works. Pulling abstract from CrossRef.", style="info") crossref_abstract = self._get_crossref_abstract(doi) if crossref_abstract: return Abstract(text=crossref_abstract, source="crossref") else: - console.print(f"\n{ICONS['warning']} No abstract found in CrossRef!", style="warning") + self.console.print(f"\n{ICONS['warning']} No abstract found in CrossRef!", style="warning") else: if license.name: - console.print(f"\n{ICONS['info']} License {license.name} does not allow derivative works. Reconstructing abstract from OpenAlex!", style="info") + self.console.print(f"\n{ICONS['info']} License {license.name} does not allow derivative works. Reconstructing abstract from OpenAlex!", style="info") else: - console.print(f"\n{ICONS['info']} Custom license does not allow derivative works. Reconstructing abstract from OpenAlex!", style="info") + self.console.print(f"\n{ICONS['info']} Custom license does not allow derivative works. Reconstructing abstract from OpenAlex!", style="info") openalex_abstract = self._get_openalex_abstract(data) if openalex_abstract: return Abstract(text=openalex_abstract, source="openalex") else: - console.print(f"\n{ICONS['warning']} No abstract found in OpenAlex!", style="warning") + self.console.print(f"\n{ICONS['warning']} No abstract found in OpenAlex!", style="warning") - console.print(f"\n{ICONS['warning']} No abstract found in either CrossRef nor OpenAlex!", style="warning") + self.console.print(f"\n{ICONS['warning']} No abstract found in either CrossRef nor OpenAlex!", style="warning") return Abstract(text="", source="none") def _get_crossref_abstract(self, doi: str) -> str | None: @@ -1332,7 +1334,7 @@ class MetadataProcessor: dict[str, Any]: The complete metadata dictionary. """ license_info = LicenseProcessor.process_license(data) - abstract_processor = AbstractProcessor(self.api_client) + abstract_processor = AbstractProcessor(self.api_client, self.console) abstract = abstract_processor.get_abstract(self.doi, data, license_info) citation_builder = CitationBuilder(data, self.doi, self.pi_finder, self.ror) @@ -1593,7 +1595,8 @@ def process_doi_batch( default_subject: str = "Medicine, Health and Life Sciences", contact_mail: str | None = None, upload: bool = False, - ror: bool = False + ror: bool = False, + console: Console | None = None ) -> dict[str, list[Any]]: """ Process a batch of DOIs and return a summary of results. @@ -1606,12 +1609,17 @@ def process_doi_batch( contact_mail (str | None): Contact email address. upload (bool): Flag indicating whether to upload metadata to Dataverse. ror (bool): Flag indication whether to use ROR id for affiliation. + console (Console | None): Rich console instance for output. Returns: dict[str, list[Any]]: Dictionary with keys 'success' and 'failed'. """ results: dict[str, list[Any]] = {"success": [], "failed": []} + # Use provided console or create a new one + if console is None: + console = Console() + progress_columns = [ SpinnerColumn(), TextColumn("[bold blue]{task.description:<50}"), @@ -1791,7 +1799,8 @@ def main(): default_subject=args.subject, contact_mail=args.contact_mail, upload=args.upload, - ror=args.use_ror + ror=args.use_ror, + console=console )