Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions emojiasm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ def main():
except TranspileError as e:
print(str(e), file=sys.stderr)
sys.exit(1)

if args.debug:
print("Source Map:", file=sys.stderr)
for func in program.functions.values():
for instr in func.instructions:
if instr.source:
arg_str = f" {instr.arg}" if instr.arg is not None else ""
print(
f" py:{instr.line_num}: {instr.source}"
f" -> {instr.op.name}{arg_str}",
file=sys.stderr,
)
else:
if args.file is None:
ap.error("the following arguments are required: file (or use --repl)")
Expand Down
23 changes: 7 additions & 16 deletions emojiasm/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from __future__ import annotations

import math
import re
import time
from functools import lru_cache
Expand All @@ -16,6 +15,7 @@
from .bytecode import OP_MAP, compile_to_bytecode, gpu_tier, GpuProgram, _build_string_table
from .opcodes import Op
from .parser import Program
from .stats import compute_stats


# ── Constants ────────────────────────────────────────────────────────────
Expand Down Expand Up @@ -293,22 +293,13 @@ def _get_kernel():
def _stats(values: list[float]) -> dict:
"""Compute summary statistics from a list of float values.

Returns dict with mean, std, min, max, count. Returns zeros when
*values* is empty.
"""
if not values:
return {"mean": 0.0, "std": 0.0, "min": 0.0, "max": 0.0, "count": 0}
Delegates to the unified ``emojiasm.stats.compute_stats`` module.
Kept as a module-level function for backward compatibility.

n = len(values)
mean = sum(values) / n
variance = sum((x - mean) ** 2 for x in values) / n
return {
"mean": mean,
"std": math.sqrt(variance),
"min": min(values),
"max": max(values),
"count": n,
}
Returns dict with mean, std, min, max, count, median. Returns zeros
when *values* is empty.
"""
return compute_stats(values, histogram_bins=0)


# ── Output reconstruction ────────────────────────────────────────────────
Expand Down
91 changes: 61 additions & 30 deletions emojiasm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import time
from typing import Any

from .stats import compute_stats


class EmojiASMTool:
"""LLM tool that executes EmojiASM programs on GPU.
Expand Down Expand Up @@ -67,6 +69,12 @@ def execute(self, source: str, n: int = 1) -> dict:
def execute_python(self, source: str, n: int = 1) -> dict:
"""Transpile Python source and execute as EmojiASM.

When ``n > 1``, auto-parallelization is attempted: if the source
looks like a single Monte Carlo trial (imports random, no large
loops, assigns to ``result``), a ``print(result)`` is appended
so each parallel instance captures its result. The GPU/CPU
execution pipeline then runs the program N times independently.

Args:
source: Python source code (subset: arithmetic, loops, random)
n: Number of parallel instances (capped at max_instances)
Expand All @@ -78,8 +86,26 @@ def execute_python(self, source: str, n: int = 1) -> dict:
n = min(max(n, 1), self.max_instances)

try:
from .transpiler import transpile
program = transpile(source)
import ast as _ast
from .transpiler import (
transpile,
_is_single_instance,
_ensure_result_capture,
)

# Auto-parallelization: detect single-instance programs and
# ensure result capture so each parallel run returns a value.
effective_source = source
if n > 1 and source and source.strip():
try:
tree = _ast.parse(source)
if _is_single_instance(tree):
tree = _ensure_result_capture(tree)
effective_source = _ast.unparse(tree)
except SyntaxError:
pass # fall through to transpile which will report error

program = transpile(effective_source)
except Exception as exc:
elapsed_ms = (time.perf_counter() - t0) * 1000
return {
Expand Down Expand Up @@ -135,31 +161,46 @@ def _execute_gpu(self, program: Any, n: int, tier: int, t0: float) -> dict:
return self._execute_cpu(program, n, tier, t0)

def _execute_cpu(self, program: Any, n: int, tier: int, t0: float) -> dict:
"""Execute on CPU via agent mode."""
"""Execute on CPU via VM with thread-level parallelism."""
try:
from .agent import run_agent_mode
agent_result = run_agent_mode(
program, filename="<inference>", runs=n, max_steps=self.max_steps
)
from .vm import VM, VMError
from concurrent.futures import ThreadPoolExecutor
import io
from contextlib import redirect_stdout

def _run_one(instance_id: int) -> dict:
"""Run a single VM instance, capturing output."""
try:
buf = io.StringIO()
vm = VM(program)
vm.max_steps = self.max_steps
with redirect_stdout(buf):
vm.run()
return {"status": "ok", "output": buf.getvalue()}
except (VMError, Exception) as e:
return {"status": "error", "output": None, "error": str(e)}

if n == 1:
results = [_run_one(0)]
else:
with ThreadPoolExecutor(max_workers=min(n, 16)) as pool:
results = list(pool.map(_run_one, range(n)))

elapsed_ms = (time.perf_counter() - t0) * 1000

# Extract numeric results from agent output
# Extract numeric results from output
numeric_results: list[float] = []
for r in agent_result.get("results", []):
for r in results:
if r.get("status") == "ok" and r.get("output"):
try:
numeric_results.append(float(r["output"].strip()))
except (ValueError, TypeError):
pass

ok_count = sum(
1
for r in agent_result.get("results", [])
if r.get("status") == "ok"
)
ok_count = sum(1 for r in results if r.get("status") == "ok")

# Compute stats
stats = self._compute_stats(numeric_results)
stats = compute_stats(numeric_results, histogram_bins=0)

return {
"success": ok_count == n,
Expand Down Expand Up @@ -189,22 +230,12 @@ def _execute_cpu(self, program: Any, n: int, tier: int, t0: float) -> dict:

@staticmethod
def _compute_stats(values: list[float]) -> dict:
"""Compute summary statistics from a list of float values."""
import math

if not values:
return {"mean": 0.0, "std": 0.0, "min": 0.0, "max": 0.0, "count": 0}
"""Compute summary statistics from a list of float values.

