diff --git a/.cgcignore b/.cgcignore index fd51031..74f42df 100644 --- a/.cgcignore +++ b/.cgcignore @@ -4,18 +4,23 @@ /tmp/ bin/ obj/ +TestResults/ cache/ [Gg]enerated/ *[Aa]rtifacts/ +.vitepress/ # Ignore dependencies node_modules/ # Ignore logs *.log +*.txt # Git /.git/ +/.github/ + # IDE /.vscode/ diff --git a/.github/agents/DevOps.agent.md b/.github/agents/DevOps.agent.md index 4e3fd23..2c51f78 100644 --- a/.github/agents/DevOps.agent.md +++ b/.github/agents/DevOps.agent.md @@ -1,7 +1,7 @@ --- description: "Use when: managing CI/CD pipelines, GitHub Actions workflows, build/test/pack/publish automation, NuGet Trusted Publishing, or release processes. Handles .github/workflows/ files and DevOps configuration." model: GPT-5.4 (copilot) -tools: [vscode/memory, vscode/askQuestions, execute/getTerminalOutput, execute/runInTerminal, read, agent, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', edit, search, web, github/get_copilot_job_status, github/get_file_contents, github/get_latest_release, github/get_release_by_tag, github/get_tag, github/issue_read, github/list_branches, github/list_releases, github/list_tags, github/pull_request_read, github/search_code, github/search_issues, github/search_pull_requests, github/search_repositories, github.vscode-pull-request-github/notification_fetch, todo] +tools: [vscode/askQuestions, vscode/memory, vscode/resolveMemoryFileUri, execute/getTerminalOutput, execute/runInTerminal, read, agent, edit, search, web, github/get_copilot_job_status, github/get_file_contents, github/get_latest_release, github/get_release_by_tag, github/get_tag, github/issue_read, github/list_branches, github/list_releases, github/list_tags, github/pull_request_read, github/search_code, github/search_issues, github/search_pull_requests, github/search_repositories, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo, github.vscode-pull-request-github/notification_fetch] agents: ["Explore"] user-invocable: true argument-hint: "Describe the CI/CD change: add workflow, fix pipeline, update publish config, etc." @@ -43,11 +43,12 @@ git push origin main && git push origin ioc-v0.9.1-alpha # triggers pu **Post-release**: bump `version.json` via `nbgv prepare-release --project ` or manual edit. -Follow the **parent agent protocol** in `.github/instructions/plan-memory-policy.instructions.md`. +Follow the **parent agent protocol** in `.github/instructions/memory-policy.instructions.md`. ## Approach -1. **Explore First (Required)** — Delegate to `Explore` to gather workflow and release context. +0. **Capture Goal (Required)** — Distill the user's request into a concise goal statement and save it to `/memories/session/goal.md` via #tool:vscode/memory before any research. +1. **Explore First (Required)** — Delegate to `Explore` to gather workflow and release context. Provide the goal from `goal.md` alongside the research question. 2. **Create Plan.md (Required)** — Build `plan.md` from Explore findings (goal, scope, files, validation checks). 3. **Save & Verify Plan (Required)** — Follow the parent agent protocol in plan memory policy. 4. **Approve** — Present the plan and wait for user approval before risky or broad changes. @@ -58,7 +59,7 @@ Follow the **parent agent protocol** in `.github/instructions/plan-memory-policy ## Boundaries - ✅ **Always do:** - - Follow the plan memory policy in `.github/instructions/plan-memory-policy.instructions.md` + - Follow the plan memory policy in `.github/instructions/memory-policy.instructions.md` - Read workflow files before editing - Follow the three-job pattern: `build -> publish -> release` - Pin explicit stable action major versions (examples: `actions/checkout@v6`, `actions/setup-dotnet@v5`, `actions/upload-artifact@v7`, `actions/download-artifact@v8`, `NuGet/login@v1`; never `@latest` or branch refs) diff --git a/.github/agents/Doc.agent.md b/.github/agents/Doc.agent.md index 01b5598..82cc763 100644 --- a/.github/agents/Doc.agent.md +++ b/.github/agents/Doc.agent.md @@ -1,7 +1,7 @@ --- description: "Use when: writing or updating user-facing documentation files (docs/ folder). Creates progressive, beginner-friendly guides with generated code examples for the SourceGen repository." model: Claude Opus 4.6 (copilot) -tools: [vscode/memory, vscode/askQuestions, execute/getTerminalOutput, execute/runInTerminal, read, agent, codegraphcontext/analyze_code_relationships, codegraphcontext/find_code, codegraphcontext/get_repository_stats, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', edit, search, web, todo] +tools: [vscode/askQuestions, vscode/memory, vscode/resolveMemoryFileUri, execute/getTerminalOutput, execute/runInTerminal, read, agent, edit, search, web, codegraphcontext/analyze_code_relationships, codegraphcontext/find_code, codegraphcontext/get_repository_stats, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo] agents: ["Explore", "DocReview"] user-invocable: true argument-hint: "Provide the documentation topic or feature to document, and which doc files to create or update" @@ -10,11 +10,12 @@ You are an expert technical writer for the SourceGen repository. You specialize Follow the project principles in `AGENTS.md`. -Follow the **parent agent protocol** in `.github/instructions/plan-memory-policy.instructions.md`. +Follow the **parent agent protocol** in `.github/instructions/memory-policy.instructions.md`. ## Approach -1. Run `Explore` first to gather context for the requested documentation work. +0. Capture the user's request into a concise goal statement and save it to `/memories/session/goal.md` via #tool:vscode/memory before any research. +1. Run `Explore` first to gather context for the requested documentation work. Provide the goal from `goal.md` alongside the research question. 2. Create `plan.md` from Explore findings (goal, scope, target files, acceptance checks). 3. Follow the parent agent protocol in plan memory policy: save, verify, and gate on failure. 4. Read existing docs under `docs/` to understand current structure and conventions. @@ -65,7 +66,7 @@ For every source generator feature, **always** include a generated code example ## Boundaries - ✅ **Always do:** - - Follow the plan memory policy in `.github/instructions/plan-memory-policy.instructions.md` + - Follow the plan memory policy in `.github/instructions/memory-policy.instructions.md` - Write new files to `docs/` following the existing numbering scheme - Follow the style conventions already established in existing docs - Include `
` generated code sections for every source-generator feature @@ -93,10 +94,11 @@ Return a structured completion report: #### Preconditions - ExploreCompleted: true | false +- MemoryGoalSaved: true | false - MemoryPlanSaved: true | false - MemoryPlanVerified: true | false - MemoryPlanLoaded: true | false -- MemoryPath: /memories/session/plan.md +- MemoryPath: /memories/session/goal.md, /memories/session/plan.md - PlanMode: draft | approved - Blocker: (empty or reason) diff --git a/.github/agents/DocReview.agent.md b/.github/agents/DocReview.agent.md index 684df8d..81909f9 100644 --- a/.github/agents/DocReview.agent.md +++ b/.github/agents/DocReview.agent.md @@ -1,7 +1,7 @@ --- description: "Use when: reviewing completed documentation updates under docs/ for accuracy, consistency, links, and generated code examples." model: GPT-5.4 (copilot) -tools: [vscode/memory, read, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', search, web, todo] +tools: [vscode/memory, vscode/resolveMemoryFileUri, read, search, web, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo] agents: [] user-invocable: false argument-hint: "Provide changed docs files and related source/spec paths to validate" @@ -10,13 +10,13 @@ You are a documentation reviewer for the SourceGen repository. You perform read- Follow the project principles in `AGENTS.md`. -Follow the **child agent protocol** in `.github/instructions/plan-memory-policy.instructions.md`. +Follow the **child agent protocol** in `.github/instructions/memory-policy.instructions.md`. ## Approach -1. **Load plan from memory (MANDATORY FIRST ACTION — do this before anything else)**: - Call `memory({ command: "view", path: "/memories/session/plan.md" })` as your very first tool call. - - If plan is present and non-empty → proceed to step 2. - - If plan is missing or empty → STOP and return `BLOCKED_NEEDS_PARENT_PLAN`. +1. **Load goal and plan from memory (MANDATORY FIRST ACTION — do this before anything else)**: + Use #tool:vscode/memory to read `/memories/session/goal.md` first, then `/memories/session/plan.md`. These must be your very first tool calls. + - If both are present and non-empty → proceed to step 2. + - If either is missing or empty → STOP and return `BLOCKED_NEEDS_PARENT_PLAN`. - If memory tool fails → STOP and return `BLOCKED_NO_PLAN_MEMORY`. 2. Read all changed documentation files provided in the prompt 3. Validate technical accuracy against relevant source/spec files @@ -35,7 +35,7 @@ Follow the **child agent protocol** in `.github/instructions/plan-memory-policy. ## Boundaries - ✅ **Always do:** - - Follow the plan memory policy in `.github/instructions/plan-memory-policy.instructions.md` + - Follow the plan memory policy in `.github/instructions/memory-policy.instructions.md` - Read and cross-reference all changed docs against source code and specs - Verify internal links resolve correctly - Check that `
` generated code sections exist for source-generator features @@ -55,8 +55,9 @@ Return a structured report in this format: ### Documentation Review Report #### Preconditions +- MemoryGoalLoaded: true | false - MemoryPlanLoaded: true | false -- MemoryPath: /memories/session/plan.md +- MemoryPath: /memories/session/goal.md, /memories/session/plan.md - Blocker: (empty or reason) #### Findings diff --git a/.github/agents/Explore.agent.md b/.github/agents/Explore.agent.md index 41a06e0..295588a 100644 --- a/.github/agents/Explore.agent.md +++ b/.github/agents/Explore.agent.md @@ -1,7 +1,7 @@ --- description: "Fast read-only codebase exploration and Q&A subagent. Prefer over manually chaining multiple search and file-reading operations to avoid cluttering the main conversation. Safe to call in parallel. Specify thoroughness: quick, medium, or thorough." model: Claude Haiku 4.5 (copilot) -tools: [vscode/memory, execute/getTerminalOutput, execute/testFailure, read, codegraphcontext/analyze_code_relationships, codegraphcontext/calculate_cyclomatic_complexity, codegraphcontext/execute_cypher_query, codegraphcontext/find_code, codegraphcontext/find_dead_code, codegraphcontext/find_most_complex_functions, codegraphcontext/get_repository_stats, codegraphcontext/load_bundle, codegraphcontext/search_registry_bundles, codegraphcontext/visualize_graph_query, 'microsoft/markitdown/*', 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', search, web, github/get_commit, github/get_file_contents, github/issue_read, github/search_code, github/search_issues, github/search_pull_requests, github/search_repositories, github.vscode-pull-request-github/issue_fetch, github.vscode-pull-request-github/labels_fetch, github.vscode-pull-request-github/notification_fetch, github.vscode-pull-request-github/doSearch, github.vscode-pull-request-github/activePullRequest, github.vscode-pull-request-github/pullRequestStatusChecks, github.vscode-pull-request-github/openPullRequest] +tools: [vscode/memory, vscode/resolveMemoryFileUri, execute/getTerminalOutput, execute/testFailure, read, search, web, github/get_commit, github/get_file_contents, github/issue_read, github/search_code, github/search_issues, github/search_pull_requests, github/search_repositories, codegraphcontext/analyze_code_relationships, codegraphcontext/calculate_cyclomatic_complexity, codegraphcontext/execute_cypher_query, codegraphcontext/find_code, codegraphcontext/find_dead_code, codegraphcontext/find_most_complex_functions, codegraphcontext/get_repository_stats, codegraphcontext/load_bundle, codegraphcontext/search_registry_bundles, codegraphcontext/visualize_graph_query, 'microsoft/markitdown/*', 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', github.vscode-pull-request-github/issue_fetch, github.vscode-pull-request-github/labels_fetch, github.vscode-pull-request-github/notification_fetch, github.vscode-pull-request-github/doSearch, github.vscode-pull-request-github/activePullRequest, github.vscode-pull-request-github/pullRequestStatusChecks, github.vscode-pull-request-github/openPullRequest] agents: [] user-invocable: false argument-hint: "Describe WHAT you're looking for and desired thoroughness (quick/medium/thorough)" diff --git a/.github/agents/Implement.agent.md b/.github/agents/Implement.agent.md index e09f71b..8ba168f 100644 --- a/.github/agents/Implement.agent.md +++ b/.github/agents/Implement.agent.md @@ -1,7 +1,7 @@ --- description: "Use when: implementing approved plan from /memories/session/plan.md. Executes code changes, runs tests, and follows project conventions." -model: GPT-5.4 (copilot) -tools: [vscode/memory, execute, read, 'codegraphcontext/*', 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', edit, search, web, todo] +model: GPT-5.3-Codex (copilot) +tools: [vscode/memory, vscode/resolveMemoryFileUri, execute, read, edit, search, web, 'codegraphcontext/*', 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo] agents: [] user-invocable: false argument-hint: "Implement the approved plan stored in /memories/session/plan.md" @@ -10,7 +10,7 @@ You are an implementation specialist for the SourceGen C# source generator proje Follow the project principles in `AGENTS.md` and the relevant domain `AGENTS.md` for the affected code. -Follow the **child agent protocol** in `.github/instructions/plan-memory-policy.instructions.md`. +Follow the **child agent protocol** in `.github/instructions/memory-policy.instructions.md`. ## Commands @@ -23,27 +23,57 @@ Refer to the relevant domain `AGENTS.md` (e.g., `src/Ioc/AGENTS.md`) for domain- ## Approach -1. **Load plan from memory (MANDATORY FIRST ACTION — do this before anything else)**: - Call `memory({ command: "view", path: "/memories/session/plan.md" })` as your very first tool call. - - If plan is present and non-empty → proceed to step 2. - - If plan is missing or empty → STOP and return `BLOCKED_NEEDS_PARENT_PLAN`. +1. **Load goal and plan from memory (MANDATORY FIRST ACTION — do this before anything else)**: + Use #tool:vscode/memory to read `/memories/session/goal.md` first, then `/memories/session/plan.md`. These must be your very first tool calls. + - If both are present and non-empty → proceed to step 2. + - If either is missing or empty → STOP and return `BLOCKED_NEEDS_PARENT_PLAN`. - If memory tool fails → STOP and return `BLOCKED_NO_PLAN_MEMORY`. 2. Create the full todo list from plan steps via #tool:todo 3. For each step: mark **in-progress** → implement → mark **completed** (do not batch) 4. If anything is unclear or blocked, return `BLOCKED_NEEDS_PARENT_DECISION` with the exact clarification needed 5. Run all related tests after implementation 6. Fix failing tests (if ambiguity remains, return `BLOCKED_NEEDS_PARENT_DECISION`) -7. Report completion +7. **Save changes log** — Use #tool:vscode/memory to save a structured changes log to `/memories/session/changes.md` (see [Changes Log Format](#changes-log-format) below). This MUST be done before reporting completion. +8. Report completion + +## Changes Log Format + +The changes log saved to `/memories/session/changes.md` via #tool:vscode/memory MUST follow this structure: + +```markdown +## Changes Log + +### Changed Files +| # | File | Action | Description | +|---|------|--------|-------------| + +### Decisions Made +- {Decision made during implementation and rationale} + +### Issues Discovered +- {Issue found during implementation — unexpected behavior, missing API, code smell, etc.} + +### Concerns +- {Remaining concerns or risks — potential regressions, edge cases not covered, etc.} +``` + +- **Changed Files**: Every file created, modified, or deleted with a brief description of the change. +- **Decisions Made**: Any implementation choices not explicitly dictated by the plan (e.g., choosing between two valid approaches, naming decisions, handling an edge case). +- **Issues Discovered**: Problems found during implementation — bugs in existing code, spec gaps, unexpected constraints. +- **Concerns**: Lingering risks or open questions that the parent agent should be aware of. + +If a section has no entries, write "None." ## Boundaries - ✅ **Always do:** - - Follow the plan memory policy in `.github/instructions/plan-memory-policy.instructions.md` + - Follow the memory policy in `.github/instructions/memory-policy.instructions.md` - Follow C# 14 conventions: file-scoped namespaces, `#nullable enable`, .NET naming - Use `readonly record struct` or `sealed record class` for generator data models - Follow domain-specific rules from the relevant `AGENTS.md` (e.g., `src/Ioc/AGENTS.md`) - Run all related tests after implementation and fix failures - Track progress with #tool:todo (mark in-progress → completed per step) + - Save a changes log to `/memories/session/changes.md` via #tool:vscode/memory before reporting completion - ⚠️ **Ask first:** - When the plan is ambiguous or a design decision is needed — return `BLOCKED_NEEDS_PARENT_DECISION` @@ -55,6 +85,7 @@ Refer to the relevant domain `AGENTS.md` (e.g., `src/Ioc/AGENTS.md`) for domain- - Use `dotnet test --filter` for TUnit projects - Modify secrets, CI/CD configs, or NuGet publishing settings - Remove existing tests that are failing — fix them or ask + - Modify `/memories/session/plan.md` (owned by parent agents) ## Output Format @@ -63,6 +94,7 @@ Refer to the relevant domain `AGENTS.md` (e.g., `src/Ioc/AGENTS.md`) for domain- #### Preconditions - MemoryPlanLoaded: true | false - MemoryPath: /memories/session/plan.md +- ChangesLogSaved: true | false - Blocker: (empty or reason) #### Changed Files @@ -73,5 +105,11 @@ Refer to the relevant domain `AGENTS.md` (e.g., `src/Ioc/AGENTS.md`) for domain- - **Status**: Pass / Fail - **Details**: (brief summary) -#### Notes -(Any deviations, issues, or follow-ups) +#### Decisions Made +- (decisions made during implementation) + +#### Issues Discovered +- (issues found during implementation, or "None") + +#### Concerns +- (remaining concerns or risks, or "None") diff --git a/.github/agents/Orchestrator.agent.md b/.github/agents/Orchestrator.agent.md index 0d98763..1d48a3d 100644 --- a/.github/agents/Orchestrator.agent.md +++ b/.github/agents/Orchestrator.agent.md @@ -1,8 +1,8 @@ --- description: "Use when: implementing features, fixing bugs, or making code changes that require planning, approval, and review. Analyzes requirements, writes plan.md, and delegates to subagents." model: Claude Opus 4.6 (copilot) -tools: [vscode/memory, vscode/askQuestions, execute/getTerminalOutput, execute/testFailure, read, agent, 'codegraphcontext/*', 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', search, web, github/add_reply_to_pull_request_comment, github/get_commit, github/get_copilot_job_status, github/issue_read, github/pull_request_read, github/search_issues, github/search_pull_requests, vscode.mermaid-chat-features/renderMermaidDiagram, github.vscode-pull-request-github/issue_fetch, github.vscode-pull-request-github/labels_fetch, github.vscode-pull-request-github/notification_fetch, github.vscode-pull-request-github/doSearch, github.vscode-pull-request-github/activePullRequest, github.vscode-pull-request-github/pullRequestStatusChecks, github.vscode-pull-request-github/openPullRequest, todo] -agents: ["Explore", "Implement", "Review", "Spec", "Doc", "DocReview", "DevOps"] +tools: [vscode/memory, vscode/resolveMemoryFileUri, vscode/askQuestions, execute/getTerminalOutput, execute/testFailure, read, agent, search, web, 'codegraphcontext/*', 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', github/add_reply_to_pull_request_comment, github/get_commit, github/get_copilot_job_status, github/issue_read, github/pull_request_read, github/search_issues, github/search_pull_requests, vscode.mermaid-chat-features/renderMermaidDiagram, github.vscode-pull-request-github/issue_fetch, github.vscode-pull-request-github/labels_fetch, github.vscode-pull-request-github/notification_fetch, github.vscode-pull-request-github/doSearch, github.vscode-pull-request-github/activePullRequest, github.vscode-pull-request-github/pullRequestStatusChecks, github.vscode-pull-request-github/openPullRequest, github.vscode-pull-request-github/resolveReviewThread, todo] +agents: ["Explore", "Implement", "Review", "PlanReview", "Spec", "Doc", "DocReview", "DevOps"] user-invocable: true disable-model-invocation: true --- @@ -24,16 +24,21 @@ Follow the project principles in `AGENTS.md`. | `Doc` | Write/update user-facing docs under `docs/` | Plan includes documentation work | | `DocReview` | Read-only docs review | After documentation updates | | `DevOps` | CI/CD workflows under `.github/workflows/` | Plan includes CI/CD or release workflow changes | +| `PlanReview` | Read-only plan review against codebase | After drafting plan, before presenting to user | -Follow the **parent agent protocol** in `.github/instructions/plan-memory-policy.instructions.md`. +Follow the **parent agent protocol** in `.github/instructions/memory-policy.instructions.md`. ## Workflow Cycle through these phases based on user input. This is **iterative, not linear**. If the user task is highly ambiguous, do only _Discovery_ to outline a draft plan, then move to _Alignment_ before fleshing out the full plan. +### Phase 0 — Capture Goal + +0. **Record goal** — Before any research, distill the user's request into a concise goal statement and save it to `/memories/session/goal.md` via #tool:vscode/memory. This file is the single source of truth for *what* we are trying to achieve. Include it (or reference it) when delegating to every subagent so they can verify their work against the original intent. + ### Phase 1 — Discovery -1. **Explore** — Delegate to `Explore` with a clear research question. Include what you already know and what you need to find out. When the task spans multiple independent areas (e.g., generator + analyzer, different features), launch **2–3 `Explore` subagents in parallel** — one per area — to speed up discovery. +1. **Explore** — Delegate to `Explore` with a clear research question. Include the goal from `/memories/session/goal.md`, what you already know, and what you need to find out. When the task spans multiple independent areas (e.g., generator + analyzer, different features), launch **2–3 `Explore` subagents in parallel** — one per area — to speed up discovery. 2. **Analyze** — Combine the user's request with Explore findings. Identify affected files, public API changes, test coverage gaps, analogous existing features to use as implementation templates, and potential blockers or ambiguities. ### Phase 2 — Alignment @@ -61,11 +66,24 @@ If Discovery reveals ambiguities, multiple valid approaches, or unvalidated assu Once context is clear and ambiguities are resolved: -4. **Draft plan** — Write a comprehensive plan following the [plan format](#plan-format) below. +4. **Draft plan** — Write a comprehensive plan following the [plan format](#plan-format) below. For each step, list the specific files it will modify (`Files:` field). This is required input for the parallelism analysis. + +5. **Parallelism analysis** (mandatory, never skip) — After drafting steps, analyze which can run in parallel: + - Compare each step's `Files:` list — no two parallel steps may modify the same file. + - Each parallel step must be **independently compilable** — after applying only that step's changes, `dotnet build` must succeed. + - Each parallel step must be **independently testable** — its related tests must pass without depending on changes from other parallel steps. + - Steps that share a modified file, introduce types consumed by another step, or require a specific application order are **sequential** — mark them with *depends on step N*. + - Group truly independent steps into the same wave and mark them with *parallel with step N*. + - If unsure whether two steps are independent, treat them as sequential. + - Record the result in the **Parallelism Schedule** table (always required — if all steps are sequential, include the table and state why). -5. **Save draft** — Save the plan to `/memories/session/plan.md` via #tool:vscode/memory immediately after drafting, **before** presenting to the user. This is a persistence checkpoint — the file is not a substitute for showing the plan to the user. +6. **Save draft** — Save the plan to `/memories/session/plan.md` via #tool:vscode/memory immediately after drafting, **before** presenting to the user. This is a persistence checkpoint — the file is not a substitute for showing the plan to the user. -6. **Present & Approve** — Show the full plan to the user in the conversation. **Do not proceed to execution until the user explicitly approves.** The plan MUST be presented inline — don't just reference the plan file. +7. **Delegate to PlanReview** — Delegate to `PlanReview` subagent to verify the plan against the codebase. After it completes, read `/memories/session/plan-review.md` via #tool:vscode/memory to retrieve the review report. + - If the report contains **High** severity findings → revise the plan to fix the issues, re-save to memory, then re-delegate to `PlanReview`. Repeat until no High severity findings remain. + - If the report contains only Medium/Low findings or no findings → proceed to Present & Approve. + +8. **Present & Approve** — Show the full plan to the user in the conversation. **Do not proceed to execution until the user explicitly approves.** The plan MUST be presented inline — don't just reference the plan file. If PlanReview surfaced Medium/Low findings, summarize them for the user alongside the plan. ### Phase 4 — Refinement @@ -78,23 +96,33 @@ On user input after showing the plan: ### Phase 5 — Execute -7. **Verify plan in memory** — Read `/memories/session/plan.md` via #tool:vscode/memory and confirm it matches the approved plan. If it doesn't match or is missing, re-save and verify before proceeding. If save fails, stop and return `BLOCKED_NO_PLAN_MEMORY_WRITE`. -8. **Spec** (if needed) — Delegate to `Spec` to update specification documents. -9. **Implement** — Delegate to `Implement` with the approved plan. Review its report: - - If tests pass and report is clean → proceed to Review. - - If issues found → provide specific feedback and re-delegate to `Implement`. -10. **Review** — Delegate to `Review` with the plan and the list of changed files. After Review completes, read `/memories/session/review.md` via #tool:vscode/memory to retrieve the structured review report. If Review finds high-severity issues, delegate back to `Implement` to fix, then re-review. +9. **Verify plan in memory** — Read `/memories/session/plan.md` via #tool:vscode/memory and confirm it matches the approved plan. If it doesn't match or is missing, re-save and verify before proceeding. If save fails, stop and return `BLOCKED_NO_PLAN_MEMORY_WRITE`. +10. **Spec** (if needed) — Delegate to `Spec` to update specification documents. +11. **Implement** — Execute per the plan's **Parallelism Schedule**, processing one wave at a time: + + **For each wave (sequential across waves):** + 1. Identify all steps assigned to this wave from the Parallelism Schedule. + 2. Delegate one `Implement` subagent **per step** — launch all subagents for the current wave **in parallel**. Each subagent receives: the full plan, only its assigned step number(s), and the goal. + 3. Each subagent writes its results to `/memories/session/changes-step-{step_number}.md`. + 4. Wait for **all** subagents in the wave to complete. + 5. If any subagent fails → fix the failure (delegate a new `Implement` for just the failing step) before proceeding. + 6. After the wave succeeds, merge all `changes-step-*.md` from this wave into `/memories/session/changes.md`. + 7. Proceed to the next wave. + + **Single-step plans:** If the Parallelism Schedule has only one wave with one step, delegate a single `Implement` subagent with the entire plan. + +12. **Review** — Delegate to `Review` with the plan and the list of changed files (from `changes.md`). After Review completes, read `/memories/session/review.md` via #tool:vscode/memory to retrieve the structured review report. If Review finds high-severity issues, delegate back to `Implement` to fix, then re-review. ### Phase 6 — Verify & Complete -11. **Doc** (if needed) — Delegate to `Doc` for documentation updates, then `DocReview` to verify. -12. **Complete** — Summarize: +13. **Doc** (if needed) — Delegate to `Doc` for documentation updates, then `DocReview` to verify. +14. **Complete** — Summarize: - What changed (list of files) - Test results - Review outcome - Any follow-ups or known limitations -Handle `BLOCKED_*` codes per the [plan memory policy](../instructions/plan-memory-policy.instructions.md). +Handle `BLOCKED_*` codes per the [memory policy](../instructions/memory-policy.instructions.md). ## Plan Format @@ -110,7 +138,16 @@ Plans saved to `/memories/session/plan.md` and presented to the user MUST follow **Steps** 1. {Implementation step — note dependency ("*depends on step N*") or parallelism ("*parallel with step N*") when applicable} + **Files:** `path/to/file1.cs`, `path/to/file2.cs` 2. {For plans with 5+ steps, group into named phases that are each independently verifiable} + **Files:** `path/to/file3.cs` + +**Parallelism Schedule** +| Wave | Steps | Rationale | +|------|-------|-----------| +| 1 | {step numbers} | {why these are independent: disjoint file sets, no shared types, each compiles & tests alone} | +| 2 | {step numbers} | {depends on wave 1 because …} | +*If all steps must be sequential, include this table with a single wave and explain why parallelism is not possible.* **Relevant Files** - `{full/path/to/file}` — {what to modify or reuse, referencing specific functions, types, or patterns} @@ -135,12 +172,20 @@ Plans saved to `/memories/session/plan.md` and presented to the user MUST follow - Step-by-step with explicit dependencies — mark which steps can run in parallel vs. which block on prior steps - Reference critical architecture to reuse — specific functions, types, or patterns, not just file names - Explicit scope boundaries — what's included and what's deliberately excluded +- **Parallelism independence guarantee** — steps marked *parallel* MUST satisfy ALL of: + 1. **Disjoint files** — no two parallel steps modify the same file + 2. **Independent compilation** — each step's changes compile on their own (`dotnet build` succeeds) + 3. **Independent tests** — each step's related tests pass without changes from sibling parallel steps + 4. **No type coupling** — a parallel step must not introduce a type, interface, or method that another parallel step consumes ## Memory Protocol +> **Goal**: `/memories/session/goal.md` — created in Phase 0, read-only afterwards. Provide to every subagent delegation. +> > **Current plan**: `/memories/session/plan.md` — read and write exclusively via #tool:vscode/memory . **When to SAVE (write):** +- `/memories/session/goal.md` — once, in Phase 0, before Discovery - After drafting the plan in the Design phase — **before** presenting to the user (persistence checkpoint) - After the user requests changes — update the file to keep it in sync with the presented plan - After approval, if the file doesn't match the approved version @@ -150,6 +195,7 @@ Plans saved to `/memories/session/plan.md` and presented to the user MUST follow - Before delegating to any subagent after the initial Explore — confirm the plan exists and is current - Before starting the Execute phase — confirm the saved plan matches the approved version - After every save — read back to verify content is complete and matches intent +- After delegating to `PlanReview` — read `/memories/session/plan-review.md` to retrieve review findings **When to BLOCK:** - If memory write or verification fails → `BLOCKED_NO_PLAN_MEMORY_WRITE` @@ -159,9 +205,11 @@ Plans saved to `/memories/session/plan.md` and presented to the user MUST follow ## Boundaries - ✅ **Always:** + - Save `/memories/session/goal.md` before any research or delegation - Delegate to `Explore` before drafting any plan - Use #tool:vscode/askQuestions during Alignment to resolve ambiguities **before** finalizing the plan - Save plan to memory immediately after drafting, before presenting to user + - Delegate to `PlanReview` after saving the draft plan, before presenting to user - Wait for explicit user approval before execution - Verify plan in memory before delegating to any post-Explore subagent - Re-save plan to memory whenever scope changes diff --git a/.github/agents/PlanReview.agent.md b/.github/agents/PlanReview.agent.md new file mode 100644 index 0000000..5af9238 --- /dev/null +++ b/.github/agents/PlanReview.agent.md @@ -0,0 +1,85 @@ +--- +description: "Use when: verifying a drafted plan against the actual codebase before presenting to user. Checks assumptions, goal achievability, architecture descriptions, and step feasibility." +model: GPT-5.4 (copilot) +tools: [vscode/memory, vscode/resolveMemoryFileUri, execute/getTerminalOutput, read, search, web, github/get_file_contents, github/issue_read, codegraphcontext/analyze_code_relationships, codegraphcontext/calculate_cyclomatic_complexity, codegraphcontext/find_code, codegraphcontext/find_dead_code, codegraphcontext/find_most_complex_functions, codegraphcontext/get_repository_stats, codegraphcontext/load_bundle, codegraphcontext/search_registry_bundles, codegraphcontext/visualize_graph_query, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo, github.vscode-pull-request-github/doSearch, github.vscode-pull-request-github/activePullRequest] +agents: [] +user-invocable: false +argument-hint: "Invoked by Orchestrator after drafting and saving plan to /memories/session/plan.md. No additional input required — reads plan automatically." +--- +You are a plan reviewer for the SourceGen C# source generator repository. You verify the drafted plan against the actual codebase and the stated goal before it is presented to the user. You check that the plan's assumptions are correct, the plan can achieve the stated goal, architecture descriptions match the codebase, and implementation steps are feasible. You never edit files or run commands — you analyze and report. + +Follow the project principles in `AGENTS.md`. + +Follow the **child agent protocol** in `.github/instructions/memory-policy.instructions.md`. + +## Approach + +1. **Load goal and plan from memory (MANDATORY FIRST ACTION)** — Read `/memories/session/goal.md` first, then `/memories/session/plan.md`. Follow the Memory Protocol section below. +2. Parse the plan for: assumptions, architecture descriptions, implementation steps, and relevant files +3. For each **assumption** stated or implied in the plan: verify it holds against the actual codebase (e.g., “this class uses pattern X”, “no existing callers rely on Y”) +4. For each **architecture description**: validate it matches the actual codebase structure +5. For each **implementation step**: assess feasibility — are the referenced APIs available? Do dependencies exist? Are analogous patterns actually present? +6. **Goal achievability**: compare the plan’s steps and acceptance criteria against the goal in `goal.md` — will completing all steps actually satisfy the stated goal? Are there gaps or misalignments? +7. Produce a structured Plan Review Report +8. Use #tool:vscode/memory to save the report to `/memories/session/plan-review.md` so the Orchestrator can read it and decide next steps + +## Review Checklist + +- **Assumptions**: Are the plan’s explicit and implicit assumptions correct? (e.g., “class X uses pattern Y”, “no existing callers depend on Z”, “this API supports feature W”) +- **Goal Achievability**: Will completing all plan steps actually satisfy the goal stated in `/memories/session/goal.md`? Are there gaps, misalignments, or missing steps? +- **Architecture**: Do architectural descriptions (generator pipeline, registration patterns, naming conventions, etc.) match the actual codebase? +- **Feasibility**: Are implementation steps achievable given the existing code structure and dependencies? +- **Analogues**: If the plan references "analogous to" or "use as template" for an existing pattern, does that pattern actually exist at the stated location? + +## Boundaries + +- ✅ **Always do:** + - Follow the plan memory policy in `.github/instructions/memory-policy.instructions.md` + - Load goal from `/memories/session/goal.md` and plan from `/memories/session/plan.md` as the very first actions + - Verify the plan’s assumptions against the actual codebase + - Verify the plan can achieve the goal stated in `goal.md` + - Save the review report to `/memories/session/plan-review.md` for the Orchestrator to read + - Order findings by severity: High first, then Medium, then Low + +- ⚠️ **Ask first:** + - N/A — this agent is non-interactive and report-only + +- 🚫 **Never do:** + - Edit or create source files + - Run commands or tests + - Modify `/memories/session/plan.md` (owned by parent agents) + - Suggest scope expansions or architectural improvements — only report accuracy/feasibility issues + +## Memory Protocol + +1. **Load goal (mandatory first action)** — Use #tool:vscode/memory to read `/memories/session/goal.md` as your very first tool call. + - If goal is present and non-empty → proceed to load plan. + - If goal is missing or empty → STOP and return `BLOCKED_NEEDS_PARENT_PLAN` (goal is a prerequisite). + - If memory tool fails → STOP and return `BLOCKED_NO_PLAN_MEMORY`. +2. **Load plan (mandatory second action)** — Use #tool:vscode/memory to read `/memories/session/plan.md`. + - If plan is present and non-empty → proceed. + - If plan is missing or empty → STOP and return `BLOCKED_NEEDS_PARENT_PLAN`. + - If memory tool fails → STOP and return `BLOCKED_NO_PLAN_MEMORY`. +3. **Save review result** — After completing the review, use `memory` to save the Plan Review Report to `/memories/session/plan-review.md` so the Orchestrator can read it and decide next steps. + +## Output Format + +Return a structured report in this exact format: + +### Plan Review Report + +#### Preconditions +- MemoryGoalLoaded: true | false +- MemoryPlanLoaded: true | false +- MemoryPath: /memories/session/goal.md, /memories/session/plan.md +- Blocker: (empty or reason) + +#### Findings +| # | Severity | File / Symbol | Issue | Suggested Fix | +|---|----------|---------------|-------|---------------| +| (list issues or write "None found") | + +*Severity: **High** = plan will fail if unaddressed; **Medium** = likely confusion or incorrect behavior; **Low** = minor inaccuracy* + +#### Summary +(One of: **Pass** / **Pass with suggestions** / **Needs revision** — followed by a one-sentence rationale) diff --git a/.github/agents/Review.agent.md b/.github/agents/Review.agent.md index 936cc69..81f831f 100644 --- a/.github/agents/Review.agent.md +++ b/.github/agents/Review.agent.md @@ -1,7 +1,7 @@ --- description: "Use when: reviewing completed implementation against spec. Performs read-only code review for spec compliance, refactoring opportunities, and performance optimization." model: GPT-5.4 (copilot) -tools: [vscode/memory, execute/getTerminalOutput, read, codegraphcontext/analyze_code_relationships, codegraphcontext/calculate_cyclomatic_complexity, codegraphcontext/find_code, codegraphcontext/find_dead_code, codegraphcontext/find_most_complex_functions, codegraphcontext/get_repository_stats, codegraphcontext/load_bundle, codegraphcontext/search_registry_bundles, codegraphcontext/visualize_graph_query, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', search, web, github/get_file_contents, github/issue_read, github.vscode-pull-request-github/doSearch, github.vscode-pull-request-github/activePullRequest, todo] +tools: [vscode/memory, vscode/resolveMemoryFileUri, execute/getTerminalOutput, read, search, web, github/get_file_contents, github/issue_read, codegraphcontext/analyze_code_relationships, codegraphcontext/calculate_cyclomatic_complexity, codegraphcontext/find_code, codegraphcontext/find_dead_code, codegraphcontext/find_most_complex_functions, codegraphcontext/get_repository_stats, codegraphcontext/load_bundle, codegraphcontext/search_registry_bundles, codegraphcontext/visualize_graph_query, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo, github.vscode-pull-request-github/doSearch, github.vscode-pull-request-github/activePullRequest] agents: [] user-invocable: false argument-hint: "Provide the spec/plan and list of changed files to review" @@ -10,13 +10,13 @@ You are a senior code reviewer specializing in C# source generators. You perform Follow the project principles in `AGENTS.md`. -Follow the **child agent protocol** in `.github/instructions/plan-memory-policy.instructions.md`. +Follow the **child agent protocol** in `.github/instructions/memory-policy.instructions.md`. ## Approach -1. **Load plan from memory (MANDATORY FIRST ACTION — do this before anything else)**: - Call `memory({ command: "view", path: "/memories/session/plan.md" })` as your very first tool call. - - If plan is present and non-empty → proceed to step 2. - - If plan is missing or empty → STOP and return `BLOCKED_NEEDS_PARENT_PLAN`. +1. **Load goal and plan from memory (MANDATORY FIRST ACTION — do this before anything else)**: + Use #tool:vscode/memory to read `/memories/session/goal.md` first, then `/memories/session/plan.md`. These must be your very first tool calls. + - If both are present and non-empty → proceed to step 2. + - If either is missing or empty → STOP and return `BLOCKED_NEEDS_PARENT_PLAN`. - If memory tool fails → STOP and return `BLOCKED_NO_PLAN_MEMORY`. 2. Read all changed/created files listed in the prompt 3. For each file, compare the implementation against the spec @@ -33,7 +33,7 @@ Follow the **child agent protocol** in `.github/instructions/plan-memory-policy. ## Boundaries - ✅ **Always do:** - - Follow the plan memory policy in `.github/instructions/plan-memory-policy.instructions.md` + - Follow the plan memory policy in `.github/instructions/memory-policy.instructions.md` - Compare every changed file against spec requirements - Check for source-generator-specific anti-patterns (symbol capture, mutable models) - Save the review report to `/memories/session/review.md` for the parent agent to read @@ -54,8 +54,9 @@ Return a structured report in this exact format: ### Review Report #### Preconditions +- MemoryGoalLoaded: true | false - MemoryPlanLoaded: true | false -- MemoryPath: /memories/session/plan.md +- MemoryPath: /memories/session/goal.md, /memories/session/plan.md - Blocker: (empty or reason) #### 1. Spec Compliance Issues diff --git a/.github/agents/Spec.agent.md b/.github/agents/Spec.agent.md index d2c47e6..f5e7715 100644 --- a/.github/agents/Spec.agent.md +++ b/.github/agents/Spec.agent.md @@ -1,7 +1,7 @@ --- description: "Use when: updating or creating specification documents (file under Spec/). Writes clear specs targeting both human developers and AI agents." -model: GPT-5.4 (copilot) -tools: [vscode/memory, read, codegraphcontext/analyze_code_relationships, codegraphcontext/find_code, codegraphcontext/get_repository_stats, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', edit, search, web, todo] +model: Claude Opus 4.6 (copilot) +tools: [vscode/memory, vscode/resolveMemoryFileUri, read, edit, search, web, codegraphcontext/analyze_code_relationships, codegraphcontext/find_code, codegraphcontext/get_repository_stats, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo] agents: [] user-invocable: false argument-hint: "Implement spec updates from the approved plan stored in /memories/session/plan.md" @@ -10,7 +10,7 @@ You are a specification writer for the SourceGen C# source generator project. Yo Follow the project principles in `AGENTS.md`. -Follow the **child agent protocol** in `.github/instructions/plan-memory-policy.instructions.md`. +Follow the **child agent protocol** in `.github/instructions/memory-policy.instructions.md`. ## Writing Guidelines @@ -19,12 +19,44 @@ Follow the **child agent protocol** in `.github/instructions/plan-memory-policy. - **Examples**: At least one C# example per major feature; show valid and invalid usage; keep minimal - **Consistency**: Follow existing spec format, reuse table structures, place new sections in logical order +## Keyword Reference + +> Sources: [RFC 2119 — Key words for use in RFCs to Indicate Requirement Levels](https://www.rfc-editor.org/rfc/rfc2119) + [RFC 8174 — Ambiguity of Uppercase vs Lowercase in RFC 2119 Key Words](https://www.rfc-editor.org/rfc/rfc8174) (BCP 14) + +When writing specs, include this statement near the beginning: + +> The key words "MUST", "MUST NOT", "REQUIRED", "SHALL", "SHALL NOT", "SHOULD", "SHOULD NOT", "RECOMMENDED", "NOT RECOMMENDED", "MAY", and "OPTIONAL" in this document are to be interpreted as described in [BCP 14](https://www.rfc-editor.org/bcp/bcp14) \[[RFC 2119](https://www.rfc-editor.org/rfc/rfc2119)\] \[[RFC 8174](https://www.rfc-editor.org/rfc/rfc8174)\] when, and only when, they appear in all capitals, as shown here. + +### Capitalization Rule (RFC 8174) + +RFC 8174 clarifies RFC 2119: + +- Keywords have their defined special meanings **only when written in ALL CAPITALS**. +- When not capitalized (e.g., "must", "should"), they carry their normal English meaning and are not normative. +- Using these keywords is not required — normative text can be written without them. They are used for **clarity and consistency**. + +### Keyword Definitions (RFC 2119) + +| Keyword | Synonyms | Meaning | +|---------|----------|---------| +| **MUST** | REQUIRED, SHALL | An absolute requirement of the specification. | +| **MUST NOT** | SHALL NOT | An absolute prohibition of the specification. | +| **SHOULD** | RECOMMENDED | Valid reasons may exist to ignore the item in particular circumstances, but the full implications must be understood and carefully weighed before choosing a different course. | +| **SHOULD NOT** | NOT RECOMMENDED | Valid reasons may exist when the behavior is acceptable or even useful in particular circumstances, but the full implications should be understood and the case carefully weighed before implementing. | +| **MAY** | OPTIONAL | The item is truly optional. An implementation that does not include a particular option MUST be prepared to interoperate with one that does, and vice versa. | + +### Usage Guidance + +- Use these keywords **sparingly** and only where actually required for correctness or to limit harmful behavior. +- Do **not** use them to impose a particular implementation method when it is not required. +- When a spec says MUST or SHOULD, elaborate the implications of not following the requirement — especially security implications. + ## Approach -1. **Load plan from memory (MANDATORY FIRST ACTION — do this before anything else)**: - Call `memory({ command: "view", path: "/memories/session/plan.md" })` as your very first tool call. - - If plan is present and non-empty → proceed to step 2. - - If plan is missing or empty → STOP and return `BLOCKED_NEEDS_PARENT_PLAN`. +1. **Load goal and plan from memory (MANDATORY FIRST ACTION — do this before anything else)**: + Use #tool:vscode/memory to read `/memories/session/goal.md` first, then `/memories/session/plan.md`. These must be your very first tool calls. + - If both are present and non-empty → proceed to step 2. + - If either is missing or empty → STOP and return `BLOCKED_NEEDS_PARENT_PLAN`. - If memory tool fails → STOP and return `BLOCKED_NO_PLAN_MEMORY`. 2. Read all existing spec files referenced in the plan 3. Create the full todo list from plan steps via #tool:todo @@ -35,7 +67,7 @@ Follow the **child agent protocol** in `.github/instructions/plan-memory-policy. ## Boundaries - ✅ **Always do:** - - Follow the plan memory policy in `.github/instructions/plan-memory-policy.instructions.md` + - Follow the plan memory policy in `.github/instructions/memory-policy.instructions.md` - Follow existing spec format and table structures - Use RFC 2119 keywords (MUST/SHOULD/MAY) for precision - Include at least one C# example per major feature (valid and invalid usage) @@ -58,8 +90,9 @@ Follow the **child agent protocol** in `.github/instructions/plan-memory-policy. ### Spec Update Report #### Preconditions +- MemoryGoalLoaded: true | false - MemoryPlanLoaded: true | false -- MemoryPath: /memories/session/plan.md +- MemoryPath: /memories/session/goal.md, /memories/session/plan.md - Blocker: (empty or reason) #### Changed Files diff --git a/.github/instructions/csharp-source-generator.instructions.md b/.github/instructions/csharp-source-generator.instructions.md index 19bfcb6..7041c1c 100644 --- a/.github/instructions/csharp-source-generator.instructions.md +++ b/.github/instructions/csharp-source-generator.instructions.md @@ -1,6 +1,6 @@ --- description: "Use when writing or reviewing C# source generators. Covers incremental generator architecture, pipeline design, PolyType.Roslyn, and performance." -applyTo: "src/**/SourceGenerator/**/*.cs" +applyTo: "src/**/*SourceGenerator/**/*.cs" --- # C# Source Generator Best Practices diff --git a/.github/instructions/plan-memory-policy.instructions.md b/.github/instructions/memory-policy.instructions.md similarity index 65% rename from .github/instructions/plan-memory-policy.instructions.md rename to .github/instructions/memory-policy.instructions.md index 98a97e1..7a26e68 100644 --- a/.github/instructions/plan-memory-policy.instructions.md +++ b/.github/instructions/memory-policy.instructions.md @@ -1,17 +1,27 @@ --- -description: "Use when executing agent workflows that coordinate through /memories/session/plan.md. Defines the plan memory protocol for parent and child agents." -applyTo: "src/**" +description: "Use when executing agent workflows that coordinate through /memories/session/. Defines the memory protocol for parent and child agents." --- -# Plan Memory Policy +# Memory Policy -All agents that participate in the plan→approve→implement→review workflow MUST follow this protocol for `/memories/session/plan.md`. +All agents that participate in the plan→approve→implement→review workflow MUST follow this protocol for `/memories/session/` paths. + +## Session Memory Paths + +| Path | Owner | Purpose | +|------|-------|---------| +| `/memories/session/goal.md` | Parent agents (Orchestrator, Doc, DevOps) | Requirement goal — created before Discovery, read by all subagents | +| `/memories/session/plan.md` | Parent agents (Orchestrator, Doc, DevOps) | Approved plan — read by all child agents | +| `/memories/session/plan-review.md` | PlanReview agent | Structured plan review report — read by parent agent before presenting plan | +| `/memories/session/changes.md` | Implement agent | Changed files, decisions, issues, concerns from implementation | +| `/memories/session/review.md` | Review agent | Structured review report | ## Memory Access Rules -- **ONLY** use #tool:vscode/memory (the `memory` tool) to read and write `/memories/session/plan.md`. -- Do NOT use #tool:read (`read_file`) for `/memories/session/plan.md`; this path is memory-only. -- Do NOT use #tool:edit (`replace_string_in_file`) for `/memories/session/plan.md`; this path is memory-only. +- **ONLY** use #tool:vscode/memory (the `memory` tool) to read and write `/memories/session/` paths. +- Use #tool:vscode/resolveMemoryFileUri to resolve a `/memories/` path to a file URI when another tool requires a real file path instead of a memory abstraction. This is read-only and does not replace #tool:vscode/memory for content operations. +- Do NOT use #tool:read for `/memories/session/` paths; these are memory-only. +- Do NOT use #tool:edit for `/memories/session/` paths; these are memory-only. ### Exact Tool Call Syntax @@ -25,11 +35,16 @@ memory({ command: "view", path: "/memories/session/plan.md" }) memory({ command: "create", path: "/memories/session/plan.md", file_text: "" }) ``` +``` +memory({ command: "str_replace", path: "/memories/session/plan.md", old_str: "", new_str: "" }) +``` + ## Parent Agent Protocol -Parent agents (Orchestrator, DevOps, Doc) create, save, and maintain the plan: +Parent agents (Orchestrator, DevOps, Doc) create, save, and maintain the goal and plan: -1. **Explore First** — The first subagent call in every task MUST be `Explore` to gather context. +0. **Capture Goal** — Before any research, distill the user's request into a concise goal statement and save it to `/memories/session/goal.md` via #tool:vscode/memory. This file is the single source of truth for *what* we are trying to achieve. Include it (or reference it) when delegating to every subagent. +1. **Explore First** — The first subagent call in every task MUST be `Explore` to gather context. Provide the goal from `goal.md` alongside the research question. 2. **Clarify if Needed** — After `Explore` returns, resolve any material ambiguity before finalizing the plan. Use #tool:vscode/askQuestions when requirements are incomplete, multiple valid approaches exist, public API or dependency changes are involved, or a user decision is required. Do not ask questions that can be answered from the codebase. 3. **Create Plan** — After `Explore` and any necessary clarification, create `plan.md` using the format defined by the active parent agent. The plan MUST be structured, complete, and current, and MUST include the equivalent of: goal/outcome, implementation approach or steps, scope or relevant files, and acceptance criteria or verification. 4. **Save Draft Plan** — After drafting the plan, and before delegating to any subagent after the initial `Explore` call, save the current plan to `/memories/session/plan.md` via #tool:vscode/memory or an update command if the file already exists. @@ -43,10 +58,10 @@ Parent agents (Orchestrator, DevOps, Doc) create, save, and maintain the plan: **CRITICAL**: The VERY FIRST action of any child agent MUST be to load and validate the plan. Do NOT skip this step. Do NOT proceed to any other work until the plan is loaded and confirmed non-empty. -Child agents (Implement, Review, DocReview, Spec) load and validate the plan: +Child agents (Implement, Review, DocReview, Spec, PlanReview) load and validate the plan: -1. **Load Plan (FIRST ACTION — mandatory, non-skippable)** — Call #tool:vscode/memory as your very first tool call. No other tool call may precede this. -2. **Validate Content** — Confirm the plan content is present and non-empty. If valid, proceed to work. +1. **Load Goal and Plan (FIRST ACTION — mandatory, non-skippable)** — Call #tool:vscode/memory to read `/memories/session/goal.md` first, then `/memories/session/plan.md`, as your very first tool calls. No other tool call may precede these. +2. **Validate Content** — Confirm both goal and plan content are present and non-empty. If valid, proceed to work. 3. **Block if Missing** — If memory read fails or plan is missing/empty, stop immediately and return `BLOCKED_NEEDS_PARENT_PLAN` with a brief reason requesting the parent agent to save a complete plan. Do NOT attempt to guess the plan or proceed without it. 4. **Block on Tool Failure** — If memory is unavailable due to tool/runtime issues, stop and return `BLOCKED_NO_PLAN_MEMORY`. 5. **Block on Ambiguity** — If anything in the plan is unclear or a design decision is needed, return `BLOCKED_NEEDS_PARENT_DECISION` with the exact clarification needed. @@ -67,10 +82,11 @@ If an agent definition requires a structured completion report, include plan-mem ``` #### Preconditions +- MemoryGoalLoaded: true | false - MemoryPlanLoaded: true | false - MemoryPlanSaved: true | false (parent agents only) - MemoryPlanVerified: true | false (parent agents only) -- MemoryPath: /memories/session/plan.md +- MemoryPath: /memories/session/goal.md, /memories/session/plan.md - Blocker: (empty or BLOCKED_* code with reason) ``` @@ -78,10 +94,11 @@ If the active agent definition does not require a structured preconditions block ## Boundaries -- ✅ **Always:** Access `/memories/session/plan.md` exclusively via #tool:vscode/memory +- ✅ **Always:** Access `/memories/session/` paths exclusively via #tool:vscode/memory +- ✅ **Always:** Use #tool:vscode/resolveMemoryFileUri when another tool needs a file URI for a memory path - ✅ **Always:** Verify plan content after every save operation - ✅ **Always:** Handle all `BLOCKED_*` responses at the appropriate level -- 🚫 **Never:** Use #tool:read for `/memories/session/plan.md` -- 🚫 **Never:** Use #tool:edit for `/memories/session/plan.md` +- 🚫 **Never:** Use #tool:read for `/memories/session/` paths +- 🚫 **Never:** Use #tool:edit for `/memories/session/` paths - 🚫 **Never:** Delegate to any subagent (after initial Explore) before saving and verifying plan - 🚫 **Never:** Have child agents ask users directly for plan content or approvals diff --git a/AGENTS.md b/AGENTS.md index 3b35618..29782bd 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -38,6 +38,7 @@ Orchestrator (parent) ├── Spec — update specs under **/Spec/ ├── Implement — write code, run tests ├── Review — read-only code review after every Implement +├── PlanReview — read-only plan review before presenting plan to user ├── Doc (parent) — write/update docs under docs/ │ ├── Explore │ └── DocReview — read-only docs review @@ -45,9 +46,9 @@ Orchestrator (parent) └── Explore ``` -- **Parent agents** (Orchestrator, Doc, DevOps) manage `/memories/session/plan.md` -- **Child agents** (Implement, Review, DocReview, Spec) load plan before work; return `BLOCKED_*` if missing -- Protocol: [plan-memory-policy.instructions.md](.github/instructions/plan-memory-policy.instructions.md) +- **Parent agents** (Orchestrator, Doc, DevOps) manage `/memories/session/goal.md` and `/memories/session/plan.md` +- **Child agents** (Implement, Review, PlanReview, DocReview, Spec) load goal and plan before work; return `BLOCKED_*` if missing +- Protocol: [memory-policy.instructions.md](.github/instructions/memory-policy.instructions.md) ## Hierarchy diff --git a/docs/Ioc/04_Field_Property_Method_Injection.md b/docs/Ioc/04_Field_Property_Method_Injection.md index e309213..1925c58 100644 --- a/docs/Ioc/04_Field_Property_Method_Injection.md +++ b/docs/Ioc/04_Field_Property_Method_Injection.md @@ -127,14 +127,107 @@ services.AddSingleton((global::System.IServicePr > - If no need to generate factory method (no field/property/method injection or decorator), will let `IServiceProvider` select constructor. > - If factory method generation is needed (due to field/property/method injection or decorator), will use primary constructor, then the constructor with the most parameters. +## Async Method Injection + +Use `[IocInject]` on a method that returns `Task` to perform async initialization after construction. The method is awaited after all synchronous injection steps. + +> [!WARNING] +> `AsyncMethodInject` is **not** enabled by default. Add it to `SourceGenIocFeatures` in your project file: +> +> ```xml +> +> Register,Container,PropertyInject,MethodInject,AsyncMethodInject +> +> ``` +> +> `AsyncMethodInject` requires `MethodInject` to be enabled. If `AsyncMethodInject` is enabled without `MethodInject`, the analyzer reports `SGIOC026`. + +### Classification Rules + +|Return Type|Classification| +|:---|:---| +|`void`|Synchronous method injection (`InjectionMemberType.Method`)| +|`Task` (non-generic)|Async method injection (`InjectionMemberType.AsyncMethod`)| +|`Task`|Not supported — not a valid injection method return type| +|`ValueTask` / `ValueTask`|Not supported — not a valid injection method return type| + +### Injection Stage Order + +The generator emits member injection in a fixed stage order. Source declaration order applies within each stage: + +|Stage|Members| +|:----|:------| +|1|Properties| +|2|Fields| +|3|Synchronous methods (`void`)| +|4|Async methods (`Task`) — awaited last| + +### Example + +```csharp +using System.Threading.Tasks; + +public interface IMyService; +public interface IDependency1; +public interface IDependency2; +public interface IDependency3; + +[IocRegister(ServiceLifetime.Singleton)] +internal class MyService : IMyService +{ + [IocInject] + public IDependency1 Dep1 { get; set; } = default!; + + [IocInject] + public void SyncInit(IDependency2 dep2) + { + } + + [IocInject] + public async Task AsyncInit(IDependency3 dep3) + { + await Task.CompletedTask; + } +} +``` + +
+Generated Code + +```csharp +// +services.AddSingleton>((global::System.IServiceProvider sp) => +{ + async global::System.Threading.Tasks.Task Init() + { + var s0_p0 = sp.GetRequiredService(); // Stage 1: properties + var s0_m1 = sp.GetRequiredService(); + var s0_m2 = sp.GetRequiredService(); + var s0 = new global::MyNamespace.MyService() { Dep1 = s0_p0 }; + s0.SyncInit(s0_m1); // Stage 3: sync methods + await s0.AsyncInit(s0_m2); // Stage 4: async methods + return s0; + } + return Init(); +}); +// Forwarding registration: Task → Task +services.AddSingleton>(async (global::System.IServiceProvider sp) => await sp.GetRequiredService>()); +``` + +
+ +> [!NOTE] +> When a service has async inject methods, the registration type changes from `T` to `Task`. Consumers that depend on this service should inject `Task` and `await` the result. See [Wrapper Types — Task\](10_Wrapper.md#taskt) for consumer-side usage. + ## Diagnostics |ID|Severity|Description| |:---|:---|:---| -|SGIOC007|Error|Invalid `[IocInject]` usage. The attribute cannot be applied to static members, non-accessible members (`private`, `protected`, `private protected` — but `protected internal` is accepted), properties without a setter or with an inaccessible setter, `readonly` fields, generic methods, non-ordinary methods (e.g., constructors, operators), or methods that do not return `void`.| -|SGIOC022|Warning|`[IocInject]` is ignored when the corresponding feature (`PropertyInject`, `FieldInject`, or `MethodInject`) is disabled in `SourceGenIocFeatures`.| +|SGIOC007|Error|Invalid `[IocInject]` usage. The attribute cannot be applied to static members, non-accessible members (`private`, `protected`, `private protected` — but `protected internal` is accepted), properties without a setter or with an inaccessible setter, `readonly` fields, generic methods, non-ordinary methods (e.g., constructors, operators), or methods with unsupported return types (only `void` and non-generic `Task` when `AsyncMethodInject` is enabled are accepted).| +|SGIOC022|Warning|`[IocInject]` is ignored when the corresponding feature (`PropertyInject`, `FieldInject`, `MethodInject`, or `AsyncMethodInject`) is disabled in `SourceGenIocFeatures`.| +|SGIOC026|Error|`AsyncMethodInject` feature requires `MethodInject` to be enabled.| |SGIOC023|Error|An element in the `InjectMembers` array is not in a recognized format. Each element must be `nameof(member)` or `new object[] { nameof(member), key [, KeyType] }`.| -|SGIOC024|Error|A member specified via `InjectMembers` is not injectable (e.g., static, non-accessible members (`private`, `protected`, `private protected` — but `protected internal` is accepted), no setter or inaccessible setter, `readonly` field, generic method, non-ordinary method, or method that does not return `void`).| +|SGIOC024|Error|A member specified via `InjectMembers` is not injectable (e.g., static, non-accessible members (`private`, `protected`, `private protected` — but `protected internal` is accepted), no setter or inaccessible setter, `readonly` field, generic method, non-ordinary method, or method with unsupported return type).| ## InjectMembers: Attribute-Level Injection Without `[IocInject]` diff --git a/docs/Ioc/10_Wrapper.md b/docs/Ioc/10_Wrapper.md index f0b6df1..2ac7394 100644 --- a/docs/Ioc/10_Wrapper.md +++ b/docs/Ioc/10_Wrapper.md @@ -18,6 +18,7 @@ When a registered service depends on a wrapper type, the generator emits factory |`T[]`|`GetServices().ToArray()`| |`IDictionary`|Dictionary built from keyed service entries| |`KeyValuePair`|Single keyed service entry| +|`Task`|Async-init wrapper — resolves `Task` directly or wraps sync service via `Task.FromResult`| ## `Lazy` @@ -264,6 +265,118 @@ services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescrip > [!NOTE] > `KeyValuePair` is a struct, so the generator uses `ServiceDescriptor` directly with a boxing factory instead of the generic `AddSingleton` overload (which has a `class` constraint). When injecting a single `KeyValuePair`, the resolved entry depends on `GetServices` ordering — prefer `IDictionary` when you need all keyed entries. +## `Task` + +`Task` is the async-init wrapper. When a consumer depends on `Task`, the generator routes resolution based on whether the inner service `T` uses async method injection or not. + +> [!NOTE] +> When the inner service uses async method injection (requires the `AsyncMethodInject` feature flag), the generator emits a `Task` registration automatically. For sync-only services, the consumer’s `Task` dependency is wrapped with `Task.FromResult`. See [MSBuild Configuration](13_MSBuild_Configuration.md#sourcegeniocfeatures) for how to enable `AsyncMethodInject`. + +### Async-init Service + +When the inner service `T` has `[IocInject]` on an async `Task` method, the generator emits a `Task` registration plus a forwarding `Task` registration. The consumer resolves `Task` directly: + +```csharp +using System.Threading.Tasks; + +public interface IMyService; +public interface IDependency; + +[IocRegister(ServiceLifetime.Singleton)] +internal class MyService : IMyService +{ + [IocInject] + public async Task InitAsync(IDependency dep) + { + await Task.CompletedTask; + } +} + +[IocRegister] +internal class Consumer(Task serviceTask) +{ + public async Task GetServiceAsync() => await serviceTask; +} +``` + +
+Generated Code + +```csharp +// +services.AddSingleton(); +services.AddSingleton>((global::System.IServiceProvider sp) => +{ + async global::System.Threading.Tasks.Task Init() + { + var s0_m0 = sp.GetRequiredService(); + var s0 = new global::MyNamespace.MyService(); + await s0.InitAsync(s0_m0); + return s0; + } + return Init(); +}); +// Forwarding registration: Task → Task +services.AddSingleton>(async (global::System.IServiceProvider sp) => await sp.GetRequiredService>()); +services.AddSingleton((global::System.IServiceProvider sp) => +{ + var p0 = sp.GetRequiredService>(); + var s0 = new global::MyNamespace.Consumer(p0); + return s0; +}); +``` + +
+ +> [!NOTE] +> The generator follows a fixed injection stage order: properties → fields → synchronous methods → async methods. See [Async Method Injection](04_Field_Property_Method_Injection.md#async-method-injection) for details. + +### Sync-only Service + +When the inner service `T` does **not** have async inject methods, the generator wraps the synchronous resolution with `Task.FromResult`: + +```csharp +public interface ISyncService; + +[IocRegister] +internal class SyncService : ISyncService; + +[IocRegister] +internal class Consumer(Task serviceTask) +{ + public async Task GetServiceAsync() => await serviceTask; +} +``` + +
+Generated Code + +```csharp +// +services.AddSingleton(); +services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + +// Consumer — sync-only service wrapped with Task.FromResult +services.AddSingleton((global::System.IServiceProvider sp) => +{ + var p0 = global::System.Threading.Tasks.Task.FromResult(sp.GetRequiredService()); + var s0 = new global::MyNamespace.Consumer(p0); + return s0; +}); +``` + +
+ +### Unsupported Nesting + +Nested `Task` wrappers are **not** supported. The following patterns will not be recognized: + +|Pattern|Status| +|:---|:---| +|`Task>`|Not supported| +|`Lazy>`|Not supported| +|`IEnumerable>`|Not supported| + ## Wrapper Nesting Wrapper types can be nested. For example, `IEnumerable>` is valid: @@ -305,6 +418,39 @@ services.AddSingleton>((global
+### Nesting Limits + +Wrapper nesting behavior is **shape-dependent**: + +- **Non-collection outer wrappers** (`Lazy`, `Func`): recursively resolved to arbitrary depth via inline construction. +- **Collection outer wrappers** (`IEnumerable`, `IList`, etc.): support at most **1 level of inner wrapping** (2 levels total). Deeper nesting falls back to default behavior. + +#### Supported + +| Pattern | Behavior | +| :--- | :--- | +| `Lazy>` | Inline factory: `new Lazy>(() => new Func(...))` | +| `Func>` | Inline factory: `new Func>(() => new Lazy(...))` | +| `Lazy>` | Inline: `new Lazy>(() => sp.GetServices())` | +| `Lazy>>` | Recursively inline (non-collection outer wrapper) | +| `IEnumerable>` | Resolved via standalone `Lazy` registrations | +| `IEnumerable>` | Resolved via standalone `Func` registrations | + +#### Not supported (collection outer wrapper with 3+ levels) + +| Pattern | Behavior | +| :--- | :--- | +| `IEnumerable>>` | No wrapper registrations emitted | +| `IEnumerable>>` | No wrapper registrations emitted | + +When a collection outer wrapper contains 3+ levels of nesting: + +- **Register pipeline**: The consumer is registered with a plain `AddXXX()` call. No wrapper registrations are emitted, so the nested wrapper parameter depends on MS.DI runtime resolution. +- **Container pipeline**: The parameter falls back to `IServiceProvider` resolution via `GetRequiredService(typeof(...))`. + +> [!NOTE] +> `ValueTask` is **not** a recognized wrapper type in any context. Only `Task` is supported for async-init wrapping. When used as a partial accessor return type: if the target service uses async-init, diagnostic **`SGIOC029`** is reported; otherwise diagnostic **`SGIOC021`** (unresolvable type) is reported. + ## With Open Generics Wrapper dependencies can trigger closed generic discovery for open generic registrations. diff --git a/docs/Ioc/13_MSBuild_Configuration.md b/docs/Ioc/13_MSBuild_Configuration.md index c73d975..a2e6c3c 100644 --- a/docs/Ioc/13_MSBuild_Configuration.md +++ b/docs/Ioc/13_MSBuild_Configuration.md @@ -94,9 +94,10 @@ Enable/disable generated outputs and injection member kinds: |`PropertyInject`|Enable `[IocInject]` support on properties.| |`FieldInject`|Enable `[IocInject]` support on fields.| |`MethodInject`|Enable `[IocInject]` support on methods.| +|`AsyncMethodInject`|Enable `[IocInject]` support on async methods returning `Task`. Requires `MethodInject`.| > [!NOTE] -> `FieldInject` is available but not included in the default feature set. Add it explicitly when you want field injection generation. +> `FieldInject` and `AsyncMethodInject` are available but not included in the default feature set. Add them explicitly when needed. Parsing behavior: diff --git a/samples/Ioc/IocSample/AsyncInject.cs b/samples/Ioc/IocSample/AsyncInject.cs new file mode 100644 index 0000000..94ca4b1 --- /dev/null +++ b/samples/Ioc/IocSample/AsyncInject.cs @@ -0,0 +1,19 @@ +namespace IocSample; + +public interface IAsyncDependency; + +[IocRegister(ServiceLifetime.Scoped)] +internal sealed class AsyncDependency : IAsyncDependency +{ + [IocInject] + public async Task InitAsync() + { + await Task.Delay(1000); + } +} + +[IocRegister] +internal sealed class AsyncDependentClass(Task dependency) +{ + private readonly Task _dependency = dependency; +} diff --git a/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/IocSample.ServiceRegistration.g.cs b/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/IocSample.ServiceRegistration.g.cs index caf9381..a4e83c8 100644 --- a/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/IocSample.ServiceRegistration.g.cs +++ b/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/IocSample.ServiceRegistration.g.cs @@ -21,6 +21,12 @@ public static class IocSampleServiceCollectionExtensions { if (!tags.Any()) { + services.AddTransient((global::System.IServiceProvider sp) => + { + var p0 = sp.GetRequiredService>(); + var s0 = new global::IocSample.AsyncDependentClass(p0); + return s0; + }); services.AddTransient(); services.AddTransient((global::System.IServiceProvider sp) => sp.GetRequiredService()); services.AddTransient((global::System.IServiceProvider sp) => sp.GetRequiredService()); @@ -34,10 +40,10 @@ public static class IocSampleServiceCollectionExtensions services.AddTransient((global::System.IServiceProvider sp) => { var p0 = sp.GetRequiredService(); - var s0_p0 = sp.GetRequiredService, global::System.Collections.Generic.List>>(); - var s0_p1 = sp.GetRequiredService, global::System.Collections.Generic.List>>(); + var s0_p0 = sp.GetRequiredService, global::System.Collections.Generic.List>>(); + var s0_p1 = sp.GetRequiredService, global::System.Collections.Generic.List>>(); var s0_m2 = sp.GetRequiredService, global::System.Collections.Generic.List>>(); - var s0 = new global::IocSample.ViewModel2(p0) { Handler = s0_p0, Handler2 = s0_p1 }; + var s0 = new global::IocSample.ViewModel2(p0) { Handler2 = s0_p0, Handler = s0_p1 }; s0.Initialize(s0_m2); return s0; }); @@ -69,6 +75,17 @@ public static class IocSampleServiceCollectionExtensions var s0 = new global::IocSample.Consumer(p0, p1, p2, p3, p4, p5); return s0; }); + services.AddScoped>((global::System.IServiceProvider sp) => + { + async global::System.Threading.Tasks.Task Init() + { + var s0 = new global::IocSample.AsyncDependency(); + await s0.InitAsync(); + return s0; + } + return Init(); + }); + services.AddScoped>(async (global::System.IServiceProvider sp) => await sp.GetRequiredService>()); services.AddScoped(); services.AddScoped((global::System.IServiceProvider sp) => sp.GetRequiredService()); services.AddScoped((global::System.IServiceProvider sp) => @@ -159,13 +176,13 @@ public static class IocSampleServiceCollectionExtensions var s2 = new global::IocSample.Shared.HandlerDecorator1, global::System.Collections.Generic.List>(s1, s2_p0); return s2; }); - services.AddSingleton, global::IocSample.GenericRequestHandler2>(); - services.AddSingleton, global::System.Collections.Generic.List>>((global::System.IServiceProvider sp) => + services.AddSingleton, global::IocSample.GenericRequestHandler>(); + services.AddSingleton, global::System.Collections.Generic.List>>((global::System.IServiceProvider sp) => { - var s0 = sp.GetRequiredService>(); - var s1 = new global::IocSample.Shared.HandlerDecorator2, global::System.Collections.Generic.List>(s0); - var s2_p0 = sp.GetRequiredService, global::System.Collections.Generic.List>>>(); - var s2 = new global::IocSample.Shared.HandlerDecorator1, global::System.Collections.Generic.List>(s1, s2_p0); + var s0 = sp.GetRequiredService>(); + var s1 = new global::IocSample.Shared.HandlerDecorator2, global::System.Collections.Generic.List>(s0); + var s2_p0 = sp.GetRequiredService, global::System.Collections.Generic.List>>>(); + var s2 = new global::IocSample.Shared.HandlerDecorator1, global::System.Collections.Generic.List>(s1, s2_p0); return s2; }); services.AddSingleton, global::IocSample.GenericRequestHandler>(); @@ -177,13 +194,13 @@ public static class IocSampleServiceCollectionExtensions var s2 = new global::IocSample.Shared.HandlerDecorator1, global::System.Collections.Generic.List>(s1, s2_p0); return s2; }); - services.AddSingleton, global::IocSample.GenericRequestHandler>(); - services.AddSingleton, global::System.Collections.Generic.List>>((global::System.IServiceProvider sp) => + services.AddSingleton, global::IocSample.GenericRequestHandler2>(); + services.AddSingleton, global::System.Collections.Generic.List>>((global::System.IServiceProvider sp) => { - var s0 = sp.GetRequiredService>(); - var s1 = new global::IocSample.Shared.HandlerDecorator2, global::System.Collections.Generic.List>(s0); - var s2_p0 = sp.GetRequiredService, global::System.Collections.Generic.List>>>(); - var s2 = new global::IocSample.Shared.HandlerDecorator1, global::System.Collections.Generic.List>(s1, s2_p0); + var s0 = sp.GetRequiredService>(); + var s1 = new global::IocSample.Shared.HandlerDecorator2, global::System.Collections.Generic.List>(s0); + var s2_p0 = sp.GetRequiredService, global::System.Collections.Generic.List>>>(); + var s2 = new global::IocSample.Shared.HandlerDecorator1, global::System.Collections.Generic.List>(s1, s2_p0); return s2; }); services.AddSingleton, global::IocSample.GenericRequestHandler2>(); diff --git a/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/Module.Container.g.cs b/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/Module.Container.g.cs index c82cf40..9d1262a 100644 --- a/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/Module.Container.g.cs +++ b/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/Module.Container.g.cs @@ -42,9 +42,9 @@ public Module(IServiceProvider? fallbackProvider) _iocSample_Consumer = GetIocSample_Consumer(); _iocSample_External = GetIocSample_External(); _iocSample_GenericRequestHandler_IocSample_Entity_ = GetIocSample_GenericRequestHandler_IocSample_Entity_(); - _iocSample_GenericRequestHandler2_IocSample_Entity3_ = GetIocSample_GenericRequestHandler2_IocSample_Entity3_(); - _iocSample_GenericRequestHandler_IocSample_Entity2_ = GetIocSample_GenericRequestHandler_IocSample_Entity2_(); _iocSample_GenericRequestHandler_IocSample_Entity3_ = GetIocSample_GenericRequestHandler_IocSample_Entity3_(); + _iocSample_GenericRequestHandler_IocSample_Entity2_ = GetIocSample_GenericRequestHandler_IocSample_Entity2_(); + _iocSample_GenericRequestHandler2_IocSample_Entity3_ = GetIocSample_GenericRequestHandler2_IocSample_Entity3_(); _iocSample_GenericRequestHandler2_IocSample_Entity_ = GetIocSample_GenericRequestHandler2_IocSample_Entity_(); _iocSample_IGenericFactoryService_IocSample_IWrapper_decimal___IocSample_GenericFactory_Create = GetIocSample_IGenericFactoryService_IocSample_IWrapper_decimal___IocSample_GenericFactory_Create(); _iocSample_GenericRequestHandler2_IocSample_Entity2_ = GetIocSample_GenericRequestHandler2_IocSample_Entity2_(); @@ -65,9 +65,9 @@ private Module(Module parent) _iocSample_Consumer = parent._iocSample_Consumer; _iocSample_External = parent._iocSample_External; _iocSample_GenericRequestHandler_IocSample_Entity_ = parent._iocSample_GenericRequestHandler_IocSample_Entity_; - _iocSample_GenericRequestHandler2_IocSample_Entity3_ = parent._iocSample_GenericRequestHandler2_IocSample_Entity3_; - _iocSample_GenericRequestHandler_IocSample_Entity2_ = parent._iocSample_GenericRequestHandler_IocSample_Entity2_; _iocSample_GenericRequestHandler_IocSample_Entity3_ = parent._iocSample_GenericRequestHandler_IocSample_Entity3_; + _iocSample_GenericRequestHandler_IocSample_Entity2_ = parent._iocSample_GenericRequestHandler_IocSample_Entity2_; + _iocSample_GenericRequestHandler2_IocSample_Entity3_ = parent._iocSample_GenericRequestHandler2_IocSample_Entity3_; _iocSample_GenericRequestHandler2_IocSample_Entity_ = parent._iocSample_GenericRequestHandler2_IocSample_Entity_; _iocSample_IGenericFactoryService_IocSample_IWrapper_decimal___IocSample_GenericFactory_Create = parent._iocSample_IGenericFactoryService_IocSample_IWrapper_decimal___IocSample_GenericFactory_Create; _iocSample_GenericRequestHandler2_IocSample_Entity2_ = parent._iocSample_GenericRequestHandler2_IocSample_Entity2_; @@ -137,17 +137,17 @@ private Module(Module parent) return instance; } - private global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> _iocSample_GenericRequestHandler2_IocSample_Entity3_ = null!; - private global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> GetIocSample_GenericRequestHandler2_IocSample_Entity3_() + private global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> _iocSample_GenericRequestHandler_IocSample_Entity3_ = null!; + private global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> GetIocSample_GenericRequestHandler_IocSample_Entity3_() { - if(_iocSample_GenericRequestHandler2_IocSample_Entity3_ is not null) return _iocSample_GenericRequestHandler2_IocSample_Entity3_; + if(_iocSample_GenericRequestHandler_IocSample_Entity3_ is not null) return _iocSample_GenericRequestHandler_IocSample_Entity3_; - global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> instance = new global::IocSample.GenericRequestHandler2((global::IocSample.Shared.ILogger>)GetRequiredService(typeof(global::IocSample.Shared.ILogger>))); + global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> instance = new global::IocSample.GenericRequestHandler((global::IocSample.Shared.ILogger>)GetRequiredService(typeof(global::IocSample.Shared.ILogger>))); - instance = new global::IocSample.Shared.HandlerDecorator2, global::System.Collections.Generic.List>(instance); - instance = new global::IocSample.Shared.HandlerDecorator1, global::System.Collections.Generic.List>(instance, (global::IocSample.Shared.ILogger, global::System.Collections.Generic.List>>)GetRequiredService(typeof(global::IocSample.Shared.ILogger, global::System.Collections.Generic.List>>))); + instance = new global::IocSample.Shared.HandlerDecorator2, global::System.Collections.Generic.List>(instance); + instance = new global::IocSample.Shared.HandlerDecorator1, global::System.Collections.Generic.List>(instance, (global::IocSample.Shared.ILogger, global::System.Collections.Generic.List>>)GetRequiredService(typeof(global::IocSample.Shared.ILogger, global::System.Collections.Generic.List>>))); - _iocSample_GenericRequestHandler2_IocSample_Entity3_ = instance; + _iocSample_GenericRequestHandler_IocSample_Entity3_ = instance; return instance; } @@ -165,17 +165,17 @@ private Module(Module parent) return instance; } - private global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> _iocSample_GenericRequestHandler_IocSample_Entity3_ = null!; - private global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> GetIocSample_GenericRequestHandler_IocSample_Entity3_() + private global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> _iocSample_GenericRequestHandler2_IocSample_Entity3_ = null!; + private global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> GetIocSample_GenericRequestHandler2_IocSample_Entity3_() { - if(_iocSample_GenericRequestHandler_IocSample_Entity3_ is not null) return _iocSample_GenericRequestHandler_IocSample_Entity3_; + if(_iocSample_GenericRequestHandler2_IocSample_Entity3_ is not null) return _iocSample_GenericRequestHandler2_IocSample_Entity3_; - global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> instance = new global::IocSample.GenericRequestHandler((global::IocSample.Shared.ILogger>)GetRequiredService(typeof(global::IocSample.Shared.ILogger>))); + global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List> instance = new global::IocSample.GenericRequestHandler2((global::IocSample.Shared.ILogger>)GetRequiredService(typeof(global::IocSample.Shared.ILogger>))); - instance = new global::IocSample.Shared.HandlerDecorator2, global::System.Collections.Generic.List>(instance); - instance = new global::IocSample.Shared.HandlerDecorator1, global::System.Collections.Generic.List>(instance, (global::IocSample.Shared.ILogger, global::System.Collections.Generic.List>>)GetRequiredService(typeof(global::IocSample.Shared.ILogger, global::System.Collections.Generic.List>>))); + instance = new global::IocSample.Shared.HandlerDecorator2, global::System.Collections.Generic.List>(instance); + instance = new global::IocSample.Shared.HandlerDecorator1, global::System.Collections.Generic.List>(instance, (global::IocSample.Shared.ILogger, global::System.Collections.Generic.List>>)GetRequiredService(typeof(global::IocSample.Shared.ILogger, global::System.Collections.Generic.List>>))); - _iocSample_GenericRequestHandler_IocSample_Entity3_ = instance; + _iocSample_GenericRequestHandler2_IocSample_Entity3_ = instance; return instance; } @@ -235,6 +235,36 @@ private Module(Module parent) } } + private global::System.Threading.Tasks.Task? _iocSample_AsyncDependency; + private readonly global::System.Threading.SemaphoreSlim _iocSample_AsyncDependencySemaphore = new(1, 1); + + private async global::System.Threading.Tasks.Task GetIocSample_AsyncDependencyAsync() + { + if(_iocSample_AsyncDependency is not null) + return await _iocSample_AsyncDependency; + + await _iocSample_AsyncDependencySemaphore.WaitAsync(); + try + { + if(_iocSample_AsyncDependency is null) + { + _iocSample_AsyncDependency = CreateIocSample_AsyncDependencyAsync(); + } + } + finally + { + _iocSample_AsyncDependencySemaphore.Release(); + } + return await _iocSample_AsyncDependency; + } + + private async global::System.Threading.Tasks.Task CreateIocSample_AsyncDependencyAsync() + { + var instance = new global::IocSample.AsyncDependency(); + await instance.InitAsync(); + return instance; + } + private global::IocSample.Basic? _iocSample_Basic; private readonly Lock _iocSample_BasicLock = new(); private global::IocSample.Basic GetIocSample_Basic() @@ -286,6 +316,11 @@ private Module(Module parent) } } + private global::IocSample.AsyncDependentClass GetIocSample_AsyncDependentClass() + { + return new global::IocSample.AsyncDependentClass(((global::System.Func>)(async () => (global::IocSample.IAsyncDependency)(await GetIocSample_AsyncDependencyAsync())))()); + } + private global::IocSample.Basic2 GetIocSample_Basic2() { return new global::IocSample.Basic2(); @@ -310,8 +345,8 @@ private Module(Module parent) { var instance = new global::IocSample.ViewModel2(GetIocSample_CustomMessenger()) { - Handler = GetIocSample_GenericRequestHandler_IocSample_Entity2_(), Handler2 = GetIocSample_GenericRequestHandler_IocSample_Entity3_(), + Handler = GetIocSample_GenericRequestHandler_IocSample_Entity2_(), }; instance.Initialize(GetIocSample_GenericRequestHandler2_IocSample_Entity3_()); return instance; @@ -491,6 +526,8 @@ private Module(Module parent) public partial global::IocSample.IKeyed GetKeyEnum() => GetIocSample_KeyedEnum_IocSample_KeyEnum_Key0(); + public partial async global::System.Threading.Tasks.Task GetAsyncDependencyTask() => await GetIocSample_AsyncDependencyAsync(); + #endregion #region IServiceProvider @@ -637,6 +674,7 @@ public IServiceScope CreateScope() new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), new(new ServiceIdentifier(typeof(global::IocSample.Module), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::IocSample.AsyncDependentClass), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_AsyncDependentClass()), new(new ServiceIdentifier(typeof(global::IocSample.Basic2), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_Basic2()), new(new ServiceIdentifier(typeof(global::IocSample.IBasic), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_Basic()), new(new ServiceIdentifier(typeof(global::IocSample.IBasic2), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_Basic2()), @@ -652,6 +690,8 @@ public IServiceScope CreateScope() new(new ServiceIdentifier(typeof(global::IocSample.DependentClass), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_DependentClass()), new(new ServiceIdentifier(typeof(global::IocSample.DependentClass2), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_DependentClass2()), new(new ServiceIdentifier(typeof(global::IocSample.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_Consumer!), + new(new ServiceIdentifier(typeof(global::IocSample.AsyncDependency), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_AsyncDependencyAsync()), + new(new ServiceIdentifier(typeof(global::IocSample.IAsyncDependency), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_AsyncDependencyAsync()), new(new ServiceIdentifier(typeof(global::IocSample.Basic), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_Basic()), new(new ServiceIdentifier(typeof(global::IocSample.FactoryService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_FactoryService_IocSample_Factory_Create()), new(new ServiceIdentifier(typeof(global::IocSample.IFactoryService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_FactoryService2_IocSample_FactoryService2_Create()), @@ -679,12 +719,12 @@ public IServiceScope CreateScope() new(new ServiceIdentifier(typeof(global::IocSample.Default4), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_Default4()), new(new ServiceIdentifier(typeof(global::IocSample.GenericRequestHandler), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler_IocSample_Entity_!), new(new ServiceIdentifier(typeof(global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List>), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler_IocSample_Entity_!), - new(new ServiceIdentifier(typeof(global::IocSample.GenericRequestHandler2), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler2_IocSample_Entity3_!), - new(new ServiceIdentifier(typeof(global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List>), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler2_IocSample_Entity3_!), - new(new ServiceIdentifier(typeof(global::IocSample.GenericRequestHandler), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler_IocSample_Entity2_!), - new(new ServiceIdentifier(typeof(global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List>), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler_IocSample_Entity2_!), new(new ServiceIdentifier(typeof(global::IocSample.GenericRequestHandler), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler_IocSample_Entity3_!), new(new ServiceIdentifier(typeof(global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List>), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler_IocSample_Entity3_!), + new(new ServiceIdentifier(typeof(global::IocSample.GenericRequestHandler), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler_IocSample_Entity2_!), + new(new ServiceIdentifier(typeof(global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List>), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler_IocSample_Entity2_!), + new(new ServiceIdentifier(typeof(global::IocSample.GenericRequestHandler2), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler2_IocSample_Entity3_!), + new(new ServiceIdentifier(typeof(global::IocSample.Shared.IRequestHandler, global::System.Collections.Generic.List>), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._iocSample_GenericRequestHandler2_IocSample_Entity3_!), new(new ServiceIdentifier(typeof(global::IocSample.WrapperService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_WrapperService_int_()), new(new ServiceIdentifier(typeof(global::IocSample.IWrapperService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_WrapperService_int_()), new(new ServiceIdentifier(typeof(global::IocSample.WrapperService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetIocSample_WrapperService_string_()), @@ -801,6 +841,8 @@ public void Dispose() DisposeService(_iocSample_FactoryService2_IocSample_FactoryService2_Create); DisposeService(_iocSample_FactoryService_IocSample_Factory_Create); DisposeService(_iocSample_Basic); + DisposeService(_iocSample_AsyncDependency); + _iocSample_AsyncDependencySemaphore.Dispose(); DisposeService(_iocSample_Default2); _iocSample_Shared_SharedModule.Dispose(); return; @@ -809,9 +851,9 @@ public void Dispose() DisposeService(_iocSample_GenericRequestHandler2_IocSample_Entity2_); DisposeService(_iocSample_IGenericFactoryService_IocSample_IWrapper_decimal___IocSample_GenericFactory_Create); DisposeService(_iocSample_GenericRequestHandler2_IocSample_Entity_); - DisposeService(_iocSample_GenericRequestHandler_IocSample_Entity3_); - DisposeService(_iocSample_GenericRequestHandler_IocSample_Entity2_); DisposeService(_iocSample_GenericRequestHandler2_IocSample_Entity3_); + DisposeService(_iocSample_GenericRequestHandler_IocSample_Entity2_); + DisposeService(_iocSample_GenericRequestHandler_IocSample_Entity3_); DisposeService(_iocSample_GenericRequestHandler_IocSample_Entity_); DisposeService(_iocSample_External); DisposeService(_iocSample_Consumer); @@ -828,6 +870,8 @@ public async ValueTask DisposeAsync() await DisposeServiceAsync(_iocSample_FactoryService2_IocSample_FactoryService2_Create); await DisposeServiceAsync(_iocSample_FactoryService_IocSample_Factory_Create); await DisposeServiceAsync(_iocSample_Basic); + await DisposeServiceAsync(_iocSample_AsyncDependency); + _iocSample_AsyncDependencySemaphore.Dispose(); await DisposeServiceAsync(_iocSample_Default2); await _iocSample_Shared_SharedModule.DisposeAsync(); return; @@ -836,9 +880,9 @@ public async ValueTask DisposeAsync() await DisposeServiceAsync(_iocSample_GenericRequestHandler2_IocSample_Entity2_); await DisposeServiceAsync(_iocSample_IGenericFactoryService_IocSample_IWrapper_decimal___IocSample_GenericFactory_Create); await DisposeServiceAsync(_iocSample_GenericRequestHandler2_IocSample_Entity_); - await DisposeServiceAsync(_iocSample_GenericRequestHandler_IocSample_Entity3_); - await DisposeServiceAsync(_iocSample_GenericRequestHandler_IocSample_Entity2_); await DisposeServiceAsync(_iocSample_GenericRequestHandler2_IocSample_Entity3_); + await DisposeServiceAsync(_iocSample_GenericRequestHandler_IocSample_Entity2_); + await DisposeServiceAsync(_iocSample_GenericRequestHandler_IocSample_Entity3_); await DisposeServiceAsync(_iocSample_GenericRequestHandler_IocSample_Entity_); await DisposeServiceAsync(_iocSample_External); await DisposeServiceAsync(_iocSample_Consumer); @@ -862,5 +906,35 @@ private static void DisposeService(object? service) if(service is IDisposable disposable) disposable.Dispose(); } + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + #endregion } diff --git a/samples/Ioc/IocSample/IocSample.csproj b/samples/Ioc/IocSample/IocSample.csproj index c114e34..e274c98 100644 --- a/samples/Ioc/IocSample/IocSample.csproj +++ b/samples/Ioc/IocSample/IocSample.csproj @@ -6,7 +6,7 @@ Exe true - Register,Container,PropertyInject,FieldInject,MethodInject + Register,Container,PropertyInject,FieldInject,MethodInject,AsyncMethodInject diff --git a/samples/Ioc/IocSample/Module.cs b/samples/Ioc/IocSample/Module.cs index 3b7bc76..c1ca388 100644 --- a/samples/Ioc/IocSample/Module.cs +++ b/samples/Ioc/IocSample/Module.cs @@ -10,4 +10,6 @@ public sealed partial class Module [IocInject(Key = KeyEnum.Key0)] public partial IKeyed GetKeyEnum(); + + public partial Task GetAsyncDependencyTask(); } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/AnalyzerHelpers.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/AnalyzerHelpers.cs index 02c62c0..a974aa4 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/AnalyzerHelpers.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/AnalyzerHelpers.cs @@ -38,7 +38,7 @@ public static bool IsWellKnownServiceType(ITypeSymbol typeSymbol) public static bool IsAlwaysResolvable(INamedTypeSymbol serviceType) { // Well-known service types - if (IsWellKnownServiceType(serviceType)) + if(IsWellKnownServiceType(serviceType)) return true; return false; @@ -62,7 +62,7 @@ public static bool IsIEnumerableOfT(INamedTypeSymbol serviceType) /// The element type T, or null if not applicable. public static INamedTypeSymbol? GetEnumerableElementType(INamedTypeSymbol enumerableType) { - if (!IsIEnumerableOfT(enumerableType)) + if(!IsIEnumerableOfT(enumerableType)) return null; return enumerableType.TypeArguments.FirstOrDefault() as INamedTypeSymbol; @@ -77,7 +77,7 @@ public static bool IsIEnumerableOfT(INamedTypeSymbol serviceType) /// True if the attribute class matches the target symbol. public static bool IsAttributeMatch(INamedTypeSymbol? attributeClass, INamedTypeSymbol? targetSymbol) { - if (attributeClass is null || targetSymbol is null) + if(attributeClass is null || targetSymbol is null) return false; // For generic types, get the original unbound definition for comparison @@ -91,7 +91,7 @@ public static bool IsAttributeMatch(INamedTypeSymbol? attributeClass, INamedType /// The attribute class to check. /// The attribute symbols context. /// True if the attribute is an IoC registration attribute. - public static bool IsIoCRegistrationAttribute(INamedTypeSymbol attributeClass, IoCAttributeSymbols attributeSymbols) + public static bool IsIoCRegistrationAttribute(INamedTypeSymbol attributeClass, IocAttributeSymbols attributeSymbols) { return IsAttributeMatch(attributeClass, attributeSymbols.IocRegisterAttribute) || IsAttributeMatch(attributeClass, attributeSymbols.IocRegisterAttribute_T1) @@ -105,7 +105,7 @@ public static bool IsIoCRegistrationAttribute(INamedTypeSymbol attributeClass, I /// The attribute class to check. /// The attribute symbols context. /// True if the attribute is an IoCRegisterForAttribute. - public static bool IsIoCRegisterForAttribute(INamedTypeSymbol attributeClass, IoCAttributeSymbols attributeSymbols) + public static bool IsIoCRegisterForAttribute(INamedTypeSymbol attributeClass, IocAttributeSymbols attributeSymbols) { return IsAttributeMatch(attributeClass, attributeSymbols.IocRegisterForAttribute) || IsAttributeMatch(attributeClass, attributeSymbols.IocRegisterForAttribute_T1); @@ -117,7 +117,7 @@ public static bool IsIoCRegisterForAttribute(INamedTypeSymbol attributeClass, Io /// The attribute class to check. /// The attribute symbols context. /// True if the attribute is an IoCRegisterAttribute. - public static bool IsIoCRegisterAttribute(INamedTypeSymbol attributeClass, IoCAttributeSymbols attributeSymbols) + public static bool IsIoCRegisterAttribute(INamedTypeSymbol attributeClass, IocAttributeSymbols attributeSymbols) { return IsAttributeMatch(attributeClass, attributeSymbols.IocRegisterAttribute) || IsAttributeMatch(attributeClass, attributeSymbols.IocRegisterAttribute_T1); @@ -129,7 +129,7 @@ public static bool IsIoCRegisterAttribute(INamedTypeSymbol attributeClass, IoCAt /// The attribute class to check. /// The attribute symbols context. /// True if the attribute is an IoCRegisterDefaultsAttribute. - public static bool IsIoCRegisterDefaultsAttribute(INamedTypeSymbol attributeClass, IoCAttributeSymbols attributeSymbols) + public static bool IsIoCRegisterDefaultsAttribute(INamedTypeSymbol attributeClass, IocAttributeSymbols attributeSymbols) { return IsAttributeMatch(attributeClass, attributeSymbols.IocRegisterDefaultsAttribute) || IsAttributeMatch(attributeClass, attributeSymbols.IocRegisterDefaultsAttribute_T1); @@ -141,9 +141,9 @@ public static bool IsIoCRegisterDefaultsAttribute(INamedTypeSymbol attributeClas /// The attribute class to check. /// The attribute symbols context. /// True if the attribute is an IocImportModuleAttribute. - public static bool IsIocImportModuleAttribute(INamedTypeSymbol? attributeClass, IoCAttributeSymbols attributeSymbols) + public static bool IsIocImportModuleAttribute(INamedTypeSymbol? attributeClass, IocAttributeSymbols attributeSymbols) { - if (attributeClass is null) + if(attributeClass is null) return false; return IsAttributeMatch(attributeClass, attributeSymbols.IocImportModuleAttribute) @@ -156,7 +156,7 @@ public static bool IsIocImportModuleAttribute(INamedTypeSymbol? attributeClass, /// The attribute class to check. /// The attribute symbols context. /// True if the attribute is any IoC attribute. - public static bool IsAnyIoCAttribute(INamedTypeSymbol attributeClass, IoCAttributeSymbols attributeSymbols) + public static bool IsAnyIoCAttribute(INamedTypeSymbol attributeClass, IocAttributeSymbols attributeSymbols) { return IsIoCRegistrationAttribute(attributeClass, attributeSymbols) || IsIoCRegisterDefaultsAttribute(attributeClass, attributeSymbols); @@ -169,49 +169,333 @@ public static bool IsAnyIoCAttribute(INamedTypeSymbol attributeClass, IoCAttribu /// An enumerable of service type symbols. public static IEnumerable GetServiceTypesFromAttribute(AttributeData attribute) { - foreach (var namedArg in attribute.NamedArguments) + foreach(var namedArg in attribute.NamedArguments) { - if (namedArg.Key is not "ServiceTypes") + if(namedArg.Key is not "ServiceTypes") continue; - if (namedArg.Value.Kind is not TypedConstantKind.Array) + if(namedArg.Value.Kind is not TypedConstantKind.Array) continue; - foreach (var element in namedArg.Value.Values) + foreach(var element in namedArg.Value.Values) { - if (element.Value is INamedTypeSymbol serviceType) + if(element.Value is INamedTypeSymbol serviceType) yield return serviceType; } } } /// - /// Gets the target implementation type from an IoCRegisterForAttribute. - /// Supports both generic and non-generic variants. + /// Returns when the implementation contains at least one async inject method. + /// Mirrors the generator's async-init classification by looking for instance ordinary methods + /// marked with [IocInject]/[Inject] that return non-generic Task. /// - /// The attribute data. - /// The target type symbol, or null if not found. - public static INamedTypeSymbol? GetTargetTypeFromRegisterFor(AttributeData attribute) + public static bool IsAsyncInitImplementation(INamedTypeSymbol implType, IocFeatures features) { + if((features & IocFeatures.AsyncMethodInject) == 0) + return false; + + var typeToInspect = implType.IsUnboundGenericType ? implType.OriginalDefinition : implType; + + foreach(var member in typeToInspect.GetMembers()) + { + if(member is not IMethodSymbol { MethodKind: MethodKind.Ordinary, IsStatic: false, IsGenericMethod: false } method) + continue; + + if(!IsNonGenericTaskType(method.ReturnType)) + continue; + + if(method.GetAttributes().Any(static attr => attr.AttributeClass?.IsInject == true)) + return true; + } + + return false; + } + + /// + /// Extracts the keyed service key from a member attribute. + /// [FromKeyedServices] takes precedence over [IocInject]/[Inject]. + /// + public static string? GetServiceKeyFromMember(ISymbol member) + { + string? serviceKey = null; + + foreach(var attribute in member.GetAttributes()) + { + var attrClass = attribute.AttributeClass; + if(attrClass is null) + continue; + + if(attrClass.Name == "FromKeyedServicesAttribute" + && attrClass.ContainingNamespace?.ToDisplayString() == "Microsoft.Extensions.DependencyInjection") + { + if(attribute.ConstructorArguments.Length > 0) + { + var keyArg = attribute.ConstructorArguments[0]; + if(!keyArg.IsNull && keyArg.Value is not null) + return keyArg.GetPrimitiveConstantString(); + } + + return null; + } + + if(attrClass.IsInject && serviceKey is null) + { + var (key, _, _) = attribute.GetKeyInfo(); + serviceKey = key; + } + } + + return serviceKey; + } + + /// + /// Comparer for (service type, key) tuples that uses symbol equality for the type component. + /// + public static IEqualityComparer<(INamedTypeSymbol ServiceType, string? Key)> ServiceTypeAndKeyComparer { get; } + = new ServiceTypeAndKeySymbolComparer(); + + private sealed class ServiceTypeAndKeySymbolComparer : IEqualityComparer<(INamedTypeSymbol ServiceType, string? Key)> + { + public bool Equals((INamedTypeSymbol ServiceType, string? Key) x, (INamedTypeSymbol ServiceType, string? Key) y) + => SymbolEqualityComparer.Default.Equals(x.ServiceType, y.ServiceType) + && StringComparer.Ordinal.Equals(x.Key, y.Key); + + public int GetHashCode((INamedTypeSymbol ServiceType, string? Key) obj) + => unchecked((SymbolEqualityComparer.Default.GetHashCode(obj.ServiceType) * 397) + ^ StringComparer.Ordinal.GetHashCode(obj.Key ?? string.Empty)); + } + + /// + /// Enumerates the service types exposed by a registration attribute for the specified implementation type. + /// Includes self-registration plus any explicit aliases or register-all flags. + /// + public static IEnumerable EnumerateRegisteredServiceTypes( + INamedTypeSymbol implementationType, + AttributeData attribute, + IocAttributeSymbols attributeSymbols) + { + yield return implementationType; + var attrClass = attribute.AttributeClass; - if (attrClass is null) + if(attrClass is null) + yield break; + + if(IsIoCRegisterAttribute(attrClass, attributeSymbols) && attrClass.IsGenericType) + { + foreach(var typeArg in attrClass.TypeArguments) + { + if(typeArg is INamedTypeSymbol serviceType) + yield return serviceType; + } + } + + foreach(var serviceType in GetServiceTypesFromAttribute(attribute)) + yield return serviceType; + + var (_, registerAllInterfaces) = attribute.TryGetRegisterAllInterfaces(); + if(registerAllInterfaces) + { + foreach(var interfaceType in implementationType.AllInterfaces) + yield return interfaceType; + } + + var (_, registerAllBaseClasses) = attribute.TryGetRegisterAllBaseClasses(); + if(!registerAllBaseClasses) + yield break; + + var baseType = implementationType.BaseType; + while(baseType is not null && baseType.SpecialType is not SpecialType.System_Object) + { + yield return baseType; + baseType = baseType.BaseType; + } + } + + /// + /// Enumerates the implicit service aliases used by the analyzers for an implementation type. + /// + public static IEnumerable EnumerateImplicitServiceTypes(INamedTypeSymbol implementationType) + { + yield return implementationType; + + foreach(var interfaceType in implementationType.AllInterfaces) + yield return interfaceType; + + var baseType = implementationType.BaseType; + while(baseType is not null && baseType.SpecialType is not SpecialType.System_Object) + { + yield return baseType; + baseType = baseType.BaseType; + } + } + + /// + /// Enumerates all assembly-level [IocRegisterFor] / [IocRegisterFor<T>] attributes + /// and yields (attribute, targetType) tuples. Does not perform abstract/private validation. + /// + /// The compilation to query. + /// The attribute symbols context. + /// The cancellation token. + /// An enumerable of (AttributeData, INamedTypeSymbol) tuples. + public static IEnumerable<(AttributeData Attribute, INamedTypeSymbol TargetType)> EnumerateAssemblyLevelRegisterForAttributes( + Compilation compilation, + IocAttributeSymbols attributeSymbols, + CancellationToken cancellationToken) + { + if(attributeSymbols.IocRegisterForAttribute is null && attributeSymbols.IocRegisterForAttribute_T1 is null) + yield break; + + foreach(var attribute in compilation.Assembly.GetAttributes()) + { + cancellationToken.ThrowIfCancellationRequested(); + + var attrClass = attribute.AttributeClass; + if(attrClass is null) + continue; + + if(!IsIoCRegisterForAttribute(attrClass, attributeSymbols)) + continue; + + var targetType = attribute.GetTargetTypeFromRegisterForAttribute(); + if(targetType is null) + continue; + + yield return (attribute, targetType); + } + } + + /// + /// Returns when is the non-generic + /// class. + /// + public static bool IsNonGenericTaskType(ITypeSymbol? type) + => UnwrapNullableValueType(type) is INamedTypeSymbol { Arity: 0, Name: "Task" } named + && named.ContainingNamespace.ToDisplayString() == "System.Threading.Tasks"; + + /// + /// Unwraps Nullable<T> to T for value-type async wrapper analysis. + /// + public static ITypeSymbol? UnwrapNullableValueType(ITypeSymbol? type) + => type is INamedTypeSymbol + { + OriginalDefinition.SpecialType: SpecialType.System_Nullable_T, + TypeArguments.Length: 1 + } nullableType + ? nullableType.TypeArguments[0] + : type; + + /// + /// Returns the single generic argument of Task<T> or ValueTask<T>, if present. + /// + public static INamedTypeSymbol? TryGetAsyncWrapperElementType(ITypeSymbol? type) + { + type = UnwrapNullableValueType(type); + + if(type is not INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: 1 } namedType) return null; - // Generic variant: IocRegisterForAttribute - if (attrClass.IsGenericType && attrClass.TypeArguments.Length >= 1) + if(namedType.ContainingNamespace.ToDisplayString() != "System.Threading.Tasks") + return null; + + if(namedType.Name is not ("Task" or "ValueTask")) + return null; + + return namedType.TypeArguments[0] as INamedTypeSymbol; + } + + /// + /// Unwraps any Generator-supported wrapper type to extract the inner service type for partial accessor resolution. + /// Supported wrappers: Task<T>, Lazy<T>, Func<T>, IEnumerable<T>, IReadOnlyCollection<T>, ICollection<T>, + /// IReadOnlyList<T>, IList<T>, T[], IDictionary<K,V>, IReadOnlyDictionary<K,V>, Dictionary<K,V>, KeyValuePair<K,V>. + /// + /// The inner service type, or null if the type is not a recognized wrapper. + public static INamedTypeSymbol? TryUnwrapWrapperElementType(ITypeSymbol type) + { + // Array: T[] + if(type.TypeKind == TypeKind.Array) + return (type as IArrayTypeSymbol)?.ElementType as INamedTypeSymbol; + + if(type is not INamedTypeSymbol named) + return null; + + // Arity-1 wrappers + if(named.Arity == 1) { - return attrClass.TypeArguments[0] as INamedTypeSymbol; + var typeArg = named.TypeArguments[0] as INamedTypeSymbol; + + // IEnumerable + if(named.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T) + return typeArg; + + var ns = named.ContainingNamespace.ToDisplayString(); + + if(ns == "System.Collections.Generic" + && named.Name is "IReadOnlyCollection" or "ICollection" or "IReadOnlyList" or "IList") + return typeArg; + + if(ns == "System" && named.Name == "Lazy") + return typeArg; + + if(ns == "System.Threading.Tasks" && named.Name is "Task" or "ValueTask") + return typeArg; } - // Non-generic variant: IocRegisterForAttribute(typeof(T)) - if (attribute.ConstructorArguments.Length >= 1 && attribute.ConstructorArguments[0].Value is INamedTypeSymbol argType) + // Arity-2 wrappers — return the value type (TypeArguments[1]) + if(named.Arity == 2) { - return argType; + var ns = named.ContainingNamespace.ToDisplayString(); + + if(ns == "System.Collections.Generic" + && named.Name is "IDictionary" or "IReadOnlyDictionary" or "Dictionary" or "KeyValuePair") + return named.TypeArguments[1] as INamedTypeSymbol; } + // Func, Func, Func, etc. + // The last type argument is always the return type (the service type to resolve). + if(named.Arity >= 1 + && named.ContainingNamespace.ToDisplayString() == "System" + && named.Name == "Func") + return named.TypeArguments[named.TypeArguments.Length - 1] as INamedTypeSymbol; + return null; } + /// + /// Returns true when the return type is not supported for partial accessor resolution. + /// Unsupported: non-generic Task, non-generic ValueTask. + /// Note: ValueTask<T> is handled separately via the async-init path (SGIOC029). + /// + public static bool IsUnsupportedPartialAccessorReturnType(ITypeSymbol type) + { + // Non-generic Task + if(IsNonGenericTaskType(type)) + return true; + + if(type is not INamedTypeSymbol { Name: "ValueTask", Arity: 0 } named) + return false; + + return named.ContainingNamespace.ToDisplayString() == "System.Threading.Tasks"; + } + + /// + /// Returns true when the type is ValueTask<T> (generic, arity 1). + /// + public static bool IsGenericValueTaskType(ITypeSymbol type) + => type is INamedTypeSymbol { Name: "ValueTask", Arity: 1 } named + && named.ContainingNamespace.ToDisplayString() == "System.Threading.Tasks"; + + /// + /// Returns when is a direct Task<T> + /// for the specified service type. + /// + public static bool IsTaskOfServiceType(ITypeSymbol? type, INamedTypeSymbol serviceType) + => TryGetAsyncWrapperElementType(type) is { } wrappedType + && UnwrapNullableValueType(type) is INamedTypeSymbol { Name: "Task" } + && SymbolEqualityComparer.Default.Equals( + wrappedType.WithNullableAnnotation(NullableAnnotation.NotAnnotated), + serviceType.WithNullableAnnotation(NullableAnnotation.NotAnnotated)); + /// /// Checks if a field type is resolvable for injection. /// @@ -223,11 +507,11 @@ public static bool IsFieldAlwaysResolvable(IFieldSymbol field, AttributeData inj var fieldType = field.Type; // Skip well-known service types - if (IsWellKnownServiceType(fieldType)) + if(IsWellKnownServiceType(fieldType)) return true; // Check if always resolvable (well-known types) - if (fieldType is INamedTypeSymbol namedType && IsAlwaysResolvable(namedType)) + if(fieldType is INamedTypeSymbol namedType && IsAlwaysResolvable(namedType)) return true; // Note: IEnumerable is handled separately by the caller @@ -235,7 +519,7 @@ public static bool IsFieldAlwaysResolvable(IFieldSymbol field, AttributeData inj // Check if [IocInject] has a Key specified - this makes it resolvable var (key, _, _) = injectAttribute.GetKeyInfo(); - if (key is not null) + if(key is not null) return true; return false; @@ -249,14 +533,14 @@ public static bool IsFieldAlwaysResolvable(IFieldSymbol field, AttributeData inj public static (bool IsValid, string? ErrorReason) ValidateFactoryOrInstanceSymbol(ISymbol symbol) { // Check if the symbol is static - if (!symbol.IsStatic) + if(!symbol.IsStatic) { return (false, "not static"); } // Check accessibility - must be at least internal to be accessible // Private members cannot be accessed from the generated code - switch (symbol.DeclaredAccessibility) + switch(symbol.DeclaredAccessibility) { case Accessibility.Private: return (false, "private"); @@ -269,9 +553,9 @@ public static (bool IsValid, string? ErrorReason) ValidateFactoryOrInstanceSymbo // Also check containing type accessibility var containingType = symbol.ContainingType; - while (containingType is not null) + while(containingType is not null) { - if (containingType.DeclaredAccessibility is Accessibility.Private) + if(containingType.DeclaredAccessibility is Accessibility.Private) { return (false, "declared in a private type"); } @@ -296,39 +580,39 @@ public static bool IsParameterAlwaysResolvable(IParameterSymbol param) var paramType = param.Type; // Skip if parameter has default value - if (param.HasExplicitDefaultValue) + if(param.HasExplicitDefaultValue) return true; // Skip well-known service types - if (IsWellKnownServiceType(paramType)) + if(IsWellKnownServiceType(paramType)) return true; // Note: IEnumerable is handled separately by the caller // as it depends on whether T is registered when IntegrateServiceProvider = false // Check for special attributes that make the parameter resolvable - foreach (var attribute in param.GetAttributes()) + foreach(var attribute in param.GetAttributes()) { var attrClass = attribute.AttributeClass; - if (attrClass is null) + if(attrClass is null) continue; var attrNamespace = attrClass.ContainingNamespace?.ToDisplayString(); // [IocInject] or [Inject] with Key - check if it has a key - if (attrClass.IsInject) + if(attrClass.IsInject) { var (key, _, _) = attribute.GetKeyInfo(); - if (key is not null) + if(key is not null) return true; } // [ServiceKey] - injects the registration key - if (attrClass.Name == "ServiceKeyAttribute" && attrNamespace == "Microsoft.Extensions.DependencyInjection") + if(attrClass.Name == "ServiceKeyAttribute" && attrNamespace == "Microsoft.Extensions.DependencyInjection") return true; // [FromKeyedServices] - MS.DI handles this automatically - if (attrClass.Name == "FromKeyedServicesAttribute" && attrNamespace == "Microsoft.Extensions.DependencyInjection") + if(attrClass.Name == "FromKeyedServicesAttribute" && attrNamespace == "Microsoft.Extensions.DependencyInjection") return true; } @@ -346,11 +630,11 @@ public static bool IsPropertyAlwaysResolvable(IPropertySymbol property, Attribut var propertyType = property.Type; // Skip well-known service types - if (IsWellKnownServiceType(propertyType)) + if(IsWellKnownServiceType(propertyType)) return true; // Check if IEnumerable - always resolvable - if (propertyType is INamedTypeSymbol namedType + if(propertyType is INamedTypeSymbol namedType && namedType.IsGenericType && namedType.OriginalDefinition.SpecialType is SpecialType.System_Collections_Generic_IEnumerable_T) { @@ -359,7 +643,7 @@ public static bool IsPropertyAlwaysResolvable(IPropertySymbol property, Attribut // Check if [IocInject] has a Key specified - this makes it resolvable var (key, _, _) = injectAttribute.GetKeyInfo(); - if (key is not null) + if(key is not null) return true; return false; @@ -373,24 +657,24 @@ public static bool IsPropertyAlwaysResolvable(IPropertySymbol property, Attribut /// True if the member has a resolvable attribute; otherwise, false. public static bool HasResolvableAttribute(ImmutableArray attributes) { - foreach (var attribute in attributes) + foreach(var attribute in attributes) { var attrClass = attribute.AttributeClass; - if (attrClass is null) + if(attrClass is null) continue; var attrNamespace = attrClass.ContainingNamespace?.ToDisplayString(); // [IocInject] or [Inject] - user explicitly handles this - if (attrClass.IsInject) + if(attrClass.IsInject) return true; // [ServiceKey] - injects the registration key - if (attrClass.Name == "ServiceKeyAttribute" && attrNamespace == "Microsoft.Extensions.DependencyInjection") + if(attrClass.Name == "ServiceKeyAttribute" && attrNamespace == "Microsoft.Extensions.DependencyInjection") return true; // [FromKeyedServices] - MS.DI handles this automatically - if (attrClass.Name == "FromKeyedServicesAttribute" && attrNamespace == "Microsoft.Extensions.DependencyInjection") + if(attrClass.Name == "FromKeyedServicesAttribute" && attrNamespace == "Microsoft.Extensions.DependencyInjection") return true; } @@ -410,30 +694,30 @@ public static ImmutableArray GetEffectiveTags(IEnumerable tags) return tagArray.IsEmpty ? [""] : tagArray; } -/// -/// Checks if a target type is invalid for IoC registration. -/// A type is invalid if it is private or abstract (unless it's an interface). -/// -/// The type to check. -/// A tuple indicating whether the type is invalid and the reason if so. -public static (bool IsInvalid, string? Reason) GetRegistrationInvalidReason(INamedTypeSymbol targetType) -{ - // Check if target type is private - if (targetType.DeclaredAccessibility is Accessibility.Private) - return (true, "private"); + /// + /// Checks if a target type is invalid for IoC registration. + /// A type is invalid if it is private or abstract (unless it's an interface). + /// + /// The type to check. + /// A tuple indicating whether the type is invalid and the reason if so. + public static (bool IsInvalid, string? Reason) GetRegistrationInvalidReason(INamedTypeSymbol targetType) + { + // Check if target type is private + if(targetType.DeclaredAccessibility is Accessibility.Private) + return (true, "private"); - // Check if target type is abstract (but not interface) - if (targetType.IsAbstract && targetType.TypeKind is not TypeKind.Interface) - return (true, "abstract"); + // Check if target type is abstract (but not interface) + if(targetType.IsAbstract && targetType.TypeKind is not TypeKind.Interface) + return (true, "abstract"); - return (false, null); -} + return (false, null); + } } /// /// Holds cached IoC attribute type symbols for efficient comparison in analyzers. /// -internal sealed class IoCAttributeSymbols +internal sealed class IocAttributeSymbols { public INamedTypeSymbol? IocContainerAttribute { get; } public INamedTypeSymbol? IocRegisterAttribute { get; } @@ -445,7 +729,7 @@ internal sealed class IoCAttributeSymbols public INamedTypeSymbol? IocImportModuleAttribute { get; } public INamedTypeSymbol? IocImportModuleAttribute_T1 { get; } - public IoCAttributeSymbols(Compilation compilation) + public IocAttributeSymbols(Compilation compilation) { IocContainerAttribute = compilation.GetTypeByMetadataName(Constants.IocContainerAttributeFullName); IocRegisterAttribute = compilation.GetTypeByMetadataName(Constants.IocRegisterAttributeFullName); diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/ContainerAnalyzer.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/ContainerAnalyzer.cs index f250a98..d59301a 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/ContainerAnalyzer.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/ContainerAnalyzer.cs @@ -73,6 +73,32 @@ public sealed class ContainerAnalyzer : DiagnosticAnalyzer description: "Circular module imports create static initializer deadlocks. Remove the circular dependency.", customTags: [WellKnownDiagnosticTags.CompilationEnd]); + /// + /// SGIOC027: Partial accessor must return Task<T> for an async-init service. + /// + public static readonly DiagnosticDescriptor PartialAccessorMustReturnTask = new( + id: "SGIOC027", + title: "Partial accessor must return Task for async-init service", + messageFormat: "Partial accessor '{0}' returns '{1}' but the implementation has async inject methods. Use 'Task<{1}>'.", + category: Constants.Category_Design, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: "When a registered implementation has async inject methods (methods returning Task decorated with [IocInject]), partial accessors targeting that service must return Task so the generator can emit an awaitable resolver.", + customTags: [WellKnownDiagnosticTags.CompilationEnd]); + + /// + /// SGIOC029: Unsupported async partial accessor type. + /// + public static readonly DiagnosticDescriptor UnsupportedAsyncPartialAccessorType = new( + id: "SGIOC029", + title: "Unsupported async partial accessor type", + messageFormat: "Partial accessor '{0}' returns '{1}' which is not a supported async type. Only 'Task' is supported.", + category: Constants.Category_Design, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: "When an async-init service is targeted by a partial accessor, only Task is a valid return type. Other wrapper types (Lazy, Func, ValueTask, collections, etc.) and nested wrapper shapes are not supported.", + customTags: [WellKnownDiagnosticTags.CompilationEnd]); + private static readonly SymbolDisplayFormat s_qualifiedFormat = new( typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces); @@ -82,7 +108,9 @@ public sealed class ContainerAnalyzer : DiagnosticAnalyzer ContainerMustBePartialAndNotStatic, UseSwitchStatementIgnoredWithImportedModules, UnableToResolvePartialAccessor, - CircularModuleImport + CircularModuleImport, + PartialAccessorMustReturnTask, + UnsupportedAsyncPartialAccessorType ]; public override void Initialize(AnalysisContext context) @@ -95,25 +123,42 @@ public override void Initialize(AnalysisContext context) private static void OnCompilationStart(CompilationStartAnalysisContext context) { - var attributeSymbols = new IoCAttributeSymbols(context.Compilation); + var features = ParseIocFeatures(context.Options); + var attributeSymbols = new IocAttributeSymbols(context.Compilation); - if (!attributeSymbols.HasContainerAttribute) + if(!attributeSymbols.HasContainerAttribute) return; // Collect registered services for SGIOC018 analysis var registeredServiceTypes = new ConcurrentDictionary(SymbolEqualityComparer.Default); + // Maps (service type, key) -> implementation type for async partial accessor validation. + var serviceImplementationTypes = new ConcurrentDictionary<(INamedTypeSymbol ServiceType, string? Key), INamedTypeSymbol>(AnalyzerHelpers.ServiceTypeAndKeyComparer); + + // Tracks every registration entry so async-init checks can evaluate all implementations for a (service type, key) pair. + var registrations = new ConcurrentBag(); + // Collect containers with IntegrateServiceProvider = false for SGIOC018 analysis var containersWithNoFallback = new ConcurrentBag(); + // Collect all [IocContainer] types for SGIOC027/029 analysis + var allContainers = new ConcurrentBag(); + // Collect import edges for SGIOC025 circular import analysis var importEdges = new ConcurrentBag<(INamedTypeSymbol Container, INamedTypeSymbol Module)>(); var analyzerContext = new ContainerAnalyzerContext( attributeSymbols, registeredServiceTypes, + serviceImplementationTypes, + registrations, containersWithNoFallback, - importEdges); + allContainers, + importEdges, + features); + + // Collect assembly-level [IocRegisterFor] / [IocRegisterFor] registrations + CollectAssemblyLevelRegistrations(context.Compilation, analyzerContext, context.CancellationToken); // SGIOC019: Check for partial modifier and static modifier on container classes // Also collect containers with IntegrateServiceProvider = false for SGIOC018 @@ -134,29 +179,29 @@ private static void OnCompilationStart(CompilationStartAnalysisContext context) /// private static void AnalyzeContainerClass(SymbolAnalysisContext context, ContainerAnalyzerContext analyzerContext) { - if (context.Symbol is not INamedTypeSymbol typeSymbol) + if(context.Symbol is not INamedTypeSymbol typeSymbol) return; // Check if the type has IocContainerAttribute var containerAttribute = typeSymbol.GetAttributes() .FirstOrDefault(attr => AnalyzerHelpers.IsAttributeMatch(attr.AttributeClass, analyzerContext.AttributeSymbols.IocContainerAttribute)); - if (containerAttribute is null) + if(containerAttribute is null) return; // SGIOC019: Check for partial modifier and static modifier - foreach (var syntaxRef in typeSymbol.DeclaringSyntaxReferences) + foreach(var syntaxRef in typeSymbol.DeclaringSyntaxReferences) { context.CancellationToken.ThrowIfCancellationRequested(); var syntax = syntaxRef.GetSyntax(context.CancellationToken); - if (syntax is not ClassDeclarationSyntax classDeclaration) + if(syntax is not ClassDeclarationSyntax classDeclaration) continue; var hasPartial = classDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword); var hasStatic = classDeclaration.Modifiers.Any(SyntaxKind.StaticKeyword); - if (!hasPartial || hasStatic) + if(!hasPartial || hasStatic) { var location = containerAttribute.ApplicationSyntaxReference?.GetSyntax(context.CancellationToken).GetLocation() ?? classDeclaration.Identifier.GetLocation(); @@ -171,35 +216,37 @@ private static void AnalyzeContainerClass(SymbolAnalysisContext context, Contain // Collect containers with IntegrateServiceProvider = false for SGIOC018 var integrateServiceProvider = true; var useSwitchStatement = false; - foreach (var namedArg in containerAttribute.NamedArguments) + foreach(var namedArg in containerAttribute.NamedArguments) { - if (namedArg.Key is "IntegrateServiceProvider" && namedArg.Value.Value is bool integrateValue) + if(namedArg.Key is "IntegrateServiceProvider" && namedArg.Value.Value is bool integrateValue) { integrateServiceProvider = integrateValue; } - else if (namedArg.Key is "UseSwitchStatement" && namedArg.Value.Value is bool switchValue) + else if(namedArg.Key is "UseSwitchStatement" && namedArg.Value.Value is bool switchValue) { useSwitchStatement = switchValue; } } - if (!integrateServiceProvider) + if(!integrateServiceProvider) { analyzerContext.ContainersWithNoFallback.Add(typeSymbol); } + analyzerContext.AllContainers.Add(typeSymbol); + // SGIOC020: Check for UseSwitchStatement = true with imported modules - if (useSwitchStatement) + if(useSwitchStatement) { var hasImportedModules = typeSymbol.GetAttributes() .Any(attr => AnalyzerHelpers.IsIocImportModuleAttribute(attr.AttributeClass, analyzerContext.AttributeSymbols)); - if (hasImportedModules) + if(hasImportedModules) { var location = containerAttribute.ApplicationSyntaxReference?.GetSyntax(context.CancellationToken).GetLocation() ?? typeSymbol.Locations.FirstOrDefault(); - if (location is not null) + if(location is not null) { context.ReportDiagnostic(Diagnostic.Create( UseSwitchStatementIgnoredWithImportedModules, @@ -210,83 +257,78 @@ private static void AnalyzeContainerClass(SymbolAnalysisContext context, Contain } // Collect import edges for SGIOC025 circular import detection - foreach (var attr in typeSymbol.GetAttributes()) + foreach(var attr in typeSymbol.GetAttributes()) { var importedModuleType = GetImportedModuleType(attr, analyzerContext.AttributeSymbols); - if (importedModuleType is not null) + if(importedModuleType is not null) { analyzerContext.ImportEdges.Add((typeSymbol, importedModuleType)); } } } + /// + /// Collects assembly-level [IocRegisterFor] / [IocRegisterFor<T>] registrations into the analyzer context. + /// + private static void CollectAssemblyLevelRegistrations( + Compilation compilation, + ContainerAnalyzerContext analyzerContext, + CancellationToken cancellationToken) + { + foreach(var (attribute, targetType) in AnalyzerHelpers.EnumerateAssemblyLevelRegisterForAttributes( + compilation, analyzerContext.AttributeSymbols, cancellationToken)) + { + RegisterServiceTypes(analyzerContext, targetType, attribute); + } + } + /// /// Collects all registered service types for SGIOC018 analysis. /// private static void CollectRegisteredServices(SymbolAnalysisContext context, ContainerAnalyzerContext analyzerContext) { - if (context.Symbol is not INamedTypeSymbol typeSymbol) + if(context.Symbol is not INamedTypeSymbol typeSymbol) return; - foreach (var attribute in typeSymbol.GetAttributes()) + foreach(var attribute in typeSymbol.GetAttributes()) { context.CancellationToken.ThrowIfCancellationRequested(); var attrClass = attribute.AttributeClass; - if (attrClass is null) + if(attrClass is null) continue; // Check for IocRegisterAttribute - if (AnalyzerHelpers.IsIoCRegisterAttribute(attrClass, analyzerContext.AttributeSymbols)) + if(AnalyzerHelpers.IsIoCRegisterAttribute(attrClass, analyzerContext.AttributeSymbols)) { - // Register the implementation type itself - analyzerContext.RegisteredServiceTypes.TryAdd(typeSymbol, true); - - // Register service types from generic type arguments - if (attrClass.IsGenericType) - { - foreach (var typeArg in attrClass.TypeArguments) - { - if (typeArg is INamedTypeSymbol serviceType) - analyzerContext.RegisteredServiceTypes.TryAdd(serviceType, true); - } - } - - // Register service types from ServiceTypes property - foreach (var serviceType in AnalyzerHelpers.GetServiceTypesFromAttribute(attribute)) - { - analyzerContext.RegisteredServiceTypes.TryAdd(serviceType, true); - } + RegisterServiceTypes(analyzerContext, typeSymbol, attribute); } // Check for IocRegisterForAttribute - if (AnalyzerHelpers.IsIoCRegisterForAttribute(attrClass, analyzerContext.AttributeSymbols)) + if(AnalyzerHelpers.IsIoCRegisterForAttribute(attrClass, analyzerContext.AttributeSymbols)) { // Get the target implementation type from the attribute - var targetType = AnalyzerHelpers.GetTargetTypeFromRegisterFor(attribute); - if (targetType is not null) - analyzerContext.RegisteredServiceTypes.TryAdd(targetType, true); - - // Register service types from ServiceTypes property - foreach (var serviceType in AnalyzerHelpers.GetServiceTypesFromAttribute(attribute)) - { - analyzerContext.RegisteredServiceTypes.TryAdd(serviceType, true); - } + var targetType = attribute.GetTargetTypeFromRegisterForAttribute(); + if(targetType is not null) + RegisterServiceTypes(analyzerContext, targetType, attribute); } } } /// - /// SGIOC018: Analyzes container dependencies when IntegrateServiceProvider = false. + /// Analyzes container dependencies for SGIOC018, SGIOC021, SGIOC027, and SGIOC029. /// private static void AnalyzeContainerDependencies(CompilationAnalysisContext context, ContainerAnalyzerContext analyzerContext) { - // Skip if no containers with IntegrateServiceProvider = false - if (analyzerContext.ContainersWithNoFallback.IsEmpty) - return; + // SGIOC027/029: async-init semantic validation runs on ALL containers + foreach(var containerSymbol in analyzerContext.AllContainers) + { + context.CancellationToken.ThrowIfCancellationRequested(); + AnalyzeAsyncPartialAccessors(context, analyzerContext, containerSymbol); + } - // Analyze dependencies for each container with no fallback - foreach (var containerSymbol in analyzerContext.ContainersWithNoFallback) + // SGIOC018/021: registration resolution checks only when IntegrateServiceProvider = false + foreach(var containerSymbol in analyzerContext.ContainersWithNoFallback) { context.CancellationToken.ThrowIfCancellationRequested(); AnalyzeContainerServiceDependencies(context, analyzerContext, containerSymbol); @@ -298,7 +340,7 @@ private static void AnalyzeContainerServiceDependencies( ContainerAnalyzerContext analyzerContext, INamedTypeSymbol containerSymbol) { - foreach (var kvp in analyzerContext.RegisteredServiceTypes) + foreach(var kvp in analyzerContext.RegisteredServiceTypes) { context.CancellationToken.ThrowIfCancellationRequested(); @@ -306,17 +348,17 @@ private static void AnalyzeContainerServiceDependencies( // Analyze constructor dependencies var constructor = serviceType.SpecifiedOrPrimaryOrMostParametersConstructor; - if (constructor is not null) + if(constructor is not null) { AnalyzeParameterDependencies(context, analyzerContext, containerSymbol, constructor.Parameters); } // Analyze injected property dependencies - foreach (var (member, injectAttribute) in serviceType.GetInjectedMembers()) + foreach(var (member, injectAttribute) in serviceType.GetInjectedMembers()) { context.CancellationToken.ThrowIfCancellationRequested(); - switch (member) + switch(member) { case IPropertySymbol property: AnalyzePropertyDependency(context, analyzerContext, containerSymbol, property, injectAttribute); @@ -338,6 +380,92 @@ private static void AnalyzeContainerServiceDependencies( AnalyzePartialAccessorDependencies(context, analyzerContext, containerSymbol); } + /// + /// SGIOC027/029: Validates async-init partial accessor return types for ALL containers. + /// Fires regardless of IntegrateServiceProvider because async-init access requires the correct return type. + /// + private static void AnalyzeAsyncPartialAccessors( + CompilationAnalysisContext context, + ContainerAnalyzerContext analyzerContext, + INamedTypeSymbol containerSymbol) + { + foreach(var member in containerSymbol.GetMembers()) + { + context.CancellationToken.ThrowIfCancellationRequested(); + + if(member.IsStatic) + continue; + + ITypeSymbol? returnType = null; + string? memberName = null; + + switch(member) + { + case IMethodSymbol method + when method.IsPartialDefinition + && !method.ReturnsVoid + && method.Parameters.Length == 0 + && !method.IsGenericMethod + && method.MethodKind == MethodKind.Ordinary: + returnType = method.ReturnType; + memberName = method.Name; + break; + + case IPropertySymbol property + when property.IsPartialDefinition + && property.GetMethod is not null: + returnType = property.Type; + memberName = property.Name; + break; + } + + if(returnType is null || memberName is null) + continue; + + var serviceKey = AnalyzerHelpers.GetServiceKeyFromMember(member); + + var normalizedReturnType = AnalyzerHelpers.UnwrapNullableValueType( + returnType.WithNullableAnnotation(NullableAnnotation.NotAnnotated)); + + if(normalizedReturnType is null) + continue; + + var serviceType = TryGetInnermostServiceType(normalizedReturnType); + + if(serviceType is null) + continue; + + if(!IsAnyImplementationAsyncInit(analyzerContext, serviceType, serviceKey)) + continue; + + // Async-init service found: validate return type. + var location = member.Locations.FirstOrDefault(); + + // Task is the only valid return type → no diagnostic + if(AnalyzerHelpers.IsTaskOfServiceType(normalizedReturnType, serviceType)) + continue; + + // Sync TService return → SGIOC027 + if(normalizedReturnType is INamedTypeSymbol namedReturnType + && SymbolEqualityComparer.Default.Equals(namedReturnType, serviceType)) + { + context.ReportDiagnostic(Diagnostic.Create( + PartialAccessorMustReturnTask, + location, + memberName, + serviceType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat))); + continue; + } + + // Any other return type (wrappers, ValueTask, collections, nested shapes, etc.) → SGIOC029 + context.ReportDiagnostic(Diagnostic.Create( + UnsupportedAsyncPartialAccessorType, + location, + memberName, + normalizedReturnType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat))); + } + } + /// /// SGIOC021: Analyzes partial method/property accessors in a container class to ensure their return types are registered. /// Only applies when IntegrateServiceProvider = false. @@ -347,17 +475,17 @@ private static void AnalyzePartialAccessorDependencies( ContainerAnalyzerContext analyzerContext, INamedTypeSymbol containerSymbol) { - foreach (var member in containerSymbol.GetMembers()) + foreach(var member in containerSymbol.GetMembers()) { context.CancellationToken.ThrowIfCancellationRequested(); - if (member.IsStatic) + if(member.IsStatic) continue; ITypeSymbol? returnType = null; string? memberName = null; - switch (member) + switch(member) { case IMethodSymbol method when method.IsPartialDefinition @@ -377,51 +505,229 @@ when property.IsPartialDefinition break; } - if (returnType is null || memberName is null) + if(returnType is null || memberName is null) continue; + var serviceKey = AnalyzerHelpers.GetServiceKeyFromMember(member); + // Strip nullable annotation for type lookup - var unwrappedType = returnType.WithNullableAnnotation(NullableAnnotation.NotAnnotated); + var normalizedReturnType = AnalyzerHelpers.UnwrapNullableValueType( + returnType.WithNullableAnnotation(NullableAnnotation.NotAnnotated)); var isNullable = returnType.NullableAnnotation == NullableAnnotation.Annotated; - // Nullable accessors are optional — skip the check - if (isNullable) + if(normalizedReturnType is null) continue; - if (unwrappedType is INamedTypeSymbol namedReturnType - && !IsServiceRegistered(namedReturnType, analyzerContext)) + var serviceType = GetAccessorServiceType(normalizedReturnType); + + // Guard: if the innermost type (ignoring downgrade rules) is an async-init service, + // all diagnostic reporting is owned by AnalyzeAsyncPartialAccessors (SGIOC027/029). + // Skip SGIOC021 entirely to prevent double-reporting on downgraded wrapper shapes. + var innermostType = TryGetInnermostServiceType(normalizedReturnType); + if(innermostType is not null) + { + if(IsAnyImplementationAsyncInit(analyzerContext, innermostType, serviceKey)) + continue; + } + + // Report SGIOC021 for unsupported return types (non-generic Task, non-generic ValueTask) + if(serviceType is null && !isNullable) { var location = member.Locations.FirstOrDefault(); context.ReportDiagnostic(Diagnostic.Create( UnableToResolvePartialAccessor, location, - returnType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat), + normalizedReturnType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat), memberName, containerSymbol.Name)); + continue; } + + if(serviceType is not null + && !isNullable + && !IsPartialAccessorServiceRegistered(serviceType, serviceKey, analyzerContext)) + { + var location = member.Locations.FirstOrDefault(); + // For ValueTask (not a generator-supported recursive wrapper), report the full return type + // because the shape is considered downgraded/unsupported per spec. + var diagnosticType = AnalyzerHelpers.IsGenericValueTaskType(normalizedReturnType) + ? normalizedReturnType + : serviceType; + context.ReportDiagnostic(Diagnostic.Create( + UnableToResolvePartialAccessor, + location, + diagnosticType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat), + memberName, + containerSymbol.Name)); + } + + // Report SGIOC021 for ValueTask when the service is registered but not an async-init service. + // (When it IS async-init, SGIOC029 is reported by AnalyzeAsyncPartialAccessors and we already continued.) + if(serviceType is not null + && !isNullable + && IsPartialAccessorServiceRegistered(serviceType, serviceKey, analyzerContext) + && AnalyzerHelpers.IsGenericValueTaskType(normalizedReturnType)) + { + var location = member.Locations.FirstOrDefault(); + context.ReportDiagnostic(Diagnostic.Create( + UnableToResolvePartialAccessor, + location, + normalizedReturnType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat), + memberName, + containerSymbol.Name)); + continue; + } + } + } + + private static void RegisterServiceTypes( + ContainerAnalyzerContext analyzerContext, + INamedTypeSymbol implementationType, + AttributeData attribute) + { + var (serviceKey, _, _) = attribute.GetKeyInfo(); + + foreach(var serviceType in AnalyzerHelpers.EnumerateRegisteredServiceTypes( + implementationType, + attribute, + analyzerContext.AttributeSymbols)) + { + analyzerContext.RegisteredServiceTypes.TryAdd(serviceType, true); + analyzerContext.Registrations.Add(new ServiceRegistration(serviceType, serviceKey, implementationType)); + analyzerContext.ServiceImplementationTypes[(serviceType, serviceKey)] = implementationType; } } + private static INamedTypeSymbol? GetAccessorServiceType(ITypeSymbol returnType) + { + var normalizedReturnType = AnalyzerHelpers.UnwrapNullableValueType( + returnType.WithNullableAnnotation(NullableAnnotation.NotAnnotated)); + + if(normalizedReturnType is null) + return null; + + // Unsupported types (non-generic Task, non-generic ValueTask) → return null so caller reports SGIOC021 + if(AnalyzerHelpers.IsUnsupportedPartialAccessorReturnType(normalizedReturnType)) + return null; + + // ValueTask is not a generator-supported recursive wrapper; return T directly so callers + // can check registration and async-init. SGIOC029 / SGIOC021 are reported separately. + if(AnalyzerHelpers.IsGenericValueTaskType(normalizedReturnType)) + return AnalyzerHelpers.TryUnwrapWrapperElementType(normalizedReturnType); + + // Recursively unwrap generator-supported wrappers with downgrade detection. + // Mirrors TransformExtensions.cs downgrade rules (nested Task shapes, collection-at-top). + ITypeSymbol current = normalizedReturnType; + var isFirst = true; + + while(AnalyzerHelpers.TryUnwrapWrapperElementType(current) is { } element) + { + // Downgrade rule 1: Task — outer is Task AND inner type is itself a wrapper + if(IsGenericTask(current) && AnalyzerHelpers.TryUnwrapWrapperElementType(element) is not null) + return null; + + // Downgrade rule 2: Wrapper — outer is non-Task wrapper AND inner type is Task + if(!IsGenericTask(current) && IsGenericTask(element)) + return null; + + // Downgrade rule 3: ValueTask encountered during recursion (at top level it is handled above) + if(!isFirst && AnalyzerHelpers.IsGenericValueTaskType(current)) + return null; + + // Downgrade rule 4: Collection-at-top — outermost is a collection AND inner is a non-collection wrapper + if(isFirst && IsCollectionWrapper(current) && AnalyzerHelpers.TryUnwrapWrapperElementType(element) is not null && !IsCollectionWrapper(element)) + return null; + + current = element; + isFirst = false; + } + + return current as INamedTypeSymbol; + + static bool IsGenericTask(ITypeSymbol type) + => type is INamedTypeSymbol { Name: "Task", Arity: 1 } named + && named.ContainingNamespace.ToDisplayString() == "System.Threading.Tasks"; + + static bool IsCollectionWrapper(ITypeSymbol type) + { + if(type.TypeKind == TypeKind.Array) + return true; + + if(type is not INamedTypeSymbol named || named.Arity != 1) + return false; + + if(named.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T) + return true; + + var ns = named.ContainingNamespace.ToDisplayString(); + return ns == "System.Collections.Generic" + && named.Name is "IReadOnlyCollection" or "ICollection" or "IReadOnlyList" or "IList"; + } + } + + /// + /// Recursively unwraps all generator-supported wrapper types (without applying downgrade rules) to find + /// the innermost concrete service type. Used by diagnostic analysis to classify async-init services + /// regardless of whether the generator would downgrade the shape to IServiceProvider fallback. + /// + private static INamedTypeSymbol? TryGetInnermostServiceType(ITypeSymbol input) + { + ITypeSymbol current = input; + + while(AnalyzerHelpers.TryUnwrapWrapperElementType(current) is { } element) + { + current = element; + } + + return current as INamedTypeSymbol; + } + + private static bool IsAnyImplementationAsyncInit( + ContainerAnalyzerContext analyzerContext, + INamedTypeSymbol serviceType, + string? serviceKey) + { + foreach(var registration in analyzerContext.Registrations) + { + if(!SymbolEqualityComparer.Default.Equals(registration.ServiceType, serviceType)) + continue; + + if(!StringComparer.Ordinal.Equals(registration.Key, serviceKey)) + continue; + + if(AnalyzerHelpers.IsAsyncInitImplementation(registration.ImplementationType, analyzerContext.Features)) + return true; + } + + return false; + } + + private static IocFeatures ParseIocFeatures(AnalyzerOptions options) + { + options.AnalyzerConfigOptionsProvider.GlobalOptions.TryGetValue(Constants.SourceGenIocFeaturesProperty, out var rawFeatures); + return IocFeaturesHelper.Parse(rawFeatures); + } + private static void AnalyzeParameterDependencies( CompilationAnalysisContext context, ContainerAnalyzerContext analyzerContext, INamedTypeSymbol containerSymbol, ImmutableArray parameters) { - foreach (var param in parameters) + foreach(var param in parameters) { context.CancellationToken.ThrowIfCancellationRequested(); // Skip if parameter is always resolvable - if (AnalyzerHelpers.IsParameterAlwaysResolvable(param)) + if(AnalyzerHelpers.IsParameterAlwaysResolvable(param)) continue; var paramType = param.Type; // Check if the dependency type is registered - if (paramType is INamedTypeSymbol namedParamType) + if(paramType is INamedTypeSymbol namedParamType) { - if (!IsServiceRegistered(namedParamType, analyzerContext)) + if(!IsServiceRegistered(namedParamType, analyzerContext)) { var location = param.Locations.FirstOrDefault(); context.ReportDiagnostic(Diagnostic.Create( @@ -442,15 +748,15 @@ private static void AnalyzePropertyDependency( AttributeData injectAttribute) { // Skip if property is always resolvable - if (AnalyzerHelpers.IsPropertyAlwaysResolvable(property, injectAttribute)) + if(AnalyzerHelpers.IsPropertyAlwaysResolvable(property, injectAttribute)) return; var propertyType = property.Type; // Check if the dependency type is registered - if (propertyType is INamedTypeSymbol namedPropertyType) + if(propertyType is INamedTypeSymbol namedPropertyType) { - if (!IsServiceRegistered(namedPropertyType, analyzerContext)) + if(!IsServiceRegistered(namedPropertyType, analyzerContext)) { var location = injectAttribute.ApplicationSyntaxReference?.GetSyntax(context.CancellationToken).GetLocation() ?? property.Locations.FirstOrDefault(); @@ -471,15 +777,15 @@ private static void AnalyzeFieldDependency( AttributeData injectAttribute) { // Skip if field is always resolvable - if (AnalyzerHelpers.IsFieldAlwaysResolvable(field, injectAttribute)) + if(AnalyzerHelpers.IsFieldAlwaysResolvable(field, injectAttribute)) return; var fieldType = field.Type; // Check if the dependency type is registered - if (fieldType is INamedTypeSymbol namedFieldType) + if(fieldType is INamedTypeSymbol namedFieldType) { - if (!IsServiceRegistered(namedFieldType, analyzerContext)) + if(!IsServiceRegistered(namedFieldType, analyzerContext)) { var location = injectAttribute.ApplicationSyntaxReference?.GetSyntax(context.CancellationToken).GetLocation() ?? field.Locations.FirstOrDefault(); @@ -495,18 +801,18 @@ private static void AnalyzeFieldDependency( private static bool IsServiceRegistered(INamedTypeSymbol serviceType, ContainerAnalyzerContext analyzerContext) { // Direct match - if (analyzerContext.RegisteredServiceTypes.ContainsKey(serviceType)) + if(analyzerContext.RegisteredServiceTypes.ContainsKey(serviceType)) return true; // Check if it's always resolvable (well-known types like IServiceProvider) - if (AnalyzerHelpers.IsAlwaysResolvable(serviceType)) + if(AnalyzerHelpers.IsAlwaysResolvable(serviceType)) return true; // Handle IEnumerable - check if T is registered - if (AnalyzerHelpers.IsIEnumerableOfT(serviceType)) + if(AnalyzerHelpers.IsIEnumerableOfT(serviceType)) { var elementType = AnalyzerHelpers.GetEnumerableElementType(serviceType); - if (elementType is not null) + if(elementType is not null) { // IEnumerable is resolvable if T is registered return analyzerContext.RegisteredServiceTypes.ContainsKey(elementType); @@ -516,18 +822,39 @@ private static bool IsServiceRegistered(INamedTypeSymbol serviceType, ContainerA return false; } + private static bool IsPartialAccessorServiceRegistered( + INamedTypeSymbol serviceType, + string? serviceKey, + ContainerAnalyzerContext analyzerContext) + { + if(analyzerContext.ServiceImplementationTypes.ContainsKey((serviceType, serviceKey))) + return true; + + if(AnalyzerHelpers.IsAlwaysResolvable(serviceType)) + return true; + + if(serviceKey is null && AnalyzerHelpers.IsIEnumerableOfT(serviceType)) + { + var elementType = AnalyzerHelpers.GetEnumerableElementType(serviceType); + if(elementType is not null) + return analyzerContext.RegisteredServiceTypes.ContainsKey(elementType); + } + + return false; + } + /// /// Extracts the imported module type from an [IocImportModule] or [IocImportModule<T>] attribute. /// Returns null if the attribute is not an import module attribute or the type cannot be resolved. /// - private static INamedTypeSymbol? GetImportedModuleType(AttributeData attr, IoCAttributeSymbols attributeSymbols) + private static INamedTypeSymbol? GetImportedModuleType(AttributeData attr, IocAttributeSymbols attributeSymbols) { var attrClass = attr.AttributeClass; - if (attrClass is null) + if(attrClass is null) return null; // Non-generic form: [IocImportModule(typeof(T))] - if (AnalyzerHelpers.IsAttributeMatch(attrClass, attributeSymbols.IocImportModuleAttribute)) + if(AnalyzerHelpers.IsAttributeMatch(attrClass, attributeSymbols.IocImportModuleAttribute)) { return attr.ConstructorArguments.Length > 0 ? attr.ConstructorArguments[0].Value as INamedTypeSymbol @@ -535,7 +862,7 @@ private static bool IsServiceRegistered(INamedTypeSymbol serviceType, ContainerA } // Generic form: [IocImportModule] — OriginalDefinition comparison is handled inside IsAttributeMatch - if (AnalyzerHelpers.IsAttributeMatch(attrClass, attributeSymbols.IocImportModuleAttribute_T1)) + if(AnalyzerHelpers.IsAttributeMatch(attrClass, attributeSymbols.IocImportModuleAttribute_T1)) { return attrClass.IsGenericType && attrClass.TypeArguments.Length > 0 ? attrClass.TypeArguments[0] as INamedTypeSymbol @@ -550,14 +877,14 @@ private static bool IsServiceRegistered(INamedTypeSymbol serviceType, ContainerA /// private static void AnalyzeCircularImports(CompilationAnalysisContext context, ContainerAnalyzerContext analyzerContext) { - if (analyzerContext.ImportEdges.IsEmpty) + if(analyzerContext.ImportEdges.IsEmpty) return; // Build adjacency list: container → list of imported module types var graph = new Dictionary>(SymbolEqualityComparer.Default); - foreach (var (container, module) in analyzerContext.ImportEdges) + foreach(var (container, module) in analyzerContext.ImportEdges) { - if (!graph.TryGetValue(container, out var edges)) + if(!graph.TryGetValue(container, out var edges)) { edges = []; graph[container] = edges; @@ -570,11 +897,11 @@ private static void AnalyzeCircularImports(CompilationAnalysisContext context, C var inStack = new HashSet(SymbolEqualityComparer.Default); var reported = new HashSet(SymbolEqualityComparer.Default); - foreach (var node in graph.Keys) + foreach(var node in graph.Keys) { context.CancellationToken.ThrowIfCancellationRequested(); - if (!visited.Contains(node)) + if(!visited.Contains(node)) { var path = new List(); DetectCycles(context, graph, node, visited, inStack, reported, path, analyzerContext); @@ -596,39 +923,39 @@ private static void DetectCycles( inStack.Add(node); path.Add(node); - if (graph.TryGetValue(node, out var neighbors)) + if(graph.TryGetValue(node, out var neighbors)) { - foreach (var neighbor in neighbors) + foreach(var neighbor in neighbors) { context.CancellationToken.ThrowIfCancellationRequested(); - if (!visited.Contains(neighbor)) + if(!visited.Contains(neighbor)) { DetectCycles(context, graph, neighbor, visited, inStack, reported, path, analyzerContext); } - else if (inStack.Contains(neighbor)) + else if(inStack.Contains(neighbor)) { // Back-edge found — locate where the cycle starts in the current path var cycleStartIdx = -1; - for (var i = 0; i < path.Count; i++) + for(var i = 0; i < path.Count; i++) { - if (SymbolEqualityComparer.Default.Equals(path[i], neighbor)) + if(SymbolEqualityComparer.Default.Equals(path[i], neighbor)) { cycleStartIdx = i; break; } } - if (cycleStartIdx < 0) + if(cycleStartIdx < 0) continue; var cycleStr = string.Join(" → ", path.Skip(cycleStartIdx).Append(neighbor).Select(s => s.ToDisplayString(s_qualifiedFormat))); // Report a diagnostic for every container in the cycle - for (var i = cycleStartIdx; i < path.Count; i++) + for(var i = cycleStartIdx; i < path.Count; i++) { var containerInCycle = path[i]; - if (!reported.Add(containerInCycle)) + if(!reported.Add(containerInCycle)) continue; var location = GetContainerLocation(containerInCycle); @@ -649,16 +976,29 @@ private static void DetectCycles( private static Location? GetContainerLocation(INamedTypeSymbol containerSymbol) => containerSymbol.Locations.FirstOrDefault(); + private readonly record struct ServiceRegistration( + INamedTypeSymbol ServiceType, + string? Key, + INamedTypeSymbol ImplementationType); + private sealed class ContainerAnalyzerContext( - IoCAttributeSymbols attributeSymbols, + IocAttributeSymbols attributeSymbols, ConcurrentDictionary registeredServiceTypes, + ConcurrentDictionary<(INamedTypeSymbol ServiceType, string? Key), INamedTypeSymbol> serviceImplementationTypes, + ConcurrentBag registrations, ConcurrentBag containersWithNoFallback, - ConcurrentBag<(INamedTypeSymbol Container, INamedTypeSymbol Module)> importEdges) + ConcurrentBag allContainers, + ConcurrentBag<(INamedTypeSymbol Container, INamedTypeSymbol Module)> importEdges, + IocFeatures features) { - public IoCAttributeSymbols AttributeSymbols { get; } = attributeSymbols; + public IocAttributeSymbols AttributeSymbols { get; } = attributeSymbols; public ConcurrentDictionary RegisteredServiceTypes { get; } = registeredServiceTypes; + public ConcurrentDictionary<(INamedTypeSymbol ServiceType, string? Key), INamedTypeSymbol> ServiceImplementationTypes { get; } = serviceImplementationTypes; + public ConcurrentBag Registrations { get; } = registrations; public ConcurrentBag ContainersWithNoFallback { get; } = containersWithNoFallback; + public ConcurrentBag AllContainers { get; } = allContainers; public ConcurrentBag<(INamedTypeSymbol Container, INamedTypeSymbol Module)> ImportEdges { get; } = importEdges; + public IocFeatures Features { get; } = features; } } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.AttributeUsage.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.AttributeUsage.cs index f289574..680faa4 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.AttributeUsage.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.AttributeUsage.cs @@ -73,10 +73,24 @@ private static void AnalyzeInjectAttribute(SymbolAnalysisContext context, Analyz var location = injectAttribute.ApplicationSyntaxReference?.GetSyntax(context.CancellationToken).GetLocation() ?? member.Locations.FirstOrDefault(); + // SGIOC028: async void methods cannot be awaited - report before any other method check + if (member is IMethodSymbol { IsAsync: true, ReturnsVoid: true }) + { + context.ReportDiagnostic(Diagnostic.Create( + AsyncVoidInjectMethod, + location, + member.Name)); + return; + } + + var asyncMethodInjectEnabled = (analyzerContext.Features & IocFeatures.AsyncMethodInject) != 0; + var isTaskReturningMethod = member is IMethodSymbol taskMethod && IsNonGenericTaskType(taskMethod.ReturnType); + var (requiredFeature, featureName) = member switch { IPropertySymbol => (IocFeatures.PropertyInject, nameof(IocFeatures.PropertyInject)), IFieldSymbol => (IocFeatures.FieldInject, nameof(IocFeatures.FieldInject)), + IMethodSymbol when isTaskReturningMethod => (IocFeatures.AsyncMethodInject, nameof(IocFeatures.AsyncMethodInject)), IMethodSymbol => (IocFeatures.MethodInject, nameof(IocFeatures.MethodInject)), _ => (IocFeatures.None, string.Empty) }; @@ -88,6 +102,9 @@ private static void AnalyzeInjectAttribute(SymbolAnalysisContext context, Analyz location, member.Name, featureName)); + // For Task-returning methods: SGIOC022 already fired; do NOT also report SGIOC007 return-type error + if (isTaskReturningMethod) + return; } // Container-class partial definitions are valid injection targets (keyed service accessors) @@ -96,7 +113,7 @@ private static void AnalyzeInjectAttribute(SymbolAnalysisContext context, Analyz if (member is IMethodSymbol { IsPartialDefinition: true } && IsInContainerClass(member.ContainingType)) return; - var reason = GetMemberInjectabilityIssue(member); + var reason = GetMemberInjectabilityIssue(member, asyncMethodInjectEnabled); if (reason is not null) { context.ReportDiagnostic(Diagnostic.Create( @@ -158,7 +175,7 @@ private static void AnalyzeFactoryAndInstanceOnAttribute( // SGIOC023 + SGIOC024: Validate InjectMembers elements - only for IoCRegisterFor attributes if (!isDefaultsAttribute) { - AnalyzeInjectMembersOnAttribute(context, argumentList); + AnalyzeInjectMembersOnAttribute(context, argumentList, (analyzerContext.Features & IocFeatures.AsyncMethodInject) != 0); } } @@ -466,7 +483,8 @@ private static bool IsInContainerClass(INamedTypeSymbol? containingType) /// private static void AnalyzeInjectMembersOnAttribute( SyntaxNodeAnalysisContext context, - AttributeArgumentListSyntax argumentList) + AttributeArgumentListSyntax argumentList, + bool asyncMethodInjectEnabled = false) { AttributeArgumentSyntax? injectMembersArg = null; foreach (var arg in argumentList.Arguments) @@ -547,7 +565,7 @@ private static void AnalyzeInjectMembersOnAttribute( continue; // unresolvable — a compile error will already be reported // SGIOC024: Check if member is injectable - var (isInjectable, reason) = ValidateInjectableMember(symbol); + var (isInjectable, reason) = ValidateInjectableMember(symbol, asyncMethodInjectEnabled); if (!isInjectable) { context.ReportDiagnostic(Diagnostic.Create( @@ -563,7 +581,11 @@ private static void AnalyzeInjectMembersOnAttribute( /// Returns the reason a member is not injectable, or if it is valid. /// Shared by SGIOC007 () and SGIOC024 (). /// - private static string? GetMemberInjectabilityIssue(ISymbol symbol) + /// + /// When , methods returning non-generic are + /// considered valid injection targets (allowed by AsyncMethodInject feature). + /// + private static string? GetMemberInjectabilityIssue(ISymbol symbol, bool asyncMethodInjectEnabled = false) { if (symbol.IsStatic) return "it is static"; @@ -589,19 +611,29 @@ IMethodSymbol m when m.DeclaredAccessibility is not (Accessibility.Public or Acc => "method is not accessible", IMethodSymbol m when m.MethodKind != MethodKind.Ordinary => "method is not an ordinary method", - IMethodSymbol { ReturnsVoid: false } => "method does not return void", + // Allow non-generic Task return type only when AsyncMethodInject feature is enabled + IMethodSymbol { ReturnsVoid: false } m when !(asyncMethodInjectEnabled && IsNonGenericTaskType(m.ReturnType)) + => asyncMethodInjectEnabled ? "method does not return void or non-generic Task" : "method does not return void", IMethodSymbol { IsGenericMethod: true } => "method is generic", IPropertySymbol or IFieldSymbol or IMethodSymbol => null, _ => "member is not a property, field, or method" }; } - private static (bool IsInjectable, string Reason) ValidateInjectableMember(ISymbol symbol) + private static (bool IsInjectable, string Reason) ValidateInjectableMember(ISymbol symbol, bool asyncMethodInjectEnabled = false) { - var reason = GetMemberInjectabilityIssue(symbol); + var reason = GetMemberInjectabilityIssue(symbol, asyncMethodInjectEnabled); return reason is null ? (true, string.Empty) : (false, reason); } + /// + /// Returns when is the non-generic + /// class (i.e., Task with arity 0). + /// + private static bool IsNonGenericTaskType(ITypeSymbol? type) + => type is INamedTypeSymbol { Arity: 0, Name: "Task" } named + && named.ContainingNamespace.ToDisplayString() == "System.Threading.Tasks"; + private static ExpressionSyntax[]? GetInjectMembersElements(ExpressionSyntax expression) => expression switch { diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.ServiceCollection.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.ServiceCollection.cs index dce7d13..d42df3f 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.ServiceCollection.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.ServiceCollection.cs @@ -15,38 +15,20 @@ private static ImmutableHashSet CollectAssemblyLevelRegistrations( { var syntaxTreesBuilder = ImmutableHashSet.CreateBuilder(); - // Check if any IoCRegisterForAttribute variant is available - if (analyzerContext.AttributeSymbols.IocRegisterForAttribute is null && analyzerContext.AttributeSymbols.IocRegisterForAttribute_T1 is null) - return syntaxTreesBuilder.ToImmutable(); - - foreach (var attribute in compilation.Assembly.GetAttributes()) + foreach(var (attribute, targetType) in AnalyzerHelpers.EnumerateAssemblyLevelRegisterForAttributes( + compilation, analyzerContext.AttributeSymbols, cancellationToken)) { - cancellationToken.ThrowIfCancellationRequested(); - - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - // Check if this is an IoCRegisterForAttribute (non-generic or generic) - if (!AnalyzerHelpers.IsIoCRegisterForAttribute(attributeClass, analyzerContext.AttributeSymbols)) - continue; - // Track which syntax tree contains this attribute var syntaxReference = attribute.ApplicationSyntaxReference; - if (syntaxReference?.SyntaxTree is { } syntaxTree) + if(syntaxReference?.SyntaxTree is { } syntaxTree) { syntaxTreesBuilder.Add(syntaxTree); } - // Get target type from attribute (constructor arg for non-generic, type parameter for generic) - var targetType = attribute.GetTargetTypeFromRegisterForAttribute(); - if (targetType is null) - continue; - // Skip invalid types - if (targetType.IsAbstract && targetType.TypeKind is not TypeKind.Interface) + if(targetType.IsAbstract && targetType.TypeKind is not TypeKind.Interface) continue; - if (targetType.DeclaredAccessibility is Accessibility.Private) + if(targetType.DeclaredAccessibility is Accessibility.Private) continue; var location = syntaxReference?.GetSyntax(cancellationToken).GetLocation(); @@ -60,7 +42,9 @@ private static ImmutableHashSet CollectAssemblyLevelRegistrations( // Check for Factory and Instance (used by ServiceInfo) var (hasFactory, hasInstance) = attribute.HasFactoryOrInstance(); - RegisterServiceWithIndex(analyzerContext, targetType, lifetime, location, keyTypeSymbol, key is not null, hasFactory, hasInstance); + var assemblyServiceTypes = AnalyzerHelpers.EnumerateRegisteredServiceTypes( + targetType, attribute, analyzerContext.AttributeSymbols).ToList(); + RegisterServiceWithIndex(analyzerContext, targetType, lifetime, location, key, keyTypeSymbol, key is not null, hasFactory, hasInstance, assemblyServiceTypes); } return syntaxTreesBuilder.ToImmutable(); @@ -71,34 +55,16 @@ private static void AnalyzeAssemblyLevelRegistrations( AnalyzerContext analyzerContext, ImmutableHashSet assemblyAttributeSyntaxTrees) { - // Check if any IoCRegisterForAttribute variant is available - if (analyzerContext.AttributeSymbols.IocRegisterForAttribute is null && analyzerContext.AttributeSymbols.IocRegisterForAttribute_T1 is null) - return; - // Only analyze if this syntax tree contains assembly-level attributes - if (!assemblyAttributeSyntaxTrees.Contains(context.SemanticModel.SyntaxTree)) + if(!assemblyAttributeSyntaxTrees.Contains(context.SemanticModel.SyntaxTree)) return; - foreach (var attribute in context.SemanticModel.Compilation.Assembly.GetAttributes()) + foreach(var (attribute, targetType) in AnalyzerHelpers.EnumerateAssemblyLevelRegisterForAttributes( + context.SemanticModel.Compilation, analyzerContext.AttributeSymbols, context.CancellationToken)) { - context.CancellationToken.ThrowIfCancellationRequested(); - // Only process attributes from the current syntax tree var syntaxReference = attribute.ApplicationSyntaxReference; - if (syntaxReference?.SyntaxTree != context.SemanticModel.SyntaxTree) - continue; - - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - // Check if this is an IoCRegisterForAttribute (non-generic or generic) - if (!AnalyzerHelpers.IsIoCRegisterForAttribute(attributeClass, analyzerContext.AttributeSymbols)) - continue; - - // Get target type from attribute (constructor arg for non-generic, type parameter for generic) - var targetType = attribute.GetTargetTypeFromRegisterForAttribute(); - if (targetType is null) + if(syntaxReference?.SyntaxTree != context.SemanticModel.SyntaxTree) continue; var location = syntaxReference.GetSyntax(context.CancellationToken).GetLocation(); @@ -123,23 +89,23 @@ private static void AnalyzeAssemblyLevelRegistrations( /// private static void CollectAndValidateNamedType(SymbolAnalysisContext context, AnalyzerContext analyzerContext) { - if (context.Symbol is not INamedTypeSymbol typeSymbol) + if(context.Symbol is not INamedTypeSymbol typeSymbol) return; - foreach (var attribute in typeSymbol.GetAttributes()) + foreach(var attribute in typeSymbol.GetAttributes()) { context.CancellationToken.ThrowIfCancellationRequested(); - if (!TryGetIoCAttribute(attribute, analyzerContext, out var isIoCRegisterFor)) + if(!TryGetIoCAttribute(attribute, analyzerContext, out var isIoCRegisterFor)) continue; INamedTypeSymbol targetType; - if (isIoCRegisterFor) + if(isIoCRegisterFor) { // Use extension method to get target type (supports both generic and non-generic variants) var target = attribute.GetTargetTypeFromRegisterForAttribute(); - if (target is null) + if(target is null) continue; targetType = target; @@ -165,9 +131,9 @@ private static void CollectAndValidateNamedType(SymbolAnalysisContext context, A AnalyzeDuplicatedRegistration(context.ReportDiagnostic, analyzerContext, attribute, targetType, fullyQualifiedTypeName, location); // Skip registration if type is invalid - if (targetType.IsAbstract && targetType.TypeKind is not TypeKind.Interface) + if(targetType.IsAbstract && targetType.TypeKind is not TypeKind.Interface) continue; - if (targetType.DeclaredAccessibility is Accessibility.Private) + if(targetType.DeclaredAccessibility is Accessibility.Private) continue; // Get lifetime of current service (considering default settings) @@ -182,7 +148,9 @@ private static void CollectAndValidateNamedType(SymbolAnalysisContext context, A // Register service with index for faster lookup // Dependency analysis will be done in CompilationEnd after all services are collected - RegisterServiceWithIndex(analyzerContext, targetType, currentLifetime, location, keyTypeSymbol, key is not null, hasFactory, hasInstance); + var serviceTypes = AnalyzerHelpers.EnumerateRegisteredServiceTypes( + targetType, attribute, analyzerContext.AttributeSymbols).ToList(); + RegisterServiceWithIndex(analyzerContext, targetType, currentLifetime, location, key, keyTypeSymbol, key is not null, hasFactory, hasInstance, serviceTypes); } } @@ -194,36 +162,36 @@ private static void ResolveCsharpKeyTypes( SyntaxNodeAnalysisContext context, AnalyzerContext analyzerContext) { - if (context.Node is not AttributeSyntax attributeSyntax) + if(context.Node is not AttributeSyntax attributeSyntax) return; var attributeSymbol = context.SemanticModel.GetSymbolInfo(attributeSyntax, context.CancellationToken).Symbol; - if (attributeSymbol is not IMethodSymbol attributeConstructor) + if(attributeSymbol is not IMethodSymbol attributeConstructor) return; var attributeClass = attributeConstructor.ContainingType; - if (attributeClass is null) + if(attributeClass is null) return; - if (!AnalyzerHelpers.IsIoCRegistrationAttribute(attributeClass, analyzerContext.AttributeSymbols)) + if(!AnalyzerHelpers.IsIoCRegistrationAttribute(attributeClass, analyzerContext.AttributeSymbols)) return; var attributeData = GetAttributeDataFromSyntax(context, attributeSyntax, attributeClass); - if (attributeData is null) + if(attributeData is null) return; - if (attributeData.GetNamedArgument("KeyType", 0) != 1) + if(attributeData.GetNamedArgument("KeyType", 0) != 1) return; var (_, _, resolvedKeyType) = attributeData.GetKeyInfo(context.SemanticModel); - if (resolvedKeyType is null) + if(resolvedKeyType is null) return; var targetType = AnalyzerHelpers.IsIoCRegisterForAttribute(attributeClass, analyzerContext.AttributeSymbols) ? attributeData.GetTargetTypeFromRegisterForAttribute() : GetTypeLevelTargetType(context, attributeSyntax); - if (targetType is null) + if(targetType is null) return; analyzerContext.ResolvedCsharpKeyTypes[(targetType, attributeSyntax.GetLocation())] = resolvedKeyType; @@ -234,12 +202,12 @@ private static void ResolveCsharpKeyTypes( AttributeSyntax attributeSyntax, INamedTypeSymbol attributeClass) { - if (attributeSyntax.Parent is not AttributeListSyntax attributeList) + if(attributeSyntax.Parent is not AttributeListSyntax attributeList) return null; var syntaxTree = attributeSyntax.SyntaxTree; - if (attributeList.Target?.Identifier.IsKind(SyntaxKind.AssemblyKeyword) is true) + if(attributeList.Target?.Identifier.IsKind(SyntaxKind.AssemblyKeyword) is true) { return context.SemanticModel.Compilation.Assembly.GetAttributes() .FirstOrDefault(attr => @@ -249,7 +217,7 @@ private static void ResolveCsharpKeyTypes( } var targetType = GetTypeLevelTargetType(context, attributeSyntax); - if (targetType is null) + if(targetType is null) return null; return targetType.GetAttributes() @@ -280,17 +248,17 @@ private static bool TryGetIoCAttribute(AttributeData attribute, AnalyzerContext { isIoCRegisterFor = false; var attributeClass = attribute.AttributeClass; - if (attributeClass is null) + if(attributeClass is null) return false; // Check IoCRegisterAttribute variants (non-generic and generic) - if (AnalyzerHelpers.IsIoCRegisterAttribute(attributeClass, analyzerContext.AttributeSymbols)) + if(AnalyzerHelpers.IsIoCRegisterAttribute(attributeClass, analyzerContext.AttributeSymbols)) { return true; } // Check IoCRegisterForAttribute variants (non-generic and generic) - if (AnalyzerHelpers.IsIoCRegisterForAttribute(attributeClass, analyzerContext.AttributeSymbols)) + if(AnalyzerHelpers.IsIoCRegisterForAttribute(attributeClass, analyzerContext.AttributeSymbols)) { isIoCRegisterFor = true; return true; @@ -307,32 +275,49 @@ private static void RegisterServiceWithIndex( INamedTypeSymbol targetType, ServiceLifetime lifetime, Location? location, + string? serviceKey = null, ITypeSymbol? keyTypeSymbol = null, bool hasKey = false, bool hasFactory = false, - bool hasInstance = false) + bool hasInstance = false, + IReadOnlyCollection? registrationServiceTypes = null) { - var serviceInfo = new ServiceInfo(targetType, lifetime, location, keyTypeSymbol, hasKey, hasFactory, hasInstance); + var serviceInfo = new ServiceInfo(targetType, lifetime, location, serviceKey, keyTypeSymbol, hasKey, hasFactory, hasInstance); + + if(!analyzerContext.RegisteredServices.TryAdd(targetType, serviceInfo)) + { + // Same implementation type already registered (e.g., multiple assembly-level IocRegisterFor + // attributes targeting the same impl type with different keys). Record the additional + // (service type, key) pairs on the existing entry so SGIOC030 analysis sees all registrations. + if(registrationServiceTypes is not null + && analyzerContext.RegisteredServices.TryGetValue(targetType, out var existingInfo)) + { + foreach(var st in registrationServiceTypes) + existingInfo.AllRegistrations.Add((st, serviceKey)); + } + return; + } - if (!analyzerContext.RegisteredServices.TryAdd(targetType, serviceInfo)) - return; // Already registered + if(registrationServiceTypes is not null) + foreach(var st in registrationServiceTypes) + serviceInfo.AllRegistrations.Add((st, serviceKey)); // Only build type index for non-keyed services. // Keyed services are resolved by key + type, not by type alone, // so including them in the type-only index could lead to false positive // circular dependency (SGIOC002) or lifetime conflict (SGIOC003-005) diagnostics. - if (hasKey) + if(hasKey) return; // Build index for interfaces - foreach (var iface in targetType.AllInterfaces) + foreach(var iface in targetType.AllInterfaces) { analyzerContext.ServiceTypeIndex.TryAdd(iface, serviceInfo); } // Build index for base classes var baseType = targetType.BaseType; - while (baseType is not null && baseType.SpecialType is not SpecialType.System_Object) + while(baseType is not null && baseType.SpecialType is not SpecialType.System_Object) { analyzerContext.ServiceTypeIndex.TryAdd(baseType, serviceInfo); baseType = baseType.BaseType; @@ -346,26 +331,26 @@ private static void RegisterServiceWithIndex( /// private static DefaultSettingsMap CollectDefaults( Compilation compilation, - IoCAttributeSymbols attributeSymbols, + IocAttributeSymbols attributeSymbols, ConcurrentBag<(string TargetTypeName, Location? Location)> duplicatedDefaults, ConcurrentDictionary<(string TargetTypeName, string Tag), Location?> seenTargetTypes, CancellationToken cancellationToken) { - if (attributeSymbols.IocRegisterDefaultsAttribute is null && attributeSymbols.IocRegisterDefaultsAttribute_T1 is null) + if(attributeSymbols.IocRegisterDefaultsAttribute is null && attributeSymbols.IocRegisterDefaultsAttribute_T1 is null) return new DefaultSettingsMap([]); var settingsBuilder = ImmutableArray.CreateBuilder(); - foreach (var attribute in compilation.Assembly.GetAttributes()) + foreach(var attribute in compilation.Assembly.GetAttributes()) { cancellationToken.ThrowIfCancellationRequested(); var attributeClass = attribute.AttributeClass; - if (attributeClass is null) + if(attributeClass is null) continue; // Check if this is an IoCRegisterDefaultsAttribute (non-generic or generic) - if (!AnalyzerHelpers.IsIoCRegisterDefaultsAttribute(attributeClass, attributeSymbols)) + if(!AnalyzerHelpers.IsIoCRegisterDefaultsAttribute(attributeClass, attributeSymbols)) { continue; } @@ -375,7 +360,7 @@ private static DefaultSettingsMap CollectDefaults( var settings = attributeClass.IsGenericType ? attribute.ExtractDefaultSettingsFromGenericAttribute() : attribute.ExtractDefaultSettings(); - if (settings is not null) + if(settings is not null) { var targetTypeName = settings.TargetServiceType.Name; var tags = settings.Tags; @@ -385,17 +370,17 @@ private static DefaultSettingsMap CollectDefaults( // SGIOC012: Check each effective tag for duplicates var hasDuplicate = false; - foreach (var tag in effectiveTags) + foreach(var tag in effectiveTags) { var defaultKey = (targetTypeName, tag); - if (!seenTargetTypes.TryAdd(defaultKey, attribute.ApplicationSyntaxReference?.GetSyntax(cancellationToken).GetLocation())) + if(!seenTargetTypes.TryAdd(defaultKey, attribute.ApplicationSyntaxReference?.GetSyntax(cancellationToken).GetLocation())) { hasDuplicate = true; break; // Only need to find one duplicate } } - if (hasDuplicate) + if(hasDuplicate) { var location = attribute.ApplicationSyntaxReference?.GetSyntax(cancellationToken).GetLocation(); duplicatedDefaults.Add((targetTypeName, location)); @@ -421,24 +406,24 @@ private static ServiceLifetime GetEffectiveLifetime( ServiceLifetime explicitLifetime) { // If lifetime is explicitly set, use it - if (hasExplicitLifetime) + if(hasExplicitLifetime) return explicitLifetime; var defaultSettings = analyzerContext.DefaultSettings; - if (defaultSettings.IsEmpty) + if(defaultSettings.IsEmpty) return explicitLifetime; // Check default settings for matching interfaces - foreach (var iface in targetType.AllInterfaces) + foreach(var iface in targetType.AllInterfaces) { var ifaceTypeData = iface.GetTypeData(); // Try exact match first - if (defaultSettings.TryGetExactMatches(ifaceTypeData.Name, out var exactIndex)) + if(defaultSettings.TryGetExactMatches(ifaceTypeData.Name, out var exactIndex)) return defaultSettings[exactIndex].Lifetime; // Try generic match (e.g., IGenericTest<> matches IGenericTest) - if (iface.IsGenericType + if(iface.IsGenericType && ifaceTypeData is GenericTypeData genericInterfaceTypeData && defaultSettings.TryGetGenericMatches(genericInterfaceTypeData.NameWithoutGeneric, genericInterfaceTypeData.GenericArity, out var genericIndex)) return defaultSettings[genericIndex].Lifetime; @@ -446,16 +431,16 @@ private static ServiceLifetime GetEffectiveLifetime( // Check default settings for matching base classes var baseType = targetType.BaseType; - while (baseType is not null && baseType.SpecialType is not SpecialType.System_Object) + while(baseType is not null && baseType.SpecialType is not SpecialType.System_Object) { var baseTypeData = baseType.GetTypeData(); // Try exact match first - if (defaultSettings.TryGetExactMatches(baseTypeData.Name, out var exactIndex)) + if(defaultSettings.TryGetExactMatches(baseTypeData.Name, out var exactIndex)) return defaultSettings[exactIndex].Lifetime; // Try generic match for base classes - if (baseType.IsGenericType + if(baseType.IsGenericType && baseTypeData is GenericTypeData genericBaseTypeData && defaultSettings.TryGetGenericMatches(genericBaseTypeData.NameWithoutGeneric, genericBaseTypeData.GenericArity, out var genericIndex)) return defaultSettings[genericIndex].Lifetime; @@ -477,7 +462,7 @@ public int GetHashCode((INamedTypeSymbol Type, Location? Location) obj) } private sealed class AnalyzerContext( - IoCAttributeSymbols attributeSymbols, + IocAttributeSymbols attributeSymbols, ConcurrentDictionary registeredServices, ConcurrentDictionary serviceTypeIndex, DefaultSettingsMap defaultSettings, @@ -485,7 +470,7 @@ private sealed class AnalyzerContext( ConcurrentDictionary<(string TargetTypeName, string Tag), Location?> seenDefaultTargetTypes, IocFeatures features) { - public IoCAttributeSymbols AttributeSymbols { get; } = attributeSymbols; + public IocAttributeSymbols AttributeSymbols { get; } = attributeSymbols; public ConcurrentDictionary RegisteredServices { get; } = registeredServices; /// @@ -542,6 +527,11 @@ private sealed record ServiceInfo /// public string FullyQualifiedName { get; } + /// + /// The registration key string, or null if the service is unkeyed. + /// + public string? ServiceKey { get; } + /// /// The type symbol of the registration key, or null if no key is specified or KeyType is Csharp. /// @@ -562,6 +552,14 @@ private sealed record ServiceInfo /// public bool HasInstance { get; } + /// + /// All (service type, key) pairs from every registration that maps to this implementation type. + /// Populated during collection to capture all registrations including assembly-level duplicates + /// where only stores the first entry per implementation type. + /// Used by SGIOC030 analysis to correctly identify all async-init-only service type/key pairs. + /// + public ConcurrentBag<(INamedTypeSymbol ServiceType, string? Key)> AllRegistrations { get; } = []; + /// /// Cached constructor to avoid repeated SpecifiedOrPrimaryOrMostParametersConstructor lookups. /// @@ -576,6 +574,7 @@ public ServiceInfo( INamedTypeSymbol type, ServiceLifetime lifetime, Location? location, + string? serviceKey = null, ITypeSymbol? keyTypeSymbol = null, bool hasKey = false, bool hasFactory = false, @@ -585,6 +584,7 @@ public ServiceInfo( Lifetime = lifetime; Location = location; FullyQualifiedName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + ServiceKey = serviceKey; KeyTypeSymbol = keyTypeSymbol; HasKey = hasKey; HasFactory = hasFactory; diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.cs index 04d394e..cdcfcb0 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.cs @@ -80,7 +80,7 @@ public sealed partial class RegisterAnalyzer : DiagnosticAnalyzer category: Constants.Category_Usage, defaultSeverity: DiagnosticSeverity.Error, isEnabledByDefault: true, - description: "InjectAttribute cannot be applied to static members, members that cannot be assigned/invoked, or methods that do not return void."); + description: "InjectAttribute cannot be applied to static members, members that cannot be assigned/invoked, or methods that do not return void (or non-generic Task when AsyncMethodInject is enabled)."); /// /// SGIOC008: Invalid Attribute Usage - Factory or Instance uses nameof() but the referenced member is not static or is inaccessible. @@ -249,7 +249,44 @@ public sealed partial class RegisterAnalyzer : DiagnosticAnalyzer category: Constants.Category_Usage, defaultSeverity: DiagnosticSeverity.Error, isEnabledByDefault: true, - description: "Members specified in InjectMembers must be injectable: instance properties with accessible setters, non-readonly fields, and ordinary non-generic void-returning methods, all of which must be public, internal, or protected internal."); + description: "Members specified in InjectMembers must be injectable: instance properties with accessible setters, non-readonly fields, and ordinary non-generic void-returning methods (or non-generic Task-returning when AsyncMethodInject is enabled), all of which must be public, internal, or protected internal."); + + /// + /// SGIOC026: AsyncMethodInject feature requires MethodInject to be enabled. + /// + public static readonly DiagnosticDescriptor AsyncMethodInjectRequiresMethodInject = new( + id: "SGIOC026", + title: "Invalid feature combination", + messageFormat: "'AsyncMethodInject' feature requires 'MethodInject' to be enabled. Add 'MethodInject' to .", + category: Constants.Category_Usage, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: "AsyncMethodInject delegates async method injection to the source generator, which requires MethodInject to already be enabled. Add MethodInject alongside AsyncMethodInject in SourceGenIocFeatures.", + customTags: [WellKnownDiagnosticTags.CompilationEnd]); + + /// + /// SGIOC030: Synchronous dependency requested for async-init-only service. + /// + public static readonly DiagnosticDescriptor SyncDependencyOnAsyncInitService = new( + id: "SGIOC030", + title: "Synchronous dependency requested for async-init service", + messageFormat: "'{0}' requires '{1}' but this service has async inject methods and no synchronous registration exists. Use 'Task<{1}>'.", + category: Constants.Category_Usage, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: "When a service is registered with async inject methods, consumers must request Task instead of T because no synchronous resolution path exists."); + + /// + /// SGIOC028: [IocInject] method is declared as async void, which cannot be awaited. + /// + public static readonly DiagnosticDescriptor AsyncVoidInjectMethod = new( + id: "SGIOC028", + title: "async void injection method cannot be awaited", + messageFormat: "[IocInject] method '{0}' is 'async void' which cannot be awaited. Change return type to 'Task' for async initialization, or remove the 'async' modifier for synchronous injection.", + category: Constants.Category_Usage, + defaultSeverity: DiagnosticSeverity.Warning, + isEnabledByDefault: true, + description: "Methods marked with [IocInject] that are declared as async void cannot be awaited by the source generator. Change the return type to Task to enable async injection, or remove the async modifier if the method is synchronous."); public override ImmutableArray SupportedDiagnostics { get; } = [ @@ -272,7 +309,10 @@ public sealed partial class RegisterAnalyzer : DiagnosticAnalyzer DuplicatedGenericFactoryPlaceholders, InjectFeatureDisabled, InjectMembersInvalidFormat, - InjectMembersNonInjectableMember + InjectMembersNonInjectableMember, + AsyncMethodInjectRequiresMethodInject, + SyncDependencyOnAsyncInitService, + AsyncVoidInjectMethod ]; public override void Initialize(AnalysisContext context) @@ -288,11 +328,18 @@ private static void OnCompilationStart(CompilationStartAnalysisContext context) { var features = ParseIocFeatures(context.Options); + // SGIOC026: AsyncMethodInject requires MethodInject — report once per compilation + if((features & IocFeatures.AsyncMethodInject) != 0 && (features & IocFeatures.MethodInject) == 0) + { + context.RegisterCompilationEndAction(static ctx => + ctx.ReportDiagnostic(Diagnostic.Create(AsyncMethodInjectRequiresMethodInject, Location.None))); + } + // Get attribute type symbols for faster lookup (including generic variants) - var attributeSymbols = new IoCAttributeSymbols(context.Compilation); + var attributeSymbols = new IocAttributeSymbols(context.Compilation); // Check if any IoC attribute is available - if (!attributeSymbols.HasAnyRegistrationAttribute) + if(!attributeSymbols.HasAnyRegistrationAttribute) return; // Use ConcurrentDictionary for thread-safe collection during parallel symbol analysis @@ -356,7 +403,7 @@ private static void OnCompilationStart(CompilationStartAnalysisContext context) private static void AnalyzeAllDependencies(CompilationAnalysisContext context, AnalyzerContext analyzerContext) { // SGIOC012: Report duplicated IoCRegisterDefaults - foreach (var (targetTypeName, location) in analyzerContext.DuplicatedDefaults) + foreach(var (targetTypeName, location) in analyzerContext.DuplicatedDefaults) { context.ReportDiagnostic(Diagnostic.Create( DuplicatedDefaultSettings, @@ -368,7 +415,7 @@ private static void AnalyzeAllDependencies(CompilationAnalysisContext context, A var visited = new HashSet(SymbolEqualityComparer.Default); var pathStack = new Stack(); - foreach (var kvp in analyzerContext.RegisteredServices) + foreach(var kvp in analyzerContext.RegisteredServices) { context.CancellationToken.ThrowIfCancellationRequested(); @@ -388,6 +435,206 @@ private static void AnalyzeAllDependencies(CompilationAnalysisContext context, A pathStack, context.CancellationToken); } + + var asyncInitOnlyServiceTypes = BuildAsyncInitOnlyServiceTypes(analyzerContext); + if(asyncInitOnlyServiceTypes.Count == 0) + return; + + foreach(var serviceInfo in analyzerContext.RegisteredServices.Values) + { + context.CancellationToken.ThrowIfCancellationRequested(); + AnalyzeSyncDependenciesOnAsyncInitServices( + context.ReportDiagnostic, + serviceInfo, + asyncInitOnlyServiceTypes, + context.CancellationToken); + } + } + + private static ImmutableHashSet<(INamedTypeSymbol ServiceType, string? Key)> BuildAsyncInitOnlyServiceTypes(AnalyzerContext analyzerContext) + { + var serviceTypeStates = new Dictionary<(INamedTypeSymbol ServiceType, string? Key), (bool HasAsync, bool HasSync)>(AnalyzerHelpers.ServiceTypeAndKeyComparer); + + foreach(var serviceInfo in analyzerContext.RegisteredServices.Values) + { + var isAsyncInit = AnalyzerHelpers.IsAsyncInitImplementation(serviceInfo.Type, analyzerContext.Features); + + // Use pre-computed AllRegistrations (populated at collection time from attribute data) when + // available. This correctly handles assembly-level registrations where multiple IocRegisterFor + // attributes target the same impl type with different keys — the second and subsequent + // registrations are not stored in RegisteredServices but are recorded in AllRegistrations. + IEnumerable<(INamedTypeSymbol ServiceType, string? Key)> serviceTypes = + serviceInfo.AllRegistrations.IsEmpty + ? CollectRegisteredServiceTypesForAnalysis(serviceInfo, analyzerContext.AttributeSymbols) + : serviceInfo.AllRegistrations; + + foreach(var serviceType in serviceTypes) + { + serviceTypeStates.TryGetValue(serviceType, out var state); + serviceTypeStates[serviceType] = isAsyncInit + ? (HasAsync: true, HasSync: state.HasSync) + : (HasAsync: state.HasAsync, HasSync: true); + } + } + + var asyncOnlyServiceTypes = ImmutableHashSet.CreateBuilder<(INamedTypeSymbol ServiceType, string? Key)>(AnalyzerHelpers.ServiceTypeAndKeyComparer); + foreach(var kvp in serviceTypeStates) + { + var serviceType = kvp.Key; + var state = kvp.Value; + if(state is { HasAsync: true, HasSync: false }) + asyncOnlyServiceTypes.Add(serviceType); + } + + return asyncOnlyServiceTypes.ToImmutable(); + } + + private static HashSet<(INamedTypeSymbol ServiceType, string? Key)> CollectRegisteredServiceTypesForAnalysis( + ServiceInfo serviceInfo, + IocAttributeSymbols attributeSymbols) + { + var serviceTypes = new HashSet<(INamedTypeSymbol ServiceType, string? Key)>(AnalyzerHelpers.ServiceTypeAndKeyComparer); + var hasRegistrationAttribute = false; + + var implementationType = serviceInfo.Type; + + foreach(var attribute in implementationType.GetAttributes()) + { + var attrClass = attribute.AttributeClass; + if(attrClass is null || !AnalyzerHelpers.IsIoCRegistrationAttribute(attrClass, attributeSymbols)) + continue; + + hasRegistrationAttribute = true; + var (serviceKey, _, _) = attribute.GetKeyInfo(); + foreach(var serviceType in AnalyzerHelpers.EnumerateRegisteredServiceTypes(implementationType, attribute, attributeSymbols)) + { + serviceTypes.Add((serviceType, serviceKey)); + } + } + + if(!hasRegistrationAttribute) + { + foreach(var serviceType in AnalyzerHelpers.EnumerateImplicitServiceTypes(implementationType)) + { + serviceTypes.Add((serviceType, serviceInfo.ServiceKey)); + } + } + + return serviceTypes; + } + + private static void AnalyzeSyncDependenciesOnAsyncInitServices( + Action reportDiagnostic, + ServiceInfo serviceInfo, + ImmutableHashSet<(INamedTypeSymbol ServiceType, string? Key)> asyncInitOnlyServiceTypes, + CancellationToken cancellationToken) + { + if(serviceInfo.Constructor is not null) + { + AnalyzeParameterDependenciesOnAsyncInitServices( + reportDiagnostic, + serviceInfo.Constructor.Parameters, + asyncInitOnlyServiceTypes, + cancellationToken); + } + + foreach(var (member, injectAttribute) in serviceInfo.InjectedMembers) + { + cancellationToken.ThrowIfCancellationRequested(); + + switch(member) + { + case IPropertySymbol property: + AnalyzeDependencyOnAsyncInitService( + reportDiagnostic, + property.Name, + property.Type, + injectAttribute.ApplicationSyntaxReference?.GetSyntax(cancellationToken).GetLocation() + ?? property.Locations.FirstOrDefault(), + injectAttribute.GetKeyInfo().Key, + asyncInitOnlyServiceTypes); + break; + + case IFieldSymbol field: + AnalyzeDependencyOnAsyncInitService( + reportDiagnostic, + field.Name, + field.Type, + injectAttribute.ApplicationSyntaxReference?.GetSyntax(cancellationToken).GetLocation() + ?? field.Locations.FirstOrDefault(), + injectAttribute.GetKeyInfo().Key, + asyncInitOnlyServiceTypes); + break; + + case IMethodSymbol method: + AnalyzeParameterDependenciesOnAsyncInitServices( + reportDiagnostic, + method.Parameters, + asyncInitOnlyServiceTypes, + cancellationToken); + break; + } + } + } + + private static void AnalyzeParameterDependenciesOnAsyncInitServices( + Action reportDiagnostic, + ImmutableArray parameters, + ImmutableHashSet<(INamedTypeSymbol ServiceType, string? Key)> asyncInitOnlyServiceTypes, + CancellationToken cancellationToken) + { + foreach(var parameter in parameters) + { + cancellationToken.ThrowIfCancellationRequested(); + + if(ShouldSkipAsyncInitDependencyCheck(parameter)) + continue; + + AnalyzeDependencyOnAsyncInitService( + reportDiagnostic, + parameter.Name, + parameter.Type, + parameter.Locations.FirstOrDefault(), + parameter.GetServiceKeyAndAttributeInfo().ServiceKey, + asyncInitOnlyServiceTypes); + } + } + + private static void AnalyzeDependencyOnAsyncInitService( + Action reportDiagnostic, + string memberName, + ITypeSymbol dependencyType, + Location? location, + string? serviceKey, + ImmutableHashSet<(INamedTypeSymbol ServiceType, string? Key)> asyncInitOnlyServiceTypes) + { + if(AnalyzerHelpers.TryGetAsyncWrapperElementType(dependencyType) is not null) + return; + + if(dependencyType.WithNullableAnnotation(NullableAnnotation.NotAnnotated) is not INamedTypeSymbol namedDependencyType) + return; + + if(!asyncInitOnlyServiceTypes.Contains((namedDependencyType, serviceKey))) + return; + + reportDiagnostic(Diagnostic.Create( + SyncDependencyOnAsyncInitService, + location, + memberName, + namedDependencyType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat))); + } + + private static bool ShouldSkipAsyncInitDependencyCheck(IParameterSymbol parameter) + { + if(parameter.HasExplicitDefaultValue) + return true; + + if(AnalyzerHelpers.IsWellKnownServiceType(parameter.Type)) + return true; + + return parameter.GetAttributes().Any(static attr => + attr.AttributeClass?.Name == "ServiceKeyAttribute" + && attr.AttributeClass.ContainingNamespace?.ToDisplayString() == "Microsoft.Extensions.DependencyInjection"); } private static IocFeatures ParseIocFeatures(AnalyzerOptions options) diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/Spec/SPEC.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/Spec/SPEC.spec.md index 1b23a65..83459ce 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/Spec/SPEC.spec.md +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/Spec/SPEC.spec.md @@ -86,7 +86,7 @@ Report when `IocInjectAttribute`/`InjectAttribute` is mark on: - member without public, internal, or protected internal accessibility - property without setter or with private setter - readonly field -- method that does not return void +- method with an unsupported return type - method that is generic (has type parameters) - method that is not an ordinary method (e.g., constructor, operator) @@ -98,10 +98,22 @@ Report when `IocInjectAttribute`/`InjectAttribute` is mark on: - Member is not `public`, `internal`, or `protected internal` (private, protected, or private protected members are rejected because generated code runs in a public static context). - Property has no setter or setter is private. - Field is readonly. - - Method does not return void. + - Method returns a type other than `void` or supported non-generic `Task`. `async void` is handled separately by `SGIOC028`. - Method is generic (has type parameters). - Method is not an ordinary method (i.e., constructors, operators, and other special methods are rejected). +#### Method Return-Type Truth Table + +|Method shape|`AsyncMethodInject` enabled|Diagnostic result| +|:-----------|:--------------------------|:----------------| +|`void Initialize(...)`|No/Yes|No return-type diagnostic.| +|`void Initialize(...)` declared as `async void`|No/Yes|`SGIOC028` MUST report. `SGIOC007` SHOULD NOT duplicate the return-type diagnostic.| +|`Task InitializeAsync(...)`|Yes|No `SGIOC007` return-type diagnostic.| +|`Task InitializeAsync(...)`|No|`SGIOC022` MUST report the disabled feature. `SGIOC007` MUST NOT report a duplicate return-type diagnostic.| +|`Task InitializeAsync(...)`|No/Yes|`SGIOC007` MUST report.| +|`ValueTask InitializeAsync(...)` or `ValueTask InitializeAsync(...)`|No/Yes|`SGIOC007` MUST report.| +|Any other non-void return type|No/Yes|`SGIOC007` MUST report.| + --- ### SGIOC008 - Error - Usage - Invalid Attribute Usage @@ -316,12 +328,24 @@ Report when a partial method or property accessor in a container class reference - Checks classes marked with `[IocContainer]` attribute that have `IntegrateServiceProvider = false`. - Scans the container class members for partial methods (non-void, parameterless, non-generic) and partial properties (with getter). -- Checks if the return type of each non-nullable partial accessor is a registered service type. -- Reports when a non-nullable partial accessor's return type is not found among registered services. +- For non-nullable partial accessors, determines the effective service type by: + - Reporting SGIOC021 immediately for unsupported return types: non-generic `Task`, non-generic `ValueTask`. + - `ValueTask`: unwraps `T` and checks registration. If `T` references an async-init service, `SGIOC029` is reported instead (see SGIOC029). Otherwise `SGIOC021` is reported. + - Recursively unwraps Generator-supported wrapper types to extract the innermost service type for resolution checking. Supported wrappers: `Task`, `Lazy`, `Func` / `Func` (extracts the last type argument as the return/service type), `IEnumerable`, `IReadOnlyCollection`, `ICollection`, `IReadOnlyList`, `IList`, `T[]`, `IDictionary`, `IReadOnlyDictionary`, `Dictionary`, `KeyValuePair`. + - Mirrors the Generator downgrade rules while recursively unwrapping wrappers: + - `Task>` and `Wrapper>` are treated as unresolvable when `IntegrateServiceProvider = false`. + - `ValueTask` is not a Generator-supported wrapper; if encountered during wrapper recursion, the accessor is treated as unresolvable. + - A top-level collection wrapper whose element type contains nested non-collection wrappers, for example `IEnumerable>>`, is treated as unresolvable when `IntegrateServiceProvider = false`. + - If the innermost unwrapped service type matches a registration (with the same service key) that has async-init implementations, `SGIOC021` skips the accessor and `SGIOC029` owns the return-type diagnostic regardless of `IntegrateServiceProvider`. + - For supported wrapper shapes that successfully unwrap, registration lookup and the diagnostic message use the innermost service type. + - For downgraded or unsupported wrapper shapes whose innermost service type is not async-init, the accessor is treated as unresolvable and the diagnostic message uses the full return type. + - For non-wrapper types, the return type itself is checked directly. +- Reports when a non-nullable partial accessor's effective service type is not found among registered services. +- Nullable accessors are exempt (can safely return `null`). **Rationale:** -When `IntegrateServiceProvider = false`, there is no fallback to an external `IServiceProvider`. If a partial accessor references a service type that is not registered, it cannot be resolved at runtime. Nullable accessors are exempt because they can safely return `null`. +When `IntegrateServiceProvider = false`, there is no fallback to an external `IServiceProvider`. If a partial accessor references a service type that is not registered after recursively unwrapping any Generator-supported wrapper, it cannot be resolved at runtime. Wrapper shapes that the Generator downgrades to `IServiceProvider` fallback are also unresolvable in this mode. Unsupported return types cannot be generated regardless of registration state. If the innermost unwrapped service type is async-init, the accessor is excluded from `SGIOC021` because `SGIOC029` owns async-init partial accessor return-type mismatches for all containers. **Message format:** `Unable to resolve service '{ServiceType}' for partial accessor '{MemberName}' in container '{ContainerType}'.` @@ -338,9 +362,19 @@ Report when a member has `[IocInject]`/`[Inject]` but its corresponding feature - Checks members marked with `[IocInject]`/`[Inject]`: - `IPropertySymbol` requires `PropertyInject` - `IFieldSymbol` requires `FieldInject` - - `IMethodSymbol` requires `MethodInject` + - `IMethodSymbol` returning `void` requires `MethodInject` + - `IMethodSymbol` returning non-generic `Task` requires `AsyncMethodInject` - Reports when the required feature flag is not enabled. +#### Feature Gate Mapping + +|Member shape|Required feature|Notes| +|:-----------|:---------------|:----| +|Property|`PropertyInject`|Unchanged.| +|Field|`FieldInject`|Unchanged.| +|Method returning `void`|`MethodInject`|Covers synchronous method injection.| +|Method returning non-generic `Task`|`AsyncMethodInject`|`MethodInject` remains a project-level prerequisite when `AsyncMethodInject` is enabled; invalid combinations are reported by `SGIOC026`.| + **Message format:** `'{MemberName}' has [IocInject] but {FeatureName} feature is not enabled. Add '{FeatureName}' to in your project file.` --- @@ -377,7 +411,8 @@ Report when a member resolved from `nameof()` in `InjectMembers` cannot be injec - not `public`, `internal`, or `protected internal` (private, protected, or private protected members are rejected because generated registration code runs in a public static context) - property without setter or with private setter - readonly field - - method that doesn't return void or is generic + - method that doesn't return `void` (or non-generic `Task` when `AsyncMethodInject` is enabled) + - generic method - method that is not an ordinary method (i.e., constructors, operators, and other special methods are rejected) - This validation reuses the same logic as SGIOC007 but specifically for members specified via `InjectMembers`. @@ -403,3 +438,89 @@ Circular module imports create static initializer deadlocks. When `_serviceResol **Message format:** `Container 'TestNamespace.ModuleA' has a circular module import dependency: TestNamespace.ModuleA → TestNamespace.ModuleB → TestNamespace.ModuleA` Both `{ContainerType}` and types in `{CyclePath}` use `NameAndContainingTypesAndNamespaces` display format (without `global::` prefix). For types in the global namespace, this format may produce simple names. + +--- + +### SGIOC026 - Error - Usage - Invalid feature combination + +Report when `SourceGenIocFeatures` enables `AsyncMethodInject` without also enabling `MethodInject`. + +**Analysis:** + +- Reads and parses `SourceGenIocFeatures` during `CompilationStart`. +- If `AsyncMethodInject` is enabled and `MethodInject` is disabled, the analyzer MUST report `SGIOC026`. +- The diagnostic SHOULD be reported once per compilation because the problem is project-wide rather than member-specific. + +**Message format:** `'AsyncMethodInject' feature requires 'MethodInject' to be enabled.` + +--- + +### SGIOC027 - Error - Design - Partial accessor must return `Task` for async-init service + +Report when a partial accessor returns the direct synchronous service type with no wrapper even though the matched implementation has async inject methods. + +**Analysis:** + +- Applies to partial methods and partial properties declared in `[IocContainer]` types. +- Applies regardless of the `IntegrateServiceProvider` setting because returning the wrong type for an async-init service is a semantic error, not a fallback-resolution issue. +- Matches the accessor's service type and key against container registrations. +- If the matched implementation contains async inject methods and the accessor returns the direct synchronous service type `TService` with no wrapper instead of `Task`, the analyzer MUST report `SGIOC027`. +- `SGIOC029` owns all non-`Task` return types for async-init services, including wrappers, arrays, and unsupported async shapes; `SGIOC027` is only for the direct synchronous `TService` case. + +**Message format:** `Partial accessor '{MemberName}' returns '{ServiceType}' but the implementation has async inject methods. Use 'Task<{ServiceType}>'.` + +--- + +### SGIOC028 - Warning - Usage - `async void` injection method cannot be awaited + +Report when an `[IocInject]`/`[Inject]` method is declared as `async void`. + +**Analysis:** + +- Checks methods marked with `[IocInject]` or `[Inject]`. +- If a method is declared `async` and returns `void`, the analyzer MUST report `SGIOC028` because the generator cannot await it. +- `SGIOC028` SHOULD be the user-facing diagnostic for this case; `SGIOC007` SHOULD NOT add a duplicate return-type diagnostic for the same method. + +**Message format:** `[IocInject] method '{MethodName}' is 'async void' which cannot be awaited. Change return type to 'Task'.` + +--- + +### SGIOC029 - Error - Design - Unsupported async partial accessor type + +Report when a partial accessor targets an async-init service but returns any non-`Task` shape, including wrappers, arrays, or unsupported async shapes. + +**Analysis:** + +- Applies to partial methods and partial properties declared in `[IocContainer]` types. +- Applies regardless of the `IntegrateServiceProvider` setting because returning any non-`Task` shape for an async-init service is a semantic error, not a resolution fallback issue. +- Checks accessors that target registrations whose implementation contains async inject methods. +- For generic wrappers and arrays, recursively unwraps wrapper element types to find the innermost service type. This diagnostic-only analysis intentionally ignores Generator downgrade rules so that shapes such as `Task>`, `Lazy>`, nested wrappers, and arrays are still classified by their innermost service type. +- If the innermost unwrapped service type is an async-init service and the declared return type is not exactly `Task`, the analyzer MUST report `SGIOC029`. +- The only supported async accessor return shape is `Task`. +- `Task` is supported and does not produce a diagnostic. +- `SGIOC027` owns only the direct synchronous `TService` return case with no wrapper. +- `SGIOC029` owns all remaining non-`Task` return types for async-init services, including `ValueTask`, wrapper types, collection wrappers, arrays, nested wrappers, and downgraded async-shaped returns. + +**Message format:** `Partial accessor '{MemberName}' returns '{ReturnType}' which is not a supported async type. Only 'Task' is supported.` + +--- + +### SGIOC030 - Error - Usage - Synchronous dependency requested for async-init service + +Report when a consumer requests `TService` but the matched registration can only be resolved asynchronously. + +**Analysis:** + +- Applies to constructor parameters, injected properties, injected fields, and parameters of `[IocInject]`/`[Inject]` methods. +- Matches the requested service type and key against available registrations. +- If the matched service has async inject methods and there is no synchronous registration for the same service type/key, the analyzer MUST report `SGIOC030`. +- Consumers SHOULD request `Task` instead of `TService` in this scenario. +- Partial accessors are handled separately by `SGIOC027` and `SGIOC029`. + +**Message format:** `'{MemberName}' requires '{ServiceType}' but this service has async inject methods and no synchronous registration exists. Use 'Task<{ServiceType}>'.` + +--- + +### Known Limitations + +None. diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/AnalyzerReleases.Unshipped.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/AnalyzerReleases.Unshipped.md index fa5cb74..97d31bb 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/AnalyzerReleases.Unshipped.md +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/AnalyzerReleases.Unshipped.md @@ -27,4 +27,8 @@ SGIOC022 | Usage | Warning | Inject attribute ignored due to disabled feature - SGIOC023 | Usage | Error | Invalid InjectMembers element format - Each element in InjectMembers must be nameof(member) or new object[] { nameof(member), key [, KeyType] }. SGIOC024 | Usage | Error | InjectMembers specifies non-injectable member - Members in InjectMembers must be injectable (instance properties with accessible setters, non-readonly fields, and ordinary non-generic void-returning methods, all of which must be public, internal, or protected internal). SGIOC025 | Design | Error | Circular module import detected - A container has a circular [IocImportModule] dependency that would cause a static initializer deadlock. - +SGIOC026 | Usage | Error | Invalid feature combination - AsyncMethodInject feature requires MethodInject to be enabled. +SGIOC027 | Design | Error | Partial accessor must return Task of T for async-init service - The matched implementation has async inject methods but the accessor returns plain TService. +SGIOC028 | Usage | Warning | async void injection method cannot be awaited - [IocInject] method is declared as async void. +SGIOC029 | Design | Error | Unsupported async partial accessor type - Partial accessor targets an async-init service but returns an async type other than Task of T. +SGIOC030 | Usage | Error | Synchronous dependency requested for async-init service - Consumer requests T but the service has async inject methods and no synchronous registration exists. diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/FuncRegistrationHelper.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/FuncRegistrationHelper.cs index b469234..4e1f40a 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/FuncRegistrationHelper.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/FuncRegistrationHelper.cs @@ -367,6 +367,10 @@ private static ImmutableEquatableArray CollectContainerFuncE if(reg.IsOpenGeneric) continue; + // Async-init services cannot be resolved synchronously — exclude from Func entries + if(cached.IsAsyncInit) + continue; + var entryKey = $"{serviceType}|{reg.ImplementationType.Name}|{reg.Key}"; if(!addedKeys.Add(entryKey)) continue; diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateContainerOutput.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateContainerOutput.cs index a6d4779..411bd49 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateContainerOutput.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateContainerOutput.cs @@ -129,7 +129,7 @@ private static ImmutableEquatableArray FilterCachedRegistrat filteredRegistrations.Add(registrations[j]); } - filteredRegistrations.Add(registration with { Registration = filteredRegistration }); + filteredRegistrations.Add(registration with { Registration = filteredRegistration, IsAsyncInit = HasAsyncInitMembers(filteredRegistration) }); } return filteredRegistrations is null ? registrations : filteredRegistrations.ToImmutableEquatableArray(); @@ -562,12 +562,23 @@ private static void WritePartialAccessorImplementations( foreach(var accessor in container.PartialAccessors) { + var isTaskReturn = accessor.Kind == PartialAccessorKind.Method + && TryExtractTaskInnerType(accessor.ReturnTypeName, out _); + var resolveExpression = ResolvePartialAccessorExpression(accessor, container, groups); switch(accessor.Kind) { case PartialAccessorKind.Method: - writer.WriteLine($"public partial {accessor.ReturnTypeName}{(accessor.IsNullable ? "?" : "")} {accessor.Name}() => {resolveExpression};"); + // Async partial methods (returning Task) require the 'async' modifier + if(isTaskReturn) + { + writer.WriteLine($"public partial async {accessor.ReturnTypeName} {accessor.Name}() => {resolveExpression};"); + } + else + { + writer.WriteLine($"public partial {accessor.ReturnTypeName}{(accessor.IsNullable ? "?" : "")} {accessor.Name}() => {resolveExpression};"); + } break; case PartialAccessorKind.Property: @@ -585,6 +596,7 @@ private static void WritePartialAccessorImplementations( /// /// Resolves the expression to use for a partial accessor implementation. /// Looks up the registration by return type and optional key, with fallback to IServiceProvider. + /// For Task<T> return types, routes through the async resolver. /// private static string ResolvePartialAccessorExpression( PartialAccessorData accessor, @@ -594,6 +606,37 @@ private static string ResolvePartialAccessorExpression( var serviceType = accessor.ReturnTypeName; var key = accessor.Key; + // Handle Task return types — route through the async resolver. + if(TryExtractTaskInnerType(serviceType, out var innerTypeName)) + { + if(groups.ByServiceTypeAndKey.TryGetValue((innerTypeName, key), out var taskRegistrations)) + { + var cached = taskRegistrations[^1]; // Last registration wins + + if(cached.IsAsyncInit) + { + // Async-init: await the shared async resolver and let the async method wrap the cast. + var asyncMethodName = GetAsyncResolverMethodName(cached.ResolverMethodName); + return $"await {asyncMethodName}()"; + } + else + { + // Sync-only service wrapped as Task: use Task.FromResult with cast. + return $"global::System.Threading.Tasks.Task.FromResult(({innerTypeName}){cached.ResolverMethodName}())"; + } + } + + // Fallback: delegate to IServiceProvider if available + if(container.IntegrateServiceProvider) + { + if(key is not null) + return $"({serviceType})GetRequiredKeyedService(typeof({serviceType}), {key})"; + return $"({serviceType})GetRequiredService(typeof({serviceType}))"; + } + + return $"""throw new global::System.InvalidOperationException("Service '{innerTypeName}' is not registered.")"""; + } + // Try to find direct resolver in this container if(groups.ByServiceTypeAndKey.TryGetValue((serviceType, key), out var registrations)) { @@ -733,6 +776,23 @@ private static void WriteServiceResolverMethod( return; } + // For async-init services: generate async resolver instead of sync resolver + if(cached.IsAsyncInit) + { + switch(reg.Lifetime) + { + case ServiceLifetime.Singleton: + case ServiceLifetime.Scoped: + WriteAsyncServiceResolverMethod(writer, strategy, methodName, returnType, fieldName, reg, hasFactory, hasDecorators, groups); + break; + + case ServiceLifetime.Transient: + WriteAsyncTransientResolverMethod(writer, methodName, returnType, reg, hasFactory, hasDecorators, groups); + break; + } + return; + } + switch(reg.Lifetime) { case ServiceLifetime.Singleton: @@ -1070,6 +1130,325 @@ private static void WriteTransientResolverMethod( writer.WriteLine("}"); } + // ────────────────────────────────────────────────────────────────────────────── + // Async-init service resolver generation + // Async-init services have at least one InjectionMemberType.AsyncMethod member. + // They use Task caching and async resolver methods. + // ────────────────────────────────────────────────────────────────────────────── + + /// + /// Returns the async routing resolver method name by appending "Async" to the sync method name. + /// + private static string GetAsyncResolverMethodName(string syncMethodName) + => syncMethodName + "Async"; + + /// + /// Returns the async creation method name (e.g. "CreateFooBarAsync" from "GetFooBar"). + /// + private static string GetAsyncCreateMethodName(string syncMethodName) + { + if(syncMethodName.Length > 3 && syncMethodName.StartsWith("Get", StringComparison.Ordinal)) + return "Create" + syncMethodName[3..] + "Async"; + return syncMethodName + "_CreateAsync"; + } + + /// + /// Returns the effective thread-safety strategy for a registration. + /// Async-init services auto-upgrade async-incompatible strategies to . + /// + private static ThreadSafeStrategy GetEffectiveThreadSafeStrategy( + ThreadSafeStrategy strategy, + bool isAsyncInit) + { + if(!isAsyncInit) + return strategy; + + return strategy is ThreadSafeStrategy.None ? ThreadSafeStrategy.None : ThreadSafeStrategy.SemaphoreSlim; + } + + /// + /// Writes the field declaration for an async-init service's cached Task<T>. + /// The caller must pass the effective async-init strategy. + /// + private static void WriteAsyncServiceInstanceField( + SourceWriter writer, + ThreadSafeStrategy strategy, + string fieldName, + string taskFieldTypeName) + { + writer.WriteLine($"private {taskFieldTypeName}? {fieldName};"); + + // Only SemaphoreSlim is async-compatible; others fall back to unsynchronized access. + if(strategy == ThreadSafeStrategy.SemaphoreSlim) + { + writer.WriteLine($"private readonly global::System.Threading.SemaphoreSlim {fieldName}Semaphore = new(1, 1);"); + } + } + + /// + /// Writes an async routing resolver + async creation method for a singleton/scoped async-init service. + /// + private static void WriteAsyncServiceResolverMethod( + SourceWriter writer, + ThreadSafeStrategy strategy, + string syncMethodName, + string returnType, + string fieldName, + ServiceRegistrationModel reg, + bool hasFactory, + bool hasDecorators, + ContainerRegistrationGroups groups) + { + var asyncMethodName = GetAsyncResolverMethodName(syncMethodName); + var createMethodName = GetAsyncCreateMethodName(syncMethodName); + var taskReturnType = $"global::System.Threading.Tasks.Task<{returnType}>"; + var effectiveStrategy = GetEffectiveThreadSafeStrategy(strategy, true); + + // Write the Task? instance field (+ semaphore if SemaphoreSlim) + WriteAsyncServiceInstanceField(writer, effectiveStrategy, fieldName, taskReturnType); + writer.WriteLine(); + + // ── Routing resolver method ── + writer.WriteLine($"private async {taskReturnType} {asyncMethodName}()"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine($"if({fieldName} is not null)"); + writer.Indentation++; + writer.WriteLine($"return await {fieldName};"); + writer.Indentation--; + writer.WriteLine(); + + if(effectiveStrategy == ThreadSafeStrategy.SemaphoreSlim) + { + WriteAsyncResolverBodySemaphoreSlim(writer, fieldName, createMethodName); + } + else + { + WriteAsyncResolverBodyNone(writer, fieldName, createMethodName); + } + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // ── Creation method ── + writer.WriteLine($"private async {taskReturnType} {createMethodName}()"); + writer.WriteLine("{"); + writer.Indentation++; + + WriteAsyncInstanceCreationBody(writer, reg, hasFactory, hasDecorators, groups); + + writer.Indentation--; + writer.WriteLine("}"); + } + + /// + /// Writes an async creation method for a transient async-init service. + /// Each call produces a new Task (no caching). + /// + private static void WriteAsyncTransientResolverMethod( + SourceWriter writer, + string syncMethodName, + string returnType, + ServiceRegistrationModel reg, + bool hasFactory, + bool hasDecorators, + ContainerRegistrationGroups groups) + { + var createMethodName = GetAsyncCreateMethodName(syncMethodName); + var taskReturnType = $"global::System.Threading.Tasks.Task<{returnType}>"; + + writer.WriteLine($"private async {taskReturnType} {createMethodName}()"); + writer.WriteLine("{"); + writer.Indentation++; + + WriteAsyncInstanceCreationBody(writer, reg, hasFactory, hasDecorators, groups); + + writer.Indentation--; + writer.WriteLine("}"); + } + + /// + /// Writes the instance creation body for an async-init service: + /// constructor, sync injection (properties + sync methods), await async methods, optional decorators. + /// + private static void WriteAsyncInstanceCreationBody( + SourceWriter writer, + ServiceRegistrationModel reg, + bool hasFactory, + bool hasDecorators, + ContainerRegistrationGroups groups) + { + var (properties, syncMethods, asyncMethods) = CategorizeInjectionMembersAsync(reg.InjectionMembers); + var args = BuildConstructorArgumentsString(reg, groups); + + // When decorators are present AND there is method injection (sync or async), we cannot + // type the instance as the service interface because [IocInject] methods may be on the + // concrete implementation only. + // + // Two-variable pattern: + // var baseInstance = new Impl(args) { Props... }; + // baseInstance.SyncMethod(...); + // await baseInstance.AsyncMethod(...); + // ServiceType instance = baseInstance; + // instance = new Decorator(instance); // decorator chain + // + // Single-variable pattern (no decorators, or decorators + pure property injection): + // var instance = new Impl(args) { Props... }; + // await instance.AsyncInit(...); + bool hasMethods = syncMethods is { Count: > 0 } || asyncMethods is { Count: > 0 }; + bool needsTwoVarPattern = hasDecorators && hasMethods; + + // ── Create the instance ── + string injectionVar = needsTwoVarPattern ? "baseInstance" : "instance"; + string varTypeDecl = (hasDecorators && !needsTwoVarPattern) ? reg.ServiceType.Name : "var"; + + if(hasFactory) + { + var factoryCall = BuildFactoryCallForContainer(reg.Factory!, reg, groups); + writer.WriteLine($"{varTypeDecl} {injectionVar} = ({reg.ImplementationType.Name}){factoryCall};"); + } + else + { + WriteConstructorWithPropertyInitializers(writer, injectionVar, varTypeDecl, reg.ImplementationType.Name, args, properties, groups); + } + + // ── Sync method injection ── + if(syncMethods is { Count: > 0 }) + { + foreach(var method in syncMethods) + { + var methodArgs = method.Parameters is { Length: > 0 } + ? string.Join(", ", method.Parameters.Select(p => BuildParameterForContainer(p, reg, groups))) + : ""; + writer.WriteLine($"{injectionVar}.{method.Name}({methodArgs});"); + } + } + + // ── Awaited async method injection ── + if(asyncMethods is { Count: > 0 }) + { + foreach(var method in asyncMethods) + { + var methodArgs = method.Parameters is { Length: > 0 } + ? string.Join(", ", method.Parameters.Select(p => BuildParameterForContainer(p, reg, groups))) + : ""; + writer.WriteLine($"await {injectionVar}.{method.Name}({methodArgs});"); + } + } + + // ── Apply decorators after all injection ── + if(hasDecorators) + { + writer.WriteLine(); + if(needsTwoVarPattern) + { + // Convert the concrete implementation variable to the service type + // so the decorator chain can reassign the variable. + writer.WriteLine($"{reg.ServiceType.Name} instance = {injectionVar};"); + } + WriteDecoratorApplication(writer, "instance", reg, groups); + } + + writer.WriteLine("return instance;"); + } + + /// + /// Writes the async routing resolver body for (no synchronization). + /// + private static void WriteAsyncResolverBodyNone( + SourceWriter writer, + string fieldName, + string createMethodName) + { + writer.WriteLine($"{fieldName} = {createMethodName}();"); + writer.WriteLine($"return await {fieldName};"); + } + + /// + /// Writes the async routing resolver body for . + /// Uses WaitAsync() for async-compatible locking. + /// + private static void WriteAsyncResolverBodySemaphoreSlim( + SourceWriter writer, + string fieldName, + string createMethodName) + { + writer.WriteLine($"await {fieldName}Semaphore.WaitAsync();"); + writer.WriteLine("try"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine($"if({fieldName} is null)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine($"{fieldName} = {createMethodName}();"); + writer.Indentation--; + writer.WriteLine("}"); + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("finally"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine($"{fieldName}Semaphore.Release();"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine($"return await {fieldName};"); + } + + /// + /// Categorizes injection members into properties/fields, sync methods, and async methods. + /// + private static (List? Properties, List? SyncMethods, List? AsyncMethods) CategorizeInjectionMembersAsync( + ImmutableEquatableArray injectionMembers) + { + List? properties = null; + List? syncMethods = null; + List? asyncMethods = null; + + foreach(var member in injectionMembers) + { + switch(member.MemberType) + { + case InjectionMemberType.Property: + case InjectionMemberType.Field: + properties ??= []; + properties.Add(member); + break; + case InjectionMemberType.Method: + syncMethods ??= []; + syncMethods.Add(member); + break; + case InjectionMemberType.AsyncMethod: + asyncMethods ??= []; + asyncMethods.Add(member); + break; + } + } + + return (properties, syncMethods, asyncMethods); + } + + /// + /// Tries to extract the inner type name from a + /// global::System.Threading.Tasks.Task<T> type name string. + /// Returns and sets if matched. + /// + private static bool TryExtractTaskInnerType(string typeName, out string innerTypeName) + { + const string TaskPrefix = "global::System.Threading.Tasks.Task<"; + if(typeName.StartsWith(TaskPrefix, StringComparison.Ordinal) + && typeName.EndsWith(">", StringComparison.Ordinal)) + { + innerTypeName = typeName[TaskPrefix.Length..^1]; + return true; + } + innerTypeName = string.Empty; + return false; + } + /// /// Writes instance creation with property/method injection. /// @@ -1416,6 +1795,20 @@ private static IEnumerable GetFactoryArguments(FactoryMethodData factory yield return BuildParameterForContainer(param, reg, groups); } + private static string BuildServiceProviderFallbackExpression( + string typeName, + string? key, + bool isOptional) + { + if(key is not null) + return isOptional + ? $"GetKeyedService(typeof({typeName}), {key}) as {typeName}" + : $"({typeName})GetRequiredKeyedService(typeof({typeName}), {key})"; + return isOptional + ? $"GetService(typeof({typeName})) as {typeName}" + : $"({typeName})GetRequiredService(typeof({typeName}))"; + } + /// /// Builds a service resolution call for container (direct call or GetService/GetRequiredService). /// When the dependency is registered in the same container, calls the resolver method directly. @@ -1464,8 +1857,8 @@ private static string BuildServiceResolutionCallForContainer( return $"GetServices<{elementTypeName}>()"; } - // Wrapper types - Lazy, Func, KeyValuePair - if(type is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData) + // Wrapper types - Lazy, Func, KeyValuePair, Task + if(type is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData or TaskTypeData) { return BuildWrapperExpressionForContainer(type, key, isOptional, groups); } @@ -1474,27 +1867,20 @@ private static string BuildServiceResolutionCallForContainer( if(groups.ByServiceTypeAndKey.TryGetValue((type.Name, key), out var registrations)) { var cached = registrations[^1]; // Last wins - return $"{cached.ResolverMethodName}()"; - } - - // Fallback to GetService/GetRequiredService for dependencies not in this container - - // Keyed services - if(key is not null) - { - if(isOptional) + // Async-init services: the sync method was not generated; use the async method instead. + // Callers that depend on an async-init service should be taking Task, not T directly. + // The analyzer (SGIOC027/029) normally prevents this, but fall back gracefully. + if(cached.IsAsyncInit) { - return $"GetKeyedService(typeof({type.Name}), {key}) as {type.Name}"; + if(cached.Registration.Lifetime == ServiceLifetime.Transient) + return $"{GetAsyncCreateMethodName(cached.ResolverMethodName)}()"; + return $"{GetAsyncResolverMethodName(cached.ResolverMethodName)}()"; } - return $"({type.Name})GetRequiredKeyedService(typeof({type.Name}), {key})"; + return $"{cached.ResolverMethodName}()"; } - // Regular services - if(isOptional) - { - return $"GetService(typeof({type.Name})) as {type.Name}"; - } - return $"({type.Name})GetRequiredService(typeof({type.Name}))"; + // Fallback to GetService/GetRequiredService for dependencies not in this container + return BuildServiceProviderFallbackExpression(type.Name, key, isOptional); } /// @@ -1523,8 +1909,9 @@ private static string BuildWrapperExpressionForContainer( var safeImplType = GetSafeIdentifier(lastReg.Registration.ImplementationType.Name); return $"_lazy_{safeInnerType}_{safeImplType}"; } - // Fallback: no matching inner service - return BuildServiceResolutionCallForContainer(type, key, isOptional, groups); + // Fallback: inner type not in this container — build inline via IServiceProvider + var lazyFallbackExpr = BuildServiceProviderFallbackExpression(innerType.Name, key, isOptional); + return $"new global::System.Lazy<{innerType.Name}>(() => {lazyFallbackExpr}, global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication)"; } // Nested wrapper or inside nested context — inline construction var lazyInnerExpr = BuildInnerResolutionForContainer(innerType, key, isOptional, groups); @@ -1543,7 +1930,11 @@ private static string BuildWrapperExpressionForContainer( return BuildContainerMultiParamFuncExpression(func, targetRegistration, groups); } - return BuildServiceResolutionCallForContainer(type, key, isOptional, groups); + // Fallback: inner return type not in this container — resolve the full Func<...> type + // directly from IServiceProvider. Do NOT call BuildServiceResolutionCallForContainer + // here as that would route FuncTypeData back to BuildWrapperExpressionForContainer, + // causing infinite recursion. + return BuildServiceProviderFallbackExpression(type.Name, key, isOptional); } // Direct Func where T is not a wrapper — call wrapper resolver if available (only at top level) @@ -1556,8 +1947,9 @@ private static string BuildWrapperExpressionForContainer( var safeImplType = GetSafeIdentifier(lastReg.Registration.ImplementationType.Name); return $"_func_{safeInnerType}_{safeImplType}"; } - // Fallback: no matching inner service - return BuildServiceResolutionCallForContainer(type, key, isOptional, groups); + // Fallback: inner type not in this container — build inline via IServiceProvider + var funcFallbackExpr = BuildServiceProviderFallbackExpression(innerType.Name, key, isOptional); + return $"new global::System.Func<{innerType.Name}>(() => {funcFallbackExpr})"; } // Nested wrapper or inside nested context — inline construction var funcInnerExpr = BuildInnerResolutionForContainer(innerType, key, isOptional, groups); @@ -1591,6 +1983,33 @@ private static string BuildWrapperExpressionForContainer( return $"GetServices<{kvpTypeName}>().ToDictionary()"; } + case TaskTypeData task: + { + // Task wrapper: route based on sync vs async-init registration. + var innerType = task.InnerType; + var innerTypeName = innerType.Name; + + if(groups.ByServiceTypeAndKey.TryGetValue((innerTypeName, key), out var innerRegs)) + { + var lastReg = innerRegs[^1]; + if(lastReg.IsAsyncInit) + { + // Async-init: project Task → Task via async lambda (not ContinueWith) + // so that exceptions propagate as-awaited rather than wrapped in AggregateException. + var asyncMethodName = GetAsyncResolverMethodName(lastReg.ResolverMethodName); + return $"((global::System.Func>)(async () => ({innerTypeName})(await {asyncMethodName}())))()"; + } + else + { + // Sync-only: wrap in Task.FromResult with cast. + return $"global::System.Threading.Tasks.Task.FromResult(({innerTypeName}){lastReg.ResolverMethodName}())"; + } + } + + // Fallback to IServiceProvider + return BuildServiceProviderFallbackExpression(type.Name, key, isOptional); + } + default: return BuildServiceResolutionCallForContainer(type, key, isOptional, groups); } @@ -1607,9 +2026,12 @@ private static string BuildInnerResolutionForContainer( bool isOptional, ContainerRegistrationGroups groups) { - if(innerType is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData) + if(innerType is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData or TaskTypeData) { - // Inner wrappers always use inline construction (no resolver methods) + // Inner wrappers always use inline construction (no resolver methods). + // NOTE: Nested Task shapes such as Lazy> or IEnumerable> are not + // supported by the spec. The transform layer prevents these from reaching code generation + // by downgrading their WrapperKind to None, so they fall back to IServiceProvider. return BuildWrapperExpressionForContainer(innerType, key, isOptional, groups, useResolverMethods: false); } @@ -2240,6 +2662,22 @@ private static void WriteIIocContainerImplementation( // Instance registration: directly return the instance resolverExpr = $"static _ => {cached.Registration.Instance}"; } + else if(cached.IsAsyncInit) + { + // Async-init services: expose Task from GetService. + // Singleton/Scoped: call the routing async resolver method. + // Transient: call the async creation method directly (no caching). + if(cached.Registration.Lifetime == ServiceLifetime.Transient) + { + var createMethodName = GetAsyncCreateMethodName(cached.ResolverMethodName); + resolverExpr = $"static c => c.{createMethodName}()"; + } + else + { + var asyncMethodName = GetAsyncResolverMethodName(cached.ResolverMethodName); + resolverExpr = $"static c => c.{asyncMethodName}()"; + } + } else if(cached.IsEager) { // Eager services: directly access the field @@ -2382,7 +2820,9 @@ private static void WriteDisposalImplementation( writer.WriteLine("}"); writer.WriteLine(); - WriteDisposalHelperMethods(writer); + var hasAsyncInitServices = groups.ReversedSingletonsForDisposal.Any(static c => c.IsAsyncInit) + || groups.ReversedScopedForDisposal.Any(static c => c.IsAsyncInit); + WriteDisposalHelperMethods(writer, hasAsyncInitServices); writer.WriteLine("#endregion"); } @@ -2716,10 +3156,12 @@ private static void WriteDisposalCalls( if(cached.Registration.Instance is not null) continue; + var effectiveStrategy = GetEffectiveThreadSafeStrategy(strategy, cached.IsAsyncInit); + writer.WriteLine($"{serviceMethod}({cached.FieldName});"); // Dispose SemaphoreSlim if using SemaphoreSlim strategy (only for non-eager services) - if(strategy == ThreadSafeStrategy.SemaphoreSlim && !cached.IsEager) + if(effectiveStrategy == ThreadSafeStrategy.SemaphoreSlim && !cached.IsEager) { writer.WriteLine($"{cached.FieldName}Semaphore.Dispose();"); } @@ -2735,7 +3177,7 @@ private static void WriteDisposalCalls( /// /// Writes the static helper methods for disposal. /// - private static void WriteDisposalHelperMethods(SourceWriter writer) + private static void WriteDisposalHelperMethods(SourceWriter writer, bool hasAsyncInitServices) { // Helper method to throw ObjectDisposedException if disposed writer.WriteLine("private void ThrowIfDisposed()"); @@ -2764,6 +3206,58 @@ private static void WriteDisposalHelperMethods(SourceWriter writer) writer.Indentation--; writer.WriteLine("}"); writer.WriteLine(); + + if(hasAsyncInitServices) + { + // Overload for async-init services stored as Task? + writer.WriteLine("private static async ValueTask DisposeServiceAsync(Task? task)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("if(task is { IsCompletedSuccessfully: true })"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("try"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("await DisposeServiceAsync(await task);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("catch(Exception ex)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + writer.WriteLine("private static void DisposeService(Task? task)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("if(task is { IsCompletedSuccessfully: true })"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("try"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult());"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("catch(Exception ex)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + } } /// diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateRegisterOutput.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateRegisterOutput.cs index c89278f..19b97f6 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateRegisterOutput.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateRegisterOutput.cs @@ -154,6 +154,22 @@ private static string GenerateExtensionMethodSource( .OrderBy(static kvp => kvp.Key, StringComparer.Ordinal) .ToList(); + // Compute the set of service type names (and impl type names) that have async-init injection. + // Used by Task wrapper resolution to distinguish async-init vs sync-only services. + var asyncInitServiceTypeSet = new HashSet(StringComparer.Ordinal); + foreach(var group in tagGroups.Values) + { + foreach(var reg in group) + { + if(reg.InjectionMembers.Any(m => m.MemberType == InjectionMemberType.AsyncMethod)) + { + asyncInitServiceTypeSet.Add(reg.ServiceType.Name); + asyncInitServiceTypeSet.Add(reg.ImplementationType.Name); + } + } + } + HashSet? asyncInitServiceTypeNames = asyncInitServiceTypeSet.Count > 0 ? asyncInitServiceTypeSet : null; + bool isFirstGroup = true; foreach(var kvp in sortedGroups) { @@ -178,12 +194,12 @@ private static string GenerateExtensionMethodSource( if(tagList.Length == 0) { // No tags - only register when no tags passed (mutually exclusive model) - WriteNoTagConditionalBlock(mainWriter, groupRegistrations, groupLazyEntries, groupFuncEntries, groupKvpEntries); + WriteNoTagConditionalBlock(mainWriter, groupRegistrations, groupLazyEntries, groupFuncEntries, groupKvpEntries, asyncInitServiceTypeNames); } else { // Has tags - only register when tags match - WriteConditionalTagBlock(mainWriter, tagList, groupRegistrations, groupLazyEntries, groupFuncEntries, groupKvpEntries); + WriteConditionalTagBlock(mainWriter, tagList, groupRegistrations, groupLazyEntries, groupFuncEntries, groupKvpEntries, asyncInitServiceTypeNames); } } @@ -205,11 +221,11 @@ private static string GenerateExtensionMethodSource( /// /// Writes a group of registrations without any conditional wrapper. /// - private static void WriteRegistrationGroup(SourceWriter writer, List registrations) + private static void WriteRegistrationGroup(SourceWriter writer, List registrations, HashSet? asyncInitServiceTypeNames = null) { foreach(var registration in registrations) { - WriteRegistration(writer, registration); + WriteRegistration(writer, registration, asyncInitServiceTypeNames); } } @@ -223,7 +239,8 @@ private static void WriteConditionalTagBlock( List registrations, List? lazyEntries = null, List? funcEntries = null, - List? kvpEntries = null) + List? kvpEntries = null, + HashSet? asyncInitServiceTypeNames = null) { // Build the condition - only register when tags match var tagConditions = tags.Select(static tag => $"tags.Contains(\"{tag}\")"); @@ -235,7 +252,7 @@ private static void WriteConditionalTagBlock( foreach(var registration in registrations) { - WriteRegistration(writer, registration); + WriteRegistration(writer, registration, asyncInitServiceTypeNames); } WriteLazyRegistrations(writer, lazyEntries); @@ -255,7 +272,8 @@ private static void WriteNoTagConditionalBlock( List registrations, List? lazyEntries = null, List? funcEntries = null, - List? kvpEntries = null) + List? kvpEntries = null, + HashSet? asyncInitServiceTypeNames = null) { writer.WriteLine("if (!tags.Any())"); writer.WriteLine("{"); @@ -263,7 +281,7 @@ private static void WriteNoTagConditionalBlock( foreach(var registration in registrations) { - WriteRegistration(writer, registration); + WriteRegistration(writer, registration, asyncInitServiceTypeNames); } WriteLazyRegistrations(writer, lazyEntries); @@ -292,7 +310,7 @@ private static void WriteServiceRegistrationLambdaStart( writer.WriteLine($"services.Add{lifetime}<{serviceTypeName}>(({IServiceProviderGlobalTypeName} sp) =>"); } - private static void WriteRegistration(SourceWriter writer, ServiceRegistrationModel registration) + private static void WriteRegistration(SourceWriter writer, ServiceRegistrationModel registration, HashSet? asyncInitServiceTypeNames = null) { var lifetime = registration.Lifetime.Name; var serviceTypeName = registration.ServiceType.Name; @@ -306,7 +324,7 @@ private static void WriteRegistration(SourceWriter writer, ServiceRegistrationMo // Handle Factory registration first (takes precedence) if(hasFactory) { - WriteFactoryMethodRegistration(writer, registration, lifetime); + WriteFactoryMethodRegistration(writer, registration, lifetime, asyncInitServiceTypeNames); return; } @@ -334,7 +352,7 @@ private static void WriteRegistration(SourceWriter writer, ServiceRegistrationMo } else { - WriteDecoratorRegistration(writer, registration, lifetime); + WriteDecoratorRegistration(writer, registration, lifetime, asyncInitServiceTypeNames); return; } } @@ -367,17 +385,26 @@ private static void WriteRegistration(SourceWriter writer, ServiceRegistrationMo // For non-open-generic, service type registrations (interface/base class): // Always use forwarding to implementation type to ensure single instance per scope/lifetime // This generates: sp => sp.GetRequiredService() + // When impl is async-init, forwards Task → Task instead. if(!registration.IsOpenGeneric && isServiceTypeRegistration) { // Service type registration (interface/base class) forwards to implementation - WriteServiceTypeForwardingRegistration(writer, registration, lifetime); + WriteServiceTypeForwardingRegistration(writer, registration, lifetime, asyncInitServiceTypeNames); return; } if(needsFactoryConstruction && !registration.IsOpenGeneric) { // Self registration with injection members or constructor params with special handling - generate factory method - WriteInjectionRegistration(writer, registration, lifetime); + bool hasAsyncInjectionMembers = registration.InjectionMembers.Any(m => m.MemberType == InjectionMemberType.AsyncMethod); + if(hasAsyncInjectionMembers) + { + WriteAsyncInjectionRegistration(writer, registration, lifetime, asyncInitServiceTypeNames); + } + else + { + WriteInjectionRegistration(writer, registration, lifetime, asyncInitServiceTypeNames); + } return; } @@ -423,7 +450,7 @@ private static void WriteRegistration(SourceWriter writer, ServiceRegistrationMo /// /// If the factory return type differs from the service type, adds a cast. /// - private static void WriteFactoryMethodRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime) + private static void WriteFactoryMethodRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime, HashSet? asyncInitServiceTypeNames = null) { var serviceTypeName = registration.ServiceType.Name; var factory = registration.Factory!; @@ -452,7 +479,7 @@ private static void WriteFactoryMethodRegistration(SourceWriter writer, ServiceR // Use multi-line lambda format for readability WriteFactoryMethodRegistrationWithAdditionalParams( writer, registration, lifetime, serviceTypeName, factoryPath, - hasServiceProvider, hasKey, returnTypeName, additionalParameters, isKeyedRegistration, genericTypeArgs); + hasServiceProvider, hasKey, returnTypeName, additionalParameters, isKeyedRegistration, genericTypeArgs, asyncInitServiceTypeNames); return; } @@ -484,7 +511,8 @@ private static void WriteFactoryMethodRegistrationWithAdditionalParams( string? returnTypeName, ImmutableEquatableArray additionalParameters, bool isKeyedRegistration, - string? genericTypeArgs = null) + string? genericTypeArgs = null, + HashSet? asyncInitServiceTypeNames = null) { // Open the registration lambda WriteServiceRegistrationLambdaStart(writer, lifetime, serviceTypeName, registration.Key); @@ -498,7 +526,7 @@ private static void WriteFactoryMethodRegistrationWithAdditionalParams( foreach(var param in additionalParameters) { var paramVar = $"f_p{paramIndex}"; - var varName = ResolveParamAndEmitVar(writer, param, paramVar, isKeyedRegistration, registration.Key); + var varName = ResolveParamAndEmitVar(writer, param, paramVar, isKeyedRegistration, registration.Key, asyncInitServiceTypeNames); // Use the resolved variable name paramVars.Add(varName); @@ -622,11 +650,35 @@ private static void WriteInstanceRegistration(SourceWriter writer, ServiceRegist /// services.AddTransient<IMyService>(sp => sp.GetRequiredService<MyService>()); /// /// - private static void WriteServiceTypeForwardingRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime) + private static void WriteServiceTypeForwardingRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime, HashSet? asyncInitServiceTypeNames = null) { var serviceTypeName = registration.ServiceType.Name; var implTypeName = registration.ImplementationType.Name; + // When the implementation type is registered as async-init (Task), + // the forwarding registration must also use Task. + // We must use "async ... => await ..." because Task is invariant in C# — + // Task cannot be implicitly assigned to Task even when + // MyService : IMyService. The await unwraps the result and the async lambda + // re-wraps it as Task. + bool isAsyncInit = asyncInitServiceTypeNames?.Contains(implTypeName) == true; + if(isAsyncInit) + { + var taskServiceTypeName = $"global::System.Threading.Tasks.Task<{serviceTypeName}>"; + var taskImplTypeName = $"global::System.Threading.Tasks.Task<{implTypeName}>"; + if(registration.Key is not null) + { + var requiredCall = BuildServiceCall(GetRequiredKeyedService, taskImplTypeName, "key"); + writer.WriteLine($"services.AddKeyed{lifetime}<{taskServiceTypeName}>({registration.Key}, async ({IServiceProviderGlobalTypeName} sp, object? key) => await {requiredCall});"); + } + else + { + var requiredCall = BuildServiceCall(GetRequiredService, taskImplTypeName, serviceKey: null); + writer.WriteLine($"services.Add{lifetime}<{taskServiceTypeName}>(async ({IServiceProviderGlobalTypeName} sp) => await {requiredCall});"); + } + return; + } + if(registration.Key is not null) { // Keyed registration - forward to keyed implementation @@ -663,7 +715,7 @@ private static void WriteServiceTypeForwardingRegistration(SourceWriter writer, /// var s0 = p0 is not null ? new MyService(optDep: p0) : new MyService(); /// /// - private static void WriteInjectionRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime) + private static void WriteInjectionRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime, HashSet? asyncInitServiceTypeNames = null) { var serviceTypeName = registration.ServiceType.Name; var implTypeName = registration.ImplementationType.Name; @@ -687,7 +739,8 @@ private static void WriteInjectionRegistration(SourceWriter writer, ServiceRegis serviceTypeNames: null, ctorTypeNameResolver: null, memberTypeNameResolver: null, - decoratedPrevVar: null); + decoratedPrevVar: null, + asyncInitServiceTypeNames: asyncInitServiceTypeNames); writer.WriteLine("return s0;"); @@ -695,6 +748,77 @@ private static void WriteInjectionRegistration(SourceWriter writer, ServiceRegis writer.WriteLine("});"); } + /// + /// Writes registration code for services that have async injection methods. + /// Generates a Task<T> registration with an async local Init() function + /// that performs construction, sync injection, and awaited async injection in order. + /// + /// + /// Generates code like: + /// + /// services.AddSingleton<Task<MyService>>((IServiceProvider sp) => + /// { + /// async Task<MyService> Init() + /// { + /// var s0_p0 = sp.GetRequiredService<ILogger>(); + /// var s0_m0 = sp.GetRequiredService<IAsyncInitializer>(); + /// var s0 = new MyService { Logger = s0_p0 }; + /// await s0.InitAsync(s0_m0); + /// return s0; + /// } + /// return Init(); + /// }); + /// + /// + private static void WriteAsyncInjectionRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime, HashSet? asyncInitServiceTypeNames = null) + { + var serviceTypeName = registration.ServiceType.Name; + var implTypeName = registration.ImplementationType.Name; + var injectionMembers = registration.InjectionMembers; + bool isKeyedRegistration = registration.Key is not null; + + // Service type is wrapped in Task for async-init registrations + var taskServiceTypeName = $"global::System.Threading.Tasks.Task<{serviceTypeName}>"; + var taskImplTypeName = $"global::System.Threading.Tasks.Task<{implTypeName}>"; + + // Open the registration lambda for Task + WriteServiceRegistrationLambdaStart(writer, lifetime, taskServiceTypeName, registration.Key); + + writer.WriteLine("{"); + writer.Indentation++; + + // Write the async local function header + writer.WriteLine($"async {taskImplTypeName} Init()"); + writer.WriteLine("{"); + writer.Indentation++; + + // Emit construction with sync + async injection (isAsyncMode = true) + WriteConstructInstanceWithInjection( + writer, + instanceVarName: "s0", + implTypeName: implTypeName, + constructorParams: registration.ImplementationType.ConstructorParameters, + injectionMembers: injectionMembers, + isKeyedRegistration: isKeyedRegistration, + registrationKey: registration.Key, + serviceTypeNames: null, + ctorTypeNameResolver: null, + memberTypeNameResolver: null, + decoratedPrevVar: null, + asyncInitServiceTypeNames: asyncInitServiceTypeNames, + isAsyncMode: true); + + writer.WriteLine("return s0;"); + + writer.Indentation--; + writer.WriteLine("}"); + + writer.WriteLine("return Init();"); + + writer.Indentation--; + writer.WriteLine("});"); + } + /// @@ -712,7 +836,9 @@ private static void WriteConstructInstanceWithInjection( HashSet? serviceTypeNames, Func? ctorTypeNameResolver, Func? memberTypeNameResolver, - string? decoratedPrevVar) + string? decoratedPrevVar, + HashSet? asyncInitServiceTypeNames = null, + bool isAsyncMode = false) { var ctorParams = constructorParams ?? []; var constructorParamEntries = new List<(string Name, string? Value, bool NeedsConditional)>(ctorParams.Length); @@ -728,7 +854,7 @@ private static void WriteConstructInstanceWithInjection( else { var varName = decoratedPrevVar is not null ? $"{instanceVarName}_p{paramIndex}" : $"p{paramIndex}"; - var resolvedVar = ResolveParamAndEmitVar(writer, param, varName, isKeyedRegistration, registrationKey); + var resolvedVar = ResolveParamAndEmitVar(writer, param, varName, isKeyedRegistration, registrationKey, asyncInitServiceTypeNames); constructorParamEntries.Add((param.Name, resolvedVar, false)); paramIndex++; } @@ -781,7 +907,9 @@ private static void WriteConstructInstanceWithInjection( injectionMembers, isKeyedRegistration, registrationKey, - memberTypeNameResolver); + memberTypeNameResolver, + asyncInitServiceTypeNames, + isAsyncMode); } } @@ -794,7 +922,9 @@ private static void EmitConstruction( ImmutableEquatableArray injectionMembers, bool isKeyedRegistration, string? registrationKey, - Func? memberTypeNameResolver) + Func? memberTypeNameResolver, + HashSet? asyncInitServiceTypeNames = null, + bool isAsyncMode = false) { // Resolve property/field injection parameters int pfCount = injectionMembers.Count(m => m.MemberType is InjectionMemberType.Property or InjectionMemberType.Field); @@ -809,17 +939,63 @@ private static void EmitConstruction( var mt = m.Type; var mtName = mt is null ? "object" : (memberTypeNameResolver is not null ? memberTypeNameResolver(mt) : mt.Name); bool hasNonNullDefault = m.HasDefaultValue && !m.DefaultValueIsNull; - ResolveMemberValue(writer, mt, mtName, varN, m.Key, m.IsNullable, hasNonNullDefault, m.DefaultValue); + ResolveMemberValue(writer, mt, mtName, varN, m.Key, m.IsNullable, hasNonNullDefault, m.DefaultValue, asyncInitServiceTypeNames); preProps[idx++] = (m.Name, varN); pfIdxCounter++; } - // Resolve method parameters before instance creation (unified for both decorator and non-decorator) - var methodParamResolutions = new List<(string MethodName, string?[] ParamVars, string[] ParamNames)>(); + // Resolve sync method parameters before instance creation int memberParamIndex = pfIdxCounter; + var methodParamResolutions = ResolveMethodParamResolutions( + writer, injectionMembers, InjectionMemberType.Method, instanceVarName, ref memberParamIndex, + isKeyedRegistration, registrationKey, memberTypeNameResolver, asyncInitServiceTypeNames); + + // Resolve async method parameters before instance creation (only when in async mode) + var asyncMethodParamResolutions = isAsyncMode + ? ResolveMethodParamResolutions( + writer, injectionMembers, InjectionMemberType.AsyncMethod, instanceVarName, ref memberParamIndex, + isKeyedRegistration, registrationKey, memberTypeNameResolver, asyncInitServiceTypeNames) + : []; + + // Property/field values are resolved upfront; include all properties in the object initializer + var propertyInits = preProps.Select(p => $"{p.Name} = {p.ParamVar}").ToArray(); + var constructorArgs = BuildArgumentListFromEntries([.. constructorParamEntries]); + var initializerPart = propertyInits.Length > 0 ? $" {{ {string.Join(", ", propertyInits)} }}" : ""; + var constructorInvocation = BuildConstructorInvocation(implTypeName, constructorArgs, initializerPart); + writer.WriteLine($"var {instanceVarName} = {constructorInvocation};"); + + // Call sync methods + EmitMethodInvocations(writer, instanceVarName, methodParamResolutions, useAwait: false); + + // Await async methods (only in async mode) + EmitMethodInvocations(writer, instanceVarName, asyncMethodParamResolutions, useAwait: true); + } + + /// + /// Builds a constructor invocation expression with an optional initializer. + /// + private static string BuildConstructorInvocation(string implTypeName, string args, string initializerPart) => + $"new {implTypeName}({args}){initializerPart}"; + + /// + /// Resolves method parameters of a given and emits their variable declarations. + /// Shared by sync () and async () resolution loops. + /// + private static List<(string MethodName, string?[] ParamVars, string[] ParamNames)> ResolveMethodParamResolutions( + SourceWriter writer, + ImmutableEquatableArray injectionMembers, + InjectionMemberType targetType, + string instanceVarName, + ref int memberParamIndex, + bool isKeyedRegistration, + string? registrationKey, + Func? memberTypeNameResolver, + HashSet? asyncInitServiceTypeNames) + { + var resolutions = new List<(string MethodName, string?[] ParamVars, string[] ParamNames)>(); foreach(var method in injectionMembers) { - if(method.MemberType != InjectionMemberType.Method) + if(method.MemberType != targetType) continue; var mParams = method.Parameters ?? []; var mVars = new string?[mParams.Length]; @@ -838,43 +1014,43 @@ private static void EmitConstruction( method.Key, isKeyedRegistration, registrationKey, - memberTypeNameResolver); + memberTypeNameResolver, + asyncInitServiceTypeNames); } else { - var resolvedVar = ResolveParamAndEmitVar(writer, p, pVar, isKeyedRegistration, registrationKey); - mVars[mi] = resolvedVar; + mVars[mi] = ResolveParamAndEmitVar(writer, p, pVar, isKeyedRegistration, registrationKey, asyncInitServiceTypeNames); } mi++; memberParamIndex++; } - methodParamResolutions.Add((method.Name, mVars, mNames)); + resolutions.Add((method.Name, mVars, mNames)); } + return resolutions; + } - // Property/field values are resolved upfront; include all properties in the object initializer - var propertyInits = preProps.Select(p => $"{p.Name} = {p.ParamVar}").ToArray(); - var constructorArgs = BuildArgumentListFromEntries([.. constructorParamEntries]); - var initializerPart = propertyInits.Length > 0 ? $" {{ {string.Join(", ", propertyInits)} }}" : ""; - var constructorInvocation = BuildConstructorInvocation(implTypeName, constructorArgs, initializerPart); - writer.WriteLine($"var {instanceVarName} = {constructorInvocation};"); - - // Call methods - foreach(var (mName, mVars, mNames) in methodParamResolutions) + /// + /// Emits method invocation statements for the given . + /// When is , each call is prefixed with await. + /// + private static void EmitMethodInvocations( + SourceWriter writer, + string instanceVarName, + List<(string MethodName, string?[] ParamVars, string[] ParamNames)> resolutions, + bool useAwait) + { + foreach(var (mName, mVars, mNames) in resolutions) { var entries = new List<(string Name, string? Value, bool NeedsConditional)>(mVars.Length); for(int i = 0; i < mVars.Length; i++) entries.Add((mNames[i], mVars[i], false)); var args = BuildArgumentListFromEntries([.. entries]); - writer.WriteLine($"{instanceVarName}.{mName}({args});"); + writer.WriteLine(useAwait + ? $"await {instanceVarName}.{mName}({args});" + : $"{instanceVarName}.{mName}({args});"); } } - /// - /// Builds a constructor invocation expression with an optional initializer. - /// - private static string BuildConstructorInvocation(string implTypeName, string args, string initializerPart) => - $"new {implTypeName}({args}){initializerPart}"; - /// /// Resolves a property or field injection value and emits its variable declaration. /// @@ -886,7 +1062,8 @@ private static void ResolveMemberValue( string? serviceKey, bool isNullable, bool hasNonNullDefault, - string? defaultValue) + string? defaultValue, + HashSet? asyncInitServiceTypeNames = null) { if(memberType is CollectionWrapperTypeData) { @@ -894,9 +1071,9 @@ private static void ResolveMemberValue( return; } - if(memberType is LazyTypeData or FuncTypeData or DictionaryTypeData or KeyValuePairTypeData) + if(memberType is LazyTypeData or FuncTypeData or DictionaryTypeData or KeyValuePairTypeData or TaskTypeData) { - WriteWrapperResolution(writer, memberType, paramVar, serviceKey, isOptional: isNullable); + WriteWrapperResolution(writer, memberType, paramVar, serviceKey, isOptional: isNullable, asyncInitServiceTypeNames); return; } @@ -925,9 +1102,10 @@ private static void ResolveMemberValue( string methodKey, bool isKeyedRegistration, string? registrationKey, - Func? typeNameResolver) + Func? typeNameResolver, + HashSet? asyncInitServiceTypeNames = null) { - if(TryResolveCommonParameter(writer, param, paramVar, isKeyedRegistration, registrationKey, methodKey, isOptional: false, typeNameResolver, out var resolvedVar)) + if(TryResolveCommonParameter(writer, param, paramVar, isKeyedRegistration, registrationKey, methodKey, isOptional: false, typeNameResolver, asyncInitServiceTypeNames, out var resolvedVar)) { return resolvedVar!; } @@ -1019,7 +1197,7 @@ private static string GetGenericString(in int arity) => /// /// For open generic decorators, falls back to ActivatorUtilities.CreateInstance. /// - private static void WriteDecoratorRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime) + private static void WriteDecoratorRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime, HashSet? asyncInitServiceTypeNames = null) { var decorators = registration.Decorators; // Decorators array is in order from outermost to innermost, @@ -1078,7 +1256,8 @@ private static void WriteDecoratorRegistration(SourceWriter writer, ServiceRegis serviceTypeNames: serviceTypeNames, ctorTypeNameResolver: t => decorator is GenericTypeData { IsOpenGeneric: true } && serviceTypeParams is not null ? SubstituteGenericArguments(t, decorator, serviceTypeParams) : t.Name, memberTypeNameResolver: t => decorator is GenericTypeData { IsOpenGeneric: true } && serviceTypeParams is not null ? SubstituteGenericArguments(t, decorator, serviceTypeParams) : t.Name, - decoratedPrevVar: prevVar); + decoratedPrevVar: prevVar, + asyncInitServiceTypeNames: asyncInitServiceTypeNames); } // Return the outermost decorator @@ -1260,11 +1439,12 @@ private static string ResolveParamAndEmitVar( ParameterData param, string paramVar, bool isKeyedRegistration, - string? registrationKey = null) + string? registrationKey = null, + HashSet? asyncInitServiceTypeNames = null) { var paramTypeName = param.Type.Name; - if(TryResolveCommonParameter(writer, param, paramVar, isKeyedRegistration, registrationKey, param.ServiceKey, param.IsOptional, typeNameResolver: null, out var resolvedVar)) + if(TryResolveCommonParameter(writer, param, paramVar, isKeyedRegistration, registrationKey, param.ServiceKey, param.IsOptional, typeNameResolver: null, asyncInitServiceTypeNames, out var resolvedVar)) { return resolvedVar!; } @@ -1402,9 +1582,10 @@ private static void WriteWrapperResolution( TypeData type, string paramVar, string? serviceKey, - bool isOptional = false) + bool isOptional = false, + HashSet? asyncInitServiceTypeNames = null) { - var expr = BuildWrapperExpression(type, serviceKey, isOptional); + var expr = BuildWrapperExpression(type, serviceKey, isOptional, asyncInitServiceTypeNames); writer.WriteLine($"var {paramVar} = {expr};"); } @@ -1415,7 +1596,7 @@ private static void WriteWrapperResolution( /// The service key (null for non-keyed services). /// Whether this resolution is optional. /// A C# expression string that resolves the wrapper type. - private static string BuildWrapperExpression(TypeData type, string? serviceKey, bool isOptional) + private static string BuildWrapperExpression(TypeData type, string? serviceKey, bool isOptional, HashSet? asyncInitServiceTypeNames = null) { switch(type) { @@ -1431,7 +1612,7 @@ private static string BuildWrapperExpression(TypeData type, string? serviceKey, return BuildServiceCall(methodName, type.Name, serviceKey); } // Nested wrapper (e.g., Lazy>) — inline construction - var lazyInnerExpr = BuildInnerResolutionExpression(innerType, serviceKey, isOptional); + var lazyInnerExpr = BuildInnerResolutionExpression(innerType, serviceKey, isOptional, asyncInitServiceTypeNames); return $"new global::System.Lazy<{innerType.Name}>(() => {lazyInnerExpr}, global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication)"; } @@ -1455,7 +1636,7 @@ private static string BuildWrapperExpression(TypeData type, string? serviceKey, return BuildServiceCall(methodName, type.Name, serviceKey); } // Nested wrapper (e.g., Func>) — inline construction - var funcInnerExpr = BuildInnerResolutionExpression(innerType, serviceKey, isOptional); + var funcInnerExpr = BuildInnerResolutionExpression(innerType, serviceKey, isOptional, asyncInitServiceTypeNames); return $"new global::System.Func<{innerType.Name}>(() => {funcInnerExpr})"; } @@ -1465,7 +1646,7 @@ private static string BuildWrapperExpression(TypeData type, string? serviceKey, var valueType = kvp.ValueType; // KeyValuePair uses the registration's service key as the key value var keyExpr = serviceKey ?? "default"; - var valueExpr = BuildInnerResolutionExpression(valueType, serviceKey, isOptional); + var valueExpr = BuildInnerResolutionExpression(valueType, serviceKey, isOptional, asyncInitServiceTypeNames); return $"new global::System.Collections.Generic.KeyValuePair<{keyType.Name}, {valueType.Name}>({keyExpr}, {valueExpr})"; } @@ -1482,6 +1663,30 @@ private static string BuildWrapperExpression(TypeData type, string? serviceKey, return $"{getServicesCall}.ToDictionary()"; } + case TaskTypeData task: + { + // Task wrapper: if inner type is an async-init service, resolve Task directly; + // otherwise wrap synchronous resolution in Task.FromResult(...). + var innerTypeName = task.InnerType.Name; + if(asyncInitServiceTypeNames?.Contains(innerTypeName) == true) + { + // Async-init service: Task is registered directly. + var methodName = isOptional + ? (serviceKey is not null ? GetKeyedService : GetService) + : (serviceKey is not null ? GetRequiredKeyedService : GetRequiredService); + return BuildServiceCall(methodName, type.Name, serviceKey); + } + else + { + // Sync-only service: wrap with Task.FromResult. + var syncMethodName = isOptional + ? (serviceKey is not null ? GetKeyedService : GetService) + : (serviceKey is not null ? GetRequiredKeyedService : GetRequiredService); + var syncCall = BuildServiceCall(syncMethodName, innerTypeName, serviceKey); + return $"global::System.Threading.Tasks.Task.FromResult({syncCall})"; + } + } + default: { // Fallback for non-wrapper types @@ -1497,12 +1702,12 @@ private static string BuildWrapperExpression(TypeData type, string? serviceKey, /// Builds an inner resolution expression — either a nested wrapper expression, a collection /// expression, or a direct service call. Supports nesting such as Lazy<IEnumerable<T>>. /// - private static string BuildInnerResolutionExpression(TypeData innerType, string? serviceKey, bool isOptional) + private static string BuildInnerResolutionExpression(TypeData innerType, string? serviceKey, bool isOptional, HashSet? asyncInitServiceTypeNames = null) { // If the inner type is itself a wrapper, recurse to handle nesting (e.g., Lazy>) - if(innerType is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData) + if(innerType is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData or TaskTypeData) { - return BuildWrapperExpression(innerType, serviceKey, isOptional); + return BuildWrapperExpression(innerType, serviceKey, isOptional, asyncInitServiceTypeNames); } // Collection types inside wrappers (e.g., Lazy>) @@ -1550,6 +1755,7 @@ private static bool TryResolveCommonParameter( string? serviceKey, bool isOptional, Func? typeNameResolver, + HashSet? asyncInitServiceTypeNames, out string? resolvedVar) { if(IsServiceProviderType(param.Type.Name)) @@ -1572,9 +1778,9 @@ private static bool TryResolveCommonParameter( return true; } - if(param.Type is LazyTypeData or FuncTypeData or DictionaryTypeData or KeyValuePairTypeData) + if(param.Type is LazyTypeData or FuncTypeData or DictionaryTypeData or KeyValuePairTypeData or TaskTypeData) { - WriteWrapperResolution(writer, param.Type, paramVar, serviceKey, isOptional); + WriteWrapperResolution(writer, param.Type, paramVar, serviceKey, isOptional, asyncInitServiceTypeNames); resolvedVar = paramVar; return true; } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GroupRegistrationsForContainer.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GroupRegistrationsForContainer.cs index e15fa64..b2041b5 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GroupRegistrationsForContainer.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GroupRegistrationsForContainer.cs @@ -193,6 +193,7 @@ private static ContainerRegistrationGroups BuildContainerRegistrationGroups( var transients = transientMap.Values.Select(static v => v.Cached).ToList(); // Collect service types with multiple registrations for IEnumerable resolution + // Async-init services are excluded from collection resolution (only Task can access them). var collectionServiceTypes = new List(); var collectionRegistrations = new Dictionary>(); @@ -201,9 +202,12 @@ private static ContainerRegistrationGroups BuildContainerRegistrationGroups( // Include non-keyed service types with multiple registrations if(kvp.Key.Key is null && kvp.Value.Count > 1) { + // Filter out async-init registrations — they cannot appear in IEnumerable resolvers + var effectiveRegistrations = kvp.Value.Where(static c => !c.IsAsyncInit).ToImmutableEquatableArray(); + // Deduplicate resolver method names to count unique implementations var uniqueResolvers = new HashSet(); - foreach(var cached in kvp.Value) + foreach(var cached in effectiveRegistrations) { uniqueResolvers.Add(cached.ResolverMethodName); } @@ -212,7 +216,7 @@ private static ContainerRegistrationGroups BuildContainerRegistrationGroups( if(uniqueResolvers.Count > 1) { collectionServiceTypes.Add(kvp.Key.ServiceType); - collectionRegistrations[kvp.Key.ServiceType] = kvp.Value.ToImmutableEquatableArray(); + collectionRegistrations[kvp.Key.ServiceType] = effectiveRegistrations; } } } @@ -302,14 +306,30 @@ private static CachedRegistration CreateCachedRegistration( // Determine if this registration should be eagerly resolved // Instance registrations are inherently eager (no field caching needed) // Transient services are not supported for eager resolution - var isEager = reg.Instance is null && reg.Lifetime switch + // Async-init services must always be lazy (cannot be started in constructor) + var isAsyncInit = HasAsyncInitMembers(reg); + var isEager = reg.Instance is null && !isAsyncInit && reg.Lifetime switch { ServiceLifetime.Singleton => (eagerResolveOptions & EagerResolveOptions.Singleton) != 0, ServiceLifetime.Scoped => (eagerResolveOptions & EagerResolveOptions.Scoped) != 0, _ => false // Transient is never eager }; - return new CachedRegistration(reg, fieldName, methodName, isEager); + return new CachedRegistration(reg, fieldName, methodName, isEager, isAsyncInit); + } + + /// + /// Returns when the registration has at least one + /// member, making it an async-init service. + /// + private static bool HasAsyncInitMembers(ServiceRegistrationModel reg) + { + foreach(var m in reg.InjectionMembers) + { + if(m.MemberType == InjectionMemberType.AsyncMethod) + return true; + } + return false; } /// diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/LazyRegistrationHelper.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/LazyRegistrationHelper.cs index 80ec151..71a090f 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/LazyRegistrationHelper.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/LazyRegistrationHelper.cs @@ -178,6 +178,10 @@ private static ImmutableEquatableArray CollectContainerLazyE if(reg.IsOpenGeneric) continue; + // Async-init services cannot be resolved synchronously — exclude from Lazy entries + if(cached.IsAsyncInit) + continue; + var entryKey = $"{serviceType}|{reg.ImplementationType.Name}|{reg.Key}"; if(!addedKeys.Add(entryKey)) continue; diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Basic.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Basic.spec.md index 19d3cc7..1ff5cec 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Basic.spec.md +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Basic.spec.md @@ -327,6 +327,37 @@ partial class AppContainer : IIocContainer, IServiceProvid if(service is IDisposable disposable) disposable.Dispose(); } + // These overloads handle async-init services that store their instance as Task?. + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + private void ThrowIfDisposed() { if (_disposed != 0) throw new ObjectDisposedException(GetType().Name); diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Collections.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Collections.spec.md index 8657396..921ca29 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Collections.spec.md +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Collections.spec.md @@ -132,6 +132,10 @@ When a constructor parameter or injected member uses a wrapper type (`Lazy`, > - `Func>` → `new Func>(() => new Lazy(() => GetMyService()))` > - `Lazy>` → `new Lazy>(() => GetServices())` > - `IEnumerable>` / `IEnumerable>` — Resolved via `GetServices>()` which uses the wrapper resolver methods +> +> Non-collection outer wrappers (`Lazy`, `Func`) are recursively resolved to arbitrary depth. Collection outer wrappers (`IEnumerable`, etc.) support at most **1 level of inner wrapping** (2 levels total); deeper nesting (e.g., `IEnumerable>>`) falls back to `IServiceProvider` resolution via `GetRequiredService(typeof(...))`. +> +> `ValueTask` is **not** a recognized wrapper type in any context. Only `Task` is supported for async-init wrapping. When used as a partial accessor return type: if the target service uses async-init, `SGIOC029` is reported; otherwise `SGIOC021` is reported. ```csharp #region Define: diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Injection.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Injection.spec.md index b25a6a4..e9991bd 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Injection.spec.md +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Injection.spec.md @@ -2,7 +2,7 @@ ## Overview -The container supports all injection patterns including constructor parameters, properties, methods, and optional parameters, consistent with the registration generator. +The container supports all injection patterns including constructor parameters, properties, synchronous methods, awaited async methods, and optional parameters, consistent with the registration generator. ## Injection Support @@ -69,6 +69,231 @@ partial class AppContainer #endregion ``` +## Async Method Injection + +When a registration contains one or more async inject methods, the container MUST switch that implementation path from synchronous `T` resolution to asynchronous `Task` resolution. + +### Resolution Rules + +|Condition|Required behavior| +|:--------|:----------------| +|Registration contains one or more `InjectionMemberType.AsyncMethod` members|The container MUST generate an async resolver path that returns `Task`. Service-type aliases MUST project from that resolver as `Task`.| +|Singleton or scoped async-init registration|The container MUST cache `Task` in a field.| +|Transient async-init registration|The container MUST create a new `Task` per resolution and MUST NOT cache it.| +|Multiple service-type aliases resolve the same implementation, key, and instance/factory identity|The container MUST share a single cached `Task` field. The deduplication key MUST be `(ImplementationType, Key, InstanceOrFactory)`. Service type is **not** part of the cache key.| +|`ThreadSafeStrategy.None`|Allowed. The container MAY assign the task field directly without synchronization.| +|`ThreadSafeStrategy.SemaphoreSlim`|Allowed. Singleton/scoped async-init services MUST use `WaitAsync()` / `Release()` around first initialization.| +|`ThreadSafeStrategy.Lock`, `ThreadSafeStrategy.SpinLock`, or `ThreadSafeStrategy.CompareExchange`|Async-incompatible for async-init services and MUST NOT be used for that resolver path.| +|`EagerResolveOptions` includes singleton and/or scoped services|Async-init services MUST be excluded from eager resolution. The container constructor/scope constructor MUST NOT pre-start those tasks.| +|Collection wrappers (`IEnumerable`, `IReadOnlyCollection`, `IReadOnlyList`, `IList`, `T[]`)|Async-init registrations MUST be excluded from collection resolvers. `IEnumerable>` is not supported.| + +```mermaid +flowchart LR + A[Task accessor] --> C[Shared Task cache\nkey=(ImplType, Key, InstanceOrFactory)] + B[Task accessor] --> C + C --> D[Construct FooBar] + D --> E[Assign properties] + E --> F[Assign fields] + F --> G[Call sync inject methods] + G --> H[await async inject methods] + H --> I[Return completed Task] +``` + +### Shared `Task` Field Across Aliases + +Async-init services MUST follow the **same implementation-based field deduplication** as synchronous services. If one implementation is registered for multiple service types, all aliases MUST reuse the same cached `Task` field. + +```csharp +#region Define: +using System.Threading.Tasks; + +public interface IFoo { } +public interface IBar { } +public interface ILogger { } +public interface IAsyncInitializer +{ + Task InitializeAsync(object instance); +} + +[IocRegister(ServiceTypes = [typeof(IFoo), typeof(IBar)], Lifetime = ServiceLifetime.Singleton)] +public sealed class FooBar : IFoo, IBar +{ + [IocInject] + public ILogger Logger { get; set; } = default!; + + [IocInject] + public void InitializeSync() + { + } + + [IocInject] + public async Task InitializeAsync(IAsyncInitializer initializer) + { + await initializer.InitializeAsync(this); + } +} + +[IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.SemaphoreSlim)] +public partial class AppContainer +{ + public partial Task GetFooAsync(); + public partial Task GetBarAsync(); +} +#endregion + +#region Generate: +partial class AppContainer +{ + private global::System.Threading.Tasks.Task? _fooBar; + private readonly global::System.Threading.SemaphoreSlim _fooBarSemaphore = new(1, 1); + + public AppContainer(global::System.IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Async-init singleton is excluded from eager resolution. + // _fooBar = GetFooBarAsync(); // MUST NOT be emitted + } + + private async global::System.Threading.Tasks.Task GetFooBarAsync() + { + if(_fooBar is not null) + return await _fooBar; + + await _fooBarSemaphore.WaitAsync(); + try + { + if(_fooBar is null) + { + _fooBar = CreateFooBarAsync(); + } + } + finally + { + _fooBarSemaphore.Release(); + } + + return await _fooBar; + } + + private async global::System.Threading.Tasks.Task CreateFooBarAsync() + { + var instance = new global::FooBar + { + Logger = (global::ILogger)GetRequiredService(typeof(global::ILogger)), + }; + + instance.InitializeSync(); + await instance.InitializeAsync((global::IAsyncInitializer)GetRequiredService(typeof(global::IAsyncInitializer))); + return instance; + } + + public partial async global::System.Threading.Tasks.Task GetFooAsync() => await GetFooBarAsync(); + + public partial async global::System.Threading.Tasks.Task GetBarAsync() => await GetFooBarAsync(); +} +#endregion +``` + +```csharp +// Invalid outcome (must not happen): aliases must not get separate task caches. +private global::System.Threading.Tasks.Task? _fooBar_IFoo; +private global::System.Threading.Tasks.Task? _fooBar_IBar; +``` + +## Disposal of Async-init Fields + +Cached async-init singleton/scoped services use fields of type `Task?`. Generated disposal code MUST unwrap the completed service instance before calling `DisposeServiceAsync` or `DisposeService`, but only when the cached task completed successfully. + +If accessing the cached task result or disposing the resolved service throws, the generated code MUST catch the exception, invoke `global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex)`, and continue disposal without rethrowing. + +`global::SourceGen.Ioc.IocContainerGlobalOptions` is a static configuration type in the `SourceGen.Ioc` namespace. It exposes `public static Action? OnDisposeException`, which users MAY assign to observe or log disposal-time exceptions. + +### Disposal Rules + +|Disposal path|Field type|Required generated pattern|Forbidden pattern| +|:---|:---|:---|:---| +|`DisposeAsync`|`Task?`|Check `task.IsCompletedSuccessfully`, then try-catch `await DisposeServiceAsync(await _field)`|Direct `await task` without status check or try-catch| +|`Dispose`|`Task?`|Check `task.IsCompletedSuccessfully`, then try-catch `DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult())`|Direct `.GetResult()` without status check or try-catch| + +These rules apply regardless of the container's `ThreadSafeStrategy`. Disposal behavior depends on the cached field type (`Task?`), not on the synchronization primitive used during first initialization. + +### Example + +```csharp +#region Generate: +partial class AppContainer : IAsyncDisposable, IDisposable +{ + private global::System.Threading.Tasks.Task? _fooBar; + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + public async ValueTask DisposeAsync() + { + await DisposeServiceAsync(_fooBar); + } + + public void Dispose() + { + DisposeService(_fooBar); + } +} +#endregion +``` + +```csharp +#region Configure: +global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException = static ex => +{ + Console.Error.WriteLine(ex); +}; +#endregion +``` + +```csharp +// Invalid outcome (must not happen): Task disposal must not read the result +// unless the task completed successfully, and must not let disposal exceptions escape. +public async ValueTask DisposeAsync() +{ + if(_fooBar is not null) + await DisposeServiceAsync(await _fooBar); +} + +public void Dispose() +{ + if(_fooBar is not null) + DisposeService(_fooBar.ConfigureAwait(false).GetAwaiter().GetResult()); +} +``` + ## See Also - [Injection Registration](Register.Injection.spec.md) diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.ThreadSafety.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.ThreadSafety.spec.md index 321d29b..2b6f10c 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.ThreadSafety.spec.md +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.ThreadSafety.spec.md @@ -16,6 +16,100 @@ The `ThreadSafeStrategy` enum controls how the container ensures thread-safe ini |`SpinLock`|Uses `SpinLock` with double-checked locking pattern.|High-performance scenarios with very short initialization times. Not recommended for I/O-bound initialization.| |`CompareExchange`|Uses `Interlocked.CompareExchange` (CAS) for lock-free thread safety.|Best performance for lightweight constructors. No synchronization overhead. May create duplicate instances under contention; duplicates are disposed via `DisposeService`.| +## Async-init Strategy Override + +Registrations that contain one or more `InjectionMemberType.AsyncMethod` members generate async-init resolver paths and cached fields of type `Task`. Those resolver paths MUST use an async-compatible synchronization strategy. + +### Effective Strategy Rules for Async-init Services + +|Container `ThreadSafeStrategy`|Effective strategy for async-init field declarations and resolver methods|Required behavior| +|:---|:---|:---| +|`None`|`None`|The generator MUST preserve `None` as-is. This remains an explicit opt-in for single-threaded or externally synchronized usage.| +|`SemaphoreSlim`|`SemaphoreSlim`|The generator MUST use `SemaphoreSlim` for async-init services.| +|`Lock`|`SemaphoreSlim`|The generator MUST automatically override the configured strategy to `SemaphoreSlim` for async-init field declarations and resolver methods because `lock` cannot span `await`.| +|`SpinLock`|`SemaphoreSlim`|The generator MUST automatically override the configured strategy to `SemaphoreSlim` for async-init field declarations and resolver methods because `SpinLock` cannot span `await`.| +|`CompareExchange`|`SemaphoreSlim`|The generator MUST automatically override the configured strategy to `SemaphoreSlim` for async-init field declarations and resolver methods because the CAS path is not compatible with awaited first-initialization.| + +> **Scope**: This override applies only to async-init services. Non-async resolver paths MUST continue to use the container's configured `ThreadSafeStrategy`. + +```mermaid +flowchart TD + A[Container ThreadSafeStrategy] --> B{Registration has\nInjectionMemberType.AsyncMethod?} + B -- No --> C[Use configured strategy as-is] + B -- Yes --> D{Configured strategy == None?} + D -- Yes --> E[Keep None] + D -- No --> F[Override to SemaphoreSlim] +``` + +### Example: `Lock` Container Strategy with Async-init Service + +```csharp +#region Define: +using System.Threading.Tasks; + +public interface IAsyncInitializer +{ + Task InitializeAsync(object instance); +} + +[IocRegister(ServiceLifetime.Singleton)] +public sealed class MyService +{ + [IocInject] + public async Task InitializeAsync(IAsyncInitializer initializer) + { + await initializer.InitializeAsync(this); + } +} + +[IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.Lock)] +public partial class AppContainer; +#endregion + +#region Generate: +partial class AppContainer +{ + private global::System.Threading.Tasks.Task? _myService; + private readonly global::System.Threading.SemaphoreSlim _myServiceSemaphore = new(1, 1); + + private async global::System.Threading.Tasks.Task GetMyServiceAsync() + { + if(_myService is not null) + return await _myService; + + await _myServiceSemaphore.WaitAsync(); + try + { + if(_myService is null) + { + _myService = CreateMyServiceAsync(); + } + } + finally + { + _myServiceSemaphore.Release(); + } + + return await _myService; + } +} +#endregion +``` + +```csharp +// Invalid outcome (must not happen): async-init resolver paths cannot use lock-based strategies. +private readonly Lock _myServiceLock = new(); +private async global::System.Threading.Tasks.Task GetMyServiceAsync() +{ + lock(_myServiceLock) + { + _myService ??= CreateMyServiceAsync(); + } + + return await _myService; +} +``` + ## Generated Code Examples ### ThreadSafeStrategy.None diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Injection.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Injection.spec.md index 320ea0a..4613c5d 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Injection.spec.md +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Injection.spec.md @@ -75,17 +75,137 @@ When a parameter or member type is a recognized wrapper type, the generator reso | `Dictionary` | — | `sp.GetServices>().ToDictionary(...)` | - **Nullable wrapper types**: Use `GetService()` (optional) instead of `GetRequiredService()` -- **Nested wrappers**: Supported up to 2 levels. Inner wrapper types are recursively constructed **inline** (no standalone registration). +- **Nested wrappers**: Non-collection outer wrappers (`Lazy`, `Func`) are recursively resolved to arbitrary depth via inline construction. Collection outer wrappers support at most **1 level of inner wrapping** (2 levels total). Inner wrapper types are constructed **inline** (no standalone registration). - `Lazy>` → `new Lazy>(() => new Func(() => sp.GetRequiredService()))` - `Func>` → `new Func>(() => new Lazy(() => sp.GetRequiredService()))` - `Lazy>` → `new Lazy>(() => sp.GetServices())` - `IEnumerable>` / `IEnumerable>` — Consumers resolve via `sp.GetServices>()` (uses standalone registrations) + - Collection outer wrapper with 3+ levels (e.g., `IEnumerable>>`) is **not** supported. No wrapper registrations are emitted; the consumer is registered with a plain `AddXXX()` call and the parameter is left to MS.DI runtime resolution. + - `ValueTask` is **not** a recognized wrapper type in any context. Only `Task` is supported for async-init wrapping. When used as a partial accessor return type: if the target service uses async-init, `SGIOC029` is reported; otherwise `SGIOC021` is reported. - **Multi-parameter Func matching**: For `Func`, constructor parameters and injectable members are matched by type against Func inputs using first-unused semantics. Unmatched dependencies are resolved from DI. - **Nested multi-parameter Func**: Not supported (e.g., `Lazy>` with input parameters). - **Open generic dependencies**: Wrapper inner types that reference closed generics trigger automatic closed generic registration (e.g., `Lazy>` → registers `Handler`) - **Factory method requirement**: Only **nested** wrapper types and **nullable** direct Lazy/Func types trigger factory method registration. Direct non-nullable Lazy/Func types resolve from their standalone registrations. - **Tag-awareness**: Standalone `Lazy`/`Func`/`KeyValuePair` registrations inherit the tags of the inner service `T` and are emitted within the same tag conditional block. +### Async Method Injection + +`AsyncMethodInject` extends post-construction member injection with awaited initialization methods. The generator MUST treat async method injection as a distinct injection stage that runs after all synchronous injection steps. + +#### Classification Rules + +|Condition|Required behavior| +|:--------|:----------------| +|Method has `[IocInject]`/`[Inject]`, is an ordinary instance method, returns non-generic `Task`, and `AsyncMethodInject` is enabled|MUST classify the member as `InjectionMemberType.AsyncMethod`.| +|Method has `[IocInject]`/`[Inject]` and returns `void`|MUST continue to classify the member as `InjectionMemberType.Method`.| +|Method returns `Task`|MUST NOT classify the member as async method injection. `Task` is not a supported injection-method return type.| +|Method returns `ValueTask` or `ValueTask`|MUST NOT classify the member as async method injection.| +|Method returns `Task` but `AsyncMethodInject` is disabled|The generator MUST treat the member as feature-gated; the analyzer owns the user-facing warning via `SGIOC022`.| + +#### Ordering Contract + +The generator MUST emit member injection in the following fixed stage order. Source declaration order applies **within** each stage. + +|Stage|Members|Emission rule| +|:----|:------|:------------| +|1|Properties|MUST be assigned first, in source declaration order.| +|2|Fields|MUST be assigned second, in source declaration order.| +|3|Synchronous methods|MUST be invoked third, in source declaration order.| +|4|Async methods|MUST be awaited last, in source declaration order.| + +```mermaid +flowchart LR + A[Construct instance] --> B[Assign properties] + B --> C[Assign fields] + C --> D[Call sync inject methods] + D --> E[await async inject methods] + E --> F[Return completed Task] +``` + +#### Register-Path Generation + +When a registration contains one or more async inject methods, the registration generator MUST emit a `Task` registration. The generated factory MUST create an `async Task Init(...)` local function, perform synchronous injection first, then `await` each async inject method in stage order. + +```csharp +#region Define: +using System.Threading.Tasks; + +public interface ILogger { } +public interface IAsyncInitializer +{ + Task InitializeAsync(); +} + +public interface IMyService { } + +[IocRegister(ServiceTypes = [typeof(IMyService)], Lifetime = ServiceLifetime.Singleton)] +public class MyService : IMyService +{ + [IocInject] + public ILogger Logger { get; set; } = default!; + + [IocInject] + public void InitializeSync() + { + } + + [IocInject] + public async Task InitializeAsync(IAsyncInitializer initializer) + { + await initializer.InitializeAsync(); + } +} +#endregion + +#region Generate: +services.AddSingleton>(sp => +{ + async Task Init(global::System.IServiceProvider provider) + { + var initializer = provider.GetRequiredService(); + var instance = new global::MyService + { + Logger = provider.GetRequiredService(), + }; + + instance.InitializeSync(); + await instance.InitializeAsync(initializer); + return instance; + } + + return Init(sp); +}); +#endregion +``` + +#### `WrapperKind.Task` Detection and Resolution + +`Task` is the async-init wrapper for consumer dependencies and partial accessors. The generator MUST only recognize the direct, non-nested form. Resolution MUST distinguish whether `T` itself has an async-init service path or only a synchronous service path. + +|Requested type|Classification|Resolution behavior| +|:-------------|:-------------|:------------------| +|`Task` where `T` is a non-wrapper service type|`WrapperKind.Task`|MUST classify as `WrapperKind.Task`. Resolution MUST follow the async-init vs. sync-only rules below.| +|`Task>`|Unsupported|MUST NOT be classified as `WrapperKind.Task`.| +|`Lazy>`|Unsupported|MUST NOT be classified as a supported nested wrapper combination.| +|`IEnumerable>`|Unsupported|MUST NOT be classified as a supported collection wrapper.| + +|Inner service `T` shape|Register-path behavior|Container-path behavior| +|:----------------------|:---------------------|:----------------------| +|Async-init service (`T` resolves through generated `Task` init path)|MUST resolve `Task` directly via `sp.GetRequiredService>()` or `sp.GetService>()`, depending on requiredness/default handling rules.|MUST call the generated async resolver path for `Task` directly.| +|Sync-only service (`T` only has synchronous resolution)|MUST wrap the synchronous resolution as `Task.FromResult(sp.GetRequiredService())` or the matching optional/default-aware `Task.FromResult(...)` form.|MUST wrap the sync resolver result as `Task.FromResult(self.GetT_Resolve())` or the equivalent generated sync resolver call.| + +```csharp +// Invalid: nested async wrappers are not supported. +[IocRegister] +public class BadConsumer +{ + [IocInject] + public void Initialize(Task> service) + { + } +} +``` + ### Examples **Basic Injection**: diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/SPEC.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/SPEC.spec.md index 66ebdc3..c832a3a 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/SPEC.spec.md +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/SPEC.spec.md @@ -13,7 +13,7 @@ Find detailed documentation for each feature: |Basic Registration|[Register.Basic.spec.md](Register.Basic.spec.md)|Core service registration patterns including implementation types and keyed services| |Decorators|[Register.Decorators.spec.md](Register.Decorators.spec.md)|Decorator pattern for composing services with multiple layers| |Tags|[Register.Tags.spec.md](Register.Tags.spec.md)|Tag-based mutually exclusive service registration| -|Injection Members|[Register.Injection.spec.md](Register.Injection.spec.md)|Field, property, method, and constructor injection patterns| +|Injection Members|[Register.Injection.spec.md](Register.Injection.spec.md)|Field, property, method, async method, and constructor injection patterns| |Imported Modules|[Register.ImportModule.spec.md](Register.ImportModule.spec.md)|Cross-assembly module importing and sharing registrations| |Open Generics|[Register.Generics.spec.md](Register.Generics.spec.md)|Generic service types, closed generic discovery, and generic factory mapping| |IServiceProvider|[Register.ServiceProviderInvocation.spec.md](Register.ServiceProviderInvocation.spec.md)|Automatic service discovery from IServiceProvider invocations| @@ -33,7 +33,7 @@ Find detailed documentation for each feature: |Imported Modules|[Container.ImportModule.spec.md](Container.ImportModule.spec.md)|FrozenDictionary-based service resolution with module composition| |Factory & Instance|[Container.Factory.spec.md](Container.Factory.spec.md)|Factory-created and static instance service handling| |Open Generics|[Container.Generics.spec.md](Container.Generics.spec.md)|Open generic service resolution| -|Collections & Wrappers|[Container.Collections.spec.md](Container.Collections.spec.md)|Collection types (IEnumerable, arrays) and wrapper types (Lazy, Func, KeyValuePair)| +|Collections & Wrappers|[Container.Collections.spec.md](Container.Collections.spec.md)|Collection types (IEnumerable, arrays) and wrapper types (Lazy, Func, Task, KeyValuePair)| |Container Options|[Container.Options.spec.md](Container.Options.spec.md)|Configuration attributes and behavior flags (IntegrateServiceProvider, ExplicitOnly, etc.)| |Thread Safety|[Container.ThreadSafety.spec.md](Container.ThreadSafety.spec.md)|Thread-safe service initialization strategies (Lock, SemaphoreSlim, SpinLock, CompareExchange)| |Partial Accessors|[Container.PartialAccessors.spec.md](Container.PartialAccessors.spec.md)|Fast-path service resolution via partial members| @@ -76,7 +76,7 @@ Find detailed documentation for each feature: |`AllBaseClasses`|All base classes (excluding `System.Object`)| |`TypeParameters`|Generic type parameters with constraints| |`ConstructorParameters`|Constructor parameters (for decorators)| -|`WrapperKind`|`None`, `Enumerable`, `ReadOnlyCollection`, `Collection`, `ReadOnlyList`, `List`, `Array`, `Lazy`, `Func`, `Dictionary`, or `KeyValuePair`| +|`WrapperKind`|`None`, `Enumerable`, `ReadOnlyCollection`, `Collection`, `ReadOnlyList`, `List`, `Array`, `Lazy`, `Func`, `Task`, `Dictionary`, or `KeyValuePair`| ### 4. Injection Members @@ -85,6 +85,7 @@ Find detailed documentation for each feature: |Property|With `[IocInject]`/`[Inject]`, set via object initializer| |Field|With `[IocInject]`/`[Inject]`, set via object initializer| |Method|With `[IocInject]`/`[Inject]`, called after construction| +|AsyncMethod|With `[IocInject]`/`[Inject]`, awaited after synchronous member injection when `AsyncMethodInject` is enabled| ### 5. IServiceProvider Invocations @@ -106,23 +107,42 @@ The `SourceGenIocFeatures` MSBuild property controls which outputs and injection Available features: -|Feature|Description| -|:------|:----------| -|`Register`|Enable generation of the registration extension method output| -|`Container`|Enable generation of the container class output| -|`PropertyInject`|Enable property injection member generation| -|`FieldInject`|Enable field injection member generation| -|`MethodInject`|Enable method injection member generation| +|Feature|Value|Description| +|:------|:----|:----------| +|`Register`|`1 << 0`|Enable generation of the registration extension method output.| +|`Container`|`1 << 1`|Enable generation of the container class output.| +|`PropertyInject`|`1 << 2`|Enable property injection member generation.| +|`FieldInject`|`1 << 3`|Enable field injection member generation.| +|`MethodInject`|`1 << 4`|Enable synchronous method injection member generation.| +|`AsyncMethodInject`|`1 << 5`|Enable awaited `[IocInject]`/`[Inject]` methods that return non-generic `Task`. This feature MUST be combined with `MethodInject`; otherwise the analyzer MUST report `SGIOC026`.| Default value: `Register,Container,PropertyInject,MethodInject` +`AsyncMethodInject` is **NOT** part of `Default`. + Behavior: - `Register`: Controls whether the registration extension method output is generated. - `Container`: Controls whether the container class output is generated. -- `PropertyInject` / `FieldInject` / `MethodInject`: Control which injection member types are included in generated code. +- `PropertyInject` / `FieldInject` / `MethodInject`: Control which synchronous injection member types are included in generated code. +- `AsyncMethodInject`: Controls awaited async method injection for `[IocInject]` methods that return `Task`. + +Feature dependency rules: + +|Condition|Required behavior| +|:--------|:----------------| +|`AsyncMethodInject` enabled and `MethodInject` disabled|The configuration is invalid. The analyzer MUST report `SGIOC026`: `'AsyncMethodInject' feature requires 'MethodInject' to be enabled.`| +|`AsyncMethodInject` omitted|`Task`-returning injection methods are not enabled and MUST NOT participate in generated injection code.| + +Enabling example: + +```xml + + Register,Container,PropertyInject,MethodInject,AsyncMethodInject + +``` Parsing rules: @@ -229,6 +249,7 @@ Only members with `[IocInject]` or `[Inject]`: |`Array`|`ArrayTypeData`|`T[]`|`GetServices().ToArray()`| |`Lazy`|`LazyTypeData`|`Lazy`|Lazy-initialized service wrapper| |`Func`|`FuncTypeData`|`Func` / `Func`|Factory delegate wrapper| +|`Task`|`TaskTypeData`|`Task`|Async-init wrapper; resolve `Task` directly for async-init services or wrap sync resolution with `Task.FromResult(...)` for sync-only services.| |`Dictionary`|`DictionaryTypeData`|`IDictionary`|Dictionary of keyed services| |`KeyValuePair`|`KeyValuePairTypeData`|`KeyValuePair`|Single keyed service entry| @@ -248,6 +269,7 @@ TypeData │ └── ArrayTypeData (Array) ├── LazyTypeData (Lazy) ├── FuncTypeData (Func) + ├── TaskTypeData (Task) ├── DictionaryTypeData (Dictionary) └── KeyValuePairTypeData (KeyValuePair) ``` diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformRegister.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformRegister.cs index c53ddce..1b4e5f2 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformRegister.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformRegister.cs @@ -351,8 +351,13 @@ private static InjectionMemberData CreateMethodInjection(IMethodSymbol method, s }) .ToImmutableEquatableArray(); + // Detect async injection methods: non-generic Task return type → AsyncMethod + var memberType = RoslynExtensions.IsNonGenericTaskReturnType(method) + ? InjectionMemberType.AsyncMethod + : InjectionMemberType.Method; + return new InjectionMemberData( - InjectionMemberType.Method, + memberType, method.Name, null, parameters, @@ -402,7 +407,38 @@ private static ImmutableEquatableArray ExtractAndMergeInjec ? ExtractInjectMembersFromAttribute(attributeData, semanticModel) : []; - return MergeInjectionMembers(iocInjectMembers, attrInjectMembers); + var merged = MergeInjectionMembers(iocInjectMembers, attrInjectMembers); + return SortInjectionMembersByStage(merged); + } + + /// + /// Sorts injection members by injection stage (Property → Field → Method → AsyncMethod), + /// preserving source declaration order within each stage. + /// + private static ImmutableEquatableArray SortInjectionMembersByStage( + ImmutableEquatableArray members) + { + if(members.Length <= 1) + return members; + + // Check if already in correct order to avoid allocation + bool inOrder = true; + for(int i = 1; i < members.Length; i++) + { + if(members[i].MemberType < members[i - 1].MemberType) + { + inOrder = false; + break; + } + } + + if(inOrder) + return members; + + // Stable sort: OrderBy preserves relative order within each stage + return members + .OrderBy(static m => m.MemberType) + .ToImmutableEquatableArray(); } /// @@ -581,7 +617,7 @@ private static bool IsInjectableMember(ISymbol symbol) IFieldSymbol field => !field.IsReadOnly && field.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal or Accessibility.ProtectedOrInternal, IMethodSymbol method => method.MethodKind == MethodKind.Ordinary - && method.ReturnsVoid + && (method.ReturnsVoid || IsNonGenericTaskReturnType(method)) && !method.IsGenericMethod && method.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal or Accessibility.ProtectedOrInternal, _ => false diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/ContainerWithGroups.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/ContainerWithGroups.cs index bb6af42..d4e6008 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/ContainerWithGroups.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/ContainerWithGroups.cs @@ -54,11 +54,13 @@ internal sealed record ContainerRegistrationGroups( /// The pre-computed field name for storing the service instance. /// The pre-computed resolver method name. /// Whether this registration should be eagerly resolved during container/scope construction. +/// Whether this registration has async initialization members (pre-computed from ). internal readonly record struct CachedRegistration( ServiceRegistrationModel Registration, string FieldName, string ResolverMethodName, - bool IsEager); + bool IsEager, + bool IsAsyncInit); /// /// Represents a Lazy resolver entry for container code generation. diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/InjectionMemberData.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/InjectionMemberData.cs index fe2b244..3bdfcd9 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/InjectionMemberData.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/InjectionMemberData.cs @@ -45,5 +45,10 @@ internal enum InjectionMemberType /// /// A method to be called after object creation. /// - Method + Method, + + /// + /// An async method (returning Task) to be awaited after object creation. + /// + AsyncMethod } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/IocFeatures.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/IocFeatures.cs index 0caa675..9198af8 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/IocFeatures.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/IocFeatures.cs @@ -11,12 +11,13 @@ internal enum IocFeatures PropertyInject = 1 << 2, FieldInject = 1 << 3, MethodInject = 1 << 4, + AsyncMethodInject = 1 << 5, Default = Register | Container | PropertyInject | MethodInject } internal static class IocFeaturesHelper { - private const IocFeatures AllInjectionFeatures = IocFeatures.PropertyInject | IocFeatures.FieldInject | IocFeatures.MethodInject; + private const IocFeatures AllInjectionFeatures = IocFeatures.PropertyInject | IocFeatures.FieldInject | IocFeatures.MethodInject | IocFeatures.AsyncMethodInject; public static bool HasAllInjectionFeatures(IocFeatures features) => (features & AllInjectionFeatures) == AllInjectionFeatures; @@ -26,6 +27,7 @@ public static bool IsInjectionFeatureEnabled(InjectionMemberType memberType, Ioc InjectionMemberType.Property => (features & IocFeatures.PropertyInject) != 0, InjectionMemberType.Field => (features & IocFeatures.FieldInject) != 0, InjectionMemberType.Method => (features & IocFeatures.MethodInject) != 0, + InjectionMemberType.AsyncMethod => (features & IocFeatures.AsyncMethodInject) != 0, _ => false }; @@ -49,6 +51,8 @@ public static IocFeatures Parse(string? rawFeatures) features |= IocFeatures.FieldInject; else if(token.Equals("methodinject", StringComparison.OrdinalIgnoreCase)) features |= IocFeatures.MethodInject; + else if(token.Equals("asyncmethodinject", StringComparison.OrdinalIgnoreCase)) + features |= IocFeatures.AsyncMethodInject; } return features; diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TransformExtensions.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TransformExtensions.cs index 3c57e96..7360457 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TransformExtensions.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TransformExtensions.cs @@ -105,6 +105,19 @@ public TypeData GetTypeData( var nameWithoutGeneric = GetNameWithoutGeneric(typeName); var wrapperKind = typeSymbol.GetWrapperKind(nameWithoutGeneric); + // Downgrade nested Task or Wrapper shapes to WrapperKind.None. + // These shapes are not supported by the spec and fall back to IServiceProvider resolution. + if(wrapperKind == WrapperKind.Task && typeParameters is { Length: > 0 } && typeParameters[0].Type is WrapperTypeData) + { + wrapperKind = WrapperKind.None; + } + else if(wrapperKind is not WrapperKind.None + && typeParameters is not null + && typeParameters.Any(p => p.Type is TaskTypeData)) + { + wrapperKind = WrapperKind.None; + } + if(wrapperKind is not WrapperKind.None) { return TypeData.CreateWrapper( @@ -441,7 +454,7 @@ public WrapperKind GetWrapperKind(string nameWithoutGeneric) if(IsEnumerableType(nameWithoutGeneric)) return WrapperKind.Enumerable; - return GetNonCollectionWrapperKind(nameWithoutGeneric); + return GetNonCollectionWrapperKind(nameWithoutGeneric, typeSymbol.Arity); } /// @@ -450,7 +463,9 @@ public WrapperKind GetWrapperKind(string nameWithoutGeneric) /// Collection types (IEnumerable, IReadOnlyCollection, etc.) are detected separately /// in via GetWrapperKind. /// - public static WrapperKind GetNonCollectionWrapperKind(string nameWithoutGeneric) => nameWithoutGeneric switch + /// The type name without generic parameters. + /// The number of type parameters. Used to distinguish Task<T> (arity 1) from non-generic Task (arity 0). + public static WrapperKind GetNonCollectionWrapperKind(string nameWithoutGeneric, int arity) => nameWithoutGeneric switch { "global::System.Lazy" or "System.Lazy" or "Lazy" => WrapperKind.Lazy, "global::System.Func" or "System.Func" or "Func" => WrapperKind.Func, @@ -459,6 +474,7 @@ public WrapperKind GetWrapperKind(string nameWithoutGeneric) or "global::System.Collections.Generic.Dictionary" or "System.Collections.Generic.Dictionary" or "Dictionary" => WrapperKind.Dictionary, "global::System.Collections.Generic.KeyValuePair" or "System.Collections.Generic.KeyValuePair" or "KeyValuePair" => WrapperKind.KeyValuePair, + ("global::System.Threading.Tasks.Task" or "System.Threading.Tasks.Task" or "Task") when arity == 1 => WrapperKind.Task, _ => WrapperKind.None }; @@ -474,7 +490,9 @@ public WrapperKind GetWrapperKind(string nameWithoutGeneric) /// - Non-static members only /// - Properties with a setter /// - Non-readonly fields - /// - Ordinary methods that return void and are not generic + /// - Ordinary methods that return (sync) or non-generic + /// (async, when AsyncMethodInject + /// feature is enabled), and are not generic /// /// /// An enumerable of tuples containing the member symbol and its inject attribute. @@ -505,7 +523,7 @@ public WrapperKind GetWrapperKind(string nameWithoutGeneric) IPropertySymbol property => property.SetMethod is not null, IFieldSymbol field => !field.IsReadOnly, IMethodSymbol method => method.MethodKind == MethodKind.Ordinary - && method.ReturnsVoid + && (method.ReturnsVoid || RoslynExtensions.IsNonGenericTaskReturnType(method)) && !method.IsGenericMethod, _ => false }; diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TypeData.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TypeData.cs index e3540d9..2d0b1f1 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TypeData.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TypeData.cs @@ -334,6 +334,29 @@ internal sealed record class KeyValuePairTypeData( ConstructorParameters, HasInjectConstructor, InjectionMembers, AllInterfaces, AllBaseClasses); +/// +/// Represents Task<T> wrapper type. Async-initialized service wrapper. +/// Resolved via an async resolver method that awaits async inject methods. +/// +internal sealed record class TaskTypeData( + string Name, + string NameWithoutGeneric, + bool IsOpenGeneric, + int GenericArity, + bool IsNestedOpenGeneric = false, + ImmutableEquatableArray? TypeParameters = null, + ImmutableEquatableArray? ConstructorParameters = null, + bool HasInjectConstructor = false, + ImmutableEquatableArray? InjectionMembers = null, + ImmutableEquatableArray? AllInterfaces = null, + ImmutableEquatableArray? AllBaseClasses = null) + : WrapperTypeData( + Name, NameWithoutGeneric, IsOpenGeneric, GenericArity, + WrapperKind.Task, + IsNestedOpenGeneric, TypeParameters, + ConstructorParameters, HasInjectConstructor, InjectionMembers, + AllInterfaces, AllBaseClasses); + /// /// Represents the kind of wrapper for DI injection purposes. /// Each value has a corresponding sealed TypeData derived type. @@ -398,7 +421,13 @@ internal enum WrapperKind /// /// KeyValuePair<TKey, TValue> - single keyed service entry. /// - KeyValuePair + KeyValuePair, + + /// + /// Task<T> - async-initialized service wrapper. + /// Resolved via an async resolver method that awaits async inject methods. + /// + Task } internal static class TypeDataExtensions @@ -529,6 +558,11 @@ public static WrapperTypeData CreateWrapper( IsNestedOpenGeneric, TypeParameters, ConstructorParameters, HasInjectConstructor, InjectionMembers, AllInterfaces, AllBaseClasses), + WrapperKind.Task => new TaskTypeData( + Name, NameWithoutGeneric, IsOpenGeneric, GenericArity, + IsNestedOpenGeneric, TypeParameters, + ConstructorParameters, HasInjectConstructor, InjectionMembers, + AllInterfaces, AllBaseClasses), _ => new WrapperTypeData( Name, NameWithoutGeneric, IsOpenGeneric, GenericArity, WrapperKind, @@ -662,4 +696,12 @@ public static WrapperTypeData CreateWrapper( /// public TypeData ValueType => typeData.TypeParameters![1].Type; } + + extension(TaskTypeData typeData) + { + /// + /// Gets the inner service type of the Task<T> wrapper. + /// + public TypeData InnerType => typeData.TypeParameters![0].Type; + } } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/RoslynExtensions.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/RoslynExtensions.cs index 5909a3d..c4e3424 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/RoslynExtensions.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/RoslynExtensions.cs @@ -821,6 +821,14 @@ public static bool IsListType(string nameWithoutGeneric) => return symbol as IMethodSymbol; } + /// + /// Returns when the method returns the non-generic + /// type (arity 0). + /// + internal static bool IsNonGenericTaskReturnType(IMethodSymbol method) + => method.ReturnType is INamedTypeSymbol { Arity: 0, Name: "Task" } named + && named.ContainingNamespace.ToDisplayString() == "System.Threading.Tasks"; + extension(IEnumerable source) { public IEnumerable<(int Index, T Item)> Index() diff --git a/src/Ioc/src/SourceGen.Ioc/IocContainerGlobalOptions.cs b/src/Ioc/src/SourceGen.Ioc/IocContainerGlobalOptions.cs new file mode 100644 index 0000000..1c9fd82 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc/IocContainerGlobalOptions.cs @@ -0,0 +1,13 @@ +namespace SourceGen.Ioc; + +/// +/// Global options for IoC containers generated by SourceGen.Ioc. +/// +public static class IocContainerGlobalOptions +{ + /// + /// Gets or sets an action to be called when an exception occurs during disposal of a service. + /// Set this to a logging delegate to capture disposal errors. + /// + public static Action? OnDisposeException { get; set; } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC007Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC007Tests.cs index aa85dbf..7c59e92 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC007Tests.cs +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC007Tests.cs @@ -392,8 +392,10 @@ public class TestService : IService } [Test] - public async Task SGIOC007_InjectAttribute_OnAsyncMethod_ReportsDiagnostic() + public async Task SGIOC007_InjectAttribute_OnAsyncMethod_ReportsSGIOC022NotSGIOC007() { + // Per spec: Task-returning methods require AsyncMethodInject feature. + // When that feature is OFF (default), SGIOC022 MUST fire and SGIOC007 MUST NOT. const string source = """ using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -413,9 +415,14 @@ public class TestService : IService var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync(source); var sgioc007 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC007").ToList(); + var sgioc022 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC022").ToList(); - await Assert.That(sgioc007).Count().IsEqualTo(1); - await Assert.That(sgioc007[0].GetMessage()).Contains("InitializeAsync").And.Contains("void"); + // SGIOC007 must NOT fire (no duplicate return-type error) + await Assert.That(sgioc007).Count().IsEqualTo(0); + + // SGIOC022 MUST fire with AsyncMethodInject feature name + await Assert.That(sgioc022).Count().IsEqualTo(1); + await Assert.That(sgioc022[0].GetMessage()).Contains("InitializeAsync").And.Contains("AsyncMethodInject"); } [Test] @@ -779,4 +786,105 @@ protected TestService(IService service) { } await Assert.That(sgioc007).Count().IsEqualTo(0); } + + [Test] + public async Task SGIOC007_InjectAttribute_OnTaskMethod_WithAsyncMethodInjectEnabled_NoDiagnostic() + { + // When AsyncMethodInject is ON, non-generic Task return type is allowed — SGIOC007 must not fire. + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync(IService service) => Task.CompletedTask; + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc007 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC007"); + + await Assert.That(sgioc007).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC007_InjectAttribute_OnGenericTaskMethod_WithAsyncMethodInjectEnabled_ReportsDiagnostic() + { + // Task (generic Task) is NOT allowed even when AsyncMethodInject is ON — only non-generic Task. + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync(IService service) => Task.FromResult(0); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc007 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC007").ToList(); + + await Assert.That(sgioc007).Count().IsEqualTo(1); + await Assert.That(sgioc007[0].GetMessage()).Contains("InitializeAsync").And.Contains("void or non-generic Task"); + } + + [Test] + public async Task SGIOC007_InjectAttribute_OnValueTaskMethod_WithAsyncMethodInjectEnabled_ReportsDiagnostic() + { + // ValueTask is NOT allowed even when AsyncMethodInject is ON — only non-generic Task is accepted. + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister] + public class TestService : IService + { + [IocInject] + public ValueTask InitializeAsync(IService service) => ValueTask.CompletedTask; + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc007 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC007").ToList(); + + await Assert.That(sgioc007).Count().IsEqualTo(1); + await Assert.That(sgioc007[0].GetMessage()).Contains("InitializeAsync").And.Contains("void or non-generic Task"); + } } diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC018Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC018Tests.cs index a485834..80f0134 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC018Tests.cs +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC018Tests.cs @@ -339,5 +339,40 @@ public partial class TestContainer { } await Assert.That(sgioc018).Count().IsEqualTo(2); } + + [Test] + public async Task SGIOC018_AssemblyLevelRegisterFor_NonGeneric_DependencyRegistered_NoDiagnostic() + { + // IDependency is registered via assembly-level [IocRegisterFor] (non-generic). + // IMyService depends on IDependency. SGIOC018 should NOT fire. + const string source = """ + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + [assembly: IocRegisterFor(typeof(TestNamespace.MyDependency), ServiceLifetime.Singleton, ServiceTypes = [typeof(TestNamespace.IDependency)])] + + namespace TestNamespace; + + public interface IDependency { } + + public class MyDependency : IDependency { } + + public interface IMyService { } + + [IocRegister(ServiceLifetime.Singleton)] + public class MyService : IMyService + { + public MyService(IDependency dependency) { } + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer { } + """; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync(source); + var sgioc018 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC018"); + + await Assert.That(sgioc018).Count().IsEqualTo(0); + } } diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC021Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC021Tests.cs new file mode 100644 index 0000000..70abe07 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC021Tests.cs @@ -0,0 +1,1016 @@ +namespace SourceGen.Ioc.Test.Analyzer; + +/// +/// Tests for SGIOC021: Unable to resolve partial accessor service. +/// +[Category(Constants.Analyzer)] +[Category(Constants.SGIOC021)] +public class SGIOC021Tests +{ + // ── Positive tests (SGIOC021 expected) ──────────────────────────────────── + + [Test] + public async Task SGIOC021_PartialMethod_UnregisteredService_ReportsDiagnostic() + { + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IUnregistered GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_PartialProperty_UnregisteredService_ReportsDiagnostic() + { + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IUnregistered Service { get; } + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_TaskWrapper_UnregisteredInnerType_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Task GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_LazyWrapper_UnregisteredInnerType_ReportsDiagnostic() + { + const string source = """ + using System; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Lazy GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_FuncWrapper_UnregisteredInnerType_ReportsDiagnostic() + { + const string source = """ + using System; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Func GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_MultiArgFuncWrapper_UnregisteredInnerType_ReportsDiagnostic() + { + const string source = """ + using System; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Func GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_EnumerableWrapper_UnregisteredInnerType_ReportsDiagnostic() + { + const string source = """ + using System.Collections.Generic; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IEnumerable GetServices(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_KeyedPartialMethod_UnregisteredService_ReportsDiagnostic() + { + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + [IocInject("mykey")] + public partial IUnregistered GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_KeyedAccessor_UnkeyedAsyncInitService_ReportsDiagnostic() + { + // Unkeyed async-init registration exists, but accessor asks for a keyed service. + // SGIOC021 should still report because keyed registration is missing. + const string source = """ + using System; + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + [assembly: IocRegisterFor(ServiceTypes = [typeof(TestNamespace.IMyService)])] + + namespace TestNamespace; + + public interface IMyService { } + + public class MyAsyncService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + [FromKeyedServices("myKey")] + public partial Lazy GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(1); + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029")).Count().IsEqualTo(0); + } + + // ── Positive tests — unsupported return types ────────────────────────────── + + [Test] + public async Task SGIOC021_NonGenericTask_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Task GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_NonGenericValueTask_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial ValueTask GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_ValueTaskWrapper_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial ValueTask GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + // ── Negative tests (NO SGIOC021 expected) ───────────────────────────────── + + [Test] + public async Task SGIOC021_RegisteredService_NoDiagnostic() + { + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IService GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_NullableUnregistered_NoDiagnostic() + { + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IUnregistered? Service { get; } + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_IntegrateServiceProviderTrue_NoDiagnostic() + { + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = true)] + public partial class TestContainer + { + public partial IUnregistered GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_EnumerableRegistered_NoDiagnostic() + { + const string source = """ + using System.Collections.Generic; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IEnumerable GetServices(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_AlwaysResolvableType_NoDiagnostic() + { + const string source = """ + using System; + using SourceGen.Ioc; + + namespace TestNamespace; + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IServiceProvider GetServiceProvider(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_TaskWrapperRegistered_NoDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Task GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_LazyWrapperRegistered_NoDiagnostic() + { + const string source = """ + using System; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Lazy GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_IntegrateServiceProvider_NoDiagnostic() + { + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = true)] + public partial class TestContainer + { + public partial IUnregistered GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_AssemblyLevelRegisterForGeneric_PartialAccessorRegistered_NoDiagnostic() + { + // IMyService is registered via assembly-level [IocRegisterFor] (generic). + // Container partial accessor returns IMyService. SGIOC021 should NOT fire. + const string source = """ + using SourceGen.Ioc; + + [assembly: IocRegisterFor(ServiceTypes = [typeof(TestNamespace.IMyService)])] + + namespace TestNamespace; + + public interface IMyService { } + + public class MyService : IMyService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IMyService GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + // ── Recursive wrapper unwrap tests ─────────────────────────────────────── + + [Test] + public async Task SGIOC021_NestedLazyFunc_RegisteredInnerService_NoDiagnostic() + { + // Lazy> — recursive unwrap reaches registered IService → no diagnostic. + const string source = """ + using System; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Lazy> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_NestedLazyFunc_UnregisteredInnerService_ReportsDiagnostic() + { + // Lazy> — recursive unwrap reaches unregistered service → SGIOC021. + const string source = """ + using System; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IUnregistered { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Lazy> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_TaskNestedWrapper_DowngradedShape_ReportsDiagnostic() + { + // Task> — downgrade rule: Task is unresolvable without IServiceProvider. + // Even though IService is registered, the generator falls back for nested-Task shapes. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Task> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + // ── SGIOC021 negative tests — downgraded async-init shapes (SGIOC029 fires, not SGIOC021) ── + + [Test] + public async Task SGIOC021_TaskLazy_AsyncInitService_IntegrateServiceProviderFalse_NoDiagnostic() + { + // Task> (downgraded shape) on an async-init service. + // SGIOC029 fires for this shape; SGIOC021 should NOT fire. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Task> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_LazyTask_AsyncInitService_IntegrateServiceProviderFalse_NoDiagnostic() + { + // Lazy> (downgraded shape) on an async-init service. + // SGIOC029 fires for this shape; SGIOC021 should NOT fire. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Lazy> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_ValueTaskLazy_AsyncInitService_IntegrateServiceProviderFalse_NoDiagnostic() + { + // ValueTask> on an async-init service. + // SGIOC029 fires for this shape; SGIOC021 should NOT fire. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial ValueTask> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_CollectionAtTopNestedWrappers_ReportsDiagnostic() + { + // IEnumerable>> — collection-at-top downgrade rule fires because + // IEnumerable resolves to GetServices() which requires IServiceProvider. + const string source = """ + using System; + using System.Collections.Generic; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IEnumerable>> GetServices(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_LazyTaskWrapper_DowngradedShape_ReportsDiagnostic() + { + // Lazy> — downgrade rule: Wrapper is unresolvable without IServiceProvider. + // Even though IService is registered, the generator falls back for Wrapper> shapes. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Lazy> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC021_ValueTaskNestedWrapper_ReportsDiagnostic() + { + // ValueTask> — ValueTask is not a generator-supported recursive wrapper. + // GetAccessorServiceType unwraps once to Lazy, which is not a registered type → SGIOC021. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial ValueTask> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToArray(); + + await Assert.That(sgioc021).Count().IsEqualTo(1); + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC022Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC022Tests.cs index f7deceb..af90685 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC022Tests.cs +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC022Tests.cs @@ -39,4 +39,106 @@ public class TestService : IService await Assert.That(sgioc022).Count().IsEqualTo(1); await Assert.That(sgioc022[0].GetMessage()).Contains("Dependency").And.Contains("PropertyInject"); } + + [Test] + public async Task SGIOC022_InjectAttribute_WhenMethodInjectDisabled_ReportsDiagnostic() + { + // Void-returning method with MethodInject OFF → SGIOC022 with MethodInject feature name. + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister] + public class TestService : IService + { + [IocInject] + public void Initialize(IService service) { } + } + """; + + var analyzerConfigOptions = new Dictionary + { + // MethodInject explicitly disabled + ["build_property.SourceGenIocFeatures"] = "Register,Container,PropertyInject,FieldInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc022 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC022").ToList(); + + await Assert.That(sgioc022).Count().IsEqualTo(1); + await Assert.That(sgioc022[0].GetMessage()).Contains("Initialize").And.Contains("MethodInject"); + } + + [Test] + public async Task SGIOC022_InjectAttribute_TaskMethodWhenAsyncMethodInjectDisabled_ReportsDiagnosticWithAsyncMethodInjectName() + { + // Task-returning method with MethodInject ON but AsyncMethodInject OFF → SGIOC022 with AsyncMethodInject feature name. + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync(IService service) => Task.CompletedTask; + } + """; + + var analyzerConfigOptions = new Dictionary + { + // MethodInject ON, AsyncMethodInject OFF + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc022 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC022").ToList(); + + await Assert.That(sgioc022).Count().IsEqualTo(1); + await Assert.That(sgioc022[0].GetMessage()).Contains("InitializeAsync").And.Contains("AsyncMethodInject"); + } + + [Test] + public async Task SGIOC022_InjectAttribute_TaskMethodWhenAsyncMethodInjectEnabled_NoDiagnostic() + { + // Task-returning method with both MethodInject and AsyncMethodInject ON → no SGIOC022. + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync(IService service) => Task.CompletedTask; + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc022 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC022"); + + await Assert.That(sgioc022).Count().IsEqualTo(0); + } } diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC024Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC024Tests.cs index 5ad39cb..aff0965 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC024Tests.cs +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC024Tests.cs @@ -247,6 +247,43 @@ public class MyService await Assert.That(sgioc024[0].GetMessage()).Contains("GetDep").And.Contains("void"); } + [Test] + public async Task SGIOC024_NonVoidMethod_WithAsyncMethodInjectEnabled_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IDependency { } + + [IocRegisterFor(typeof(MyService), InjectMembers = [nameof(MyService.GetDepAsync)])] + public static class MyModule { } + + public class MyService + { + public Task GetDepAsync(IDependency dep) => Task.FromResult(0); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc024 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, Constants.SGIOC024).ToList(); + + await Assert.That(sgioc024).Count().IsEqualTo(1); + await Assert.That(sgioc024[0].GetMessage()) + .Contains("GetDepAsync") + .And.Contains("void or non-generic Task"); + } + [Test] public async Task SGIOC024_GenericMethod_ReportsDiagnostic() { diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC026Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC026Tests.cs new file mode 100644 index 0000000..2acf4fe --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC026Tests.cs @@ -0,0 +1,106 @@ +namespace SourceGen.Ioc.Test.Analyzer; + +/// +/// Tests for SGIOC026: AsyncMethodInject feature requires MethodInject to be enabled. +/// +[Category(Constants.Analyzer)] +[Category(Constants.SGIOC026)] +public class SGIOC026Tests +{ + [Test] + public async Task SGIOC026_AsyncMethodInjectWithoutMethodInject_ReportsDiagnostic() + { + // SGIOC026 fires at compilation level when AsyncMethodInject is ON but MethodInject is OFF. + // Any source referencing SourceGen.Ioc is sufficient — diagnostic has no specific location. + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + [IocRegister] + public class TestService { } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc026 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC026").ToList(); + + await Assert.That(sgioc026).Count().IsEqualTo(1); + await Assert.That(sgioc026[0].GetMessage()).Contains("AsyncMethodInject").And.Contains("MethodInject"); + } + + [Test] + public async Task SGIOC026_AsyncMethodInjectWithMethodInject_NoDiagnostic() + { + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + [IocRegister] + public class TestService { } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc026 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC026"); + + await Assert.That(sgioc026).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC026_OnlyMethodInject_NoDiagnostic() + { + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + [IocRegister] + public class TestService { } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc026 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC026"); + + await Assert.That(sgioc026).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC026_DefaultFeatures_NoDiagnostic() + { + // Default features include MethodInject but NOT AsyncMethodInject — SGIOC026 must not fire. + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + [IocRegister] + public class TestService { } + """; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync(source); + var sgioc026 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC026"); + + await Assert.That(sgioc026).Count().IsEqualTo(0); + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC027Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC027Tests.cs new file mode 100644 index 0000000..59f301d --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC027Tests.cs @@ -0,0 +1,332 @@ +namespace SourceGen.Ioc.Test.Analyzer; + +/// +/// Tests for SGIOC027: Partial accessor must return Task<T> for an async-init service. +/// +[Category(Constants.Analyzer)] +[Category(Constants.SGIOC027)] +public class SGIOC027Tests +{ + [Test] + public async Task SGIOC027_PartialAccessorReturnsSyncType_ForAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync(IService service) => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IService GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc027 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC027").ToList(); + + await Assert.That(sgioc027).Count().IsEqualTo(1); + await Assert.That(sgioc027[0].GetMessage()).Contains("GetService").And.Contains("IService"); + } + + [Test] + public async Task SGIOC027_PartialAccessorReturnsTaskType_ForAsyncInitService_NoDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Task GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC027")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC027_PartialAccessorReturnsSyncType_ForKeyedAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)], Key = "special")] + public class AsyncService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class SyncService : IService + { + [IocInject] + public void Initialize() { } + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + [IocInject("special")] + public partial IService GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc027 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC027").ToList(); + + await Assert.That(sgioc027).Count().IsEqualTo(1); + await Assert.That(sgioc027[0].GetMessage()).Contains("GetService").And.Contains("IService"); + } + + [Test] + public async Task SGIOC027_NullablePartialAccessor_ForAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IService? GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc027 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC027").ToList(); + + await Assert.That(sgioc027).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC027_PartialAccessorReturnsSyncType_ForSyncOnlyService_NoDiagnostic() + { + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public void Initialize() { } + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IService GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC027")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC027_GenericTaskInjectMethod_DoesNotTriggerAsyncInit_NoDiagnostic() + { + // Task is a generic method — only non-generic Task qualifies as async init. + // A partial accessor returning the sync service type should NOT report SGIOC027. + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task SomeGenericMethod() => Task.FromResult(0); + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IService GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + // Task is generic so the service is NOT async-init; no SGIOC027 should fire. + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC027")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC027_AssemblyLevelRegisterForGeneric_SyncPartialAccessor_ReportsDiagnostic() + { + // AsyncService is registered via assembly-level [IocRegisterFor] (generic) with async inject method. + // Container partial accessor returns synchronous IService. SGIOC027 SHOULD fire. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + [assembly: IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(TestNamespace.IService)])] + + namespace TestNamespace; + + public interface IService { } + + public class AsyncService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IService GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc027 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC027").ToList(); + + await Assert.That(sgioc027).Count().IsEqualTo(1); + await Assert.That(sgioc027[0].GetMessage()).Contains("GetService").And.Contains("IService"); + } + + [Test] + public async Task SGIOC027_IntegrateServiceProviderTrue_SyncPartialAccessor_ForAsyncInitService_ReportsDiagnostic() + { + // SGIOC027 must fire for ALL containers regardless of IntegrateServiceProvider. + // Returning the sync service type for an async-init service is always a semantic error — + // the generator cannot produce a sync accessor for an async-init service. + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = true)] + public partial class TestContainer + { + public partial IService GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc027 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC027").ToList(); + + await Assert.That(sgioc027).Count().IsEqualTo(1); + await Assert.That(sgioc027[0].GetMessage()).Contains("GetService").And.Contains("IService"); + } +} + diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC028Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC028Tests.cs new file mode 100644 index 0000000..8ef276b --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC028Tests.cs @@ -0,0 +1,155 @@ +namespace SourceGen.Ioc.Test.Analyzer; + +/// +/// Tests for SGIOC028: [IocInject] method is declared as async void, which cannot be awaited. +/// +[Category(Constants.Analyzer)] +[Category(Constants.SGIOC028)] +public class SGIOC028Tests +{ + [Test] + public async Task SGIOC028_AsyncVoidMethod_ReportsDiagnostic() + { + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister] + public class TestService : IService + { + [IocInject] + public async void InitializeAsync(IService service) { } + } + """; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync(source); + var sgioc028 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC028").ToList(); + + await Assert.That(sgioc028).Count().IsEqualTo(1); + await Assert.That(sgioc028[0].GetMessage()).Contains("InitializeAsync").And.Contains("async void"); + } + + [Test] + public async Task SGIOC028_AsyncVoidMethod_NeitherSGIOC007NorSGIOC022AlsoFires() + { + // SGIOC028 fires first and returns early - the return-type check (SGIOC007) and + // feature-gate check (SGIOC022) must NOT duplicate the diagnostic. + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister] + public class TestService : IService + { + [IocInject] + public async void InitializeAsync(IService service) { } + } + """; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync(source); + + var sgioc007 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC007"); + var sgioc022 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC022"); + + await Assert.That(sgioc007).Count().IsEqualTo(0); + await Assert.That(sgioc022).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC028_AsyncTaskMethod_WithAsyncMethodInjectEnabled_NoDiagnostic() + { + // async Task is a valid signature when AsyncMethodInject is ON - SGIOC028 must not fire. + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister] + public class TestService : IService + { + [IocInject] + public async Task InitializeAsync(IService service) => await Task.CompletedTask; + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc028 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC028"); + + await Assert.That(sgioc028).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC028_SyncVoidMethod_NoDiagnostic() + { + // Regular void method (non-async) must NOT trigger SGIOC028. + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister] + public class TestService : IService + { + [IocInject] + public void Initialize(IService service) { } + } + """; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync(source); + var sgioc028 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC028"); + + await Assert.That(sgioc028).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC028_AsyncVoidMethod_FeaturesDisabled_StillReportsDiagnostic() + { + // SGIOC028 fires regardless of feature flags - async void is never acceptable. + const string source = """ + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister] + public class TestService : IService + { + [IocInject] + public async void InitializeAsync(IService service) { } + } + """; + + var analyzerConfigOptions = new Dictionary + { + // MethodInject and AsyncMethodInject both disabled + ["build_property.SourceGenIocFeatures"] = "Register" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc028 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC028").ToList(); + + await Assert.That(sgioc028).Count().IsEqualTo(1); + await Assert.That(sgioc028[0].GetMessage()).Contains("InitializeAsync"); + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC029Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC029Tests.cs new file mode 100644 index 0000000..0eb1088 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC029Tests.cs @@ -0,0 +1,712 @@ +namespace SourceGen.Ioc.Test.Analyzer; + +/// +/// Tests for SGIOC029: Unsupported async partial accessor type (e.g., ValueTask<T>). +/// +[Category(Constants.Analyzer)] +[Category(Constants.SGIOC029)] +public class SGIOC029Tests +{ + [Test] + public async Task SGIOC029_PartialAccessorReturnsValueTask_ForAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync(IService service) => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial ValueTask GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("ValueTask").And.Contains("Task"); + } + + [Test] + public async Task SGIOC029_PartialAccessorReturnsTaskType_ForAsyncInitService_NoDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Task GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC029_PartialAccessorReturnsBareTask_ForAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Task GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + var sgioc021 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(0); + await Assert.That(sgioc021).Count().IsEqualTo(1); + await Assert.That(sgioc021[0].GetMessage()).Contains("Task").And.Contains("GetService"); + } + + [Test] + public async Task SGIOC029_NullableValueTaskAccessor_ForAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial ValueTask? GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("ValueTask").And.Contains("Task"); + } + + [Test] + public async Task SGIOC029_AssemblyLevelRegisterFor_NonGeneric_ValueTaskPartialAccessor_ReportsDiagnostic() + { + // AsyncService is registered via assembly-level [IocRegisterFor] (non-generic) with async inject method. + // Container partial accessor returns ValueTask (unsupported). SGIOC029 SHOULD fire. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + [assembly: IocRegisterFor(typeof(TestNamespace.AsyncService), ServiceLifetime.Singleton, ServiceTypes = [typeof(TestNamespace.IService)])] + + namespace TestNamespace; + + public interface IService { } + + public class AsyncService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial ValueTask GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("ValueTask").And.Contains("Task"); + } + + [Test] + public async Task SGIOC029_PartialAccessorReturnsLazy_ForAsyncInitService_ReportsDiagnostic() + { + // Lazy on an async-init service should trigger SGIOC029. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Lazy GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("Lazy").And.Contains("Task"); + } + + [Test] + public async Task SGIOC029_PartialAccessorReturnsFunc_ForAsyncInitService_ReportsDiagnostic() + { + // Func on an async-init service should trigger SGIOC029. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Func GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("Func").And.Contains("Task"); + } + + [Test] + public async Task SGIOC029_MixedSyncAndAsyncInitRegistrations_ReportsDiagnostic() + { + // When sync and async-init implementations share the same service type+key, + // SGIOC029 should still report for non-Task async accessor shapes. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(ServiceTypes = [typeof(IMyService)], Key = "mixed")] + public class AsyncService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IMyService)], Key = "mixed")] + public class SyncService : IMyService + { + [IocInject] + public void Initialize() { } + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + [IocInject("mixed")] + public partial Lazy GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("Lazy").And.Contains("Task"); + } + + [Test] + public async Task SGIOC029_IntegrateServiceProviderTrue_ValueTaskAccessor_ForAsyncInitService_ReportsDiagnostic() + { + // SGIOC029 must fire for ALL containers regardless of IntegrateServiceProvider. + // A ValueTask return type for an async-init service is always an error — the generator + // only supports Task for async-init accessor methods. + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = true)] + public partial class TestContainer + { + public partial ValueTask GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("ValueTask").And.Contains("Task"); + } + + [Test] + public async Task SGIOC029_PartialAccessorReturnsEnumerable_ForAsyncInitService_ReportsDiagnostic() + { + // IEnumerable on an async-init service should trigger SGIOC029. + const string source = """ + using System.Collections.Generic; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IEnumerable GetServices(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("Task"); + } + + [Test] + public async Task SGIOC029_PartialAccessorReturnsArray_ForAsyncInitService_ReportsDiagnostic() + { + // IService[] on an async-init service should trigger SGIOC029. + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IService[] GetServices(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("Task"); + } + + [Test] + public async Task SGIOC029_PartialAccessorReturnsTaskLazy_ForAsyncInitService_ReportsDiagnostic() + { + // Task> (downgraded shape) on an async-init service should trigger SGIOC029. + // The generator downgrades this shape to fallback, but SGIOC029 still fires on the async-init wrapper. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Task> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("Task"); + } + + [Test] + public async Task SGIOC029_PartialAccessorReturnsLazyTask_ForAsyncInitService_ReportsDiagnostic() + { + // Lazy> (downgraded shape) on an async-init service should trigger SGIOC029. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Lazy> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("Task"); + } + + [Test] + public async Task SGIOC029_PartialAccessorReturnsValueTaskLazy_ForAsyncInitService_ReportsDiagnostic() + { + // ValueTask> (nested ValueTask) on an async-init service should trigger SGIOC029. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial ValueTask> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("Task"); + } + + [Test] + public async Task SGIOC029_PartialAccessorReturnsFuncLazy_ForAsyncInitService_ReportsDiagnostic() + { + // Func> (nested wrapper) on an async-init service should trigger SGIOC029. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Func> GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToList(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(sgioc029[0].GetMessage()).Contains("Task"); + } + + [Test] + public async Task SGIOC029_LazyWrapper_AsyncInitService_IntegrateServiceProviderTrue_ReportsDiagnostic() + { + // Lazy targeting an async-init service with IntegrateServiceProvider=true. + // SGIOC029 fires for ALL containers — async-init wrapper diagnostics are owned by SGIOC029. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = true)] + public partial class TestContainer + { + public partial Lazy GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToArray(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + } + + [Test] + public async Task SGIOC029_LazyWrapper_AsyncInitService_IntegrateServiceProviderFalse_ReportsDiagnostic() + { + // Lazy targeting an async-init service with IntegrateServiceProvider=false. + // SGIOC029 fires (not SGIOC021) — async-init wrapper diagnostics are owned by SGIOC029. + const string source = """ + using System; + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial Lazy GetService(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + var sgioc029 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC029").ToArray(); + + await Assert.That(sgioc029).Count().IsEqualTo(1); + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC030Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC030Tests.cs new file mode 100644 index 0000000..768a2bf --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC030Tests.cs @@ -0,0 +1,468 @@ +namespace SourceGen.Ioc.Test.Analyzer; + +/// +/// Tests for SGIOC030: Synchronous dependency requested for async-init-only service. +/// +[Category(Constants.Analyzer)] +[Category(Constants.SGIOC030)] +public class SGIOC030Tests +{ + [Test] + public async Task SGIOC030_ConstructorRequestsSyncTypeForAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer : IConsumer + { + public Consumer(IMyService service) { } + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc030 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030").ToList(); + + await Assert.That(sgioc030).Count().IsEqualTo(1); + await Assert.That(sgioc030[0].GetMessage()).Contains("service").And.Contains("IMyService"); + } + + [Test] + public async Task SGIOC030_ConstructorRequestsTaskTypeForAsyncInitService_NoDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer : IConsumer + { + public Consumer(Task service) { } + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC030_PropertyInjectionRequestsSyncTypeForAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer : IConsumer + { + [IocInject] + public IMyService Service { get; set; } = default!; + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,PropertyInject,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc030 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030").ToList(); + + await Assert.That(sgioc030).Count().IsEqualTo(1); + await Assert.That(sgioc030[0].GetMessage()).Contains("Service").And.Contains("IMyService"); + } + + [Test] + public async Task SGIOC030_FieldInjectionRequestsSyncTypeForAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer : IConsumer + { + [IocInject] + public IMyService ServiceField = default!; + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,FieldInject,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc030 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030").ToList(); + + await Assert.That(sgioc030).Count().IsEqualTo(1); + await Assert.That(sgioc030[0].GetMessage()).Contains("ServiceField").And.Contains("IMyService"); + } + + [Test] + public async Task SGIOC030_MethodInjectionRequestsSyncTypeForAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer : IConsumer + { + [IocInject] + public void Initialize(IMyService service) { } + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc030 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030").ToList(); + + await Assert.That(sgioc030).Count().IsEqualTo(1); + await Assert.That(sgioc030[0].GetMessage()).Contains("service").And.Contains("IMyService"); + } + + [Test] + public async Task SGIOC030_ConstructorRequestsKeyedSyncTypeForAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IMyService)], Key = "special")] + public class MyService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer([FromKeyedServices("special")] IMyService service) : IConsumer; + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc030 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030").ToList(); + + await Assert.That(sgioc030).Count().IsEqualTo(1); + await Assert.That(sgioc030[0].GetMessage()).Contains("service").And.Contains("IMyService"); + } + + [Test] + public async Task SGIOC030_ConstructorRequestsSyncType_WhenSyncRegistrationAlsoExists_NoDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IMyService)])] + public class AsyncService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IMyService)])] + public class SyncService : IMyService + { + [IocInject] + public void Initialize() { } + } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer : IConsumer + { + public Consumer(IMyService service) { } + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC030_MultiKeyedRegistration_AsyncInitKeyReportsDiagnostic() + { + // AsyncService registered under key "async" (async-init); SyncService under key "sync". + // A consumer requesting IService with key "async" has no sync resolution path → SGIOC030. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IMyService)], Key = "async")] + public class AsyncService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IMyService)], Key = "sync")] + public class SyncService : IMyService { } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer([FromKeyedServices("async")] IMyService service) : IConsumer; + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc030 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030").ToList(); + + await Assert.That(sgioc030).Count().IsEqualTo(1); + await Assert.That(sgioc030[0].GetMessage()).Contains("service").And.Contains("IMyService"); + } + + [Test] + public async Task SGIOC030_MultiKeyedRegistration_SyncKeyNoDiagnostic() + { + // AsyncService registered under key "async" (async-init); SyncService under key "sync". + // A consumer requesting IService with key "sync" has a sync resolution path → no SGIOC030. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IMyService)], Key = "async")] + public class AsyncService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IMyService)], Key = "sync")] + public class SyncService : IMyService { } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer([FromKeyedServices("sync")] IMyService service) : IConsumer; + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC030_KeyedPropertyInjectionRequestsSyncTypeForAsyncInitService_ReportsDiagnostic() + { + const string source = """ + using System.Threading.Tasks; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IMyService)], Key = "special")] + public class MyService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer : IConsumer + { + [IocInject("special")] + public IMyService Service { get; set; } = default!; + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,PropertyInject,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc030 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030").ToList(); + + await Assert.That(sgioc030).Count().IsEqualTo(1); + await Assert.That(sgioc030[0].GetMessage()).Contains("Service").And.Contains("IMyService"); + } + + [Test] + public async Task SGIOC030_SameImplRegisteredForTwoKeyedServices_AsyncInitConsumerPathReportsDiagnostic() + { + // A single implementation class is registered for two different keyed service types + // via assembly-level IocRegisterFor attributes. This exercises the AllRegistrations + // duplicate-impl path in RegisterAnalyzer.ServiceCollection.cs: + // first TryAdd succeeds; second TryAdd fails → adds to existingInfo.AllRegistrations. + // The consumer requests IServiceA (key1) synchronously → SGIOC030 because + // the only registration for (IServiceA, "key1") is async-init. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + [assembly: IocRegisterFor(typeof(TestNamespace.MultiKeyImpl), ServiceTypes = [typeof(TestNamespace.IServiceA)], Key = "key1")] + [assembly: IocRegisterFor(typeof(TestNamespace.MultiKeyImpl), ServiceTypes = [typeof(TestNamespace.IServiceB)], Key = "key2")] + + namespace TestNamespace; + + public interface IServiceA { } + public interface IServiceB { } + public interface IConsumer { } + + public class MultiKeyImpl : IServiceA, IServiceB + { + [IocInject] + public Task InjectAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer : IConsumer + { + public Consumer([FromKeyedServices("key1")] IServiceA service) { } + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc030 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, Constants.SGIOC030).ToList(); + + await Assert.That(sgioc030).Count().IsEqualTo(1); + await Assert.That(sgioc030[0].GetMessage()).Contains("service").And.Contains("IServiceA"); + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_AsyncService_ExcludedFromCollectionResolver.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_AsyncService_ExcludedFromCollectionResolver.verified.txt new file mode 100644 index 0000000..93a28a3 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_AsyncService_ExcludedFromCollectionResolver.verified.txt @@ -0,0 +1,327 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_SyncImpl = parent._testNamespace_SyncImpl; + _testNamespace_AsyncImpl = parent._testNamespace_AsyncImpl; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.SyncImpl? _testNamespace_SyncImpl; + private global::TestNamespace.SyncImpl GetTestNamespace_SyncImpl() + { + if(_testNamespace_SyncImpl is not null) return _testNamespace_SyncImpl; + + var instance = new global::TestNamespace.SyncImpl(); + + _testNamespace_SyncImpl = instance; + return instance; + } + + private global::System.Threading.Tasks.Task? _testNamespace_AsyncImpl; + + private async global::System.Threading.Tasks.Task GetTestNamespace_AsyncImplAsync() + { + if(_testNamespace_AsyncImpl is not null) + return await _testNamespace_AsyncImpl; + + _testNamespace_AsyncImpl = CreateTestNamespace_AsyncImplAsync(); + return await _testNamespace_AsyncImpl; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_AsyncImplAsync() + { + var instance = new global::TestNamespace.AsyncImpl(); + await instance.InitAsync(); + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.SyncImpl), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_SyncImpl()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_AsyncImplAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.AsyncImpl), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_AsyncImplAsync()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_AsyncImpl); + DisposeService(_testNamespace_SyncImpl); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_AsyncImpl); + await DisposeServiceAsync(_testNamespace_SyncImpl); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_AsyncService_ExcludedFromEagerInit.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_AsyncService_ExcludedFromEagerInit.verified.txt new file mode 100644 index 0000000..22559b8 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_AsyncService_ExcludedFromEagerInit.verified.txt @@ -0,0 +1,331 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_SyncService = GetTestNamespace_SyncService(); + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_AsyncService = parent._testNamespace_AsyncService; + _testNamespace_SyncService = parent._testNamespace_SyncService; + } + + #endregion + + #region Service Resolution + + private global::System.Threading.Tasks.Task? _testNamespace_AsyncService; + + private async global::System.Threading.Tasks.Task GetTestNamespace_AsyncServiceAsync() + { + if(_testNamespace_AsyncService is not null) + return await _testNamespace_AsyncService; + + _testNamespace_AsyncService = CreateTestNamespace_AsyncServiceAsync(); + return await _testNamespace_AsyncService; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_AsyncServiceAsync() + { + var instance = new global::TestNamespace.AsyncService(); + await instance.InitAsync(); + return instance; + } + + private global::TestNamespace.SyncService _testNamespace_SyncService = null!; + private global::TestNamespace.SyncService GetTestNamespace_SyncService() + { + if(_testNamespace_SyncService is not null) return _testNamespace_SyncService; + + var instance = new global::TestNamespace.SyncService(); + + _testNamespace_SyncService = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.AsyncService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_AsyncServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IAsyncService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_AsyncServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.SyncService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_SyncService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.ISyncService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_SyncService!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_SyncService); + DisposeService(_testNamespace_AsyncService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_SyncService); + await DisposeServiceAsync(_testNamespace_AsyncService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MixedInjection_GeneratesAllCallsInOrder.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MixedInjection_GeneratesAllCallsInOrder.verified.txt new file mode 100644 index 0000000..580302f --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MixedInjection_GeneratesAllCallsInOrder.verified.txt @@ -0,0 +1,361 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_Dep1 = parent._testNamespace_Dep1; + _testNamespace_Dep2 = parent._testNamespace_Dep2; + _testNamespace_Dep3 = parent._testNamespace_Dep3; + _testNamespace_MyService = parent._testNamespace_MyService; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.Dep1? _testNamespace_Dep1; + private global::TestNamespace.Dep1 GetTestNamespace_Dep1() + { + if(_testNamespace_Dep1 is not null) return _testNamespace_Dep1; + + var instance = new global::TestNamespace.Dep1(); + + _testNamespace_Dep1 = instance; + return instance; + } + + private global::TestNamespace.Dep2? _testNamespace_Dep2; + private global::TestNamespace.Dep2 GetTestNamespace_Dep2() + { + if(_testNamespace_Dep2 is not null) return _testNamespace_Dep2; + + var instance = new global::TestNamespace.Dep2(); + + _testNamespace_Dep2 = instance; + return instance; + } + + private global::TestNamespace.Dep3? _testNamespace_Dep3; + private global::TestNamespace.Dep3 GetTestNamespace_Dep3() + { + if(_testNamespace_Dep3 is not null) return _testNamespace_Dep3; + + var instance = new global::TestNamespace.Dep3(); + + _testNamespace_Dep3 = instance; + return instance; + } + + private global::System.Threading.Tasks.Task? _testNamespace_MyService; + + private async global::System.Threading.Tasks.Task GetTestNamespace_MyServiceAsync() + { + if(_testNamespace_MyService is not null) + return await _testNamespace_MyService; + + _testNamespace_MyService = CreateTestNamespace_MyServiceAsync(); + return await _testNamespace_MyService; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_MyServiceAsync() + { + var instance = new global::TestNamespace.MyService() + { + Dep1 = (global::TestNamespace.IDep1)GetRequiredService(typeof(global::TestNamespace.IDep1)), + }; + instance.SyncInit((global::TestNamespace.IDep2)GetRequiredService(typeof(global::TestNamespace.IDep2))); + await instance.AsyncInit((global::TestNamespace.IDep3)GetRequiredService(typeof(global::TestNamespace.IDep3))); + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.Dep1), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_Dep1()), + new(new ServiceIdentifier(typeof(global::TestNamespace.Dep2), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_Dep2()), + new(new ServiceIdentifier(typeof(global::TestNamespace.Dep3), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_Dep3()), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_MyService); + DisposeService(_testNamespace_Dep3); + DisposeService(_testNamespace_Dep2); + DisposeService(_testNamespace_Dep1); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_MyService); + await DisposeServiceAsync(_testNamespace_Dep3); + await DisposeServiceAsync(_testNamespace_Dep2); + await DisposeServiceAsync(_testNamespace_Dep1); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MultipleServiceTypes_ShareSingleTaskField.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MultipleServiceTypes_ShareSingleTaskField.verified.txt new file mode 100644 index 0000000..17a5e56 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MultipleServiceTypes_ShareSingleTaskField.verified.txt @@ -0,0 +1,313 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + } + + #endregion + + #region Service Resolution + + private global::System.Threading.Tasks.Task? _testNamespace_MyService; + + private async global::System.Threading.Tasks.Task GetTestNamespace_MyServiceAsync() + { + if(_testNamespace_MyService is not null) + return await _testNamespace_MyService; + + _testNamespace_MyService = CreateTestNamespace_MyServiceAsync(); + return await _testNamespace_MyService; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_MyServiceAsync() + { + var instance = new global::TestNamespace.MyService(); + await instance.InitAsync(); + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IFoo), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IBar), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_NonGenericTaskDependency_TreatedAsPlainServiceInContainer.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_NonGenericTaskDependency_TreatedAsPlainServiceInContainer.verified.txt new file mode 100644 index 0000000..9cbfc6c --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_NonGenericTaskDependency_TreatedAsPlainServiceInContainer.verified.txt @@ -0,0 +1,274 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.Consumer? _testNamespace_Consumer; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer((global::System.Threading.Tasks.Task)GetRequiredService(typeof(global::System.Threading.Tasks.Task))); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_Consumer()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_PartialTaskAccessor_GeneratesAsyncMethod.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_PartialTaskAccessor_GeneratesAsyncMethod.verified.txt new file mode 100644 index 0000000..2c239c7 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_PartialTaskAccessor_GeneratesAsyncMethod.verified.txt @@ -0,0 +1,318 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + } + + #endregion + + #region Service Resolution + + private global::System.Threading.Tasks.Task? _testNamespace_MyService; + + private async global::System.Threading.Tasks.Task GetTestNamespace_MyServiceAsync() + { + if(_testNamespace_MyService is not null) + return await _testNamespace_MyService; + + _testNamespace_MyService = CreateTestNamespace_MyServiceAsync(); + return await _testNamespace_MyService; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_MyServiceAsync() + { + var instance = new global::TestNamespace.MyService(); + await instance.InitAsync(); + return instance; + } + + #endregion + + #region Partial Accessor Implementations + + public partial async global::System.Threading.Tasks.Task GetMyServiceAsync() => await GetTestNamespace_MyServiceAsync(); + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonWithAsyncInit_GeneratesAsyncResolver_None.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonWithAsyncInit_GeneratesAsyncResolver_None.verified.txt new file mode 100644 index 0000000..006064e --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonWithAsyncInit_GeneratesAsyncResolver_None.verified.txt @@ -0,0 +1,327 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_Dependency = parent._testNamespace_Dependency; + _testNamespace_MyService = parent._testNamespace_MyService; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.Dependency? _testNamespace_Dependency; + private global::TestNamespace.Dependency GetTestNamespace_Dependency() + { + if(_testNamespace_Dependency is not null) return _testNamespace_Dependency; + + var instance = new global::TestNamespace.Dependency(); + + _testNamespace_Dependency = instance; + return instance; + } + + private global::System.Threading.Tasks.Task? _testNamespace_MyService; + + private async global::System.Threading.Tasks.Task GetTestNamespace_MyServiceAsync() + { + if(_testNamespace_MyService is not null) + return await _testNamespace_MyService; + + _testNamespace_MyService = CreateTestNamespace_MyServiceAsync(); + return await _testNamespace_MyService; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_MyServiceAsync() + { + var instance = new global::TestNamespace.MyService(); + await instance.InitAsync((global::TestNamespace.IDependency)GetRequiredService(typeof(global::TestNamespace.IDependency))); + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.Dependency), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_Dependency()), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_MyService); + DisposeService(_testNamespace_Dependency); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_MyService); + await DisposeServiceAsync(_testNamespace_Dependency); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonWithAsyncInit_GeneratesAsyncResolver_SemaphoreSlim.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonWithAsyncInit_GeneratesAsyncResolver_SemaphoreSlim.verified.txt new file mode 100644 index 0000000..1b9dc9d --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonWithAsyncInit_GeneratesAsyncResolver_SemaphoreSlim.verified.txt @@ -0,0 +1,354 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_Dependency = parent._testNamespace_Dependency; + _testNamespace_MyService = parent._testNamespace_MyService; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.Dependency? _testNamespace_Dependency; + private readonly SemaphoreSlim _testNamespace_DependencySemaphore = new(1, 1); + private global::TestNamespace.Dependency GetTestNamespace_Dependency() + { + if(_testNamespace_Dependency is not null) return _testNamespace_Dependency; + + _testNamespace_DependencySemaphore.Wait(); + try + { + if(_testNamespace_Dependency is not null) return _testNamespace_Dependency; + + var instance = new global::TestNamespace.Dependency(); + + _testNamespace_Dependency = instance; + return instance; + } + finally + { + _testNamespace_DependencySemaphore.Release(); + } + } + + private global::System.Threading.Tasks.Task? _testNamespace_MyService; + private readonly global::System.Threading.SemaphoreSlim _testNamespace_MyServiceSemaphore = new(1, 1); + + private async global::System.Threading.Tasks.Task GetTestNamespace_MyServiceAsync() + { + if(_testNamespace_MyService is not null) + return await _testNamespace_MyService; + + await _testNamespace_MyServiceSemaphore.WaitAsync(); + try + { + if(_testNamespace_MyService is null) + { + _testNamespace_MyService = CreateTestNamespace_MyServiceAsync(); + } + } + finally + { + _testNamespace_MyServiceSemaphore.Release(); + } + return await _testNamespace_MyService; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_MyServiceAsync() + { + var instance = new global::TestNamespace.MyService(); + await instance.InitAsync((global::TestNamespace.IDependency)GetRequiredService(typeof(global::TestNamespace.IDependency))); + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.Dependency), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_Dependency()), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_MyService); + _testNamespace_MyServiceSemaphore.Dispose(); + DisposeService(_testNamespace_Dependency); + _testNamespace_DependencySemaphore.Dispose(); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_MyService); + _testNamespace_MyServiceSemaphore.Dispose(); + await DisposeServiceAsync(_testNamespace_Dependency); + _testNamespace_DependencySemaphore.Dispose(); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskDependency_AsyncInitService_UsesAsyncResolver.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskDependency_AsyncInitService_UsesAsyncResolver.verified.txt new file mode 100644 index 0000000..272433e --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskDependency_AsyncInitService_UsesAsyncResolver.verified.txt @@ -0,0 +1,327 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::System.Threading.Tasks.Task? _testNamespace_MyService; + + private async global::System.Threading.Tasks.Task GetTestNamespace_MyServiceAsync() + { + if(_testNamespace_MyService is not null) + return await _testNamespace_MyService; + + _testNamespace_MyService = CreateTestNamespace_MyServiceAsync(); + return await _testNamespace_MyService; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_MyServiceAsync() + { + var instance = new global::TestNamespace.MyService(); + await instance.InitAsync(); + return instance; + } + + private global::TestNamespace.Consumer? _testNamespace_Consumer; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer(((global::System.Func>)(async () => (global::TestNamespace.IMyService)(await GetTestNamespace_MyServiceAsync())))()); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_Consumer()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskDependency_SyncService_UsesTaskFromResult.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskDependency_SyncService_UsesTaskFromResult.verified.txt new file mode 100644 index 0000000..a05016b --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskDependency_SyncService_UsesTaskFromResult.verified.txt @@ -0,0 +1,290 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_SyncService = parent._testNamespace_SyncService; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.SyncService? _testNamespace_SyncService; + private global::TestNamespace.SyncService GetTestNamespace_SyncService() + { + if(_testNamespace_SyncService is not null) return _testNamespace_SyncService; + + var instance = new global::TestNamespace.SyncService(); + + _testNamespace_SyncService = instance; + return instance; + } + + private global::TestNamespace.Consumer? _testNamespace_Consumer; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer(global::System.Threading.Tasks.Task.FromResult((global::TestNamespace.ISyncService)GetTestNamespace_SyncService())); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.SyncService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_SyncService()), + new(new ServiceIdentifier(typeof(global::TestNamespace.ISyncService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_SyncService()), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_Consumer()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + DisposeService(_testNamespace_SyncService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + await DisposeServiceAsync(_testNamespace_SyncService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TransientWithAsyncInit_GeneratesAsyncCreateMethod.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TransientWithAsyncInit_GeneratesAsyncCreateMethod.verified.txt new file mode 100644 index 0000000..29bf117 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TransientWithAsyncInit_GeneratesAsyncCreateMethod.verified.txt @@ -0,0 +1,286 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_Dependency = GetTestNamespace_Dependency(); + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_Dependency = parent._testNamespace_Dependency; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.Dependency _testNamespace_Dependency = null!; + private global::TestNamespace.Dependency GetTestNamespace_Dependency() + { + if(_testNamespace_Dependency is not null) return _testNamespace_Dependency; + + var instance = new global::TestNamespace.Dependency(); + + _testNamespace_Dependency = instance; + return instance; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_MyServiceAsync() + { + var instance = new global::TestNamespace.MyService(); + await instance.InitAsync((global::TestNamespace.IDependency)GetRequiredService(typeof(global::TestNamespace.IDependency))); + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.Dependency), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Dependency!), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.CreateTestNamespace_MyServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.CreateTestNamespace_MyServiceAsync()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Dependency); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Dependency); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_WithDecorator_GeneratesAsyncResolverWithDecoratorApplication.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_WithDecorator_GeneratesAsyncResolverWithDecoratorApplication.verified.txt new file mode 100644 index 0000000..8663f8a --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_WithDecorator_GeneratesAsyncResolverWithDecoratorApplication.verified.txt @@ -0,0 +1,315 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + } + + #endregion + + #region Service Resolution + + private global::System.Threading.Tasks.Task? _testNamespace_MyService; + + private async global::System.Threading.Tasks.Task GetTestNamespace_MyServiceAsync() + { + if(_testNamespace_MyService is not null) + return await _testNamespace_MyService; + + _testNamespace_MyService = CreateTestNamespace_MyServiceAsync(); + return await _testNamespace_MyService; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_MyServiceAsync() + { + var baseInstance = new global::TestNamespace.MyService(); + await baseInstance.InitAsync(); + + global::TestNamespace.IMyService instance = baseInstance; + instance = new global::TestNamespace.MyServiceDecorator(instance); + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_MyServiceAsync()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.cs b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.cs new file mode 100644 index 0000000..722a86c --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.cs @@ -0,0 +1,548 @@ +namespace SourceGen.Ioc.Test.ContainerSourceGeneratorSnapshot; + +/// +/// Snapshot tests for async method injection container generation (Phase 3B). +/// Verifies that async-init services generate Task<ImplType> cached fields, +/// async routing resolver methods, and async creation methods. +/// +[Category(Constants.SourceGeneratorSnapshot)] +[Category(Constants.ContainerGeneration)] +[Category(Constants.AsyncMethodInject)] +public class AsyncMethodInjectTests +{ + private const string AsyncMethodInjectFeatures = "Register,Container,PropertyInject,FieldInject,MethodInject,AsyncMethodInject"; + + /// + /// Suppressed diagnostics for partial accessor tests: + /// CS8795 (partial method must have implementation) and CS9248 (partial property must have implementation). + /// These are expected because the source generator provides the implementation. + /// + private static readonly IReadOnlySet SuppressedPartialDiagnostics = new HashSet(["CS8795", "CS9248"]); + + // ───────────────────────────────────────────────────────────────────────── + // Basic async resolver generation + // ───────────────────────────────────────────────────────────────────────── + + [Test] + public async Task AsyncMethodInject_SingletonWithAsyncInit_GeneratesAsyncResolver_None() + { + // A singleton with a single async-init method should get a Task? field + // and async routing + creation methods. With ThreadSafeStrategy.None, no semaphore. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IDependency { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dependency : IDependency { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public async Task InitAsync(IDependency dep) { } + } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task AsyncMethodInject_SingletonWithAsyncInit_GeneratesAsyncResolver_SemaphoreSlim() + { + // A singleton with a single async-init method and SemaphoreSlim strategy should get + // a Task? field, a SemaphoreSlim field, and the async routing body uses WaitAsync(). + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IDependency { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dependency : IDependency { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public async Task InitAsync(IDependency dep) { } + } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.SemaphoreSlim, EagerResolveOptions = EagerResolveOptions.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task AsyncMethodInject_TransientWithAsyncInit_GeneratesAsyncCreateMethod() + { + // A transient async-init service produces only a creation method (no caching field). + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IDependency { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dependency : IDependency { } + + [IocRegister(Lifetime = ServiceLifetime.Transient, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public async Task InitAsync(IDependency dep) { } + } + + [IocContainer] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + // ───────────────────────────────────────────────────────────────────────── + // Mixed injection: property + sync method + async method + // ───────────────────────────────────────────────────────────────────────── + + [Test] + public async Task AsyncMethodInject_MixedInjection_GeneratesAllCallsInOrder() + { + // Property injection → sync method call → await async method call, all in CreateAsync. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IDep1 { } + public interface IDep2 { } + public interface IDep3 { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dep1 : IDep1 { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dep2 : IDep2 { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dep3 : IDep3 { } + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public IDep1 Dep1 { get; set; } = default!; + + [IocInject] + public void SyncInit(IDep2 dep2) { } + + [IocInject] + public async Task AsyncInit(IDep3 dep3) { } + } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + // ───────────────────────────────────────────────────────────────────────── + // Shared field deduplication — multiple service aliases share ONE Task field + // ───────────────────────────────────────────────────────────────────────── + + [Test] + public async Task AsyncMethodInject_MultipleServiceTypes_ShareSingleTaskField() + { + // MyService registered as both IFoo and IBar. + // Both aliases share the same Task? field and resolver method. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IFoo { } + public interface IBar { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IFoo), typeof(IBar)])] + public class MyService : IFoo, IBar + { + [IocInject] + public async Task InitAsync() { } + } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + // ───────────────────────────────────────────────────────────────────────── + // Eager resolve exclusion — async-init services must NOT be in constructor init + // ───────────────────────────────────────────────────────────────────────── + + [Test] + public async Task AsyncMethodInject_AsyncService_ExcludedFromEagerInit() + { + // EagerResolveOptions.Singleton is set, but the async-init service must NOT be eager. + // The sync-only SyncService IS eager. Verify that only SyncService appears in the ctor. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IAsyncService { } + public interface ISyncService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IAsyncService)])] + public class AsyncService : IAsyncService + { + [IocInject] + public async Task InitAsync() { } + } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(ISyncService)])] + public class SyncService : ISyncService { } + + [IocContainer(EagerResolveOptions = EagerResolveOptions.Singleton, ThreadSafeStrategy = ThreadSafeStrategy.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + // ───────────────────────────────────────────────────────────────────────── + // Collection exclusion — async-init services must NOT appear in collection resolvers + // ───────────────────────────────────────────────────────────────────────── + + [Test] + public async Task AsyncMethodInject_AsyncService_ExcludedFromCollectionResolver() + { + // Multiple IMyService registrations — only the sync one appears in the array resolver. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class SyncImpl : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class AsyncImpl : IMyService + { + [IocInject] + public async Task InitAsync() { } + } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + // ───────────────────────────────────────────────────────────────────────── + // Task wrapper resolution + // ───────────────────────────────────────────────────────────────────────── + + [Test] + public async Task AsyncMethodInject_TaskDependency_AsyncInitService_UsesAsyncResolver() + { + // A consumer takes Task — should resolve via async/await projection (not ContinueWith) + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public async Task InitAsync() { } + } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(Task service) { } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task AsyncMethodInject_TaskDependency_SyncService_UsesTaskFromResult() + { + // A consumer takes Task — should resolve via Task.FromResult(GetSync()) + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface ISyncService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(ISyncService)])] + public class SyncService : ISyncService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(Task service) { } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + // ───────────────────────────────────────────────────────────────────────── + // Partial accessor with Task return type + // ───────────────────────────────────────────────────────────────────────── + + [Test] + public async Task AsyncMethodInject_PartialTaskAccessor_GeneratesAsyncMethod() + { + // A partial method returning Task → generated as async + await. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public async Task InitAsync() { } + } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.None)] + public partial class TestContainer + { + public partial Task GetMyServiceAsync(); + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + suppressedInitialDiagnosticIds: SuppressedPartialDiagnostics, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + // ───────────────────────────────────────────────────────────────────────── + // Decorator support for async-init services + // ───────────────────────────────────────────────────────────────────────── + + [Test] + public async Task AsyncMethodInject_WithDecorator_GeneratesAsyncResolverWithDecoratorApplication() + { + // An async-init service with a decorator: + // - the creation method must await the async member before applying the decorator. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)], Decorators = [typeof(MyServiceDecorator)])] + public class MyService : IMyService + { + [IocInject] + public async Task InitAsync() { } + } + + public class MyServiceDecorator(IMyService inner) : IMyService { } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task AsyncMethodInject_NonGenericTaskDependency_TreatedAsPlainServiceInContainer() + { + // Non-generic Task (arity 0) must NOT be classified as WrapperKind.Task in the + // container output path. The container should resolve it as a plain service, not + // attempt to unwrap a Task.InnerType (which would cause a NullReferenceException). + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(Task nonGenericTask) + { + } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + // Generator must not throw NRE — arity-0 Task is a plain service dependency. + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.EnumerableTask_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.EnumerableTask_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..3868ac3 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.EnumerableTask_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,294 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_MyService = GetTestNamespace_MyService(); + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.MyService _testNamespace_MyService = null!; + private global::TestNamespace.MyService GetTestNamespace_MyService() + { + if(_testNamespace_MyService is not null) return _testNamespace_MyService; + + var instance = new global::TestNamespace.MyService(); + + _testNamespace_MyService = instance; + return instance; + } + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer((global::System.Collections.Generic.IEnumerable>)GetRequiredService(typeof(global::System.Collections.Generic.IEnumerable>))); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenOptional_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenOptional_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..8823585 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenOptional_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,277 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class ExplicitContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public ExplicitContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public ExplicitContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private ExplicitContainer(ExplicitContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer(GetService(typeof(global::System.Func)) as global::System.Func); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new ExplicitContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.ExplicitContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new ExplicitContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenReturnTypeNotRegistered_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenReturnTypeNotRegistered_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..7053a10 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenReturnTypeNotRegistered_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,277 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class ExplicitContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public ExplicitContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public ExplicitContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private ExplicitContainer(ExplicitContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer((global::System.Func)GetRequiredService(typeof(global::System.Func))); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new ExplicitContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.ExplicitContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new ExplicitContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WithKeyedService_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WithKeyedService_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..c066760 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WithKeyedService_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,277 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class ExplicitContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public ExplicitContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public ExplicitContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private ExplicitContainer(ExplicitContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer((global::System.Func)GetRequiredKeyedService(typeof(global::System.Func), "myKey")); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new ExplicitContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.ExplicitContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new ExplicitContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncTask_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncTask_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..e10e813 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncTask_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,294 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_MyService = GetTestNamespace_MyService(); + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.MyService _testNamespace_MyService = null!; + private global::TestNamespace.MyService GetTestNamespace_MyService() + { + if(_testNamespace_MyService is not null) return _testNamespace_MyService; + + var instance = new global::TestNamespace.MyService(); + + _testNamespace_MyService = instance; + return instance; + } + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer((global::System.Func>)GetRequiredService(typeof(global::System.Func>))); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyTask_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyTask_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..6566b5b --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyTask_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,294 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_MyService = GetTestNamespace_MyService(); + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.MyService _testNamespace_MyService = null!; + private global::TestNamespace.MyService GetTestNamespace_MyService() + { + if(_testNamespace_MyService is not null) return _testNamespace_MyService; + + var instance = new global::TestNamespace.MyService(); + + _testNamespace_MyService = instance; + return instance; + } + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer((global::System.Lazy>)GetRequiredService(typeof(global::System.Lazy>))); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskEnumerable_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskEnumerable_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..ec13a60 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskEnumerable_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,294 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_MyService = GetTestNamespace_MyService(); + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.MyService _testNamespace_MyService = null!; + private global::TestNamespace.MyService GetTestNamespace_MyService() + { + if(_testNamespace_MyService is not null) return _testNamespace_MyService; + + var instance = new global::TestNamespace.MyService(); + + _testNamespace_MyService = instance; + return instance; + } + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer((global::System.Threading.Tasks.Task>)GetRequiredService(typeof(global::System.Threading.Tasks.Task>))); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskFunc_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskFunc_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..52cf738 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskFunc_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,294 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_MyService = GetTestNamespace_MyService(); + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.MyService _testNamespace_MyService = null!; + private global::TestNamespace.MyService GetTestNamespace_MyService() + { + if(_testNamespace_MyService is not null) return _testNamespace_MyService; + + var instance = new global::TestNamespace.MyService(); + + _testNamespace_MyService = instance; + return instance; + } + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer((global::System.Threading.Tasks.Task>)GetRequiredService(typeof(global::System.Threading.Tasks.Task>))); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskLazy_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskLazy_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..aeb281c --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskLazy_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,294 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_MyService = GetTestNamespace_MyService(); + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.MyService _testNamespace_MyService = null!; + private global::TestNamespace.MyService GetTestNamespace_MyService() + { + if(_testNamespace_MyService is not null) return _testNamespace_MyService; + + var instance = new global::TestNamespace.MyService(); + + _testNamespace_MyService = instance; + return instance; + } + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer((global::System.Threading.Tasks.Task>)GetRequiredService(typeof(global::System.Threading.Tasks.Task>))); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..517b59c --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,294 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_MyService = GetTestNamespace_MyService(); + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.MyService _testNamespace_MyService = null!; + private global::TestNamespace.MyService GetTestNamespace_MyService() + { + if(_testNamespace_MyService is not null) return _testNamespace_MyService; + + var instance = new global::TestNamespace.MyService(); + + _testNamespace_MyService = instance; + return instance; + } + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer(GetServices>>()); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_FuncLazyEnumerable_GeneratesNestedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_FuncLazyEnumerable_GeneratesNestedResolution.verified.txt new file mode 100644 index 0000000..5bd9767 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_FuncLazyEnumerable_GeneratesNestedResolution.verified.txt @@ -0,0 +1,294 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_MyService = GetTestNamespace_MyService(); + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.MyService _testNamespace_MyService = null!; + private global::TestNamespace.MyService GetTestNamespace_MyService() + { + if(_testNamespace_MyService is not null) return _testNamespace_MyService; + + var instance = new global::TestNamespace.MyService(); + + _testNamespace_MyService = instance; + return instance; + } + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer(new global::System.Func>>(() => new global::System.Lazy>(() => GetServices(), global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication))); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_LazyFuncLazy_GeneratesNestedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_LazyFuncLazy_GeneratesNestedResolution.verified.txt new file mode 100644 index 0000000..6a10043 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_LazyFuncLazy_GeneratesNestedResolution.verified.txt @@ -0,0 +1,294 @@ +// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + + // Initialize eager singletons + _testNamespace_MyService = GetTestNamespace_MyService(); + _testNamespace_Consumer = GetTestNamespace_Consumer(); + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_MyService = parent._testNamespace_MyService; + _testNamespace_Consumer = parent._testNamespace_Consumer; + } + + #endregion + + #region Service Resolution + + private global::TestNamespace.MyService _testNamespace_MyService = null!; + private global::TestNamespace.MyService GetTestNamespace_MyService() + { + if(_testNamespace_MyService is not null) return _testNamespace_MyService; + + var instance = new global::TestNamespace.MyService(); + + _testNamespace_MyService = instance; + return instance; + } + + private global::TestNamespace.Consumer _testNamespace_Consumer = null!; + private global::TestNamespace.Consumer GetTestNamespace_Consumer() + { + if(_testNamespace_Consumer is not null) return _testNamespace_Consumer; + + var instance = new global::TestNamespace.Consumer(new global::System.Lazy>>(() => new global::System.Func>(() => new global::System.Lazy(() => GetTestNamespace_MyService(), global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication)), global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication)); + + _testNamespace_Consumer = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.MyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.IMyService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_MyService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.Consumer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_Consumer!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_Consumer); + DisposeService(_testNamespace_MyService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_Consumer); + await DisposeServiceAsync(_testNamespace_MyService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.cs b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.cs index 6da8593..124b777 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.cs +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.cs @@ -679,4 +679,412 @@ public partial class TestContainer { } await Verify(generatedSource); } + + [Test] + public async Task FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenReturnTypeNotRegistered_FallsBackToServiceProvider() + { + // Regression test: when ExplicitOnly container has a consumer depending on Func + // and IService is NOT explicitly registered, the generator must NOT recurse into + // BuildWrapperExpressionForContainer via BuildServiceResolutionCallForContainer. + // Expected: generated code uses GetRequiredService(typeof(Func)) fallback. + const string source = """ + using System; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + // IService is intentionally NOT registered — not via [IocRegister] and not via [IocRegisterFor] + // Consumer is registered via [IocRegisterFor] on the ExplicitOnly container + public class Consumer(Func serviceFactory) + { + public Func ServiceFactory { get; } = serviceFactory; + } + + [IocContainer(ExplicitOnly = true)] + [IocRegisterFor(typeof(Consumer), Lifetime = ServiceLifetime.Singleton)] + public partial class ExplicitContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WithKeyedService_FallsBackToServiceProvider() + { + const string source = """ + using System; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + public class Consumer([IocInject(Key = "myKey")] Func serviceFactory) + { + public Func ServiceFactory { get; } = serviceFactory; + } + + [IocContainer(ExplicitOnly = true)] + [IocRegisterFor(typeof(Consumer), Lifetime = ServiceLifetime.Singleton)] + public partial class ExplicitContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenOptional_FallsBackToServiceProvider() + { + const string source = """ + using System; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + public class Consumer(Func? serviceFactory) + { + public Func? ServiceFactory { get; } = serviceFactory; + } + + [IocContainer(ExplicitOnly = true)] + [IocRegisterFor(typeof(Consumer), Lifetime = ServiceLifetime.Singleton)] + public partial class ExplicitContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task TaskLazy_FallsBackToServiceProvider() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(System.Threading.Tasks.Task> service) + { + public System.Threading.Tasks.Task> Service { get; } = service; + } + + [IocDiscover] + [IocDiscover] + [IocContainer] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task LazyTask_FallsBackToServiceProvider() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(Lazy> service) + { + public Lazy> Service { get; } = service; + } + + [IocDiscover] + [IocDiscover] + [IocContainer] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task TaskFunc_FallsBackToServiceProvider() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(System.Threading.Tasks.Task> service) + { + public System.Threading.Tasks.Task> Service { get; } = service; + } + + [IocDiscover] + [IocDiscover] + [IocContainer] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task FuncTask_FallsBackToServiceProvider() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(Func> service) + { + public Func> Service { get; } = service; + } + + [IocDiscover] + [IocDiscover] + [IocContainer] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task TaskEnumerable_FallsBackToServiceProvider() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(System.Threading.Tasks.Task> service) + { + public System.Threading.Tasks.Task> Service { get; } = service; + } + + [IocDiscover] + [IocDiscover] + [IocContainer] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task EnumerableTask_FallsBackToServiceProvider() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(IEnumerable> service) + { + public IEnumerable> Service { get; } = service; + } + + [IocDiscover] + [IocDiscover] + [IocContainer] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task TripleNested_LazyFuncLazy_GeneratesNestedResolution() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(Lazy>> triple) + { + public Lazy>> Triple { get; } = triple; + } + + [IocDiscover] + [IocDiscover] + [IocContainer] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(IEnumerable>> triple) + { + public IEnumerable>> Triple { get; } = triple; + } + + [IocDiscover] + [IocDiscover] + [IocContainer] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task TripleNested_FuncLazyEnumerable_GeneratesNestedResolution() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(Func>> triple) + { + public Func>> Triple { get; } = triple; + } + + [IocDiscover] + [IocDiscover] + [IocContainer] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } } diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Helpers/Constants.cs b/src/Ioc/test/SourceGen.Ioc.Test/Helpers/Constants.cs index 60509b8..7dadfec 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/Helpers/Constants.cs +++ b/src/Ioc/test/SourceGen.Ioc.Test/Helpers/Constants.cs @@ -26,6 +26,7 @@ internal static class Constants public const string ContainerOptions = "ContainerOptions"; public const string PartialAccessor = "PartialAccessor"; public const string WrapperType = "WrapperType"; + public const string AsyncMethodInject = "AsyncMethodInject"; public const string SGIOC001 = "SGIOC001"; public const string SGIOC002 = "SGIOC002"; @@ -52,4 +53,9 @@ internal static class Constants public const string SGIOC023 = "SGIOC023"; public const string SGIOC024 = "SGIOC024"; public const string SGIOC025 = "SGIOC025"; + public const string SGIOC026 = "SGIOC026"; + public const string SGIOC027 = "SGIOC027"; + public const string SGIOC028 = "SGIOC028"; + public const string SGIOC029 = "SGIOC029"; + public const string SGIOC030 = "SGIOC030"; } diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_BasicAsyncMethod_GeneratesTaskRegistration.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_BasicAsyncMethod_GeneratesTaskRegistration.verified.txt new file mode 100644 index 0000000..a7497d9 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_BasicAsyncMethod_GeneratesTaskRegistration.verified.txt @@ -0,0 +1,42 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton>((global::System.IServiceProvider sp) => + { + async global::System.Threading.Tasks.Task Init() + { + var s0_m0 = sp.GetRequiredService(); + var s0 = new global::TestNamespace.MyService(); + await s0.InitAsync(s0_m0); + return s0; + } + return Init(); + }); + services.AddSingleton>(async (global::System.IServiceProvider sp) => await sp.GetRequiredService>()); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MixedSyncAndAsyncMethods_GeneratesTaskRegistration.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MixedSyncAndAsyncMethods_GeneratesTaskRegistration.verified.txt new file mode 100644 index 0000000..67a13a9 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MixedSyncAndAsyncMethods_GeneratesTaskRegistration.verified.txt @@ -0,0 +1,47 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton>((global::System.IServiceProvider sp) => + { + async global::System.Threading.Tasks.Task Init() + { + var s0_p0 = sp.GetRequiredService(); + var s0_m1 = sp.GetRequiredService(); + var s0_m2 = sp.GetRequiredService(); + var s0 = new global::TestNamespace.MyService() { Dep1 = s0_p0 }; + s0.SyncInit(s0_m1); + await s0.AsyncInit(s0_m2); + return s0; + } + return Init(); + }); + services.AddSingleton>(async (global::System.IServiceProvider sp) => await sp.GetRequiredService>()); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MultipleAsyncMethods_GeneratesTaskRegistration.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MultipleAsyncMethods_GeneratesTaskRegistration.verified.txt new file mode 100644 index 0000000..078a410 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_MultipleAsyncMethods_GeneratesTaskRegistration.verified.txt @@ -0,0 +1,45 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton>((global::System.IServiceProvider sp) => + { + async global::System.Threading.Tasks.Task Init() + { + var s0_m0 = sp.GetRequiredService(); + var s0_m1 = sp.GetRequiredService(); + var s0 = new global::TestNamespace.MyService(); + await s0.InitStep1(s0_m0); + await s0.InitStep2(s0_m1); + return s0; + } + return Init(); + }); + services.AddSingleton>(async (global::System.IServiceProvider sp) => await sp.GetRequiredService>()); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_NonGenericTaskDependency_TreatedAsPlainService.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_NonGenericTaskDependency_TreatedAsPlainService.verified.txt new file mode 100644 index 0000000..39e53a3 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_NonGenericTaskDependency_TreatedAsPlainService.verified.txt @@ -0,0 +1,30 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskWrapperDependency_AsyncInitService_ResolvesDirectly.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskWrapperDependency_AsyncInitService_ResolvesDirectly.verified.txt new file mode 100644 index 0000000..a44ab04 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskWrapperDependency_AsyncInitService_ResolvesDirectly.verified.txt @@ -0,0 +1,48 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton>((global::System.IServiceProvider sp) => + { + async global::System.Threading.Tasks.Task Init() + { + var s0_m0 = sp.GetRequiredService(); + var s0 = new global::TestNamespace.MyService(); + await s0.InitAsync(s0_m0); + return s0; + } + return Init(); + }); + services.AddSingleton>(async (global::System.IServiceProvider sp) => await sp.GetRequiredService>()); + services.AddSingleton((global::System.IServiceProvider sp) => + { + var p0 = sp.GetRequiredService>(); + var s0 = new global::TestNamespace.Consumer(p0); + return s0; + }); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskWrapperDependency_SyncService_UsesTaskFromResult.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskWrapperDependency_SyncService_UsesTaskFromResult.verified.txt new file mode 100644 index 0000000..36c3a79 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TaskWrapperDependency_SyncService_UsesTaskFromResult.verified.txt @@ -0,0 +1,37 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + services.AddSingleton((global::System.IServiceProvider sp) => + { + var p0 = global::System.Threading.Tasks.Task.FromResult(sp.GetRequiredService()); + var s0 = new global::TestNamespace.Consumer(p0); + return s0; + }); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.cs b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.cs new file mode 100644 index 0000000..313f7d2 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/AsyncMethodInjectTests.cs @@ -0,0 +1,262 @@ +namespace SourceGen.Ioc.Test.RegisterSourceGeneratorSnapshot; + +/// +/// Snapshot tests for AsyncMethodInject feature — async method injection code generation. +/// Verifies that the source generator emits correct Task<T> registrations with +/// async local Init() functions when services have async-inject methods. +/// +[Category(Constants.SourceGeneratorSnapshot)] +[Category(Constants.AsyncMethodInject)] +public class AsyncMethodInjectTests +{ + private const string AsyncMethodInjectFeatures = "Register,Container,PropertyInject,FieldInject,MethodInject,AsyncMethodInject"; + + [Test] + public async Task AsyncMethodInject_BasicAsyncMethod_GeneratesTaskRegistration() + { + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IDependency { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dependency : IDependency { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public async Task InitAsync(IDependency dep) + { + // async init + } + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task AsyncMethodInject_MultipleAsyncMethods_GeneratesTaskRegistration() + { + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IDependency1 { } + public interface IDependency2 { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dependency1 : IDependency1 { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dependency2 : IDependency2 { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public async Task InitStep1(IDependency1 dep1) { } + + [IocInject] + public async Task InitStep2(IDependency2 dep2) { } + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task AsyncMethodInject_MixedSyncAndAsyncMethods_GeneratesTaskRegistration() + { + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IDependency1 { } + public interface IDependency2 { } + public interface IDependency3 { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dependency1 : IDependency1 { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dependency2 : IDependency2 { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dependency3 : IDependency3 { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public IDependency1 Dep1 { get; init; } + + [IocInject] + public void SyncInit(IDependency2 dep2) { } + + [IocInject] + public async Task AsyncInit(IDependency3 dep3) { } + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task AsyncMethodInject_TaskWrapperDependency_AsyncInitService_ResolvesDirectly() + { + // Consumer takes Task — resolved as sp.GetRequiredService>() + // because MyService is async-init. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IDependency { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Dependency : IDependency { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public class MyService : IMyService + { + [IocInject] + public async Task InitAsync(IDependency dep) { } + } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(Task lazyService) + { + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task AsyncMethodInject_NonGenericTaskDependency_TreatedAsPlainService() + { + // Non-generic Task (arity 0) must NOT be classified as WrapperKind.Task. + // Before the fix, GetNonCollectionWrapperKind("Task") returned WrapperKind.Task + // for arity-0 Task, causing a NullReferenceException when TaskTypeData.InnerType + // tried to access TypeParameters![0]. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(Task nonGenericTask) + { + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + // Generator must not throw NRE — arity-0 Task is a plain service dependency. + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task AsyncMethodInject_TaskWrapperDependency_SyncService_UsesTaskFromResult() + { + // Consumer takes Task — resolved as Task.FromResult(sp.GetRequiredService()) + // because SyncService is NOT async-init. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface ISyncService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(ISyncService)])] + public class SyncService : ISyncService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public class Consumer(Task taskService) + { + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.EnumerableTask_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.EnumerableTask_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..36ee66a --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.EnumerableTask_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,32 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + services.AddSingleton(); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncTask_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncTask_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..36ee66a --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncTask_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,32 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + services.AddSingleton(); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyTask_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyTask_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..36ee66a --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyTask_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,32 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + services.AddSingleton(); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskEnumerable_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskEnumerable_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..36ee66a --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskEnumerable_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,32 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + services.AddSingleton(); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskFunc_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskFunc_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..36ee66a --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskFunc_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,32 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + services.AddSingleton(); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskLazy_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskLazy_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..36ee66a --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskLazy_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,32 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + services.AddSingleton(); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider.verified.txt new file mode 100644 index 0000000..36ee66a --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider.verified.txt @@ -0,0 +1,32 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + services.AddSingleton(); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_FuncLazyEnumerable_GeneratesNestedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_FuncLazyEnumerable_GeneratesNestedResolution.verified.txt new file mode 100644 index 0000000..73b094f --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_FuncLazyEnumerable_GeneratesNestedResolution.verified.txt @@ -0,0 +1,37 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + services.AddSingleton((global::System.IServiceProvider sp) => + { + var p0 = new global::System.Func>>(() => new global::System.Lazy>(() => sp.GetServices(), global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication)); + var s0 = new global::TestNamespace.Consumer(p0); + return s0; + }); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_LazyFuncLazy_GeneratesNestedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_LazyFuncLazy_GeneratesNestedResolution.verified.txt new file mode 100644 index 0000000..883d668 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_LazyFuncLazy_GeneratesNestedResolution.verified.txt @@ -0,0 +1,37 @@ +// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (!tags.Any()) + { + services.AddSingleton(); + services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + services.AddSingleton((global::System.IServiceProvider sp) => + { + var p0 = new global::System.Lazy>>(() => new global::System.Func>(() => sp.GetRequiredService>()), global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication); + var s0 = new global::TestNamespace.Consumer(p0); + return s0; + }); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.cs b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.cs index 4439827..d84af44 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.cs +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/WrapperTypeDependencyTests.cs @@ -671,4 +671,270 @@ public sealed class EnumConsumer(KeyValuePair entry) await Verify(generatedSource); } + + [Test] + public async Task TaskLazy_FallsBackToServiceProvider() + { + const string source = """ + using System; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public sealed class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public sealed class Consumer(System.Threading.Tasks.Task> service) + { + public System.Threading.Tasks.Task> Service { get; } = service; + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task LazyTask_FallsBackToServiceProvider() + { + const string source = """ + using System; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public sealed class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public sealed class Consumer(Lazy> service) + { + public Lazy> Service { get; } = service; + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task TaskFunc_FallsBackToServiceProvider() + { + const string source = """ + using System; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public sealed class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public sealed class Consumer(System.Threading.Tasks.Task> service) + { + public System.Threading.Tasks.Task> Service { get; } = service; + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task FuncTask_FallsBackToServiceProvider() + { + const string source = """ + using System; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public sealed class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public sealed class Consumer(Func> service) + { + public Func> Service { get; } = service; + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task TaskEnumerable_FallsBackToServiceProvider() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public sealed class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public sealed class Consumer(System.Threading.Tasks.Task> service) + { + public System.Threading.Tasks.Task> Service { get; } = service; + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task EnumerableTask_FallsBackToServiceProvider() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public sealed class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public sealed class Consumer(IEnumerable> service) + { + public IEnumerable> Service { get; } = service; + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task TripleNested_LazyFuncLazy_GeneratesNestedResolution() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public sealed class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public sealed class Consumer(Lazy>> triple) + { + private readonly Lazy>> _triple = triple; + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public sealed class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public sealed class Consumer(IEnumerable>> triple) + { + private readonly IEnumerable>> _triple = triple; + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } + + [Test] + public async Task TripleNested_FuncLazyEnumerable_GeneratesNestedResolution() + { + const string source = """ + using System; + using System.Collections.Generic; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IMyService)])] + public sealed class MyService : IMyService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton)] + public sealed class Consumer(Func>> triple) + { + private readonly Func>> _triple = triple; + } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } } diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/GlobalUsing.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/GlobalUsing.cs new file mode 100644 index 0000000..b41d8c6 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/GlobalUsing.cs @@ -0,0 +1,3 @@ +global using Microsoft.Extensions.DependencyInjection; +global using SourceGen.Ioc.TestCase; +global using SourceGen.Ioc.TestAot.TestCase; \ No newline at end of file diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/SourceGen.Ioc.TestAot.csproj b/src/Ioc/test/SourceGen.Ioc.TestAot/SourceGen.Ioc.TestAot.csproj index 4fd482c..7a7210c 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestAot/SourceGen.Ioc.TestAot.csproj +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/SourceGen.Ioc.TestAot.csproj @@ -7,7 +7,7 @@ Exe true true - Register,Container,PropertyInject,FieldInject,MethodInject + Register,Container,PropertyInject,FieldInject,MethodInject,AsyncMethodInject diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/AsyncInjectionContainer.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/AsyncInjectionContainer.cs new file mode 100644 index 0000000..5ad869d --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/AsyncInjectionContainer.cs @@ -0,0 +1,16 @@ +namespace SourceGen.Ioc.TestAot.TestCase; + +/// +/// Composite standalone container that wires and +/// together. Used only for verifying that +/// the composite resolver builds without errors; async-init service access is +/// exercised directly on (with +/// as fallback provider). +/// +[IocImportModule] +[IocImportModule] +[IocContainer( + ExplicitOnly = true, + ThreadSafeStrategy = ThreadSafeStrategy.SemaphoreSlim, + EagerResolveOptions = EagerResolveOptions.None)] +public sealed partial class AsyncInjectionContainer; diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/ContainerModule.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/ContainerModule.cs index 5dd2d97..cc1ec1d 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/ContainerModule.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/ContainerModule.cs @@ -1,6 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; -using SourceGen.Ioc.TestCase; - namespace SourceGen.Ioc.TestAot.TestCase; /// diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/EagerResolveContainer.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/EagerResolveContainer.cs new file mode 100644 index 0000000..c8fb3cd --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/EagerResolveContainer.cs @@ -0,0 +1,9 @@ +namespace SourceGen.Ioc.TestAot.TestCase; + +/// +/// Container with EagerResolveOptions.SingletonAndScoped to verify eager resolution behavior. +/// Singletons are resolved at container construction; scoped services at scope creation. +/// +[IocImportModule] +[IocContainer(EagerResolveOptions = EagerResolveOptions.SingletonAndScoped)] +public sealed partial class EagerResolveContainer; diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/FeatureServiceContainers.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/FeatureServiceContainers.cs new file mode 100644 index 0000000..a9646d6 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/FeatureServiceContainers.cs @@ -0,0 +1,38 @@ +namespace SourceGen.Ioc.TestAot.TestCase; + +/// +/// Feature service interface used to test IocContainer.IncludeTags filtering. +/// Services are defined in the same assembly as the container so the source +/// generator can evaluate their Tags properties at compile time. +/// +public interface IFeatureService +{ + string FeatureName { get; } +} + +// Registered via [IocRegister] so the generator can apply IncludeTags filtering. +[IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IFeatureService)], Tags = ["featureA"])] +public sealed class FeatureAService : IFeatureService +{ + public string FeatureName => "FeatureA"; +} + +[IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IFeatureService)], Tags = ["featureB"])] +public sealed class FeatureBService : IFeatureService +{ + public string FeatureName => "FeatureB"; +} + +/// +/// Container that includes only services tagged "featureA". +/// must NOT appear in this container's resolver. +/// +[IocContainer(IncludeTags = ["featureA"])] +public sealed partial class FeatureAContainer; + +/// +/// Container that includes only services tagged "featureB". +/// must NOT appear in this container's resolver. +/// +[IocContainer(IncludeTags = ["featureB"])] +public sealed partial class FeatureBContainer; diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/OpenGenericDiscovery.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/OpenGenericDiscovery.cs index 21294bf..a030155 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/OpenGenericDiscovery.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/OpenGenericDiscovery.cs @@ -1,6 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; -using SourceGen.Ioc.TestCase; - namespace SourceGen.Ioc.TestAot.TestCase; #region Discovery Method 1: Constructor Parameter diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/SwitchStatementContainer.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/SwitchStatementContainer.cs new file mode 100644 index 0000000..a1845e3 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/SwitchStatementContainer.cs @@ -0,0 +1,14 @@ +namespace SourceGen.Ioc.TestAot.TestCase; + +#pragma warning disable SGIOC011 // Duplicated Registration Detected + +/// +/// Container with UseSwitchStatement = true for testing the switch-statement resolution path. +/// Direct registrations are used instead of module imports because UseSwitchStatement is +/// ignored when imported modules are present (SGIOC020). +/// +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(ISingletonService)])] +[IocRegisterFor(ServiceLifetime.Scoped, ServiceTypes = [typeof(IScopedService)])] +[IocRegisterFor(ServiceLifetime.Transient, ServiceTypes = [typeof(ITransientService)])] +[IocContainer(UseSwitchStatement = true, ExplicitOnly = true)] +public sealed partial class SwitchStatementContainer; diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/ThreadSafeStrategyContainers.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/ThreadSafeStrategyContainers.cs index 8a87bbc..04b746c 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/ThreadSafeStrategyContainers.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/ThreadSafeStrategyContainers.cs @@ -1,6 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; -using SourceGen.Ioc.TestCase; - namespace SourceGen.Ioc.TestAot.TestCase; #pragma warning disable SGIOC011 // Duplicated Registration Detected diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/AsyncInjectionTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/AsyncInjectionTests.cs new file mode 100644 index 0000000..9606721 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/AsyncInjectionTests.cs @@ -0,0 +1,115 @@ +namespace SourceGen.Ioc.TestAot.Tests; + +/// +/// Tests for async method injection pattern. +/// Verifies that services with [IocInject] async Task methods are properly +/// initialized before first use in both standalone and MS.Extensions.DI scenarios. +/// +public sealed class AsyncInjectionTests +{ + #region Standalone Container Tests (via partial Task accessor on AsyncInjectionModule) + + // AsyncInjectionModule.GetAsyncInitServiceAsync() is the generated partial accessor. + // IInjectionDependency is provided by InjectionModule passed as a fallback provider. + + [Test] + public async Task AsyncInitService_StandaloneContainer_IsInitializedAfterResolve() + { + // Arrange + using var fallback = new InjectionModule(); + await using var module = new AsyncInjectionModule(fallback); + + // Act — resolve via the generated partial Task accessor + var service = await module.GetAsyncInitServiceAsync(); + + // Assert + await Assert.That(service.IsInitialized).IsTrue(); + } + + [Test] + public async Task AsyncInitService_StandaloneContainer_HasCorrectDependencyName() + { + // Arrange + using var fallback = new InjectionModule(); + await using var module = new AsyncInjectionModule(fallback); + + // Act + var service = await module.GetAsyncInitServiceAsync(); + + // Assert — initialized with the IInjectionDependency from InjectionModule + await Assert.That(service.InitializedBy).IsNotNull(); + await Assert.That(service.InitializedBy).IsEqualTo("InjectionDependency"); + } + + [Test] + public async Task AsyncInitService_StandaloneContainer_IsSingleton() + { + // Arrange + using var fallback = new InjectionModule(); + await using var module = new AsyncInjectionModule(fallback); + + // Act — calling twice must return the same instance + var instance1 = await module.GetAsyncInitServiceAsync(); + var instance2 = await module.GetAsyncInitServiceAsync(); + + // Assert + await Assert.That(instance1).IsSameReferenceAs(instance2); + } + + #endregion + + #region MS.Extensions.DI Integration Tests (via Task) + + [Test] + public async Task AsyncInitService_MsDi_IsInitializedAfterAwaitingTask() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act — async-init services are registered as Task in MS DI + var task = provider.GetRequiredService>(); + var service = await task; + + // Assert + await Assert.That(service).IsNotNull(); + await Assert.That(service.IsInitialized).IsTrue(); + } + + [Test] + public async Task AsyncInitService_MsDi_TaskResolvesTheSameInstance() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act + var service1 = await provider.GetRequiredService>(); + var service2 = await provider.GetRequiredService>(); + + // Assert — singleton-backed Task resolves the same underlying instance + await Assert.That(service1).IsSameReferenceAs(service2); + } + + #endregion + + #region Composite Container Tests + + [Test] + public async Task AsyncInjectionContainer_InjectionDependency_IsResolvableFromCompositeContainer() + { + // Arrange — AsyncInjectionContainer imports both InjectionModule and AsyncInjectionModule + using var container = new AsyncInjectionContainer(); + + // Act — IInjectionDependency comes from the imported InjectionModule + var dep = container.GetRequiredService(); + + // Assert + await Assert.That(dep).IsNotNull(); + await Assert.That(dep.Name).IsEqualTo("InjectionDependency"); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/CollectionTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/CollectionTests.cs new file mode 100644 index 0000000..49e2ca5 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/CollectionTests.cs @@ -0,0 +1,121 @@ +namespace SourceGen.Ioc.TestAot.Tests; + +/// +/// Tests for collection injection — IEnumerable<T> and IDictionary<TKey, TValue> resolving +/// multiple registrations for the same service interface. +/// +public sealed class CollectionTests +{ + #region Standalone Container — IEnumerable + + [Test] + public async Task Enumerable_StandaloneContainer_ReturnsAllRegisteredImplementations() + { + // Arrange + using var container = new ContainerModule(); + + // Act + var plugins = container.GetRequiredService>().ToList(); + + // Assert + await Assert.That(plugins.Count).IsEqualTo(3); + } + + [Test] + public async Task Enumerable_StandaloneContainer_ReturnsAllExpectedNames() + { + // Arrange + using var container = new ContainerModule(); + + // Act + var names = container.GetRequiredService>() + .Select(p => p.Name) + .ToHashSet(); + + // Assert + await Assert.That(names.Contains("PluginA")).IsTrue(); + await Assert.That(names.Contains("PluginB")).IsTrue(); + await Assert.That(names.Contains("PluginC")).IsTrue(); + } + + [Test] + public async Task PluginHost_StandaloneContainer_ReceivesAllPluginsViaConstructor() + { + // Arrange + using var container = new ContainerModule(); + + // Act + var host = container.GetRequiredService(); + + // Assert + await Assert.That(host.Plugins.Count).IsEqualTo(3); + } + + #endregion + + #region Standalone Container — IDictionary + + [Test] + public async Task KeyedDictionary_StandaloneContainer_ContainsBothKeyedServices() + { + // Arrange + using var container = new ContainerModule(); + + // Act + var registry = container.GetRequiredService(); + + // Assert + await Assert.That(registry.Processors.ContainsKey("alpha")).IsTrue(); + await Assert.That(registry.Processors.ContainsKey("beta")).IsTrue(); + } + + [Test] + public async Task KeyedDictionary_StandaloneContainer_ReturnsCorrectImplementationPerKey() + { + // Arrange + using var container = new ContainerModule(); + + // Act + var registry = container.GetRequiredService(); + + // Assert + await Assert.That(registry.Processors["alpha"].ProcessorName).IsEqualTo("Alpha"); + await Assert.That(registry.Processors["beta"].ProcessorName).IsEqualTo("Beta"); + } + + #endregion + + #region MS.Extensions.DI Integration Tests + + [Test] + public async Task Enumerable_MsDi_ReturnsAllRegisteredImplementations() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act + var plugins = provider.GetServices().ToList(); + + // Assert + await Assert.That(plugins.Count).IsEqualTo(3); + } + + [Test] + public async Task PluginHost_MsDi_ReceivesAllPluginsViaConstructor() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act + var host = provider.GetRequiredService(); + + // Assert + await Assert.That(host.Plugins.Count).IsEqualTo(3); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/ContainerTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/ContainerTests.cs index b861018..a1c4ea6 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/ContainerTests.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/ContainerTests.cs @@ -1,7 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; -using SourceGen.Ioc.TestAot.TestCase; -using SourceGen.Ioc.TestCase; - namespace SourceGen.Ioc.TestAot.Tests; /// @@ -361,6 +357,21 @@ public async Task Container_OpenGeneric_ServiceProviderDiscovery_Resolves() await Assert.That(result.Result).IsEqualTo(20); } + [Test] + public async Task Container_OpenGeneric_IocDiscoverAttribute_Resolves() + { + // Arrange + using var container = new ContainerModule(); + + // Act — IHandler is discovered via [IocDiscover] on Marker in OpenGenericDiscovery.cs + var handler = container.GetRequiredService>(); + + // Assert + await Assert.That(handler).IsNotNull(); + var result = handler.Handle(new RequestC(true)); + await Assert.That(result.Result).IsFalse(); + } + #endregion #region Dispose Tests diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/CrossAssemblyAttributeTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/CrossAssemblyAttributeTests.cs index 6d3a48f..b243026 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/CrossAssemblyAttributeTests.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/CrossAssemblyAttributeTests.cs @@ -1,5 +1,4 @@ using System.Reflection; -using SourceGen.Ioc.TestCase; namespace SourceGen.Ioc.TestAot.Tests; diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/EagerResolveTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/EagerResolveTests.cs new file mode 100644 index 0000000..a9f34e3 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/EagerResolveTests.cs @@ -0,0 +1,47 @@ +namespace SourceGen.Ioc.TestAot.Tests; + +/// +/// Tests for EagerResolveOptions — verifies that containers configured with +/// EagerResolveOptions.SingletonAndScoped resolve singletons during construction +/// rather than on first use. +/// +public sealed class EagerResolveTests +{ + [Test] + public async Task EagerResolveContainer_Singleton_IsNotNullAfterConstruction() + { + // Arrange & Act + using var container = new EagerResolveContainer(); + + // Assert — singleton was eagerly resolved during construction + var service = container.GetRequiredService(); + await Assert.That(service).IsNotNull(); + } + + [Test] + public async Task EagerResolveContainer_Singleton_ReturnsSameInstance() + { + // Arrange + using var container = new EagerResolveContainer(); + + // Act + var s1 = container.GetRequiredService(); + var s2 = container.GetRequiredService(); + + // Assert + await Assert.That(s1.InstanceId).IsEqualTo(s2.InstanceId); + await Assert.That(s1).IsSameReferenceAs(s2); + } + + [Test] + public async Task StandardContainer_Singleton_AlsoBehavesAsSingleton() + { + // Ensure default (non-eager) container also resolves correctly + using var container = new ContainerModule(); + + var s1 = container.GetRequiredService(); + var s2 = container.GetRequiredService(); + + await Assert.That(s1.InstanceId).IsEqualTo(s2.InstanceId); + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/FactoryAndInstanceTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/FactoryAndInstanceTests.cs new file mode 100644 index 0000000..ec396fd --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/FactoryAndInstanceTests.cs @@ -0,0 +1,101 @@ +namespace SourceGen.Ioc.TestAot.Tests; + +/// +/// Tests for factory and instance registration patterns. +/// Verifies that services created via Factory and Instance parameters +/// are resolved correctly in both standalone containers and MS.Extensions.DI. +/// +public sealed class FactoryAndInstanceTests +{ + #region Standalone Container Tests + + [Test] + public async Task Factory_StandaloneContainer_ReturnsFactoryCreatedInstance() + { + // Arrange + using var container = new ContainerModule(); + + // Act + var service = container.GetRequiredService(); + + // Assert + await Assert.That(service).IsNotNull(); + await Assert.That(service.CreatedBy).IsEqualTo("FactoryServiceFactory"); + await Assert.That(service.DepName).IsEqualTo("FactoryDep"); + // Verify the dep was successfully resolved via the generator's factory dep path + var dep = container.GetRequiredService(); + await Assert.That(service.DepName).IsEqualTo(dep.Label); + } + + [Test] + public async Task Instance_StandaloneContainer_ReturnsPredefinedInstance() + { + // Arrange + using var container = new ContainerModule(); + + // Act + var service = container.GetRequiredService(); + + // Assert + await Assert.That(service).IsNotNull(); + await Assert.That(service.Name).IsEqualTo("InstanceService"); + } + + [Test] + public async Task Instance_StandaloneContainer_ReturnsSameInstanceOnEachResolve() + { + // Arrange + using var container = new ContainerModule(); + + // Act + var instance1 = container.GetRequiredService(); + var instance2 = container.GetRequiredService(); + + // Assert - static pre-created instance is always the same object + await Assert.That(instance1).IsSameReferenceAs(instance2); + } + + #endregion + + #region MS.Extensions.DI Integration Tests + + [Test] + public async Task Factory_MsDi_ReturnsFactoryCreatedInstance() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act + var service = provider.GetRequiredService(); + + // Assert + await Assert.That(service).IsNotNull(); + await Assert.That(service.CreatedBy).IsEqualTo("FactoryServiceFactory"); + await Assert.That(service.DepName).IsEqualTo("FactoryDep"); + // Verify the dep was successfully resolved via the generator's factory dep path + var dep = provider.GetRequiredService(); + await Assert.That(service.DepName).IsEqualTo(dep.Label); + } + + [Test] + public async Task Instance_MsDi_ReturnsPredefinedStaticInstance() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act + var service1 = provider.GetRequiredService(); + var service2 = provider.GetRequiredService(); + + // Assert + await Assert.That(service1).IsNotNull(); + await Assert.That(service1.Name).IsEqualTo("InstanceService"); + await Assert.That(service1).IsSameReferenceAs(service2); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/KeyedCollectionTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/KeyedCollectionTests.cs new file mode 100644 index 0000000..8972545 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/KeyedCollectionTests.cs @@ -0,0 +1,107 @@ +namespace SourceGen.Ioc.TestAot.Tests; + +/// +/// Tests for keyed service collection injection — IDictionary<string, T> resolving +/// all keyed registrations for the same service interface. +/// +public sealed class KeyedCollectionTests +{ + #region Standalone Container Tests + + [Test] + public async Task KeyedDictionary_StandaloneContainer_ContainsBothKeys() + { + // Arrange + using var container = new ContainerModule(); + + // Act + var registry = container.GetRequiredService(); + + // Assert + await Assert.That(registry.Processors.ContainsKey("alpha")).IsTrue(); + await Assert.That(registry.Processors.ContainsKey("beta")).IsTrue(); + } + + [Test] + public async Task KeyedDictionary_StandaloneContainer_ReturnsCorrectImplementation() + { + // Arrange + using var container = new ContainerModule(); + + // Act + var registry = container.GetRequiredService(); + + // Assert + await Assert.That(registry.Processors["alpha"].ProcessorName).IsEqualTo("Alpha"); + await Assert.That(registry.Processors["beta"].ProcessorName).IsEqualTo("Beta"); + } + + [Test] + public async Task KeyedDictionary_StandaloneContainer_ExactlyTwoEntries() + { + // Arrange + using var container = new ContainerModule(); + + // Act + var registry = container.GetRequiredService(); + + // Assert + await Assert.That(registry.Processors.Count).IsEqualTo(2); + } + + #endregion + + #region MS.Extensions.DI Integration Tests + + [Test] + public async Task GetKeyedService_MsDi_AlphaKeyReturnsAlphaProcessor() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act + var processor = provider.GetKeyedService("alpha"); + + // Assert + await Assert.That(processor).IsNotNull(); + await Assert.That(processor!.ProcessorName).IsEqualTo("Alpha"); + } + + [Test] + public async Task GetKeyedService_MsDi_BetaKeyReturnsBetaProcessor() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act + var processor = provider.GetKeyedService("beta"); + + // Assert + await Assert.That(processor).IsNotNull(); + await Assert.That(processor!.ProcessorName).IsEqualTo("Beta"); + } + + [Test] + [Skip("IDictionary MS DI injection via IEnumerable> is AOT-incompatible: KeyValuePair is a ValueType and cannot be enumerated in native AOT.")] + public async Task ProcessorRegistry_MsDi_ReceivesAllProcessorsViaDictionary() + { + // NOTE: MS DI resolves IDictionary by enumerating KeyValuePair. + // KeyValuePair is a ValueType, which native AOT cannot enumerate through IEnumerable. + // Use the standalone container path (KeyedCollectionModule) instead. + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + var registry = provider.GetRequiredService(); + + await Assert.That(registry.Processors.Count).IsEqualTo(2); + await Assert.That(registry.Processors["alpha"].ProcessorName).IsEqualTo("Alpha"); + await Assert.That(registry.Processors["beta"].ProcessorName).IsEqualTo("Beta"); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/RegisterAllInterfacesTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/RegisterAllInterfacesTests.cs new file mode 100644 index 0000000..5eccab2 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/RegisterAllInterfacesTests.cs @@ -0,0 +1,81 @@ +namespace SourceGen.Ioc.TestAot.Tests; + +/// +/// Tests for RegisterAllInterfaces = true registration. +/// In standalone containers the generator registers only the concrete type; +/// interface forwarding (IServiceA, IServiceB) is generated only +/// in the MS.Extensions.DI Register path where forwarding lambdas are supported. +/// +public sealed class RegisterAllInterfacesTests +{ + #region Standalone Container Tests + + [Test] + public async Task RegisterAllInterfaces_StandaloneContainer_ConcreteTypeIsResolvable() + { + // Arrange + using var container = new ContainerModule(); + + // Act — standalone container registers the concrete type; access it directly + var service = container.GetRequiredService(); + + // Assert — concrete type implements both interfaces + await Assert.That(service).IsNotNull(); + await Assert.That(service.NameA).IsEqualTo("IServiceA"); + await Assert.That(service.NameB).IsEqualTo("IServiceB"); + } + + #endregion + + #region MS.Extensions.DI Integration Tests + + [Test] + public async Task RegisterAllInterfaces_MsDi_ResolvesViaFirstInterface() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act + var serviceA = provider.GetRequiredService(); + + // Assert + await Assert.That(serviceA).IsNotNull(); + await Assert.That(serviceA.NameA).IsEqualTo("IServiceA"); + } + + [Test] + public async Task RegisterAllInterfaces_MsDi_ResolvesViaSecondInterface() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act + var serviceB = provider.GetRequiredService(); + + // Assert + await Assert.That(serviceB).IsNotNull(); + await Assert.That(serviceB.NameB).IsEqualTo("IServiceB"); + } + + [Test] + public async Task RegisterAllInterfaces_MsDi_BothInterfacesPointToSameInstance() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act — singleton registered for all interfaces shares the same instance + var serviceA = provider.GetRequiredService(); + var serviceB = provider.GetRequiredService(); + + // Assert — same underlying object + await Assert.That(serviceA as object).IsSameReferenceAs(serviceB as object); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/RegisterIntegrationTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/RegisterIntegrationTests.cs index f893562..7e3d722 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/RegisterIntegrationTests.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/RegisterIntegrationTests.cs @@ -1,7 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; -using SourceGen.Ioc.TestAot.TestCase; -using SourceGen.Ioc.TestCase; - namespace SourceGen.Ioc.TestAot.Tests; /// @@ -297,4 +293,79 @@ public async Task Collection_ReturnsAllImplementations() } #endregion + + #region Cross-Assembly Collection / Factory / Wrapper Integration Tests + + [Test] + public async Task CrossAssembly_Collection_ResolvesAllPlugins() + { + // Arrange + await using var provider = CreateServiceProvider(); + + // Act — IPlugin registrations come from CollectionModule inside TestCase assembly + var plugins = provider.GetServices().ToList(); + + // Assert + await Assert.That(plugins.Count).IsEqualTo(3); + } + + [Test] + public async Task CrossAssembly_Factory_ReturnsFactoryCreatedInstanceWithDep() + { + // Arrange + await using var provider = CreateServiceProvider(); + + // Act — IFactoryService uses a factory method that injects ISingletonService + var service = provider.GetRequiredService(); + + // Assert + await Assert.That(service).IsNotNull(); + await Assert.That(service.CreatedBy).IsEqualTo("FactoryServiceFactory"); + await Assert.That(service.DepName).IsEqualTo("FactoryDep"); + } + + [Test] + public async Task CrossAssembly_Instance_ReturnsPredefinedStaticInstance() + { + // Arrange + await using var provider = CreateServiceProvider(); + + // Act — IInstanceService is registered via Instance = nameof(Default) + var service = provider.GetRequiredService(); + + // Assert + await Assert.That(service).IsNotNull(); + await Assert.That(service.Name).IsEqualTo("InstanceService"); + } + + [Test] + public async Task CrossAssembly_LazyPluginConsumer_LazyWrapperResolvesPlugin() + { + // Arrange + await using var provider = CreateServiceProvider(); + + // Act — LazyPluginConsumer receives Lazy from the cross-assembly CollectionModule + var consumer = provider.GetRequiredService(); + + // Assert + await Assert.That(consumer).IsNotNull(); + var plugin = consumer.LazyPlugin.Value; + await Assert.That(plugin).IsNotNull(); + } + + [Test] + public async Task CrossAssembly_LazyPluginConsumer_FuncWrapperResolvesPlugin() + { + // Arrange + await using var provider = CreateServiceProvider(); + + // Act + var consumer = provider.GetRequiredService(); + var plugin = consumer.PluginFactory(); + + // Assert + await Assert.That(plugin).IsNotNull(); + } + + #endregion } diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/SwitchStatementTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/SwitchStatementTests.cs new file mode 100644 index 0000000..43ab9b9 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/SwitchStatementTests.cs @@ -0,0 +1,80 @@ +namespace SourceGen.Ioc.TestAot.Tests; + +/// +/// Tests for UseSwitchStatement = true in the generated container resolver. +/// Verifies that the switch-statement dispatch path correctly routes service +/// resolution for common lifecycle scenarios. +/// +public sealed class SwitchStatementTests +{ + [Test] + public async Task SwitchStatement_Singleton_ResolvesCorrectly() + { + // Arrange + using var container = new SwitchStatementContainer(); + + // Act + var service = container.GetRequiredService(); + + // Assert + await Assert.That(service).IsNotNull(); + await Assert.That(service.InstanceId).IsNotEqualTo(Guid.Empty); + } + + [Test] + public async Task SwitchStatement_Singleton_ReturnsSameInstance() + { + // Arrange + using var container = new SwitchStatementContainer(); + + // Act + var s1 = container.GetRequiredService(); + var s2 = container.GetRequiredService(); + + // Assert + await Assert.That(s1.InstanceId).IsEqualTo(s2.InstanceId); + await Assert.That(s1).IsSameReferenceAs(s2); + } + + [Test] + public async Task SwitchStatement_Transient_ReturnsDifferentInstances() + { + // Arrange + using var container = new SwitchStatementContainer(); + + // Act + var t1 = container.GetRequiredService(); + var t2 = container.GetRequiredService(); + + // Assert + await Assert.That(t1.InstanceId).IsNotEqualTo(t2.InstanceId); + } + + [Test] + public async Task SwitchStatement_Scoped_SameScopeReturnsSameInstance() + { + // Arrange + using var container = new SwitchStatementContainer(); + using var scope = container.CreateScope(); + + // Act + var s1 = scope.ServiceProvider.GetRequiredService(); + var s2 = scope.ServiceProvider.GetRequiredService(); + + // Assert + await Assert.That(s1.InstanceId).IsEqualTo(s2.InstanceId); + } + + [Test] + public async Task SwitchStatement_UnregisteredService_ReturnsNull() + { + // Arrange + using var container = new SwitchStatementContainer(); + + // Act — ISingletonService is registered but IAsyncInitService is not in SwitchStatementContainer + var service = container.GetService(); + + // Assert + await Assert.That(service).IsNull(); + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/TagsTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/TagsTests.cs new file mode 100644 index 0000000..90be106 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/TagsTests.cs @@ -0,0 +1,105 @@ +namespace SourceGen.Ioc.TestAot.Tests; + +/// +/// Tests for IncludeTags filtering — verifies that containers with IncludeTags +/// only include services tagged with the specified tags. +/// Services must be defined in the same assembly as the container (using [IocRegister]) +/// for the source generator's IncludeTags filter to apply at compile time. +/// +public sealed class TagsTests +{ + #region FeatureAContainer — IncludeTags = ["featureA"] + + [Test] + public async Task FeatureAContainer_Resolves_FeatureAService() + { + // Arrange + using var container = new FeatureAContainer(); + + // Act + var service = container.GetRequiredService(); + + // Assert — only FeatureAService is tagged "featureA" + await Assert.That(service).IsNotNull(); + await Assert.That(service.FeatureName).IsEqualTo("FeatureA"); + } + + [Test] + public async Task FeatureAContainer_DoesNotContain_FeatureBService() + { + // Arrange — FeatureAContainer only includes services tagged "featureA" + using var container = new FeatureAContainer(); + + // Act — IFeatureService resolves to FeatureAService (only one registered) + var service = container.GetRequiredService(); + + // Assert — must NOT be FeatureBService + await Assert.That(service.FeatureName).IsNotEqualTo("FeatureB"); + } + + #endregion + + #region FeatureBContainer — IncludeTags = ["featureB"] + + [Test] + public async Task FeatureBContainer_Resolves_FeatureBService() + { + // Arrange + using var container = new FeatureBContainer(); + + // Act + var service = container.GetRequiredService(); + + // Assert — only FeatureBService is tagged "featureB" + await Assert.That(service).IsNotNull(); + await Assert.That(service.FeatureName).IsEqualTo("FeatureB"); + } + + [Test] + public async Task FeatureBContainer_DoesNotContain_FeatureAService() + { + // Arrange — FeatureBContainer only includes services tagged "featureB" + using var container = new FeatureBContainer(); + + // Act + var service = container.GetRequiredService(); + + // Assert — must NOT be FeatureAService + await Assert.That(service.FeatureName).IsNotEqualTo("FeatureA"); + } + + #endregion + + #region TagsModule — unfiltered, both tagged services visible + + [Test] + public async Task TagsModule_Resolves_ITaggedService() + { + // Arrange — TagsModule has no IncludeTags filter; both tagged services are registered + using var container = new TagsModule(); + + // Act + var service = container.GetService(); + + // Assert — at least one tagged service is resolvable + await Assert.That(service).IsNotNull(); + } + + [Test] + public async Task TagsModule_Resolves_BothTaggedServicesViaEnumerable() + { + // Arrange — TagsModule has no IncludeTags filter; both tagged services are registered + using var container = new TagsModule(); + + // Act — IEnumerable returns all registered implementations + var allServices = container.GetService>()?.ToList(); + + // Assert + await Assert.That(allServices).IsNotNull(); + await Assert.That(allServices!.Count).IsEqualTo(2); + await Assert.That(allServices.Any(s => s.ServiceName == "TaggedServiceA")).IsTrue(); + await Assert.That(allServices.Any(s => s.ServiceName == "TaggedServiceB")).IsTrue(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/ThreadSafeStrategyTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/ThreadSafeStrategyTests.cs index 944a527..20d5ed1 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/ThreadSafeStrategyTests.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/ThreadSafeStrategyTests.cs @@ -1,7 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; -using SourceGen.Ioc.TestAot.TestCase; -using SourceGen.Ioc.TestCase; - namespace SourceGen.Ioc.TestAot.Tests; /// diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/WrapperTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/WrapperTests.cs new file mode 100644 index 0000000..101530e --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/WrapperTests.cs @@ -0,0 +1,80 @@ +namespace SourceGen.Ioc.TestAot.Tests; + +/// +/// Tests for Lazy<T> and Func<T> wrapper injection — resolving a wrapper type whose +/// inner type has multiple registrations. +/// +public sealed class WrapperTests +{ + #region Standalone Container — Lazy and Func wrapper + + [Test] + public async Task LazyConsumer_StandaloneContainer_LazyWrapperResolvesService() + { + // Arrange + using var container = new WrapperModule(); + + // Act + var consumer = container.GetRequiredService(); + + // Assert — consumer is resolved; lazy is not yet materialised + await Assert.That(consumer).IsNotNull(); + var plugin = consumer.LazyPlugin.Value; + await Assert.That(plugin).IsNotNull(); + } + + [Test] + public async Task FuncConsumer_StandaloneContainer_FuncWrapperResolvesService() + { + // Arrange + using var container = new WrapperModule(); + + // Act + var consumer = container.GetRequiredService(); + var plugin1 = consumer.PluginFactory(); + var plugin2 = consumer.PluginFactory(); + + // Assert — singleton-backed Func returns the same instance each call + await Assert.That(plugin1).IsNotNull(); + await Assert.That(plugin1).IsSameReferenceAs(plugin2); + } + + #endregion + + #region MS.Extensions.DI Integration Tests + + [Test] + public async Task LazyConsumer_MsDi_LazyWrapperResolvesService() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act + var consumer = provider.GetRequiredService(); + + // Assert — consumer is resolved; lazy materialises on first access + await Assert.That(consumer).IsNotNull(); + var plugin = consumer.LazyPlugin.Value; + await Assert.That(plugin).IsNotNull(); + } + + [Test] + public async Task FuncConsumer_MsDi_FuncWrapperResolvesService() + { + // Arrange + var services = new ServiceCollection(); + services.AddSourceGen_Ioc_TestCase(); + await using var provider = services.BuildServiceProvider(); + + // Act + var consumer = provider.GetRequiredService(); + var plugin = consumer.PluginFactory(); + + // Assert + await Assert.That(plugin).IsNotNull(); + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/AsyncInjection.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/AsyncInjection.cs new file mode 100644 index 0000000..572e169 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/AsyncInjection.cs @@ -0,0 +1,30 @@ +namespace SourceGen.Ioc.TestCase; + +/// Interface for async-initialized service. +public interface IAsyncInitService +{ + bool IsInitialized { get; } + string? InitializedBy { get; } +} + +internal sealed class AsyncInitService : IAsyncInitService +{ + public bool IsInitialized { get; private set; } + public string? InitializedBy { get; private set; } + + [IocInject] + public async Task InitializeAsync(IInjectionDependency dep) + { + await Task.CompletedTask; + InitializedBy = dep.Name; + IsInitialized = true; + } +} + +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(IAsyncInitService)])] +[IocContainer(ExplicitOnly = true, ThreadSafeStrategy = ThreadSafeStrategy.SemaphoreSlim, EagerResolveOptions = EagerResolveOptions.None)] +public sealed partial class AsyncInjectionModule +{ + /// Async accessor — generated as async Task<IAsyncInitService> → awaits the internal resolver. + public partial Task GetAsyncInitServiceAsync(); +} diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/Basic.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/Basic.cs index e666aec..e35b016 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestCase/Basic.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/Basic.cs @@ -1,5 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; - namespace SourceGen.Ioc.TestCase; #region Lifetime Services diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/Collection.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/Collection.cs new file mode 100644 index 0000000..87fd647 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/Collection.cs @@ -0,0 +1,36 @@ +namespace SourceGen.Ioc.TestCase; + +/// Plugin interface for collection injection testing. +public interface IPlugin +{ + string Name { get; } +} + +internal sealed class PluginA : IPlugin +{ + public string Name => "PluginA"; +} + +internal sealed class PluginB : IPlugin +{ + public string Name => "PluginB"; +} + +internal sealed class PluginC : IPlugin +{ + public string Name => "PluginC"; +} + +/// Service that receives a collection of plugins via constructor injection. +public sealed class PluginHost(IEnumerable plugins) +{ + public IReadOnlyList Plugins { get; } = [.. plugins]; +} + +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(IPlugin)])] +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(IPlugin)])] +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(IPlugin)])] +[IocRegisterFor(ServiceLifetime.Transient)] +[IocDiscover>] +[IocContainer(ExplicitOnly = true)] +public sealed partial class CollectionModule; diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/Decorator.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/Decorator.cs index 2f350a6..95d5fe4 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestCase/Decorator.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/Decorator.cs @@ -1,5 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; - namespace SourceGen.Ioc.TestCase; /// diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/Factory.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/Factory.cs new file mode 100644 index 0000000..0ef4950 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/Factory.cs @@ -0,0 +1,51 @@ +namespace SourceGen.Ioc.TestCase; + +/// Interface for factory-created service. +public interface IFactoryService +{ + string CreatedBy { get; } + + /// Label from the injected dependency, proving the dep was resolved. + string DepName { get; } +} + +internal sealed class FactoryService : IFactoryService +{ + public string CreatedBy { get; init; } = string.Empty; + public string DepName { get; init; } = string.Empty; +} + +/// Dependency injected into the factory method to prove the generator resolves factory parameters. +public interface IFactoryDep +{ + string Label { get; } +} + +internal sealed class FactoryDep : IFactoryDep +{ + public string Label => nameof(FactoryDep); +} + +/// Interface for instance-based service. +public interface IInstanceService +{ + string Name { get; } +} + +internal sealed class InstanceService : IInstanceService +{ + public static readonly InstanceService Default = new(); + public string Name => nameof(InstanceService); +} + +internal static class FactoryServiceFactory +{ + public static IFactoryService Create(IFactoryDep dep) => + new FactoryService { CreatedBy = nameof(FactoryServiceFactory), DepName = dep.Label }; +} + +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(IFactoryDep)])] +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(IFactoryService)], Factory = nameof(FactoryServiceFactory.Create))] +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(IInstanceService)], Instance = nameof(InstanceService.Default))] +[IocContainer(ExplicitOnly = true)] +public sealed partial class FactoryModule; diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/GlobalUsing.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/GlobalUsing.cs new file mode 100644 index 0000000..3ec8c96 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/GlobalUsing.cs @@ -0,0 +1,2 @@ +global using Microsoft.Extensions.DependencyInjection; +global using SourceGen.Ioc; diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/Injection.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/Injection.cs index b53ae4a..7cb2ccf 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestCase/Injection.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/Injection.cs @@ -1,5 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; - namespace SourceGen.Ioc.TestCase; /// Dependency interface for injection testing. diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/Keyed.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/Keyed.cs index f228754..87b4b50 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestCase/Keyed.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/Keyed.cs @@ -1,5 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; - namespace SourceGen.Ioc.TestCase; /// Keyed service interface for testing keyed service resolution. diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/KeyedCollection.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/KeyedCollection.cs new file mode 100644 index 0000000..9fc260f --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/KeyedCollection.cs @@ -0,0 +1,30 @@ +namespace SourceGen.Ioc.TestCase; + +/// Processor interface for keyed collection testing. +public interface IProcessor +{ + string ProcessorName { get; } +} + +internal sealed class ProcessorAlpha : IProcessor +{ + public string ProcessorName => "Alpha"; +} + +internal sealed class ProcessorBeta : IProcessor +{ + public string ProcessorName => "Beta"; +} + +/// Registry that receives all processors via keyed dictionary injection. +public sealed class ProcessorRegistry(IDictionary processors) +{ + public IDictionary Processors => processors; +} + +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(IProcessor)], Key = "alpha")] +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(IProcessor)], Key = "beta")] +[IocRegisterFor(ServiceLifetime.Transient)] +[IocDiscover>] +[IocContainer(ExplicitOnly = true)] +public sealed partial class KeyedCollectionModule; diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/OpenGeneric.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/OpenGeneric.cs index 24a1c59..2195bab 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestCase/OpenGeneric.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/OpenGeneric.cs @@ -1,6 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; -using SourceGen.Ioc; - [assembly: IocRegisterDefaults( typeof(SourceGen.Ioc.TestCase.IHandler<,>), ServiceLifetime.Transient)] diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/RegisterAllInterfaces.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/RegisterAllInterfaces.cs new file mode 100644 index 0000000..b21433a --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/RegisterAllInterfaces.cs @@ -0,0 +1,23 @@ +namespace SourceGen.Ioc.TestCase; + +/// First interface for multi-interface registration testing. +public interface IServiceA +{ + string NameA { get; } +} + +/// Second interface for multi-interface registration testing. +public interface IServiceB +{ + string NameB { get; } +} + +public sealed class MultiInterfaceService : IServiceA, IServiceB +{ + public string NameA => nameof(IServiceA); + public string NameB => nameof(IServiceB); +} + +[IocRegisterFor(ServiceLifetime.Singleton, RegisterAllInterfaces = true)] +[IocContainer(ExplicitOnly = true)] +public sealed partial class RegisterAllInterfacesModule; diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/SourceGen.Ioc.TestCase.csproj b/src/Ioc/test/SourceGen.Ioc.TestCase/SourceGen.Ioc.TestCase.csproj index 978a3ff..cff76c4 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestCase/SourceGen.Ioc.TestCase.csproj +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/SourceGen.Ioc.TestCase.csproj @@ -5,7 +5,7 @@ true - Register,Container,PropertyInject,FieldInject,MethodInject + Register,Container,PropertyInject,FieldInject,MethodInject,AsyncMethodInject diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/Tags.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/Tags.cs new file mode 100644 index 0000000..8a76b00 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/Tags.cs @@ -0,0 +1,22 @@ +namespace SourceGen.Ioc.TestCase; + +/// Interface for tagged service testing. +public interface ITaggedService +{ + string ServiceName { get; } +} + +internal sealed class TaggedServiceA : ITaggedService +{ + public string ServiceName => "TaggedServiceA"; +} + +internal sealed class TaggedServiceB : ITaggedService +{ + public string ServiceName => "TaggedServiceB"; +} + +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(ITaggedService)], Tags = ["groupA"])] +[IocRegisterFor(ServiceLifetime.Singleton, ServiceTypes = [typeof(ITaggedService)], Tags = ["groupB"])] +[IocContainer(ExplicitOnly = true)] +public sealed partial class TagsModule; diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/TestCaseModule.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/TestCaseModule.cs index 3ad1787..9d7f0ec 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestCase/TestCaseModule.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/TestCaseModule.cs @@ -9,5 +9,11 @@ namespace SourceGen.Ioc.TestCase; [IocImportModule] [IocImportModule] [IocImportModule] +[IocImportModule] +[IocImportModule] +[IocImportModule] +[IocImportModule] +[IocImportModule] +[IocImportModule] [IocContainer(ExplicitOnly = true)] public sealed partial class TestCaseModule; diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/Wrapper.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/Wrapper.cs new file mode 100644 index 0000000..54aca62 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/Wrapper.cs @@ -0,0 +1,13 @@ +namespace SourceGen.Ioc.TestCase; + +/// Service that receives Lazy and Func wrapper dependencies. +public sealed class LazyPluginConsumer(Lazy lazyPlugin, Func pluginFactory) +{ + public Lazy LazyPlugin => lazyPlugin; + public Func PluginFactory => pluginFactory; +} + +[IocImportModule] +[IocRegisterFor(ServiceLifetime.Transient)] +[IocContainer(ExplicitOnly = true)] +public sealed partial class WrapperModule;