diff --git a/LICENSE_CELLDINO b/LICENSE_CELLDINO new file mode 100644 index 000000000..102013379 --- /dev/null +++ b/LICENSE_CELLDINO @@ -0,0 +1,395 @@ +Attribution 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution 4.0 International Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution 4.0 International Public License ("Public License"). To the +extent this Public License may be interpreted as a contract, You are +granted the Licensed Rights in consideration of Your acceptance of +these terms and conditions, and the Licensor grants You such rights in +consideration of benefits the Licensor receives from making the +Licensed Material available under these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + j. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + k. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part; and + + b. produce, reproduce, and Share Adapted Material. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/LICENSE_CELLDINO_WEIGHTS b/LICENSE_CELLDINO_WEIGHTS new file mode 100644 index 000000000..284ebaf65 --- /dev/null +++ b/LICENSE_CELLDINO_WEIGHTS @@ -0,0 +1,124 @@ +FAIR Noncommercial Research License +v1 Last Updated: August 18, 2025 + +“Acceptable Use Policy” means the FAIR Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement. + +“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein. + + +“Documentation” means the specifications, manuals and documentation accompanying +Research Materials distributed by Meta. + + +“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. + + +“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland). + +“Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others. + +“Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement. + +By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement. + + +1. License Rights and Redistribution. + + +a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials. + +b. Redistribution and Use. + i. You will not use the Research Materials or any outputs or results of the Research Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses; + + +ii. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party. + + +iii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication. + + +iv. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the FAIR Acceptable Use Policy, which is hereby incorporated by reference into this Agreement. +2. User Support. Your Noncommercial Research Use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind. + + +3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS. + +4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. + +5. Intellectual Property. + + +a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications. + +b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials. + +6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement. + +7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. + + +8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta. + + +FAIR Acceptable Use Policy + +The Fundamental AI Research (FAIR) team at Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all. + +As part of this mission, Meta makes certain research materials available for noncommercial research use. Meta is committed to promoting the safe and responsible use of such research materials. + +Prohibited Uses + +You agree you will not use, or allow others to use, Research Materials to: + + Violate the law or others’ rights, including to: +Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as: +Violence or terrorism +Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material +Human trafficking, exploitation, and sexual violence +The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials. +Sexual solicitation +Any other criminal activity + +Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals + +Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services + +Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices + +Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws + +Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using FAIR research materials + +Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system + +2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following: + +Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State + +Guns and illegal weapons (including weapon development) + +Illegal drugs and regulated/controlled substances + +Operation of critical infrastructure, transportation technologies, or heavy machinery + +Self-harm or harm to others, including suicide, cutting, and eating disorders + +Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual + +3. Intentionally deceive or mislead others, including use of FAIR Research Materials related to the following: + + Generating, promoting, or furthering fraud or the creation or promotion of disinformation + + Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content + +Generating, promoting, or further distributing spam + + Impersonating another individual without consent, authorization, or legal right + +Representing that outputs of FAIR research materials or outputs from technology using FAIR research materials are human-generated + +Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement + +4. Fail to appropriately disclose to end users any known dangers of your Research Materials. + +Please report any violation of this Policy or other problems that could lead to a violation of this Policy by submitting a report here [https://docs.google.com/forms/d/e/1FAIpQLSeb11cryAopJ7LNrC4nxEUXrHY26hfkXQMf_uH-oFgA3WlYZQ/viewform]. \ No newline at end of file diff --git a/README.md b/README.md index be6c4e67c..716e7dc48 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ :new: [2025-08-14] *Please check out the more recent [DINOv3](https://github.com/facebookresearch/dinov3) effort continuing this line of work.* +:new: [2025-11-21] *Added ChannelAgnostic-DINO code following [Scaling Channel-Adaptive Self-Supervised Learning](https://openreview.net/forum?id=pT8sgtRVAf), and Cell-DINO code [README](https://github.com/facebookresearch/dinov2/blob/main/README_CELL-DINO_AND_CHANNEL-DINO.md)* + [2025-06-11] *Added dino.txt inference code, following [DINOv2 Meets Text: A Unified Framework for Image- and Pixel-Level Vision-Language Alignment](https://arxiv.org/abs/2412.16334).* [2023-10-26] *Added DINOv2 backbones with registers, following [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588).* diff --git a/README_CELL-DINO_AND_CHANNEL-DINO.md b/README_CELL-DINO_AND_CHANNEL-DINO.md new file mode 100644 index 000000000..ee7d1ff9c --- /dev/null +++ b/README_CELL-DINO_AND_CHANNEL-DINO.md @@ -0,0 +1,29 @@ +The contents of this repo, including the code and model weights, are intended for research use only. It is not for use in medical procedures, including any diagnostics, treatment, or curative applications. Do not use this model for any clinical purpose or as a substitute for professional medical judgement. + + +# Scaling Channel-Adaptive Self-Supervised Learning (CHANNEL-DINO) + + [[`Paper `](https://openreview.net/forum?id=pT8sgtRVAf))] [[`BibTeX`](#citing-channeladaptivedino-and-dinov2)] + +Alice V. De Lorenci, Seungeun Yi, Théo Moutakanni, Piotr Bojanowski, Camille Couprie, Juan C. Caicedo, Wolfgang M. Pernice, + +with special thanks to Elouan Gardes for his contributions to the codebase. + + [README](https://github.com/facebookresearch/dinov2/blob/main/docs/README_CHANNEL_ADAPTIVE_DINO.md) + + + +# Cell-DINO: Self-Supervised Image-based Embeddings for Cell Fluorescent Microscopy (CELL-DINO) + +Théo Moutakanni, Camille Couprie, Seungeun Yi, Elouan Gardes, Piotr Bojanowski, Hugo Touvron, Michael Doron, Zitong S. Chen, Nikita Moshkov, Mathilde Caron, Armand Joulin, Wolfgang M. Pernice, Juan C. Caicedo + +to appear soon. + + [README](https://github.com/facebookresearch/dinov2/blob/main/docs/README_CELL_DINO.md) + + + ## Licenses + + Code is released under the CC BY NC License. See [LICENSE_CELLDINO](LICENSE_CELLDINO) for additional details. + Model weights are released under the FAIR Noncommercial Research License. See [LICENSE_CELLDINO_WEIGHTS](LICENSE_CELLDINO_WEIGHTS) for additional details. + \ No newline at end of file diff --git a/dinov2/configs/eval/celldino.yaml b/dinov2/configs/eval/celldino.yaml new file mode 100644 index 000000000..0c31ee81c --- /dev/null +++ b/dinov2/configs/eval/celldino.yaml @@ -0,0 +1,14 @@ +student: + arch: vit_large + patch_size: 16 + num_register_tokens: 0 + interpolate_antialias: false + interpolate_offset: 0.1 + drop_path_rate: 0.1 + in_chans: 4 + block_chunks: 4 +teacher: + in_chans: 4 +crops: + global_crops_size: 224 + local_crops_size: 96 diff --git a/dinov2/configs/eval/channeldino_ext_chammi.yaml b/dinov2/configs/eval/channeldino_ext_chammi.yaml new file mode 100644 index 000000000..e32eb1772 --- /dev/null +++ b/dinov2/configs/eval/channeldino_ext_chammi.yaml @@ -0,0 +1,35 @@ +train: + batch_size_per_gpu: 32 + OFFICIAL_EPOCH_LENGTH: 450 + cell_augmentation: true + channel_adaptive: true +student: + arch: vit_large + patch_size: 16 + num_register_tokens: 0 + interpolate_antialias: false + interpolate_offset: 0.1 + drop_path_rate: 0.1 + in_chans: 1 + block_chunks: 4 + channel_adaptive: true +teacher: + momentum_teacher: 0.996 + warmup_teacher_temp_epochs: 20 + in_chans: 1 + channel_adaptive: true +crops: + global_crops_scale: + - 0.4 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.005 + - 0.4 + global_crops_size: 224 + local_crops_size: 96 +optim: + weight_decay_end: 0.2 + base_lr: 5.0e-4 + warmup_epochs: 20 + epochs: 400 \ No newline at end of file diff --git a/dinov2/configs/ssl_default_config.yaml b/dinov2/configs/ssl_default_config.yaml index ccaae1c31..cdbea43d8 100644 --- a/dinov2/configs/ssl_default_config.yaml +++ b/dinov2/configs/ssl_default_config.yaml @@ -68,6 +68,7 @@ train: OFFICIAL_EPOCH_LENGTH: 1250 cache_dataset: true centering: "centering" # or "sinkhorn_knopp" + cell_augmentation: false student: arch: vit_large patch_size: 16 @@ -83,12 +84,16 @@ student: num_register_tokens: 0 interpolate_antialias: false interpolate_offset: 0.1 + in_chans: 3 + channel_adaptive: false teacher: momentum_teacher: 0.992 final_momentum_teacher: 1 warmup_teacher_temp: 0.04 teacher_temp: 0.07 warmup_teacher_temp_epochs: 30 + in_chans: 3 + channel_adaptive: false optim: epochs: 100 weight_decay: 0.04 diff --git a/dinov2/configs/train/hpafov_vitl16.yaml b/dinov2/configs/train/hpafov_vitl16.yaml new file mode 100644 index 000000000..59496f93d --- /dev/null +++ b/dinov2/configs/train/hpafov_vitl16.yaml @@ -0,0 +1,32 @@ +train: + batch_size_per_gpu: 16 + OFFICIAL_EPOCH_LENGTH: 450 + cell_augmentation: true +student: + arch: vit_large + patch_size: 16 + in_chans: 4 + drop_path_rate: 0.1 + block_chunks: 4 +teacher: + momentum_teacher: 0.996 + warmup_teacher_temp_epochs: 20 + in_chans: 4 +optim: + weight_decay_end: 0.2 + base_lr: 5.0e-4 + warmup_epochs: 20 +crops: + global_crops_scale: + - 0.4 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.005 + - 0.4 + global_crops_size: 224 + local_crops_size: 96 +evaluation: + eval_period_iterations: 9000 + + \ No newline at end of file diff --git a/dinov2/configs/train/hpafov_vitl16_boc.yaml b/dinov2/configs/train/hpafov_vitl16_boc.yaml new file mode 100644 index 000000000..4520df315 --- /dev/null +++ b/dinov2/configs/train/hpafov_vitl16_boc.yaml @@ -0,0 +1,31 @@ +train: + batch_size_per_gpu: 16 + OFFICIAL_EPOCH_LENGTH: 450 + cell_augmentation: true + channel_adaptive: true +student: + arch: vit_large + patch_size: 16 + in_chans: 1 + drop_path_rate: 0.1 + block_chunks: 4 +teacher: + momentum_teacher: 0.996 + warmup_teacher_temp_epochs: 20 + in_chans: 1 +crops: + global_crops_scale: + - 0.4 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.005 + - 0.4 + global_crops_size: 224 + local_crops_size: 96 +optim: + weight_decay_end: 0.2 + base_lr: 5.0e-4 + warmup_epochs: 20 + epochs: 400 + \ No newline at end of file diff --git a/dinov2/configs/train/hpaone_vitl16.yaml b/dinov2/configs/train/hpaone_vitl16.yaml new file mode 100644 index 000000000..c6f76b1c2 --- /dev/null +++ b/dinov2/configs/train/hpaone_vitl16.yaml @@ -0,0 +1,30 @@ +train: + batch_size_per_gpu: 16 + OFFICIAL_EPOCH_LENGTH: 1756 + cell_augmentation: true +student: + arch: vit_large + patch_size: 16 + in_chans: 4 + drop_path_rate: 0.1 + block_chunks: 4 +teacher: + momentum_teacher: 0.996 + warmup_teacher_temp_epochs: 20 + in_chans: 4 +optim: + weight_decay_end: 0.2 + base_lr: 5.0e-4 + warmup_epochs: 20 +crops: + global_crops_scale: + - 0.4 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.005 + - 0.4 + global_crops_size: 224 + local_crops_size: 96 +evaluation: + eval_period_iterations: 9000 \ No newline at end of file diff --git a/dinov2/data/__init__.py b/dinov2/data/__init__.py index 2ded47ea6..5bfdc802e 100644 --- a/dinov2/data/__init__.py +++ b/dinov2/data/__init__.py @@ -8,3 +8,5 @@ from .collate import collate_data_and_cast from .masking import MaskingGenerator from .augmentations import DataAugmentationDINO +from .cell_augmentations import CellAugmentationDINO +from .accumulators import NoOpAccumulator, ResultsAccumulator diff --git a/dinov2/data/accumulators.py b/dinov2/data/accumulators.py new file mode 100644 index 000000000..8dac4f322 --- /dev/null +++ b/dinov2/data/accumulators.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +from collections import defaultdict +from typing import Dict, List, Optional, Any + +import torch +from torch import Tensor +from torch.nn import functional as F + +import torch.distributed as dist +from dinov2.distributed import get_global_size + + +def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]: + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + dist.all_gather(gathered_result, result, group) + return gathered_result + + +def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]: + """ + Copied from https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/utilities/distributed.py + Gather all tensors from several ddp processes onto a list that is broadcasted to all processes. + + Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case + tensors are padded, gathered and then trimmed to secure equal workload for all processes. + + Args: + result: the value to sync + group: the process group to gather results from. Defaults to all processes (world) + + Return: + list with size equal to the process group where element i corresponds to result tensor from process i + """ + # convert tensors to contiguous format + result = result.contiguous() + + world_size = get_global_size() + dist.barrier(group=group) + + # if the tensor is scalar, things are easy + if result.ndim == 0: + return _simple_gather_all_tensors(result, group, world_size) + + # 1. Gather sizes of all tensors + local_size = torch.tensor(result.shape, device=result.device) + local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] + dist.all_gather(local_sizes, local_size, group=group) + max_size = torch.stack(local_sizes).max(dim=0).values + all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) + + # 2. If shapes are all the same, then do a simple gather: + if all_sizes_equal: + return _simple_gather_all_tensors(result, group, world_size) + + # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate + pad_dims = [] + pad_by = (max_size - local_size).detach().cpu() + for val in reversed(pad_by): + pad_dims.append(0) + pad_dims.append(val.item()) + result_padded = F.pad(result, pad_dims) + gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] + dist.all_gather(gathered_result, result_padded, group) + for idx, item_size in enumerate(local_sizes): + slice_param = [slice(dim_size) for dim_size in item_size] + gathered_result[idx] = gathered_result[idx][slice_param] + return gathered_result + + +def _cat_and_gather_tensor_list(tensor_list: List[Tensor]) -> Tensor: + local_cat = torch.cat(tensor_list) + return torch.cat(gather_all_tensors(local_cat)) + + +class Accumulator: + def __init__(self) -> None: + pass + + def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None: + raise NotImplementedError + + def accumulate(self) -> Optional[Dict[str, Tensor]]: + raise NotImplementedError + + +class NoOpAccumulator(Accumulator): + def __init__(self) -> None: + pass + + def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None: + pass + + def accumulate(self) -> None: + return None + + +class ResultsAccumulator(Accumulator): + """ + Accumulate predictions and targets across processes + """ + + def __init__(self) -> None: + self._local_values: Dict[str, List[Tensor]] = defaultdict(list) + self._gathered_values: Dict[str, Tensor] = {} + self._gathered = False + + def update(self, preds: Tensor, target: Tensor, index: Tensor) -> None: + assert len(preds) == len(target) == len(index) + assert not self._gathered, "Tensors have already been gathered in this helper" + self._local_values["preds"].append(preds) + self._local_values["target"].append(target) + self._local_values["index"].append(index) + self._gathered = False + + def _gather_tensors(self): + for k, tensor_list in self._local_values.items(): + self._gathered_values[k] = _cat_and_gather_tensor_list(tensor_list) + self._gathered = True + + def accumulate(self) -> Dict[str, Tensor]: + if not self._gathered: + self._gather_tensors() + preds, target, index = [self._gathered_values[k] for k in ["preds", "target", "index"]] + assert len(preds) == len(target) == len(index) and index.min() == 0 + preds_ordered = torch.zeros((index.max() + 1, *preds.shape[1:]), dtype=preds.dtype, device=preds.device) + preds_ordered[index] = preds + target_ordered = torch.zeros((index.max() + 1, *target.shape[1:]), dtype=target.dtype, device=target.device) + target_ordered[index] = target + return {"preds": preds_ordered, "target": target_ordered} diff --git a/dinov2/data/adapters.py b/dinov2/data/adapters.py index 2097bad04..a5efe965f 100644 --- a/dinov2/data/adapters.py +++ b/dinov2/data/adapters.py @@ -3,26 +3,49 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -from typing import Any, Tuple +from typing import Any, Tuple, Optional from torch.utils.data import Dataset class DatasetWithEnumeratedTargets(Dataset): - def __init__(self, dataset): + """ + If pad_dataset is set, pads based on torch's DistributedSampler implementation, which + with drop_last=False pads the last batch to be a multiple of the world size. + https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py#L91 + """ + + def __init__(self, dataset: Dataset, pad_dataset: bool = False, num_replicas: Optional[int] = None): self._dataset = dataset + self._size = len(self._dataset) + self._padded_size = self._size + self._pad_dataset = pad_dataset + if self._pad_dataset: + assert num_replicas is not None, "num_replicas should be set if pad_dataset is True" + self._padded_size = num_replicas * ((len(dataset) + num_replicas - 1) // num_replicas) + + def get_image_relpath(self, index: int) -> str: + assert self._pad_dataset or index < self._size + return self._dataset.get_image_relpath(index % self._size) def get_image_data(self, index: int) -> bytes: - return self._dataset.get_image_data(index) + assert self._pad_dataset or index < self._size + return self._dataset.get_image_data(index % self._size) def get_target(self, index: int) -> Tuple[Any, int]: - target = self._dataset.get_target(index) + target = self._dataset.get_target(index % self._size) + if index >= self._size: + assert self._pad_dataset + return (-1, target) return (index, target) def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: - image, target = self._dataset[index] + image, target = self._dataset[index % self._size] + if index >= self._size: + assert self._pad_dataset + return image, (-1, target) target = index if target is None else target return image, (index, target) def __len__(self) -> int: - return len(self._dataset) + return self._padded_size diff --git a/dinov2/data/cell_augmentations.py b/dinov2/data/cell_augmentations.py new file mode 100644 index 000000000..6a6156133 --- /dev/null +++ b/dinov2/data/cell_augmentations.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import logging +import torchvision +from torchvision import transforms + +from .transforms_cells import ( + RandomContrastProteinChannel, + RandomRemoveChannelExceptProtein, + RandomBrightness, + RandomContrast, + Div255, + SelfNormalizeNoDiv, +) + +logger = logging.getLogger("dinov2") + + +class CellAugmentationDINO(object): + def __init__( + self, + global_crops_scale, + local_crops_scale, + local_crops_number, + global_crops_size=224, + local_crops_size=96, + ): + self.global_crops_scale = global_crops_scale + self.local_crops_scale = local_crops_scale + self.local_crops_number = local_crops_number + self.global_crops_size = global_crops_size + self.local_crops_size = local_crops_size + + logger.info("###################################") + logger.info("Using data augmentation parameters:") + logger.info(f"global_crops_scale: {global_crops_scale}") + logger.info(f"local_crops_scale: {local_crops_scale}") + logger.info(f"local_crops_number: {local_crops_number}") + logger.info(f"global_crops_size: {global_crops_size}") + logger.info(f"local_crops_size: {local_crops_size}") + logger.info("###################################") + + additional_transforms_list = [ + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.RandomVerticalFlip(), + RandomBrightness(), + RandomContrast(), + SelfNormalizeNoDiv(), + ] + + first_transforms_list = [ + Div255(), + RandomRemoveChannelExceptProtein(), + RandomContrastProteinChannel(), + ] + + global_transforms_list = first_transforms_list.copy() + global_transforms_list.append( + torchvision.transforms.RandomResizedCrop(size=global_crops_size, scale=global_crops_scale) + ) + global_transforms_list = global_transforms_list + additional_transforms_list + + local_transforms_list = first_transforms_list + local_transforms_list.append( + torchvision.transforms.RandomResizedCrop(size=local_crops_size, scale=local_crops_scale) + ) + local_transforms_list = local_transforms_list + additional_transforms_list + + self.global_transform = transforms.Compose(global_transforms_list) + self.local_transform = transforms.Compose(local_transforms_list) + + def __call__(self, image): + output = {} + + global_crop1 = self.global_transform(image) + global_crop2 = self.global_transform(image) + + output["global_crops"] = [global_crop1, global_crop2] + + local_crops = [] + for _ in range(self.local_crops_number): + local_crops.append(self.local_transform(image)) + + output["local_crops"] = local_crops + output["global_crops_teacher"] = [global_crop1, global_crop2] + output["offsets"] = () + + return output diff --git a/dinov2/data/datasets/__init__.py b/dinov2/data/datasets/__init__.py index 5550fdc5c..db9918d2c 100644 --- a/dinov2/data/datasets/__init__.py +++ b/dinov2/data/datasets/__init__.py @@ -5,3 +5,8 @@ from .image_net import ImageNet from .image_net_22k import ImageNet22k +from .hpaone import HPAone +from .hpafov import HPAFoV +from .chammi_cp import CHAMMI_CP +from .chammi_hpa import CHAMMI_HPA +from .chammi_wtc import CHAMMI_WTC diff --git a/dinov2/data/datasets/chammi_cp.py b/dinov2/data/datasets/chammi_cp.py new file mode 100644 index 000000000..d3a6d45e0 --- /dev/null +++ b/dinov2/data/datasets/chammi_cp.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import csv +from enum import Enum +import logging +import os +from typing import Any, Callable, Optional, Union + +import numpy as np + +from .extended import ExtendedVisionDataset +from .decoders import DecoderType + +logger = logging.getLogger("dinov2") + + +METADATA_FILE = "morphem70k_v2.csv" + +CLASS_LABELS = { + "BRD-A29260609": 0, + "BRD-K04185004": 1, + "BRD-K21680192": 2, + "DMSO": 3, + "BRD-K11129031": 4, # labels only seen in TASK_FOUR + "BRD-K62310379": 5, + "BRD-K77947974": 6, +} + + +class _Split(Enum): + TRAIN = "Train" + TASK_ONE = "Task_one" + TASK_TWO = "Task_two" + TASK_THREE = "Task_three" + TASK_FOUR = "Task_four" + + +def _load_file_names_and_targets( + root: str, + split: _Split, +): + image_paths = [] + labels = [] + with open(os.path.join(root, METADATA_FILE)) as metadata: + metadata_reader = csv.DictReader(metadata) + for row in metadata_reader: + row_dataset = row["file_path"].split("/")[0] + + if row["train_test_split"].upper() == split and row_dataset == "CP": + image_paths.append(row["file_path"]) + labels.append(CLASS_LABELS[row["label"]]) + + return image_paths, labels # to debug + + +class CHAMMI_CP(ExtendedVisionDataset): + """ + Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset, + following the CHAMMI paper: https://arxiv.org/pdf/2310.19224 + Github code: https://github.com/chaudatascience/channel_adaptive_models + """ + + Split = Union[_Split] + + def __init__( + self, + *, + split: "CHAMMI_CP.Split", + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + image_decoder_type: DecoderType = DecoderType.XChannelsDecoder, + **kwargs: Any, + ) -> None: + super().__init__( + root, + transforms, + transform, + target_transform, + image_decoder_type=image_decoder_type, + **kwargs, + ) + self.split = split + self.root = root + self.num_additional_labels_loo_eval = 3 + self._image_paths, self._targets = _load_file_names_and_targets( + root, + split, + ) + + def get_image_relpath(self, index: int) -> str: + return self._image_paths[index] + + def get_image_data(self, index: int) -> bytes: + image_relpath = self.get_image_relpath(index) + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> Any: + return self._targets[index] + + def get_targets(self) -> np.ndarray: + return np.array(self._targets) + + def __len__(self) -> int: + return len(self._image_paths) diff --git a/dinov2/data/datasets/chammi_hpa.py b/dinov2/data/datasets/chammi_hpa.py new file mode 100644 index 000000000..59e6e1c43 --- /dev/null +++ b/dinov2/data/datasets/chammi_hpa.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import csv +from enum import Enum +import logging +import os +from typing import Any, Callable, Optional, Union + +import numpy as np + +from .extended import ExtendedVisionDataset +from .decoders import DecoderType + +logger = logging.getLogger("dinov2") + + +METADATA_FILE = "morphem70k_v2.csv" + +CLASS_LABELS = { + "golgi apparatus": 0, + "microtubules": 1, + "mitochondria": 2, + "nuclear speckles": 3, + "cytosol": 4, # labels only seen in TASK_THREE + "endoplasmic reticulum": 5, + "nucleoplasm": 6, +} + + +class _Split(Enum): + TRAIN = "Train" + TASK_ONE = "Task_one" + TASK_TWO = "Task_two" + TASK_THREE = "Task_three" + + +def _load_file_names_and_targets( + root: str, + split: _Split, +): + image_paths = [] + labels = [] + with open(os.path.join(root, METADATA_FILE)) as metadata: + metadata_reader = csv.DictReader(metadata) + for row in metadata_reader: + row_dataset = row["file_path"].split("/")[0] + if row["train_test_split"].upper() == split and row_dataset == "HPA": + image_paths.append(row["file_path"]) + labels.append(CLASS_LABELS[row["label"]]) + + return image_paths, labels + + +class CHAMMI_HPA(ExtendedVisionDataset): + """ + Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset, + following the CHAMMI paper: https://arxiv.org/pdf/2310.19224 + Github code: https://github.com/chaudatascience/channel_adaptive_models + """ + + Split = Union[_Split] + + def __init__( + self, + *, + split: "CHAMMI_HPA.Split", + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + image_decoder_type: DecoderType = DecoderType.XChannelsDecoder, + **kwargs: Any, + ) -> None: + super().__init__( + root, + transforms, + transform, + target_transform, + image_decoder_type=image_decoder_type, + **kwargs, + ) + self.split = split + self.root = root + self.num_additional_labels_loo_eval = 3 + + self._image_paths, self._targets = _load_file_names_and_targets( + root, + split, + ) + + def get_image_relpath(self, index: int) -> str: + return self._image_paths[index] + + def get_image_data(self, index: int) -> bytes: + image_relpath = self.get_image_relpath(index) + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> Any: + return self._targets[index] + + def get_targets(self) -> np.ndarray: + return np.array(self._targets) + + def __len__(self) -> int: + return len(self._image_paths) diff --git a/dinov2/data/datasets/chammi_wtc.py b/dinov2/data/datasets/chammi_wtc.py new file mode 100644 index 000000000..a34239da3 --- /dev/null +++ b/dinov2/data/datasets/chammi_wtc.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import csv +from enum import Enum +import logging +import os +from typing import Any, Callable, Optional, Union + +import numpy as np + +from .extended import ExtendedVisionDataset +from .decoders import DecoderType + +logger = logging.getLogger("dinov2") + + +METADATA_FILE = "morphem70k_v2.csv" + +CLASS_LABELS = { + "M0": 0, + "M1M2": 1, + "M3": 2, + "M4M5": 3, + "M6M7_complete": 4, + "M6M7_single": 5, +} + + +class _Split(Enum): + TRAIN = "Train" + TASK_ONE = "Task_one" + TASK_TWO = "Task_two" + + +def _load_file_names_and_targets( + root: str, + split: _Split, +): + image_paths = [] + labels = [] + with open(os.path.join(root, METADATA_FILE)) as metadata: + metadata_reader = csv.DictReader(metadata) + for row in metadata_reader: + row_dataset = row["file_path"].split("/")[0] + if row["train_test_split"].upper() == split and row_dataset == "Allen": + image_paths.append(row["file_path"]) + labels.append(CLASS_LABELS[row["label"]]) + + return image_paths, labels + + +class CHAMMI_WTC(ExtendedVisionDataset): + """ + Implementation of the CP (Cell-Painting) subset of the CHAMMI benchmark dataset, + following the CHAMMI paper: https://arxiv.org/pdf/2310.19224 + Github code: https://github.com/chaudatascience/channel_adaptive_models + """ + + Split = Union[_Split] + + def __init__( + self, + *, + split: "CHAMMI_WTC.Split", + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + image_decoder_type: DecoderType = DecoderType.XChannelsTIFFDecoder, + **kwargs: Any, + ) -> None: + super().__init__( + root, + transforms, + transform, + target_transform, + image_decoder_type=image_decoder_type, + **kwargs, + ) + self.split = split + self.root = root + + self._image_paths, self._targets = _load_file_names_and_targets( + root, + split, + ) + + def get_image_relpath(self, index: int) -> str: + return self._image_paths[index] + + def get_image_data(self, index: int) -> bytes: + image_relpath = self.get_image_relpath(index) + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> Any: + return self._targets[index] + + def get_targets(self) -> np.ndarray: + return np.array(self._targets) + + def __len__(self) -> int: + return len(self._image_paths) diff --git a/dinov2/data/datasets/decoders.py b/dinov2/data/datasets/decoders.py index 3769f7750..feb746885 100644 --- a/dinov2/data/datasets/decoders.py +++ b/dinov2/data/datasets/decoders.py @@ -4,9 +4,17 @@ # found in the LICENSE file in the root directory of this source tree. from io import BytesIO -from typing import Any +from typing import Any, Type from PIL import Image +import numpy as np +import torch +from enum import Enum + +try: + import tifffile +except ImportError: + print("Could not import `tifffile`, TIFFImageDataDecoder will be disabled") class Decoder: @@ -14,6 +22,23 @@ def decode(self) -> Any: raise NotImplementedError +class DecoderType(Enum): + ImageDataDecoder = "ImageDataDecoder" + XChannelsDecoder = "XChannelsDecoder" + XChannelsTIFFDecoder = "XChannelsTIFFDecoder" + ChannelSelectDecoder = "ChannelSelectDecoder" + + def get_class(self) -> Type[Decoder]: # noqa: C901 + if self == DecoderType.ImageDataDecoder: + return ImageDataDecoder + if self == DecoderType.XChannelsDecoder: + return XChannelsDecoder + if self == DecoderType.XChannelsTIFFDecoder: + return XChannelsTIFFDecoder + if self == DecoderType.ChannelSelectDecoder: + return ChannelSelectDecoder + + class ImageDataDecoder(Decoder): def __init__(self, image_data: bytes) -> None: self._image_data = image_data @@ -29,3 +54,41 @@ def __init__(self, target: Any): def decode(self) -> Any: return self._target + + +class XChannelsDecoder(Decoder): + def __init__(self, image_data: bytes) -> None: + self._image_data = image_data + + def decode(self): + im = np.asarray(Image.open(BytesIO(self._image_data))) + if len(im.shape) == 2: + im = np.reshape(im, (im.shape[0], im.shape[0], -1), order="F") + return torch.Tensor(im).permute(2, 0, 1) + + +class XChannelsTIFFDecoder(Decoder): + def __init__(self, image_data: bytes, num_channels: int = 3) -> None: + self._image_data = image_data + self._num_channels = num_channels + + def decode(self): + numpy_array = tifffile.imread(BytesIO(self._image_data)) + numpy_array = np.reshape(numpy_array, (numpy_array.shape[0], -1, self._num_channels), order="F") + return torch.Tensor(numpy_array).permute(2, 0, 1) + + +class ChannelSelectDecoder(Decoder): + def __init__(self, image_data: bytes, select_channel: bool = False) -> None: + self.select_channel = select_channel + if select_channel: + self._image_data = image_data[:-1] + self._channel = image_data[-1] + else: + self._image_data = image_data + + def decode(self): + im = np.asarray(Image.open(BytesIO(self._image_data))) + if self.select_channel: + return torch.Tensor(im).permute(2, 0, 1)[[self._channel]] + return torch.Tensor(im).permute(2, 0, 1) diff --git a/dinov2/data/datasets/extended.py b/dinov2/data/datasets/extended.py index f60b619a3..32555b57c 100644 --- a/dinov2/data/datasets/extended.py +++ b/dinov2/data/datasets/extended.py @@ -7,11 +7,17 @@ from torchvision.datasets import VisionDataset -from .decoders import TargetDecoder, ImageDataDecoder +from .decoders import DecoderType, TargetDecoder class ExtendedVisionDataset(VisionDataset): def __init__(self, *args, **kwargs) -> None: + image_decoder_type = kwargs.pop("image_decoder_type", DecoderType.ImageDataDecoder) + self._decoder_params = {} + self._image_decoder_class = image_decoder_type.get_class() + if "image_decoder_params" in kwargs: + self._decoder_params = kwargs.pop("image_decoder_params") + super().__init__(*args, **kwargs) # type: ignore def get_image_data(self, index: int) -> bytes: @@ -23,7 +29,7 @@ def get_target(self, index: int) -> Any: def __getitem__(self, index: int) -> Tuple[Any, Any]: try: image_data = self.get_image_data(index) - image = ImageDataDecoder(image_data).decode() + image = self._image_decoder_class(image_data, **self._decoder_params).decode() except Exception as e: raise RuntimeError(f"can not read image for sample {index}") from e target = self.get_target(index) diff --git a/dinov2/data/datasets/hpafov.py b/dinov2/data/datasets/hpafov.py new file mode 100644 index 000000000..18a59ddc3 --- /dev/null +++ b/dinov2/data/datasets/hpafov.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import csv +from enum import Enum +import logging +import os +from typing import Any, Callable, List, Optional, Tuple, Union, Dict + +import numpy as np + +from .extended import ExtendedVisionDataset +from .decoders import DecoderType + +logger = logging.getLogger("dinov2") + +CELL_TYPE = [ + "BJ", # 1 + "LHCN-M2", + "RH-30", + "SH-SY5Y", + "U-2 OS", # 5 + "ASC TERT1", + "HaCaT", + "A-431", + "U-251 MG", + "HEK 293", # 10 + "A549", + "RT4", + "HeLa", + "MCF7", + "PC-3", # 15 + "hTERT-RPE1", + "SK-MEL-30", + "EFO-21", + "AF22", + "HEL", # 20 + "Hep G2", + "HUVEC TERT2", + "THP-1", + "CACO-2", + "JURKAT", # 25 + "RPTEC TERT1", + "SuSa", + "REH", + "HDLM-2", + "K-562", # 30 + "hTCEpi", + "NB-4", + "HAP1", + "OE19", + "SiHa", # 35 +] + +PROTEIN_LOCALIZATION = [ # matches https://www.kaggle.com/c/human-protein-atlas-image-classification/data + "nucleoplasm", + "nuclear membrane", + "nucleoli", + "nucleoli fibrillar center", + "nuclear speckles", # 5 + "nuclear bodies", + "endoplasmic reticulum", + "golgi apparatus", + "peroxisomes", + "endosomes", # 10 + "lysosomes", + "intermediate filaments", + "actin filaments", + "focal adhesion sites", + "microtubules", # 15 + "microtubule ends", + "cytokinetic bridge", + "mitotic spindle", + "microtubule organizing center", + "centrosome", # 20 + "lipid droplets", + "plasma membrane", + "cell junctions", + "mitochondria", + "aggresome", # 25 + "cytosol", + "cytoplasmic bodies", + "rods & rings", +] + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + SSL = "ssl" + + +def get_csv_fpath(split): + """ + Path to data relative to root + """ + if split == _Split.TRAIN.value.upper(): + return "2022_07_04_whole_image_train_data/whole_images_512_train.csv" + elif split == _Split.VAL.value.upper(): + return "2022_07_04_whole_image_train_data/whole_images_512_test.csv" + + +class _WildCard(Enum): + NONE = "none" + SEPARATECHANNELS = "separate_channels" # each channel from each image is treated as an independent sample, overrides chosen channel configuration + + +class _Mode(Enum): + """ + Targets: + - ALL: tuple, (one hot encoding of multilabel protein localization, categorical encoding of cell type) + - PROTEIN_LOCALIZATION: one hot encoding of multilabel protein localization + - CELL_TYPE: categorical encoding of cell type + """ + + ALL = "all" + PROTEIN_LOCALIZATION = "protein_localization" + CELL_TYPE = "cell_type" + + @property + def nb_labels(self): + if self == _Mode.CELL_TYPE: + return len(CELL_TYPE) + elif self == _Mode.PROTEIN_LOCALIZATION: + return len(PROTEIN_LOCALIZATION) + else: + return None + + +# def _list_images_from_csv(img_path, csv_path): +# L = [] +# with open(csv_path) as filename: +# reader = csv.DictReader(filename) +# for row in reader: +# breakpoint() +# L.append(os.path.join(img_path, row["ID"] + ".png")) +# return L + +def _list_ssl_images(img_rootdir): + img_list = [] + for file in os.listdir(img_rootdir): + if file.endswith(".tiff"): + img_list.append(os.path.join(img_rootdir, file)) + listofzeros = [0] * len(img_list) + return img_list, listofzeros + +def _load_file_names_and_labels_ssl( + root: str, +) -> Tuple[List[str], List[Any]]: + + curr_img_path = os.path.join(root, "512_whole_images") + csv_train_ssl = os.path.join(root, "whole_images_names_deduplicated.csv") + image_paths, labels = _list_ssl_images(curr_img_path) + #labels = [i for i in range(len(image_paths))] + + return image_paths, labels + + +def _load_file_names_and_labels( + root: str, + split: _Split, + mode: _Mode, +) -> Tuple[List[str], List[Any], np.ndarray]: + + data_path = os.path.join(root, "512_whole_images") + csv_fpath = os.path.join(root, get_csv_fpath(split)) + + image_paths = [] + labels = [] + + with open(csv_fpath) as filename: + reader = csv.DictReader(filename) + for row in reader: + + add_sample = True + if mode != _Mode.PROTEIN_LOCALIZATION.value.upper(): + # categorical + if row["cell_type"] in CELL_TYPE: + cell_type = CELL_TYPE.index(row["cell_type"]) + else: + cell_type = np.nan + + if mode != _Mode.CELL_TYPE.value.upper(): + # one hot encoding + prot_loc = np.zeros(len(PROTEIN_LOCALIZATION), dtype=np.int_) + for k in range(len(PROTEIN_LOCALIZATION)): + if row[PROTEIN_LOCALIZATION[k]] == "True": + prot_loc[k] = 1 + if prot_loc.max() < 0.5: + add_sample = False + + if add_sample: + if mode == _Mode.PROTEIN_LOCALIZATION.value.upper(): + labels.append(prot_loc) + elif mode == _Mode.CELL_TYPE.value.upper(): + labels.append(cell_type) + else: + labels.append({"prot_loc": prot_loc, "cell_type": cell_type}) + + candidate_path = os.path.join(data_path, row["file"].split("/")[-1]) + if os.path.exists(candidate_path): + image_paths.append(candidate_path) + else: + candidate_path = os.path.join(data_path, row["file"].split("/")[-1].split(".")[0] + "_blue.png") + # some images on the normalized_data folder have a _blue suffix on their names + if os.path.exists(candidate_path): + image_paths.append(candidate_path) + else: + raise FileNotFoundError(f"File {candidate_path} not found.") + + return image_paths, labels + + +class HPAFoV(ExtendedVisionDataset): + Split = Union[_Split] + Mode = Union[_Mode] + WildCard = Union[_WildCard] + + def __init__( + self, + *, + split: "HPAFoV.Split" = _Split.TRAIN, + mode: "HPAFoV.Mode" = _Mode.ALL, + wildcard: "HPAFoV.WildCard" = _WildCard.NONE, + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + image_decoder_type: DecoderType = DecoderType.ChannelSelectDecoder, + image_decoder_params: Dict[str, Any] = {}, + **kwargs: Any, + ) -> None: + super().__init__( + root, + transforms, + transform, + target_transform, + image_decoder_type=image_decoder_type, + image_decoder_params={ + "select_channel": True + if wildcard == _WildCard.SEPARATECHANNELS or wildcard == "SEPARATE_CHANNELS" + else False + }, + **kwargs, + ) + self.mode = mode + self.split = split + self.root = root + self.wildcard = wildcard + self.channel_adaptive = True + if split == _Split.SSL.value.upper() or split == _Split.SSL or split == "SSL": + self._image_paths, self._labels = _load_file_names_and_labels_ssl(root) + #self.channel_adaptive = False + else: + self._image_paths, self._labels = _load_file_names_and_labels(root, self.split, self.mode) + + self._channels = np.repeat(np.array([[0, 1, 2, 3]]), len(self._image_paths), axis=0).tolist() + + if self.wildcard == _WildCard.SEPARATECHANNELS.value.upper(): + image_paths, labels, channels = self._image_paths, self._labels, self._channels + channels = np.array(channels) + # separate and stack the columns of the channels array + C = channels.shape[1] + channels = np.concatenate([channels[:, i] for i in range(C)]) + self._channels = np.expand_dims(channels, 1).tolist() + self.image_paths = image_paths * C + self.labels = labels * C + + def get_image_relpath(self, index: int) -> str: + return self._image_paths[index] + + def get_image_data(self, index: int) -> bytes: + image_relpath = self.get_image_relpath(index) + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + if self.channel_adaptive: + channels = self._channels[index] + return image_data + bytes(channels) + (len(channels)).to_bytes(1, byteorder="big") + else: + return image_data + + def get_target(self, index: int) -> Any: + return self._labels[index] + + def get_targets(self) -> np.ndarray: + return np.array(self._labels) + + def __len__(self) -> int: + return len(self._image_paths) diff --git a/dinov2/data/datasets/hpaone.py b/dinov2/data/datasets/hpaone.py new file mode 100644 index 000000000..46bb51ced --- /dev/null +++ b/dinov2/data/datasets/hpaone.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import csv +from enum import Enum +import logging +import os +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np + +from .extended import ExtendedVisionDataset +from .decoders import DecoderType + +logger = logging.getLogger("dinov2") + +PROTEIN_LOCALIZATION = [ + "actin filaments,focal adhesion sites", + "aggresome", + "centrosome,centriolar satellite", + "cytosol", + "endoplasmic reticulum", + "golgi apparatus", + "intermediate filaments", + "microtubules", + "mitochondria", + "mitotic spindle", + "no staining", + "nuclear bodies", + "nuclear membrane", + "nuclear speckles", + "nucleoli", + "nucleoli fibrillar center", + "nucleoplasm", + "plasma membrane,cell junctions", + "vesicles,peroxisomes,endosomes,lysosomes,lipid droplets,cytoplasmic bodies", +] # 19 + + +CELL_TYPE = [ + "A-431", # 0 + "A549", + "AF22", + "ASC TERT1", + "BJ", + "CACO-2", + "EFO-21", + "HAP1", + "HDLM-2", + "HEK 293", # 9 + "HEL", + "HUVEC TERT2", + "HaCaT", + "HeLa", + "Hep G2", + "JURKAT", + "K-562", + "MCF7", + "PC-3", + "REH", + "RH-30", # 20 + "RPTEC TERT1", + "RT4", + "SH-SY5Y", + "SK-MEL-30", + "SiHa", + "U-2 OS", + "U-251 MG", + "hTCEpi", # 28 +] # 29 cell types + + +class _Split(Enum): + VAL = "val" + TRAIN = "train" + ALL = "all" # images without labels, for encoder training + + +class _Mode(Enum): + PROTEIN_LOCALIZATION = "protein_localization" + CELL_TYPE = "cell_type" + + @property + def num_labels(self): + if self == _Mode.CELL_TYPE.value.upper(): + return len(CELL_TYPE) + return len(PROTEIN_LOCALIZATION) + + +def _simple_parse_csv(img_rootdir, csv_filepath: str): + samples = [] + with open(csv_filepath) as filename: + template = csv.DictReader(filename) + samples = [(os.path.join(img_rootdir , row["img_path"]), 0) for row in template] + return samples + + +def _parse_csv(img_rootdir, csv_labels_path: str): + nb_protein_location = len(PROTEIN_LOCALIZATION) + nb_cell_type = len(CELL_TYPE) + samples = [] + with open(csv_labels_path) as filename: + reader = csv.DictReader(filename) + for row in reader: + protein_location = np.zeros(nb_protein_location, dtype=np.int_) + for k in range(nb_protein_location): + if row[PROTEIN_LOCALIZATION[k]] == "True": + protein_location[k] = 1 + + cell_type = 0 + for k in range(nb_cell_type): + if row[CELL_TYPE[k]] == "True": + cell_type = k + + samples.append( + ( + img_rootdir + "/" + row["file"].rsplit("/", 1)[1], + protein_location, + cell_type, + ) + ) + return samples + + +def _load_file_names_and_labels_ssl( + root: str, +) -> Tuple[List[str], List[Any]]: + curr_dir_train = os.path.join(root, "fixed_size_masked_single_cells_HPA") + csv_all_path = os.path.join(root, "pretraining_hpa_single_cell.csv") + samples = _simple_parse_csv(curr_dir_train, csv_all_path) + image_paths, fake_labels = zip(*samples) + lab = list(fake_labels) + return image_paths, lab + + +def _load_file_names_and_labels_train_or_test( + root: str, + split: _Split, + mode: _Mode, +) -> Tuple[List[str], List[Any]]: + + if split == _Split.TRAIN.value.upper(): + csv_labels_path = os.path.join(root, "fixed_size_masked_single_cells_pretrain_20240507.csv") + elif split == _Split.VAL.value.upper(): + csv_labels_path = os.path.join(root, "fixed_size_masked_single_cells_evaluation_20240507.csv") + curr_dir_val = os.path.join(root, "fixed_size_masked_single_cells_HPA") + + samples = _parse_csv(curr_dir_val, csv_labels_path) + image_paths, protein_location, cell_type = zip(*samples) + if mode == _Mode.PROTEIN_LOCALIZATION.value.upper(): + lab = protein_location + elif mode == _Mode.CELL_TYPE.value.upper(): + lab = cell_type + else: + lab = protein_location, cell_type + image_paths = list(image_paths) + return image_paths, lab + + +class HPAone(ExtendedVisionDataset): + Split = Union[_Split] + Mode = Union[_Mode] + + def __init__( + self, + *, + split: "HPAone.Split" = _Split.ALL, + mode: "HPAone.Mode" = None, + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + image_decoder_type: DecoderType = DecoderType.XChannelsDecoder, + **kwargs: Any, + ) -> None: + super().__init__( + root, + transforms, + transform, + target_transform, + image_decoder_type=image_decoder_type, + **kwargs, + ) + self.mode = mode + self.split = split + self.root = root + + if ( + split in {_Split.TRAIN.value.upper(), _Split.VAL.value.upper()} + or split == _Split.TRAIN + or split == _Split.VAL + ): + ( + self._image_paths, + self._labels, + ) = _load_file_names_and_labels_train_or_test(root, split, mode) + elif split == _Split.ALL.value.upper() or split == _Split.ALL: + self._image_paths, self._labels = _load_file_names_and_labels_ssl(root) + else: + logger.info(f"unknown split: {split}, {_Split.ALL.value.upper()}") + + def get_image_relpath(self, index: int) -> str: + return self._image_paths[index] + + def get_image_data(self, index: int) -> bytes: + image_relpath = self.get_image_relpath(index) + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> Any: + return self._labels[index] + + def get_targets(self) -> np.ndarray: + return np.array(self._labels) + + def __len__(self) -> int: + return len(self._image_paths) diff --git a/dinov2/data/loaders.py b/dinov2/data/loaders.py index d6a2f0210..fdf6709b8 100644 --- a/dinov2/data/loaders.py +++ b/dinov2/data/loaders.py @@ -10,7 +10,7 @@ import torch from torch.utils.data import Sampler -from .datasets import ImageNet, ImageNet22k +from .datasets import ImageNet, ImageNet22k, HPAone, HPAFoV, CHAMMI_CP, CHAMMI_HPA, CHAMMI_WTC from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler @@ -49,7 +49,7 @@ def _parse_dataset_str(dataset_str: str): for token in tokens[1:]: key, value = token.split("=") - assert key in ("root", "extra", "split") + assert key in ("root", "extra", "split", "mode", "wildcard") kwargs[key] = value if name == "ImageNet": @@ -58,6 +58,16 @@ def _parse_dataset_str(dataset_str: str): kwargs["split"] = ImageNet.Split[kwargs["split"]] elif name == "ImageNet22k": class_ = ImageNet22k + elif name == "HPAone": + class_ = HPAone + elif name == "HPAFoV": + class_ = HPAFoV + elif name == "CHAMMI_CP": + class_ = CHAMMI_CP + elif name == "CHAMMI_WTC": + class_ = CHAMMI_WTC + elif name == "CHAMMI_HPA": + class_ = CHAMMI_HPA else: raise ValueError(f'Unsupported dataset "{name}"') diff --git a/dinov2/data/transforms_cells.py b/dinov2/data/transforms_cells.py new file mode 100644 index 000000000..b7c44ed19 --- /dev/null +++ b/dinov2/data/transforms_cells.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import torch +from torchvision import transforms +import numpy as np +from enum import Enum + + +class NormalizationType(Enum): + SELF_NORM_AUG_DECODER = "self_norm_aug_decoder" + SELF_NORM_CENTER_CROP = "self_norm_center_crop" + + +class Div255(torch.nn.Module): + def forward(self, x): + x = x / 255 + return x + + +class SelfNormalizeNoDiv(torch.nn.Module): + def forward(self, x): + m = x.mean((-2, -1), keepdim=True) + s = x.std((-2, -1), unbiased=False, keepdim=True) + x -= m + x /= s + 1e-7 + return x + + +class SelfNormalize(torch.nn.Module): + def forward(self, x): + x = x / 255 + m = x.mean((-2, -1), keepdim=True) + s = x.std((-2, -1), unbiased=False, keepdim=True) + x -= m + x /= s + 1e-7 + return x + + +class RandomContrastProteinChannel(torch.nn.Module): + """ + Random constrast rescaling of the protein channel only. + RescaleProtein function in Dino4cell codebase. + """ + + def __init__(self, p=0.2): + super().__init__() + self.p = p + + def forward(self, img): + if img.max() == 0: + return img + if len(img) == 1: + return img + if np.random.rand() <= self.p: + random_factor = (np.random.rand() * 2) / img.max() # scaling + img[1] = img[1] * random_factor + return img + else: + return img + + +class RandomRemoveChannelExceptProtein(torch.nn.Module): + """ + dropping a channel at random except the channel 1, corresponding to proteins in HPA datasets. + """ + + def __init__(self, p=0.2): + super().__init__() + self.p = p + + def forward(self, img): + img_size = np.array(img).shape + if img_size[0] < 4: + return img + if np.random.rand() <= self.p: + channel_to_blacken = np.random.choice(np.array([0, 2, 3])) + img[channel_to_blacken] = torch.zeros(1, *img.shape[1:]) + return img + else: + return img + + +class RandomRemoveChannel(torch.nn.Module): + """ + dropping a channel at random + """ + + def __init__(self, p=0.2): + super().__init__() + self.p = p + + def forward(self, img): + img_size = np.array(img).shape + num_channels = img_size[0] + if num_channels < 4: + return img + if np.random.rand() <= self.p: + channel_to_blacken = np.random.choice(np.array(list(range(num_channels)))) + img[channel_to_blacken] = torch.zeros(1, *img.shape[1:]) + return img + else: + return img + + +class RandomContrast(torch.nn.Module): + def __init__(self, p=0.2): + super().__init__() + self.p = p + + def forward(self, img): + if img.max() == 0: + return img + n_channels = img.shape[0] + for ind in range(n_channels): + factor = max(np.random.normal(1, self.p), 0.5) + img[ind] = transforms.functional.adjust_contrast(img[ind][None, ...], factor) + return img + + +class RandomBrightness(torch.nn.Module): + def __init__(self, p=0.2): + super().__init__() + self.p = p + + def forward(self, img): + if img.max() == 0: + return img + n_channels = img.shape[0] + for ind in range(n_channels): + factor = max(np.random.normal(1, self.p), 0.5) + img[ind] = transforms.functional.adjust_brightness(img[ind], factor) + return img + + +def make_classification_eval_cell_transform( + *, + resize_size: int = 0, + interpolation=transforms.InterpolationMode.BICUBIC, + crop_size: int = 384, + normalization_type: Enum = NormalizationType.SELF_NORM_CENTER_CROP, +) -> transforms.Compose: + + from .transforms_cells import ( + Div255, + SelfNormalizeNoDiv, + ) + + transforms_list = [Div255()] + if resize_size > 0: + transforms_list.append(transforms.Resize(resize_size, interpolation=interpolation)) + + if normalization_type == NormalizationType.SELF_NORM_AUG_DECODER: + transforms_list.extend( + [ + transforms.RandomCrop(size=crop_size, pad_if_needed=True), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + ] + ) + elif normalization_type == NormalizationType.SELF_NORM_CENTER_CROP: + transforms_list.append(transforms.CenterCrop(size=crop_size)) + else: + raise ValueError("f{normalization_type}: unknown NormalizationType") + transforms_list.append(SelfNormalizeNoDiv()) + + return transforms.Compose(transforms_list) diff --git a/dinov2/eval/knn_celldino.py b/dinov2/eval/knn_celldino.py new file mode 100644 index 000000000..6e8494157 --- /dev/null +++ b/dinov2/eval/knn_celldino.py @@ -0,0 +1,478 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import argparse +from functools import partial +import json +import logging +import os +import sys +from typing import List, Optional, Any +import numpy as np + +import torch +import torch.backends.cudnn as cudnn +import pandas as pd +from sklearn.metrics import f1_score + +import dinov2.distributed as distributed +from dinov2.data import make_dataset, DatasetWithEnumeratedTargets, SamplerType, make_data_loader +from dinov2.data.transforms_cells import NormalizationType, make_classification_eval_cell_transform +from dinov2.eval.metrics import build_metric, MetricType +from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.eval.setup import setup_and_build_model + +from dinov2.data import ResultsAccumulator +from dinov2.eval.utils import ModelWithNormalize +from dinov2.eval.utils_celldino import ( + BagOfChannelsModelWithNormalize, + extract_features_celldino, + average_metrics, + create_train_dataset_dict, + get_num_classes, + extract_features_for_dataset_dict, + evaluate_with_accumulate, +) +from dinov2.eval.knn import KnnModule, DictKeysModule +from torch.utils.data import Subset as SubsetEx +from torch.utils.data import ConcatDataset as ConcatDatasetEx + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--nb_knn", + nargs="+", + type=int, + help="Number of NN to use. 20 is usually working the best.", + ) + parser.add_argument( + "--temperature", + type=float, + help="Temperature used in the voting coefficient", + ) + parser.add_argument( + "--gather-on-cpu", + action="store_true", + help="Whether to gather the train features on cpu, slower" + "but useful to avoid OOM for large datasets (e.g. ImageNet22k).", + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch size.", + ) + parser.add_argument( + "--n-per-class-list", + nargs="+", + type=int, + help="Number to take per class", + ) + parser.add_argument( + "--n-tries", + type=int, + help="Number of tries", + ) + parser.add_argument( + "--leave-one-out-dataset", + type=str, + help="Path with indexes to use the leave one out strategy for CHAMMI_CP task 3 and CHAMMI_HPA task 4", + ) + parser.add_argument( + "--bag-of-channels", + action="store_true", + help='Whether to use the "bag of channels" channel adaptive strategy', + ) + parser.add_argument( + "--crop-size", + type=int, + help="crop size for train and eval", + ) + parser.add_argument( + "--resize-size", + type=int, + help="resize size for image just before crop. 0: no resize", + ) + parser.add_argument( + "--metric-type", + type=MetricType, + choices=list(MetricType), + help="Validation metric", + ) + parser.add_argument( + "--avgpool", + action="store_true", + help="Whether to use average pooling of path tokens in addition to CLS tokens", + ) + + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + nb_knn=[1], + temperature=0.07, + batch_size=256, + resize_size=0, + ) + return parser + + +class SequentialWithKwargs(torch.nn.Sequential): + def __init__(self, *args): + super().__init__(*args) + + def forward(self, input, **kwargs): + + input = self[0](input, **kwargs) + for module in self[1:]: + input = module(input) + return input + + +def create_train_test_dataset_dict_leave_one_out( + train_dataset, + test_dataset, +) -> dict[int, dict[int, Any]]: + """ + This function implements a train dataset dictionary with the leave-one-out (LOO) method. + Specifically, given a train dataset and test dataset, it creates a train dataset for each + test dataset point, which is a combination of train+test dataset except for this specific data point. + At the end, it contains len(test_dataset) key and value pairs. + + Format is {"nth-test-sample": dataset_without_test_sample} + """ + train_dataset_dict: dict[int, Any] = {} + test_size = len(test_dataset) + + for test_sample_index in range(test_size): + test_indices_bool = torch.ones(test_size, dtype=bool) + test_indices_bool[test_sample_index] = False + train_dataset_dict[test_sample_index] = ConcatDatasetEx( + [train_dataset, SubsetEx(test_dataset, test_indices_bool.nonzero().flatten())] + ) + + return train_dataset_dict + + +def eval_knn_with_leave_one_out( + model, leave_one_out_dataset, train_dataset, test_dataset, metric_type, nb_knn, temperature, batch_size, num_workers +): + num_classes = get_num_classes(test_dataset) + train_dataset_dict = create_train_dataset_dict(train_dataset) + test_dataset_dict = create_train_dataset_dict(test_dataset) + + logger.info("Extracting features for train set...") + train_data_dict = extract_features_for_dataset_dict( + model, train_dataset_dict, batch_size, num_workers, gather_on_cpu=True + ) + test_data_dict = extract_features_for_dataset_dict( + model, test_dataset_dict, batch_size, num_workers, gather_on_cpu=True + ) + + train_features = train_data_dict[0]["train_features"] + train_labels = train_data_dict[0]["train_labels"] + test_features = test_data_dict[0]["train_features"] + test_labels = test_data_dict[0]["train_labels"] + + metric_collection = build_metric(metric_type, num_classes=3) + + device = torch.cuda.current_device() + partial_knn_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes) + + logger.info("Reading the leave-one-out label metadata.") + + leave_one_out_indices = {} + metadata = pd.read_csv(leave_one_out_dataset) + if "HPA" in leave_one_out_dataset: + metadata = metadata[metadata["Task_three"]].reset_index() + leave_one_out_label_type = "cell_type" + else: + metadata = metadata[metadata["Task_four"]].reset_index() + leave_one_out_label_type = "Plate" + leave_one_out_labels = metadata[leave_one_out_label_type].unique() + + for leave_one_out_label in leave_one_out_labels: + leave_one_out_indices[leave_one_out_label] = torch.tensor( + metadata[metadata[leave_one_out_label_type] == leave_one_out_label].index.values + ) + + # ============ evaluation ... ============ + logger.info("Start the k-NN classification.") + + eval_metrics_dict = {} + postprocessors, metrics = {k: DictKeysModule([k]) for k in nb_knn}, { + k: metric_collection.clone().to(device) for k in nb_knn + } + for metric_key in metrics.keys(): + metrics[metric_key] = metrics[metric_key].to(device) + + accumulator_class = ResultsAccumulator + accumulators = {k: accumulator_class() for k in postprocessors.keys()} + all_preds = [] + all_target = [] + + for loo_label, loo_indices in leave_one_out_indices.items(): + logger.info(f"Evaluating on test sample {loo_label}") + loo_for_training_indices = torch.ones(test_features.shape[0], dtype=bool) + loo_for_training_indices[loo_indices] = False + train_features_sample = torch.cat([train_features, test_features[loo_for_training_indices]]) + train_labels_sample = torch.cat([train_labels, test_labels[loo_for_training_indices]]) + logger.info(f"Train shape {train_features_sample.shape}, Test shape {test_features[loo_indices].shape}") + logger.info( + f"Train values {train_labels_sample.unique(return_counts=True)}, Test shape {test_labels[loo_indices].unique(return_counts=True)}" + ) + knn_module = partial_knn_module( + train_features=train_features_sample, train_labels=train_labels_sample, nb_knn=nb_knn + ) + + output = knn_module(test_features[loo_indices].to(device)) + all_preds.append(output[1]) + all_target.append(test_labels[loo_indices]) + output[1] = output[1][:, 4:] + transformed_test_labels = test_labels[loo_indices] - 4 + for k, metric in metrics.items(): + metric_inputs = postprocessors[k](output, transformed_test_labels.to(device)) + metric.update(**metric_inputs) + accumulators[k].update( + preds=metric_inputs["preds"], target=metric_inputs["target"], index=loo_indices.to(device) + ) + + all_preds = torch.cat(all_preds).cpu().detach().numpy() + + all_preds = np.argmax(all_preds, axis=1) + all_target = torch.cat(all_target).cpu().detach().numpy() + + f1 = f1_score(all_target, all_preds, average="macro", labels=[4, 5, 6]) + logger.info(f"Real f1 score: {f1}") + eval_metrics = { + k: metric.compute() for k, metric in metrics.items() + } # next erased by the real f1 score computed above + + for k in nb_knn: + if k not in eval_metrics_dict: + eval_metrics_dict[k] = {} + eval_metrics_dict[k] = {metric: f1 * 100.0 for metric, v in eval_metrics[k].items()} + + if len(train_data_dict) > 1: + return {k: average_metrics(eval_metrics_dict[k]) for k in eval_metrics_dict.keys()} + + return {k: eval_metrics_dict[k] for k in eval_metrics_dict.keys()} + + +def eval_knn_with_model( + model, + output_dir, + train_dataset_str, + val_dataset_str, + nb_knn=(10, 20, 100, 200), + temperature=0.07, + autocast_dtype=torch.float, + metric_type=MetricType.MEAN_ACCURACY, + transform=None, + resize_size=256, + crop_size=224, + batch_size=256, + num_workers=5, + leave_one_out_dataset="", + bag_of_channels=False, + avgpool=False, +): + autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype) + if bag_of_channels: + model = BagOfChannelsModelWithNormalize(model, autocast_ctx, avgpool) + else: + model = ModelWithNormalize(model) + if leave_one_out_dataset == "" or leave_one_out_dataset is None: + leave_one_out = False + else: + leave_one_out = True + + cudnn.benchmark = True + transform = make_classification_eval_cell_transform( + normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, resize_size=resize_size, crop_size=crop_size + ) + + train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform) + results_dict = {} + test_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform) + + with torch.cuda.amp.autocast(dtype=autocast_dtype): + if leave_one_out: + results_dict_knn = eval_knn_with_leave_one_out( + model=model, + leave_one_out_dataset=leave_one_out_dataset, + train_dataset=train_dataset, + test_dataset=test_dataset, + metric_type=metric_type, + nb_knn=nb_knn, + temperature=temperature, + batch_size=batch_size, + num_workers=num_workers, + ) + else: + results_dict_knn = eval_knn( + model=model, + train_dataset=train_dataset, + test_dataset=test_dataset, + metric_type=metric_type, + nb_knn=nb_knn, + temperature=temperature, + batch_size=batch_size, + num_workers=num_workers, + ) + + for knn_ in results_dict_knn.keys(): + top1 = results_dict_knn[knn_]["top-1"] + results_dict[f"{val_dataset_str}_{knn_} Top 1"] = top1 + results_string = f"{val_dataset_str} {knn_} NN classifier result: Top1: {top1:.2f}" + if "top-5" in results_dict_knn[knn_]: + top5 = results_dict_knn[knn_]["top-5"] + results_dict[f"{val_dataset_str}_{knn_} Top 5"] = top5 + results_string += f"Top5: {top5:.2f}" + logger.info(results_string) + + metrics_file_path = os.path.join(output_dir, "results_eval_knn.json") + with open(metrics_file_path, "a") as f: + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + + if distributed.is_enabled(): + torch.distributed.barrier() + return results_dict + + +def eval_knn( + model, + train_dataset, + test_dataset, + metric_type, + nb_knn, + temperature, + batch_size, + num_workers, + few_shot_eval=False, + few_shot_k_or_percent=None, + few_shot_n_tries=1, +): + num_classes = get_num_classes(train_dataset) + train_dataset_dict = create_train_dataset_dict( + train_dataset, + few_shot_eval=few_shot_eval, + few_shot_k_or_percent=few_shot_k_or_percent, + few_shot_n_tries=few_shot_n_tries, + ) + + logger.info("Extracting features for train set...") + + train_data_dict: dict[int, dict[str, torch.Tensor]] = {} + for try_n, dataset in train_dataset_dict.items(): + features, labels = extract_features_celldino(model, dataset, batch_size, num_workers, gather_on_cpu=True) + train_data_dict[try_n] = {"train_features": features, "train_labels": labels} + + test_data_loader = make_data_loader( + dataset=DatasetWithEnumeratedTargets( + test_dataset, pad_dataset=True, num_replicas=distributed.get_global_size() + ), + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=True, + collate_fn=None, + ) + metric_collection = build_metric(metric_type, num_classes=num_classes) + + device = torch.cuda.current_device() + partial_knn_module = partial( + KnnModule, + T=temperature, + device=device, + num_classes=num_classes, + ) + + # ============ evaluation ... ============ + logger.info("Start the k-NN classification.") + eval_metrics_dict = {} + + for try_ in train_data_dict.keys(): + train_features, train_labels = train_data_dict[try_]["train_features"], train_data_dict[try_]["train_labels"] + k_list = sorted(set([el if el < len(train_features) else len(train_features) for el in nb_knn])) + knn_module = partial_knn_module(train_features=train_features, train_labels=train_labels, nb_knn=k_list) + postprocessors, metrics = {k: DictKeysModule([k]) for k in k_list}, { + k: metric_collection.clone() for k in k_list + } + _, eval_metrics, _ = evaluate_with_accumulate( + SequentialWithKwargs(model, knn_module), + test_data_loader, + postprocessors, + metrics, + device, + accumulate_results=False, + ) + for k in k_list: + if k not in eval_metrics_dict: + eval_metrics_dict[k] = {} + eval_metrics_dict[k][try_] = {metric: v.item() * 100.0 for metric, v in eval_metrics[k].items()} + + if len(train_data_dict) > 1: + return {k: average_metrics(eval_metrics_dict[k]) for k in eval_metrics_dict.keys()} + + return {k: eval_metrics_dict[k][0] for k in eval_metrics_dict.keys()} + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + eval_knn_with_model( + model=model, + output_dir=args.output_dir, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + nb_knn=args.nb_knn, + temperature=args.temperature, + autocast_dtype=autocast_dtype, + transform=None, + metric_type=args.metric_type, + batch_size=args.batch_size, + num_workers=5, + leave_one_out_dataset=args.leave_one_out_dataset, + resize_size=args.resize_size, + crop_size=args.crop_size, + avgpool=args.avgpool, + bag_of_channels=args.bag_of_channels, + ) + return 0 + + +if __name__ == "__main__": + description = "k-NN evaluation on models trained with bag of channel strategy or cell dino" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/dinov2/eval/linear_celldino.py b/dinov2/eval/linear_celldino.py new file mode 100644 index 000000000..15148978a --- /dev/null +++ b/dinov2/eval/linear_celldino.py @@ -0,0 +1,1049 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import argparse +from functools import partial +import json +import logging +import os +import sys +from typing import Any, Callable, Dict, Optional, Tuple, List +from enum import Enum +from dataclasses import dataclass + +from sklearn.metrics import f1_score +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from torch.utils.data import TensorDataset +from torch.nn.parallel import DistributedDataParallel + + +from dinov2.data import SamplerType, make_data_loader, make_dataset, DatasetWithEnumeratedTargets +from dinov2.data.transforms_cells import NormalizationType, make_classification_eval_cell_transform +import dinov2.distributed as distributed +from dinov2.eval.metrics import MetricType, build_metric +from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.eval.setup import setup_and_build_model +from dinov2.eval.utils_celldino import ( + evaluate_with_accumulate, + LossType, + average_metrics, + create_train_dataset_dict, + get_num_classes, + extract_features_for_dataset_dict, +) +from dinov2.eval.utils import ModelWithIntermediateLayers +from dinov2.logging import MetricLogger +from dinov2.utils.checkpoint import build_periodic_checkpointer, resume_or_load + +logger = logging.getLogger("dinov2") + +""" +List of changes with respect to the standard linear evaluation script: + +bag of channel option : SCALE ADAPTIVE STRATEGY + +Adam optimizer instead of SGD +Scheduler : two options : onecycleLR or CosineAnnealingLR +the transforms/normalization are different, now calling make_classification_eval_cell_transform +add binary cross entropy loss option for protein localization +change the definition of the num_classes using get_num_classes +change of some default parameters (batch_size, epoch_length, epochs, lrs) +defined n_last_blocks option +avgpool option +leave one out strategy for CHAMMI evaluation +grid search for optimal weight decay +""" + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--test-datasets", + dest="test_dataset_strs", + type=str, + nargs="+", + help="Test datasets, none to reuse the validation dataset", + ) + parser.add_argument( + "--epochs", + type=int, + help="Number of training epochs", + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch Size (per GPU)", + ) + parser.add_argument( + "--num-workers", + type=int, + help="Number de Workers", + ) + parser.add_argument( + "--epoch-length", + type=int, + help="Length of an epoch in number of iterations", + ) + parser.add_argument( + "--save-checkpoint-frequency", + type=int, + help="Number of epochs between two named checkpoint saves.", + ) + parser.add_argument( + "--eval-period-iterations", + type=int, + help="Number of iterations between two evaluations.", + ) + parser.add_argument( + "--learning-rates", + nargs="+", + type=float, + help="Learning rates to grid search.", + ) + parser.add_argument( + "--weight_decays", + nargs="+", + type=float, + help="Weight decays to grid search.", + ) + parser.add_argument( + "--n-last-blocks", + type=int, + help="number of backbone last blocks used for the linear classifier", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Whether to not resume from existing checkpoints", + ) + parser.add_argument( + "--val-metric-type", + type=MetricType, + choices=list(MetricType), + help="Validation metric", + ) + parser.add_argument( + "--test-metric-types", + type=MetricType, + choices=list(MetricType), + nargs="+", + help="Evaluation metric", + ) + parser.add_argument( + "--classifier-fpath", + type=str, + help="Path to a file containing pretrained linear classifiers", + ) + parser.add_argument( + "--val-class-mapping-fpath", + type=str, + help="Path to a file containing a mapping to adjust classifier outputs", + ) + parser.add_argument( + "--test-class-mapping-fpaths", + nargs="+", + type=str, + help="Path to a file containing a mapping to adjust classifier outputs", + ) + parser.add_argument( + "--loss-type", + type=LossType, + help="Cross Entropy or Binary Cross Entropy, default cross entropy loss", + ) + parser.add_argument( + "--bag-of-channels", + action="store_true", + help='Whether to use the "bag of channels" channel adaptive strategy', + ) + parser.add_argument( + "--leave-one-out-dataset", + type=str, + help="Path with indexes to use the leave one out strategy for CHAMMI_CP task 3 and CHAMMI_HPA task 4", + ) + parser.add_argument( + "--crop-size", + type=int, + help="crop size for train and eval", + ) + parser.add_argument( + "--resize-size", + type=int, + help="resize size for image just before crop. 0: no resize", + ) + parser.add_argument( + "--avgpool", + action="store_true", + help="Whether to use average pooling of path tokens in addition to CLS tokens", + ) + parser.add_argument( + "--scheduler", + type=SchedulerType, + help="Scheduler type", + ) + + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + test_dataset_strs=None, + epochs=30, + batch_size=64, + num_workers=8, + epoch_length=145, + save_checkpoint_frequency=1250, + eval_period_iterations=1250, + learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 5e-1, 1.0], + weight_decays=[0.0, 0.0001, 1.0e-05], + val_metric_type=MetricType.MEAN_ACCURACY, + test_metric_types=None, + classifier_fpath=None, + val_class_mapping_fpath=None, + test_class_mapping_fpaths=[None], + loss_type=LossType.CROSS_ENTROPY, + crop_size=384, + resize_size=0, + n_last_blocks=4, + avgpool=False, + scheduler=SchedulerType.COSINE_ANNEALING, + ) + return parser + + +def has_ddp_wrapper(m: nn.Module) -> bool: + return isinstance(m, DistributedDataParallel) + + +def remove_ddp_wrapper(m: nn.Module) -> nn.Module: + return m.module if has_ddp_wrapper(m) else m + + +def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool, bag_of_channels): + intermediate_output = x_tokens_list[-use_n_blocks:] + output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1) + if bag_of_channels: + if use_avgpool: + output = torch.cat( + ( + output, + torch.mean(intermediate_output[-1][0], dim=-2).reshape(intermediate_output[-1][0].shape[0], -1), + # average pooling of patch tokens: average over N, then concatenate channels if single-channel patch model + ), + dim=-1, + ) # concatenate average pooling of patch tokens to concatenated patch tokens + else: + if use_avgpool: + output = torch.cat( + ( + output, + torch.mean(intermediate_output[-1][0], dim=1), # patch tokens + ), + dim=-1, + ) + output = output.reshape(output.shape[0], -1) + return output.float() + + +class LinearClassifier(nn.Module): + """Linear layer to train on top of frozen features""" + + def __init__( + self, out_dim, use_n_blocks, use_avgpool, num_classes=1000, bag_of_channels=False, leave_one_out=False + ): + super().__init__() + self.out_dim = out_dim + self.use_n_blocks = use_n_blocks + self.use_avgpool = use_avgpool + self.num_classes = num_classes + self.bag_of_channels = bag_of_channels + self.leave_one_out = leave_one_out + self.linear = nn.Linear(out_dim, num_classes) + self.linear.weight.data.normal_(mean=0.0, std=0.01) + self.linear.bias.data.zero_() + + def forward(self, x_tokens_list): + if self.leave_one_out: + return self.linear(x_tokens_list) + output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool, self.bag_of_channels) + return self.linear(output) + + +class AllClassifiers(nn.Module): + def __init__(self, classifiers_dict): + super().__init__() + self.classifiers_dict = nn.ModuleDict() + self.classifiers_dict.update(classifiers_dict) + + def forward(self, inputs): + return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()} + + def __len__(self): + return len(self.classifiers_dict) + + +class LinearPostprocessor(nn.Module): + def __init__(self, linear_classifier, class_mapping=None): + super().__init__() + self.linear_classifier = linear_classifier + self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping)) + + def forward(self, samples, targets): + preds = self.linear_classifier(samples) + return { + "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds, + "target": targets, + } + + +def scale_lr(learning_rates, batch_size): + return learning_rates * (batch_size * distributed.get_global_size()) / 256.0 + + +def setup_linear_classifiers( + sample_output, + n_last_blocks_list, + learning_rates, + weight_decays, + batch_size, + num_classes=1000, + bag_of_channels=False, + leave_one_out=False, + avgpool=False, +): + linear_classifiers_dict = nn.ModuleDict() + avgpool_value = avgpool + optim_param_groups = [] + for n in n_last_blocks_list: + for avgpool in [avgpool_value]: + for _lr in learning_rates: + for wd in weight_decays: + lr = scale_lr(_lr, batch_size) + out_dim = create_linear_input( + sample_output, use_n_blocks=n, use_avgpool=avgpool, bag_of_channels=bag_of_channels + ).shape[1] + linear_classifier = LinearClassifier( + out_dim, + use_n_blocks=n, + use_avgpool=avgpool, + num_classes=num_classes, + bag_of_channels=bag_of_channels, + leave_one_out=leave_one_out, + ) + linear_classifier = linear_classifier.cuda() + linear_classifiers_dict[ + f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}_wd_{wd:.2E}".replace(".", "_") + ] = linear_classifier + optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr, "weight_decay": wd}) + + linear_classifiers = AllClassifiers(linear_classifiers_dict) + if distributed.is_enabled(): + linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers) + + return linear_classifiers, optim_param_groups + + +def make_eval_data_loader( + *, + test_dataset_str_or_path_or_loo_dataset, + config, + batch_size, + num_workers, +): + if isinstance(test_dataset_str_or_path_or_loo_dataset, str): + logger.info(f"Loading dataset {test_dataset_str_or_path_or_loo_dataset}") + transform = make_classification_eval_cell_transform( + normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, + resize_size=config["resize_size"], + crop_size=config["crop_size"], + ) + print("transform", transform) + test_dataset = make_dataset(dataset_str=test_dataset_str_or_path_or_loo_dataset, transform=transform) + collate_fn = None + else: + logger.info(f'Making data loader for feature dataset (typical in leave one out evaluation)') + test_dataset = test_dataset_str_or_path_or_loo_dataset + collate_fn = None + class_mapping = None + if hasattr(test_dataset, "get_imagenet_class_mapping"): + class_mapping = test_dataset.get_imagenet_class_mapping() + + test_data_loader = make_data_loader( + dataset=DatasetWithEnumeratedTargets( + test_dataset, pad_dataset=True, num_replicas=distributed.get_global_size() + ), + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=False, + collate_fn=collate_fn, + ) + return test_data_loader, class_mapping + + +@dataclass +class Evaluator: + batch_size: int + num_workers: int + dataset_str_or_path: str + config: Dict + metric_type: MetricType + metrics_file_path: str + training_num_classes: int + save_results_func: Optional[Callable] + val_dataset_loo: Optional[TensorDataset] = None + + def __post_init__(self): + self.main_metric_name = f"{self.dataset_str_or_path}_accuracy" + + if self.val_dataset_loo is not None: + self.dataset_str_or_path = self.val_dataset_loo + + self.data_loader, self.class_mapping = make_eval_data_loader( + test_dataset_str_or_path_or_loo_dataset=self.dataset_str_or_path, + batch_size=self.batch_size, + num_workers=self.num_workers, + config=self.config, + ) + + @torch.no_grad() + def _evaluate_linear_classifiers( + self, + *, + feature_model, + linear_classifiers, + iteration, + prefixstring="", + best_classifier_on_val=None, + accumulate_results=False, + test_mode=False, + ) -> Tuple[Dict[str, Any], Optional[Dict[str, torch.Tensor]]]: + logger.info("running validation !") + + num_classes = len(self.class_mapping) if self.class_mapping is not None else self.training_num_classes + metric = build_metric(self.metric_type, num_classes=num_classes) + postprocessors = { + k: LinearPostprocessor(v, self.class_mapping) for k, v in linear_classifiers.classifiers_dict.items() + } + metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict} + + _, results_dict_temp, accumulated_results = evaluate_with_accumulate( + feature_model, + self.data_loader, + postprocessors, + metrics, + torch.cuda.current_device(), + accumulate_results=accumulate_results, + leave_one_out=self.config["leave_one_out"], + test_mode=test_mode, + ) + + logger.info("") + results_dict = {} + max_accuracy = 0 + best_classifier = "" + for _, (classifier_string, metric) in enumerate(results_dict_temp.items()): + logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}") + if ( + best_classifier_on_val is None and metric["top-1"].item() > max_accuracy + ) or classifier_string == best_classifier_on_val: + max_accuracy = metric["top-1"].item() + best_classifier = classifier_string + + results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy} + + logger.info(f"best classifier: {results_dict['best_classifier']}") + + accumulated_best_results = None + if test_mode: + accumulated_best_results = accumulated_results + elif accumulated_results is not None: + accumulated_best_results = accumulated_results[best_classifier] + + if distributed.is_main_process(): + with open(self.metrics_file_path, "a") as f: + f.write(f"iter: {iteration}\n") + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + f.write("\n") + + return results_dict, accumulated_best_results + + def evaluate_and_maybe_save( + self, + feature_model, + linear_classifiers, + iteration: int, + best_classifier_on_val: Optional[Any] = None, + save_filename_suffix: str = "", + prefixstring: str = "", + test_mode: bool = False, + ): + logger.info(f"Testing on {self.dataset_str_or_path}") + save_results = self.save_results_func is not None + full_results_dict, accumulated_best_results = self._evaluate_linear_classifiers( + feature_model=feature_model, + linear_classifiers=remove_ddp_wrapper(linear_classifiers), + iteration=iteration, + prefixstring=prefixstring, + best_classifier_on_val=best_classifier_on_val, + accumulate_results=save_results, + test_mode=test_mode, + ) + if self.save_results_func is not None: + self.save_results_func( + filename_suffix=f"{self.dataset_str_or_path}{save_filename_suffix}", **accumulated_best_results + ) + + results_dict = { + self.main_metric_name: 100.0 * full_results_dict["best_classifier"]["accuracy"], + "best_classifier": full_results_dict["best_classifier"]["name"], + } + return results_dict, accumulated_best_results + + +def make_evaluators( + config: Dict, + val_metric_type: MetricType, + val_dataset: str, + metric_type: MetricType, + metrics_file_path: str, + training_num_classes: int, + save_results_func: Optional[Callable], + val_dataset_loo: Optional[TensorDataset] = None, +): + test_metric_types = config["test_metric_types"] + test_dataset_strs = config["test_datasets"] + if test_dataset_strs is None: + test_dataset_strs = (config["val_dataset"],) + if test_metric_types is None: + test_metric_types = (val_metric_type,) + else: + assert len(test_metric_types) == len(config["test_datasets"]) + + val_evaluator, *test_evaluators = [ + Evaluator( + dataset_str_or_path=dataset_str_or_path, + batch_size=config["batch_size"], + num_workers=config["num_workers"], + config=config, + metric_type=metric_type, + metrics_file_path=metrics_file_path, + training_num_classes=training_num_classes, + save_results_func=save_results_func, + val_dataset_loo=val_dataset_loo, + ) + for dataset_str_or_path, metric_type in zip( + (val_dataset,) + tuple(test_dataset_strs), + (val_metric_type,) + tuple(test_metric_types), + ) + ] + return val_evaluator, test_evaluators + + +class SchedulerType(Enum): + COSINE_ANNEALING = "cosine_annealing" + ONE_CYCLE = "one_cycle" + + def get_scheduler(self, optimizer, optim_param_groups, epoch_length, epochs, max_iter): + if self == SchedulerType.ONE_CYCLE: + lr_list = [optim_param_groups[i]["lr"] for i in range(len(optim_param_groups))] + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, max_lr=lr_list, steps_per_epoch=epoch_length, epochs=epochs + ) + else: + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0) + print("CosineAnnealingLR scheduler") + return scheduler + + +def setup_linear_training( + *, + config: Dict, + sample_output: torch.Tensor, + training_num_classes: int, + checkpoint_output_dir: str, +): + linear_classifiers, optim_param_groups = setup_linear_classifiers( + sample_output, + config["n_last_blocks_list"], + config["learning_rates"], + config["weight_decays"], + config["batch_size"], + training_num_classes, + config["bag_of_channels"], + config["leave_one_out"], + config["avgpool"], + ) + max_iter = config["epochs"] * config["epoch_length"] + optimizer = torch.optim.AdamW(optim_param_groups, weight_decay=0) + + scheduler = config["scheduler"].get_scheduler( + optimizer=optimizer, + optim_param_groups=optim_param_groups, + epoch_length=config["epoch_length"], + epochs=config["epochs"], + max_iter=max_iter, + ) + checkpoint_period = config["save_checkpoint_iterations"] or config["epoch_length"] + periodic_checkpointer = build_periodic_checkpointer( + linear_classifiers, + checkpoint_output_dir, + optimizer=optimizer, + scheduler=scheduler, + period=checkpoint_period, + max_iter=max_iter, + max_to_keep=None, + ) + checkpoint = resume_or_load(periodic_checkpointer, config["classifier_fpath"] or "", resume=config["resume"]) + + start_iter = checkpoint.get("iteration", -1) + 1 + best_accuracy = checkpoint.get("best_accuracy", -1) + + if config["loss_type"] == LossType.BINARY_CROSS_ENTROPY: + criterion = nn.BCEWithLogitsLoss() + else: + criterion = nn.CrossEntropyLoss() + + return ( + linear_classifiers, + start_iter, + max_iter, + criterion, + optimizer, + scheduler, + periodic_checkpointer, + best_accuracy, + ) + + +def train_linear_classifiers( + *, + feature_model, + train_dataset, + train_config: Dict, + training_num_classes: int, + val_evaluator: Evaluator, + checkpoint_output_dir: str, + sample_output: Optional[torch.Tensor] = None, +): + + if train_config["leave_one_out"]: + assert sample_output is not None, "sample_output should be passed as argument when using leave_one_out." + else: + sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda()) + + ( + linear_classifiers, + start_iter, + max_iter, + criterion, + optimizer, + scheduler, + periodic_checkpointer, + best_accuracy, + ) = setup_linear_training( + config=train_config, + sample_output=sample_output, + training_num_classes=training_num_classes, + checkpoint_output_dir=checkpoint_output_dir, + ) + + sampler_type = SamplerType.INFINITE + train_data_loader = make_data_loader( + dataset=train_dataset, + batch_size=train_config["batch_size"], + num_workers=train_config["num_workers"], + shuffle=True, + seed=0, + sampler_type=sampler_type, + sampler_advance=start_iter, + drop_last=True, + persistent_workers=True, + ) + eval_period = train_config["eval_period_iterations"] or train_config["epoch_length"] + iteration = start_iter + logger.info("Starting training from iteration {}".format(start_iter)) + metric_logger = MetricLogger(delimiter=" ") + header = "Training" + + for data, labels in metric_logger.log_every( + train_data_loader, + 10, + header, + max_iter, + start_iter, + ): + data = data.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + if not train_config["leave_one_out"]: + in_classifier = feature_model(data) + else: + in_classifier = data + + outputs = linear_classifiers(in_classifier) + + if len(labels.shape) > 1: + labels = labels.float() + losses = {f"loss_{k}": criterion(v, labels) for k, v in outputs.items()} + loss = sum(losses.values()) + + optimizer.zero_grad() + loss.backward() + + optimizer.step() + scheduler.step() + + if iteration % 10 == 0: + torch.cuda.synchronize() + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + periodic_checkpointer.step(iteration=iteration, best_accuracy=best_accuracy) + + if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1: + val_results_dict, _ = val_evaluator.evaluate_and_maybe_save( + feature_model=feature_model, + linear_classifiers=linear_classifiers, + prefixstring=f"ITER: {iteration}", + iteration=iteration, + ) + val_accuracy = val_results_dict[val_evaluator.main_metric_name] + if val_accuracy >= best_accuracy: + best_accuracy = val_accuracy + periodic_checkpointer.save_best(iteration=iteration, best_accuracy=best_accuracy) + torch.distributed.barrier() + + iteration = iteration + 1 + + return feature_model, linear_classifiers, iteration, periodic_checkpointer + + +def eval_linear_with_model( + model, + output_dir, + train_dataset_str, + val_dataset_str, + batch_size, + epochs, + epoch_length, + num_workers, + save_checkpoint_frequency, + eval_period_iterations, + learning_rates, + weight_decays, + autocast_dtype, + test_dataset_strs=None, + resume=True, + classifier_fpath=None, + val_metric_type=MetricType.MEAN_ACCURACY, + test_metric_types=None, + loss_type=LossType.CROSS_ENTROPY, + bag_of_channels=False, + leave_one_out_dataset="", + resize_size=0, + crop_size=384, + n_last_blocks=4, + avgpool=False, + scheduler=SchedulerType.COSINE_ANNEALING, +): + + if leave_one_out_dataset == "" or leave_one_out_dataset is None: + leave_one_out = False + else: + logger.info("Reading the leave-one-out label metadata.") + + leave_one_out_indices = {} + metadata = pd.read_csv(leave_one_out_dataset) + if "HPA" in leave_one_out_dataset: + metadata = metadata[metadata["Task_three"]].reset_index() + leave_one_out_label_type = "cell_type" + else: + metadata = metadata[metadata["Task_four"]].reset_index() + leave_one_out_label_type = "Plate" + leave_one_out_labels = metadata[leave_one_out_label_type].unique() + + for leave_one_out_label in leave_one_out_labels: + leave_one_out_indices[leave_one_out_label] = np.array( + metadata[metadata[leave_one_out_label_type] == leave_one_out_label].index.values + ) + + leave_one_out = True + + train_transform = make_classification_eval_cell_transform( + normalization_type=NormalizationType.SELF_NORM_AUG_DECODER, crop_size=crop_size, resize_size=resize_size + ) + print("train_transform", train_transform) + train_dataset = make_dataset( + dataset_str=train_dataset_str, + transform=train_transform, + ) + + training_num_classes = get_num_classes(train_dataset) + if leave_one_out: + training_num_classes += train_dataset.num_additional_labels_loo_eval + train_dataset_dict = create_train_dataset_dict(train_dataset) + n_last_blocks_list = [n_last_blocks] + n_last_blocks = max(n_last_blocks_list) + dataset_use_cache = True + autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype) + feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx) + + if bag_of_channels: + sample = train_dataset[0][0].unsqueeze(0) + sample_output = feature_model(sample.cuda()) + + if leave_one_out: + loo_dict = {} + train_data_dict = extract_features_for_dataset_dict( + feature_model, + train_dataset_dict, + batch_size, + num_workers, + gather_on_cpu=True, + avgpool=avgpool, + ) + val_dataset = make_dataset( + dataset_str=val_dataset_str, + transform=make_classification_eval_cell_transform( + normalization_type=NormalizationType.SELF_NORM_CENTER_CROP, crop_size=crop_size, resize_size=resize_size + ), + ) + val_dataset_dict = create_train_dataset_dict(val_dataset) + val_data_dict = extract_features_for_dataset_dict( + feature_model, + val_dataset_dict, + batch_size, + num_workers, + gather_on_cpu=True, + avgpool=avgpool, + ) + + train_features = train_data_dict[0]["train_features"] + train_labels = train_data_dict[0]["train_labels"] + val_features = val_data_dict[0]["train_features"] + val_labels = val_data_dict[0]["train_labels"] + + for loo_label, loo_indices in leave_one_out_indices.items(): + loo_for_training_indices = torch.ones(val_features.shape[0], dtype=bool) + loo_for_training_indices[loo_indices] = False + loo_for_val_indices = torch.zeros(val_features.shape[0], dtype=bool) + loo_for_val_indices[loo_indices] = True + + loo_dict[loo_label] = { + "train_features": torch.cat([train_features, val_features[loo_for_training_indices]]), + "train_labels": torch.cat([train_labels, val_labels[loo_for_training_indices]]), + "val_features": val_features[loo_indices], + "val_labels": val_labels[loo_indices], + } + save_results_func = None + # if config.save_results: + # save_results_func = partial(default_save_results_func, output_dir=output_dir) + + metrics_file_path = os.path.join(output_dir, "results_eval_linear.json") + periodic_checkpointers: list = [] + + train_config = { + "learning_rates": learning_rates, + "weight_decays": weight_decays, + "batch_size": batch_size, + "num_workers": num_workers, + "dataset_use_cache": dataset_use_cache, + "eval_period_iterations": eval_period_iterations, + "epoch_length": epoch_length, + "leave_one_out": leave_one_out, + "bag_of_channels": bag_of_channels, + "n_last_blocks_list": n_last_blocks_list, + "epochs": epochs, + "loss_type": loss_type, + "resume": resume, + "save_checkpoint_iterations": save_checkpoint_frequency, + "classifier_fpath": classifier_fpath, + "avgpool": avgpool, + "scheduler": scheduler, + } + config = { + "test_metric_types": test_metric_types, + "test_datasets": test_dataset_strs, + "val_metric_types": val_metric_type, + "val_dataset": val_dataset_str, + "batch_size": batch_size, + "num_workers": num_workers, + "leave_one_out": leave_one_out, + "crop_size": crop_size, + "resize_size": resize_size, + } + if not leave_one_out: + val_evaluator, test_evaluators = make_evaluators( + config=config, + val_metric_type=val_metric_type, + val_dataset=val_dataset_str, + metric_type=test_metric_types, + metrics_file_path=metrics_file_path, + training_num_classes=training_num_classes, + save_results_func=save_results_func, + ) + results_dict = {} + + for _try in train_dataset_dict.keys(): + if len(train_dataset_dict) > 1: + checkpoint_output_dir = os.path.join(output_dir, f"checkpoints_{_try}") + save_filename_suffix = f"_{_try}" + else: + checkpoint_output_dir, save_filename_suffix = output_dir, "" + os.makedirs(checkpoint_output_dir, exist_ok=True) + + feature_model, linear_classifiers, iteration, periodic_checkpointer = train_linear_classifiers( + train_config=train_config, + feature_model=feature_model, + train_dataset=train_dataset_dict[_try], + training_num_classes=training_num_classes, + val_evaluator=val_evaluator, + checkpoint_output_dir=checkpoint_output_dir, + ) + periodic_checkpointers.append(periodic_checkpointer) + results_dict[_try], _ = val_evaluator.evaluate_and_maybe_save( + feature_model=feature_model, + linear_classifiers=linear_classifiers, + iteration=iteration, + save_filename_suffix=save_filename_suffix, + ) + for test_evaluator in test_evaluators: + eval_results_dict, _ = test_evaluator.evaluate_and_maybe_save( + feature_model=feature_model, + linear_classifiers=linear_classifiers, + iteration=iteration, + best_classifier_on_val=results_dict[_try]["best_classifier"], + save_filename_suffix=save_filename_suffix, + ) + results_dict[_try] = {**eval_results_dict, **results_dict[_try]} + if len(train_dataset_dict) > 1: + results_dict = average_metrics(results_dict, ignore_keys=["best_classifier"]) + else: + results_dict = {**results_dict[_try]} + else: # if leave one out is True + test_results_dict = {} + for loo_label in loo_dict.keys(): + + checkpoint_output_dir, save_filename_suffix = os.path.join(output_dir, f"checkpoints_{loo_label}"), "" + os.makedirs(checkpoint_output_dir, exist_ok=True) + + train_dataset_loo = TensorDataset( + loo_dict[loo_label]["train_features"], loo_dict[loo_label]["train_labels"] + ) + + logger.info(f"Creating leave_one_out evaluators. loo_label: {loo_label}") + val_dataset_loo = TensorDataset(loo_dict[loo_label]["val_features"], loo_dict[loo_label]["val_labels"]) + val_evaluators_loo, _ = make_evaluators( + config=config, + val_metric_type=val_metric_type, + val_dataset="loo", + metric_type=test_metric_types, + metrics_file_path=metrics_file_path, + training_num_classes=training_num_classes, + save_results_func=save_results_func, + val_dataset_loo=val_dataset_loo, + ) + feature_model, linear_classifiers, iteration, periodic_checkpointer = train_linear_classifiers( + feature_model=feature_model, + train_dataset=train_dataset_loo, + train_config=train_config, + training_num_classes=training_num_classes, + val_evaluator=val_evaluators_loo, + checkpoint_output_dir=checkpoint_output_dir, + sample_output=sample_output, + ) + periodic_checkpointers.append(periodic_checkpointer) + _, test_results_dict[loo_label] = val_evaluators_loo.evaluate_and_maybe_save( + feature_model=feature_model, + linear_classifiers=linear_classifiers, + iteration=iteration, + save_filename_suffix=save_filename_suffix, + test_mode=True, + ) + classifier_names = test_results_dict[loo_label].keys() + results_dict = {k: [[], []] for k in classifier_names} + for ll in test_results_dict.keys(): + for k in classifier_names: + results_dict[k][0].append(test_results_dict[ll][k][0]) + results_dict[k][1].append(test_results_dict[ll][k][1]) + for k in classifier_names: + results_dict[k] = [ + np.argmax(torch.cat(results_dict[k][0]).cpu().detach().numpy(), axis=1), + torch.cat(results_dict[k][1]).cpu().detach().numpy(), + ] + results_dict[k] = f1_score(results_dict[k][1], results_dict[k][0], average="macro", labels=[4, 5, 6]) + logger.info( + f"Best performance is for {max(results_dict, key=results_dict.get)}, with F1-Score of {results_dict[max(results_dict, key=results_dict.get)]}" + ) + + logger.info("Test Results Dict " + str(results_dict)) + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + eval_linear_with_model( + model=model, + output_dir=args.output_dir, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + test_dataset_strs=args.test_dataset_strs, + batch_size=args.batch_size, + epochs=args.epochs, + epoch_length=args.epoch_length, + num_workers=args.num_workers, + save_checkpoint_frequency=args.save_checkpoint_frequency, + eval_period_iterations=args.eval_period_iterations, + learning_rates=args.learning_rates, + weight_decays=args.weight_decays, + autocast_dtype=autocast_dtype, + resume=not args.no_resume, + classifier_fpath=args.classifier_fpath, + val_metric_type=args.val_metric_type, + test_metric_types=args.test_metric_types, + loss_type=args.loss_type, + bag_of_channels=args.bag_of_channels, + leave_one_out_dataset=args.leave_one_out_dataset, + crop_size=args.crop_size, + resize_size=args.resize_size, + n_last_blocks=args.n_last_blocks, + avgpool=args.avgpool, + scheduler=args.scheduler, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 linear_celldino evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/dinov2/eval/metrics.py b/dinov2/eval/metrics.py index 52be81a85..c26db7b46 100644 --- a/dinov2/eval/metrics.py +++ b/dinov2/eval/metrics.py @@ -10,7 +10,7 @@ import torch from torch import Tensor from torchmetrics import Metric, MetricCollection -from torchmetrics.classification import MulticlassAccuracy +from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MultilabelF1Score from torchmetrics.utilities.data import dim_zero_cat, select_topk @@ -22,6 +22,8 @@ class MetricType(Enum): MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" PER_CLASS_ACCURACY = "per_class_accuracy" IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" + MEAN_PER_CLASS_MULTICLASS_F1 = "mean_per_class_multiclass_f1" + MEAN_PER_CLASS_MULTILABEL_F1 = "mean_per_class_multilabel_f1" @property def accuracy_averaging(self): @@ -52,6 +54,10 @@ def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tupl num_classes=num_classes, ks=(1, 5) if ks is None else ks, ) + elif metric_type == MetricType.MEAN_PER_CLASS_MULTILABEL_F1: + return MetricCollection({"top-1": MultilabelF1Score(num_labels=int(num_classes), average="macro")}) + elif metric_type == MetricType.MEAN_PER_CLASS_MULTICLASS_F1: + return MetricCollection({"top-1": MulticlassF1Score(num_classes=int(num_classes), average="macro")}) raise ValueError(f"Unknown metric type {metric_type}") diff --git a/dinov2/eval/utils_celldino.py b/dinov2/eval/utils_celldino.py new file mode 100644 index 000000000..866f27196 --- /dev/null +++ b/dinov2/eval/utils_celldino.py @@ -0,0 +1,451 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import logging +from typing import Callable, Dict, Optional, Any, List + +import torch +from torch import nn +from torchmetrics import MetricCollection + +from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader +from dinov2.data import NoOpAccumulator, ResultsAccumulator +import dinov2.distributed as distributed +from dinov2.logging import MetricLogger +from enum import Enum +from torch.utils.data import Subset +from torchvision.datasets.vision import StandardTransform +import numpy as np + +logger = logging.getLogger("dinov2") + + +class LossType(Enum): + CROSS_ENTROPY = "cross_entropy" + BINARY_CROSS_ENTROPY = "binary_cross_entropy" + + +class BagOfChannelsModelWithNormalize(nn.Module): + def __init__(self, model, autocast_ctx, avgpool, n_last_blocks=1): + super().__init__() + self.model = model + self.autocast_ctx = autocast_ctx + self.n_last_blocks = n_last_blocks + self.avgpool = avgpool + + def forward(self, samples): + with self.autocast_ctx(): + features = self.model.get_intermediate_layers(samples, self.n_last_blocks, return_class_token=True) + output = create_linear_input(features, self.avgpool, use_n_blocks=self.n_last_blocks) + return nn.functional.normalize(output, dim=1, p=2) + + +@torch.inference_mode() +def evaluate_with_accumulate( + model: nn.Module, + data_loader, + postprocessors: Dict[str, nn.Module], + metrics: Dict[str, MetricCollection], + device: torch.device, + criterion: Optional[nn.Module] = None, + test_mode: bool = False, + accumulate_results: bool = False, + leave_one_out: bool = False, +): + model.eval() + + if test_mode: + output_tensor = {k: [] for k in postprocessors.keys()} + target_tensor = {k: [] for k in postprocessors.keys()} + + if criterion is not None: + criterion.eval() + + accumulator_class = ResultsAccumulator if accumulate_results else NoOpAccumulator + accumulators = {k: accumulator_class() for k in postprocessors.keys()} + + for metric in metrics.values(): + metric = metric.to(device) + + metric_logger = MetricLogger(delimiter=" ") + header = "Test:" + + for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header): + if isinstance(targets, list): + index = targets[0] + targets = targets[1] + samples, targets, index = samples[index >= 0], targets[index >= 0], index[index >= 0] + if len(index) == 0: + continue + + outputs = samples.to(device) if leave_one_out else model(samples.to(device)) + targets = targets.to(device) + + if criterion is not None: + loss = criterion(outputs, targets) + metric_logger.update(loss=loss.item()) + + for k, metric in metrics.items(): + metric_inputs = postprocessors[k](outputs, targets) + metric.update(**metric_inputs) + if test_mode: + output_tensor[k].append(metric_inputs["preds"]) + target_tensor[k].append(metric_inputs["target"]) + accumulators[k].update(preds=metric_inputs["preds"], target=metric_inputs["target"], index=index) + + metric_logger.synchronize_between_processes() + logger.info(f"Averaged stats: {metric_logger}") + + stats = {k: metric.compute() for k, metric in metrics.items()} + metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + # accumulator.accumulate() returns None for the NoOpAccumulator + accumulated_results = {k: accumulator.accumulate() for k, accumulator in accumulators.items()} + if test_mode: + for k in postprocessors.keys(): + output_tensor[k] = torch.cat(output_tensor[k]) + target_tensor[k] = torch.cat(target_tensor[k]) + accumulated_results = {k: [output_tensor[k], target_tensor[k]] for k in postprocessors.keys()} + + if accumulate_results: + return metric_logger_stats, stats + return metric_logger_stats, stats, accumulated_results + + +def all_gather_and_flatten(tensor_rank): + tensor_all_ranks = torch.empty( + distributed.get_global_size(), + *tensor_rank.shape, + dtype=tensor_rank.dtype, + device=tensor_rank.device, + ) + tensor_list = list(tensor_all_ranks.unbind(0)) + torch.distributed.all_gather(tensor_list, tensor_rank.contiguous()) + return tensor_all_ranks.flatten(end_dim=1) + + +def extract_features_celldino( + model, dataset, batch_size, num_workers, gather_on_cpu=False, shuffle=False, avgpool=False +): + dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset) + sample_count = len(dataset_with_enumerated_targets) + data_loader = make_data_loader( + dataset=dataset_with_enumerated_targets, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=shuffle, + ) + return extract_features_with_dataloader_celldino(model, data_loader, sample_count, gather_on_cpu, avgpool=avgpool) + + +@torch.inference_mode() +def extract_features_with_dataloader_celldino(model, data_loader, sample_count, gather_on_cpu=False, avgpool=False): + gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda") + metric_logger = MetricLogger(delimiter=" ") + features, all_labels = None, None + for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10): + samples = samples.cuda(non_blocking=True) + labels_rank = labels_rank.cuda(non_blocking=True) + index = index.cuda(non_blocking=True) + feat = model(samples) + if isinstance(samples, list) or isinstance(feat, tuple): + features_rank = create_linear_input(feat, avgpool=avgpool) + else: + features_rank = feat + + # init storage feature matrix + if features is None: + features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device) + labels_shape = list(labels_rank.shape) + labels_shape[0] = sample_count + all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) + logger.info(f"Storing features into tensor of shape {features.shape}") + + # share indexes, features and labels between processes + index_all = all_gather_and_flatten(index).to(gather_device) + features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device) + labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device) + + # update storage feature matrix + if len(index_all) > 0: + features.index_copy_(0, index_all, features_all_ranks) + all_labels.index_copy_(0, index_all, labels_all_ranks) + + logger.info(f"Features shape: {tuple(features.shape)}") + logger.info(f"Labels shape: {tuple(all_labels.shape)}") + + assert torch.all(all_labels > -1) + + return features, all_labels + + +def create_linear_input(x_tokens_list, avgpool=False, use_n_blocks=1): + intermediate_output = x_tokens_list[-use_n_blocks:] + output = torch.cat( + [class_token for _, class_token in intermediate_output], dim=-1 + ) # concatenate class tokens of the last n blocks + if avgpool: + output = torch.cat( + ( + output, + torch.mean(intermediate_output[-1][0], dim=-2).reshape( + intermediate_output[-1][0].shape[0], -1 + ), # average pooling of patch tokens: average over N, then concatenate channels if single-channel patch model + ), + dim=-1, + ) # concatenate average pooling of patch tokens to concatenated patch tokens + output = output.reshape(output.shape[0], -1) + + return output.float() + + +def get_target_transform(dataset) -> Optional[Callable]: + if hasattr(dataset, "transforms"): + if isinstance(dataset.transforms, StandardTransform): + return dataset.transforms.target_transform + raise ValueError("Dataset has a non-standard .transforms property") + if hasattr(dataset, "target_transform"): + return dataset.target_transform + return None + + +def get_labels(dataset) -> torch.Tensor: + """ + Get the labels of a classification dataset, as a Tensor, using the `get_targets` method + if it is present or loading the labels one by one with `get_target`, if it exists. + If the dataset has a target transform, iterate over the whole dataset to get the + transformed labels for each element, then stack them as a torch tensor. + """ + logger.info("Getting dataset labels ...") + if hasattr(dataset, "get_targets") or hasattr(dataset, "get_target"): + if hasattr(dataset, "get_targets"): # Returns a np.array + labels = dataset.get_targets() + elif hasattr(dataset, "get_target"): + labels = [dataset.get_target(i) for i in range(len(dataset))] + target_transform = get_target_transform(dataset) + if target_transform is not None: + labels = [target_transform(label) for label in labels] + else: + # Target transform is applied in this case + labels = [dataset[i][1] for i in range(len(dataset))] + return torch.stack([torch.tensor(label, dtype=int) for label in labels]) + + +def get_num_classes(dataset) -> int: + """ + Get the labels of a dataset and compute the number of classes + """ + labels = get_labels(dataset) + if len(labels.shape) > 1: + return int(labels.shape[1]) + return int(labels.max() + 1) + + +def average_metrics(eval_metrics_dict: dict[Any, dict[str, torch.Tensor]], ignore_keys: List[str] = []): + """ + Function that computes the average and the std on a metrics dict. + A linear evaluation dictionary contains "best_classifier", + so this specific key is removed for computing aggregated metrics. + """ + output_metrics_dict = {} + metrics = [metric for metric in eval_metrics_dict[0].keys() if metric not in ignore_keys] + for metric in metrics: + stats_tensor = torch.tensor([stat[metric] for stat in eval_metrics_dict.values()]) + output_metrics_dict[metric + "_mean"] = stats_tensor.mean().item() + output_metrics_dict[metric + "_std"] = torch.std(stats_tensor).item() + + return output_metrics_dict + + +def create_class_indices_mapping(labels: torch.Tensor) -> dict[int, torch.Tensor]: + """ + Efficiently creates a mapping between the labels and tensors containing + the indices of all the dataset elements that share this label. + In the case of multiple labels, it is not guaranteed that there + will be exactly the specified percentage of labels. + """ + if len(labels.shape) > 1: # labels are a one-hot encoding + assert len(labels.shape) == 2 + sorted_labels, indices = torch.nonzero(labels.T, as_tuple=True) + else: + sorted_labels, indices = torch.sort(labels, stable=True) + unique_labels, counts = torch.unique_consecutive(sorted_labels, return_counts=True) + mapping = dict(zip(unique_labels.tolist(), torch.split(indices, counts.tolist()))) + return mapping + + +def _shuffle_dataset(dataset: torch.Tensor, seed: int = 0): + """ + Shuffling a dataset by subsetting it with a random permutation of its indices + """ + random_generator = torch.Generator() + random_generator.manual_seed(seed) + random_indices = torch.randperm(len(dataset), generator=random_generator) + return Subset(dataset, random_indices) + + +def _subset_dataset_per_class( + class_indices_mapping: dict[int, torch.Tensor], + n_or_percent_per_class: float, + dataset_size: int, + seed: int = 0, + is_percent: bool = False, +) -> torch.Tensor: + """ + Helper function to select a percentage of a dataset, equally distributed across classes, + or to take the same number of elements from each class of the dataset. + Returns a boolean mask tensor being True at indices of selected elements + """ + + random_generator = torch.Generator() + random_generator.manual_seed(seed) + + final_indices_bool = torch.zeros(dataset_size, dtype=bool) + for class_indices in class_indices_mapping.values(): + # Select at least one element + n_for_class = max(int(len(class_indices) * n_or_percent_per_class), 1) if is_percent else n_or_percent_per_class + assert isinstance(n_for_class, int) + filtered_index = torch.randperm(len(class_indices), generator=random_generator)[:n_for_class] + final_indices_bool[class_indices[filtered_index]] = True + return final_indices_bool + + +def _multilabel_rebalance_subset( + class_indices_mapping: dict[int, torch.Tensor], + n_or_percent_per_class: float, + labels: torch.Tensor, + indices_bool: torch.Tensor, + dataset_size: int, + seed: int = 0, +) -> torch.Tensor: + """ + Helper function to refine a subset of a multi-label dataset (indices_bool) + to better match a target percentage of labels. + Returns a boolean mask tensor being True at indices of selected elements. + """ + + # Compute the number of selected labels in indices_bool + num_total_labels = labels.sum() + num_wanted_labels = int(num_total_labels * n_or_percent_per_class) + num_selected_labels = (labels[indices_bool] > 0).sum() + logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}") + + # Compute a new percentage and new set selecting less images, therefore less labels, to match approximatelly the exact percentage of labels selected + n_or_percent_per_class = n_or_percent_per_class / (num_selected_labels / num_wanted_labels) + final_indices_bool = _subset_dataset_per_class( + class_indices_mapping, n_or_percent_per_class, dataset_size, seed, True + ) + + # Compute the number of labels finally used + num_selected_labels = (labels[final_indices_bool] > 0).sum() + logger.info(f" {num_selected_labels} labels instead of {num_wanted_labels}") + + return final_indices_bool + + +def split_train_val_datasets(train_dataset, split_percentage: float = 0.1, shuffle_train: bool = True): + """ + Splitting a percent of the train dataset to choose hyperparameters, taking the same percentage for each class. + If `shuffle` is False, taking the first elements of each class as the validaton set. + """ + assert 0 < split_percentage < 1 + logger.info(f"Selecting {int(split_percentage * 100)}% of the train dataset as the validation set") + if shuffle_train: + logger.info("Shuffling train dataset before splitting in train and validation sets") + train_dataset = _shuffle_dataset(train_dataset) + train_labels = get_labels(train_dataset) + class_indices_mapping = create_class_indices_mapping(train_labels) + val_mask = torch.zeros(len(train_labels), dtype=bool) + for class_indices in class_indices_mapping.values(): + # If there is only one element, it goes in the train set + n_for_val = max(1, int(split_percentage * len(class_indices))) if len(class_indices) > 1 else 0 + val_mask[class_indices[:n_for_val]] = True + + val_dataset = Subset(train_dataset, val_mask.nonzero().flatten()) + train_dataset = Subset(train_dataset, (~val_mask).nonzero().flatten()) + return train_dataset, val_dataset + + +def create_train_dataset_dict( + train_dataset, + few_shot_eval: bool = False, + few_shot_k_or_percent=None, + few_shot_n_tries: int = 1, +) -> dict[int, dict[int, Any]]: + """ + Randomly split a dataset for few-shot evaluation, with `few_shot_k_or_percent` being + n elements or x% of a class. Produces a dict, which keys are number of random "tries" + and values are the dataset subset for this "try". + + Format is {"nth-try": dataset} + """ + if few_shot_eval is False: + assert few_shot_k_or_percent is None + assert few_shot_n_tries == 1 + return {0: train_dataset} + + assert few_shot_k_or_percent is not None + train_labels = get_labels(train_dataset) + class_indices_mapping = create_class_indices_mapping(train_labels) + train_dataset_dict: dict[int, Any] = {} + is_percent = few_shot_k_or_percent < 1 + if not is_percent: + few_shot_k_or_percent = int(few_shot_k_or_percent) + + for t in range(few_shot_n_tries): + t_subset_bool = _subset_dataset_per_class( + class_indices_mapping=class_indices_mapping, + n_or_percent_per_class=few_shot_k_or_percent, + dataset_size=len(train_labels), + is_percent=is_percent, + seed=t, + ) + if len(train_labels.shape) > 1 and is_percent: + t_subset_bool = _multilabel_rebalance_subset( + class_indices_mapping=class_indices_mapping, + n_or_percent_per_class=few_shot_k_or_percent, + dataset_size=len(train_labels), + labels=train_labels, + indices_bool=t_subset_bool, + seed=t, + ) + train_dataset_dict[t] = Subset(train_dataset, t_subset_bool.nonzero().flatten()) + return train_dataset_dict + + +def extract_features_for_dataset_dict( + model, + dataset_dict: dict[int, dict[int, Any]], + batch_size: int, + num_workers: int, + gather_on_cpu=False, + avgpool=False, +) -> dict[int, dict[str, torch.Tensor]]: + """ + Extract features for each subset of dataset in the context of few-shot evaluations + """ + few_shot_data_dict: dict[int, dict[str, torch.Tensor]] = {} + for try_n, dataset in dataset_dict.items(): + features, labels = extract_features_celldino( + model, dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu, avgpool=avgpool + ) + few_shot_data_dict[try_n] = {"train_features": features, "train_labels": labels} + return few_shot_data_dict + + +def pad_multilabel_and_collate(batch, pad_value=-1): + """ + This method pads and collates a batch of (image, (index, target)) tuples, coming from + DatasetWithEnumeratedTargets, with targets that are list of potentially varying sizes. + The targets are padded to the length of the longest target list in the batch. + """ + maxlen = max(len(targets) for _, (_, targets) in batch) + padded_batch = [ + (image, (index, np.pad(targets, (0, maxlen - len(targets)), constant_values=pad_value))) + for image, (index, targets) in batch + ] + return torch.utils.data.default_collate(padded_batch) diff --git a/dinov2/hub/backbones.py b/dinov2/hub/backbones.py index 53fe83719..bd0a730c2 100644 --- a/dinov2/hub/backbones.py +++ b/dinov2/hub/backbones.py @@ -4,7 +4,7 @@ # found in the LICENSE file in the root directory of this source tree. from enum import Enum -from typing import Union +from typing import Optional, Union import torch @@ -13,6 +13,7 @@ class Weights(Enum): LVD142M = "LVD142M" + CELL_DINO = "CELL-DINO" def _make_dinov2_model( @@ -28,6 +29,8 @@ def _make_dinov2_model( interpolate_offset: float = 0.1, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, + pretrained_url: Optional[str] = None, + pretrained_path: Optional[str] = None, **kwargs, ): from ..models import vision_transformer as vits @@ -53,9 +56,16 @@ def _make_dinov2_model( model = vits.__dict__[arch_name](**vit_kwargs) if pretrained: - model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) - url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" - state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if pretrained_path is not None: + state_dict = torch.load(pretrained_path, map_location="cpu") + else: + if pretrained_url is None: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + else: + url = pretrained_url + assert url is not None + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") model.load_state_dict(state_dict, strict=True) return model @@ -154,3 +164,61 @@ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = interpolate_offset=0.0, **kwargs, ) + + +def celldino_hpa_vitl16( + *, + pretrained_url: Optional[str] = None, + pretrained_path: Optional[str] = None, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.CELL_DINO, + in_channels: int = 4, + **kwargs, +): + """ + Cell-DINO ViT-L/16 model dataset pretrained on HPA single cell dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + patch_size=16, + img_size=224, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + block_chunks=4, + pretrained_url=pretrained_url, + pretrained_path=pretrained_path, + pretrained=pretrained, + weights=weights, + in_chans=in_channels, + **kwargs, + ) + + +def celldino_cp_vits8( + *, + pretrained_url: Optional[str] = None, + pretrained_path: Optional[str] = None, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.CELL_DINO, + in_channels: int = 5, + **kwargs, +): + """ + Cell-DINO ViT-S/8 model dataset pretrained on the combined cell painting dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + patch_size=8, + img_size=128, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + block_chunks=4, + pretrained_url=pretrained_url, + pretrained_path=pretrained_path, + pretrained=pretrained, + weights=weights, + in_chans=in_channels, + **kwargs, + ) diff --git a/dinov2/models/__init__.py b/dinov2/models/__init__.py index 3fdff20ba..817a63aeb 100644 --- a/dinov2/models/__init__.py +++ b/dinov2/models/__init__.py @@ -26,6 +26,8 @@ def build_model(args, only_teacher=False, img_size=224): num_register_tokens=args.num_register_tokens, interpolate_offset=args.interpolate_offset, interpolate_antialias=args.interpolate_antialias, + in_chans=args.in_chans, + channel_adaptive=args.channel_adaptive, ) teacher = vits.__dict__[args.arch](**vit_kwargs) if only_teacher: diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index 74df767eb..34694244a 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -66,6 +66,7 @@ def __init__( num_register_tokens=0, interpolate_antialias=False, interpolate_offset=0.1, + channel_adaptive=False, ): """ Args: @@ -103,6 +104,7 @@ def __init__( self.num_register_tokens = num_register_tokens self.interpolate_antialias = interpolate_antialias self.interpolate_offset = interpolate_offset + self.bag_of_channels = channel_adaptive self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches @@ -304,6 +306,11 @@ def get_intermediate_layers( return_class_token: bool = False, norm=True, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + + if self.bag_of_channels: + B, C, H, W = x.shape + x = x.reshape(B * C, 1, H, W) # passing channels to batch dimension to get encodings for each channel + if self.chunked_blocks: outputs = self._get_intermediate_layers_chunked(x, n) else: @@ -318,6 +325,22 @@ def get_intermediate_layers( out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() for out in outputs ] + + if self.bag_of_channels: + output = tuple(zip(outputs, class_tokens)) + output = list( + zip(*output) + ) # unzip the tuple: (list of patch_tokens per block, list of class tokens per block) + patch_tokens_per_block = output[0] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B*C, N, D + cls_tokens_per_block = output[1] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B*C, D + patch_tokens_per_block = [ + patch_tokens.reshape(B, C, patch_tokens.shape[-2], patch_tokens.shape[-1]) + for patch_tokens in patch_tokens_per_block + ] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B, C, N, D + cls_tokens_per_block = [cls_tokens.reshape(B, -1) for cls_tokens in cls_tokens_per_block] + output = tuple(zip(patch_tokens_per_block, cls_tokens_per_block)) + return output + if return_class_token: return tuple(zip(outputs, class_tokens)) return tuple(outputs) @@ -338,7 +361,7 @@ def init_weights_vit_timm(module: nn.Module, name: str = ""): nn.init.zeros_(module.bias) -def vit_small(patch_size=16, num_register_tokens=0, **kwargs): +def vit_small(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=384, @@ -347,12 +370,14 @@ def vit_small(patch_size=16, num_register_tokens=0, **kwargs): mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, + in_chans=in_chans, + channel_adaptive=channel_adaptive, **kwargs, ) return model -def vit_base(patch_size=16, num_register_tokens=0, **kwargs): +def vit_base(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=768, @@ -361,12 +386,14 @@ def vit_base(patch_size=16, num_register_tokens=0, **kwargs): mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, + in_chans=in_chans, + channel_adaptive=channel_adaptive, **kwargs, ) return model -def vit_large(patch_size=16, num_register_tokens=0, **kwargs): +def vit_large(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1024, @@ -375,12 +402,14 @@ def vit_large(patch_size=16, num_register_tokens=0, **kwargs): mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, + in_chans=in_chans, + channel_adaptive=channel_adaptive, **kwargs, ) return model -def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): +def vit_giant2(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs): """ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 """ @@ -392,6 +421,8 @@ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, + in_chans=in_chans, + channel_adaptive=channel_adaptive, **kwargs, ) return model diff --git a/dinov2/run/eval/knn_celldino.py b/dinov2/run/eval/knn_celldino.py new file mode 100644 index 000000000..cd128a1d3 --- /dev/null +++ b/dinov2/run/eval/knn_celldino.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.eval.knn_celldino import get_args_parser as get_knn_args_parser +from dinov2.logging import setup_logging +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.eval.knn_celldino import main as knn_main + + self._setup_args() + knn_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for k-NN evaluation on models trained with bag of channel strategy or cell dino" + knn_args_parser = get_knn_args_parser(add_help=False) + parents = [knn_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:knn") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dinov2/run/eval/linear_celldino.py b/dinov2/run/eval/linear_celldino.py new file mode 100644 index 000000000..f616539ef --- /dev/null +++ b/dinov2/run/eval/linear_celldino.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.eval.linear_celldino import get_args_parser as get_linear_args_parser +from dinov2.logging import setup_logging +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.eval.linear_celldino import main as linear_main + + self._setup_args() + linear_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 linear_celldino evaluation" + linear_args_parser = get_linear_args_parser(add_help=False) + parents = [linear_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:linear_celldino") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dinov2/train/train.py b/dinov2/train/train.py index 473b8d014..4e86b8daa 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -13,7 +13,7 @@ import torch from dinov2.data import SamplerType, make_data_loader, make_dataset -from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator +from dinov2.data import collate_data_and_cast, DataAugmentationDINO, CellAugmentationDINO, MaskingGenerator import dinov2.distributed as distributed from dinov2.fsdp import FSDPCheckpointer from dinov2.logging import MetricLogger @@ -172,13 +172,22 @@ def do_train(cfg, model, resume=False): max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, ) - data_transform = DataAugmentationDINO( - cfg.crops.global_crops_scale, - cfg.crops.local_crops_scale, - cfg.crops.local_crops_number, - global_crops_size=cfg.crops.global_crops_size, - local_crops_size=cfg.crops.local_crops_size, - ) + if cfg.train.cell_augmentation: + data_transform = CellAugmentationDINO( + cfg.crops.global_crops_scale, + cfg.crops.local_crops_scale, + cfg.crops.local_crops_number, + global_crops_size=cfg.crops.global_crops_size, + local_crops_size=cfg.crops.local_crops_size, + ) + else: + data_transform = DataAugmentationDINO( + cfg.crops.global_crops_scale, + cfg.crops.local_crops_scale, + cfg.crops.local_crops_number, + global_crops_size=cfg.crops.global_crops_size, + local_crops_size=cfg.crops.local_crops_size, + ) collate_fn = partial( collate_data_and_cast, diff --git a/dinov2/utils/checkpoint.py b/dinov2/utils/checkpoint.py new file mode 100644 index 000000000..ea1ebe10b --- /dev/null +++ b/dinov2/utils/checkpoint.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +from typing import Any + +from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer +from torch import nn + +import dinov2.distributed as dist + + +class PeriodicCheckpointerWithCleanup(PeriodicCheckpointer): + @property + def does_write(self) -> bool: + """See https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py#L114""" + return self.checkpointer.save_dir and self.checkpointer.save_to_disk + + def save_best(self, **kwargs: Any) -> None: + """Same argument as `Checkpointer.save`, to save a model named like `model_best.pth`""" + self.checkpointer.save(f"{self.file_prefix}_best", **kwargs) + + def has_checkpoint(self) -> bool: + return self.checkpointer.has_checkpoint() + + def get_checkpoint_file(self) -> str: # returns "" if the file does not exist + return self.checkpointer.get_checkpoint_file() + + def load(self, path: str, checkpointables=None) -> dict[str, Any]: + return self.checkpointer.load(path=path, checkpointables=checkpointables) + + def step(self, iteration: int, **kwargs: Any) -> None: + if not self.does_write: # step also removes files, so should be deactivated when object does not write + return + super().step(iteration=iteration, **kwargs) + + +def resume_or_load(checkpointer: Checkpointer, path: str, *, resume: bool = True) -> dict[str, Any]: + """ + If `resume` is True, this method attempts to resume from the last + checkpoint, if exists. Otherwise, load checkpoint from the given path. + Similar to Checkpointer.resume_or_load in fvcore + https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py#L208 + but always reload checkpointables, in case we want to resume the training in a new job. + """ + if resume and checkpointer.has_checkpoint(): + path = checkpointer.get_checkpoint_file() + return checkpointer.load(path) + + +def build_periodic_checkpointer( + model: nn.Module, + save_dir="", + *, + period: int, + max_iter=None, + max_to_keep=None, + **checkpointables: Any, +) -> PeriodicCheckpointerWithCleanup: + """Util to build a `PeriodicCheckpointerWithCleanup`.""" + checkpointer = Checkpointer(model, save_dir, **checkpointables, save_to_disk=dist.is_main_process()) + return PeriodicCheckpointerWithCleanup(checkpointer, period, max_iter=max_iter, max_to_keep=max_to_keep) diff --git a/dinov2/utils/cluster.py b/dinov2/utils/cluster.py index 3df87dc3e..855a5268b 100644 --- a/dinov2/utils/cluster.py +++ b/dinov2/utils/cluster.py @@ -64,8 +64,8 @@ def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[ return None SLURM_PARTITIONS = { - ClusterType.AWS: "learnlab", - ClusterType.FAIR: "learnlab", + ClusterType.AWS: "learnaccel", + ClusterType.FAIR: "learnaccel", ClusterType.RSC: "learn", } return SLURM_PARTITIONS[cluster_type] diff --git a/docs/Cell-DINO.png b/docs/Cell-DINO.png new file mode 100644 index 000000000..82cec520a Binary files /dev/null and b/docs/Cell-DINO.png differ diff --git a/docs/README_CELLDINO.md b/docs/README_CELLDINO.md new file mode 100644 index 000000000..26ef1e321 --- /dev/null +++ b/docs/README_CELLDINO.md @@ -0,0 +1,153 @@ + +# Cell-DINO: Self-Supervised Image-based Embeddings for Cell Fluorescent Microscopy + +Théo Moutakanni*, Camille Couprie*, Seungeun Yi*, Elouan Gardes*, Piotr Bojanowski*, Hugo Touvron*, Michael Doron, Zitong S. Chen, Nikita Moshkov, Mathilde Caron, Armand Joulin, Wolfgang M. Pernice, Juan C. Caicedo + +[[`BibTeX`](#citing-cell-dino)] + +**[*Meta AI Research, FAIR](https://ai.facebook.com/research/)** + +PyTorch implementation and pretrained models for Cell-DINO. + +The contents of this repo, including the code and model weights, are intended for research use only. It is not for use in medical procedures, including any diagnostics, treatment, or curative applications. Do not use this model for any clinical purpose or as a substitute for professional medical judgement. + +![teaser](Cell-DINO.png) + +## Pretrained models + +One model pretrained on HPA single cell, one model pretrained on HPA Field of View and one model pretrained on the combined cell painting dataset will be released soon. + +## Installation + +Follow instructions in the DINOv2 README or build the following environment: +```shell +conda create -n py39 python=3.9 +conda activate py39 +pip install -r requirements.txt +pip install -U scikit-learn +``` + +## Data preparation + +Instructions how to prepare HPA single cell and HPA Field of view data will be added soon. + +The HPA-FoV and HPA single cell (HPAone) datasets are available [here](https://www.ebi.ac.uk/biostudies/bioimages/studies/S-BIAD2443) + +Required files for HPA single cell (dataloader : HPAone.py): + +train_data/varied_size_masked_single_cells_HPA +train_data/fixed_size_masked_single_cells_HPA +train_data/varied_size_masked_single_cells_HPA +varied_size_masked_single_cells_pretrain_20240507.csv +fixed_size_masked_single_cells_evaluation_20240507.csv +fixed_size_masked_single_cells_pretrain_20240507.csv + +:warning: To execute the commands provided in the next sections for training and evaluation, the `dinov2` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`. + +## Training + +### Fast setup: training DINOv2 ViT-L/16 on HPA single cell dataset + +Run CellDINO training on 4 A100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit: + +```shell +python dinov2/run/train/train.py \ + --nodes 4 \ + --config-file dinov2/configs/train/hpaone_vitl16.yaml \ + --output-dir \ + train.dataset_path=HPAone:split=ALL:root= +``` + +Training time is approximately 2 days on 4 A100 GPU nodes and the resulting checkpoint should reach 78.5 F1 accuracy for protein localization with a linear evaluation. + +The training code saves the weights of the teacher in the `eval` folder every 9000 iterations for evaluation. + +## Evaluation + +The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node: + +### Linear classification with data augmentation on HPAone: + +```shell +PYTHONPATH=.:dinov2/data python dinov2/run/eval/linear_celldino.py \ + --config-file dinov2/configs/eval/celldino.yaml \ + --pretrained-weights /eval/training_44999/teacher_checkpoint.pth \ + --output-dir /eval/training_44999/linear \ + --train-dataset HPAone:split=TRAIN:mode=PROTEIN_LOCALIZATION:root= \ + --val-dataset HPAone:split=VAL:mode=PROTEIN_LOCALIZATION:root=/large_experiments/dinov2/datasets/HPAone \ + --val-metric-type mean_per_class_multilabel_f1 \ + --loss-type binary_cross_entropy \ + --avgpool \ +``` + +We release the weights from evaluating the different models: + +The performance of the provided pretrained model weights can be evaluated as follows on HPAone for the protein localization task: + +```shell +PYTHONPATH=.:dinov2/data python dinov2/run/eval/linear_celldino.py \ + --config-file dinov2/configs/eval/celldino.yaml \ + --pretrained-weights https://dl.fbaipublicfiles.com/dinov2/celldino/HPA_single_cell.pth \ + --output-dir \ + --train-dataset HPAone:split=TRAIN:mode=PROTEIN_LOCALIZATION:root= \ + --val-dataset HPAone:split=VAL:mode=PROTEIN_LOCALIZATION:root=/large_experiments/dinov2/datasets/HPAone \ + --val-metric-type mean_per_class_multilabel_f1 \ + --loss-type binary_cross_entropy \ + --avgpool \ +``` + +and + +```shell +PYTHONPATH=.:dinov2/data python dinov2/run/eval/linear_celldino.py \ + --config-file dinov2/configs/eval/celldino.yaml \ + --pretrained-weights https://dl.fbaipublicfiles.com/dinov2/celldino/HPA_single_cell.pth \ + --output-dir \ + --train-dataset HPAone:split=TRAIN:mode=CELL_TYPE:root= \ + --val-dataset HPAone:split=VAL:mode=CELL_TYPE:root=/large_experiments/dinov2/datasets/HPAone \ + --val-metric-type mean_per_class_multiclass_f1 \ + --avgpool \ +``` + +for the cell line classification task. + +### knn evaluation on HPAone: + +```shell +PYTHONPATH=.:dinov2/data python dinov2/run/eval/knn_celldino.py \ +--config-file dinov2/configs/eval/celldino.yaml \ +--pretrained-weights https://dl.fbaipublicfiles.com/dinov2/celldino/HPA_single_cell.pt \ +--output-dir \ +--train-dataset HPAone:split=TRAIN:mode=CELL_TYPE:root= \ +--val-dataset HPAone:split=VAL:mode=CELL_TYPE:root= \ +--metric-type mean_per_class_multiclass_f1 \ +--crop-size 384 \ +--batch-size 256 \ +--resize-size 0 \ +--nb_knn 10 \ +``` + +For the knn evaluation on HPAFoV, replace 'HPAone' by 'HPAFoV' in the command above. + +## License + +Cell-DINO code is released under the CC by NC licence See [LICENSE_CELLDINO](LICENSE_CELLDINO) for additional details. +Model weights will be released under the FAIR Non-Commercial Research License. + +## Contributing + +See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). + +## Citing Cell-DINO + +If you find this repository useful, please consider giving a star :star: and citation :t-rex:: + +``` +@misc{, + title={Cell-DINO: Self-Supervised Image-based Embeddings for Cell Fluorescent Microscopy}, + author={Moutakanni, Th\'eo and Couprie, Camille and Yi, Seungeun and Gardes, Elouan Gardes and Bojanowski, Piotr and Touvron, Hugo and Doron, Michael and Chen, Zitong S. and Moshkov, Nikita and Caron, Mathilde and Joulin, Armand and Pernice, Wolfgang M. and Caicedo, Juan C.}, + journal={in review to PloS One on Computational Biology}, + year={2025} +} +``` + diff --git a/docs/README_CHANNEL_ADAPTIVE_DINO.md b/docs/README_CHANNEL_ADAPTIVE_DINO.md index 99eddcd63..c1a0ce335 100644 --- a/docs/README_CHANNEL_ADAPTIVE_DINO.md +++ b/docs/README_CHANNEL_ADAPTIVE_DINO.md @@ -7,15 +7,13 @@ Alice V. De Lorenci, Seungeun Yi, Théo Moutakanni, Piotr Bojanowski, Camille Couprie, Juan C. Caicedo, Wolfgang M. Pernice, -with special thanks to Elouan Gardes for his contributions to the codebase. +with special thanks to Elouan Gardes for his contributions to the codebase. -:warning: This is just the README, the code is coming soon (in July 2025). - -PyTorch implementation and pretrained model for ChannelAdaptive-DINO. +PyTorch implementation and pretrained model for ChannelAdaptive-DINO. The contents of this repo, including the code and model weights, are intended for research use only. It is not for use in medical procedures, including any diagnostics, treatment, or curative applications. Do not use this model for any clinical purpose or as a substitute for professional medical judgement. -![teaser](ChannelAdaptiveDINO.png) +![teaser](https://github.com/se-yi/dinov2-bio/blob/rebased_channel_adaptive_dino/ChannelAdaptiveDINO.png) ## Pretrained model @@ -33,8 +31,13 @@ This repository includes the Bag of Channel implementation, not the Hierarchical The CHAMMI dataset is available [here](https://github.com/chaudatascience/channel_adaptive_models). -The HPA-FoV dataset is available [here]() TODO! +The HPA-FoV dataset is available [here](https://www.ebi.ac.uk/biostudies/bioimages/studies/S-BIAD2443) + +Content: a directory new_512_whole_images and two csv files: TODO + +"2022_07_04_whole_image_train_data/whole_images_512_test.csv" +"2022_07_04_whole_image_train_data/whole_images_512_train.csv" :warning: To execute the commands provided in the next sections for training and evaluation, the `dinov2` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`. @@ -49,7 +52,7 @@ python dinov2/run/train/train.py \ --nodes 4 \ --config-file dinov2/configs/train/hpafov_vitl16_boc.yaml \ --output-dir \ - train.dataset_path=HPAFoV:split=LARGE_REPRODUCE:root=:wildcard=SEPARATE_CHANNELS" + train.dataset_path=HPAFoV:split=TRAIN:root=:wildcard=SEPARATE_CHANNELS" ``` Training time is approximately 2 days. @@ -67,8 +70,8 @@ PYTHONPATH=.:dinov2/data python dinov2/run/eval/linear_celldino.py \ --config-file dinov2/configs/eval/channeldino_ext_chammi.yaml \ --pretrained-weights /eval/training_359999/teacher_checkpoint.pth \ --output-dir /eval/training_359999/linear \ - --train-dataset HPAFoV:split=LARGE_REPRODUCE:mode=PROTEIN_LOCALIZATION:root= \ - --val-dataset HPAFoV:split=SMALL_REPRODUCE:mode=PROTEIN_LOCALIZATION:root= \ + --train-dataset HPAFoV:split=TRAIN:mode=PROTEIN_LOCALIZATION:root= \ + --val-dataset HPAFoV:split=VAL:mode=PROTEIN_LOCALIZATION:root= \ --val-metric-type mean_per_class_multilabel_f1 \ --loss-type binary_cross_entropy \ --bag-of-channels \ @@ -82,20 +85,24 @@ PYTHONPATH=.:dinov2/data python dinov2/run/eval/linear_celldino.py \ ### KNN classification on CHAMMI +Go to the docs directory, modifify some paths in launcher_knn_eval_on_chammi.sh and run + ```shell -./launcher_CHAMMI_knn_eval.sh WTC TASK_ONE ; -./launcher_CHAMMI_knn_eval.sh WTC TASK_TWO ; -./launcher_CHAMMI_knn_eval.sh HPA TASK_ONE ; -./launcher_CHAMMI_knn_eval.sh HPA TASK_TWO ; -./launcher_CHAMMI_knn_eval.sh HPA TASK_THREE ; -./launcher_CHAMMI_knn_eval.sh CP TASK_ONE ; -./launcher_CHAMMI_knn_eval.sh CP TASK_TWO ; -./launcher_CHAMMI_knn_eval.sh CP TASK_THREE ; -./launcher_CHAMMI_knn_eval.sh CP TASK_FOUR ; +./launcher_knn_eval_on_chammi.sh WTC TASK_ONE ; +./launcher_knn_eval_on_chammi.sh WTC TASK_TWO ; +./launcher_knn_eval_on_chammi.sh HPA TASK_ONE ; +./launcher_knn_eval_on_chammi.sh HPA TASK_TWO ; +./launcher_knn_eval_on_chammi.sh HPA TASK_THREE ; +./launcher_knn_eval_on_chammi.sh CP TASK_ONE ; +./launcher_knn_eval_on_chammi.sh CP TASK_TWO ; +./launcher_knn_eval_on_chammi.sh CP TASK_THREE ; +./launcher_knn_eval_on_chammi.sh CP TASK_FOUR ; ``` ### Linear classification on CHAMMI +Go to the docs directory, modifify some paths in launcher_CHAMMI_eval.sh and run + ```shell ./launcher_CHAMMI_eval.sh WTC TASK_ONE ; ./launcher_CHAMMI_eval.sh WTC TASK_TWO ; @@ -110,7 +117,7 @@ PYTHONPATH=.:dinov2/data python dinov2/run/eval/linear_celldino.py \ | | WTC - Task 1 | WTC - Task 2 | HPA - Task 1 | HPA - Task 2 | HPA - Task 3 | CP - Task 1 | CP - Task 2 | CP - Task 3 | CP - Task 4 | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| knn reproduced | 80.3 | 79.3 | 91.6 | 61.4 | 28.5 | 89.8 | 57.6 | 23.4 | 18.4 | +| knn reproduced | 80.3 | 79.3 | 91.6 | 61.4 | 29.0 | 89.8 | 57.6 | 23.4 | 18.4 | | knn paper | 79.4 | 79.0 | 86.6 | 59.3 | 29.6 | 92.6 | 57.6 | 22.1 | 18.5 | | Linear reproduced | 89.9 | 87.9 | 92.7 | 87.2 | 66.2 | 89.9 | 59.8 | 26.6 | 32.5| | Linear paper | 90.5 | 89.2 | 88.3 | 84.7 | 65.0 | 90.5 | 60.5 | 25.8 | 32.7| @@ -121,6 +128,9 @@ PYTHONPATH=.:dinov2/data python dinov2/run/eval/linear_celldino.py \ CellDINO code is released under the CC by NC licence See [LICENSE_CELL_DINO](LICENSE_CELL_DINO) for additional details. Model weights will be released under the FAIR Non-Commercial Research License. +## Contributing + +See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). ## Citing ChannelAdaptiveDINO and DINOv2 diff --git a/docs/launcher_CHAMMI_eval.sh b/docs/launcher_CHAMMI_eval.sh new file mode 100755 index 000000000..20a1b69f6 --- /dev/null +++ b/docs/launcher_CHAMMI_eval.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# 1 : modify CHANNEL_AGNOSTIC_CELL_MODEL, CHAMMI_DATA_PATH and OUTPUT_DIR below +# 2 : call this script with the two arguments specified below + +#Arguments: +# $1 : dataset, e.g CP +# $2 : task number, e.g TASK_TWO + +CHAMMI_DATA_PATH="" +CHANNEL_AGNOSTIC_CELL_MODEL="path_to_model/model.pth" +OUTPUTDIR=YOUR_OUTPUT_PATH_$1_$2 + +if [ "$2" == "TASK_FOUR" ]; then + OTHER_ARG="--leave-one-out-dataset $CHAMMI_DATA_PATH/CP/enriched_meta.csv " +elif [ "$1" == "HPA" -a "$2" == "TASK_THREE" ]; then + OTHER_ARG="--leave-one-out-dataset $CHAMMI_DATA_PATH/HPA/enriched_meta.csv " +else + OTHER_ARG="" +fi + +if [ $1 != "CP" ]; then + OTHER_ARG="$OTHER_ARG --resize-size 256" +fi + +PYTHONPATH=..:../dinov2/data python ../dinov2/run/eval/linear_celldino.py \ +--config-file ../dinov2/configs/eval/channeldino_ext_chammi.yaml \ +--pretrained-weights $CHANNEL_AGNOSTIC_CELL_MODEL \ +--output-dir $OUTPUTDIR \ +--train-dataset CHAMMI_$1:split=TRAIN:root=$CHAMMI_DATA_PATH \ +--val-dataset CHAMMI_$1:split=$2:root=$CHAMMI_DATA_PATH \ +--val-metric-type mean_per_class_multiclass_f1 \ +--bag-of-channels \ +--crop-size 224 \ +--n-last-blocks 1 \ +--avgpool \ +--batch-size 128 \ +--epoch-length 30 \ +--epochs 10 \ +$OTHER_ARG \ diff --git a/docs/launcher_knn_eval_on_chammi.sh b/docs/launcher_knn_eval_on_chammi.sh new file mode 100755 index 000000000..29ab7798c --- /dev/null +++ b/docs/launcher_knn_eval_on_chammi.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# 1 : modify CHANNEL_AGNOSTIC_CELL_MODEL, CHAMMI_DATA_PATH and OUTPUT_DIR below +# 2 : call this script with the two arguments specified below + +#Arguments: +# $1 : dataset, e.g CP +# $2 : task number, e.g TASK_TWO + +CHAMMI_DATA_PATH="" +CHANNEL_AGNOSTIC_CELL_MODEL="path_to_model/model.pth" +OUTPUT_DIR=YOUR_OUTPUT_PATH_$1_$2 + +if [ "$2" == "TASK_FOUR" ]; then + OTHER_ARG="--leave-one-out-dataset $CHAMMI_DATA_PATH/CP/enriched_meta.csv " +elif [ "$1" == "HPA" -a "$2" == "TASK_THREE" ]; then + OTHER_ARG="--leave-one-out-dataset $CHAMMI_DATA_PATH/CHAMMI/HPA/enriched_meta.csv " +else + OTHER_ARG="" +fi +echo $OTHER_ARG + +PYTHONPATH=..:../dinov2/data python ../dinov2/run/eval/knn_celldino.py \ +--config-file ../dinov2/configs/eval/channeldino_ext_chammi.yaml \ +--pretrained-weights $CHANNEL_AGNOSTIC_CELL_MODEL \ +--output-dir $OUTPUT_DIR \ +--train-dataset CHAMMI_$1:split=TRAIN:root=$CHAMMI_DATA_PATH \ +--val-dataset CHAMMI_$1:split=$2:root=$CHAMMI_DATA_PATH \ +--metric-type mean_per_class_multiclass_f1 \ +--crop-size 224 \ +--batch-size 32 \ +--resize-size 256 \ +--bag-of-channels \ +$OTHER_ARG \ diff --git a/docs/test_inference_celldino.py b/docs/test_inference_celldino.py new file mode 100644 index 000000000..263942baa --- /dev/null +++ b/docs/test_inference_celldino.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC-by-NC licence, +# found in the LICENSE_CELLDINO file in the root directory of this source tree. + +import torch +import torchvision +from dinov2.hub.backbones import celldino_hpa_vitl16, celldino_cp_vits8 +from functools import partial +from dinov2.eval.utils import ModelWithIntermediateLayers + +DEVICE = "cuda:0" +SAMPLE_IMAGES_DIR = "" # path to directory with cell images. +MODELS_DIR = "" # path to directory with pretrained models. + + +class self_normalize(object): + def __call__(self, x): + x = x / 255 + m = x.mean((-2, -1), keepdim=True) + s = x.std((-2, -1), unbiased=False, keepdim=True) + x -= m + x /= s + 1e-7 + return x + + +normalize = self_normalize() + +# ---------------------- Example inference on HPA-FoV dataset -------------------------- + +# 1- Read one human protein atlas HPA-FoV image (4 channels) +img = torchvision.io.read_image(SAMPLE_IMAGES_DIR + "HPA_FoV_00070df0-bbc3-11e8-b2bc-ac1f6b6435d0.png") + +# 2- Normalise image as it was done for training +img_hpa_fov = img.unsqueeze(0).to(device=DEVICE) +img_hpa_fov = normalize(img_hpa_fov) + +# 3- Load model +cell_dino_model = celldino_hpa_vitl16( + pretrained_path=MODELS_DIR + "celldino_hpa_fov.pth", +) +print(cell_dino_model) +cell_dino_model.to(device=DEVICE) +cell_dino_model.eval() + +# 4- Inference +features = cell_dino_model(img_hpa_fov) +print(features) + +# 5- [Optional] feature extractor as used for linear evaluation +autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float) +model_with_interm_layers = ModelWithIntermediateLayers(cell_dino_model, 4, autocast_ctx) +features_with_interm_layers = model_with_interm_layers(img_hpa_fov) + +# ---------------------- Example inference on cell painting data -------------------------- + +# 1- Read one cell painting image (5 channels) +img = torchvision.io.read_image(SAMPLE_IMAGES_DIR + "CP_BBBC036_24277_a06_1_976@140x149.png") +img5_channels = torch.zeros([1, 5, 160, 160]) +for c in range(5): + img5_channels[0, c] = img[0, :, 160 * c : 160 * (c + 1)] +img5_channels = img5_channels.to(device=DEVICE) + +# 2- Normalise image as it was done for training +img5_channels = normalize(img5_channels) + +# 3- Load model +cell_dino_model = celldino_cp_vits8( + pretrained_path=MODELS_DIR + "celldino_cp.pth", +) +print(cell_dino_model) +cell_dino_model.to(device=DEVICE) +cell_dino_model.eval() + +# 4- Inference +features = cell_dino_model(img5_channels) +print(features) + +# ---------------------- Example inference on HPA single cell dataset -------------------------- + +# Read one human protein atlas HPA single cell image (4 channels) +img = torchvision.io.read_image(SAMPLE_IMAGES_DIR + "HPA_single_cell_00285ce4-bba0-11e8-b2b9-ac1f6b6435d0_15.png") + +# 2- Normalise image as it was done for training +img_hpa = img.unsqueeze(0).to(device=DEVICE) +img_hpa = normalize(img_hpa) + +# 3- Load model +cell_dino_model = celldino_hpa_vitl16( + pretrained_path=MODELS_DIR + "celldino_hpa_sc.pth", +) +print(cell_dino_model) +cell_dino_model.to(device=DEVICE) +cell_dino_model.eval() + +# 4- Inference +features = cell_dino_model(img_hpa) +print(features) + +torch.save(features.cpu(), "sample_features_hpa.pt")