n = len(values)
mean = sum(values) / n
variance = sum((x - mean) ** 2 for x in values) / n
return {
"mean": mean,
"std": math.sqrt(variance),
"min": min(values),
"max": max(values),
"count": n,
}
Delegates to the unified ``emojiasm.stats.compute_stats`` module.
Kept as a static method for backward compatibility.
"""
return compute_stats(values, histogram_bins=0)

def execute_batch(self, sources: list[str], n_each: int = 1) -> list[dict]:
"""Execute multiple programs, returning results for each."""
Expand Down
94 changes: 94 additions & 0 deletions emojiasm/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Unified statistics module for EmojiASM."""

from __future__ import annotations

import math
import statistics
from typing import Any


def compute_stats(
values: list[float | int], histogram_bins: int = 10
) -> dict[str, Any]:
"""Compute descriptive statistics over a list of numeric values.

NaN and inf values are filtered out before computation. If all
values are non-finite, returns the same zero-result as an empty list.

Args:
values: List of numeric values (may contain NaN/inf).
histogram_bins: Number of histogram bins. Set to 0 to skip histogram.

Returns:
Dict with keys: mean, std, min, max, count, median, and optionally histogram.
"""
# Filter out NaN and inf values — they poison arithmetic and comparisons
values = [v for v in values if isinstance(v, (int, float)) and math.isfinite(v)]
count = len(values)

if count == 0:
result: dict[str, Any] = {
"mean": 0,
"std": 0,
"min": 0,
"max": 0,
"count": 0,
"median": 0,
}
return result

val_min = min(values)
val_max = max(values)
mean = sum(values) / count
median = statistics.median(values)

# Population standard deviation
if count == 1:
std = 0.0
else:
variance = sum((x - mean) ** 2 for x in values) / count
std = math.sqrt(variance)

result = {
"mean": mean,
"std": std,
"min": val_min,
"max": val_max,
"count": count,
"median": median,
}

if histogram_bins > 0:
result["histogram"] = _histogram(values, histogram_bins, val_min, val_max)

return result


def _histogram(
values: list[float | int], bins: int, val_min: float | int, val_max: float | int
) -> dict[str, list[float]]:
"""Compute histogram edges and counts.

Returns dict with 'edges' (list of bin edges, length bins+1) and
'counts' (list of counts per bin, length bins).
"""
# All same values — single bin
if val_min == val_max:
edges = [float(val_min), float(val_min)]
counts = [len(values)]
return {"edges": edges, "counts": counts}

# Compute evenly spaced bin edges
step = (val_max - val_min) / bins
edges = [val_min + i * step for i in range(bins)] + [val_max]
counts = [0] * bins

for v in values:
# Find the bin index
idx = int((v - val_min) / step)
# Clamp: values equal to val_max go in the last bin
if idx >= bins:
idx = bins - 1
counts[idx] += 1

return {"edges": [float(e) for e in edges], "counts": counts}
Loading