diff --git a/build.sh b/build.sh index 8980552f94..492c033686 100755 --- a/build.sh +++ b/build.sh @@ -208,6 +208,13 @@ if [ -z "$NCCL_LFLAG" ]; then NCCL_LFLAG=$(python -c "import nvidia.nccl, os; print('-L' + os.path.join(nvidia.nccl.__path__[0], 'lib'))" 2>/dev/null || echo "") fi +WHEEL_RPATH_FLAGS=() +for lib_flag in "$CUDNN_LFLAG" "$NCCL_LFLAG"; do + if [[ "$lib_flag" == -L* ]]; then + WHEEL_RPATH_FLAGS+=("-Wl,-rpath,${lib_flag#-L}") + fi +done + export CCACHE_DIR="${CCACHE_DIR:-$HOME/.ccache}" export CCACHE_BASEDIR="$(pwd)" export CCACHE_COMPILERCHECK=content @@ -232,7 +239,7 @@ if [ ! -f "$BINDING_SRC" ]; then fi echo "Compiling static library for $ENV..." -${CC:-clang} -c "${CLANG_OPT[@]}" \ +${CC:-clang} -c "${CLANG_OPT[@]}" $EXTRA_CFLAGS \ -I. -Isrc -I$SRC_DIR -Ivendor \ -I./$RAYLIB_NAME/include -I$CUDA_HOME/include \ -DPLATFORM_DESKTOP \ @@ -268,6 +275,7 @@ if [ -z "$MODE" ]; then ${CXX:-g++} -shared -fPIC -fopenmp build/bindings.o "$STATIC_LIB" "$RAYLIB_A" -L$CUDA_HOME/lib64 $CUDNN_LFLAG $NCCL_LFLAG + "${WHEEL_RPATH_FLAGS[@]}" -lcudart -lnccl -lnvidia-ml -lcublas -lcusolver -lcurand -lcudnn $OMP_LIB $LINK_OPT "${SHARED_LDFLAGS[@]}" diff --git a/config/craftax.ini b/config/craftax.ini new file mode 100644 index 0000000000..83bbcbd663 --- /dev/null +++ b/config/craftax.ini @@ -0,0 +1,19 @@ +[base] +env_name = craftax + +[vec] +total_agents = 8192 +num_buffers = 4 +num_threads = 16 + +[env] +seed_offset = 0 +# Pre-generated world pool. Each reset memcpys from a pool entry +# instead of re-running generate_world (~ms -> ~us per reset). +# Bounds world diversity: at most reset_pool_size unique maps are +# ever seen per process. Set to 0 to disable (required for the +# parity harness to maintain exact per-seed determinism). +reset_pool_size = 1024 + +[train] +total_timesteps = 200_000_000 diff --git a/config/craftax_classic.ini b/config/craftax_classic.ini new file mode 100644 index 0000000000..f71be6d189 --- /dev/null +++ b/config/craftax_classic.ini @@ -0,0 +1,20 @@ +[base] +env_name = craftax_classic + +[vec] +total_agents = 8192 +num_buffers = 4 +num_threads = 16 + +[env] +# Pre-generated world pool. When > 0, c_reset memcpys from a random pool +# entry instead of re-running generate_world (~30 us -> ~0.5 us per reset). +# Default is 0 (disabled) because on classic the env is not the training +# bottleneck: policy backward/optimizer dominate, so caching doesn't move +# training SPS. Useful to set > 0 for sim-only workloads (data generation, +# evaluation rollouts) where c_step throughput matters. Bounds world +# diversity: at most reset_pool_size unique maps are ever seen per process. +reset_pool_size = 0 + +[train] +total_timesteps = 200_000_000 diff --git a/config/ocean/craftax.ini b/config/ocean/craftax.ini deleted file mode 100644 index 987a4dd314..0000000000 --- a/config/ocean/craftax.ini +++ /dev/null @@ -1,12 +0,0 @@ -[base] -env_name = craftax - -[vec] -total_agents = 8192 -num_buffers = 4 -num_threads = 16 - -[env] - -[train] -total_timesteps = 200_000_000 diff --git a/ocean/craftax/PORT_NOTES.md b/ocean/craftax/PORT_NOTES.md new file mode 100644 index 0000000000..4542b1dcb8 --- /dev/null +++ b/ocean/craftax/PORT_NOTES.md @@ -0,0 +1,543 @@ +# Craftax Full Ocean Port Notes + +## Verification coverage + +The standalone parity harness now supports deterministic action policies beyond +uniform random exploration: + +- `uniform`: the original random action stream. +- `combat`: biases toward `DO`, arrows, fireballs, and iceballs when mobs and + resources make those actions meaningful, otherwise moves toward live mobs. +- `descend`: uses the mirrored state to push toward down ladders, clear blocked + levels through combat, and exercise placement and crafting actions. +- `suicide`: steers into adjacent lava, water, mob-occupied, or projectile-heavy + danger and otherwise paths toward the nearest known hazard. +- `boss`: warms up with downward navigation and then repeatedly attempts + descent while continuing to route toward ladders. +- `mixed`: round-robins the above every 500 steps. + +`tests/craftax_parity.py` now reports the policy, seed, step, action, reward +delta, terminal delta, first symbolic-observation field, suspected subsystem, +and the last 10 actions on any divergence. With `--reset-on-done` enabled, the +harness tracks terminal counts and mean episode length by seed. JAX stepping is +run through the no-auto-reset path with the same per-step key split used by the +native env; when a terminal is observed, the mirrored state is advanced through +the native reset helper keyed by the same auto-reset key, and that reset state +and observation are checked field-by-field before continuing. + +The stress battery in `tests/craftax_parity_stress.py` runs: + +- 64 seeds times 10000 steps with `mixed`. +- 16 seeds times 30000 steps with `descend`. +- 32 seeds times 5000 steps with `suicide`. +- 16 seeds times 5000 steps with `combat`. + +All stress cases use `atol=1e-5` for observations and rewards and exact terminal +matching. The phase-10a run completed with zero divergences in 1033.0 seconds: +2883 terminals in `mixed`, 2498 in `descend`, 622 in `suicide`, and 355 in +`combat`. + +Residual caveats: + +- The harness observes live C step state through the public vector API, so step + diagnostics identify the first differing observation field and subsystem class + rather than dumping the entire private C state after every step. +- CPU XLA can fuse reset worldgen noise normalization differently from + materialized JAX by one ULP on exact threshold cells. Materialized JAX + worldgen and native reset agree on the targeted sand-threshold keys covered by + `tests/craftax_worldgen_test.py`, so terminal continuation uses the native + reset helper after explicit reset-state verification. + +## 2026-04-18 Native Step Integration and Proxy Removal + +This phase wires the green native reset and all green native step subsystems +into the live Ocean `c_step` path. The Python/JAX proxy has been fully removed: +`c_init`, `c_reset`, `c_step`, and `c_close` are now 100% native. + +- `c_step_native` now mirrors the installed `craftax_step` subsystem order: + floor changes, crafting, action, placement, projectiles, spells, potions, + books, enchantment, boss logic, attributes, movement, mobs, spawning, plants, + intrinsics, clipping, inventory achievements, reward, timestep, light level, + terminal, and symbolic observation encoding. +- The live env keeps the same outer RNG schedule as the old auto-reset proxy: + reset uses the reset key's inner worldgen split, each step splits the external + key once, then splits the per-step key into gameplay and auto-reset keys. +- Step observations reuse the native symbolic encoder, now with mob channels and + boss-vulnerable special value populated for non-reset states. +- `tests/craftax_step_full_test.py` adds the full side-by-side parity check for + 16 seeds times 2000 random-action steps. `tests/craftax_parity.py` remains as + the standalone harness. + +Native-step roadmap checklist: + +- [x] Native reset PRNG, noise, 9-floor world generation, and reset observation. +- [x] Standalone native simple step subsystems with JAX-parity tests. +- [x] Standalone native medium step subsystems with JAX-parity tests. +- [x] Standalone native crafting and placement subsystems with JAX-parity tests. +- [x] Standalone native `do_action` subsystem with JAX-parity tests. +- [x] Standalone native `spawn_mobs` subsystem with JAX-parity tests. +- [x] Standalone native `update_mobs` subsystem with JAX-parity tests. +- [x] Native reward, terminal, timestep, light-level, RNG, and achievement-delta + bookkeeping around the subsystem calls. +- [x] Integrate all green subsystem ports into native `c_step` and remove all + Python/JAX proxy code paths. + +Remaining proxy paths: + +- None. The Craftax Ocean env no longer loads CPython symbols, constructs a JAX + env, or delegates reset/step/close through Python. + +Next phase: + +- Optimize the native path after correctness is locked down. Likely targets are + SIMD-friendly loops, cache-tiled symbolic observation encoding, and mob update + hot paths. Performance claims need measurement. + +## 2026-04-18 Standalone Update Mobs Step Subsystem + +This phase adds a native C port for the `update_mobs` subsystem, still +deliberately without integrating it into `c_step`. The live Ocean environment +continues to delegate step to the Python/JAX proxy. + +- `step_update_mobs.h` contains the standalone in-place helper for: + - `update_mobs` +- The helper mirrors the installed JAX update order for melee mobs, passive + mobs, ranged mobs, mob projectiles, and player projectiles. It preserves the + scan-level Threefry threading, including the melee loop's final right-key + carry, and the top-level split before each mob class. +- Mob movement and collision use the installed collision tables for land, + flying, aquatic, and amphibian mobs, including JAX-style clamped reads, + scatter-drop writes, mob-map exclusion, water/lava/solid checks, despawn + distance, boss-floor despawn suppression, and sequential mob-map updates. +- Combat covers melee player attacks, ranged projectile spawning, projectile + movement, player damage with armour and enchantment defenses, sleeping/resting + wakeups, player projectile damage scaling, first-target mob attacks, kill + achievements, mob-map clearing, and `monsters_killed` updates. +- `tests/craftax_step_update_mobs_test.py` builds a temporary C wrapper around + the inline helper and compares full copied states against the installed JAX + function for 16 reset-plus-RNG-action-stepped states. Targeted coverage + includes every mob class on every floor, melee attacks, ranged projectile + firing, mob projectiles hitting the player, walls, and out-of-bounds, player + projectile mob kills, despawn, cooldown decrement, and empty-mask live-effect + checks. + +Native-step roadmap checklist: + +- [x] Native reset PRNG, noise, 9-floor world generation, and reset observation. +- [x] Standalone native simple step subsystems with JAX-parity tests. +- [x] Standalone native medium step subsystems with JAX-parity tests. +- [x] Standalone native crafting and placement subsystems with JAX-parity tests. +- [x] Standalone native `do_action` subsystem with JAX-parity tests. +- [x] Standalone native `spawn_mobs` subsystem with JAX-parity tests. +- [x] Standalone native `update_mobs` subsystem with JAX-parity tests. +- [ ] Native reward, terminal, timestep, light-level, RNG, and achievement-delta + bookkeeping around the subsystem calls. +- [ ] Integrate all green subsystem ports into a native `c_step` behind one + explicit switch, then remove the Python/JAX proxy from the normal step path. +- [ ] Restore production vector sizes in `config/ocean/craftax.ini` after native + step is the default. +- [ ] Benchmark CPU throughput only after the proxy path is gone. + +Remaining proxy paths: + +- `c_step` still delegates to the Python/JAX proxy. None of the standalone + subsystem helpers are wired into the live environment yet. +- All gameplay step subsystems now have standalone native ports with parity + tests. Reward/terminal bookkeeping, light-level updates, timestep updates, + RNG threading between subsystems, and achievement-delta logging are still not + integrated natively. +- Rendering remains a no-op. +- `config/ocean/craftax.ini` still uses a small proxy-friendly vector size. The + native port should raise this once step no longer calls Python. + +## 2026-04-18 Standalone Spawn Mobs Step Subsystem + +This phase adds a native C port for the `spawn_mobs` subsystem, still +deliberately without integrating it into `c_step`. The live Ocean environment +continues to delegate step to the Python/JAX proxy. + +- `step_spawn_mobs.h` contains the standalone in-place helper for: + - `spawn_mobs` +- The helper mirrors the installed JAX split order: passive chance, passive + position, melee chance, melee position, ranged chance, ranged position. It + also keeps the JAX behavior where the selected slot's `type_id` is written + even when the spawn gate fails. +- Spawn maps match the installed function's terrain and distance rules, + including passive distance rejection near the player, monster range gates, + overworld night-zombie light scaling, deep-thing water spawning, grave-only + boss-wave spawning, mob-map exclusion, caps, and sequential mob-map updates + between passive, melee, and ranged attempts. +- `tests/craftax_step_spawn_mobs_test.py` builds a temporary C wrapper around + the inline helper and compares full copied states against the installed JAX + function for 16 reset-plus-NOOP-step seeds. Targeted coverage includes all + nine floors, full mob caps, empty-slot spawns at single candidate positions, + day versus night overworld melee chances, boss spawn-wave pacing, player- + adjacent candidate rejection, and land, water, and grave terrain constraints. + +Native-step roadmap checklist: + +- [x] Native reset PRNG, noise, 9-floor world generation, and reset observation. +- [x] Standalone native simple step subsystems with JAX-parity tests. +- [x] Standalone native medium step subsystems with JAX-parity tests. +- [x] Standalone native crafting and placement subsystems with JAX-parity tests. +- [x] Standalone native `do_action` subsystem with JAX-parity tests. +- [x] Standalone native `spawn_mobs` subsystem with JAX-parity tests. +- [ ] Standalone native `update_mobs` subsystem with JAX-parity tests. +- [ ] Native reward, terminal, timestep, light-level, RNG, and achievement-delta + bookkeeping around the subsystem calls. +- [ ] Integrate all green subsystem ports into a native `c_step` behind one + explicit switch, then remove the Python/JAX proxy from the normal step path. +- [ ] Restore production vector sizes in `config/ocean/craftax.ini` after native + step is the default. +- [ ] Benchmark CPU throughput only after the proxy path is gone. + +Remaining proxy paths: + +- `c_step` still delegates to the Python/JAX proxy. None of the standalone + subsystem helpers are wired into the live environment yet. +- The only gameplay step subsystem still without a standalone native port is + `update_mobs`. Reward/terminal bookkeeping, light-level updates, timestep + updates, RNG threading between subsystems, and achievement-delta logging are + also still not integrated natively. +- Rendering remains a no-op. +- `config/ocean/craftax.ini` still uses a small proxy-friendly vector size. The + native port should raise this once step no longer calls Python. + +## 2026-04-18 Standalone Do Action Step Subsystem + +This phase adds a native C port for the `do_action` subsystem, still +deliberately without integrating it into `c_step`. The live Ocean environment +continues to delegate step to the Python/JAX proxy. + +- `step_do_action.h` contains the standalone in-place helper for: + - `do_action` +- The helper mirrors the installed JAX ordering: mob attack resolution runs + before block interaction; block mining/eating/drinking/inventory/achievement + effects are gated by in-bounds and no mob attack; chest-open flags and boss + progress keep the JAX side effects that are not part of that gate. +- Chest looting calls the existing native `craftax_add_items_from_chest_native` + helper after consuming the sapling RNG split, so first-open bow/book rewards + see the old `chests_opened` value and the chest RNG thread matches JAX. +- Mob attacks cover passive, melee, and ranged mob arrays, including first-match + target selection, defense mapping, sword enchantment damage, strength and + intelligence scaling, passive food refill, kill achievements, mob-map updates, + and monster kill counts. +- `tests/craftax_step_do_action_test.py` builds a temporary C wrapper around the + inline helper and compares full copied states against the installed JAX + function for 16 reset-plus-step-through seeds. Coverage includes a seeded + no-op-then-DO sequence, mining success and missing-pickaxe cases, sapling RNG + rolls, plant/passive food and water/fountain drink cases, all chest levels, + all passive/melee/ranged kill achievement mappings, damage modifier cases, + out-of-bounds targets, no-op target blocks, projectile-occupied targets, and + mob-on-chest gating. + +Native-step roadmap checklist: + +- [x] Native reset PRNG, noise, 9-floor world generation, and reset observation. +- [x] Standalone native simple step subsystems with JAX-parity tests. +- [x] Standalone native medium step subsystems with JAX-parity tests. +- [x] Standalone native crafting and placement subsystems with JAX-parity tests. +- [x] Standalone native `do_action` subsystem with JAX-parity tests. +- [ ] Standalone native ports for the remaining mob step subsystems: + `update_mobs` and `spawn_mobs`. +- [ ] Native reward, terminal, timestep, light-level, RNG, and achievement-delta + bookkeeping around the subsystem calls. +- [ ] Integrate all green subsystem ports into a native `c_step` behind one + explicit switch, then remove the Python/JAX proxy from the normal step path. +- [ ] Restore production vector sizes in `config/ocean/craftax.ini` after native + step is the default. +- [ ] Benchmark CPU throughput only after the proxy path is gone. + +Remaining proxy paths: + +- `c_step` still delegates to the Python/JAX proxy. None of the standalone + subsystem helpers are wired into the live environment yet. +- The only gameplay step subsystems still without standalone native ports are + `update_mobs` and `spawn_mobs`. Reward/terminal bookkeeping, light-level + updates, timestep updates, RNG threading between subsystems, and + achievement-delta logging are also still not integrated natively. +- Rendering remains a no-op. +- `config/ocean/craftax.ini` still uses a small proxy-friendly vector size. The + native port should raise this once step no longer calls Python. + +## 2026-04-18 Standalone Crafting And Placement Step Subsystems + +This phase adds native C ports for two more action subsystems, still +deliberately without integrating them into `c_step`. The live Ocean environment +continues to delegate step to the Python/JAX proxy. + +- `step_crafting.h` contains standalone in-place helpers for: + - `do_crafting` + - `place_block` + - `add_new_growing_plant`, used by plant placement and exposed to the test + wrapper as a translation-unit-local helper +- `do_crafting` mirrors the JAX recipe order and sequential inventory updates + for all twelve `MAKE_*` actions present in the current Action enum: + pickaxes, swords, iron/diamond armour, arrows, and torches. +- `place_block` mirrors table, furnace, stone, plant, and torch placement, + including original-block placement tests, item-map gating, mob/out-of-bounds + rollback, first-empty growing-plant slot selection, and the padded 9x9 torch + light update near map boundaries. +- `tests/craftax_step_crafting_test.py` builds a temporary C wrapper around the + inline helpers and compares each subsystem against the installed JAX function + on reset-plus-step-through states for 16 seeds. Coverage includes success, + missing-resource/tool-cap, missing-station crafting cases; every JAX-legal + placement target block for each placement action; illegal wall/item/mob/water + cases where applicable; map-boundary rollback; and direct first-available-slot + checks for growing plants. + +Native-step roadmap checklist: + +- [x] Native reset PRNG, noise, 9-floor world generation, and reset observation. +- [x] Standalone native simple step subsystems with JAX-parity tests. +- [x] Standalone native medium step subsystems with JAX-parity tests. +- [x] Standalone native crafting and placement subsystems with JAX-parity tests. +- [x] Standalone native `do_action` subsystem with JAX-parity tests. +- [ ] Standalone native ports for the remaining mob step subsystems: + `update_mobs` and `spawn_mobs`. +- [ ] Native reward, terminal, timestep, light-level, RNG, and achievement-delta + bookkeeping around the subsystem calls. +- [ ] Integrate all green subsystem ports into a native `c_step` behind one + explicit switch, then remove the Python/JAX proxy from the normal step path. +- [ ] Restore production vector sizes in `config/ocean/craftax.ini` after native + step is the default. +- [ ] Benchmark CPU throughput only after the proxy path is gone. + +Remaining proxy paths: + +- `c_step` still delegates to the Python/JAX proxy. None of the standalone + subsystem helpers are wired into the live environment yet. +- The only gameplay step subsystems still without standalone native ports are + `update_mobs` and `spawn_mobs`. Reward/terminal bookkeeping, light-level + updates, timestep updates, RNG threading, and achievement-delta logging are + also still not integrated natively. +- Rendering remains a no-op. +- `config/ocean/craftax.ini` still uses a small proxy-friendly vector size. The + native port should raise this once step no longer calls Python. + +## 2026-04-18 Standalone Medium Step Subsystems + +This phase adds native C ports for five more step subsystems, again deliberately +without integrating them into `c_step`. The live Ocean environment still +delegates step to the Python/JAX proxy, so the full parity harness should remain +unchanged. + +- `step_medium.h` contains standalone in-place helpers for: + - `shoot_projectile` + - `cast_spell` + - `enchant` + - `change_floor` + - `add_items_from_chest` +- `add_items_from_chest` takes read-only `CraftaxState` context plus the + `CraftaxInventory` being mutated because the JAX helper's special chest drops + depend on `player_level` and `chests_opened`. +- `tests/craftax_step_medium_test.py` builds a temporary C wrapper around the + inline helpers and compares each subsystem against the installed JAX function + on copied reset-plus-step-through states for 16 seeds and targeted cases: + projectile slot and resource gating, learned/unlearned spells, enchantment + table/gem/mana/item gating, every floor transition direction, and chest potion + and special-drop paths. +- The helpers do not allocate, do not call Python, and preserve the JAX details + that matter for these routines, including clamped gather-style indexing, + first-free projectile slot selection, cumulative-probability `choice` with + `1 - uniform`, sequential Threefry split ordering, and the chest helper's + intentionally unused wood roll. + +Native-step roadmap checklist: + +- [x] Native reset PRNG, noise, 9-floor world generation, and reset observation. +- [x] Standalone native simple step subsystems with JAX-parity tests. +- [x] Standalone native medium step subsystems with JAX-parity tests. +- [x] Standalone native crafting and placement subsystems with JAX-parity tests. +- [x] Standalone native `do_action` subsystem with JAX-parity tests. +- [ ] Standalone native ports for the remaining mob step subsystems: + `update_mobs` and `spawn_mobs`. +- [ ] Native reward, terminal, timestep, light-level, RNG, and achievement-delta + bookkeeping around the subsystem calls. +- [ ] Integrate all green subsystem ports into a native `c_step` behind one + explicit switch, then remove the Python/JAX proxy from the normal step path. +- [ ] Restore production vector sizes in `config/ocean/craftax.ini` after native + step is the default. +- [ ] Benchmark CPU throughput only after the proxy path is gone. + +Remaining proxy paths: + +- `c_step` still delegates to the Python/JAX proxy. None of the new medium + helpers are wired into the live environment yet. +- The only gameplay step subsystems still without standalone native ports are + `update_mobs` and `spawn_mobs`. Reward/terminal bookkeeping, light-level + updates, timestep updates, RNG threading, and achievement-delta logging are + also still not integrated natively. +- Rendering remains a no-op. +- `config/ocean/craftax.ini` still uses a small proxy-friendly vector size. The + native port should raise this once step no longer calls Python. + +## 2026-04-18 Standalone Simple Step Subsystems + +This phase adds native C ports for the easy step subsystems, but deliberately +does not integrate them into `c_step`. The live Ocean environment still delegates +step to the Python/JAX proxy, so the full parity harness should remain unchanged. + +- `step_simple.h` contains standalone in-place helpers for: + - `move_player` + - `update_plants` + - `boss_logic` + - `level_up_attributes` + - `clip_inventory_and_intrinsics` + - `calculate_inventory_achievements` + - `update_player_intrinsics` + - `drink_potion` + - `read_book` +- `tests/craftax_state_fixtures.py` provides test-only pickle payloads for JAX + `EnvState` values, a ctypes mirror of `CraftaxState`, C-to-JAX conversion, and + strict state diffing with exact integer/bool checks and `atol=1e-6` float + checks. +- `tests/craftax_step_subsystem_test.py` builds a temporary C wrapper around the + inline helpers and compares each subsystem against the JAX function on copied + reset-plus-step-through states for 16 seeds and targeted stress cases. +- The helpers do not allocate, do not call Python, and keep JAX details that + matter for these routines, including clamped gather-style indexing, `where` and + `select` ordering, potion `-1` indexing, and the `read_book` split plus + probability-choice path. + +Native-step roadmap checklist: + +- [x] Native reset PRNG, noise, 9-floor world generation, and reset observation. +- [x] Standalone native simple step subsystems with JAX-parity tests. +- [x] Standalone native medium step subsystems with JAX-parity tests. +- [x] Standalone native crafting and placement subsystems with JAX-parity tests. +- [x] Standalone native `do_action` subsystem with JAX-parity tests. +- [ ] Standalone native ports for the remaining mob step subsystems: + `update_mobs` and `spawn_mobs`. +- [ ] Native reward, terminal, timestep, light-level, RNG, and achievement-delta + bookkeeping around the subsystem calls. +- [ ] Integrate all green subsystem ports into a native `c_step` behind one + explicit switch, then remove the Python/JAX proxy from the normal step path. +- [ ] Restore production vector sizes in `config/ocean/craftax.ini` after native + step is the default. +- [ ] Benchmark CPU throughput only after the proxy path is gone. + +## 2026-04-18 Native 9-Floor Reset Worldgen + +This phase replaces the JAX reset call with native C reset world generation for +the default `Craftax-Symbolic-v1` environment parameters. + +- `worldgen.h` now mirrors `generate_world` for all nine floors: + - floor 0 overworld smoothworld + - floor 1 dungeon + - floor 2 gnomish mines smoothworld + - floor 3 sewers dungeon + - floor 4 vaults dungeon + - floor 5 troll mines smoothworld + - floor 6 fire smoothworld + - floor 7 ice smoothworld + - floor 8 boss smoothworld +- Native reset generation covers `map`, `item_map`, `mob_map`, `light_map`, + ladders, chest flags, `monsters_killed[0] = 10`, empty mob/projectile arrays, + projectile directions, empty plants, the random `potion_mapping`, `state_rng`, + and the scalar reset fields used by symbolic observations. +- `craftax_encode_reset_observation` encodes the native reset state into the + flat symbolic observation, so `c_reset` no longer imports Python or calls JAX. +- `tests/craftax_worldgen_test.py` compares the native C reset state against JAX + `generate_world` for 16 seeds, with exact map/item/ladder/potion/scalar checks + and `atol=1e-6` for light and float state. +- The Python/JAX proxy is still used for `c_step`. Because step state is still + JAX-owned, native `c_reset` marks the proxy dirty and the first delegated step + lazily calls the proxy reset before applying the action. This keeps reset + Python-free while preserving current step parity. + +Remaining proxy paths: + +- All step logic, rewards, achievements, auto-reset behavior after a delegated + step, mob updates, inventory updates, and logging data still come from the + Python/JAX proxy. +- `c_step` still allocates through Python/JAX and serializes on the GIL. The + next porting phase should move gameplay state transitions native and remove + the lazy step-side proxy reset. +- Rendering remains a no-op. +- `config/ocean/craftax.ini` still uses a small proxy-friendly vector size. The + native port should raise this once step no longer calls Python. + +## 2026-04-18 Native Floor-0 Reset Slice + +This phase added the first native C replacement pieces while keeping the JAX +proxy as the oracle for all live game state and step logic. + +- `threefry.h` ports JAX's `threefry2x32` PRNG for uint32 seeds, including + `PRNGKey(seed)`, partitionable `split`/`split_n`, `fold_in`, and + `uniform_u32`/float32 uniform helpers. `tests/craftax_threefry_test.py` + compares bitwise against `jax.random.PRNGKey`, `split`, `fold_in`, and + `bits`. +- `noise.h` ports `craftax/craftax/util/noise.py` for Perlin and fractal 2D + noise. The test uses soft parity because C `sinf`/`cosf` and XLA + transcendental lowering can differ by a few ulps; no JAX FFT path is used. + `tests/craftax_noise_test.py` enforces `atol=rtol=2e-6`. +- `worldgen.h` ports default overworld `generate_smoothworld` for floor 0: + `map`, `item_map`, `light_map`, `ladder_down`, and `ladder_up`. + `tests/craftax_worldgen_floor0_test.py` compares these arrays against JAX for + default reset seeds. +- `c_reset` still calls the JAX proxy to build the full observation and retain + the JAX-owned state, then overwrites the visible floor-0 map/item/light + observation channels from native C. Because native floor-0 generation matches + the JAX reset data for default seeds, end-to-end step parity remains intact. + +Remaining proxy paths: + +- Floors 1..8 are still generated by JAX. +- The live `EnvState`, all step logic, rewards, achievements, auto-reset, mobs, + inventory, and logging data still come from the Python/JAX proxy. +- The native floor-0 arrays are not yet installed into the JAX state object; + this is safe only because the native generator currently matches the JAX + oracle for the covered default reset path. + +## Current Implementation + +`ocean/craftax/` is wired as a full Craftax Ocean environment with the correct +symbolic observation size (`8268`) and action count (`43`). The C header declares +the full Craftax enum set and an `EnvState`-shaped C struct matching the field +order in `craftax_state.py`. + +Reset is native for the full initial `generate_world` state and symbolic +observation. Step remains reference-backed: the C env acquires the Python GIL, +calls the installed JAX `Craftax-Symbolic-v1` implementation, and copies the +resulting float32 observation, reward, terminal flag, and terminal achievement +log into PufferLib-owned buffers. After a native reset, the first delegated step +performs a proxy reset internally so the JAX-owned step state starts from the +same seed and remains aligned with the native reset observation. + +## Deliberate Divergences From The Requested Native Port + +- The Craftax game logic is not yet native C. Step logic, achievements, rewards, + auto-reset behavior after delegated steps, mobs, inventory updates, and other + transition logic are delegated to the JAX oracle. +- `c_step` allocates through Python/JAX and serializes on the GIL. This violates + the final performance target and the intended no-allocation step path. +- `c_close` asks the proxy to drop JAX arrays, then intentionally leaks the small + Python proxy wrapper objects. DECREFing JAX/XLA-owned wrappers during + PufferLib shutdown segfaulted in the proxy baseline; the native port removes + this path. +- Rendering is a no-op. +- `config/ocean/craftax.ini` uses a small proxy-friendly vector size. The native + port should raise this once step no longer calls Python. + +## Known Risks + +- Training throughput is expected to be poor. This baseline is for parity and ABI + validation, not for the Ryzen 9950X3D optimization target. +- `uv run puffer train craftax` currently reaches rollout/train work, but a + 128-step smoke run exits with code 139 during shutdown. The parity harness and + direct `VecEnv` close path exit cleanly; this appears specific to the GPU + trainer plus proxy/JAX runtime cleanup. +- The helper forces `JAX_PLATFORM_NAME=cpu` before importing JAX to avoid using + the shared GPU from inside environment steps. +- `build.sh` now embeds rpaths for wheel-provided CUDA libraries so + `pufferlib._C` can find `libnccl.so.2`. The parity harness still preloads NCCL + defensively for older local builds. + +## Next Native Port Steps + +1. Replace one step subsystem at a time with native logic and keep the proxy as a + local oracle until each subsystem matches. +2. Remove Python/JAX calls from `c_step`, restore large vector sizes, then measure + CPU throughput before optimizing observation encoding, mob updates, and light + propagation. diff --git a/ocean/craftax/binding.c b/ocean/craftax/binding.c index 3b2da51956..95930fb8c2 100644 --- a/ocean/craftax/binding.c +++ b/ocean/craftax/binding.c @@ -1,34 +1,59 @@ +#define CRAFTAX_ENABLE_ENV_IMPL #include "craftax.h" +#include "step_crafting.h" +#include "step_update_mobs.h" +#include "step_spawn_mobs.h" -#define OBS_SIZE 1345 +#define OBS_SIZE CRAFTAX_OBS_SIZE #define NUM_ATNS 1 -#define ACT_SIZES {17} +#define ACT_SIZES {CRAFTAX_NUM_ACTIONS} #define OBS_TENSOR_T FloatTensor #define Env Craftax #include "vecenv.h" void my_init(Env* env, Dict* kwargs) { - // No per-env kwargs for Craftax-Classic: the 64x64 map, inventory sizes, - // mob caps, etc. are all compile-time constants. + env->num_agents = 1; + + uint64_t seed_offset = 0; + DictItem* item = dict_get_unsafe(kwargs, "seed_offset"); + if (item != NULL) { + seed_offset = (uint64_t)item->value; + } + env->seed = seed_offset + (uint64_t)env->rng; + + // Process-wide reset pool (first caller wins, rest block until ready). + // 0 disables caching -- regenerate every reset (exact parity mode). + int reset_pool_size = 0; + DictItem* pool_item = dict_get_unsafe(kwargs, "reset_pool_size"); + if (pool_item != NULL) reset_pool_size = (int)pool_item->value; + craftax_set_reset_pool_size(reset_pool_size); + c_init(env); } void my_log(Log* log, Dict* out) { - dict_set(out, "perf", log->perf); - dict_set(out, "score", log->score); + dict_set(out, "perf", log->perf); + dict_set(out, "score", log->score); dict_set(out, "episode_return", log->episode_return); dict_set(out, "episode_length", log->episode_length); - static const char* ACH_NAMES[NUM_ACHIEVEMENTS] = { - "collect_wood", "place_table", "eat_cow", "collect_sapling", - "collect_drink", "make_wood_pick", "make_wood_sword","place_plant", - "defeat_zombie", "collect_stone", "place_stone", "eat_plant", - "defeat_skeleton","make_stone_pick","make_stone_sword","wake_up", - "place_furnace", "collect_coal", "collect_iron", "collect_diamond", - "make_iron_pick", "make_iron_sword", + // Log 8 checkpoint achievements that form the tech / exploration curve. + // perf (above) already aggregates all 67 into a normalized score; the + // individual lines here are the milestones worth watching on a dashboard. + // The env still tracks all 67 internally for reward and perf; we just + // don't send every one through the log Dict. + struct { const char* name; int idx; } checkpoints[] = { + {"collect_wood", 0}, + {"make_wood_pickaxe", 5}, + {"make_stone_pickaxe", 13}, + {"collect_iron", 18}, + {"make_iron_pickaxe", 20}, + {"collect_diamond", 19}, + {"enter_gnomish_mines", 28}, + {"defeat_necromancer", 48}, }; - for (int i = 0; i < NUM_ACHIEVEMENTS; i++) { - dict_set(out, ACH_NAMES[i], log->achievements[i]); + for (int i = 0; i < (int)(sizeof(checkpoints) / sizeof(checkpoints[0])); i++) { + dict_set(out, checkpoints[i].name, log->achievements[checkpoints[i].idx]); } } diff --git a/ocean/craftax/craftax.c b/ocean/craftax/craftax.c new file mode 100644 index 0000000000..cec3eb2cec --- /dev/null +++ b/ocean/craftax/craftax.c @@ -0,0 +1,76 @@ +// Standalone viewer for Craftax (random-action policy). +// +// Build: +// ./build.sh craftax --fast # optimized +// ./build.sh craftax --local # debug with sanitizers +// Run: +// ./craftax + +#define CRAFTAX_ENABLE_ENV_IMPL +#include "craftax.h" +#include "step_crafting.h" +#include "step_update_mobs.h" +#include "step_spawn_mobs.h" + +#include +#include +#include + +static uint32_t xorshift32(uint32_t* s) { + uint32_t x = *s; + x ^= x << 13; x ^= x >> 17; x ^= x << 5; + *s = x ? x : 0xdeadbeef; + return x; +} + +int main(int argc, char** argv) { + uint64_t seed = (argc > 1) ? strtoull(argv[1], NULL, 10) : (uint64_t)time(NULL); + + Craftax env; + memset(&env, 0, sizeof(env)); + env.num_agents = 1; + env.seed = seed; + env.rng = (uint32_t)seed; + + // Minimal buffers for a single agent + env.observations = calloc(CRAFTAX_OBS_SIZE, sizeof(float)); + env.actions = calloc(1, sizeof(float)); + env.rewards = calloc(1, sizeof(float)); + env.terminals = calloc(1, sizeof(float)); + + c_init(&env); + c_reset(&env); + + uint32_t action_rng = (uint32_t)(seed ^ 0x9E3779B9u); + bool human_control = false; + int human_action = CRAFTAX_ACTION_NOOP; + + while (!WindowShouldClose()) { + // Toggle human control + if (IsKeyPressed(KEY_H)) human_control = !human_control; + + if (human_control) { + human_action = CRAFTAX_ACTION_NOOP; + if (IsKeyPressed(KEY_A) || IsKeyPressed(KEY_LEFT)) human_action = CRAFTAX_ACTION_LEFT; + if (IsKeyPressed(KEY_D) || IsKeyPressed(KEY_RIGHT)) human_action = CRAFTAX_ACTION_RIGHT; + if (IsKeyPressed(KEY_W) || IsKeyPressed(KEY_UP)) human_action = CRAFTAX_ACTION_UP; + if (IsKeyPressed(KEY_S) || IsKeyPressed(KEY_DOWN)) human_action = CRAFTAX_ACTION_DOWN; + if (IsKeyPressed(KEY_SPACE)) human_action = CRAFTAX_ACTION_DO; + if (IsKeyPressed(KEY_Z)) human_action = CRAFTAX_ACTION_SLEEP; + env.actions[0] = (float)human_action; + if (human_action != CRAFTAX_ACTION_NOOP || IsKeyPressed(KEY_PERIOD)) c_step(&env); + } else { + env.actions[0] = (float)(xorshift32(&action_rng) % CRAFTAX_NUM_ACTIONS); + c_step(&env); + } + + c_render(&env); + } + + c_close(&env); + free(env.observations); + free(env.actions); + free(env.rewards); + free(env.terminals); + return 0; +} diff --git a/ocean/craftax/craftax.h b/ocean/craftax/craftax.h index e7a9ff4860..b82c0ea470 100644 --- a/ocean/craftax/craftax.h +++ b/ocean/craftax/craftax.h @@ -1,1015 +1,889 @@ -// Craftax-Classic environment for PufferLib Ocean. -// -// Single-header per-env implementation. PufferLib's vec layer owns the -// observation/action/reward/terminal buffers and parallelizes c_step -// across env instances via OpenMP; this file never allocates its own -// threads or batches. -// -// Game rules follow Matthews et al. 2024 "Craftax-Classic" (ICML 2024). -// This port is derived from the CPU port at github.com/Infatoshi/craftax.c -// (47.8M SPS standalone), restructured to match the Ocean conventions -// used by breakout/drmario/etc. -// -// Observation: 1345 float32: -// - 63 tiles (7x9 local view) x 21 channels (17 block one-hot + 4 mob) = 1323 -// - 12 inventory (0..9) / 10 -// - 4 intrinsics (health, food, drink, energy / 10) -// - 4 direction one-hot -// - 1 light level [0, 1] -// - 1 is_sleeping {0, 1} -// Matches the JAX/CUDA Craftax-Classic-Symbolic-v1 layout exactly. -// -// Action: 1 discrete in 0..16 (NOOP, 4 moves, DO, SLEEP, -// 4 place, 3 make-pick, 3 make-sword). +// Full native Craftax environment for PufferLib Ocean. #pragma once -#include -#include -#include + #include -#include -#include +#include +#include +#include + +#include "worldgen.h" #include "raylib.h" +#include +#include // ============================================================ // Constants // ============================================================ -#define MAP_SIZE 64 -#define MAP_PACKED_ROW 32 -#define MAP_PACKED_SIZE (MAP_SIZE * MAP_PACKED_ROW) - -#define MAX_ZOMBIES 3 -#define MAX_COWS 3 -#define MAX_SKELETONS 2 -#define MAX_ARROWS 3 -#define MAX_PLANTS 10 -#define NUM_ACHIEVEMENTS 22 -#define NUM_ACTIONS 17 -#define NUM_BLOCK_TYPES 17 -#define OBS_DIM 1345 -#define NUM_INVENTORY 12 -#define MAX_TIMESTEPS 10000 -#define DAY_LENGTH 300 -#define MOB_DESPAWN_DIST 14 - -// Block types -#define BLK_INVALID 0 -#define BLK_OUT_OF_BOUNDS 1 -#define BLK_GRASS 2 -#define BLK_WATER 3 -#define BLK_STONE 4 -#define BLK_TREE 5 -#define BLK_WOOD 6 -#define BLK_PATH 7 -#define BLK_COAL 8 -#define BLK_IRON 9 -#define BLK_DIAMOND 10 -#define BLK_TABLE 11 -#define BLK_FURNACE 12 -#define BLK_SAND 13 -#define BLK_LAVA 14 -#define BLK_PLANT 15 -#define BLK_RIPE_PLANT 16 - -// Actions -#define ACT_NOOP 0 -#define ACT_LEFT 1 -#define ACT_RIGHT 2 -#define ACT_UP 3 -#define ACT_DOWN 4 -#define ACT_DO 5 -#define ACT_SLEEP 6 -#define ACT_PLACE_STONE 7 -#define ACT_PLACE_TABLE 8 -#define ACT_PLACE_FURNACE 9 -#define ACT_PLACE_PLANT 10 -#define ACT_MAKE_WOOD_PICK 11 -#define ACT_MAKE_STONE_PICK 12 -#define ACT_MAKE_IRON_PICK 13 -#define ACT_MAKE_WOOD_SWORD 14 -#define ACT_MAKE_STONE_SWORD 15 -#define ACT_MAKE_IRON_SWORD 16 - -// Achievements (index in env->log.achievements[]) -#define ACH_COLLECT_WOOD 0 -#define ACH_PLACE_TABLE 1 -#define ACH_EAT_COW 2 -#define ACH_COLLECT_SAPLING 3 -#define ACH_COLLECT_DRINK 4 -#define ACH_MAKE_WOOD_PICK 5 -#define ACH_MAKE_WOOD_SWORD 6 -#define ACH_PLACE_PLANT 7 -#define ACH_DEFEAT_ZOMBIE 8 -#define ACH_COLLECT_STONE 9 -#define ACH_PLACE_STONE 10 -#define ACH_EAT_PLANT 11 -#define ACH_DEFEAT_SKELETON 12 -#define ACH_MAKE_STONE_PICK 13 -#define ACH_MAKE_STONE_SWORD 14 -#define ACH_WAKE_UP 15 -#define ACH_PLACE_FURNACE 16 -#define ACH_COLLECT_COAL 17 -#define ACH_COLLECT_IRON 18 -#define ACH_COLLECT_DIAMOND 19 -#define ACH_MAKE_IRON_PICK 20 -#define ACH_MAKE_IRON_SWORD 21 - -static const int DIR_DR[5] = {0, 0, 0, -1, 1}; -static const int DIR_DC[5] = {0, -1, 1, 0, 0}; +#define CRAFTAX_OBS_ROWS 9 +#define CRAFTAX_OBS_COLS 11 +#define CRAFTAX_MAP_SIZE 48 +#define CRAFTAX_NUM_LEVELS 9 + +#define CRAFTAX_NUM_BLOCK_TYPES 37 +#define CRAFTAX_NUM_ITEM_TYPES 5 +#define CRAFTAX_NUM_MOB_CLASSES 5 +#define CRAFTAX_NUM_MOB_TYPES 8 +#define CRAFTAX_INVENTORY_OBS_SIZE 51 +#define CRAFTAX_OBS_SIZE 8268 + +#define CRAFTAX_NUM_ACTIONS 43 +#define CRAFTAX_NUM_ACHIEVEMENTS 67 + +#define CRAFTAX_MAX_MELEE_MOBS 3 +#define CRAFTAX_MAX_PASSIVE_MOBS 3 +#define CRAFTAX_MAX_RANGED_MOBS 2 +#define CRAFTAX_MAX_MOB_PROJECTILES 3 +#define CRAFTAX_MAX_PLAYER_PROJECTILES 3 +#define CRAFTAX_MAX_GROWING_PLANTS 10 + +#define CRAFTAX_DEFAULT_MAX_TIMESTEPS 100000 +#define CRAFTAX_DAY_LENGTH 300 +#define CRAFTAX_MAX_ATTRIBUTE 5 +#define CRAFTAX_MOB_DESPAWN_DISTANCE 14 +#define CRAFTAX_MONSTERS_KILLED_TO_CLEAR_LEVEL 8 // ============================================================ -// Tiny PCG-style RNG (single 64-bit state) +// Enums copied from craftax/craftax/constants.py // ============================================================ -static inline uint32_t cr_pcg(uint64_t* s) { - *s = *s * 6364136223846793005ULL + 1442695040888963407ULL; - uint32_t x = (uint32_t)(((*s >> 18u) ^ *s) >> 27u); - uint32_t rot = (uint32_t)(*s >> 59u); - return (x >> rot) | (x << ((-(int32_t)rot) & 31)); -} -static inline float cr_rf(uint64_t* s) { return (cr_pcg(s) >> 8) * (1.0f / 16777216.0f); } -static inline int cr_ri(uint64_t* s, int n) { return (int)(cr_pcg(s) % (uint32_t)n); } +typedef enum CraftaxBlockType { + CRAFTAX_BLOCK_INVALID = 0, + CRAFTAX_BLOCK_OUT_OF_BOUNDS = 1, + CRAFTAX_BLOCK_GRASS = 2, + CRAFTAX_BLOCK_WATER = 3, + CRAFTAX_BLOCK_STONE = 4, + CRAFTAX_BLOCK_TREE = 5, + CRAFTAX_BLOCK_WOOD = 6, + CRAFTAX_BLOCK_PATH = 7, + CRAFTAX_BLOCK_COAL = 8, + CRAFTAX_BLOCK_IRON = 9, + CRAFTAX_BLOCK_DIAMOND = 10, + CRAFTAX_BLOCK_CRAFTING_TABLE = 11, + CRAFTAX_BLOCK_FURNACE = 12, + CRAFTAX_BLOCK_SAND = 13, + CRAFTAX_BLOCK_LAVA = 14, + CRAFTAX_BLOCK_PLANT = 15, + CRAFTAX_BLOCK_RIPE_PLANT = 16, + CRAFTAX_BLOCK_WALL = 17, + CRAFTAX_BLOCK_DARKNESS = 18, + CRAFTAX_BLOCK_WALL_MOSS = 19, + CRAFTAX_BLOCK_STALAGMITE = 20, + CRAFTAX_BLOCK_SAPPHIRE = 21, + CRAFTAX_BLOCK_RUBY = 22, + CRAFTAX_BLOCK_CHEST = 23, + CRAFTAX_BLOCK_FOUNTAIN = 24, + CRAFTAX_BLOCK_FIRE_GRASS = 25, + CRAFTAX_BLOCK_ICE_GRASS = 26, + CRAFTAX_BLOCK_GRAVEL = 27, + CRAFTAX_BLOCK_FIRE_TREE = 28, + CRAFTAX_BLOCK_ICE_SHRUB = 29, + CRAFTAX_BLOCK_ENCHANTMENT_TABLE_FIRE = 30, + CRAFTAX_BLOCK_ENCHANTMENT_TABLE_ICE = 31, + CRAFTAX_BLOCK_NECROMANCER = 32, + CRAFTAX_BLOCK_GRAVE = 33, + CRAFTAX_BLOCK_GRAVE2 = 34, + CRAFTAX_BLOCK_GRAVE3 = 35, + CRAFTAX_BLOCK_NECROMANCER_VULNERABLE = 36, +} CraftaxBlockType; + +typedef enum CraftaxItemType { + CRAFTAX_ITEM_NONE = 0, + CRAFTAX_ITEM_TORCH = 1, + CRAFTAX_ITEM_LADDER_DOWN = 2, + CRAFTAX_ITEM_LADDER_UP = 3, + CRAFTAX_ITEM_LADDER_DOWN_BLOCKED = 4, +} CraftaxItemType; + +typedef enum CraftaxAction { + CRAFTAX_ACTION_NOOP = 0, + CRAFTAX_ACTION_LEFT = 1, + CRAFTAX_ACTION_RIGHT = 2, + CRAFTAX_ACTION_UP = 3, + CRAFTAX_ACTION_DOWN = 4, + CRAFTAX_ACTION_DO = 5, + CRAFTAX_ACTION_SLEEP = 6, + CRAFTAX_ACTION_PLACE_STONE = 7, + CRAFTAX_ACTION_PLACE_TABLE = 8, + CRAFTAX_ACTION_PLACE_FURNACE = 9, + CRAFTAX_ACTION_PLACE_PLANT = 10, + CRAFTAX_ACTION_MAKE_WOOD_PICKAXE = 11, + CRAFTAX_ACTION_MAKE_STONE_PICKAXE = 12, + CRAFTAX_ACTION_MAKE_IRON_PICKAXE = 13, + CRAFTAX_ACTION_MAKE_WOOD_SWORD = 14, + CRAFTAX_ACTION_MAKE_STONE_SWORD = 15, + CRAFTAX_ACTION_MAKE_IRON_SWORD = 16, + CRAFTAX_ACTION_REST = 17, + CRAFTAX_ACTION_DESCEND = 18, + CRAFTAX_ACTION_ASCEND = 19, + CRAFTAX_ACTION_MAKE_DIAMOND_PICKAXE = 20, + CRAFTAX_ACTION_MAKE_DIAMOND_SWORD = 21, + CRAFTAX_ACTION_MAKE_IRON_ARMOUR = 22, + CRAFTAX_ACTION_MAKE_DIAMOND_ARMOUR = 23, + CRAFTAX_ACTION_SHOOT_ARROW = 24, + CRAFTAX_ACTION_MAKE_ARROW = 25, + CRAFTAX_ACTION_CAST_FIREBALL = 26, + CRAFTAX_ACTION_CAST_ICEBALL = 27, + CRAFTAX_ACTION_PLACE_TORCH = 28, + CRAFTAX_ACTION_DRINK_POTION_RED = 29, + CRAFTAX_ACTION_DRINK_POTION_GREEN = 30, + CRAFTAX_ACTION_DRINK_POTION_BLUE = 31, + CRAFTAX_ACTION_DRINK_POTION_PINK = 32, + CRAFTAX_ACTION_DRINK_POTION_CYAN = 33, + CRAFTAX_ACTION_DRINK_POTION_YELLOW = 34, + CRAFTAX_ACTION_READ_BOOK = 35, + CRAFTAX_ACTION_ENCHANT_SWORD = 36, + CRAFTAX_ACTION_ENCHANT_ARMOUR = 37, + CRAFTAX_ACTION_MAKE_TORCH = 38, + CRAFTAX_ACTION_LEVEL_UP_DEXTERITY = 39, + CRAFTAX_ACTION_LEVEL_UP_STRENGTH = 40, + CRAFTAX_ACTION_LEVEL_UP_INTELLIGENCE = 41, + CRAFTAX_ACTION_ENCHANT_BOW = 42, +} CraftaxAction; + +typedef enum CraftaxMobType { + CRAFTAX_MOB_PASSIVE = 0, + CRAFTAX_MOB_MELEE = 1, + CRAFTAX_MOB_RANGED = 2, + CRAFTAX_MOB_PROJECTILE = 3, +} CraftaxMobType; + +typedef enum CraftaxProjectileType { + CRAFTAX_PROJECTILE_ARROW = 0, + CRAFTAX_PROJECTILE_DAGGER = 1, + CRAFTAX_PROJECTILE_FIREBALL = 2, + CRAFTAX_PROJECTILE_ICEBALL = 3, + CRAFTAX_PROJECTILE_ARROW2 = 4, + CRAFTAX_PROJECTILE_SLIMEBALL = 5, + CRAFTAX_PROJECTILE_FIREBALL2 = 6, + CRAFTAX_PROJECTILE_ICEBALL2 = 7, +} CraftaxProjectileType; + +typedef enum CraftaxAchievement { + CRAFTAX_ACH_COLLECT_WOOD = 0, + CRAFTAX_ACH_PLACE_TABLE = 1, + CRAFTAX_ACH_EAT_COW = 2, + CRAFTAX_ACH_COLLECT_SAPLING = 3, + CRAFTAX_ACH_COLLECT_DRINK = 4, + CRAFTAX_ACH_MAKE_WOOD_PICKAXE = 5, + CRAFTAX_ACH_MAKE_WOOD_SWORD = 6, + CRAFTAX_ACH_PLACE_PLANT = 7, + CRAFTAX_ACH_DEFEAT_ZOMBIE = 8, + CRAFTAX_ACH_COLLECT_STONE = 9, + CRAFTAX_ACH_PLACE_STONE = 10, + CRAFTAX_ACH_EAT_PLANT = 11, + CRAFTAX_ACH_DEFEAT_SKELETON = 12, + CRAFTAX_ACH_MAKE_STONE_PICKAXE = 13, + CRAFTAX_ACH_MAKE_STONE_SWORD = 14, + CRAFTAX_ACH_WAKE_UP = 15, + CRAFTAX_ACH_PLACE_FURNACE = 16, + CRAFTAX_ACH_COLLECT_COAL = 17, + CRAFTAX_ACH_COLLECT_IRON = 18, + CRAFTAX_ACH_COLLECT_DIAMOND = 19, + CRAFTAX_ACH_MAKE_IRON_PICKAXE = 20, + CRAFTAX_ACH_MAKE_IRON_SWORD = 21, + CRAFTAX_ACH_MAKE_ARROW = 22, + CRAFTAX_ACH_MAKE_TORCH = 23, + CRAFTAX_ACH_PLACE_TORCH = 24, + CRAFTAX_ACH_MAKE_DIAMOND_SWORD = 25, + CRAFTAX_ACH_MAKE_IRON_ARMOUR = 26, + CRAFTAX_ACH_MAKE_DIAMOND_ARMOUR = 27, + CRAFTAX_ACH_ENTER_GNOMISH_MINES = 28, + CRAFTAX_ACH_ENTER_DUNGEON = 29, + CRAFTAX_ACH_ENTER_SEWERS = 30, + CRAFTAX_ACH_ENTER_VAULT = 31, + CRAFTAX_ACH_ENTER_TROLL_MINES = 32, + CRAFTAX_ACH_ENTER_FIRE_REALM = 33, + CRAFTAX_ACH_ENTER_ICE_REALM = 34, + CRAFTAX_ACH_ENTER_GRAVEYARD = 35, + CRAFTAX_ACH_DEFEAT_GNOME_WARRIOR = 36, + CRAFTAX_ACH_DEFEAT_GNOME_ARCHER = 37, + CRAFTAX_ACH_DEFEAT_ORC_SOLIDER = 38, + CRAFTAX_ACH_DEFEAT_ORC_MAGE = 39, + CRAFTAX_ACH_DEFEAT_LIZARD = 40, + CRAFTAX_ACH_DEFEAT_KOBOLD = 41, + CRAFTAX_ACH_DEFEAT_TROLL = 42, + CRAFTAX_ACH_DEFEAT_DEEP_THING = 43, + CRAFTAX_ACH_DEFEAT_PIGMAN = 44, + CRAFTAX_ACH_DEFEAT_FIRE_ELEMENTAL = 45, + CRAFTAX_ACH_DEFEAT_FROST_TROLL = 46, + CRAFTAX_ACH_DEFEAT_ICE_ELEMENTAL = 47, + CRAFTAX_ACH_DAMAGE_NECROMANCER = 48, + CRAFTAX_ACH_DEFEAT_NECROMANCER = 49, + CRAFTAX_ACH_EAT_BAT = 50, + CRAFTAX_ACH_EAT_SNAIL = 51, + CRAFTAX_ACH_FIND_BOW = 52, + CRAFTAX_ACH_FIRE_BOW = 53, + CRAFTAX_ACH_COLLECT_SAPPHIRE = 54, + CRAFTAX_ACH_LEARN_FIREBALL = 55, + CRAFTAX_ACH_CAST_FIREBALL = 56, + CRAFTAX_ACH_LEARN_ICEBALL = 57, + CRAFTAX_ACH_CAST_ICEBALL = 58, + CRAFTAX_ACH_COLLECT_RUBY = 59, + CRAFTAX_ACH_MAKE_DIAMOND_PICKAXE = 60, + CRAFTAX_ACH_OPEN_CHEST = 61, + CRAFTAX_ACH_DRINK_POTION = 62, + CRAFTAX_ACH_ENCHANT_SWORD = 63, + CRAFTAX_ACH_ENCHANT_ARMOUR = 64, + CRAFTAX_ACH_DEFEAT_KNIGHT = 65, + CRAFTAX_ACH_DEFEAT_ARCHER = 66, +} CraftaxAchievement; // ============================================================ -// PufferLib-required structs +// State layout declarations matching craftax_state.py field order // ============================================================ +typedef struct CraftaxInventory { + int32_t wood; + int32_t stone; + int32_t coal; + int32_t iron; + int32_t diamond; + int32_t sapling; + int32_t pickaxe; + int32_t sword; + int32_t bow; + int32_t arrows; + int32_t armour[4]; + int32_t torches; + int32_t ruby; + int32_t sapphire; + int32_t potions[6]; + int32_t books; +} CraftaxInventory; + +typedef struct CraftaxMobs3 { + int32_t position[CRAFTAX_NUM_LEVELS][3][2]; + float health[CRAFTAX_NUM_LEVELS][3]; + bool mask[CRAFTAX_NUM_LEVELS][3]; + int32_t attack_cooldown[CRAFTAX_NUM_LEVELS][3]; + int32_t type_id[CRAFTAX_NUM_LEVELS][3]; +} CraftaxMobs3; + +typedef struct CraftaxMobs2 { + int32_t position[CRAFTAX_NUM_LEVELS][2][2]; + float health[CRAFTAX_NUM_LEVELS][2]; + bool mask[CRAFTAX_NUM_LEVELS][2]; + int32_t attack_cooldown[CRAFTAX_NUM_LEVELS][2]; + int32_t type_id[CRAFTAX_NUM_LEVELS][2]; +} CraftaxMobs2; + +typedef struct CraftaxState { + int32_t map[CRAFTAX_NUM_LEVELS][CRAFTAX_MAP_SIZE][CRAFTAX_MAP_SIZE]; + int32_t item_map[CRAFTAX_NUM_LEVELS][CRAFTAX_MAP_SIZE][CRAFTAX_MAP_SIZE]; + bool mob_map[CRAFTAX_NUM_LEVELS][CRAFTAX_MAP_SIZE][CRAFTAX_MAP_SIZE]; + float light_map[CRAFTAX_NUM_LEVELS][CRAFTAX_MAP_SIZE][CRAFTAX_MAP_SIZE]; + int32_t down_ladders[CRAFTAX_NUM_LEVELS][2]; + int32_t up_ladders[CRAFTAX_NUM_LEVELS][2]; + bool chests_opened[CRAFTAX_NUM_LEVELS]; + int32_t monsters_killed[CRAFTAX_NUM_LEVELS]; + + int32_t player_position[2]; + int32_t player_level; + int32_t player_direction; + + float player_health; + int32_t player_food; + int32_t player_drink; + int32_t player_energy; + int32_t player_mana; + bool is_sleeping; + bool is_resting; + + float player_recover; + float player_hunger; + float player_thirst; + float player_fatigue; + float player_recover_mana; + + int32_t player_xp; + int32_t player_dexterity; + int32_t player_strength; + int32_t player_intelligence; + + CraftaxInventory inventory; + + CraftaxMobs3 melee_mobs; + CraftaxMobs3 passive_mobs; + CraftaxMobs2 ranged_mobs; + + CraftaxMobs3 mob_projectiles; + int32_t mob_projectile_directions[CRAFTAX_NUM_LEVELS][CRAFTAX_MAX_MOB_PROJECTILES][2]; + CraftaxMobs3 player_projectiles; + int32_t player_projectile_directions[CRAFTAX_NUM_LEVELS][CRAFTAX_MAX_PLAYER_PROJECTILES][2]; + + int32_t growing_plants_positions[CRAFTAX_MAX_GROWING_PLANTS][2]; + int32_t growing_plants_age[CRAFTAX_MAX_GROWING_PLANTS]; + bool growing_plants_mask[CRAFTAX_MAX_GROWING_PLANTS]; + + int32_t potion_mapping[6]; + bool learned_spells[2]; + + int32_t sword_enchantment; + int32_t bow_enchantment; + int32_t armour_enchantments[4]; + + int32_t boss_progress; + int32_t boss_timesteps_to_spawn_this_round; + + float light_level; + bool achievements[CRAFTAX_NUM_ACHIEVEMENTS]; + uint32_t state_rng[2]; + int32_t timestep; + int32_t fractal_noise_angles[4]; +} CraftaxState; + +typedef char CraftaxStateMatchesWorldState[ + (sizeof(CraftaxState) == sizeof(CraftaxWorldState)) ? 1 : -1 +]; + +#ifdef CRAFTAX_ENABLE_ENV_IMPL +static inline void craftax_change_floor_native(CraftaxState* state, int32_t action); +static inline void craftax_do_crafting_native(CraftaxState* state, int32_t action); +static inline void craftax_do_action_native( + CraftaxState* state, + int32_t action, + CraftaxThreefryKey rng +); +static inline void craftax_place_block_native(CraftaxState* state, int32_t action); +static inline void craftax_shoot_projectile_native( + CraftaxState* state, + int32_t action +); +static inline void craftax_cast_spell_native(CraftaxState* state, int32_t action); +static inline void craftax_drink_potion_native(CraftaxState* state, int32_t action); +static inline void craftax_read_book_native( + CraftaxState* state, + const uint32_t rng_words[2], + int32_t action +); +static inline void craftax_enchant_native( + CraftaxState* state, + int32_t action, + CraftaxThreefryKey rng +); +static inline void craftax_boss_logic_native(CraftaxState* state); +static inline void craftax_level_up_attributes_native( + CraftaxState* state, + int32_t action, + int32_t max_attribute +); +static inline void craftax_move_player_native( + CraftaxState* state, + int32_t action, + bool god_mode +); +static inline void craftax_update_mobs_native( + CraftaxState* state, + CraftaxThreefryKey rng +); +static inline void craftax_spawn_mobs_native( + CraftaxState* state, + CraftaxThreefryKey rng +); +static inline void craftax_update_plants_native(CraftaxState* state); +static inline void craftax_update_player_intrinsics_native( + CraftaxState* state, + int32_t action +); +static inline void craftax_clip_inventory_and_intrinsics_native( + CraftaxState* state, + bool god_mode +); +static inline void craftax_calculate_inventory_achievements_native( + CraftaxState* state +); +#endif + typedef struct Log { - float perf; // 0-1 normalized progress (achievements / 22) - float score; // sum of episode returns seen so far - float episode_return; // last episode return - float episode_length; // last episode length - float achievements[NUM_ACHIEVEMENTS]; - float n; // required counter (last field) + float perf; + float score; + float episode_return; + float episode_length; + float achievements[CRAFTAX_NUM_ACHIEVEMENTS]; + float n; } Log; typedef struct Client { - int dummy; // handled by raylib globally; no per-env handle needed + int unused; } Client; -// ============================================================ -// Env struct -// ============================================================ typedef struct Craftax { Client* client; Log log; - float* observations; // (OBS_DIM,) fp32, PufferLib-owned - float* actions; // (1,) fp32 - float* rewards; // (1,) - float* terminals; // (1,) - - int num_agents; // = 1 - - unsigned int rng; // populated by default my_vec_init (env index) - uint64_t pcg; // actual RNG state (seeded from rng in my_init) - - // Packed map (2 blocks/byte) - uint8_t map_packed[MAP_PACKED_SIZE]; - - // Per-type occupancy bitmaps: bit c of bits[r] = "mob-type at (r,c)" - uint64_t mob_bits[MAP_SIZE]; // zombie | cow | skel (used by has_mob_at / can_move_mob) - uint64_t zombie_bits[MAP_SIZE]; - uint64_t cow_bits[MAP_SIZE]; - uint64_t skel_bits[MAP_SIZE]; - uint64_t arrow_bits[MAP_SIZE]; - - // Player - int16_t player_r, player_c; - int8_t player_dir; - - // Intrinsics - int8_t health, food, drink, energy; - bool is_sleeping; - float recover, hunger, thirst, fatigue; - - // Inventory (wood, stone, coal, iron, diamond, sapling, - // wpick, spick, ipick, wsword, ssword, isword) - int8_t inv[NUM_INVENTORY]; - - // Mobs - int16_t zombie_r[MAX_ZOMBIES], zombie_c[MAX_ZOMBIES]; - int8_t zombie_hp[MAX_ZOMBIES], zombie_cd[MAX_ZOMBIES]; - bool zombie_mask[MAX_ZOMBIES]; + float* observations; + float* actions; + float* rewards; + float* terminals; + int num_agents; - int16_t cow_r[MAX_COWS], cow_c[MAX_COWS]; - int8_t cow_hp[MAX_COWS]; - bool cow_mask[MAX_COWS]; + unsigned int rng; + uint64_t seed; + CraftaxThreefryKey rng_key; + CraftaxState state; - int16_t skel_r[MAX_SKELETONS], skel_c[MAX_SKELETONS]; - int8_t skel_hp[MAX_SKELETONS], skel_cd[MAX_SKELETONS]; - bool skel_mask[MAX_SKELETONS]; - - int16_t arrow_r[MAX_ARROWS], arrow_c[MAX_ARROWS]; - int8_t arrow_dr[MAX_ARROWS], arrow_dc[MAX_ARROWS]; - bool arrow_mask[MAX_ARROWS]; - - int16_t plant_r[MAX_PLANTS], plant_c[MAX_PLANTS]; - int16_t plant_age[MAX_PLANTS]; - bool plant_mask[MAX_PLANTS]; - - float light_level; - bool achievements[NUM_ACHIEVEMENTS]; - int32_t timestep; - - // Episode stats (accumulated; flushed into env->log on terminal) + float achievements[CRAFTAX_NUM_ACHIEVEMENTS]; float episode_return_accum; int32_t episode_length_accum; - - // Scratch for per-step reward computation - int8_t old_health; - bool old_achievements[NUM_ACHIEVEMENTS]; } Craftax; +#ifdef CRAFTAX_ENABLE_ENV_IMPL + // ============================================================ -// Map accessors + small helpers +// Native reset, observation, reward, and step glue // ============================================================ -static inline int8_t map_get(const Craftax* s, int r, int c) { - int idx = r * MAP_PACKED_ROW + (c >> 1); - uint8_t b = s->map_packed[idx]; - return (c & 1) ? (int8_t)(b >> 4) : (int8_t)(b & 0x0F); -} -static inline void map_set(Craftax* s, int r, int c, int8_t v) { - int idx = r * MAP_PACKED_ROW + (c >> 1); - uint8_t b = s->map_packed[idx]; - if (c & 1) s->map_packed[idx] = (b & 0x0F) | ((v & 0x0F) << 4); - else s->map_packed[idx] = (b & 0xF0) | (v & 0x0F); -} -static inline bool in_bounds(int r, int c) { return (unsigned)r < MAP_SIZE && (unsigned)c < MAP_SIZE; } -static inline bool is_solid(int8_t b) { - return b == BLK_WATER || b == BLK_STONE || b == BLK_TREE || - b == BLK_COAL || b == BLK_IRON || b == BLK_DIAMOND || - b == BLK_TABLE || b == BLK_FURNACE || - b == BLK_PLANT || b == BLK_RIPE_PLANT; +static const float CRAFTAX_ACHIEVEMENT_REWARD_MAP[CRAFTAX_NUM_ACHIEVEMENTS] = { + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 5.0f, 5.0f, + 5.0f, 8.0f, 8.0f, 8.0f, 3.0f, 3.0f, 3.0f, 3.0f, + 5.0f, 5.0f, 5.0f, 5.0f, 8.0f, 8.0f, 8.0f, 8.0f, + 8.0f, 8.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 5.0f, + 5.0f, 5.0f, 5.0f, 3.0f, 3.0f, 3.0f, 3.0f, 5.0f, + 5.0f, 5.0f, 5.0f, +}; + +static inline CraftaxThreefryKey craftax_step_native_next_key( + CraftaxThreefryKey* rng +) { + CraftaxThreefryKey subkey; + craftax_threefry_split(*rng, rng, &subkey); + return subkey; } -static inline int l1_dist(int r1, int c1, int r2, int c2) { - int dr = r1 - r2; if (dr < 0) dr = -dr; - int dc = c1 - c2; if (dc < 0) dc = -dc; - return dr + dc; -} -static inline int cr_clamp_i(int v, int lo, int hi){ return vhi?hi:v); } -static inline int cr_min_i(int a,int b){return ab?a:b;} -static inline float cr_min_f(float a,float b){return a0)-(v<0);} - -// Bitmap maintenance -static inline void mb_set(uint64_t* bits, int r, int c) { bits[r] |= (1ULL << c); } -static inline void mb_clear(uint64_t* bits, int r, int c) { bits[r] &= ~(1ULL << c); } -static inline bool mb_get(const uint64_t* bits, int r, int c) { return (bits[r] >> c) & 1ULL; } - -static inline bool has_mob_at(const Craftax* s, int r, int c) { - if ((unsigned)r >= MAP_SIZE || (unsigned)c >= MAP_SIZE) return false; - return ((s->mob_bits[r] >> c) & 1ULL) != 0; + +static inline void craftax_copy_world_state_to_state( + CraftaxState* dst, + const CraftaxWorldState* src +) { + memcpy(dst, src, sizeof(*dst)); } -static bool is_near_block(const Craftax* s, int8_t blk) { - int pr = s->player_r, pc = s->player_c; - static const int dr8[8] = {0, 0, -1, 1, -1, -1, 1, 1}; - static const int dc8[8] = {-1, 1, 0, 0, -1, 1, -1, 1}; - for (int i = 0; i < 8; i++) { - int nr = pr + dr8[i], nc = pc + dc8[i]; - if (in_bounds(nr, nc) && map_get(s, nr, nc) == blk) return true; - } - return false; +static inline void craftax_generate_state_from_world_key( + CraftaxThreefryKey world_key, + CraftaxState* out +) { + CraftaxWorldState world_state; + craftax_generate_world_from_key(world_key, &world_state); + craftax_copy_world_state_to_state(out, &world_state); } -static inline int get_damage(const Craftax* s) { - if (s->inv[11] > 0) return 5; - if (s->inv[10] > 0) return 3; - if (s->inv[9] > 0) return 2; - return 1; +static inline void craftax_reset_state_from_reset_key( + CraftaxState* out, + CraftaxThreefryKey reset_key +) { + CraftaxThreefryKey unused; + CraftaxThreefryKey world_key; + craftax_threefry_split(reset_key, &unused, &world_key); + craftax_generate_state_from_world_key(world_key, out); } // ============================================================ -// Perlin worldgen (AVX-512, per-env) +// Reset pool: pre-generate N worlds once, then memcpy on reset. +// Trades world diversity (<= pool_size unique maps per process) for +// ~500x faster reset. Set pool_size=0 to disable (exact per-seed +// world; required for the parity harness). // ============================================================ -static inline float perlin_interp(float t) { return t*t*t*(t*(t*6.0f-15.0f)+10.0f); } - -#if defined(__clang__) || defined(__GNUC__) -__attribute__((target("avx512f,avx512bw,avx512dq,avx512vl"))) -#endif -static void generate_world(Craftax* s) { - // Reset maps and bitmaps - for (int i = 0; i < MAP_PACKED_SIZE; i++) - s->map_packed[i] = (uint8_t)(BLK_GRASS | (BLK_GRASS << 4)); - memset(s->mob_bits, 0, sizeof(s->mob_bits)); - memset(s->zombie_bits, 0, sizeof(s->zombie_bits)); - memset(s->cow_bits, 0, sizeof(s->cow_bits)); - memset(s->skel_bits, 0, sizeof(s->skel_bits)); - memset(s->arrow_bits, 0, sizeof(s->arrow_bits)); - - // Perlin gradient tables (precompute cos/sin of the per-grid random angles). - // Padded by +16 floats so AVX-512 permute-load at the last grid row doesn't - // read out of bounds. - enum { GRID = 10, GRID_PAD = GRID * GRID + 16 }; - _Alignas(64) float cos_a[4][GRID_PAD]; - _Alignas(64) float sin_a[4][GRID_PAD]; - for (int layer = 0; layer < 4; layer++) { - for (int i = 0; i < GRID * GRID; i++) { - float a = cr_rf(&s->pcg) * 2.0f * 3.14159265f; - cos_a[layer][i] = cosf(a); - sin_a[layer][i] = sinf(a); +static int g_craftax_reset_pool_size = 0; +static CraftaxState* g_craftax_reset_pool = NULL; +static int g_craftax_reset_pool_ready = 0; + +// Called from my_init which runs single-threaded during env creation +// (vecenv.h iterates envs sequentially). First caller populates the +// pool; subsequent callers are no-ops. +static inline void craftax_set_reset_pool_size(int n) { + if (g_craftax_reset_pool_ready) return; + g_craftax_reset_pool_size = n; + if (n > 0) { + g_craftax_reset_pool = (CraftaxState*)calloc((size_t)n, sizeof(CraftaxState)); + for (int i = 0; i < n; i++) { + CraftaxThreefryKey init_key = craftax_prng_key((uint32_t)i); + CraftaxThreefryKey discard, reset_key; + craftax_threefry_split(init_key, &discard, &reset_key); + craftax_reset_state_from_reset_key(&g_craftax_reset_pool[i], reset_key); } - for (int i = GRID * GRID; i < GRID_PAD; i++) { cos_a[layer][i] = 0; sin_a[layer][i] = 0; } } + g_craftax_reset_pool_ready = 1; +} - float scale = (float)MAP_SIZE / (float)(GRID - 1); - float inv_scale = 1.0f / scale; - int center = MAP_SIZE / 2; - - _Alignas(64) float noise[4][MAP_SIZE][MAP_SIZE]; - { - const __m512 c_lane = _mm512_setr_ps(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15); - const __m512 one = _mm512_set1_ps(1.0f); - const __m512 half = _mm512_set1_ps(0.5f); - const __m512 c6 = _mm512_set1_ps(6.0f); - const __m512 c15 = _mm512_set1_ps(15.0f); - const __m512 c10 = _mm512_set1_ps(10.0f); - const __m512 invs = _mm512_set1_ps(inv_scale); - const __m512i i_one = _mm512_set1_epi32(1); - - for (int r = 0; r < MAP_SIZE; r++) { - float nr = (float)r * inv_scale; - int x0 = (int)nr; - float fx = nr - x0; - float fx1 = fx - 1.0f; - float u = perlin_interp(fx); - int row0 = x0 * GRID, row1 = row0 + GRID; - __m512 fx_v = _mm512_set1_ps(fx); - __m512 fx1_v = _mm512_set1_ps(fx1); - __m512 u_v = _mm512_set1_ps(u); - - for (int c_base = 0; c_base < MAP_SIZE; c_base += 16) { - __m512 c_v = _mm512_add_ps(_mm512_set1_ps((float)c_base), c_lane); - __m512 nc_v = _mm512_mul_ps(c_v, invs); - __m512i y0_v = _mm512_cvttps_epi32(nc_v); - __m512 y0_f = _mm512_cvtepi32_ps(y0_v); - __m512 fy_v = _mm512_sub_ps(nc_v, y0_f); - __m512 fy1_v = _mm512_sub_ps(fy_v, one); - __m512 t = _mm512_fmsub_ps(fy_v, c6, c15); - t = _mm512_fmadd_ps(fy_v, t, c10); - __m512 fy2 = _mm512_mul_ps(fy_v, fy_v); - __m512 fy3 = _mm512_mul_ps(fy2, fy_v); - __m512 v_v = _mm512_mul_ps(fy3, t); - __m512i y1_v = _mm512_add_epi32(y0_v, i_one); - - for (int k = 0; k < 4; k++) { - __m512 cos_r0 = _mm512_loadu_ps(&cos_a[k][row0]); - __m512 cos_r1 = _mm512_loadu_ps(&cos_a[k][row1]); - __m512 sin_r0 = _mm512_loadu_ps(&sin_a[k][row0]); - __m512 sin_r1 = _mm512_loadu_ps(&sin_a[k][row1]); - - __m512 c00 = _mm512_permutexvar_ps(y0_v, cos_r0); - __m512 c10v= _mm512_permutexvar_ps(y0_v, cos_r1); - __m512 c01 = _mm512_permutexvar_ps(y1_v, cos_r0); - __m512 c11 = _mm512_permutexvar_ps(y1_v, cos_r1); - __m512 s00 = _mm512_permutexvar_ps(y0_v, sin_r0); - __m512 s10 = _mm512_permutexvar_ps(y0_v, sin_r1); - __m512 s01 = _mm512_permutexvar_ps(y1_v, sin_r0); - __m512 s11 = _mm512_permutexvar_ps(y1_v, sin_r1); - - __m512 n00 = _mm512_fmadd_ps(c00, fx_v, _mm512_mul_ps(s00, fy_v)); - __m512 n10 = _mm512_fmadd_ps(c10v, fx1_v, _mm512_mul_ps(s10, fy_v)); - __m512 n01 = _mm512_fmadd_ps(c01, fx_v, _mm512_mul_ps(s01, fy1_v)); - __m512 n11 = _mm512_fmadd_ps(c11, fx1_v, _mm512_mul_ps(s11, fy1_v)); - - __m512 nx0 = _mm512_fmadd_ps(u_v, _mm512_sub_ps(n10, n00), n00); - __m512 nx1 = _mm512_fmadd_ps(u_v, _mm512_sub_ps(n11, n01), n01); - __m512 n = _mm512_fmadd_ps(v_v, _mm512_sub_ps(nx1, nx0), nx0); - n = _mm512_mul_ps(_mm512_add_ps(n, one), half); - - _mm512_storeu_ps(&noise[k][r][c_base], n); - } - } - } +static inline void craftax_reset_state_from_seed(Craftax* env) { + CraftaxThreefryKey initial_key = craftax_prng_key((uint32_t)env->seed); + if (g_craftax_reset_pool_size > 0) { + CraftaxThreefryKey discard; + craftax_threefry_split(initial_key, &env->rng_key, &discard); + int idx = (int)(env->seed % (uint64_t)g_craftax_reset_pool_size); + memcpy(&env->state, &g_craftax_reset_pool[idx], sizeof(CraftaxState)); + return; } + CraftaxThreefryKey reset_key; + craftax_threefry_split(initial_key, &env->rng_key, &reset_key); + craftax_reset_state_from_reset_key(&env->state, reset_key); +} - // Tile-logic sweep -- reads precomputed noise, writes blocks - for (int r = 0; r < MAP_SIZE; r++) { - for (int c = 0; c < MAP_SIZE; c++) { - float water_noise = noise[0][r][c]; - float mountain_noise = noise[1][r][c]; - float tree_noise = noise[2][r][c]; - float path_noise = noise[3][r][c]; - - float dist = sqrtf((float)((r-center)*(r-center) + (c-center)*(c-center))); - float prox = 1.0f - cr_min_f(dist / 20.0f, 1.0f); - - float water_val = water_noise - prox * 0.3f; - float mountain_val = mountain_noise - prox * 0.3f; - - int8_t blk = BLK_GRASS; - if (water_val > 0.7f) blk = BLK_WATER; - else if (water_val > 0.6f && water_val <= 0.75f) blk = BLK_SAND; - else if (mountain_val > 0.7f) { - blk = BLK_STONE; - if (path_noise > 0.8f) blk = BLK_PATH; - if (mountain_val > 0.85f && water_noise > 0.4f) blk = BLK_PATH; - if (mountain_val > 0.85f && tree_noise > 0.7f) blk = BLK_LAVA; - } - if (blk == BLK_STONE) { - float ore = cr_rf(&s->pcg); - if (ore < 0.005f && mountain_val > 0.8f) blk = BLK_DIAMOND; - else if (ore < 0.035f) blk = BLK_IRON; - else if (ore < 0.075f) blk = BLK_COAL; - } - if (blk == BLK_GRASS && tree_noise > 0.5f && cr_rf(&s->pcg) > 0.8f) - blk = BLK_TREE; - map_set(s, r, c, blk); - } +// Hot-path reset used by c_step on episode-done. Consults the reset pool +// when enabled, falls through to generate_world otherwise. Pool index is +// derived from the reset_key so different done events pick different +// pooled worlds. The direct craftax_reset_state_from_reset_key stays +// pool-free so the parity harness and any other direct caller get exact +// per-key determinism. +static inline void craftax_reset_state_on_done( + CraftaxState* out, + CraftaxThreefryKey reset_key +) { + if (g_craftax_reset_pool_size > 0) { + uint32_t idx = reset_key.word[0] % (uint32_t)g_craftax_reset_pool_size; + memcpy(out, &g_craftax_reset_pool[idx], sizeof(CraftaxState)); + return; } + craftax_reset_state_from_reset_key(out, reset_key); +} - map_set(s, center, center, BLK_GRASS); // player spawn always grass - - bool has_diamond = false; - for (int r = 0; r < MAP_SIZE && !has_diamond; r++) - for (int c = 0; c < MAP_SIZE && !has_diamond; c++) - if (map_get(s, r, c) == BLK_DIAMOND) has_diamond = true; - if (!has_diamond) { - for (int att = 0; att < 1000; att++) { - int r = cr_ri(&s->pcg, MAP_SIZE), c = cr_ri(&s->pcg, MAP_SIZE); - if (map_get(s, r, c) == BLK_STONE) { map_set(s, r, c, BLK_DIAMOND); break; } - } +static inline void craftax_encode_native_observation( + const CraftaxState* state, + float* obs +) { + if (obs == NULL) { + return; } + craftax_encode_reset_observation((const CraftaxWorldState*)(const void*)state, obs); +} - // Initial intrinsics + inventory + mobs - s->player_r = center; s->player_c = center; s->player_dir = 4; - s->health = 9; s->food = 9; s->drink = 9; s->energy = 9; - s->is_sleeping = false; - s->recover = s->hunger = s->thirst = s->fatigue = 0; - memset(s->inv, 0, sizeof(s->inv)); - memset(s->zombie_mask, 0, sizeof(s->zombie_mask)); - memset(s->zombie_hp, 0, sizeof(s->zombie_hp)); - memset(s->zombie_cd, 0, sizeof(s->zombie_cd)); - memset(s->cow_mask, 0, sizeof(s->cow_mask)); - memset(s->cow_hp, 0, sizeof(s->cow_hp)); - memset(s->skel_mask, 0, sizeof(s->skel_mask)); - memset(s->skel_hp, 0, sizeof(s->skel_hp)); - memset(s->skel_cd, 0, sizeof(s->skel_cd)); - memset(s->arrow_mask, 0, sizeof(s->arrow_mask)); - memset(s->plant_mask, 0, sizeof(s->plant_mask)); - memset(s->plant_age, 0, sizeof(s->plant_age)); - memset(s->achievements, 0, sizeof(s->achievements)); - s->timestep = 0; - s->light_level = 1.0f; +static inline float craftax_calculate_light_level_native(int32_t timestep) { + float progress = fmodf( + (float)timestep / (float)CRAFTAX_DAY_LENGTH, + 1.0f + ) + 0.3f; + float c = cosf(CRAFTAX_WG_PI * progress); + return 1.0f - powf(fabsf(c), 3.0f); } -// ============================================================ -// Step sub-actions -// ============================================================ -static void do_crafting(Craftax* s, int action) { - bool t = is_near_block(s, BLK_TABLE); - bool f = is_near_block(s, BLK_FURNACE); - if (action == ACT_MAKE_WOOD_PICK && t && s->inv[0] >= 1) { s->inv[0]--; s->inv[6]++; s->achievements[ACH_MAKE_WOOD_PICK] = true; } - if (action == ACT_MAKE_STONE_PICK && t && s->inv[0] >= 1 && s->inv[1] >= 1) { s->inv[0]--; s->inv[1]--; s->inv[7]++; s->achievements[ACH_MAKE_STONE_PICK] = true; } - if (action == ACT_MAKE_IRON_PICK && t && f && s->inv[0] >= 1 && s->inv[1] >= 1 && s->inv[3] >= 1 && s->inv[2] >= 1) { - s->inv[0]--; s->inv[1]--; s->inv[3]--; s->inv[2]--; s->inv[8]++; s->achievements[ACH_MAKE_IRON_PICK] = true; - } - if (action == ACT_MAKE_WOOD_SWORD && t && s->inv[0] >= 1) { s->inv[0]--; s->inv[9]++; s->achievements[ACH_MAKE_WOOD_SWORD] = true; } - if (action == ACT_MAKE_STONE_SWORD && t && s->inv[0] >= 1 && s->inv[1] >= 1) { s->inv[0]--; s->inv[1]--; s->inv[10]++; s->achievements[ACH_MAKE_STONE_SWORD] = true; } - if (action == ACT_MAKE_IRON_SWORD && t && f && s->inv[0] >= 1 && s->inv[1] >= 1 && s->inv[3] >= 1 && s->inv[2] >= 1) { - s->inv[0]--; s->inv[1]--; s->inv[3]--; s->inv[2]--; s->inv[11]++; s->achievements[ACH_MAKE_IRON_SWORD] = true; - } +static inline bool craftax_is_game_over_native(const CraftaxState* state) { + return state->timestep >= CRAFTAX_DEFAULT_MAX_TIMESTEPS + || state->player_health <= 0.0f; } -static void do_action(Craftax* s) { - int tr = s->player_r + DIR_DR[s->player_dir]; - int tc = s->player_c + DIR_DC[s->player_dir]; - if (!in_bounds(tr, tc)) return; - int dmg = get_damage(s); - bool attacked = false; - - for (int i = 0; i < MAX_ZOMBIES && !attacked; i++) - if (s->zombie_mask[i] && s->zombie_r[i] == tr && s->zombie_c[i] == tc) { - s->zombie_hp[i] -= dmg; - if (s->zombie_hp[i] <= 0) { - s->zombie_mask[i] = false; - mb_clear(s->mob_bits, tr, tc); mb_clear(s->zombie_bits, tr, tc); - s->achievements[ACH_DEFEAT_ZOMBIE] = true; - } - attacked = true; - } - for (int i = 0; i < MAX_COWS && !attacked; i++) - if (s->cow_mask[i] && s->cow_r[i] == tr && s->cow_c[i] == tc) { - s->cow_hp[i] -= dmg; - if (s->cow_hp[i] <= 0) { - s->cow_mask[i] = false; - mb_clear(s->mob_bits, tr, tc); mb_clear(s->cow_bits, tr, tc); - s->achievements[ACH_EAT_COW] = true; - s->food = (int8_t)cr_min_i(9, s->food + 6); s->hunger = 0; - } - attacked = true; - } - for (int i = 0; i < MAX_SKELETONS && !attacked; i++) - if (s->skel_mask[i] && s->skel_r[i] == tr && s->skel_c[i] == tc) { - s->skel_hp[i] -= dmg; - if (s->skel_hp[i] <= 0) { - s->skel_mask[i] = false; - mb_clear(s->mob_bits, tr, tc); mb_clear(s->skel_bits, tr, tc); - s->achievements[ACH_DEFEAT_SKELETON] = true; - } - attacked = true; - } - if (attacked) return; - - int8_t blk = map_get(s, tr, tc); - switch (blk) { - case BLK_TREE: - map_set(s, tr, tc, BLK_GRASS); - s->inv[0] = (int8_t)cr_min_i(9, s->inv[0] + 1); - s->achievements[ACH_COLLECT_WOOD] = true; break; - case BLK_STONE: - if (s->inv[6] > 0 || s->inv[7] > 0 || s->inv[8] > 0) { - map_set(s, tr, tc, BLK_PATH); - s->inv[1] = (int8_t)cr_min_i(9, s->inv[1] + 1); - s->achievements[ACH_COLLECT_STONE] = true; - } break; - case BLK_COAL: - if (s->inv[6] > 0 || s->inv[7] > 0 || s->inv[8] > 0) { - map_set(s, tr, tc, BLK_PATH); - s->inv[2] = (int8_t)cr_min_i(9, s->inv[2] + 1); - s->achievements[ACH_COLLECT_COAL] = true; - } break; - case BLK_IRON: - if (s->inv[7] > 0 || s->inv[8] > 0) { - map_set(s, tr, tc, BLK_PATH); - s->inv[3] = (int8_t)cr_min_i(9, s->inv[3] + 1); - s->achievements[ACH_COLLECT_IRON] = true; - } break; - case BLK_DIAMOND: - if (s->inv[8] > 0) { - map_set(s, tr, tc, BLK_PATH); - s->inv[4] = (int8_t)cr_min_i(9, s->inv[4] + 1); - s->achievements[ACH_COLLECT_DIAMOND] = true; - } break; - case BLK_GRASS: - if (cr_rf(&s->pcg) < 0.1f) { - s->inv[5] = (int8_t)cr_min_i(9, s->inv[5] + 1); - s->achievements[ACH_COLLECT_SAPLING] = true; - } break; - case BLK_WATER: - s->drink = (int8_t)cr_min_i(9, s->drink + 1); s->thirst = 0; - s->achievements[ACH_COLLECT_DRINK] = true; break; - case BLK_RIPE_PLANT: - map_set(s, tr, tc, BLK_PLANT); - s->food = (int8_t)cr_min_i(9, s->food + 4); s->hunger = 0; - s->achievements[ACH_EAT_PLANT] = true; - for (int i = 0; i < MAX_PLANTS; i++) - if (s->plant_mask[i] && s->plant_r[i] == tr && s->plant_c[i] == tc) { - s->plant_age[i] = 0; break; - } - break; +static inline void craftax_copy_achievements_to_env( + Craftax* env, + const CraftaxState* state +) { + for (int i = 0; i < CRAFTAX_NUM_ACHIEVEMENTS; i++) { + env->achievements[i] = state->achievements[i] ? 1.0f : 0.0f; } } -static void place_block(Craftax* s, int action) { - int tr = s->player_r + DIR_DR[s->player_dir]; - int tc = s->player_c + DIR_DC[s->player_dir]; - if (!in_bounds(tr, tc)) return; - if (has_mob_at(s, tr, tc)) return; - int8_t blk = map_get(s, tr, tc); - if (action == ACT_PLACE_TABLE && s->inv[0] >= 2 && !is_solid(blk)) { - map_set(s, tr, tc, BLK_TABLE); s->inv[0] -= 2; - s->achievements[ACH_PLACE_TABLE] = true; - } else if (action == ACT_PLACE_FURNACE && s->inv[1] >= 1 && !is_solid(blk)) { - map_set(s, tr, tc, BLK_FURNACE); s->inv[1] -= 1; - s->achievements[ACH_PLACE_FURNACE] = true; - } else if (action == ACT_PLACE_STONE && s->inv[1] >= 1 && (!is_solid(blk) || blk == BLK_WATER)) { - map_set(s, tr, tc, BLK_STONE); s->inv[1] -= 1; - s->achievements[ACH_PLACE_STONE] = true; - } else if (action == ACT_PLACE_PLANT && s->inv[5] >= 1 && blk == BLK_GRASS) { - map_set(s, tr, tc, BLK_PLANT); s->inv[5] -= 1; - s->achievements[ACH_PLACE_PLANT] = true; - for (int i = 0; i < MAX_PLANTS; i++) { - if (!s->plant_mask[i]) { - s->plant_r[i] = tr; s->plant_c[i] = tc; - s->plant_age[i] = 0; s->plant_mask[i] = true; break; - } +static void add_log(Craftax* env) { + int unlocked = 0; + for (int i = 0; i < CRAFTAX_NUM_ACHIEVEMENTS; i++) { + if (env->achievements[i] > 0.5f) { + unlocked++; + env->log.achievements[i] += 1.0f; } } + env->log.perf += (float)unlocked / (float)CRAFTAX_NUM_ACHIEVEMENTS; + env->log.score += env->episode_return_accum; + env->log.episode_return += env->episode_return_accum; + env->log.episode_length += (float)env->episode_length_accum; + env->log.n += 1.0f; } -static void move_player(Craftax* s, int action) { - if (action < 1 || action > 4) return; - int nr = s->player_r + DIR_DR[action]; - int nc = s->player_c + DIR_DC[action]; - s->player_dir = (int8_t)action; - if (!in_bounds(nr, nc)) return; - if (is_solid(map_get(s, nr, nc))) return; - if (has_mob_at(s, nr, nc)) return; - s->player_r = (int16_t)nr; s->player_c = (int16_t)nc; -} +static float craftax_gameplay_step_native( + CraftaxState* state, + int32_t action, + CraftaxThreefryKey rng +) { + bool init_achievements[CRAFTAX_NUM_ACHIEVEMENTS]; + memcpy(init_achievements, state->achievements, sizeof(init_achievements)); + float init_health = state->player_health; -static bool can_move_mob(const Craftax* s, int r, int c) { - if (!in_bounds(r, c)) return false; - int8_t blk = map_get(s, r, c); - if (is_solid(blk)) return false; - if (blk == BLK_LAVA) return false; - if (has_mob_at(s, r, c)) return false; - if (r == s->player_r && c == s->player_c) return false; - return true; -} + action = state->is_sleeping ? CRAFTAX_ACTION_NOOP : action; + action = state->is_resting ? CRAFTAX_ACTION_NOOP : action; -static void update_mobs(Craftax* s) { - int pr = s->player_r, pc = s->player_c; - - for (int i = 0; i < MAX_ZOMBIES; i++) { - if (!s->zombie_mask[i]) continue; - int zr = s->zombie_r[i], zc = s->zombie_c[i]; - int dist = l1_dist(zr, zc, pr, pc); - if (dist >= MOB_DESPAWN_DIST) { - s->zombie_mask[i] = false; - mb_clear(s->mob_bits, zr, zc); mb_clear(s->zombie_bits, zr, zc); - continue; - } - if (dist <= 1 && s->zombie_cd[i] <= 0) { - int dmg = s->is_sleeping ? 7 : 2; - s->health -= dmg; - s->zombie_cd[i] = 5; - s->is_sleeping = false; - } - s->zombie_cd[i] = (int8_t)cr_max_i(0, s->zombie_cd[i] - 1); - - int dr = 0, dc = 0; - if (dist < 10 && cr_rf(&s->pcg) < 0.75f) { - int adr = abs(pr - zr), adc = abs(pc - zc); - if (adr > adc || (adr == adc && cr_rf(&s->pcg) < 0.5f)) dr = cr_sign_i(pr - zr); - else dc = cr_sign_i(pc - zc); - } else { - int d = cr_ri(&s->pcg, 4); - dr = DIR_DR[d+1]; dc = DIR_DC[d+1]; - } - int nr = zr + dr, nc = zc + dc; - if (can_move_mob(s, nr, nc)) { - mb_clear(s->mob_bits, zr, zc); mb_clear(s->zombie_bits, zr, zc); - s->zombie_r[i] = (int16_t)nr; s->zombie_c[i] = (int16_t)nc; - mb_set(s->mob_bits, nr, nc); mb_set(s->zombie_bits, nr, nc); - } - } + craftax_change_floor_native(state, action); + craftax_do_crafting_native(state, action); - for (int i = 0; i < MAX_COWS; i++) { - if (!s->cow_mask[i]) continue; - int cr = s->cow_r[i], cc = s->cow_c[i]; - int dist = l1_dist(cr, cc, pr, pc); - if (dist >= MOB_DESPAWN_DIST) { - s->cow_mask[i] = false; - mb_clear(s->mob_bits, cr, cc); mb_clear(s->cow_bits, cr, cc); - continue; - } - int d = cr_ri(&s->pcg, 8); - if (d < 4) { - int dr = DIR_DR[d+1], dc2 = DIR_DC[d+1]; - int nr = cr + dr, nc = cc + dc2; - if (can_move_mob(s, nr, nc)) { - mb_clear(s->mob_bits, cr, cc); mb_clear(s->cow_bits, cr, cc); - s->cow_r[i] = (int16_t)nr; s->cow_c[i] = (int16_t)nc; - mb_set(s->mob_bits, nr, nc); mb_set(s->cow_bits, nr, nc); - } - } - } + CraftaxThreefryKey subkey = craftax_step_native_next_key(&rng); + craftax_do_action_native(state, action, subkey); - for (int i = 0; i < MAX_SKELETONS; i++) { - if (!s->skel_mask[i]) continue; - int sr = s->skel_r[i], sc = s->skel_c[i]; - int dist = l1_dist(sr, sc, pr, pc); - if (dist >= MOB_DESPAWN_DIST) { - s->skel_mask[i] = false; - mb_clear(s->mob_bits, sr, sc); mb_clear(s->skel_bits, sr, sc); - continue; - } - if (dist >= 4 && dist <= 5 && s->skel_cd[i] <= 0) { - for (int a = 0; a < MAX_ARROWS; a++) { - if (!s->arrow_mask[a]) { - s->arrow_mask[a] = true; - s->arrow_r[a] = (int16_t)sr; s->arrow_c[a] = (int16_t)sc; - mb_set(s->arrow_bits, sr, sc); - int adr = abs(pr - sr), adc = abs(pc - sc); - s->arrow_dr[a] = (int8_t)((adr > 0) ? cr_sign_i(pr - sr) : 0); - s->arrow_dc[a] = (int8_t)((adc > 0) ? cr_sign_i(pc - sc) : 0); - break; - } - } - s->skel_cd[i] = 4; - } - s->skel_cd[i] = (int8_t)cr_max_i(0, s->skel_cd[i] - 1); - - int dr = 0, dc = 0; - bool random_move = cr_rf(&s->pcg) < 0.15f; - if (!random_move) { - if (dist >= 10) { - int adr = abs(pr - sr), adc = abs(pc - sc); - if (adr > adc || (adr == adc && cr_rf(&s->pcg) < 0.5f)) dr = cr_sign_i(pr - sr); - else dc = cr_sign_i(pc - sc); - } else if (dist <= 3) { - int adr = abs(pr - sr), adc = abs(pc - sc); - if (adr > adc || (adr == adc && cr_rf(&s->pcg) < 0.5f)) dr = -cr_sign_i(pr - sr); - else dc = -cr_sign_i(pc - sc); - } else { - random_move = true; - } - } - if (random_move) { - int d = cr_ri(&s->pcg, 4); - dr = DIR_DR[d+1]; dc = DIR_DC[d+1]; - } - int nr = sr + dr, nc = sc + dc; - if (can_move_mob(s, nr, nc)) { - mb_clear(s->mob_bits, sr, sc); mb_clear(s->skel_bits, sr, sc); - s->skel_r[i] = (int16_t)nr; s->skel_c[i] = (int16_t)nc; - mb_set(s->mob_bits, nr, nc); mb_set(s->skel_bits, nr, nc); - } - } + craftax_place_block_native(state, action); + craftax_shoot_projectile_native(state, action); + craftax_cast_spell_native(state, action); + craftax_drink_potion_native(state, action); - for (int i = 0; i < MAX_ARROWS; i++) { - if (!s->arrow_mask[i]) continue; - int ar = s->arrow_r[i], ac = s->arrow_c[i]; - int nr = ar + s->arrow_dr[i], nc = ac + s->arrow_dc[i]; - if (!in_bounds(nr, nc)) { s->arrow_mask[i] = false; mb_clear(s->arrow_bits, ar, ac); continue; } - int8_t blk = map_get(s, nr, nc); - if (is_solid(blk) && blk != BLK_WATER) { - if (blk == BLK_FURNACE || blk == BLK_TABLE) map_set(s, nr, nc, BLK_PATH); - s->arrow_mask[i] = false; mb_clear(s->arrow_bits, ar, ac); continue; - } - if (nr == pr && nc == pc) { - s->health -= 2; s->is_sleeping = false; - s->arrow_mask[i] = false; mb_clear(s->arrow_bits, ar, ac); continue; - } - mb_clear(s->arrow_bits, ar, ac); - s->arrow_r[i] = (int16_t)nr; s->arrow_c[i] = (int16_t)nc; - mb_set(s->arrow_bits, nr, nc); - } -} + subkey = craftax_step_native_next_key(&rng); + craftax_read_book_native(state, subkey.word, action); -static bool try_spawn(Craftax* s, int min_d, int max_d, bool need_grass, bool need_path, - int* or_, int* oc_) { - int pr = s->player_r, pc = s->player_c; - for (int att = 0; att < 20; att++) { - int r = cr_ri(&s->pcg, MAP_SIZE), c = cr_ri(&s->pcg, MAP_SIZE); - int dist = l1_dist(r, c, pr, pc); - if (dist < min_d || dist >= max_d) continue; - if (has_mob_at(s, r, c)) continue; - if (r == pr && c == pc) continue; - int8_t blk = map_get(s, r, c); - if (need_grass && blk != BLK_GRASS) continue; - if (need_path && blk != BLK_PATH ) continue; - if (!need_grass && !need_path && blk != BLK_GRASS && blk != BLK_PATH) continue; - *or_ = r; *oc_ = c; return true; - } - return false; -} + subkey = craftax_step_native_next_key(&rng); + craftax_enchant_native(state, action, subkey); -static void spawn_mobs(Craftax* s) { - int n_cows = 0, n_z = 0, n_sk = 0; - for (int i = 0; i < MAX_COWS; i++) n_cows += s->cow_mask[i]; - for (int i = 0; i < MAX_ZOMBIES; i++) n_z += s->zombie_mask[i]; - for (int i = 0; i < MAX_SKELETONS; i++) n_sk += s->skel_mask[i]; - - if (n_cows < MAX_COWS && cr_rf(&s->pcg) < 0.1f) { - int r, c; - if (try_spawn(s, 3, MOB_DESPAWN_DIST, true, false, &r, &c)) { - for (int i = 0; i < MAX_COWS; i++) if (!s->cow_mask[i]) { - s->cow_mask[i] = true; s->cow_r[i] = (int16_t)r; s->cow_c[i] = (int16_t)c; s->cow_hp[i] = 3; - mb_set(s->mob_bits, r, c); mb_set(s->cow_bits, r, c); - break; - } - } - } - float zombie_chance = 0.02f + 0.1f * (1.0f - s->light_level) * (1.0f - s->light_level); - if (n_z < MAX_ZOMBIES && cr_rf(&s->pcg) < zombie_chance) { - int r, c; - if (try_spawn(s, 9, MOB_DESPAWN_DIST, false, false, &r, &c)) { - for (int i = 0; i < MAX_ZOMBIES; i++) if (!s->zombie_mask[i]) { - s->zombie_mask[i] = true; s->zombie_r[i] = (int16_t)r; s->zombie_c[i] = (int16_t)c; - s->zombie_hp[i] = 5; s->zombie_cd[i] = 0; - mb_set(s->mob_bits, r, c); mb_set(s->zombie_bits, r, c); - break; - } - } - } - if (n_sk < MAX_SKELETONS && cr_rf(&s->pcg) < 0.05f) { - int r, c; - if (try_spawn(s, 9, MOB_DESPAWN_DIST, false, true, &r, &c)) { - for (int i = 0; i < MAX_SKELETONS; i++) if (!s->skel_mask[i]) { - s->skel_mask[i] = true; s->skel_r[i] = (int16_t)r; s->skel_c[i] = (int16_t)c; - s->skel_hp[i] = 3; s->skel_cd[i] = 0; - mb_set(s->mob_bits, r, c); mb_set(s->skel_bits, r, c); - break; - } - } - } -} + craftax_boss_logic_native(state); + craftax_level_up_attributes_native(state, action, CRAFTAX_MAX_ATTRIBUTE); + craftax_move_player_native(state, action, false); -static void update_plants(Craftax* s) { - for (int i = 0; i < MAX_PLANTS; i++) { - if (!s->plant_mask[i]) continue; - s->plant_age[i]++; - if (s->plant_age[i] >= 600) { - int r = s->plant_r[i], c = s->plant_c[i]; - if (in_bounds(r, c) && map_get(s, r, c) == BLK_PLANT) - map_set(s, r, c, BLK_RIPE_PLANT); - } - } -} + subkey = craftax_step_native_next_key(&rng); + craftax_update_mobs_native(state, subkey); -static void update_intrinsics(Craftax* s, int action) { - if (action == ACT_SLEEP && s->energy < 9) s->is_sleeping = true; - if (s->energy >= 9 && s->is_sleeping) { - s->is_sleeping = false; - s->achievements[ACH_WAKE_UP] = true; - } - float mul = s->is_sleeping ? 0.5f : 1.0f; - s->hunger += mul; if (s->hunger > 25.0f) { s->food--; s->hunger = 0; } - s->thirst += mul; if (s->thirst > 20.0f) { s->drink--; s->thirst = 0; } - if (s->is_sleeping) s->fatigue -= 1.0f; else s->fatigue += 1.0f; - if (s->fatigue > 30.0f) { s->energy--; s->fatigue = 0; } - if (s->fatigue < -10.0f) { s->energy = (int8_t)cr_min_i(s->energy + 1, 9); s->fatigue = 0; } - bool ok = (s->food > 0) && (s->drink > 0) && (s->energy > 0 || s->is_sleeping); - if (ok) s->recover += s->is_sleeping ? 2.0f : 1.0f; - else s->recover += s->is_sleeping ? -0.5f : -1.0f; - if (s->recover > 25.0f) { s->health = (int8_t)cr_min_i(s->health + 1, 9); s->recover = 0; } - if (s->recover < -15.0f) { s->health--; s->recover = 0; } -} + subkey = craftax_step_native_next_key(&rng); + craftax_spawn_mobs_native(state, subkey); -// ============================================================ -// Observation builder (writes OBS_DIM floats into env->observations) -// ============================================================ -static void compute_observations(Craftax* s) { - float* obs = s->observations; - int pr = s->player_r, pc = s->player_c; - int idx = 0; - for (int dr = -3; dr <= 3; dr++) { - int r = pr + dr; - bool row_ok = (unsigned)r < MAP_SIZE; - uint64_t zb = row_ok ? s->zombie_bits[r] : 0; - uint64_t cb = row_ok ? s->cow_bits[r] : 0; - uint64_t sb = row_ok ? s->skel_bits[r] : 0; - uint64_t ab = row_ok ? s->arrow_bits[r] : 0; - for (int dc = -4; dc <= 4; dc++) { - int c = pc + dc; - int8_t blk = (row_ok && (unsigned)c < MAP_SIZE) ? map_get(s, r, c) : BLK_OUT_OF_BOUNDS; - float* dst = obs + idx; - for (int b = 0; b < NUM_BLOCK_TYPES; b++) dst[b] = 0.0f; - if ((unsigned)blk < NUM_BLOCK_TYPES) dst[blk] = 1.0f; - idx += NUM_BLOCK_TYPES; - float mz = 0, mc = 0, ms = 0, ma = 0; - if (row_ok && (unsigned)c < MAP_SIZE) { - uint64_t bit = 1ULL << c; - mz = (zb & bit) ? 1.0f : 0.0f; - mc = (cb & bit) ? 1.0f : 0.0f; - ms = (sb & bit) ? 1.0f : 0.0f; - ma = (ab & bit) ? 1.0f : 0.0f; - } - obs[idx++] = mz; obs[idx++] = mc; obs[idx++] = ms; obs[idx++] = ma; - } - } - for (int i = 0; i < NUM_INVENTORY; i++) obs[idx++] = (float)s->inv[i] * 0.1f; - obs[idx++] = (float)s->health * 0.1f; - obs[idx++] = (float)s->food * 0.1f; - obs[idx++] = (float)s->drink * 0.1f; - obs[idx++] = (float)s->energy * 0.1f; - for (int d = 1; d <= 4; d++) obs[idx++] = (s->player_dir == d) ? 1.0f : 0.0f; - obs[idx++] = s->light_level; - obs[idx++] = s->is_sleeping ? 1.0f : 0.0f; -} + craftax_update_plants_native(state); + craftax_update_player_intrinsics_native(state, action); + craftax_clip_inventory_and_intrinsics_native(state, false); + craftax_calculate_inventory_achievements_native(state); -// ============================================================ -// Logging (stats accumulated into env->log; flushed at vec-level by PufferLib) -// ============================================================ -static void add_log(Craftax* env) { - int unlocked = 0; - for (int i = 0; i < NUM_ACHIEVEMENTS; i++) { - if (env->achievements[i]) { - unlocked++; - env->log.achievements[i] += 1.0f; - } + float reward = 0.0f; + for (int i = 0; i < CRAFTAX_NUM_ACHIEVEMENTS; i++) { + int32_t delta = (int32_t)state->achievements[i] + - (int32_t)init_achievements[i]; + reward += (float)delta * CRAFTAX_ACHIEVEMENT_REWARD_MAP[i]; } - env->log.perf += (float)unlocked / (float)NUM_ACHIEVEMENTS; - env->log.score += env->episode_return_accum; - env->log.episode_return += env->episode_return_accum; - env->log.episode_length += (float)env->episode_length_accum; - env->log.n += 1.0f; + reward += (state->player_health - init_health) * 0.1f; + + subkey = craftax_step_native_next_key(&rng); + state->timestep += 1; + state->light_level = craftax_calculate_light_level_native(state->timestep); + state->state_rng[0] = subkey.word[0]; + state->state_rng[1] = subkey.word[1]; + + return reward; } // ============================================================ -// Public API: c_init / c_reset / c_step / c_close / c_render +// Public API expected by vecenv.h // ============================================================ static void c_init(Craftax* env) { - env->num_agents = 1; env->client = NULL; - // env->rng was seeded by default my_vec_init to the env index; use it to - // initialize a proper 64-bit PCG state. - uint64_t seed = (uint64_t)env->rng; - env->pcg = seed * 0x9E3779B97F4A7C15ULL + 0x87C37B91114253D5ULL; - // Warm the RNG a bit so small seeds don't produce correlated worlds. - for (int i = 0; i < 8; i++) (void)cr_pcg(&env->pcg); + env->num_agents = 1; + env->episode_return_accum = 0.0f; + env->episode_length_accum = 0; + memset(env->achievements, 0, sizeof(env->achievements)); memset(&env->log, 0, sizeof(env->log)); + craftax_reset_state_from_seed(env); } static void c_reset(Craftax* env) { + if (env->rewards != NULL) { + env->rewards[0] = 0.0f; + } + if (env->terminals != NULL) { + env->terminals[0] = 0.0f; + } env->episode_return_accum = 0.0f; env->episode_length_accum = 0; - generate_world(env); - compute_observations(env); + memset(env->achievements, 0, sizeof(env->achievements)); + + craftax_reset_state_from_seed(env); + craftax_encode_native_observation(&env->state, env->observations); } -static void c_step(Craftax* env) { +static void c_step_native(Craftax* env) { env->rewards[0] = 0.0f; env->terminals[0] = 0.0f; int action = (int)env->actions[0]; - if (action < 0) action = 0; - if (action >= NUM_ACTIONS) action = NUM_ACTIONS - 1; - - // Snapshot for reward computation - env->old_health = env->health; - memcpy(env->old_achievements, env->achievements, sizeof(env->achievements)); - - int eff_action = env->is_sleeping ? ACT_NOOP : action; - do_crafting(env, eff_action); - if (eff_action == ACT_DO) do_action(env); - if (eff_action >= ACT_PLACE_STONE && eff_action <= ACT_PLACE_PLANT) place_block(env, eff_action); - move_player(env, eff_action); - update_mobs(env); - spawn_mobs(env); - update_plants(env); - update_intrinsics(env, action); - - for (int i = 0; i < NUM_INVENTORY; i++) - env->inv[i] = (int8_t)cr_clamp_i(env->inv[i], 0, 9); - - env->timestep++; - float t_frac = fmodf((float)env->timestep / (float)DAY_LENGTH, 1.0f) + 0.3f; - float cv = cosf(3.14159265f * t_frac); - env->light_level = 1.0f - fabsf(cv * cv * cv); - - // Reward: new achievements + health change * 0.1 - float ach_r = 0.0f; - for (int i = 0; i < NUM_ACHIEVEMENTS; i++) - ach_r += (float)(env->achievements[i] && !env->old_achievements[i]); - float hp_r = (float)(env->health - env->old_health) * 0.1f; - float r = ach_r + hp_r; - env->rewards[0] = r; - env->episode_return_accum += r; - env->episode_length_accum += 1; + if (action < 0) { + action = CRAFTAX_ACTION_NOOP; + } + if (action >= CRAFTAX_NUM_ACTIONS) { + action = CRAFTAX_NUM_ACTIONS - 1; + } - // Terminal conditions - bool done = (env->timestep >= MAX_TIMESTEPS) || (env->health <= 0); - if (in_bounds(env->player_r, env->player_c) - && map_get(env, env->player_r, env->player_c) == BLK_LAVA) done = true; + CraftaxThreefryKey step_key; + craftax_threefry_split(env->rng_key, &env->rng_key, &step_key); + + CraftaxThreefryKey step_rng; + CraftaxThreefryKey reset_key; + craftax_threefry_split(step_key, &step_rng, &reset_key); + + float reward = craftax_gameplay_step_native(&env->state, action, step_rng); + bool done = craftax_is_game_over_native(&env->state); + + craftax_copy_achievements_to_env(env, &env->state); + + env->rewards[0] = reward; + env->terminals[0] = done ? 1.0f : 0.0f; + env->episode_return_accum += reward; + env->episode_length_accum += 1; if (done) { - env->terminals[0] = 1.0f; add_log(env); - c_reset(env); // auto-reset (observation written inside) - } else { - compute_observations(env); + env->episode_return_accum = 0.0f; + env->episode_length_accum = 0; + memset(env->achievements, 0, sizeof(env->achievements)); + craftax_reset_state_on_done(&env->state, reset_key); } + + craftax_encode_native_observation(&env->state, env->observations); +} + +static void c_step(Craftax* env) { + c_step_native(env); } static void c_close(Craftax* env) { (void)env; } -// ============================================================ -// Minimal raylib rendering (optional; matches breakout pattern) -// ============================================================ +// ------------------------------------------------------------ +// Tile-based renderer using upstream Craftax 16x16 PNG assets +// ------------------------------------------------------------ +// Packed layout (see ocean/craftax/pack_textures.py): +// [0..36] block textures (indexed by CraftaxBlockType) +// [37..41] player: down, up, left, right, sleep +// [42..46] items: none, torch, ladder_down, ladder_up, ladder_down_blocked + +#define CRAFTAX_TEX_TILE_PX 16 +#define CRAFTAX_TEX_SCALE 4 // on-screen px = 64 +#define CRAFTAX_TEX_DRAW_PX (CRAFTAX_TEX_TILE_PX * CRAFTAX_TEX_SCALE) +#define CRAFTAX_TEX_NUM (37 + 5 + 5 + 3 + 4) + +// Render viewport (independent of agent obs window) +#define CRAFTAX_RENDER_ROWS 16 +#define CRAFTAX_RENDER_COLS 16 + +#define CRAFTAX_TEX_PLAYER_DOWN 37 +#define CRAFTAX_TEX_PLAYER_UP 38 +#define CRAFTAX_TEX_PLAYER_LEFT 39 +#define CRAFTAX_TEX_PLAYER_RIGHT 40 +#define CRAFTAX_TEX_PLAYER_SLEEP 41 +#define CRAFTAX_TEX_ITEM_BASE 42 + +static Texture2D craftax_textures[CRAFTAX_TEX_NUM]; +static bool craftax_textures_loaded = false; + +static void craftax_load_textures(void) { + if (craftax_textures_loaded) return; + const char* candidates[] = { + "resources/craftax/textures.bin", + "../resources/craftax/textures.bin", + "../../resources/craftax/textures.bin", + }; + FILE* f = NULL; + for (size_t i = 0; i < sizeof(candidates)/sizeof(candidates[0]); i++) { + f = fopen(candidates[i], "rb"); + if (f) break; + } + if (!f) { + fprintf(stderr, "craftax: textures.bin not found in resources/craftax -- run ocean/craftax/pack_textures.py\n"); + exit(1); + } + const size_t tile_bytes = CRAFTAX_TEX_TILE_PX * CRAFTAX_TEX_TILE_PX * 4; + uint8_t* buf = (uint8_t*)malloc(tile_bytes); + for (int i = 0; i < CRAFTAX_TEX_NUM; i++) { + if (fread(buf, 1, tile_bytes, f) != tile_bytes) { + fprintf(stderr, "craftax: short read on textures.bin at tile %d\n", i); + exit(1); + } + Image img = { + .data = buf, + .width = CRAFTAX_TEX_TILE_PX, + .height = CRAFTAX_TEX_TILE_PX, + .mipmaps = 1, + .format = PIXELFORMAT_UNCOMPRESSED_R8G8B8A8, + }; + craftax_textures[i] = LoadTextureFromImage(img); + SetTextureFilter(craftax_textures[i], TEXTURE_FILTER_POINT); + } + free(buf); + fclose(f); + craftax_textures_loaded = true; +} + +static int craftax_player_tex_id(int32_t direction, bool sleeping) { + if (sleeping) return CRAFTAX_TEX_PLAYER_SLEEP; + switch (direction) { + case 1: return CRAFTAX_TEX_PLAYER_LEFT; + case 2: return CRAFTAX_TEX_PLAYER_RIGHT; + case 3: return CRAFTAX_TEX_PLAYER_UP; + case 4: return CRAFTAX_TEX_PLAYER_DOWN; + default: return CRAFTAX_TEX_PLAYER_DOWN; + } +} + +static void craftax_draw_tile(int tex_id, int dst_x, int dst_y, float tint_alpha) { + if (tex_id < 0 || tex_id >= CRAFTAX_TEX_NUM) return; + Rectangle src = {0, 0, CRAFTAX_TEX_TILE_PX, CRAFTAX_TEX_TILE_PX}; + Rectangle dst = {(float)dst_x, (float)dst_y, CRAFTAX_TEX_DRAW_PX, CRAFTAX_TEX_DRAW_PX}; + Color tint = {255, 255, 255, (unsigned char)(tint_alpha * 255.0f)}; + DrawTexturePro(craftax_textures[tex_id], src, dst, (Vector2){0, 0}, 0.0f, tint); +} + static void c_render(Craftax* env) { + const int view_w = CRAFTAX_RENDER_COLS * CRAFTAX_TEX_DRAW_PX; + const int view_h = CRAFTAX_RENDER_ROWS * CRAFTAX_TEX_DRAW_PX; + const int hud_h = 80; + if (!IsWindowReady()) { - InitWindow(MAP_SIZE * 10, MAP_SIZE * 10 + 60, "PufferLib Craftax"); + InitWindow(view_w, view_h + hud_h, "PufferLib Craftax"); SetTargetFPS(30); } + if (!craftax_textures_loaded) craftax_load_textures(); if (IsKeyDown(KEY_ESCAPE)) exit(0); + CraftaxState* s = &env->state; + int lvl = s->player_level; + int pr = s->player_position[0]; + int pc = s->player_position[1]; + int half_r = CRAFTAX_RENDER_ROWS / 2; + int half_c = CRAFTAX_RENDER_COLS / 2; + BeginDrawing(); ClearBackground(BLACK); - static const Color PALETTE[17] = { - (Color){0,0,0,255}, // INVALID - (Color){40,40,40,255}, // OUT_OF_BOUNDS - (Color){80,200,120,255}, // GRASS - (Color){50,120,220,255}, // WATER - (Color){110,110,110,255}, // STONE - (Color){40,120,40,255}, // TREE - (Color){140,90,40,255}, // WOOD - (Color){180,170,130,255}, // PATH - (Color){50,50,50,255}, // COAL - (Color){200,200,220,255}, // IRON - (Color){180,240,255,255}, // DIAMOND - (Color){180,120,60,255}, // TABLE - (Color){160,80,40,255}, // FURNACE - (Color){220,200,140,255}, // SAND - (Color){240,80,40,255}, // LAVA - (Color){60,200,60,255}, // PLANT - (Color){250,180,50,255}, // RIPE_PLANT - }; - for (int r = 0; r < MAP_SIZE; r++) { - for (int c = 0; c < MAP_SIZE; c++) { - int8_t blk = map_get(env, r, c); - DrawRectangle(c * 10, r * 10, 10, 10, PALETTE[(int)blk]); + + for (int vr = 0; vr < CRAFTAX_RENDER_ROWS; vr++) { + for (int vc = 0; vc < CRAFTAX_RENDER_COLS; vc++) { + int wr = pr - half_r + vr; + int wc = pc - half_c + vc; + int dst_x = vc * CRAFTAX_TEX_DRAW_PX; + int dst_y = vr * CRAFTAX_TEX_DRAW_PX; + + int blk = CRAFTAX_BLOCK_OUT_OF_BOUNDS; + float light = 1.0f; + if (wr >= 0 && wr < CRAFTAX_MAP_SIZE && wc >= 0 && wc < CRAFTAX_MAP_SIZE) { + blk = s->map[lvl][wr][wc]; + light = s->light_map[lvl][wr][wc]; + if (light < 0.05f) blk = CRAFTAX_BLOCK_DARKNESS; + } + if (blk < 0 || blk >= CRAFTAX_NUM_BLOCK_TYPES) blk = 0; + craftax_draw_tile(blk, dst_x, dst_y, 1.0f); + + // item overlay + if (wr >= 0 && wr < CRAFTAX_MAP_SIZE && wc >= 0 && wc < CRAFTAX_MAP_SIZE) { + int it = s->item_map[lvl][wr][wc]; + if (it > 0 && it < 5) { + craftax_draw_tile(CRAFTAX_TEX_ITEM_BASE + it, dst_x, dst_y, 1.0f); + } + } } } - DrawCircle(env->player_c * 10 + 5, env->player_r * 10 + 5, 4, WHITE); - DrawText(TextFormat("HP:%d F:%d D:%d E:%d t:%d", env->health, env->food, - env->drink, env->energy, env->timestep), - 4, MAP_SIZE * 10 + 4, 16, WHITE); + // player in center + int pid = craftax_player_tex_id(s->player_direction, s->is_sleeping); + craftax_draw_tile(pid, half_c * CRAFTAX_TEX_DRAW_PX, half_r * CRAFTAX_TEX_DRAW_PX, 1.0f); + + // night dim overlay + if (s->light_level < 1.0f) { + unsigned char a = (unsigned char)((1.0f - s->light_level) * 140.0f); + DrawRectangle(0, 0, view_w, view_h, (Color){0, 0, 40, a}); + } + + // HUD + int hud_y = view_h; + DrawRectangle(0, hud_y, view_w, hud_h, (Color){20, 20, 20, 255}); + DrawText(TextFormat("HP:%.0f F:%d D:%d E:%d M:%d L:%d t:%d", + s->player_health, s->player_food, s->player_drink, + s->player_energy, s->player_mana, s->player_level, s->timestep), + 4, hud_y + 4, 14, WHITE); + DrawText(TextFormat("XP:%d DEX:%d STR:%d INT:%d light:%.2f", + s->player_xp, s->player_dexterity, s->player_strength, + s->player_intelligence, s->light_level), + 4, hud_y + 22, 14, (Color){200, 200, 200, 255}); + int ach_count = 0; + for (int i = 0; i < CRAFTAX_NUM_ACHIEVEMENTS; i++) ach_count += s->achievements[i] ? 1 : 0; + DrawText(TextFormat("achievements: %d / %d", ach_count, CRAFTAX_NUM_ACHIEVEMENTS), + 4, hud_y + 40, 14, (Color){180, 220, 180, 255}); + DrawText(TextFormat("ret:%.2f len:%d", env->episode_return_accum, env->episode_length_accum), + 4, hud_y + 58, 14, (Color){200, 200, 140, 255}); + EndDrawing(); } + +#endif diff --git a/ocean/craftax/noise.h b/ocean/craftax/noise.h new file mode 100644 index 0000000000..e81e398509 --- /dev/null +++ b/ocean/craftax/noise.h @@ -0,0 +1,206 @@ +// Native C port of craftax/craftax/util/noise.py. + +#pragma once + +#include +#include +#include + +#include "threefry.h" + +#ifndef CRAFTAX_NOISE_PI2 +#define CRAFTAX_NOISE_PI2 6.28318530717958647692f +#endif + +#ifndef CRAFTAX_NOISE_SQRT2 +#define CRAFTAX_NOISE_SQRT2 1.41421356237309504880f +#endif + +static inline float craftax_noise_interpolant(float t) { + return t * t * t * (t * (t * 6.0f - 15.0f) + 10.0f); +} + +static inline float craftax_noise_gradient_angle( + CraftaxThreefryKey angle_key, + int res_cols, + int row, + int col, + const float* override_angles +) { + int width = res_cols + 1; + uint64_t index = (uint64_t)row * (uint64_t)width + (uint64_t)col; + float unit = override_angles == NULL + ? craftax_threefry_uniform_f32_at(angle_key, index) + : override_angles[index]; + return CRAFTAX_NOISE_PI2 * unit; +} + +static inline void craftax_noise_gradient( + CraftaxThreefryKey angle_key, + int res_cols, + int row, + int col, + const float* override_angles, + float* gx, + float* gy +) { + float angle = craftax_noise_gradient_angle( + angle_key, + res_cols, + row, + col, + override_angles + ); + *gx = cosf(angle); + *gy = sinf(angle); +} + +static inline void craftax_generate_perlin_noise_2d( + CraftaxThreefryKey rng, + int rows, + int cols, + int res_rows, + int res_cols, + const float* override_angles, + float* out +) { + CraftaxThreefryKey unused; + CraftaxThreefryKey angle_key; + craftax_threefry_split(rng, &unused, &angle_key); + + int cell_rows = rows / res_rows; + int cell_cols = cols / res_cols; + + for (int row = 0; row < rows; row++) { + int grad_row = row / cell_rows; + float local_row = (float)(row - grad_row * cell_rows) / (float)cell_rows; + float interp_row = craftax_noise_interpolant(local_row); + + for (int col = 0; col < cols; col++) { + int grad_col = col / cell_cols; + float local_col = (float)(col - grad_col * cell_cols) / (float)cell_cols; + float interp_col = craftax_noise_interpolant(local_col); + + float g00x; + float g00y; + float g10x; + float g10y; + float g01x; + float g01y; + float g11x; + float g11y; + craftax_noise_gradient( + angle_key, + res_cols, + grad_row, + grad_col, + override_angles, + &g00x, + &g00y + ); + craftax_noise_gradient( + angle_key, + res_cols, + grad_row + 1, + grad_col, + override_angles, + &g10x, + &g10y + ); + craftax_noise_gradient( + angle_key, + res_cols, + grad_row, + grad_col + 1, + override_angles, + &g01x, + &g01y + ); + craftax_noise_gradient( + angle_key, + res_cols, + grad_row + 1, + grad_col + 1, + override_angles, + &g11x, + &g11y + ); + + float n00 = local_row * g00x; + n00 += local_col * g00y; + float n10 = (local_row - 1.0f) * g10x; + n10 += local_col * g10y; + float n01 = local_row * g01x; + n01 += (local_col - 1.0f) * g01y; + float n11 = (local_row - 1.0f) * g11x; + n11 += (local_col - 1.0f) * g11y; + + float n0 = n00 * (1.0f - interp_row) + interp_row * n10; + float n1 = n01 * (1.0f - interp_row) + interp_row * n11; + out[(size_t)row * (size_t)cols + (size_t)col] = + CRAFTAX_NOISE_SQRT2 * ((1.0f - interp_col) * n0 + interp_col * n1); + } + } +} + +static inline void craftax_generate_fractal_noise_2d( + CraftaxThreefryKey rng, + int rows, + int cols, + int res_rows, + int res_cols, + int octaves, + float persistence, + int lacunarity, + const float* override_angles, + float* out +) { + size_t size = (size_t)rows * (size_t)cols; + for (size_t i = 0; i < size; i++) { + out[i] = 0.0f; + } + + int frequency = 1; + float amplitude = 1.0f; + float perlin[size]; + + for (int octave = 0; octave < octaves; octave++) { + CraftaxThreefryKey next_rng; + CraftaxThreefryKey noise_key; + craftax_threefry_split(rng, &next_rng, &noise_key); + rng = next_rng; + + craftax_generate_perlin_noise_2d( + noise_key, + rows, + cols, + frequency * res_rows, + frequency * res_cols, + override_angles, + perlin + ); + + for (size_t i = 0; i < size; i++) { + out[i] += amplitude * perlin[i]; + } + + frequency *= lacunarity; + amplitude *= persistence; + } + + float min_value = out[0]; + float max_value = out[0]; + for (size_t i = 1; i < size; i++) { + if (out[i] < min_value) { + min_value = out[i]; + } + if (out[i] > max_value) { + max_value = out[i]; + } + } + + float scale = max_value - min_value; + for (size_t i = 0; i < size; i++) { + out[i] = (out[i] - min_value) / scale; + } +} diff --git a/ocean/craftax/pack_textures.py b/ocean/craftax/pack_textures.py new file mode 100644 index 0000000000..9fdfa24dd5 --- /dev/null +++ b/ocean/craftax/pack_textures.py @@ -0,0 +1,136 @@ +"""Pack Craftax upstream 16x16 PNG assets into a single shared textures.bin. + +Consumed by both ocean/craftax (full) and ocean/craftax_classic. All files +live in craftax's asset dir; the classic PNGs that overlap are byte-identical +to the full ones. + +Layout: contiguous 16*16*4 RGBA tiles. Order must match the +CRAFTAX_TEX_* / CC_TEX_* enums in the two env headers. + + [0..36] block textures (37) -- BlockType; first 17 entries also valid for classic + [37..41] player: down, up, left, right, sleep + [42..46] items: none(blank), torch, ladder_down, ladder_up, ladder_down_blocked + [47..49] mobs: zombie, skeleton, cow + [50..53] arrows: down, up, left, right +""" + +from pathlib import Path +from PIL import Image +import numpy as np + +ASSETS = Path(__file__).resolve().parents[2] / ( + ".venv/lib/python3.12/site-packages/craftax/craftax/assets" +) +OUT_BIN = Path(__file__).resolve().parents[2] / "resources" / "craftax" / "textures.bin" + +TILE = 16 + +BLOCK_FILES = [ + "debug_tile.png", # 0 INVALID + "debug_tile.png", # 1 OUT_OF_BOUNDS (overwritten solid grey below) + "grass.png", # 2 + "water.png", # 3 + "stone.png", # 4 + "tree.png", # 5 + "wood.png", # 6 + "path.png", # 7 + "coal.png", # 8 + "iron.png", # 9 + "diamond.png", # 10 + "table.png", # 11 crafting table + "furnace.png", # 12 + "sand.png", # 13 + "lava.png", # 14 + "plant_on_grass.png", # 15 + "ripe_plant_on_grass.png", # 16 + "wall2.png", # 17 + "debug_tile.png", # 18 DARKNESS (overwritten solid black below) + "wall_moss.png", # 19 + "stalagmite.png", # 20 + "sapphire.png", # 21 + "ruby.png", # 22 + "chest.png", # 23 + "fountain.png", # 24 + "fire_grass.png", # 25 + "ice_grass.png", # 26 + "gravel.png", # 27 + "fire_tree.png", # 28 + "ice_shrub.png", # 29 + "enchantment_table_fire.png",# 30 + "enchantment_table_ice.png", # 31 + "necromancer.png", # 32 + "grave.png", # 33 + "grave2.png", # 34 + "grave3.png", # 35 + "necromancer_vulnerable.png",# 36 +] + +PLAYER_FILES = [ + "player-down.png", + "player-up.png", + "player-left.png", + "player-right.png", + "player-sleep.png", +] + +ITEM_FILES = [ + None, # NONE -> fully transparent + "torch_on_path.png", + "ladder_down.png", + "ladder_up.png", + "ladder_down_blocked.png", +] + +MOB_FILES = [ + "zombie.png", + "skeleton.png", + "cow.png", +] + +ARROW_FILES = [ + "arrow-down.png", + "arrow-up.png", + "arrow-left.png", + "arrow-right.png", +] + + +def load_tile(name: str | None) -> np.ndarray: + if name is None: + return np.zeros((TILE, TILE, 4), dtype=np.uint8) + p = ASSETS / name + img = Image.open(p).convert("RGBA").resize((TILE, TILE), Image.NEAREST) + return np.asarray(img, dtype=np.uint8) + + +def main() -> None: + tiles: list[np.ndarray] = [] + for f in BLOCK_FILES: + tiles.append(load_tile(f)) + + # manual overrides to match upstream renderer + tiles[1] = np.full((TILE, TILE, 4), 128, dtype=np.uint8); tiles[1][..., 3] = 255 # out of bounds + tiles[18] = np.zeros((TILE, TILE, 4), dtype=np.uint8); tiles[18][..., 3] = 255 # darkness + + for f in PLAYER_FILES: + tiles.append(load_tile(f)) + + # torch_in_walls doesn't exist in assets; fall back to torch.png if needed + for f in ITEM_FILES: + if f is not None and not (ASSETS / f).exists(): + alt = "torch.png" if "torch" in f else f + tiles.append(load_tile(alt)) + else: + tiles.append(load_tile(f)) + + for f in MOB_FILES + ARROW_FILES: + tiles.append(load_tile(f)) + + blob = np.stack(tiles, axis=0) # (N, 16, 16, 4) uint8 + assert blob.dtype == np.uint8 + OUT_BIN.write_bytes(blob.tobytes(order="C")) + print(f"wrote {OUT_BIN} — {blob.shape[0]} tiles, {OUT_BIN.stat().st_size} bytes") + + +if __name__ == "__main__": + main() diff --git a/ocean/craftax/step_crafting.h b/ocean/craftax/step_crafting.h new file mode 100644 index 0000000000..a40f3bca9b --- /dev/null +++ b/ocean/craftax/step_crafting.h @@ -0,0 +1,424 @@ +// Standalone native ports of Craftax crafting and placement subsystems. +// +// These helpers intentionally are not integrated into c_step yet. They mutate a +// full CraftaxState in place so tests can compare each subsystem directly +// against the installed JAX implementation. + +#pragma once + +#include "step_simple.h" + +static inline bool craftax_crafting_is_near_block( + const CraftaxState* state, + int32_t block_type +) { + static const int32_t close_blocks[8][2] = { + {0, -1}, + {0, 1}, + {-1, 0}, + {1, 0}, + {-1, -1}, + {-1, 1}, + {1, -1}, + {1, 1}, + }; + + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + for (int32_t i = 0; i < 8; i++) { + int32_t row = state->player_position[0] + close_blocks[i][0]; + int32_t col = state->player_position[1] + close_blocks[i][1]; + bool in_bounds = row >= 0 + && row < CRAFTAX_MAP_SIZE + && col >= 0 + && col < CRAFTAX_MAP_SIZE; + if (in_bounds && state->map[level][row][col] == block_type) { + return true; + } + } + return false; +} + +static inline int32_t craftax_crafting_first_armour_below( + const CraftaxInventory* inventory, + int32_t threshold, + int32_t* count +) { + int32_t first = 0; + *count = 0; + for (int32_t i = 0; i < 4; i++) { + bool below = inventory->armour[i] < threshold; + first = (*count == 0 && below) ? i : first; + *count += (int32_t)below; + } + return first; +} + +static inline void craftax_do_crafting_native( + CraftaxState* state, + int32_t action +) { + bool is_at_crafting_table = craftax_crafting_is_near_block( + state, + CRAFTAX_BLOCK_CRAFTING_TABLE + ); + bool is_at_furnace = craftax_crafting_is_near_block( + state, + CRAFTAX_BLOCK_FURNACE + ); + + CraftaxInventory* inventory = &state->inventory; + + bool can_craft_wood_pickaxe = inventory->wood >= 1; + bool is_crafting_wood_pickaxe = + action == CRAFTAX_ACTION_MAKE_WOOD_PICKAXE + && can_craft_wood_pickaxe + && is_at_crafting_table + && inventory->pickaxe < 1; + inventory->wood -= 1 * (int32_t)is_crafting_wood_pickaxe; + inventory->pickaxe = + inventory->pickaxe * (1 - (int32_t)is_crafting_wood_pickaxe) + + 1 * (int32_t)is_crafting_wood_pickaxe; + + bool can_craft_stone_pickaxe = + inventory->wood >= 1 && inventory->stone >= 1; + bool is_crafting_stone_pickaxe = + action == CRAFTAX_ACTION_MAKE_STONE_PICKAXE + && can_craft_stone_pickaxe + && is_at_crafting_table + && inventory->pickaxe < 2; + inventory->stone -= 1 * (int32_t)is_crafting_stone_pickaxe; + inventory->wood -= 1 * (int32_t)is_crafting_stone_pickaxe; + inventory->pickaxe = + inventory->pickaxe * (1 - (int32_t)is_crafting_stone_pickaxe) + + 2 * (int32_t)is_crafting_stone_pickaxe; + + bool can_craft_iron_pickaxe = + inventory->wood >= 1 + && inventory->stone >= 1 + && inventory->iron >= 1 + && inventory->coal >= 1; + bool is_crafting_iron_pickaxe = + action == CRAFTAX_ACTION_MAKE_IRON_PICKAXE + && can_craft_iron_pickaxe + && is_at_furnace + && is_at_crafting_table + && inventory->pickaxe < 3; + inventory->iron -= 1 * (int32_t)is_crafting_iron_pickaxe; + inventory->wood -= 1 * (int32_t)is_crafting_iron_pickaxe; + inventory->stone -= 1 * (int32_t)is_crafting_iron_pickaxe; + inventory->coal -= 1 * (int32_t)is_crafting_iron_pickaxe; + inventory->pickaxe = + inventory->pickaxe * (1 - (int32_t)is_crafting_iron_pickaxe) + + 3 * (int32_t)is_crafting_iron_pickaxe; + + bool can_craft_diamond_pickaxe = + inventory->wood >= 1 && inventory->diamond >= 3; + bool is_crafting_diamond_pickaxe = + action == CRAFTAX_ACTION_MAKE_DIAMOND_PICKAXE + && can_craft_diamond_pickaxe + && is_at_crafting_table + && inventory->pickaxe < 4; + inventory->diamond -= 3 * (int32_t)is_crafting_diamond_pickaxe; + inventory->wood -= 1 * (int32_t)is_crafting_diamond_pickaxe; + inventory->pickaxe = + inventory->pickaxe * (1 - (int32_t)is_crafting_diamond_pickaxe) + + 4 * (int32_t)is_crafting_diamond_pickaxe; + + bool can_craft_wood_sword = inventory->wood >= 1; + bool is_crafting_wood_sword = + action == CRAFTAX_ACTION_MAKE_WOOD_SWORD + && can_craft_wood_sword + && is_at_crafting_table + && inventory->sword < 1; + inventory->wood -= 1 * (int32_t)is_crafting_wood_sword; + inventory->sword = + inventory->sword * (1 - (int32_t)is_crafting_wood_sword) + + 1 * (int32_t)is_crafting_wood_sword; + + bool can_craft_stone_sword = + inventory->stone >= 1 && inventory->wood >= 1; + bool is_crafting_stone_sword = + action == CRAFTAX_ACTION_MAKE_STONE_SWORD + && can_craft_stone_sword + && is_at_crafting_table + && inventory->sword < 2; + inventory->wood -= 1 * (int32_t)is_crafting_stone_sword; + inventory->stone -= 1 * (int32_t)is_crafting_stone_sword; + inventory->sword = + inventory->sword * (1 - (int32_t)is_crafting_stone_sword) + + 2 * (int32_t)is_crafting_stone_sword; + + bool can_craft_iron_sword = + inventory->iron >= 1 + && inventory->wood >= 1 + && inventory->stone >= 1 + && inventory->coal >= 1; + bool is_crafting_iron_sword = + action == CRAFTAX_ACTION_MAKE_IRON_SWORD + && can_craft_iron_sword + && is_at_furnace + && is_at_crafting_table + && inventory->sword < 3; + inventory->wood -= 1 * (int32_t)is_crafting_iron_sword; + inventory->iron -= 1 * (int32_t)is_crafting_iron_sword; + inventory->stone -= 1 * (int32_t)is_crafting_iron_sword; + inventory->coal -= 1 * (int32_t)is_crafting_iron_sword; + inventory->sword = + inventory->sword * (1 - (int32_t)is_crafting_iron_sword) + + 3 * (int32_t)is_crafting_iron_sword; + + bool can_craft_diamond_sword = + inventory->diamond >= 2 && inventory->wood >= 1; + bool is_crafting_diamond_sword = + action == CRAFTAX_ACTION_MAKE_DIAMOND_SWORD + && can_craft_diamond_sword + && is_at_crafting_table + && inventory->sword < 4; + inventory->wood -= 1 * (int32_t)is_crafting_diamond_sword; + inventory->diamond -= 2 * (int32_t)is_crafting_diamond_sword; + inventory->sword = + inventory->sword * (1 - (int32_t)is_crafting_diamond_sword) + + 4 * (int32_t)is_crafting_diamond_sword; + + int32_t armour_count = 0; + int32_t iron_armour_index_to_craft = + craftax_crafting_first_armour_below(inventory, 1, &armour_count); + bool can_craft_iron_armour = + armour_count > 0 && inventory->iron >= 3 && inventory->coal >= 3; + bool is_crafting_iron_armour = + action == CRAFTAX_ACTION_MAKE_IRON_ARMOUR + && can_craft_iron_armour + && is_at_crafting_table + && is_at_furnace; + inventory->iron -= 3 * (int32_t)is_crafting_iron_armour; + inventory->coal -= 3 * (int32_t)is_crafting_iron_armour; + inventory->armour[iron_armour_index_to_craft] = + (int32_t)is_crafting_iron_armour * 1 + + (1 - (int32_t)is_crafting_iron_armour) + * inventory->armour[iron_armour_index_to_craft]; + state->achievements[CRAFTAX_ACH_MAKE_IRON_ARMOUR] = + state->achievements[CRAFTAX_ACH_MAKE_IRON_ARMOUR] + || is_crafting_iron_armour; + + int32_t diamond_armour_count = 0; + int32_t diamond_armour_index_to_craft = + craftax_crafting_first_armour_below(inventory, 2, &diamond_armour_count); + bool can_craft_diamond_armour = + diamond_armour_count > 0 && inventory->diamond >= 3; + bool is_crafting_diamond_armour = + action == CRAFTAX_ACTION_MAKE_DIAMOND_ARMOUR + && can_craft_diamond_armour + && is_at_crafting_table; + inventory->diamond -= 3 * (int32_t)is_crafting_diamond_armour; + inventory->armour[diamond_armour_index_to_craft] = + (int32_t)is_crafting_diamond_armour * 2 + + (1 - (int32_t)is_crafting_diamond_armour) + * inventory->armour[diamond_armour_index_to_craft]; + state->achievements[CRAFTAX_ACH_MAKE_DIAMOND_ARMOUR] = + state->achievements[CRAFTAX_ACH_MAKE_DIAMOND_ARMOUR] + || is_crafting_diamond_armour; + + bool can_craft_arrow = inventory->stone >= 1 && inventory->wood >= 1; + bool is_crafting_arrow = + action == CRAFTAX_ACTION_MAKE_ARROW + && can_craft_arrow + && is_at_crafting_table + && inventory->arrows < 99; + inventory->wood -= 1 * (int32_t)is_crafting_arrow; + inventory->stone -= 1 * (int32_t)is_crafting_arrow; + inventory->arrows += 2 * (int32_t)is_crafting_arrow; + + bool can_craft_torch = inventory->coal >= 1 && inventory->wood >= 1; + bool is_crafting_torch = + action == CRAFTAX_ACTION_MAKE_TORCH + && can_craft_torch + && is_at_crafting_table + && inventory->torches < 99; + inventory->wood -= 1 * (int32_t)is_crafting_torch; + inventory->coal -= 1 * (int32_t)is_crafting_torch; + inventory->torches += 4 * (int32_t)is_crafting_torch; +} + +static inline bool craftax_crafting_can_place_item(int32_t block) { + switch (block) { + case CRAFTAX_BLOCK_GRASS: + case CRAFTAX_BLOCK_SAND: + case CRAFTAX_BLOCK_PATH: + case CRAFTAX_BLOCK_FIRE_GRASS: + case CRAFTAX_BLOCK_ICE_GRASS: + return true; + default: + return false; + } +} + +static inline float craftax_crafting_torch_light(int32_t row, int32_t col) { + static const float torch_light_map[9][9] = { + {0.0f, 0.0f, 0.10557288f, 0.17537886f, 0.19999999f, 0.17537886f, 0.10557288f, 0.0f, 0.0f}, + {0.0f, 0.15147191f, 0.27888972f, 0.36754447f, 0.39999998f, 0.36754447f, 0.27888972f, 0.15147191f, 0.0f}, + {0.10557288f, 0.27888972f, 0.43431455f, 0.55278647f, 0.6f, 0.55278647f, 0.43431455f, 0.27888972f, 0.10557288f}, + {0.17537886f, 0.36754447f, 0.55278647f, 0.71715724f, 0.8f, 0.71715724f, 0.55278647f, 0.36754447f, 0.17537886f}, + {0.19999999f, 0.39999998f, 0.6f, 0.8f, 1.0f, 0.8f, 0.6f, 0.39999998f, 0.19999999f}, + {0.17537886f, 0.36754447f, 0.55278647f, 0.71715724f, 0.8f, 0.71715724f, 0.55278647f, 0.36754447f, 0.17537886f}, + {0.10557288f, 0.27888972f, 0.43431455f, 0.55278647f, 0.6f, 0.55278647f, 0.43431455f, 0.27888972f, 0.10557288f}, + {0.0f, 0.15147191f, 0.27888972f, 0.36754447f, 0.39999998f, 0.36754447f, 0.27888972f, 0.15147191f, 0.0f}, + {0.0f, 0.0f, 0.10557288f, 0.17537886f, 0.19999999f, 0.17537886f, 0.10557288f, 0.0f, 0.0f}, + }; + return torch_light_map[row][col]; +} + +static inline void craftax_crafting_add_torch_light( + CraftaxState* state, + int32_t level, + int32_t row, + int32_t col +) { + for (int32_t dr = -4; dr <= 4; dr++) { + int32_t map_row = row + dr; + if (map_row < 0 || map_row >= CRAFTAX_MAP_SIZE) { + continue; + } + for (int32_t dc = -4; dc <= 4; dc++) { + int32_t map_col = col + dc; + if (map_col < 0 || map_col >= CRAFTAX_MAP_SIZE) { + continue; + } + float light = state->light_map[level][map_row][map_col] + + craftax_crafting_torch_light(dr + 4, dc + 4); + state->light_map[level][map_row][map_col] = + craftax_step_minf32(craftax_step_maxf32(light, 0.0f), 1.0f); + } + } +} + +static inline void craftax_add_new_growing_plant_native( + CraftaxState* state, + const int32_t position[2], + bool is_placing_sapling +) { + int32_t plant_index = 0; + int32_t empty_count = 0; + for (int32_t i = 0; i < CRAFTAX_MAX_GROWING_PLANTS; i++) { + bool is_empty = !state->growing_plants_mask[i]; + plant_index = (empty_count == 0 && is_empty) ? i : plant_index; + empty_count += (int32_t)is_empty; + } + + bool is_adding_plant = empty_count > 0 && is_placing_sapling; + if (!is_adding_plant) { + return; + } + + state->growing_plants_positions[plant_index][0] = position[0]; + state->growing_plants_positions[plant_index][1] = position[1]; + state->growing_plants_age[plant_index] = 0; + state->growing_plants_mask[plant_index] = true; +} + +static inline void craftax_place_block_native( + CraftaxState* state, + int32_t action +) { + int32_t direction[2]; + craftax_step_direction(state->player_direction, direction); + + int32_t row = state->player_position[0] + direction[0]; + int32_t col = state->player_position[1] + direction[1]; + bool in_bounds = row >= 0 + && row < CRAFTAX_MAP_SIZE + && col >= 0 + && col < CRAFTAX_MAP_SIZE; + bool in_mob = in_bounds && craftax_step_is_in_mob(state, row, col); + if (!in_bounds || in_mob) { + return; + } + + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + int32_t original_block = state->map[level][row][col]; + int32_t original_item = state->item_map[level][row][col]; + bool is_placement_on_solid_block_or_item = + craftax_step_is_solid_block(original_block) + || original_item != CRAFTAX_ITEM_NONE; + + CraftaxInventory* inventory = &state->inventory; + + bool is_placing_crafting_table = + action == CRAFTAX_ACTION_PLACE_TABLE + && !is_placement_on_solid_block_or_item + && inventory->wood >= 2; + if (is_placing_crafting_table) { + state->map[level][row][col] = CRAFTAX_BLOCK_CRAFTING_TABLE; + } + inventory->wood -= 2 * (int32_t)is_placing_crafting_table; + state->achievements[CRAFTAX_ACH_PLACE_TABLE] = + state->achievements[CRAFTAX_ACH_PLACE_TABLE] + || is_placing_crafting_table; + + bool is_placing_furnace = + action == CRAFTAX_ACTION_PLACE_FURNACE + && !is_placement_on_solid_block_or_item + && inventory->stone > 0; + if (is_placing_furnace) { + state->map[level][row][col] = CRAFTAX_BLOCK_FURNACE; + } + inventory->stone -= 1 * (int32_t)is_placing_furnace; + state->achievements[CRAFTAX_ACH_PLACE_FURNACE] = + state->achievements[CRAFTAX_ACH_PLACE_FURNACE] + || is_placing_furnace; + + bool is_placing_on_valid_stone_block = + original_block == CRAFTAX_BLOCK_WATER + || !is_placement_on_solid_block_or_item; + bool is_placing_stone = + action == CRAFTAX_ACTION_PLACE_STONE + && is_placing_on_valid_stone_block + && inventory->stone > 0; + if (is_placing_stone) { + state->map[level][row][col] = CRAFTAX_BLOCK_STONE; + } + inventory->stone -= 1 * (int32_t)is_placing_stone; + state->achievements[CRAFTAX_ACH_PLACE_STONE] = + state->achievements[CRAFTAX_ACH_PLACE_STONE] + || is_placing_stone; + + bool is_placing_on_valid_torch_block = + craftax_crafting_can_place_item(original_block) + && state->item_map[level][row][col] == CRAFTAX_ITEM_NONE; + bool is_placing_torch = + action == CRAFTAX_ACTION_PLACE_TORCH + && is_placing_on_valid_torch_block + && inventory->torches > 0; + if (is_placing_torch) { + state->item_map[level][row][col] = CRAFTAX_ITEM_TORCH; + craftax_crafting_add_torch_light(state, level, row, col); + } + inventory->torches -= 1 * (int32_t)is_placing_torch; + state->achievements[CRAFTAX_ACH_PLACE_TORCH] = + state->achievements[CRAFTAX_ACH_PLACE_TORCH] + || is_placing_torch; + + bool is_placing_sapling = + action == CRAFTAX_ACTION_PLACE_PLANT + && state->map[level][row][col] == CRAFTAX_BLOCK_GRASS + && inventory->sapling > 0 + && state->item_map[level][row][col] == CRAFTAX_ITEM_NONE; + if (is_placing_sapling) { + int32_t position[2] = {row, col}; + state->map[level][row][col] = CRAFTAX_BLOCK_PLANT; + craftax_add_new_growing_plant_native( + state, + position, + is_placing_sapling + ); + } + inventory->sapling -= 1 * (int32_t)is_placing_sapling; + state->achievements[CRAFTAX_ACH_PLACE_PLANT] = + state->achievements[CRAFTAX_ACH_PLACE_PLANT] + || is_placing_sapling; +} diff --git a/ocean/craftax/step_do_action.h b/ocean/craftax/step_do_action.h new file mode 100644 index 0000000000..6f5dcc934a --- /dev/null +++ b/ocean/craftax/step_do_action.h @@ -0,0 +1,605 @@ +// Standalone native port of Craftax do_action. +// +// This helper intentionally is not integrated into c_step yet. It mutates a +// full CraftaxState in place so tests can compare the subsystem directly +// against the installed JAX implementation. + +#pragma once + +#include "step_medium.h" + +#define CRAFTAX_DO_ACTION_BOSS_FIGHT_SPAWN_TURNS 7 + +static inline float craftax_do_action_mob_defense( + int32_t type_id, + int32_t mob_class_index, + int32_t damage_index +) { + static const float defenses[8][4][3] = { + { + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {0.5f, 0.0f, 0.0f}, + {0.5f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {0.2f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {0.9f, 1.0f, 0.0f}, + {0.9f, 1.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {0.9f, 0.0f, 1.0f}, + {0.9f, 0.0f, 1.0f}, + {0.0f, 0.0f, 0.0f}, + }, + }; + + int32_t type_index = craftax_step_jax_index(type_id, 8); + int32_t class_index = craftax_step_jax_index(mob_class_index, 4); + int32_t component = craftax_step_jax_index(damage_index, 3); + return defenses[type_index][class_index][component]; +} + +static inline int32_t craftax_do_action_mob_achievement( + int32_t mob_class_index, + int32_t type_id +) { + static const int32_t achievements[3][8] = { + { + CRAFTAX_ACH_EAT_COW, + CRAFTAX_ACH_EAT_BAT, + CRAFTAX_ACH_EAT_SNAIL, + 0, + 0, + 0, + 0, + 0, + }, + { + CRAFTAX_ACH_DEFEAT_ZOMBIE, + CRAFTAX_ACH_DEFEAT_GNOME_WARRIOR, + CRAFTAX_ACH_DEFEAT_ORC_SOLIDER, + CRAFTAX_ACH_DEFEAT_LIZARD, + CRAFTAX_ACH_DEFEAT_KNIGHT, + CRAFTAX_ACH_DEFEAT_TROLL, + CRAFTAX_ACH_DEFEAT_PIGMAN, + CRAFTAX_ACH_DEFEAT_FROST_TROLL, + }, + { + CRAFTAX_ACH_DEFEAT_SKELETON, + CRAFTAX_ACH_DEFEAT_GNOME_ARCHER, + CRAFTAX_ACH_DEFEAT_ORC_MAGE, + CRAFTAX_ACH_DEFEAT_KOBOLD, + CRAFTAX_ACH_DEFEAT_ARCHER, + CRAFTAX_ACH_DEFEAT_DEEP_THING, + CRAFTAX_ACH_DEFEAT_FIRE_ELEMENTAL, + CRAFTAX_ACH_DEFEAT_ICE_ELEMENTAL, + }, + }; + + int32_t class_index = craftax_step_jax_index(mob_class_index, 3); + int32_t type_index = craftax_step_jax_index(type_id, 8); + return achievements[class_index][type_index]; +} + +static inline void craftax_do_action_player_damage_vector( + const CraftaxState* state, + float damage_vector[3] +) { + static const float physical_damages[5] = {1.0f, 2.0f, 3.0f, 5.0f, 8.0f}; + + int32_t sword_index = craftax_step_jax_index(state->inventory.sword, 5); + float physical_damage = physical_damages[sword_index]; + float fire_damage = + physical_damage * (float)(state->sword_enchantment == 1) * 0.5f; + float ice_damage = + physical_damage * (float)(state->sword_enchantment == 2) * 0.5f; + + physical_damage *= 1.0f + 0.25f * (float)(state->player_strength - 1); + fire_damage *= 1.0f + 0.05f * (float)(state->player_intelligence - 1); + ice_damage *= 1.0f + 0.05f * (float)(state->player_intelligence - 1); + + damage_vector[0] = physical_damage; + damage_vector[1] = fire_damage; + damage_vector[2] = ice_damage; +} + +static inline float craftax_do_action_damage_done( + const float damage_vector[3], + int32_t type_id, + int32_t mob_class_index +) { + float damage = 0.0f; + for (int32_t i = 0; i < 3; i++) { + float defense = craftax_do_action_mob_defense( + type_id, + mob_class_index, + i + ); + damage += (1.0f - defense) * damage_vector[i]; + } + return damage; +} + +static inline void craftax_do_action_refresh_mobs3_masks(CraftaxMobs3* mobs) { + for (int32_t level = 0; level < CRAFTAX_NUM_LEVELS; level++) { + for (int32_t i = 0; i < 3; i++) { + mobs->mask[level][i] = + mobs->mask[level][i] && mobs->health[level][i] > 0.0f; + } + } +} + +static inline void craftax_do_action_refresh_mobs2_masks(CraftaxMobs2* mobs) { + for (int32_t level = 0; level < CRAFTAX_NUM_LEVELS; level++) { + for (int32_t i = 0; i < 2; i++) { + mobs->mask[level][i] = + mobs->mask[level][i] && mobs->health[level][i] > 0.0f; + } + } +} + +static inline void craftax_do_action_attack_mobs3( + CraftaxState* state, + CraftaxMobs3* mobs, + int32_t row, + int32_t col, + const float damage_vector[3], + bool can_get_achievement, + int32_t mob_class_index, + bool* did_kill_mob, + bool* is_attacking_mob +) { + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + bool is_attacking_array[3]; + *is_attacking_mob = false; + int32_t target_mob_index = 0; + + for (int32_t i = 0; i < 3; i++) { + bool in_mob = mobs->position[level][i][0] == row + && mobs->position[level][i][1] == col; + is_attacking_array[i] = in_mob && mobs->mask[level][i]; + if (is_attacking_array[i] && !*is_attacking_mob) { + target_mob_index = i; + } + *is_attacking_mob = *is_attacking_mob || is_attacking_array[i]; + } + + int32_t target_type_id = mobs->type_id[level][target_mob_index]; + float damage = craftax_do_action_damage_done( + damage_vector, + target_type_id, + mob_class_index + ); + mobs->health[level][target_mob_index] -= + damage * (float)(int32_t)(*is_attacking_mob); + + bool old_mask = mobs->mask[level][target_mob_index]; + craftax_do_action_refresh_mobs3_masks(mobs); + *did_kill_mob = old_mask && !mobs->mask[level][target_mob_index]; + + int32_t achievement_for_kill = craftax_do_action_mob_achievement( + mob_class_index, + target_type_id + ); + bool unlock = *did_kill_mob && can_get_achievement; + state->achievements[achievement_for_kill] = + state->achievements[achievement_for_kill] || unlock; +} + +static inline void craftax_do_action_attack_mobs2( + CraftaxState* state, + CraftaxMobs2* mobs, + int32_t row, + int32_t col, + const float damage_vector[3], + bool can_get_achievement, + int32_t mob_class_index, + bool* did_kill_mob, + bool* is_attacking_mob +) { + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + bool is_attacking_array[2]; + *is_attacking_mob = false; + int32_t target_mob_index = 0; + + for (int32_t i = 0; i < 2; i++) { + bool in_mob = mobs->position[level][i][0] == row + && mobs->position[level][i][1] == col; + is_attacking_array[i] = in_mob && mobs->mask[level][i]; + if (is_attacking_array[i] && !*is_attacking_mob) { + target_mob_index = i; + } + *is_attacking_mob = *is_attacking_mob || is_attacking_array[i]; + } + + int32_t target_type_id = mobs->type_id[level][target_mob_index]; + float damage = craftax_do_action_damage_done( + damage_vector, + target_type_id, + mob_class_index + ); + mobs->health[level][target_mob_index] -= + damage * (float)(int32_t)(*is_attacking_mob); + + bool old_mask = mobs->mask[level][target_mob_index]; + craftax_do_action_refresh_mobs2_masks(mobs); + *did_kill_mob = old_mask && !mobs->mask[level][target_mob_index]; + + int32_t achievement_for_kill = craftax_do_action_mob_achievement( + mob_class_index, + target_type_id + ); + bool unlock = *did_kill_mob && can_get_achievement; + state->achievements[achievement_for_kill] = + state->achievements[achievement_for_kill] || unlock; +} + +static inline bool craftax_do_action_update_index( + int32_t index, + int32_t size, + int32_t* mapped_index +) { + if (index < -size || index >= size) { + return false; + } + *mapped_index = index < 0 ? index + size : index; + return true; +} + +static inline void craftax_do_action_update_mob_map( + CraftaxState* state, + int32_t row, + int32_t col, + bool did_kill_mob +) { + int32_t update_row; + int32_t update_col; + if (!craftax_do_action_update_index(row, CRAFTAX_MAP_SIZE, &update_row) + || !craftax_do_action_update_index(col, CRAFTAX_MAP_SIZE, &update_col)) { + return; + } + + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + int32_t read_row = craftax_step_jax_index(row, CRAFTAX_MAP_SIZE); + int32_t read_col = craftax_step_jax_index(col, CRAFTAX_MAP_SIZE); + state->mob_map[level][update_row][update_col] = + state->mob_map[level][read_row][read_col] && !did_kill_mob; +} + +static inline void craftax_do_action_attack_mob( + CraftaxState* state, + int32_t row, + int32_t col, + bool can_eat, + bool* did_attack_mob, + bool* did_kill_mob +) { + float damage_vector[3]; + craftax_do_action_player_damage_vector(state, damage_vector); + + bool did_kill_melee_mob = false; + bool is_attacking_melee_mob = false; + craftax_do_action_attack_mobs3( + state, + &state->melee_mobs, + row, + col, + damage_vector, + true, + 1, + &did_kill_melee_mob, + &is_attacking_melee_mob + ); + + bool did_kill_passive_mob = false; + bool is_attacking_passive_mob = false; + craftax_do_action_attack_mobs3( + state, + &state->passive_mobs, + row, + col, + damage_vector, + can_eat, + 0, + &did_kill_passive_mob, + &is_attacking_passive_mob + ); + + if (did_kill_passive_mob && can_eat) { + state->player_food = craftax_step_mini32( + craftax_step_get_max_food(state), + state->player_food + 6 + ); + state->player_hunger = 0.0f; + } + + bool did_kill_ranged_mob = false; + bool is_attacking_ranged_mob = false; + craftax_do_action_attack_mobs2( + state, + &state->ranged_mobs, + row, + col, + damage_vector, + true, + 2, + &did_kill_ranged_mob, + &is_attacking_ranged_mob + ); + + *did_attack_mob = is_attacking_melee_mob + || is_attacking_passive_mob + || is_attacking_ranged_mob; + bool did_kill_monster = did_kill_melee_mob || did_kill_ranged_mob; + *did_kill_mob = did_kill_monster || did_kill_passive_mob; + + craftax_do_action_update_mob_map(state, row, col, *did_kill_mob); + + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + state->monsters_killed[level] += (int32_t)did_kill_monster; +} + +static inline bool craftax_do_action_in_bounds(int32_t row, int32_t col) { + return row >= 0 + && row < CRAFTAX_MAP_SIZE + && col >= 0 + && col < CRAFTAX_MAP_SIZE; +} + +static inline bool craftax_do_action_boss_vulnerable( + const CraftaxState* state +) { + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + int32_t melee_count = 0; + int32_t ranged_count = 0; + for (int32_t i = 0; i < CRAFTAX_MAX_MELEE_MOBS; i++) { + melee_count += (int32_t)state->melee_mobs.mask[level][i]; + } + for (int32_t i = 0; i < CRAFTAX_MAX_RANGED_MOBS; i++) { + ranged_count += (int32_t)state->ranged_mobs.mask[level][i]; + } + return melee_count == 0 + && ranged_count == 0 + && state->boss_timesteps_to_spawn_this_round <= 0; +} + +static inline void craftax_do_action_update_plants_with_eat( + CraftaxState* state, + int32_t row, + int32_t col +) { + int32_t plant_index = 0; + bool found = false; + for (int32_t i = 0; i < CRAFTAX_MAX_GROWING_PLANTS; i++) { + bool is_plant = state->growing_plants_positions[i][0] == row + && state->growing_plants_positions[i][1] == col; + if (is_plant && !found) { + plant_index = i; + found = true; + } + } + state->growing_plants_age[plant_index] = 0; +} + +static inline void craftax_do_action_native( + CraftaxState* state, + int32_t action, + CraftaxThreefryKey rng +) { + if (action != CRAFTAX_ACTION_DO) { + return; + } + + int32_t direction[2]; + craftax_step_direction(state->player_direction, direction); + int32_t target_row = state->player_position[0] + direction[0]; + int32_t target_col = state->player_position[1] + direction[1]; + + bool did_attack_mob = false; + bool did_kill_mob = false; + craftax_do_action_attack_mob( + state, + target_row, + target_col, + true, + &did_attack_mob, + &did_kill_mob + ); + (void)did_kill_mob; + + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + int32_t read_row = craftax_step_jax_index(target_row, CRAFTAX_MAP_SIZE); + int32_t read_col = craftax_step_jax_index(target_col, CRAFTAX_MAP_SIZE); + int32_t target_block = state->map[level][read_row][read_col]; + + CraftaxThreefryKey sapling_key = craftax_medium_next_random_key(&rng); + CraftaxThreefryKey chest_key = craftax_medium_next_random_key(&rng); + + bool is_opening_chest = target_block == CRAFTAX_BLOCK_CHEST; + bool is_damaging_boss = target_block == CRAFTAX_BLOCK_NECROMANCER + && craftax_do_action_boss_vulnerable(state) + && craftax_step_is_fighting_boss(state); + + bool action_block_in_bounds = + craftax_do_action_in_bounds(target_row, target_col) && !did_attack_mob; + + if (action_block_in_bounds) { + bool is_block_tree = target_block == CRAFTAX_BLOCK_TREE; + bool is_block_fire_tree = target_block == CRAFTAX_BLOCK_FIRE_TREE; + bool is_block_ice_shrub = target_block == CRAFTAX_BLOCK_ICE_SHRUB; + bool is_mining_tree = + is_block_tree || is_block_fire_tree || is_block_ice_shrub; + if (is_mining_tree) { + int32_t replacement = is_block_tree + ? CRAFTAX_BLOCK_GRASS + : (is_block_fire_tree + ? CRAFTAX_BLOCK_FIRE_GRASS + : CRAFTAX_BLOCK_ICE_GRASS); + state->map[level][target_row][target_col] = replacement; + state->inventory.wood += 1; + } + + bool is_mining_stone = target_block == CRAFTAX_BLOCK_STONE + && state->inventory.pickaxe >= 1; + if (is_mining_stone) { + state->map[level][target_row][target_col] = CRAFTAX_BLOCK_PATH; + state->inventory.stone += 1; + } + + if (target_block == CRAFTAX_BLOCK_FURNACE) { + state->map[level][target_row][target_col] = CRAFTAX_BLOCK_PATH; + } + + if (target_block == CRAFTAX_BLOCK_CRAFTING_TABLE) { + state->map[level][target_row][target_col] = CRAFTAX_BLOCK_PATH; + } + + bool is_mining_coal = target_block == CRAFTAX_BLOCK_COAL + && state->inventory.pickaxe >= 1; + if (is_mining_coal) { + state->map[level][target_row][target_col] = CRAFTAX_BLOCK_PATH; + state->inventory.coal += 1; + } + + bool is_mining_iron = target_block == CRAFTAX_BLOCK_IRON + && state->inventory.pickaxe >= 2; + if (is_mining_iron) { + state->map[level][target_row][target_col] = CRAFTAX_BLOCK_PATH; + state->inventory.iron += 1; + } + + bool is_mining_diamond = target_block == CRAFTAX_BLOCK_DIAMOND + && state->inventory.pickaxe >= 3; + if (is_mining_diamond) { + state->map[level][target_row][target_col] = CRAFTAX_BLOCK_PATH; + state->inventory.diamond += 1; + } + + bool is_mining_sapphire = target_block == CRAFTAX_BLOCK_SAPPHIRE + && state->inventory.pickaxe >= 4; + if (is_mining_sapphire) { + state->map[level][target_row][target_col] = CRAFTAX_BLOCK_PATH; + state->inventory.sapphire += 1; + } + + bool is_mining_ruby = target_block == CRAFTAX_BLOCK_RUBY + && state->inventory.pickaxe >= 4; + if (is_mining_ruby) { + state->map[level][target_row][target_col] = CRAFTAX_BLOCK_PATH; + state->inventory.ruby += 1; + } + + bool is_mining_sapling = target_block == CRAFTAX_BLOCK_GRASS + && craftax_threefry_uniform_f32(sapling_key) < 0.1f; + state->inventory.sapling += (int32_t)is_mining_sapling; + + bool is_drinking_water = target_block == CRAFTAX_BLOCK_WATER + || target_block == CRAFTAX_BLOCK_FOUNTAIN; + if (is_drinking_water) { + state->player_drink = craftax_step_mini32( + craftax_step_get_max_drink(state), + state->player_drink + 1 + ); + state->player_thirst = 0.0f; + state->achievements[CRAFTAX_ACH_COLLECT_DRINK] = true; + } + + bool is_eating_plant = target_block == CRAFTAX_BLOCK_RIPE_PLANT; + if (is_eating_plant) { + state->map[level][target_row][target_col] = CRAFTAX_BLOCK_PLANT; + state->player_food = craftax_step_mini32( + craftax_step_get_max_food(state), + state->player_food + 4 + ); + state->player_hunger = 0.0f; + state->achievements[CRAFTAX_ACH_EAT_PLANT] = true; + craftax_do_action_update_plants_with_eat( + state, + target_row, + target_col + ); + } + + bool is_mining_stalagmite = target_block == CRAFTAX_BLOCK_STALAGMITE + && state->inventory.pickaxe >= 1; + if (is_mining_stalagmite) { + state->map[level][target_row][target_col] = CRAFTAX_BLOCK_PATH; + state->inventory.stone += 1; + } + + if (is_opening_chest) { + state->map[level][target_row][target_col] = CRAFTAX_BLOCK_PATH; + craftax_add_items_from_chest_native( + state, + &state->inventory, + true, + chest_key + ); + state->achievements[CRAFTAX_ACH_OPEN_CHEST] = true; + } + + if (is_damaging_boss) { + state->achievements[CRAFTAX_ACH_DAMAGE_NECROMANCER] = true; + } + } + + state->chests_opened[level] = + state->chests_opened[level] || is_opening_chest; + + state->boss_progress += (int32_t)is_damaging_boss; + if (is_damaging_boss) { + state->boss_timesteps_to_spawn_this_round = + CRAFTAX_DO_ACTION_BOSS_FIGHT_SPAWN_TURNS; + } +} diff --git a/ocean/craftax/step_medium.h b/ocean/craftax/step_medium.h new file mode 100644 index 0000000000..9f5ac1aae1 --- /dev/null +++ b/ocean/craftax/step_medium.h @@ -0,0 +1,459 @@ +// Standalone native ports of medium Craftax step subsystems. +// +// These helpers intentionally are not integrated into c_step yet. They mutate a +// full CraftaxState, or an Inventory plus read-only state context, so tests can +// compare each subsystem directly against the installed JAX implementation. + +#pragma once + +#include "step_simple.h" + +static inline CraftaxThreefryKey craftax_medium_next_random_key( + CraftaxThreefryKey* rng +) { + CraftaxThreefryKey draw; + craftax_threefry_split(*rng, rng, &draw); + return draw; +} + +static inline int32_t craftax_medium_randint( + CraftaxThreefryKey key, + int32_t minval, + int32_t maxval +) { + return craftax_randint_i32_at(key, 0u, minval, maxval); +} + +static inline int32_t craftax_medium_choice_weighted( + CraftaxThreefryKey key, + const float* weights, + int32_t count +) { + float total = 0.0f; + for (int32_t i = 0; i < count; i++) { + total += weights[i]; + } + + float draw = total * (1.0f - craftax_threefry_uniform_f32(key)); + float cumulative = 0.0f; + for (int32_t i = 0; i < count; i++) { + cumulative += weights[i]; + if (cumulative >= draw) { + return i; + } + } + return count - 1; +} + +static inline int32_t craftax_medium_projectile_count(const CraftaxState* state) { + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + int32_t count = 0; + for (int32_t i = 0; i < CRAFTAX_MAX_PLAYER_PROJECTILES; i++) { + count += (int32_t)state->player_projectiles.mask[level][i]; + } + return count; +} + +static inline int32_t craftax_medium_first_projectile_slot( + const CraftaxState* state +) { + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + for (int32_t i = 0; i < CRAFTAX_MAX_PLAYER_PROJECTILES; i++) { + if (!state->player_projectiles.mask[level][i]) { + return i; + } + } + return 0; +} + +static inline void craftax_medium_spawn_player_projectile( + CraftaxState* state, + bool is_spawning_projectile, + const int32_t new_projectile_position[2], + const int32_t direction[2], + int32_t projectile_type +) { + if (!is_spawning_projectile) { + return; + } + + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + int32_t index = craftax_medium_first_projectile_slot(state); + state->player_projectiles.position[level][index][0] = new_projectile_position[0]; + state->player_projectiles.position[level][index][1] = new_projectile_position[1]; + state->player_projectiles.mask[level][index] = true; + state->player_projectiles.type_id[level][index] = projectile_type; + state->player_projectile_directions[level][index][0] = direction[0]; + state->player_projectile_directions[level][index][1] = direction[1]; +} + +static inline int32_t craftax_medium_level_achievement(int32_t level) { + switch (craftax_step_jax_index(level, CRAFTAX_NUM_LEVELS)) { + case 1: + return CRAFTAX_ACH_ENTER_DUNGEON; + case 2: + return CRAFTAX_ACH_ENTER_GNOMISH_MINES; + case 3: + return CRAFTAX_ACH_ENTER_SEWERS; + case 4: + return CRAFTAX_ACH_ENTER_VAULT; + case 5: + return CRAFTAX_ACH_ENTER_TROLL_MINES; + case 6: + return CRAFTAX_ACH_ENTER_FIRE_REALM; + case 7: + return CRAFTAX_ACH_ENTER_ICE_REALM; + case 8: + return CRAFTAX_ACH_ENTER_GRAVEYARD; + default: + return CRAFTAX_ACH_COLLECT_WOOD; + } +} + +static inline void craftax_shoot_projectile_native( + CraftaxState* state, + int32_t action +) { + bool is_shooting_arrow = action == CRAFTAX_ACTION_SHOOT_ARROW + && state->inventory.bow >= 1 + && state->inventory.arrows >= 1 + && craftax_medium_projectile_count(state) < CRAFTAX_MAX_PLAYER_PROJECTILES; + + int32_t direction[2]; + craftax_step_direction(state->player_direction, direction); + craftax_medium_spawn_player_projectile( + state, + is_shooting_arrow, + state->player_position, + direction, + CRAFTAX_PROJECTILE_ARROW2 + ); + + state->achievements[CRAFTAX_ACH_FIRE_BOW] = + state->achievements[CRAFTAX_ACH_FIRE_BOW] || is_shooting_arrow; + state->inventory.arrows -= (int32_t)is_shooting_arrow; +} + +static inline void craftax_cast_spell_native( + CraftaxState* state, + int32_t action +) { + bool has_projectile_slot = + craftax_medium_projectile_count(state) < CRAFTAX_MAX_PLAYER_PROJECTILES; + bool has_mana = state->player_mana >= 2; + bool is_casting_fireball = action == CRAFTAX_ACTION_CAST_FIREBALL + && has_mana + && has_projectile_slot + && state->learned_spells[0]; + bool is_casting_iceball = action == CRAFTAX_ACTION_CAST_ICEBALL + && has_mana + && has_projectile_slot + && state->learned_spells[1]; + bool is_casting_spell = is_casting_fireball || is_casting_iceball; + + int32_t projectile_type = + (int32_t)is_casting_fireball * CRAFTAX_PROJECTILE_FIREBALL + + (int32_t)is_casting_iceball * CRAFTAX_PROJECTILE_ICEBALL; + + int32_t direction[2]; + craftax_step_direction(state->player_direction, direction); + craftax_medium_spawn_player_projectile( + state, + is_casting_spell, + state->player_position, + direction, + projectile_type + ); + + if (is_casting_fireball) { + state->achievements[CRAFTAX_ACH_CAST_FIREBALL] = true; + } + if (is_casting_iceball) { + state->achievements[CRAFTAX_ACH_CAST_ICEBALL] = true; + } + state->player_mana -= (int32_t)is_casting_spell * 2; +} + +static inline void craftax_enchant_native( + CraftaxState* state, + int32_t action, + CraftaxThreefryKey rng +) { + int32_t direction[2]; + craftax_step_direction(state->player_direction, direction); + + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + int32_t target_row = craftax_step_jax_index( + state->player_position[0] + direction[0], + CRAFTAX_MAP_SIZE + ); + int32_t target_col = craftax_step_jax_index( + state->player_position[1] + direction[1], + CRAFTAX_MAP_SIZE + ); + int32_t target_block = state->map[level][target_row][target_col]; + + bool is_fire_table = target_block == CRAFTAX_BLOCK_ENCHANTMENT_TABLE_FIRE; + bool is_ice_table = target_block == CRAFTAX_BLOCK_ENCHANTMENT_TABLE_ICE; + bool target_block_is_enchantment_table = is_fire_table || is_ice_table; + int32_t enchantment_type = is_fire_table ? 1 : 2; + int32_t num_gems = is_fire_table + ? state->inventory.ruby + : state->inventory.sapphire; + + bool could_enchant = state->player_mana >= 9 + && target_block_is_enchantment_table + && num_gems >= 1; + bool is_enchanting_bow = could_enchant + && action == CRAFTAX_ACTION_ENCHANT_BOW + && state->inventory.bow > 0; + bool is_enchanting_sword = could_enchant + && action == CRAFTAX_ACTION_ENCHANT_SWORD + && state->inventory.sword > 0; + + int32_t armour_count = 0; + for (int32_t i = 0; i < 4; i++) { + armour_count += state->inventory.armour[i]; + } + bool is_enchanting_armour = could_enchant + && action == CRAFTAX_ACTION_ENCHANT_ARMOUR + && armour_count > 0; + + CraftaxThreefryKey armour_key = craftax_medium_next_random_key(&rng); + int32_t unenchanted_count = 0; + for (int32_t i = 0; i < 4; i++) { + unenchanted_count += (int32_t)(state->armour_enchantments[i] == 0); + } + + float armour_targets[4]; + for (int32_t i = 0; i < 4; i++) { + bool unenchanted = state->armour_enchantments[i] == 0; + bool opposite_enchanted = state->armour_enchantments[i] != 0 + && state->armour_enchantments[i] != enchantment_type; + armour_targets[i] = (unenchanted || ( + unenchanted_count == 0 && opposite_enchanted + )) ? 1.0f : 0.0f; + } + int32_t armour_target = craftax_medium_choice_weighted( + armour_key, + armour_targets, + 4 + ); + + bool is_enchanting = is_enchanting_sword + || is_enchanting_bow + || is_enchanting_armour; + if (is_enchanting_sword) { + state->sword_enchantment = enchantment_type; + state->achievements[CRAFTAX_ACH_ENCHANT_SWORD] = true; + } + if (is_enchanting_bow) { + state->bow_enchantment = enchantment_type; + } + if (is_enchanting_armour) { + state->armour_enchantments[armour_target] = enchantment_type; + state->achievements[CRAFTAX_ACH_ENCHANT_ARMOUR] = true; + } + + state->inventory.sapphire -= + (int32_t)is_enchanting * (int32_t)(enchantment_type == 2); + state->inventory.ruby -= + (int32_t)is_enchanting * (int32_t)(enchantment_type == 1); + state->player_mana -= (int32_t)is_enchanting * 9; +} + +static inline void craftax_change_floor_native( + CraftaxState* state, + int32_t action +) { + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + int32_t player_row = craftax_step_jax_index( + state->player_position[0], + CRAFTAX_MAP_SIZE + ); + int32_t player_col = craftax_step_jax_index( + state->player_position[1], + CRAFTAX_MAP_SIZE + ); + + bool on_down_ladder = + state->item_map[level][player_row][player_col] == CRAFTAX_ITEM_LADDER_DOWN; + bool is_moving_down = action == CRAFTAX_ACTION_DESCEND + && on_down_ladder + && state->monsters_killed[level] >= CRAFTAX_MONSTERS_KILLED_TO_CLEAR_LEVEL + && state->player_level < CRAFTAX_NUM_LEVELS - 1; + + bool on_up_ladder = + state->item_map[level][player_row][player_col] == CRAFTAX_ITEM_LADDER_UP; + bool is_moving_up = action == CRAFTAX_ACTION_ASCEND + && on_up_ladder + && state->player_level > 0; + + int32_t delta_floor = (int32_t)is_moving_down - (int32_t)is_moving_up; + int32_t new_level = state->player_level + delta_floor; + int32_t achievement = craftax_medium_level_achievement(new_level); + bool new_floor = new_level != 0 && !state->achievements[achievement]; + + if (is_moving_down) { + int32_t ladder_level = craftax_step_jax_index( + state->player_level + 1, + CRAFTAX_NUM_LEVELS + ); + state->player_position[0] = state->up_ladders[ladder_level][0]; + state->player_position[1] = state->up_ladders[ladder_level][1]; + } else if (is_moving_up) { + int32_t ladder_level = craftax_step_jax_index( + state->player_level - 1, + CRAFTAX_NUM_LEVELS + ); + state->player_position[0] = state->down_ladders[ladder_level][0]; + state->player_position[1] = state->down_ladders[ladder_level][1]; + } + + state->player_level = new_level; + state->achievements[achievement] = + state->achievements[achievement] || new_level != 0; + state->player_xp += (int32_t)new_floor; +} + +static inline void craftax_add_items_from_chest_native( + const CraftaxState* state, + CraftaxInventory* inventory, + bool is_opening_chest, + CraftaxThreefryKey rng +) { + CraftaxThreefryKey draw_key; + + draw_key = craftax_medium_next_random_key(&rng); + bool is_looting_wood = craftax_threefry_uniform_f32(draw_key) < 0.6f; + draw_key = craftax_medium_next_random_key(&rng); + int32_t wood_loot_amount = + craftax_medium_randint(draw_key, 1, 6) * (int32_t)is_looting_wood; + (void)wood_loot_amount; + + draw_key = craftax_medium_next_random_key(&rng); + bool is_looting_torch = craftax_threefry_uniform_f32(draw_key) < 0.6f; + draw_key = craftax_medium_next_random_key(&rng); + int32_t torch_loot_amount = + craftax_medium_randint(draw_key, 4, 8) * (int32_t)is_looting_torch; + + draw_key = craftax_medium_next_random_key(&rng); + bool is_looting_ore = craftax_threefry_uniform_f32(draw_key) < 0.6f; + draw_key = craftax_medium_next_random_key(&rng); + float ore_weights[5] = {0.3f, 0.3f, 0.15f, 0.125f, 0.125f}; + int32_t ore_loot_id = craftax_medium_choice_weighted( + draw_key, + ore_weights, + 5 + ); + draw_key = craftax_medium_next_random_key(&rng); + + int32_t coal_loot_amount = + craftax_medium_randint(draw_key, 1, 4) + * (int32_t)(ore_loot_id == 0) + * (int32_t)is_looting_ore; + int32_t iron_loot_amount = + craftax_medium_randint(draw_key, 1, 3) + * (int32_t)(ore_loot_id == 1) + * (int32_t)is_looting_ore; + int32_t diamond_loot_amount = + craftax_medium_randint(draw_key, 1, 2) + * (int32_t)(ore_loot_id == 2) + * (int32_t)is_looting_ore; + int32_t sapphire_loot_amount = + craftax_medium_randint(draw_key, 1, 2) + * (int32_t)(ore_loot_id == 3) + * (int32_t)is_looting_ore; + int32_t ruby_loot_amount = + craftax_medium_randint(draw_key, 1, 2) + * (int32_t)(ore_loot_id == 4) + * (int32_t)is_looting_ore; + + draw_key = craftax_medium_next_random_key(&rng); + bool is_looting_potion = craftax_threefry_uniform_f32(draw_key) < 0.5f; + draw_key = craftax_medium_next_random_key(&rng); + int32_t potion_loot_index = craftax_medium_randint(draw_key, 0, 6); + draw_key = craftax_medium_next_random_key(&rng); + int32_t potion_loot_amount = craftax_medium_randint(draw_key, 1, 3); + + draw_key = craftax_medium_next_random_key(&rng); + bool is_looting_arrows = craftax_threefry_uniform_f32(draw_key) < 0.25f; + draw_key = craftax_medium_next_random_key(&rng); + int32_t arrows_loot_amount = + craftax_medium_randint(draw_key, 1, 5) * (int32_t)is_looting_arrows; + + draw_key = craftax_medium_next_random_key(&rng); + bool is_looting_tool = craftax_threefry_uniform_f32(draw_key) < 0.2f; + draw_key = craftax_medium_next_random_key(&rng); + int32_t tool_id = craftax_medium_randint(draw_key, 0, 2); + + bool is_looting_pickaxe = is_looting_tool + && tool_id == 0 + && is_opening_chest; + draw_key = craftax_medium_next_random_key(&rng); + float tool_weights[4] = {0.4f, 0.3f, 0.2f, 0.1f}; + int32_t pickaxe_loot_level = ( + craftax_medium_choice_weighted(draw_key, tool_weights, 4) + 1 + ) * (int32_t)is_looting_pickaxe; + pickaxe_loot_level = craftax_step_maxi32( + pickaxe_loot_level, + inventory->pickaxe + ); + int32_t new_pickaxe_level = is_looting_pickaxe + ? pickaxe_loot_level + : inventory->pickaxe; + + bool is_looting_sword = is_looting_tool + && tool_id == 1 + && is_opening_chest; + draw_key = craftax_medium_next_random_key(&rng); + int32_t sword_loot_level = ( + craftax_medium_choice_weighted(draw_key, tool_weights, 4) + 1 + ) * (int32_t)is_looting_sword; + sword_loot_level = craftax_step_maxi32(sword_loot_level, inventory->sword); + int32_t new_sword_level = is_looting_sword + ? sword_loot_level + : inventory->sword; + + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + bool is_looting_bow = is_opening_chest + && state->player_level == 1 + && !state->chests_opened[level]; + int32_t new_bow_level = is_looting_bow ? 1 : inventory->bow; + + bool is_looting_book = !state->chests_opened[level] + && (state->player_level == 3 || state->player_level == 4); + + int32_t opening = (int32_t)is_opening_chest; + inventory->torches += torch_loot_amount * opening; + inventory->coal += coal_loot_amount * opening; + inventory->iron += iron_loot_amount * opening; + inventory->diamond += diamond_loot_amount * opening; + inventory->sapphire += sapphire_loot_amount * opening; + inventory->ruby += ruby_loot_amount * opening; + inventory->arrows += arrows_loot_amount * opening; + inventory->pickaxe = new_pickaxe_level; + inventory->sword = new_sword_level; + inventory->potions[potion_loot_index] += + potion_loot_amount * (int32_t)is_looting_potion * opening; + inventory->bow = new_bow_level; + inventory->books += (int32_t)is_looting_book * opening; +} diff --git a/ocean/craftax/step_simple.h b/ocean/craftax/step_simple.h new file mode 100644 index 0000000000..b7ca3203f5 --- /dev/null +++ b/ocean/craftax/step_simple.h @@ -0,0 +1,556 @@ +// Standalone native ports of simple Craftax step subsystems. +// +// These helpers intentionally are not integrated into c_step yet. They mutate a +// full CraftaxState in place so tests can compare each subsystem directly +// against the installed JAX implementation. + +#pragma once + +#include "craftax.h" + +static inline int32_t craftax_step_jax_index(int32_t index, int32_t size) { + if (index < 0) { + index += size; + } + if (index < 0) { + return 0; + } + if (index >= size) { + return size - 1; + } + return index; +} + +static inline int32_t craftax_step_mini32(int32_t a, int32_t b) { + return a < b ? a : b; +} + +static inline int32_t craftax_step_maxi32(int32_t a, int32_t b) { + return a > b ? a : b; +} + +static inline float craftax_step_minf32(float a, float b) { + if (isnan(a) || isnan(b)) { + return NAN; + } + return a < b ? a : b; +} + +static inline float craftax_step_maxf32(float a, float b) { + if (isnan(a) || isnan(b)) { + return NAN; + } + return a > b ? a : b; +} + +static inline int32_t craftax_step_get_max_health(const CraftaxState* state) { + return 8 + state->player_strength; +} + +static inline int32_t craftax_step_get_max_food(const CraftaxState* state) { + return 7 + 2 * state->player_dexterity; +} + +static inline int32_t craftax_step_get_max_drink(const CraftaxState* state) { + return 7 + 2 * state->player_dexterity; +} + +static inline int32_t craftax_step_get_max_energy(const CraftaxState* state) { + return 7 + 2 * state->player_dexterity; +} + +static inline int32_t craftax_step_get_max_mana(const CraftaxState* state) { + return 6 + 3 * state->player_intelligence; +} + +static inline bool craftax_step_is_fighting_boss(const CraftaxState* state) { + return state->player_level == CRAFTAX_NUM_LEVELS - 1; +} + +static inline bool craftax_step_has_beaten_boss(const CraftaxState* state) { + return state->boss_progress >= CRAFTAX_NUM_LEVELS - 1; +} + +static inline void craftax_step_direction(int32_t action, int32_t direction[2]) { + direction[0] = 0; + direction[1] = 0; + int32_t direction_index = craftax_step_jax_index(action, 16); + if (direction_index == CRAFTAX_ACTION_LEFT) { + direction[1] = -1; + } else if (direction_index == CRAFTAX_ACTION_RIGHT) { + direction[1] = 1; + } else if (direction_index == CRAFTAX_ACTION_UP) { + direction[0] = -1; + } else if (direction_index == CRAFTAX_ACTION_DOWN) { + direction[0] = 1; + } +} + +static inline bool craftax_step_is_solid_block(int32_t block) { + switch (block) { + case CRAFTAX_BLOCK_STONE: + case CRAFTAX_BLOCK_TREE: + case CRAFTAX_BLOCK_COAL: + case CRAFTAX_BLOCK_IRON: + case CRAFTAX_BLOCK_DIAMOND: + case CRAFTAX_BLOCK_CRAFTING_TABLE: + case CRAFTAX_BLOCK_FURNACE: + case CRAFTAX_BLOCK_PLANT: + case CRAFTAX_BLOCK_RIPE_PLANT: + case CRAFTAX_BLOCK_WALL: + case CRAFTAX_BLOCK_WALL_MOSS: + case CRAFTAX_BLOCK_STALAGMITE: + case CRAFTAX_BLOCK_RUBY: + case CRAFTAX_BLOCK_SAPPHIRE: + case CRAFTAX_BLOCK_CHEST: + case CRAFTAX_BLOCK_FOUNTAIN: + case CRAFTAX_BLOCK_FIRE_TREE: + case CRAFTAX_BLOCK_ENCHANTMENT_TABLE_FIRE: + case CRAFTAX_BLOCK_ENCHANTMENT_TABLE_ICE: + case CRAFTAX_BLOCK_GRAVE: + case CRAFTAX_BLOCK_GRAVE2: + case CRAFTAX_BLOCK_GRAVE3: + case CRAFTAX_BLOCK_NECROMANCER: + return true; + default: + return false; + } +} + +static inline bool craftax_step_is_in_mob( + const CraftaxState* state, + int32_t row, + int32_t col +) { + int32_t level = craftax_step_jax_index(state->player_level, CRAFTAX_NUM_LEVELS); + int32_t map_row = craftax_step_jax_index(row, CRAFTAX_MAP_SIZE); + int32_t map_col = craftax_step_jax_index(col, CRAFTAX_MAP_SIZE); + bool player_here = state->player_position[0] == row + && state->player_position[1] == col; + return state->mob_map[level][map_row][map_col] || player_here; +} + +static inline bool craftax_step_valid_land_position( + const CraftaxState* state, + int32_t row, + int32_t col +) { + bool pos_in_bounds = row >= 0 + && row < CRAFTAX_MAP_SIZE + && col >= 0 + && col < CRAFTAX_MAP_SIZE; + int32_t level = craftax_step_jax_index(state->player_level, CRAFTAX_NUM_LEVELS); + int32_t map_row = craftax_step_jax_index(row, CRAFTAX_MAP_SIZE); + int32_t map_col = craftax_step_jax_index(col, CRAFTAX_MAP_SIZE); + int32_t block = state->map[level][map_row][map_col]; + bool in_solid_block = craftax_step_is_solid_block(block); + bool in_mob = craftax_step_is_in_mob(state, row, col); + bool in_lava = block == CRAFTAX_BLOCK_LAVA; + bool in_water = block == CRAFTAX_BLOCK_WATER; + + bool valid_move = pos_in_bounds && !in_mob && !in_solid_block; + valid_move = valid_move && !in_water; + valid_move = valid_move && !in_lava; + return valid_move; +} + +static inline void craftax_move_player_native( + CraftaxState* state, + int32_t action, + bool god_mode +) { + int32_t direction[2]; + craftax_step_direction(action, direction); + + int32_t proposed_row = state->player_position[0] + direction[0]; + int32_t proposed_col = state->player_position[1] + direction[1]; + bool valid_move = craftax_step_valid_land_position( + state, + proposed_row, + proposed_col + ); + valid_move = valid_move || god_mode; + + state->player_position[0] += (int32_t)valid_move * direction[0]; + state->player_position[1] += (int32_t)valid_move * direction[1]; + + bool is_new_direction = direction[0] != 0 || direction[1] != 0; + state->player_direction = state->player_direction * (1 - (int32_t)is_new_direction) + + action * (int32_t)is_new_direction; +} + +static inline void craftax_update_plants_native(CraftaxState* state) { + bool finished_growing_plants[CRAFTAX_MAX_GROWING_PLANTS]; + + for (int plant = 0; plant < CRAFTAX_MAX_GROWING_PLANTS; plant++) { + state->growing_plants_age[plant] = + (state->growing_plants_age[plant] + 1) + * (int32_t)state->growing_plants_mask[plant]; + finished_growing_plants[plant] = state->growing_plants_age[plant] >= 600; + } + + for (int plant = 0; plant < CRAFTAX_MAX_GROWING_PLANTS; plant++) { + int32_t row = craftax_step_jax_index( + state->growing_plants_positions[plant][0], + CRAFTAX_MAP_SIZE + ); + int32_t col = craftax_step_jax_index( + state->growing_plants_positions[plant][1], + CRAFTAX_MAP_SIZE + ); + int32_t new_block = finished_growing_plants[plant] + ? CRAFTAX_BLOCK_RIPE_PLANT + : state->map[0][row][col]; + state->map[0][row][col] = new_block; + } +} + +static inline void craftax_boss_logic_native(CraftaxState* state) { + state->achievements[CRAFTAX_ACH_DEFEAT_NECROMANCER] = + state->achievements[CRAFTAX_ACH_DEFEAT_NECROMANCER] + || craftax_step_has_beaten_boss(state); + state->boss_timesteps_to_spawn_this_round -= + (int32_t)craftax_step_is_fighting_boss(state); +} + +static inline void craftax_level_up_attributes_native( + CraftaxState* state, + int32_t action, + int32_t max_attribute +) { + bool can_level_up = state->player_xp >= 1; + bool is_levelling_up_dex = can_level_up + && action == CRAFTAX_ACTION_LEVEL_UP_DEXTERITY + && state->player_dexterity < max_attribute; + bool is_levelling_up_str = can_level_up + && action == CRAFTAX_ACTION_LEVEL_UP_STRENGTH + && state->player_strength < max_attribute; + bool is_levelling_up_int = can_level_up + && action == CRAFTAX_ACTION_LEVEL_UP_INTELLIGENCE + && state->player_intelligence < max_attribute; + bool is_levelling_up = is_levelling_up_dex + || is_levelling_up_str + || is_levelling_up_int; + + state->player_dexterity += (int32_t)is_levelling_up_dex; + state->player_strength += (int32_t)is_levelling_up_str; + state->player_intelligence += (int32_t)is_levelling_up_int; + state->player_xp -= (int32_t)is_levelling_up; +} + +static inline void craftax_clip_inventory_and_intrinsics_native( + CraftaxState* state, + bool god_mode +) { + state->inventory.wood = craftax_step_mini32(state->inventory.wood, 99); + state->inventory.stone = craftax_step_mini32(state->inventory.stone, 99); + state->inventory.coal = craftax_step_mini32(state->inventory.coal, 99); + state->inventory.iron = craftax_step_mini32(state->inventory.iron, 99); + state->inventory.diamond = craftax_step_mini32(state->inventory.diamond, 99); + state->inventory.sapling = craftax_step_mini32(state->inventory.sapling, 99); + state->inventory.pickaxe = craftax_step_mini32(state->inventory.pickaxe, 99); + state->inventory.sword = craftax_step_mini32(state->inventory.sword, 99); + state->inventory.bow = craftax_step_mini32(state->inventory.bow, 99); + state->inventory.arrows = craftax_step_mini32(state->inventory.arrows, 99); + for (int i = 0; i < 4; i++) { + state->inventory.armour[i] = craftax_step_mini32( + state->inventory.armour[i], + 99 + ); + } + state->inventory.torches = craftax_step_mini32(state->inventory.torches, 99); + state->inventory.ruby = craftax_step_mini32(state->inventory.ruby, 99); + state->inventory.sapphire = craftax_step_mini32(state->inventory.sapphire, 99); + for (int i = 0; i < 6; i++) { + state->inventory.potions[i] = craftax_step_mini32( + state->inventory.potions[i], + 99 + ); + } + state->inventory.books = craftax_step_mini32(state->inventory.books, 99); + + float min_health = god_mode ? 9.0f : 0.0f; + state->player_health = craftax_step_minf32( + craftax_step_maxf32(state->player_health, min_health), + (float)craftax_step_get_max_health(state) + ); + state->player_food = craftax_step_mini32( + craftax_step_maxi32(state->player_food, 0), + craftax_step_get_max_food(state) + ); + state->player_drink = craftax_step_mini32( + craftax_step_maxi32(state->player_drink, 0), + craftax_step_get_max_drink(state) + ); + state->player_energy = craftax_step_mini32( + craftax_step_maxi32(state->player_energy, 0), + craftax_step_get_max_energy(state) + ); + state->player_mana = craftax_step_mini32( + craftax_step_maxi32(state->player_mana, 0), + craftax_step_get_max_mana(state) + ); +} + +static inline void craftax_calculate_inventory_achievements_native( + CraftaxState* state +) { + state->achievements[CRAFTAX_ACH_COLLECT_WOOD] = + state->achievements[CRAFTAX_ACH_COLLECT_WOOD] || state->inventory.wood > 0; + state->achievements[CRAFTAX_ACH_COLLECT_STONE] = + state->achievements[CRAFTAX_ACH_COLLECT_STONE] || state->inventory.stone > 0; + state->achievements[CRAFTAX_ACH_COLLECT_COAL] = + state->achievements[CRAFTAX_ACH_COLLECT_COAL] || state->inventory.coal > 0; + state->achievements[CRAFTAX_ACH_COLLECT_IRON] = + state->achievements[CRAFTAX_ACH_COLLECT_IRON] || state->inventory.iron > 0; + state->achievements[CRAFTAX_ACH_COLLECT_DIAMOND] = + state->achievements[CRAFTAX_ACH_COLLECT_DIAMOND] || state->inventory.diamond > 0; + state->achievements[CRAFTAX_ACH_COLLECT_RUBY] = + state->achievements[CRAFTAX_ACH_COLLECT_RUBY] || state->inventory.ruby > 0; + state->achievements[CRAFTAX_ACH_COLLECT_SAPPHIRE] = + state->achievements[CRAFTAX_ACH_COLLECT_SAPPHIRE] + || state->inventory.sapphire > 0; + state->achievements[CRAFTAX_ACH_COLLECT_SAPLING] = + state->achievements[CRAFTAX_ACH_COLLECT_SAPLING] + || state->inventory.sapling > 0; + state->achievements[CRAFTAX_ACH_FIND_BOW] = + state->achievements[CRAFTAX_ACH_FIND_BOW] || state->inventory.bow > 0; + state->achievements[CRAFTAX_ACH_MAKE_ARROW] = + state->achievements[CRAFTAX_ACH_MAKE_ARROW] || state->inventory.arrows > 0; + state->achievements[CRAFTAX_ACH_MAKE_TORCH] = + state->achievements[CRAFTAX_ACH_MAKE_TORCH] || state->inventory.torches > 0; + + state->achievements[CRAFTAX_ACH_MAKE_WOOD_PICKAXE] = + state->achievements[CRAFTAX_ACH_MAKE_WOOD_PICKAXE] + || state->inventory.pickaxe >= 1; + state->achievements[CRAFTAX_ACH_MAKE_STONE_PICKAXE] = + state->achievements[CRAFTAX_ACH_MAKE_STONE_PICKAXE] + || state->inventory.pickaxe >= 2; + state->achievements[CRAFTAX_ACH_MAKE_IRON_PICKAXE] = + state->achievements[CRAFTAX_ACH_MAKE_IRON_PICKAXE] + || state->inventory.pickaxe >= 3; + state->achievements[CRAFTAX_ACH_MAKE_DIAMOND_PICKAXE] = + state->achievements[CRAFTAX_ACH_MAKE_DIAMOND_PICKAXE] + || state->inventory.pickaxe >= 4; + + state->achievements[CRAFTAX_ACH_MAKE_WOOD_SWORD] = + state->achievements[CRAFTAX_ACH_MAKE_WOOD_SWORD] + || state->inventory.sword >= 1; + state->achievements[CRAFTAX_ACH_MAKE_STONE_SWORD] = + state->achievements[CRAFTAX_ACH_MAKE_STONE_SWORD] + || state->inventory.sword >= 2; + state->achievements[CRAFTAX_ACH_MAKE_IRON_SWORD] = + state->achievements[CRAFTAX_ACH_MAKE_IRON_SWORD] + || state->inventory.sword >= 3; + state->achievements[CRAFTAX_ACH_MAKE_DIAMOND_SWORD] = + state->achievements[CRAFTAX_ACH_MAKE_DIAMOND_SWORD] + || state->inventory.sword >= 4; +} + +static inline void craftax_update_player_intrinsics_native( + CraftaxState* state, + int32_t action +) { + bool is_starting_sleep = action == CRAFTAX_ACTION_SLEEP + && state->player_energy < craftax_step_get_max_energy(state); + state->is_sleeping = state->is_sleeping || is_starting_sleep; + + bool is_waking_up = state->player_energy >= craftax_step_get_max_energy(state) + && state->is_sleeping; + state->is_sleeping = state->is_sleeping && !is_waking_up; + state->achievements[CRAFTAX_ACH_WAKE_UP] = + state->achievements[CRAFTAX_ACH_WAKE_UP] || is_waking_up; + + bool is_starting_rest = action == CRAFTAX_ACTION_REST + && state->player_health < (float)craftax_step_get_max_health(state); + state->is_resting = state->is_resting || is_starting_rest; + + is_waking_up = state->is_resting + && ( + state->player_health >= (float)craftax_step_get_max_health(state) + || state->player_food <= 0 + || state->player_drink <= 0 + ); + state->is_resting = state->is_resting && !is_waking_up; + + bool not_boss = !craftax_step_is_fighting_boss(state); + float intrinsic_decay_coeff = + 1.0f - (0.125f * (float)(state->player_dexterity - 1)); + + float hunger_add = (state->is_sleeping ? 0.5f : 1.0f) * intrinsic_decay_coeff; + float new_hunger = state->player_hunger + hunger_add; + int32_t hungered_food = craftax_step_maxi32( + state->player_food - (int32_t)not_boss, + 0 + ); + int32_t new_food = new_hunger > 25.0f ? hungered_food : state->player_food; + new_hunger = new_hunger > 25.0f ? 0.0f : new_hunger; + state->player_hunger = new_hunger; + state->player_food = new_food; + + float thirst_add = (state->is_sleeping ? 0.5f : 1.0f) * intrinsic_decay_coeff; + float new_thirst = state->player_thirst + thirst_add; + int32_t thirsted_drink = craftax_step_maxi32( + state->player_drink - (int32_t)not_boss, + 0 + ); + int32_t new_drink = new_thirst > 20.0f ? thirsted_drink : state->player_drink; + new_thirst = new_thirst > 20.0f ? 0.0f : new_thirst; + state->player_thirst = new_thirst; + state->player_drink = new_drink; + + float new_fatigue = state->is_sleeping + ? craftax_step_minf32(state->player_fatigue - 1.0f, 0.0f) + : state->player_fatigue + intrinsic_decay_coeff; + int32_t new_energy = new_fatigue > 30.0f + ? craftax_step_maxi32(state->player_energy - (int32_t)not_boss, 0) + : state->player_energy; + new_fatigue = new_fatigue > 30.0f ? 0.0f : new_fatigue; + new_energy = new_fatigue < -10.0f + ? craftax_step_mini32( + state->player_energy + 1, + craftax_step_get_max_energy(state) + ) + : new_energy; + new_fatigue = new_fatigue < -10.0f ? 0.0f : new_fatigue; + state->player_fatigue = new_fatigue; + state->player_energy = new_energy; + + bool all_necessities = state->player_food > 0 + && state->player_drink > 0 + && (state->player_energy > 0 || state->is_sleeping); + float recover_all = state->is_sleeping ? 2.0f : 1.0f; + float recover_not_all = (state->is_sleeping ? -0.5f : -1.0f) + * (float)(int32_t)not_boss; + float recover_add = all_necessities ? recover_all : recover_not_all; + float new_recover = state->player_recover + recover_add; + + float recovered_health = craftax_step_minf32( + state->player_health + 1.0f, + (float)craftax_step_get_max_health(state) + ); + float derecovered_health = state->player_health - 1.0f; + float new_health = new_recover > 25.0f + ? recovered_health + : state->player_health; + new_recover = new_recover > 25.0f ? 0.0f : new_recover; + new_health = new_recover < -15.0f ? derecovered_health : new_health; + new_recover = new_recover < -15.0f ? 0.0f : new_recover; + state->player_recover = new_recover; + state->player_health = new_health; + + float mana_recover_coeff = + 1.0f + 0.25f * (float)(state->player_intelligence - 1); + float new_recover_mana = ( + state->is_sleeping + ? state->player_recover_mana + 2.0f + : state->player_recover_mana + 1.0f + ) * mana_recover_coeff; + int32_t new_mana = new_recover_mana > 30.0f + ? state->player_mana + 1 + : state->player_mana; + new_recover_mana = new_recover_mana > 30.0f ? 0.0f : new_recover_mana; + state->player_recover_mana = new_recover_mana; + state->player_mana = new_mana; +} + +static inline void craftax_drink_potion_native( + CraftaxState* state, + int32_t action +) { + int32_t drinking_potion_index = -1; + bool is_drinking_potion = false; + + bool is_drinking_red_potion = action == CRAFTAX_ACTION_DRINK_POTION_RED + && state->inventory.potions[0] > 0; + drinking_potion_index = (int32_t)is_drinking_red_potion * 0 + + (1 - (int32_t)is_drinking_red_potion) * drinking_potion_index; + is_drinking_potion = is_drinking_potion || is_drinking_red_potion; + + bool is_drinking_green_potion = action == CRAFTAX_ACTION_DRINK_POTION_GREEN + && state->inventory.potions[1] > 0; + drinking_potion_index = (int32_t)is_drinking_green_potion * 1 + + (1 - (int32_t)is_drinking_green_potion) * drinking_potion_index; + is_drinking_potion = is_drinking_potion || is_drinking_green_potion; + + bool is_drinking_blue_potion = action == CRAFTAX_ACTION_DRINK_POTION_BLUE + && state->inventory.potions[2] > 0; + drinking_potion_index = (int32_t)is_drinking_blue_potion * 2 + + (1 - (int32_t)is_drinking_blue_potion) * drinking_potion_index; + is_drinking_potion = is_drinking_potion || is_drinking_blue_potion; + + bool is_drinking_pink_potion = action == CRAFTAX_ACTION_DRINK_POTION_PINK + && state->inventory.potions[3] > 0; + drinking_potion_index = (int32_t)is_drinking_pink_potion * 3 + + (1 - (int32_t)is_drinking_pink_potion) * drinking_potion_index; + is_drinking_potion = is_drinking_potion || is_drinking_pink_potion; + + bool is_drinking_cyan_potion = action == CRAFTAX_ACTION_DRINK_POTION_CYAN + && state->inventory.potions[4] > 0; + drinking_potion_index = (int32_t)is_drinking_cyan_potion * 4 + + (1 - (int32_t)is_drinking_cyan_potion) * drinking_potion_index; + is_drinking_potion = is_drinking_potion || is_drinking_cyan_potion; + + bool is_drinking_yellow_potion = action == CRAFTAX_ACTION_DRINK_POTION_YELLOW + && state->inventory.potions[5] > 0; + drinking_potion_index = (int32_t)is_drinking_yellow_potion * 5 + + (1 - (int32_t)is_drinking_yellow_potion) * drinking_potion_index; + is_drinking_potion = is_drinking_potion || is_drinking_yellow_potion; + + int32_t potion_index = craftax_step_jax_index(drinking_potion_index, 6); + int32_t potion_effect_index = state->potion_mapping[potion_index]; + + int32_t delta_health = 0; + delta_health += (int32_t)is_drinking_potion * (int32_t)(potion_effect_index == 0) * 8; + delta_health += (int32_t)is_drinking_potion * (int32_t)(potion_effect_index == 1) * -3; + + int32_t delta_mana = 0; + delta_mana += (int32_t)is_drinking_potion * (int32_t)(potion_effect_index == 2) * 8; + delta_mana += (int32_t)is_drinking_potion * (int32_t)(potion_effect_index == 3) * -3; + + int32_t delta_energy = 0; + delta_energy += (int32_t)is_drinking_potion * (int32_t)(potion_effect_index == 4) * 8; + delta_energy += (int32_t)is_drinking_potion * (int32_t)(potion_effect_index == 5) * -3; + + state->achievements[CRAFTAX_ACH_DRINK_POTION] = + state->achievements[CRAFTAX_ACH_DRINK_POTION] || is_drinking_potion; + state->inventory.potions[potion_index] = + state->inventory.potions[potion_index] - (int32_t)is_drinking_potion; + state->player_health += (float)delta_health; + state->player_mana += delta_mana; + state->player_energy += delta_energy; +} + +static inline void craftax_read_book_native( + CraftaxState* state, + const uint32_t rng_words[2], + int32_t action +) { + bool is_reading_book = action == CRAFTAX_ACTION_READ_BOOK + && state->inventory.books > 0; + + CraftaxThreefryKey rng = {{rng_words[0], rng_words[1]}}; + CraftaxThreefryKey unused; + CraftaxThreefryKey choice_key; + craftax_threefry_split(rng, &unused, &choice_key); + + float p0 = state->learned_spells[0] ? 0.0f : 1.0f; + float p1 = state->learned_spells[1] ? 0.0f : 1.0f; + float p_sum = p0 + p1; + int32_t spell_to_learn_index = 0; + if (p_sum != 0.0f) { + p0 /= p_sum; + float r = 1.0f - craftax_threefry_uniform_f32(choice_key); + spell_to_learn_index = r <= p0 ? 0 : 1; + } + + int32_t learn_spell_achievement = spell_to_learn_index + ? CRAFTAX_ACH_LEARN_ICEBALL + : CRAFTAX_ACH_LEARN_FIREBALL; + + state->achievements[learn_spell_achievement] = + state->achievements[learn_spell_achievement] || is_reading_book; + state->inventory.books -= (int32_t)is_reading_book; + state->learned_spells[spell_to_learn_index] = + state->learned_spells[spell_to_learn_index] || is_reading_book; +} diff --git a/ocean/craftax/step_spawn_mobs.h b/ocean/craftax/step_spawn_mobs.h new file mode 100644 index 0000000000..280f7d9d58 --- /dev/null +++ b/ocean/craftax/step_spawn_mobs.h @@ -0,0 +1,460 @@ +// Craftax spawn_mobs, optimized for CPU. +// +// Bitwise-equivalent to the prior JAX-transliterated baseline (verified by +// ocean/craftax_exp/parity_vs_baseline.c over 1.28M paired steps), ~6-9x +// faster per step by stripping JAX-isms: +// - full-grid validity masks -> compact coord list collected in one pass +// - bounding-box scan (only cells within MOB_DESPAWN_DISTANCE) +// - early return on mob-cap / probability-roll failure (no dead writes) +// - merged count + first_empty loops +// +// The prior reference implementation is archived at +// ocean/craftax_exp/step_spawn_mobs_baseline.h. + +#pragma once + +#include "step_medium.h" + +#define CRAFTAX_SPAWN_MAP_CELLS (CRAFTAX_MAP_SIZE * CRAFTAX_MAP_SIZE) +#define CRAFTAX_SPAWN_BBOX_MAX_CELLS 729 // (2*DESPAWN-1)^2 at 14 = 27*27 + +static inline CraftaxThreefryKey craftax_spawn_next_random_key( + CraftaxThreefryKey* rng +) { + CraftaxThreefryKey draw; + craftax_threefry_split(*rng, rng, &draw); + return draw; +} + +static inline int32_t craftax_spawn_floor_mob_type( + int32_t floor, int32_t mob_class +) { + static const int32_t mapping[CRAFTAX_NUM_LEVELS][3] = { + {0, 0, 0}, {2, 2, 2}, {1, 1, 1}, {2, 3, 3}, {2, 4, 4}, + {1, 5, 5}, {1, 6, 6}, {1, 7, 7}, {0, 0, 0}, + }; + int32_t level = craftax_step_jax_index(floor, CRAFTAX_NUM_LEVELS); + int32_t class_index = craftax_step_jax_index(mob_class, 3); + return mapping[level][class_index]; +} + +static inline float craftax_spawn_floor_spawn_chance( + int32_t floor, int32_t chance_index +) { + static const float chances[CRAFTAX_NUM_LEVELS][4] = { + {0.1f, 0.02f, 0.05f, 0.1f}, + {0.1f, 0.06f, 0.05f, 0.0f}, + {0.1f, 0.06f, 0.05f, 0.0f}, + {0.1f, 0.06f, 0.05f, 0.0f}, + {0.1f, 0.06f, 0.05f, 0.0f}, + {0.1f, 0.06f, 0.05f, 0.0f}, + {0.1f, 0.06f, 0.05f, 0.0f}, + {0.0f, 0.06f, 0.05f, 0.0f}, + {0.1f, 0.06f, 0.05f, 0.0f}, + }; + int32_t level = craftax_step_jax_index(floor, CRAFTAX_NUM_LEVELS); + int32_t index = craftax_step_jax_index(chance_index, 4); + return chances[level][index]; +} + +static inline float craftax_spawn_mob_type_health( + int32_t mob_type, int32_t mob_class +) { + static const float health[CRAFTAX_NUM_MOB_TYPES][4] = { + {3.0f, 5.0f, 3.0f, 0.0f}, {4.0f, 7.0f, 5.0f, 0.0f}, + {6.0f, 9.0f, 6.0f, 0.0f}, {8.0f, 11.0f, 8.0f, 0.0f}, + {0.0f, 12.0f, 12.0f, 0.0f}, {0.0f, 20.0f, 4.0f, 0.0f}, + {0.0f, 20.0f, 14.0f, 0.0f}, {0.0f, 24.0f, 16.0f, 0.0f}, + }; + int32_t type_index = craftax_step_jax_index(mob_type, CRAFTAX_NUM_MOB_TYPES); + int32_t class_index = craftax_step_jax_index(mob_class, 4); + return health[type_index][class_index]; +} + +static inline bool craftax_spawn_is_all_valid_block(int32_t block) { + return block == CRAFTAX_BLOCK_GRASS + || block == CRAFTAX_BLOCK_PATH + || block == CRAFTAX_BLOCK_FIRE_GRASS + || block == CRAFTAX_BLOCK_ICE_GRASS; +} + +static inline bool craftax_spawn_is_grave_block(int32_t block) { + return block == CRAFTAX_BLOCK_GRAVE + || block == CRAFTAX_BLOCK_GRAVE2 + || block == CRAFTAX_BLOCK_GRAVE3; +} + +static inline int32_t craftax_spawn_player_distance_squared( + const CraftaxState* state, int32_t row, int32_t col +) { + int32_t dr = row - state->player_position[0]; + int32_t dc = col - state->player_position[1]; + if (dr < 0) dr = -dr; + if (dc < 0) dc = -dc; + return dr * dr + dc * dc; +} + +static inline int32_t craftax_spawn_count_mobs3( + const CraftaxMobs3* mobs, int32_t level +) { + int32_t count = 0; + for (int32_t i = 0; i < 3; i++) count += (int32_t)mobs->mask[level][i]; + return count; +} + +static inline int32_t craftax_spawn_count_mobs2( + const CraftaxMobs2* mobs, int32_t level +) { + int32_t count = 0; + for (int32_t i = 0; i < 2; i++) count += (int32_t)mobs->mask[level][i]; + return count; +} + +static inline int32_t craftax_spawn_first_empty_mobs3( + const CraftaxMobs3* mobs, int32_t level +) { + for (int32_t i = 0; i < 3; i++) if (!mobs->mask[level][i]) return i; + return 0; +} + +static inline int32_t craftax_spawn_first_empty_mobs2( + const CraftaxMobs2* mobs, int32_t level +) { + for (int32_t i = 0; i < 2; i++) if (!mobs->mask[level][i]) return i; + return 0; +} + +static inline void craftax_spawn_mobs3_count_and_empty( + const CraftaxMobs3* mobs, int32_t level, + int32_t* count_out, int32_t* first_empty_out +) { + int32_t count = 0, first_empty = 0; + bool found = false; + for (int32_t i = 0; i < 3; i++) { + bool m = mobs->mask[level][i]; + count += (int32_t)m; + if (!m && !found) { first_empty = i; found = true; } + } + *count_out = count; + *first_empty_out = first_empty; +} + +static inline void craftax_spawn_mobs2_count_and_empty( + const CraftaxMobs2* mobs, int32_t level, + int32_t* count_out, int32_t* first_empty_out +) { + int32_t count = 0, first_empty = 0; + bool found = false; + for (int32_t i = 0; i < 2; i++) { + bool m = mobs->mask[level][i]; + count += (int32_t)m; + if (!m && !found) { first_empty = i; found = true; } + } + *count_out = count; + *first_empty_out = first_empty; +} + +// Baseline algorithm on a bool mask: +// draw = valid_count * (1.0 - uniform_f32(key)); +// cum = 0; +// for i: if valid[i] { cum += 1.0; if (cum >= draw) return i; } +// Over a compact list of length valid_count this collapses to a short loop +// using the same FP arithmetic, preserving bitwise-identical choice. +static inline int32_t craftax_spawn_pick_kth( + int32_t valid_count, CraftaxThreefryKey key +) { + float draw = (float)valid_count * (1.0f - craftax_threefry_uniform_f32(key)); + float cum = 0.0f; + for (int32_t k = 0; k < valid_count; k++) { + cum += 1.0f; + if (cum >= draw) return k; + } + return valid_count - 1; +} + +typedef struct { int16_t row, col; } CraftaxSpawnCoord; + +static inline bool craftax_spawn_scan_passive( + const CraftaxState* state, int32_t level, CraftaxThreefryKey pos_key, + int32_t* out_row, int32_t* out_col +) { + int32_t pr = state->player_position[0]; + int32_t pc = state->player_position[1]; + int32_t r0 = pr - (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + int32_t r1 = pr + (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + int32_t c0 = pc - (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + int32_t c1 = pc + (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + if (r0 < 0) r0 = 0; + if (c0 < 0) c0 = 0; + if (r1 > CRAFTAX_MAP_SIZE - 1) r1 = CRAFTAX_MAP_SIZE - 1; + if (c1 > CRAFTAX_MAP_SIZE - 1) c1 = CRAFTAX_MAP_SIZE - 1; + + const int32_t limit2 = CRAFTAX_MOB_DESPAWN_DISTANCE + * CRAFTAX_MOB_DESPAWN_DISTANCE; + CraftaxSpawnCoord coords[CRAFTAX_SPAWN_BBOX_MAX_CELLS]; + int32_t n = 0; + for (int32_t row = r0; row <= r1; row++) { + int32_t dr = row - pr; if (dr < 0) dr = -dr; + int32_t dr2 = dr * dr; + const int32_t* map_row = state->map[level][row]; + const bool* mob_row = state->mob_map[level][row]; + for (int32_t col = c0; col <= c1; col++) { + int32_t dc = col - pc; if (dc < 0) dc = -dc; + int32_t distance2 = dr2 + dc * dc; + if (distance2 <= 9 || distance2 >= limit2) continue; + if (mob_row[col]) continue; + int32_t block = map_row[col]; + if (block != CRAFTAX_BLOCK_GRASS && block != CRAFTAX_BLOCK_PATH + && block != CRAFTAX_BLOCK_FIRE_GRASS + && block != CRAFTAX_BLOCK_ICE_GRASS) continue; + coords[n].row = (int16_t)row; + coords[n].col = (int16_t)col; + n++; + } + } + if (n == 0) return false; + int32_t k = craftax_spawn_pick_kth(n, pos_key); + *out_row = coords[k].row; + *out_col = coords[k].col; + return true; +} + +static inline bool craftax_spawn_scan_melee( + const CraftaxState* state, int32_t level, bool fighting_boss, + CraftaxThreefryKey pos_key, int32_t* out_row, int32_t* out_col +) { + int32_t pr = state->player_position[0]; + int32_t pc = state->player_position[1]; + int32_t r0 = pr - (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + int32_t r1 = pr + (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + int32_t c0 = pc - (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + int32_t c1 = pc + (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + if (r0 < 0) r0 = 0; + if (c0 < 0) c0 = 0; + if (r1 > CRAFTAX_MAP_SIZE - 1) r1 = CRAFTAX_MAP_SIZE - 1; + if (c1 > CRAFTAX_MAP_SIZE - 1) c1 = CRAFTAX_MAP_SIZE - 1; + + const int32_t limit2 = CRAFTAX_MOB_DESPAWN_DISTANCE + * CRAFTAX_MOB_DESPAWN_DISTANCE; + CraftaxSpawnCoord coords[CRAFTAX_SPAWN_BBOX_MAX_CELLS]; + int32_t n = 0; + for (int32_t row = r0; row <= r1; row++) { + int32_t dr = row - pr; if (dr < 0) dr = -dr; + int32_t dr2 = dr * dr; + const int32_t* map_row = state->map[level][row]; + const bool* mob_row = state->mob_map[level][row]; + for (int32_t col = c0; col <= c1; col++) { + int32_t dc = col - pc; if (dc < 0) dc = -dc; + int32_t distance2 = dr2 + dc * dc; + if (distance2 >= limit2) continue; + bool range_ok = fighting_boss ? (distance2 <= 36) : (distance2 > 81); + if (!range_ok) continue; + if (mob_row[col]) continue; + int32_t block = map_row[col]; + bool terrain_ok; + if (fighting_boss) { + terrain_ok = (block == CRAFTAX_BLOCK_GRAVE + || block == CRAFTAX_BLOCK_GRAVE2 + || block == CRAFTAX_BLOCK_GRAVE3); + } else { + terrain_ok = (block == CRAFTAX_BLOCK_GRASS + || block == CRAFTAX_BLOCK_PATH + || block == CRAFTAX_BLOCK_FIRE_GRASS + || block == CRAFTAX_BLOCK_ICE_GRASS); + } + if (!terrain_ok) continue; + coords[n].row = (int16_t)row; + coords[n].col = (int16_t)col; + n++; + } + } + if (n == 0) return false; + int32_t k = craftax_spawn_pick_kth(n, pos_key); + *out_row = coords[k].row; + *out_col = coords[k].col; + return true; +} + +static inline bool craftax_spawn_scan_ranged( + const CraftaxState* state, int32_t level, int32_t new_type, + bool fighting_boss, CraftaxThreefryKey pos_key, + int32_t* out_row, int32_t* out_col +) { + int32_t pr = state->player_position[0]; + int32_t pc = state->player_position[1]; + int32_t r0 = pr - (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + int32_t r1 = pr + (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + int32_t c0 = pc - (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + int32_t c1 = pc + (CRAFTAX_MOB_DESPAWN_DISTANCE - 1); + if (r0 < 0) r0 = 0; + if (c0 < 0) c0 = 0; + if (r1 > CRAFTAX_MAP_SIZE - 1) r1 = CRAFTAX_MAP_SIZE - 1; + if (c1 > CRAFTAX_MAP_SIZE - 1) c1 = CRAFTAX_MAP_SIZE - 1; + + const int32_t limit2 = CRAFTAX_MOB_DESPAWN_DISTANCE + * CRAFTAX_MOB_DESPAWN_DISTANCE; + CraftaxSpawnCoord coords[CRAFTAX_SPAWN_BBOX_MAX_CELLS]; + int32_t n = 0; + bool water_type = (new_type == 5); + for (int32_t row = r0; row <= r1; row++) { + int32_t dr = row - pr; if (dr < 0) dr = -dr; + int32_t dr2 = dr * dr; + const int32_t* map_row = state->map[level][row]; + const bool* mob_row = state->mob_map[level][row]; + for (int32_t col = c0; col <= c1; col++) { + int32_t dc = col - pc; if (dc < 0) dc = -dc; + int32_t distance2 = dr2 + dc * dc; + if (distance2 >= limit2) continue; + bool range_ok = fighting_boss ? (distance2 <= 36) : (distance2 > 81); + if (!range_ok) continue; + if (mob_row[col]) continue; + int32_t block = map_row[col]; + bool terrain_ok; + if (fighting_boss) { + terrain_ok = (block == CRAFTAX_BLOCK_GRAVE + || block == CRAFTAX_BLOCK_GRAVE2 + || block == CRAFTAX_BLOCK_GRAVE3); + } else if (water_type) { + terrain_ok = (block == CRAFTAX_BLOCK_WATER); + } else { + terrain_ok = (block == CRAFTAX_BLOCK_GRASS + || block == CRAFTAX_BLOCK_PATH + || block == CRAFTAX_BLOCK_FIRE_GRASS + || block == CRAFTAX_BLOCK_ICE_GRASS); + } + if (!terrain_ok) continue; + coords[n].row = (int16_t)row; + coords[n].col = (int16_t)col; + n++; + } + } + if (n == 0) return false; + int32_t k = craftax_spawn_pick_kth(n, pos_key); + *out_row = coords[k].row; + *out_col = coords[k].col; + return true; +} + +// Both RNG keys are always consumed (preserves baseline RNG sequence). +// Baseline quirk: type_id[level][slot] is written unconditionally, even +// when no mob spawns. We match that for bitwise parity. + +static inline void craftax_spawn_passive_mob( + CraftaxState* state, CraftaxThreefryKey* rng, + int32_t level, bool fighting_boss +) { + int32_t count, slot; + craftax_spawn_mobs3_count_and_empty(&state->passive_mobs, level, &count, &slot); + + CraftaxThreefryKey prob_key = craftax_spawn_next_random_key(rng); + CraftaxThreefryKey pos_key = craftax_spawn_next_random_key(rng); + + int32_t type = craftax_spawn_floor_mob_type(level, CRAFTAX_MOB_PASSIVE); + state->passive_mobs.type_id[level][slot] = type; + + if (fighting_boss) return; + if (count >= CRAFTAX_MAX_PASSIVE_MOBS) return; + if (craftax_threefry_uniform_f32(prob_key) + >= craftax_spawn_floor_spawn_chance(level, 0)) return; + + int32_t row, col; + if (!craftax_spawn_scan_passive(state, level, pos_key, &row, &col)) return; + + state->passive_mobs.position[level][slot][0] = row; + state->passive_mobs.position[level][slot][1] = col; + state->passive_mobs.health[level][slot] = + craftax_spawn_mob_type_health(type, CRAFTAX_MOB_PASSIVE); + state->passive_mobs.mask[level][slot] = true; + state->mob_map[level][row][col] = true; +} + +static inline void craftax_spawn_melee_mob( + CraftaxState* state, CraftaxThreefryKey* rng, + int32_t level, bool fighting_boss, int32_t monster_spawn_coeff +) { + int32_t count, slot; + craftax_spawn_mobs3_count_and_empty(&state->melee_mobs, level, &count, &slot); + + int32_t type = fighting_boss + ? craftax_spawn_floor_mob_type(state->boss_progress, CRAFTAX_MOB_MELEE) + : craftax_spawn_floor_mob_type(level, CRAFTAX_MOB_MELEE); + + CraftaxThreefryKey prob_key = craftax_spawn_next_random_key(rng); + float night_coeff = 1.0f - state->light_level; + float spawn_chance = craftax_spawn_floor_spawn_chance(level, 1) + + craftax_spawn_floor_spawn_chance(level, 3) * night_coeff * night_coeff; + CraftaxThreefryKey pos_key = craftax_spawn_next_random_key(rng); + + state->melee_mobs.type_id[level][slot] = type; + + if (count >= CRAFTAX_MAX_MELEE_MOBS) return; + if (craftax_threefry_uniform_f32(prob_key) + >= spawn_chance * (float)monster_spawn_coeff) return; + + int32_t row, col; + if (!craftax_spawn_scan_melee(state, level, fighting_boss, pos_key, &row, &col)) + return; + + state->melee_mobs.position[level][slot][0] = row; + state->melee_mobs.position[level][slot][1] = col; + state->melee_mobs.health[level][slot] = + craftax_spawn_mob_type_health(type, CRAFTAX_MOB_MELEE); + state->melee_mobs.mask[level][slot] = true; + state->mob_map[level][row][col] = true; +} + +static inline void craftax_spawn_ranged_mob( + CraftaxState* state, CraftaxThreefryKey* rng, + int32_t level, bool fighting_boss, int32_t monster_spawn_coeff +) { + int32_t count, slot; + craftax_spawn_mobs2_count_and_empty(&state->ranged_mobs, level, &count, &slot); + + int32_t type = fighting_boss + ? craftax_spawn_floor_mob_type(state->boss_progress, CRAFTAX_MOB_RANGED) + : craftax_spawn_floor_mob_type(level, CRAFTAX_MOB_RANGED); + + CraftaxThreefryKey prob_key = craftax_spawn_next_random_key(rng); + CraftaxThreefryKey pos_key = craftax_spawn_next_random_key(rng); + + state->ranged_mobs.type_id[level][slot] = type; + + if (count >= CRAFTAX_MAX_RANGED_MOBS) return; + if (craftax_threefry_uniform_f32(prob_key) + >= craftax_spawn_floor_spawn_chance(level, 2) * (float)monster_spawn_coeff) + return; + + int32_t row, col; + if (!craftax_spawn_scan_ranged(state, level, type, fighting_boss, pos_key, + &row, &col)) return; + + state->ranged_mobs.position[level][slot][0] = row; + state->ranged_mobs.position[level][slot][1] = col; + state->ranged_mobs.health[level][slot] = + craftax_spawn_mob_type_health(type, CRAFTAX_MOB_RANGED); + state->ranged_mobs.mask[level][slot] = true; + state->mob_map[level][row][col] = true; +} + +static inline void craftax_spawn_mobs_native( + CraftaxState* state, CraftaxThreefryKey rng +) { + int32_t level = craftax_step_jax_index( + state->player_level, CRAFTAX_NUM_LEVELS + ); + bool fighting_boss = craftax_step_is_fighting_boss(state); + int32_t monster_spawn_coeff = + 1 + + (int32_t)(state->monsters_killed[level] + < CRAFTAX_MONSTERS_KILLED_TO_CLEAR_LEVEL) * 2; + + bool boss_spawn_wave = + fighting_boss && state->boss_timesteps_to_spawn_this_round >= 1; + if (fighting_boss) { + monster_spawn_coeff *= (int32_t)boss_spawn_wave * 1000; + } + + craftax_spawn_passive_mob(state, &rng, level, fighting_boss); + craftax_spawn_melee_mob(state, &rng, level, fighting_boss, monster_spawn_coeff); + craftax_spawn_ranged_mob(state, &rng, level, fighting_boss, monster_spawn_coeff); +} diff --git a/ocean/craftax/step_update_mobs.h b/ocean/craftax/step_update_mobs.h new file mode 100644 index 0000000000..b8627f2623 --- /dev/null +++ b/ocean/craftax/step_update_mobs.h @@ -0,0 +1,1120 @@ +// Standalone native port of Craftax update_mobs. +// +// This helper intentionally is not integrated into c_step yet. It mutates a +// full CraftaxState in place so tests can compare the subsystem directly +// against the installed JAX implementation. + +#pragma once + +#include "step_do_action.h" + +#define CRAFTAX_UPDATE_BOSS_FIGHT_EXTRA_DAMAGE 0.5f + +static inline CraftaxThreefryKey craftax_update_mobs_next_random_key( + CraftaxThreefryKey* rng +) { + CraftaxThreefryKey draw; + craftax_threefry_split(*rng, rng, &draw); + return draw; +} + +static inline bool craftax_update_mobs_scatter_index( + int32_t index, + int32_t size, + int32_t* mapped_index +) { + if (index < -size || index >= size) { + return false; + } + *mapped_index = index < 0 ? index + size : index; + return true; +} + +static inline bool craftax_update_mobs_in_bounds( + int32_t row, + int32_t col +) { + return row >= 0 + && row < CRAFTAX_MAP_SIZE + && col >= 0 + && col < CRAFTAX_MAP_SIZE; +} + +static inline int32_t craftax_update_mobs_read_block( + const CraftaxState* state, + int32_t level, + int32_t row, + int32_t col +) { + int32_t map_level = craftax_step_jax_index(level, CRAFTAX_NUM_LEVELS); + int32_t map_row = craftax_step_jax_index(row, CRAFTAX_MAP_SIZE); + int32_t map_col = craftax_step_jax_index(col, CRAFTAX_MAP_SIZE); + return state->map[map_level][map_row][map_col]; +} + +static inline void craftax_update_mobs_set_block( + CraftaxState* state, + int32_t level, + int32_t row, + int32_t col, + int32_t block +) { + int32_t map_level; + int32_t map_row; + int32_t map_col; + if (!craftax_update_mobs_scatter_index( + level, + CRAFTAX_NUM_LEVELS, + &map_level + ) + || !craftax_update_mobs_scatter_index( + row, + CRAFTAX_MAP_SIZE, + &map_row + ) + || !craftax_update_mobs_scatter_index( + col, + CRAFTAX_MAP_SIZE, + &map_col + )) { + return; + } + state->map[map_level][map_row][map_col] = block; +} + +static inline bool craftax_update_mobs_read_mob_map( + const CraftaxState* state, + int32_t level, + int32_t row, + int32_t col +) { + int32_t map_level = craftax_step_jax_index(level, CRAFTAX_NUM_LEVELS); + int32_t map_row = craftax_step_jax_index(row, CRAFTAX_MAP_SIZE); + int32_t map_col = craftax_step_jax_index(col, CRAFTAX_MAP_SIZE); + return state->mob_map[map_level][map_row][map_col]; +} + +static inline void craftax_update_mobs_set_mob_map( + CraftaxState* state, + int32_t level, + int32_t row, + int32_t col, + bool value +) { + int32_t map_level; + int32_t map_row; + int32_t map_col; + if (!craftax_update_mobs_scatter_index( + level, + CRAFTAX_NUM_LEVELS, + &map_level + ) + || !craftax_update_mobs_scatter_index( + row, + CRAFTAX_MAP_SIZE, + &map_row + ) + || !craftax_update_mobs_scatter_index( + col, + CRAFTAX_MAP_SIZE, + &map_col + )) { + return; + } + state->mob_map[map_level][map_row][map_col] = value; +} + +static inline void craftax_update_mobs_clear_old_map_entry( + CraftaxState* state, + int32_t level, + int32_t row, + int32_t col, + bool old_mask +) { + bool old_value = craftax_update_mobs_read_mob_map(state, level, row, col); + craftax_update_mobs_set_mob_map( + state, + level, + row, + col, + old_value && !old_mask + ); +} + +static inline void craftax_update_mobs_enter_new_map_entry( + CraftaxState* state, + int32_t level, + int32_t row, + int32_t col, + bool new_mask +) { + bool old_value = craftax_update_mobs_read_mob_map(state, level, row, col); + craftax_update_mobs_set_mob_map( + state, + level, + row, + col, + old_value || new_mask + ); +} + +static inline void craftax_update_mobs_damage_vector( + int32_t type_id, + int32_t mob_class_index, + float damage[3] +) { + static const float damages[CRAFTAX_NUM_MOB_TYPES][4][3] = { + { + {0.0f, 0.0f, 0.0f}, + {2.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {2.0f, 0.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {4.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {4.0f, 0.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {3.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {0.0f, 3.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {5.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 3.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {6.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {5.0f, 0.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {6.0f, 1.0f, 1.0f}, + {0.0f, 0.0f, 0.0f}, + {4.0f, 3.0f, 3.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {3.0f, 5.0f, 0.0f}, + {0.0f, 0.0f, 0.0f}, + {3.0f, 5.0f, 0.0f}, + }, + { + {0.0f, 0.0f, 0.0f}, + {4.0f, 0.0f, 5.0f}, + {0.0f, 0.0f, 0.0f}, + {4.0f, 0.0f, 5.0f}, + }, + }; + + int32_t type_index = craftax_step_jax_index( + type_id, + CRAFTAX_NUM_MOB_TYPES + ); + int32_t class_index = craftax_step_jax_index(mob_class_index, 4); + for (int32_t i = 0; i < 3; i++) { + damage[i] = damages[type_index][class_index][i]; + } +} + +static inline void craftax_update_mobs_collision_map( + int32_t type_id, + int32_t mob_class_index, + bool collision[3] +) { + static const bool collisions[CRAFTAX_NUM_MOB_TYPES][4][3] = { + { + {false, true, true}, + {false, true, true}, + {false, true, true}, + {false, false, false}, + }, + { + {false, false, false}, + {false, true, true}, + {false, true, true}, + {false, false, false}, + }, + { + {false, true, true}, + {false, true, true}, + {false, true, true}, + {false, false, false}, + }, + { + {false, true, true}, + {false, false, true}, + {false, true, true}, + {false, false, false}, + }, + { + {false, true, true}, + {false, true, true}, + {false, true, true}, + {false, false, false}, + }, + { + {false, true, true}, + {false, true, true}, + {true, false, true}, + {false, false, false}, + }, + { + {false, true, true}, + {false, true, true}, + {false, false, false}, + {false, false, false}, + }, + { + {false, true, true}, + {false, true, true}, + {false, false, false}, + {false, false, false}, + }, + }; + + int32_t type_index = craftax_step_jax_index( + type_id, + CRAFTAX_NUM_MOB_TYPES + ); + int32_t class_index = craftax_step_jax_index(mob_class_index, 4); + for (int32_t i = 0; i < 3; i++) { + collision[i] = collisions[type_index][class_index][i]; + } +} + +static inline int32_t craftax_update_mobs_projectile_type_for_ranged( + int32_t ranged_type +) { + static const int32_t mapping[CRAFTAX_NUM_MOB_TYPES] = { + CRAFTAX_PROJECTILE_ARROW, + CRAFTAX_PROJECTILE_ARROW, + CRAFTAX_PROJECTILE_FIREBALL, + CRAFTAX_PROJECTILE_DAGGER, + CRAFTAX_PROJECTILE_ARROW2, + CRAFTAX_PROJECTILE_SLIMEBALL, + CRAFTAX_PROJECTILE_FIREBALL2, + CRAFTAX_PROJECTILE_ICEBALL2, + }; + int32_t type_index = craftax_step_jax_index( + ranged_type, + CRAFTAX_NUM_MOB_TYPES + ); + return mapping[type_index]; +} + +static inline void craftax_update_mobs_direction_choice( + CraftaxThreefryKey key, + int32_t count, + int32_t direction[2] +) { + int32_t choice = craftax_medium_randint(key, 0, count); + direction[0] = 0; + direction[1] = 0; + if (choice == 0) { + direction[1] = -1; + } else if (choice == 1) { + direction[1] = 1; + } else if (choice == 2) { + direction[0] = -1; + } else if (choice == 3) { + direction[0] = 1; + } +} + +static inline int32_t craftax_update_mobs_abs_i32(int32_t value) { + return value < 0 ? -value : value; +} + +static inline int32_t craftax_update_mobs_sign_i32(int32_t value) { + if (value < 0) { + return -1; + } + return value > 0 ? 1 : 0; +} + +static inline int32_t craftax_update_mobs_player_axis_choice( + CraftaxThreefryKey key, + int32_t distance_row, + int32_t distance_col +) { + int32_t max_distance = distance_row > distance_col + ? distance_row + : distance_col; + int32_t total_distance = distance_row + distance_col; + if (total_distance == 0) { + return 1; + } + + float weights[2] = { + (distance_row == max_distance) ? 1.0f / (float)total_distance : 0.0f, + (distance_col == max_distance) ? 1.0f / (float)total_distance : 0.0f, + }; + return craftax_medium_choice_weighted(key, weights, 2); +} + +static inline bool craftax_update_mobs_valid_position( + const CraftaxState* state, + int32_t row, + int32_t col, + const bool collision[3] +) { + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + bool pos_in_bounds = craftax_update_mobs_in_bounds(row, col); + int32_t block = craftax_update_mobs_read_block(state, level, row, col); + bool in_solid_block = craftax_step_is_solid_block(block); + bool in_mob = craftax_step_is_in_mob(state, row, col); + bool in_lava = block == CRAFTAX_BLOCK_LAVA; + bool in_water = block == CRAFTAX_BLOCK_WATER; + bool on_ground_block = !in_solid_block && !in_water && !in_lava; + + bool valid_move = pos_in_bounds && !in_mob && !in_solid_block; + valid_move = valid_move && (!collision[0] || !on_ground_block); + valid_move = valid_move && (!collision[1] || !in_water); + valid_move = valid_move && (!collision[2] || !in_lava); + return valid_move; +} + +static inline int32_t craftax_update_mobs_manhattan_to_player( + const CraftaxState* state, + int32_t row, + int32_t col +) { + return craftax_update_mobs_abs_i32(row - state->player_position[0]) + + craftax_update_mobs_abs_i32(col - state->player_position[1]); +} + +static inline float craftax_update_mobs_damage_done_to_player( + const CraftaxState* state, + const float damage_vector[3] +) { + float defense_vector[3] = {0.0f, 0.0f, 0.0f}; + for (int32_t i = 0; i < 4; i++) { + defense_vector[0] += (float)state->inventory.armour[i] * 0.1f; + defense_vector[1] += + (float)(int32_t)(state->armour_enchantments[i] == 1) * 0.2f; + defense_vector[2] += + (float)(int32_t)(state->armour_enchantments[i] == 2) * 0.2f; + } + + float boss_coeff = craftax_step_is_fighting_boss(state) + ? 1.0f + CRAFTAX_UPDATE_BOSS_FIGHT_EXTRA_DAMAGE + : 1.0f; + float damage = 0.0f; + for (int32_t i = 0; i < 3; i++) { + damage += (1.0f - defense_vector[i]) * damage_vector[i] * boss_coeff; + } + return damage; +} + +static inline int32_t craftax_update_mobs_count_mob_projectiles( + const CraftaxState* state, + int32_t level +) { + int32_t count = 0; + for (int32_t i = 0; i < CRAFTAX_MAX_MOB_PROJECTILES; i++) { + count += (int32_t)state->mob_projectiles.mask[level][i]; + } + return count; +} + +static inline int32_t craftax_update_mobs_first_empty_mob_projectile( + const CraftaxState* state, + int32_t level +) { + for (int32_t i = 0; i < CRAFTAX_MAX_MOB_PROJECTILES; i++) { + if (!state->mob_projectiles.mask[level][i]) { + return i; + } + } + return 0; +} + +static inline void craftax_update_mobs_spawn_mob_projectile( + CraftaxState* state, + int32_t level, + bool is_spawning_projectile, + const int32_t position[2], + const int32_t direction[2], + int32_t projectile_type +) { + if (!is_spawning_projectile) { + return; + } + + int32_t index = craftax_update_mobs_first_empty_mob_projectile( + state, + level + ); + state->mob_projectiles.position[level][index][0] = position[0]; + state->mob_projectiles.position[level][index][1] = position[1]; + state->mob_projectiles.mask[level][index] = true; + state->mob_projectiles.type_id[level][index] = projectile_type; + state->mob_projectile_directions[level][index][0] = direction[0]; + state->mob_projectile_directions[level][index][1] = direction[1]; +} + +static inline void craftax_update_mobs_attack_mob_with_damage( + CraftaxState* state, + int32_t row, + int32_t col, + const float damage_vector[3], + bool can_eat, + bool* did_attack_mob, + bool* did_kill_mob +) { + bool did_kill_melee_mob = false; + bool is_attacking_melee_mob = false; + craftax_do_action_attack_mobs3( + state, + &state->melee_mobs, + row, + col, + damage_vector, + true, + CRAFTAX_MOB_MELEE, + &did_kill_melee_mob, + &is_attacking_melee_mob + ); + + bool did_kill_passive_mob = false; + bool is_attacking_passive_mob = false; + craftax_do_action_attack_mobs3( + state, + &state->passive_mobs, + row, + col, + damage_vector, + can_eat, + CRAFTAX_MOB_PASSIVE, + &did_kill_passive_mob, + &is_attacking_passive_mob + ); + + if (did_kill_passive_mob && can_eat) { + state->player_food = craftax_step_mini32( + craftax_step_get_max_food(state), + state->player_food + 6 + ); + state->player_hunger = 0.0f; + } + + bool did_kill_ranged_mob = false; + bool is_attacking_ranged_mob = false; + craftax_do_action_attack_mobs2( + state, + &state->ranged_mobs, + row, + col, + damage_vector, + true, + CRAFTAX_MOB_RANGED, + &did_kill_ranged_mob, + &is_attacking_ranged_mob + ); + + *did_attack_mob = is_attacking_melee_mob + || is_attacking_passive_mob + || is_attacking_ranged_mob; + bool did_kill_monster = did_kill_melee_mob || did_kill_ranged_mob; + *did_kill_mob = did_kill_monster || did_kill_passive_mob; + + craftax_do_action_update_mob_map(state, row, col, *did_kill_mob); + + int32_t level = craftax_step_jax_index( + state->player_level, + CRAFTAX_NUM_LEVELS + ); + state->monsters_killed[level] += (int32_t)did_kill_monster; +} + +static inline void craftax_update_mobs_player_projectile_damage_vector( + const CraftaxState* state, + int32_t level, + int32_t projectile_index, + float damage_vector[3] +) { + int32_t projectile_type = + state->player_projectiles.type_id[level][projectile_index]; + craftax_update_mobs_damage_vector( + projectile_type, + CRAFTAX_MOB_PROJECTILE, + damage_vector + ); + + float mask = (float)(int32_t) + state->player_projectiles.mask[level][projectile_index]; + for (int32_t i = 0; i < 3; i++) { + damage_vector[i] *= mask; + } + + bool is_arrow = projectile_type == CRAFTAX_PROJECTILE_ARROW + || projectile_type == CRAFTAX_PROJECTILE_ARROW2; + if (is_arrow) { + float arrow_damage_add[3] = {0.0f, 0.0f, 0.0f}; + int32_t enchantment_index; + if (craftax_update_mobs_scatter_index( + state->bow_enchantment, + 3, + &enchantment_index + )) { + arrow_damage_add[enchantment_index] = damage_vector[0] / 2.0f; + } + arrow_damage_add[0] = 0.0f; + for (int32_t i = 0; i < 3; i++) { + damage_vector[i] += arrow_damage_add[i]; + } + } + + if (is_arrow) { + float arrow_damage_coeff = + 1.0f + 0.2f * (float)(state->player_dexterity - 1); + for (int32_t i = 0; i < 3; i++) { + damage_vector[i] *= arrow_damage_coeff; + } + } + + bool is_magic_projectile = projectile_type == CRAFTAX_PROJECTILE_FIREBALL + || projectile_type == CRAFTAX_PROJECTILE_ICEBALL; + if (is_magic_projectile) { + float magic_damage_coeff = + 1.0f + 0.5f * (float)(state->player_intelligence - 1); + for (int32_t i = 0; i < 3; i++) { + damage_vector[i] *= magic_damage_coeff; + } + } +} + +static inline void craftax_update_mobs_move_melee( + CraftaxState* state, + CraftaxThreefryKey* rng, + int32_t index +) { + int32_t level = state->player_level; + bool old_mask = state->melee_mobs.mask[level][index]; + // Dead slot early-out: no observable effect on obs/reward/terminal. + // Skip body and RNG draws for speed. Breaks per-seed replay against + // JAX; define CRAFTAX_JAX_PARITY at build time to restore the + // branchless slow path (same pattern in every move_* below). +#ifndef CRAFTAX_JAX_PARITY + if (!old_mask) return; +#endif + int32_t old_row = state->melee_mobs.position[level][index][0]; + int32_t old_col = state->melee_mobs.position[level][index][1]; + int32_t old_cooldown = state->melee_mobs.attack_cooldown[level][index]; + int32_t mob_type = state->melee_mobs.type_id[level][index]; + + CraftaxThreefryKey draw_key = + craftax_update_mobs_next_random_key(rng); + int32_t random_direction[2]; + craftax_update_mobs_direction_choice(draw_key, 4, random_direction); + int32_t random_row = old_row + random_direction[0]; + int32_t random_col = old_col + random_direction[1]; + + int32_t distance_row = + craftax_update_mobs_abs_i32(state->player_position[0] - old_row); + int32_t distance_col = + craftax_update_mobs_abs_i32(state->player_position[1] - old_col); + draw_key = craftax_update_mobs_next_random_key(rng); + int32_t player_move_axis = craftax_update_mobs_player_axis_choice( + draw_key, + distance_row, + distance_col + ); + int32_t player_direction[2] = {0, 0}; + if (player_move_axis == 0) { + player_direction[0] = + craftax_update_mobs_sign_i32(state->player_position[0] - old_row); + } else { + player_direction[1] = + craftax_update_mobs_sign_i32(state->player_position[1] - old_col); + } + int32_t player_row = old_row + player_direction[0]; + int32_t player_col = old_col + player_direction[1]; + + int32_t distance_to_player = distance_row + distance_col; + bool close_to_player = distance_to_player < 10 + || craftax_step_is_fighting_boss(state); + draw_key = craftax_update_mobs_next_random_key(rng); + close_to_player = close_to_player + && craftax_threefry_uniform_f32(draw_key) < 0.75f; + + int32_t proposed_row = close_to_player ? player_row : random_row; + int32_t proposed_col = close_to_player ? player_col : random_col; + + bool is_attacking_player = distance_to_player == 1 + && old_cooldown <= 0 + && old_mask; + if (is_attacking_player) { + proposed_row = old_row; + proposed_col = old_col; + } + + float base_damage[3]; + craftax_update_mobs_damage_vector( + mob_type, + CRAFTAX_MOB_MELEE, + base_damage + ); + float sleeping_coeff = 1.0f + 2.5f * (float)(int32_t)state->is_sleeping; + for (int32_t i = 0; i < 3; i++) { + base_damage[i] *= sleeping_coeff; + } + float damage = craftax_update_mobs_damage_done_to_player( + state, + base_damage + ); + + int32_t new_cooldown = is_attacking_player ? 5 : old_cooldown - 1; + bool is_waking_player = state->is_sleeping && is_attacking_player; + state->player_health -= damage * (float)(int32_t)is_attacking_player; + state->is_sleeping = state->is_sleeping && !is_attacking_player; + state->is_resting = state->is_resting && !is_attacking_player; + state->achievements[CRAFTAX_ACH_WAKE_UP] = + state->achievements[CRAFTAX_ACH_WAKE_UP] || is_waking_player; + + bool collision[3]; + craftax_update_mobs_collision_map( + mob_type, + CRAFTAX_MOB_MELEE, + collision + ); + bool valid_move = craftax_update_mobs_valid_position( + state, + proposed_row, + proposed_col, + collision + ); + int32_t new_row = valid_move ? proposed_row : old_row; + int32_t new_col = valid_move ? proposed_col : old_col; + + bool should_not_despawn = distance_to_player < CRAFTAX_MOB_DESPAWN_DISTANCE + || craftax_step_is_fighting_boss(state); + + CraftaxThreefryKey unused_left; + CraftaxThreefryKey returned_key; + craftax_threefry_split(*rng, &unused_left, &returned_key); + *rng = returned_key; + + craftax_update_mobs_clear_old_map_entry( + state, + level, + old_row, + old_col, + old_mask + ); + bool new_mask = old_mask && should_not_despawn; + craftax_update_mobs_enter_new_map_entry( + state, + level, + new_row, + new_col, + new_mask + ); + + state->melee_mobs.position[level][index][0] = new_row; + state->melee_mobs.position[level][index][1] = new_col; + state->melee_mobs.attack_cooldown[level][index] = new_cooldown; + state->melee_mobs.mask[level][index] = new_mask; +} + +static inline void craftax_update_mobs_move_passive( + CraftaxState* state, + CraftaxThreefryKey* rng, + int32_t index +) { + int32_t level = state->player_level; + bool old_mask = state->passive_mobs.mask[level][index]; +#ifndef CRAFTAX_JAX_PARITY + if (!old_mask) return; +#endif + int32_t old_row = state->passive_mobs.position[level][index][0]; + int32_t old_col = state->passive_mobs.position[level][index][1]; + int32_t mob_type = state->passive_mobs.type_id[level][index]; + + CraftaxThreefryKey draw_key = + craftax_update_mobs_next_random_key(rng); + int32_t direction[2]; + craftax_update_mobs_direction_choice(draw_key, 8, direction); + int32_t proposed_row = old_row + direction[0]; + int32_t proposed_col = old_col + direction[1]; + + bool collision[3]; + craftax_update_mobs_collision_map( + mob_type, + CRAFTAX_MOB_PASSIVE, + collision + ); + bool valid_move = craftax_update_mobs_valid_position( + state, + proposed_row, + proposed_col, + collision + ); + int32_t new_row = valid_move ? proposed_row : old_row; + int32_t new_col = valid_move ? proposed_col : old_col; + + int32_t distance_to_player = craftax_update_mobs_manhattan_to_player( + state, + old_row, + old_col + ); + bool should_not_despawn = + distance_to_player < CRAFTAX_MOB_DESPAWN_DISTANCE; + + craftax_update_mobs_clear_old_map_entry( + state, + level, + old_row, + old_col, + old_mask + ); + bool new_mask = old_mask && should_not_despawn; + craftax_update_mobs_enter_new_map_entry( + state, + level, + new_row, + new_col, + new_mask + ); + + state->passive_mobs.position[level][index][0] = new_row; + state->passive_mobs.position[level][index][1] = new_col; + state->passive_mobs.mask[level][index] = new_mask; +} + +static inline void craftax_update_mobs_move_ranged( + CraftaxState* state, + CraftaxThreefryKey* rng, + int32_t index +) { + int32_t level = state->player_level; + bool old_mask = state->ranged_mobs.mask[level][index]; +#ifndef CRAFTAX_JAX_PARITY + if (!old_mask) return; +#endif + int32_t old_row = state->ranged_mobs.position[level][index][0]; + int32_t old_col = state->ranged_mobs.position[level][index][1]; + int32_t old_cooldown = state->ranged_mobs.attack_cooldown[level][index]; + int32_t mob_type = state->ranged_mobs.type_id[level][index]; + + CraftaxThreefryKey draw_key = + craftax_update_mobs_next_random_key(rng); + int32_t random_direction[2]; + craftax_update_mobs_direction_choice(draw_key, 4, random_direction); + int32_t random_row = old_row + random_direction[0]; + int32_t random_col = old_col + random_direction[1]; + + int32_t distance_row = + craftax_update_mobs_abs_i32(state->player_position[0] - old_row); + int32_t distance_col = + craftax_update_mobs_abs_i32(state->player_position[1] - old_col); + draw_key = craftax_update_mobs_next_random_key(rng); + int32_t player_move_axis = craftax_update_mobs_player_axis_choice( + draw_key, + distance_row, + distance_col + ); + int32_t player_direction[2] = {0, 0}; + if (player_move_axis == 0) { + player_direction[0] = + craftax_update_mobs_sign_i32(state->player_position[0] - old_row); + } else { + player_direction[1] = + craftax_update_mobs_sign_i32(state->player_position[1] - old_col); + } + int32_t towards_row = old_row + player_direction[0]; + int32_t towards_col = old_col + player_direction[1]; + int32_t away_row = old_row - player_direction[0]; + int32_t away_col = old_col - player_direction[1]; + + int32_t distance_to_player = distance_row + distance_col; + bool far_from_player = distance_to_player >= 6; + bool too_close_to_player = distance_to_player <= 3; + int32_t proposed_row = far_from_player ? towards_row : random_row; + int32_t proposed_col = far_from_player ? towards_col : random_col; + if (too_close_to_player) { + proposed_row = away_row; + proposed_col = away_col; + } + + draw_key = craftax_update_mobs_next_random_key(rng); + if (!(craftax_threefry_uniform_f32(draw_key) > 0.85f)) { + proposed_row = random_row; + proposed_col = random_col; + } + + bool collision[3]; + craftax_update_mobs_collision_map( + mob_type, + CRAFTAX_MOB_RANGED, + collision + ); + + bool is_attacking_player = + distance_to_player >= 4 && distance_to_player <= 5; + bool proposed_valid = craftax_update_mobs_valid_position( + state, + proposed_row, + proposed_col, + collision + ); + is_attacking_player = is_attacking_player + || (too_close_to_player && !proposed_valid); + is_attacking_player = is_attacking_player + && old_cooldown <= 0 + && old_mask; + + bool can_spawn_projectile = + craftax_update_mobs_count_mob_projectiles(state, level) + < CRAFTAX_MAX_MOB_PROJECTILES; + bool is_spawning_projectile = + is_attacking_player && can_spawn_projectile; + int32_t projectile_position[2] = {old_row, old_col}; + int32_t projectile_type = + craftax_update_mobs_projectile_type_for_ranged(mob_type); + craftax_update_mobs_spawn_mob_projectile( + state, + level, + is_spawning_projectile, + projectile_position, + player_direction, + projectile_type + ); + + if (is_attacking_player) { + proposed_row = old_row; + proposed_col = old_col; + } + int32_t new_cooldown = is_attacking_player ? 4 : old_cooldown - 1; + + bool valid_move = craftax_update_mobs_valid_position( + state, + proposed_row, + proposed_col, + collision + ); + int32_t new_row = valid_move ? proposed_row : old_row; + int32_t new_col = valid_move ? proposed_col : old_col; + + bool should_not_despawn = distance_to_player < CRAFTAX_MOB_DESPAWN_DISTANCE + || craftax_step_is_fighting_boss(state); + + craftax_update_mobs_clear_old_map_entry( + state, + level, + old_row, + old_col, + old_mask + ); + bool new_mask = old_mask && should_not_despawn; + craftax_update_mobs_enter_new_map_entry( + state, + level, + new_row, + new_col, + new_mask + ); + + state->ranged_mobs.position[level][index][0] = new_row; + state->ranged_mobs.position[level][index][1] = new_col; + state->ranged_mobs.attack_cooldown[level][index] = new_cooldown; + state->ranged_mobs.mask[level][index] = new_mask; +} + +static inline void craftax_update_mobs_move_mob_projectile( + CraftaxState* state, + int32_t index +) { + int32_t level = state->player_level; + bool old_mask = state->mob_projectiles.mask[level][index]; +#ifndef CRAFTAX_JAX_PARITY + if (!old_mask) return; +#endif + int32_t old_row = state->mob_projectiles.position[level][index][0]; + int32_t old_col = state->mob_projectiles.position[level][index][1]; + int32_t proposed_row = + old_row + state->mob_projectile_directions[level][index][0]; + int32_t proposed_col = + old_col + state->mob_projectile_directions[level][index][1]; + + bool proposed_in_player = + proposed_row == state->player_position[0] + && proposed_col == state->player_position[1]; + bool proposed_in_bounds = craftax_update_mobs_in_bounds( + proposed_row, + proposed_col + ); + int32_t proposed_block = craftax_update_mobs_read_block( + state, + level, + proposed_row, + proposed_col + ); + bool in_wall = craftax_step_is_solid_block(proposed_block) + && proposed_block != CRAFTAX_BLOCK_WATER; + bool in_mob = craftax_step_is_in_mob(state, proposed_row, proposed_col); + bool continue_move = proposed_in_bounds && !in_wall && !in_mob; + + bool hit_player0 = + old_row == state->player_position[0] + && old_col == state->player_position[1] + && old_mask; + bool hit_player1 = proposed_in_player && old_mask; + bool hit_player = hit_player0 || hit_player1; + continue_move = continue_move && !hit_player; + + bool new_mask = continue_move && old_mask; + + bool hit_bench_or_furnace = proposed_block == CRAFTAX_BLOCK_FURNACE + || proposed_block == CRAFTAX_BLOCK_CRAFTING_TABLE; + bool removing_block = hit_bench_or_furnace && old_mask; + int32_t new_block = removing_block ? CRAFTAX_BLOCK_PATH : proposed_block; + + int32_t projectile_type = + state->mob_projectiles.type_id[level][index]; + float damage_vector[3]; + craftax_update_mobs_damage_vector( + projectile_type, + CRAFTAX_MOB_PROJECTILE, + damage_vector + ); + float damage = craftax_update_mobs_damage_done_to_player( + state, + damage_vector + ); + + state->mob_projectiles.position[level][index][0] = proposed_row; + state->mob_projectiles.position[level][index][1] = proposed_col; + state->mob_projectiles.mask[level][index] = new_mask; + state->player_health -= damage * (float)(int32_t)hit_player; + state->is_sleeping = state->is_sleeping && !hit_player; + state->is_resting = state->is_resting && !hit_player; + craftax_update_mobs_set_block( + state, + level, + proposed_row, + proposed_col, + new_block + ); +} + +static inline void craftax_update_mobs_move_player_projectile( + CraftaxState* state, + int32_t index +) { + int32_t level = state->player_level; + bool old_mask = state->player_projectiles.mask[level][index]; +#ifndef CRAFTAX_JAX_PARITY + if (!old_mask) return; +#endif + int32_t old_row = state->player_projectiles.position[level][index][0]; + int32_t old_col = state->player_projectiles.position[level][index][1]; + int32_t proposed_row = + old_row + state->player_projectile_directions[level][index][0]; + int32_t proposed_col = + old_col + state->player_projectile_directions[level][index][1]; + + float damage_vector[3]; + craftax_update_mobs_player_projectile_damage_vector( + state, + level, + index, + damage_vector + ); + + bool proposed_in_bounds = craftax_update_mobs_in_bounds( + proposed_row, + proposed_col + ); + int32_t proposed_block = craftax_update_mobs_read_block( + state, + level, + proposed_row, + proposed_col + ); + bool in_wall = craftax_step_is_solid_block(proposed_block) + && proposed_block != CRAFTAX_BLOCK_WATER; + + bool did_attack_mob0 = false; + bool did_kill_mob0 = false; + craftax_update_mobs_attack_mob_with_damage( + state, + old_row, + old_col, + damage_vector, + false, + &did_attack_mob0, + &did_kill_mob0 + ); + (void)did_kill_mob0; + + float second_damage_vector[3]; + for (int32_t i = 0; i < 3; i++) { + second_damage_vector[i] = + damage_vector[i] * (float)(int32_t)(!did_attack_mob0); + } + + bool did_attack_mob1 = false; + bool did_kill_mob1 = false; + craftax_update_mobs_attack_mob_with_damage( + state, + proposed_row, + proposed_col, + second_damage_vector, + false, + &did_attack_mob1, + &did_kill_mob1 + ); + (void)did_kill_mob1; + + bool did_attack_mob = did_attack_mob0 || did_attack_mob1; + bool continue_move = proposed_in_bounds && !in_wall && !did_attack_mob; + bool new_mask = continue_move && old_mask; + + state->player_projectiles.position[level][index][0] = proposed_row; + state->player_projectiles.position[level][index][1] = proposed_col; + state->player_projectiles.mask[level][index] = new_mask; +} + +static inline void craftax_update_mobs_native( + CraftaxState* state, + CraftaxThreefryKey rng +) { + CraftaxThreefryKey unused; + + craftax_threefry_split(rng, &rng, &unused); + for (int32_t i = 0; i < CRAFTAX_MAX_MELEE_MOBS; i++) { + craftax_update_mobs_move_melee(state, &rng, i); + } + + craftax_threefry_split(rng, &rng, &unused); + for (int32_t i = 0; i < CRAFTAX_MAX_PASSIVE_MOBS; i++) { + craftax_update_mobs_move_passive(state, &rng, i); + } + + craftax_threefry_split(rng, &rng, &unused); + for (int32_t i = 0; i < CRAFTAX_MAX_RANGED_MOBS; i++) { + craftax_update_mobs_move_ranged(state, &rng, i); + } + + craftax_threefry_split(rng, &rng, &unused); + for (int32_t i = 0; i < CRAFTAX_MAX_MOB_PROJECTILES; i++) { + craftax_update_mobs_move_mob_projectile(state, i); + } + + craftax_threefry_split(rng, &rng, &unused); + for (int32_t i = 0; i < CRAFTAX_MAX_PLAYER_PROJECTILES; i++) { + craftax_update_mobs_move_player_projectile(state, i); + } +} diff --git a/ocean/craftax/threefry.h b/ocean/craftax/threefry.h new file mode 100644 index 0000000000..c2f3c2d35d --- /dev/null +++ b/ocean/craftax/threefry.h @@ -0,0 +1,133 @@ +// JAX-compatible threefry2x32 helpers for Craftax. +// +// The local JAX version uses the partitionable threefry split path by default: +// split(key, n)[i] is threefry2x32(key, counter=(0, i)). 32-bit random bits +// are bits0 ^ bits1 from the same counter schedule. + +#pragma once + +#include +#include +#include + +typedef struct CraftaxThreefryKey { + uint32_t word[2]; +} CraftaxThreefryKey; + +static inline uint32_t craftax_rotl32(uint32_t x, uint32_t k) { + return (uint32_t)((x << k) | (x >> (32u - k))); +} + +static inline CraftaxThreefryKey craftax_prng_key(uint32_t seed) { + CraftaxThreefryKey key = {{0u, seed}}; + return key; +} + +static inline void craftax_threefry2x32( + CraftaxThreefryKey key, + uint32_t count0, + uint32_t count1, + uint32_t out[2] +) { + static const uint32_t rotations[2][4] = { + {13u, 15u, 26u, 6u}, + {17u, 29u, 16u, 24u}, + }; + + uint32_t ks[3] = { + key.word[0], + key.word[1], + key.word[0] ^ key.word[1] ^ 0x1BD11BDAu, + }; + uint32_t x0 = count0 + ks[0]; + uint32_t x1 = count1 + ks[1]; + + for (uint32_t block = 0; block < 5u; block++) { + const uint32_t* rs = rotations[block & 1u]; + for (int i = 0; i < 4; i++) { + x0 += x1; + x1 = craftax_rotl32(x1, rs[i]); + x1 ^= x0; + } + x0 += ks[(block + 1u) % 3u]; + x1 += ks[(block + 2u) % 3u] + block + 1u; + } + + out[0] = x0; + out[1] = x1; +} + +static inline CraftaxThreefryKey craftax_threefry_counter_key( + CraftaxThreefryKey key, + uint32_t count0, + uint32_t count1 +) { + uint32_t out[2]; + craftax_threefry2x32(key, count0, count1, out); + CraftaxThreefryKey result = {{out[0], out[1]}}; + return result; +} + +static inline void craftax_threefry_split( + CraftaxThreefryKey key, + CraftaxThreefryKey* left, + CraftaxThreefryKey* right +) { + *left = craftax_threefry_counter_key(key, 0u, 0u); + *right = craftax_threefry_counter_key(key, 0u, 1u); +} + +static inline void craftax_threefry_split_n( + CraftaxThreefryKey key, + CraftaxThreefryKey* out, + size_t count +) { + for (size_t i = 0; i < count; i++) { + uint64_t counter = (uint64_t)i; + out[i] = craftax_threefry_counter_key( + key, + (uint32_t)(counter >> 32), + (uint32_t)counter + ); + } +} + +static inline CraftaxThreefryKey craftax_threefry_fold_in( + CraftaxThreefryKey key, + uint32_t data +) { + return craftax_threefry_counter_key(key, 0u, data); +} + +static inline uint32_t craftax_threefry_uniform_u32_at( + CraftaxThreefryKey key, + uint64_t index +) { + uint32_t out[2]; + craftax_threefry2x32( + key, + (uint32_t)(index >> 32), + (uint32_t)index, + out + ); + return out[0] ^ out[1]; +} + +static inline uint32_t craftax_threefry_uniform_u32(CraftaxThreefryKey key) { + return craftax_threefry_uniform_u32_at(key, 0u); +} + +static inline float craftax_threefry_uniform_f32_at( + CraftaxThreefryKey key, + uint64_t index +) { + uint32_t bits = craftax_threefry_uniform_u32_at(key, index); + uint32_t float_bits = (bits >> 9u) | 0x3F800000u; + float value; + memcpy(&value, &float_bits, sizeof(value)); + return value - 1.0f; +} + +static inline float craftax_threefry_uniform_f32(CraftaxThreefryKey key) { + return craftax_threefry_uniform_f32_at(key, 0u); +} diff --git a/ocean/craftax/worldgen.h b/ocean/craftax/worldgen.h new file mode 100644 index 0000000000..a6d2bcb11b --- /dev/null +++ b/ocean/craftax/worldgen.h @@ -0,0 +1,1519 @@ +// Native Craftax reset world generation. +// +// This mirrors craftax/craftax/world_gen/world_gen.py for the default +// EnvParams and StaticEnvParams used by Craftax-Symbolic-v1 reset. + +#pragma once + +#include +#include +#include +#include +#include + +#include "noise.h" + +#define CRAFTAX_WG_MAP_SIZE 48 +#define CRAFTAX_WG_MAP_CELLS (CRAFTAX_WG_MAP_SIZE * CRAFTAX_WG_MAP_SIZE) +#define CRAFTAX_WG_NUM_LEVELS 9 +#define CRAFTAX_WG_OBS_ROWS 9 +#define CRAFTAX_WG_OBS_COLS 11 +#define CRAFTAX_WG_NUM_BLOCK_TYPES 37 +#define CRAFTAX_WG_NUM_ITEM_TYPES 5 +#define CRAFTAX_WG_NUM_MOB_CLASSES 5 +#define CRAFTAX_WG_NUM_MOB_TYPES 8 +#define CRAFTAX_WG_INVENTORY_OBS_SIZE 51 +#define CRAFTAX_WG_OBS_SIZE 8268 +#define CRAFTAX_WG_NUM_ACHIEVEMENTS 67 +#define CRAFTAX_WG_MAX_MELEE_MOBS 3 +#define CRAFTAX_WG_MAX_PASSIVE_MOBS 3 +#define CRAFTAX_WG_MAX_RANGED_MOBS 2 +#define CRAFTAX_WG_MAX_MOB_PROJECTILES 3 +#define CRAFTAX_WG_MAX_PLAYER_PROJECTILES 3 +#define CRAFTAX_WG_MAX_GROWING_PLANTS 10 +#define CRAFTAX_WG_MONSTERS_KILLED_TO_CLEAR_LEVEL 8 + +// Backwards-compatible names used by the phase-1 floor-0 test. +#define CRAFTAX_OVERWORLD_SIZE CRAFTAX_WG_MAP_SIZE +#define CRAFTAX_OVERWORLD_CELLS CRAFTAX_WG_MAP_CELLS + +#define CRAFTAX_WG_BLOCK_INVALID 0 +#define CRAFTAX_WG_BLOCK_OUT_OF_BOUNDS 1 +#define CRAFTAX_WG_BLOCK_GRASS 2 +#define CRAFTAX_WG_BLOCK_WATER 3 +#define CRAFTAX_WG_BLOCK_STONE 4 +#define CRAFTAX_WG_BLOCK_TREE 5 +#define CRAFTAX_WG_BLOCK_WOOD 6 +#define CRAFTAX_WG_BLOCK_PATH 7 +#define CRAFTAX_WG_BLOCK_COAL 8 +#define CRAFTAX_WG_BLOCK_IRON 9 +#define CRAFTAX_WG_BLOCK_DIAMOND 10 +#define CRAFTAX_WG_BLOCK_CRAFTING_TABLE 11 +#define CRAFTAX_WG_BLOCK_FURNACE 12 +#define CRAFTAX_WG_BLOCK_SAND 13 +#define CRAFTAX_WG_BLOCK_LAVA 14 +#define CRAFTAX_WG_BLOCK_PLANT 15 +#define CRAFTAX_WG_BLOCK_RIPE_PLANT 16 +#define CRAFTAX_WG_BLOCK_WALL 17 +#define CRAFTAX_WG_BLOCK_DARKNESS 18 +#define CRAFTAX_WG_BLOCK_WALL_MOSS 19 +#define CRAFTAX_WG_BLOCK_STALAGMITE 20 +#define CRAFTAX_WG_BLOCK_SAPPHIRE 21 +#define CRAFTAX_WG_BLOCK_RUBY 22 +#define CRAFTAX_WG_BLOCK_CHEST 23 +#define CRAFTAX_WG_BLOCK_FOUNTAIN 24 +#define CRAFTAX_WG_BLOCK_FIRE_GRASS 25 +#define CRAFTAX_WG_BLOCK_ICE_GRASS 26 +#define CRAFTAX_WG_BLOCK_GRAVEL 27 +#define CRAFTAX_WG_BLOCK_FIRE_TREE 28 +#define CRAFTAX_WG_BLOCK_ICE_SHRUB 29 +#define CRAFTAX_WG_BLOCK_ENCHANTMENT_TABLE_FIRE 30 +#define CRAFTAX_WG_BLOCK_ENCHANTMENT_TABLE_ICE 31 +#define CRAFTAX_WG_BLOCK_NECROMANCER 32 +#define CRAFTAX_WG_BLOCK_GRAVE 33 +#define CRAFTAX_WG_BLOCK_GRAVE2 34 +#define CRAFTAX_WG_BLOCK_GRAVE3 35 +#define CRAFTAX_WG_BLOCK_NECROMANCER_VULNERABLE 36 + +#define CRAFTAX_WG_ITEM_NONE 0 +#define CRAFTAX_WG_ITEM_TORCH 1 +#define CRAFTAX_WG_ITEM_LADDER_DOWN 2 +#define CRAFTAX_WG_ITEM_LADDER_UP 3 +#define CRAFTAX_WG_ITEM_LADDER_DOWN_BLOCKED 4 + +#define CRAFTAX_WG_ACTION_UP 3 +#define CRAFTAX_WG_BOSS_FIGHT_SPAWN_TURNS 7 +#define CRAFTAX_WG_PI 3.14159265358979323846f + +typedef struct CraftaxOverworldFloor { + int32_t map[CRAFTAX_OVERWORLD_SIZE][CRAFTAX_OVERWORLD_SIZE]; + int32_t item_map[CRAFTAX_OVERWORLD_SIZE][CRAFTAX_OVERWORLD_SIZE]; + float light_map[CRAFTAX_OVERWORLD_SIZE][CRAFTAX_OVERWORLD_SIZE]; + int32_t ladder_down[2]; + int32_t ladder_up[2]; +} CraftaxOverworldFloor; + +typedef struct CraftaxWGInventory { + int32_t wood; + int32_t stone; + int32_t coal; + int32_t iron; + int32_t diamond; + int32_t sapling; + int32_t pickaxe; + int32_t sword; + int32_t bow; + int32_t arrows; + int32_t armour[4]; + int32_t torches; + int32_t ruby; + int32_t sapphire; + int32_t potions[6]; + int32_t books; +} CraftaxWGInventory; + +typedef struct CraftaxWGMobs3 { + int32_t position[CRAFTAX_WG_NUM_LEVELS][3][2]; + float health[CRAFTAX_WG_NUM_LEVELS][3]; + bool mask[CRAFTAX_WG_NUM_LEVELS][3]; + int32_t attack_cooldown[CRAFTAX_WG_NUM_LEVELS][3]; + int32_t type_id[CRAFTAX_WG_NUM_LEVELS][3]; +} CraftaxWGMobs3; + +typedef struct CraftaxWGMobs2 { + int32_t position[CRAFTAX_WG_NUM_LEVELS][2][2]; + float health[CRAFTAX_WG_NUM_LEVELS][2]; + bool mask[CRAFTAX_WG_NUM_LEVELS][2]; + int32_t attack_cooldown[CRAFTAX_WG_NUM_LEVELS][2]; + int32_t type_id[CRAFTAX_WG_NUM_LEVELS][2]; +} CraftaxWGMobs2; + +typedef struct CraftaxWorldState { + int32_t map[CRAFTAX_WG_NUM_LEVELS][CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE]; + int32_t item_map[CRAFTAX_WG_NUM_LEVELS][CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE]; + bool mob_map[CRAFTAX_WG_NUM_LEVELS][CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE]; + float light_map[CRAFTAX_WG_NUM_LEVELS][CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE]; + int32_t down_ladders[CRAFTAX_WG_NUM_LEVELS][2]; + int32_t up_ladders[CRAFTAX_WG_NUM_LEVELS][2]; + bool chests_opened[CRAFTAX_WG_NUM_LEVELS]; + int32_t monsters_killed[CRAFTAX_WG_NUM_LEVELS]; + + int32_t player_position[2]; + int32_t player_level; + int32_t player_direction; + + float player_health; + int32_t player_food; + int32_t player_drink; + int32_t player_energy; + int32_t player_mana; + bool is_sleeping; + bool is_resting; + + float player_recover; + float player_hunger; + float player_thirst; + float player_fatigue; + float player_recover_mana; + + int32_t player_xp; + int32_t player_dexterity; + int32_t player_strength; + int32_t player_intelligence; + + CraftaxWGInventory inventory; + + CraftaxWGMobs3 melee_mobs; + CraftaxWGMobs3 passive_mobs; + CraftaxWGMobs2 ranged_mobs; + + CraftaxWGMobs3 mob_projectiles; + int32_t mob_projectile_directions[CRAFTAX_WG_NUM_LEVELS][CRAFTAX_WG_MAX_MOB_PROJECTILES][2]; + CraftaxWGMobs3 player_projectiles; + int32_t player_projectile_directions[CRAFTAX_WG_NUM_LEVELS][CRAFTAX_WG_MAX_PLAYER_PROJECTILES][2]; + + int32_t growing_plants_positions[CRAFTAX_WG_MAX_GROWING_PLANTS][2]; + int32_t growing_plants_age[CRAFTAX_WG_MAX_GROWING_PLANTS]; + bool growing_plants_mask[CRAFTAX_WG_MAX_GROWING_PLANTS]; + + int32_t potion_mapping[6]; + bool learned_spells[2]; + + int32_t sword_enchantment; + int32_t bow_enchantment; + int32_t armour_enchantments[4]; + + int32_t boss_progress; + int32_t boss_timesteps_to_spawn_this_round; + + float light_level; + bool achievements[CRAFTAX_WG_NUM_ACHIEVEMENTS]; + uint32_t state_rng[2]; + int32_t timestep; + int32_t fractal_noise_angles[4]; +} CraftaxWorldState; + +typedef struct CraftaxSmoothGenConfig { + int32_t default_block; + int32_t sea_block; + int32_t coast_block; + int32_t mountain_block; + int32_t path_block; + int32_t inner_mountain_block; + int32_t ore_requirement_blocks[5]; + int32_t ores[5]; + float ore_chances[5]; + int32_t tree_requirement_block; + int32_t tree; + int32_t lava; + int32_t player_spawn; + int32_t valid_ladder; + bool ladder_up; + bool ladder_down; + float player_proximity_map_water_strength; + float player_proximity_map_water_max; + float player_proximity_map_mountain_strength; + float player_proximity_map_mountain_max; + float default_light; + float water_threshold; + float sand_threshold; + float tree_threshold_uniform; + float tree_threshold_perlin; +} CraftaxSmoothGenConfig; + +typedef struct CraftaxDungeonConfig { + int32_t special_block; + int32_t fountain_block; + int32_t rare_path_replacement_block; +} CraftaxDungeonConfig; + +static const CraftaxSmoothGenConfig CRAFTAX_SMOOTHGEN_CONFIGS[6] = { + { + CRAFTAX_WG_BLOCK_GRASS, + CRAFTAX_WG_BLOCK_WATER, + CRAFTAX_WG_BLOCK_SAND, + CRAFTAX_WG_BLOCK_STONE, + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_PATH, + {CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE}, + {CRAFTAX_WG_BLOCK_COAL, CRAFTAX_WG_BLOCK_IRON, CRAFTAX_WG_BLOCK_DIAMOND, CRAFTAX_WG_BLOCK_OUT_OF_BOUNDS, CRAFTAX_WG_BLOCK_OUT_OF_BOUNDS}, + {0.03f, 0.02f, 0.001f, 0.0f, 0.0f}, + CRAFTAX_WG_BLOCK_GRASS, + CRAFTAX_WG_BLOCK_TREE, + CRAFTAX_WG_BLOCK_LAVA, + CRAFTAX_WG_BLOCK_GRASS, + CRAFTAX_WG_BLOCK_PATH, + false, + true, + 5.0f, + 1.0f, + 5.0f, + 1.0f, + 1.0f, + 0.7f, + 0.6f, + 0.8f, + 0.5f, + }, + { + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_WATER, + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_STONE, + CRAFTAX_WG_BLOCK_STONE, + CRAFTAX_WG_BLOCK_STONE, + {CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE}, + {CRAFTAX_WG_BLOCK_COAL, CRAFTAX_WG_BLOCK_IRON, CRAFTAX_WG_BLOCK_DIAMOND, CRAFTAX_WG_BLOCK_SAPPHIRE, CRAFTAX_WG_BLOCK_RUBY}, + {0.04f, 0.02f, 0.005f, 0.0025f, 0.0025f}, + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_STALAGMITE, + CRAFTAX_WG_BLOCK_LAVA, + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_PATH, + true, + true, + 5.0f, + 1.0f, + 17.0f, + 1.5f, + 0.0f, + 0.7f, + 0.6f, + 0.8f, + 0.5f, + }, + { + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_WATER, + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_STONE, + CRAFTAX_WG_BLOCK_STONE, + CRAFTAX_WG_BLOCK_STONE, + {CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE}, + {CRAFTAX_WG_BLOCK_COAL, CRAFTAX_WG_BLOCK_IRON, CRAFTAX_WG_BLOCK_DIAMOND, CRAFTAX_WG_BLOCK_SAPPHIRE, CRAFTAX_WG_BLOCK_RUBY}, + {0.04f, 0.03f, 0.01f, 0.01f, 0.01f}, + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_STALAGMITE, + CRAFTAX_WG_BLOCK_LAVA, + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_PATH, + true, + true, + 5.0f, + 1.0f, + 17.0f, + 1.5f, + 0.0f, + 0.7f, + 0.6f, + 0.8f, + 0.5f, + }, + { + CRAFTAX_WG_BLOCK_FIRE_GRASS, + CRAFTAX_WG_BLOCK_LAVA, + CRAFTAX_WG_BLOCK_SAND, + CRAFTAX_WG_BLOCK_STONE, + CRAFTAX_WG_BLOCK_STONE, + CRAFTAX_WG_BLOCK_STONE, + {CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE}, + {CRAFTAX_WG_BLOCK_COAL, CRAFTAX_WG_BLOCK_IRON, CRAFTAX_WG_BLOCK_DIAMOND, CRAFTAX_WG_BLOCK_SAPPHIRE, CRAFTAX_WG_BLOCK_RUBY}, + {0.05f, 0.0f, 0.0f, 0.0f, 0.025f}, + CRAFTAX_WG_BLOCK_FIRE_GRASS, + CRAFTAX_WG_BLOCK_FIRE_TREE, + CRAFTAX_WG_BLOCK_LAVA, + CRAFTAX_WG_BLOCK_FIRE_GRASS, + CRAFTAX_WG_BLOCK_FIRE_GRASS, + true, + true, + 5.0f, + 1.0f, + 5.0f, + 1.0f, + 1.0f, + 0.5f, + 0.6f, + 0.8f, + 0.5f, + }, + { + CRAFTAX_WG_BLOCK_ICE_GRASS, + CRAFTAX_WG_BLOCK_WATER, + CRAFTAX_WG_BLOCK_ICE_GRASS, + CRAFTAX_WG_BLOCK_STONE, + CRAFTAX_WG_BLOCK_STONE, + CRAFTAX_WG_BLOCK_STONE, + {CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE, CRAFTAX_WG_BLOCK_STONE}, + {CRAFTAX_WG_BLOCK_COAL, CRAFTAX_WG_BLOCK_IRON, CRAFTAX_WG_BLOCK_DIAMOND, CRAFTAX_WG_BLOCK_SAPPHIRE, CRAFTAX_WG_BLOCK_RUBY}, + {0.0f, 0.0f, 0.005f, 0.02f, 0.0f}, + CRAFTAX_WG_BLOCK_ICE_GRASS, + CRAFTAX_WG_BLOCK_ICE_SHRUB, + CRAFTAX_WG_BLOCK_WATER, + CRAFTAX_WG_BLOCK_ICE_GRASS, + CRAFTAX_WG_BLOCK_ICE_GRASS, + true, + true, + 5.0f, + 1.0f, + 17.0f, + 1.5f, + 0.0f, + 0.5f, + 0.6f, + 0.4f, + 0.5f, + }, + { + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_WALL, + CRAFTAX_WG_BLOCK_WALL, + CRAFTAX_WG_BLOCK_WALL, + {CRAFTAX_WG_BLOCK_WALL, CRAFTAX_WG_BLOCK_GRAVE, CRAFTAX_WG_BLOCK_GRAVE, CRAFTAX_WG_BLOCK_WALL, CRAFTAX_WG_BLOCK_WALL}, + {CRAFTAX_WG_BLOCK_WALL_MOSS, CRAFTAX_WG_BLOCK_GRAVE2, CRAFTAX_WG_BLOCK_GRAVE3, CRAFTAX_WG_BLOCK_SAPPHIRE, CRAFTAX_WG_BLOCK_RUBY}, + {0.1f, 0.333f, 0.5f, 0.0f, 0.0f}, + CRAFTAX_WG_BLOCK_PATH, + CRAFTAX_WG_BLOCK_GRAVE, + CRAFTAX_WG_BLOCK_WALL, + CRAFTAX_WG_BLOCK_NECROMANCER, + CRAFTAX_WG_BLOCK_PATH, + false, + false, + 5.0f, + 1.0f, + 10.0f, + 10.0f, + 0.0f, + 0.7f, + 0.6f, + 0.95f, + -1.0f, + }, +}; + +static const CraftaxDungeonConfig CRAFTAX_DUNGEON_CONFIGS[3] = { + {CRAFTAX_WG_BLOCK_PATH, CRAFTAX_WG_BLOCK_FOUNTAIN, CRAFTAX_WG_BLOCK_PATH}, + {CRAFTAX_WG_BLOCK_ENCHANTMENT_TABLE_ICE, CRAFTAX_WG_BLOCK_WATER, CRAFTAX_WG_BLOCK_WATER}, + {CRAFTAX_WG_BLOCK_ENCHANTMENT_TABLE_FIRE, CRAFTAX_WG_BLOCK_FOUNTAIN, CRAFTAX_WG_BLOCK_PATH}, +}; + +static inline float craftax_wg_clampf(float value, float low, float high) { + if (value < low) { + return low; + } + if (value > high) { + return high; + } + return value; +} + +static inline int craftax_wg_clampi(int value, int low, int high) { + if (value < low) { + return low; + } + if (value > high) { + return high; + } + return value; +} + +static inline size_t craftax_wg_index(int row, int col) { + return (size_t)row * (size_t)CRAFTAX_WG_MAP_SIZE + (size_t)col; +} + +static inline void craftax_threefry_split3( + CraftaxThreefryKey key, + CraftaxThreefryKey* first, + CraftaxThreefryKey* second, + CraftaxThreefryKey* third +) { + CraftaxThreefryKey keys[3]; + craftax_threefry_split_n(key, keys, 3); + *first = keys[0]; + *second = keys[1]; + *third = keys[2]; +} + +static inline CraftaxThreefryKey craftax_worldgen_key_from_seed(uint32_t seed) { + CraftaxThreefryKey key = craftax_prng_key(seed); + CraftaxThreefryKey carry; + CraftaxThreefryKey reset_key; + craftax_threefry_split(key, &carry, &reset_key); + + CraftaxThreefryKey reset_carry; + CraftaxThreefryKey world_key; + craftax_threefry_split(reset_key, &reset_carry, &world_key); + return world_key; +} + +static inline CraftaxThreefryKey craftax_overworld_rng_from_seed(uint32_t seed) { + CraftaxThreefryKey world_key = craftax_worldgen_key_from_seed(seed); + CraftaxThreefryKey world_keys[7]; + craftax_threefry_split_n(world_key, world_keys, 7); + return world_keys[1]; +} + +static inline uint32_t craftax_randint_u32_at( + CraftaxThreefryKey key, + uint64_t index, + uint32_t minval, + uint32_t maxval +) { + CraftaxThreefryKey k1; + CraftaxThreefryKey k2; + craftax_threefry_split(key, &k1, &k2); + + uint32_t higher_bits = craftax_threefry_uniform_u32_at(k1, index); + uint32_t lower_bits = craftax_threefry_uniform_u32_at(k2, index); + uint32_t span = maxval > minval ? maxval - minval : 1u; + uint32_t multiplier = 65536u % span; + multiplier = (uint32_t)(((uint64_t)multiplier * (uint64_t)multiplier) % (uint64_t)span); + uint32_t random_offset = (uint32_t)( + (((uint64_t)(higher_bits % span) * (uint64_t)multiplier) + (uint64_t)(lower_bits % span)) + % (uint64_t)span + ); + return minval + random_offset; +} + +static inline int32_t craftax_randint_i32_at( + CraftaxThreefryKey key, + uint64_t index, + int32_t minval, + int32_t maxval +) { + return (int32_t)craftax_randint_u32_at( + key, + index, + (uint32_t)minval, + (uint32_t)maxval + ); +} + +static inline int craftax_choice_bool_flat( + CraftaxThreefryKey key, + const bool* valid, + int count +) { + int valid_count = 0; + int last_valid = 0; + for (int i = 0; i < count; i++) { + if (valid[i]) { + valid_count++; + last_valid = i; + } + } + if (valid_count == 0) { + return 0; + } + + float draw = (float)valid_count * (1.0f - craftax_threefry_uniform_f32(key)); + float cumulative = 0.0f; + for (int i = 0; i < count; i++) { + if (valid[i]) { + cumulative += 1.0f; + } + if (cumulative >= draw) { + return i; + } + } + return last_valid; +} + +static inline float craftax_torch_light_value(int row, int col, float default_light) { + float dr = (float)(row - 4); + float dc = (float)(col - 4); + float distance = sqrtf(dr * dr + dc * dc); + float torch = craftax_wg_clampf(1.0f - distance / 5.0f, 0.0f, 1.0f); + return torch * (1.0f - default_light) + default_light; +} + +static inline void craftax_apply_ladder_light( + float light_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + const int32_t ladder_up[2], + float default_light +) { + int start_row = ladder_up[0] - 4; + int start_col = ladder_up[1] - 4; + if (start_row < 0) { + start_row += CRAFTAX_WG_MAP_SIZE; + } + if (start_col < 0) { + start_col += CRAFTAX_WG_MAP_SIZE; + } + start_row = craftax_wg_clampi(start_row, 0, CRAFTAX_WG_MAP_SIZE - 9); + start_col = craftax_wg_clampi(start_col, 0, CRAFTAX_WG_MAP_SIZE - 9); + for (int row = 0; row < 9; row++) { + for (int col = 0; col < 9; col++) { + light_map[start_row + row][start_col + col] = + craftax_torch_light_value(row, col, default_light); + } + } +} + +static inline void craftax_add_lava_light( + float light_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + const bool lava_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + bool lava_emits_light +) { + if (!lava_emits_light) { + return; + } + + static const float kernel[3][3] = { + {0.2f, 0.7f, 0.2f}, + {0.7f, 1.0f, 0.7f}, + {0.2f, 0.7f, 0.2f}, + }; + + for (int row = 0; row < CRAFTAX_WG_MAP_SIZE; row++) { + for (int col = 0; col < CRAFTAX_WG_MAP_SIZE; col++) { + float add = 0.0f; + for (int kr = 0; kr < 3; kr++) { + int src_row = row + kr - 1; + if (src_row < 0 || src_row >= CRAFTAX_WG_MAP_SIZE) { + continue; + } + for (int kc = 0; kc < 3; kc++) { + int src_col = col + kc - 1; + if (src_col < 0 || src_col >= CRAFTAX_WG_MAP_SIZE) { + continue; + } + add += lava_map[src_row][src_col] ? kernel[kr][kc] : 0.0f; + } + } + light_map[row][col] = craftax_wg_clampf(light_map[row][col] + add, 0.0f, 1.0f); + } + } +} + +static inline int craftax_smooth_config_index_for_floor(int floor_idx) { + switch (floor_idx) { + case 0: + return 0; + case 2: + return 1; + case 5: + return 2; + case 6: + return 3; + case 7: + return 4; + case 8: + return 5; + default: + return -1; + } +} + +static inline int craftax_dungeon_config_index_for_floor(int floor_idx) { + switch (floor_idx) { + case 1: + return 0; + case 3: + return 1; + case 4: + return 2; + default: + return -1; + } +} + +static inline void craftax_generate_smoothworld_config( + CraftaxThreefryKey rng, + int config_idx, + int32_t map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + int32_t item_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + float light_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + int32_t ladder_down[2], + int32_t ladder_up[2] +) { + const CraftaxSmoothGenConfig* config = &CRAFTAX_SMOOTHGEN_CONFIGS[config_idx]; + const int size = CRAFTAX_WG_MAP_SIZE; + const int player_row = CRAFTAX_WG_MAP_SIZE / 2; + const int player_col = CRAFTAX_WG_MAP_SIZE / 2; + const size_t cells = CRAFTAX_WG_MAP_CELLS; + + CraftaxThreefryKey subkey; + float water[CRAFTAX_WG_MAP_CELLS]; + float mountain[CRAFTAX_WG_MAP_CELLS]; + float path_x[CRAFTAX_WG_MAP_CELLS]; + float tree_noise[CRAFTAX_WG_MAP_CELLS]; + bool lava_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE]; + + craftax_threefry_split(rng, &rng, &subkey); + craftax_generate_fractal_noise_2d(subkey, size, size, 3, 3, 1, 0.5f, 2, NULL, water); + + craftax_threefry_split(rng, &rng, &subkey); + (void)subkey; + + craftax_threefry_split(rng, &rng, &subkey); + craftax_generate_fractal_noise_2d(subkey, size, size, 3, 3, 1, 0.5f, 2, NULL, mountain); + + craftax_threefry_split(rng, &rng, &subkey); + craftax_generate_fractal_noise_2d(subkey, size, size, 6, 24, 1, 0.5f, 2, NULL, path_x); + + craftax_threefry_split(rng, &rng, &subkey); + (void)subkey; + + craftax_threefry_split(rng, &rng, &subkey); + CraftaxThreefryKey tree_uniform_key = rng; + craftax_generate_fractal_noise_2d(subkey, size, size, 12, 12, 1, 0.5f, 2, NULL, tree_noise); + + for (int row = 0; row < size; row++) { + int dr = row > player_row ? row - player_row : player_row - row; + for (int col = 0; col < size; col++) { + int dc = col > player_col ? col - player_col : player_col - col; + float distance = sqrtf((float)(dr * dr + dc * dc)); + float proximity_water = craftax_wg_clampf( + distance / config->player_proximity_map_water_strength, + 0.0f, + config->player_proximity_map_water_max + ); + float proximity_mountain = craftax_wg_clampf( + distance / config->player_proximity_map_mountain_strength, + 0.0f, + config->player_proximity_map_mountain_max + ); + size_t idx = craftax_wg_index(row, col); + + water[idx] = water[idx] + proximity_water - 1.0f; + int32_t block = water[idx] > config->water_threshold + ? config->sea_block + : config->default_block; + bool sand = water[idx] > config->sand_threshold && block != config->sea_block; + if (sand) { + block = config->coast_block; + } + + mountain[idx] = mountain[idx] + 0.05f + proximity_mountain - 1.0f; + if (mountain[idx] > 0.7f) { + block = config->mountain_block; + } + + bool path = mountain[idx] > 0.7f && path_x[idx] > 0.8f; + if (path) { + block = config->path_block; + } + + float path_y = path_x[craftax_wg_index(col, row)]; + path = mountain[idx] > 0.7f && path_y > 0.8f; + if (path) { + block = config->path_block; + } + + bool cave = mountain[idx] > 0.85f && water[idx] > 0.4f; + if (cave) { + block = config->inner_mountain_block; + } + + float tree_draw = craftax_threefry_uniform_f32_at(tree_uniform_key, idx); + bool tree = tree_noise[idx] > config->tree_threshold_perlin + && tree_draw > config->tree_threshold_uniform; + if (tree && block == config->tree_requirement_block) { + block = config->tree; + } + + map[row][col] = block; + item_map[row][col] = CRAFTAX_WG_ITEM_NONE; + light_map[row][col] = config->default_light; + } + } + + CraftaxThreefryKey ore_rng; + craftax_threefry_split(rng, &rng, &ore_rng); + for (int ore_index = 0; ore_index < 5; ore_index++) { + CraftaxThreefryKey ore_key; + craftax_threefry_split(ore_rng, &ore_rng, &ore_key); + for (int row = 0; row < size; row++) { + for (int col = 0; col < size; col++) { + size_t idx = craftax_wg_index(row, col); + bool is_ore = map[row][col] == config->ore_requirement_blocks[ore_index] + && craftax_threefry_uniform_f32_at(ore_key, idx) < config->ore_chances[ore_index]; + if (is_ore) { + map[row][col] = config->ores[ore_index]; + } + } + } + } + + for (int row = 0; row < size; row++) { + for (int col = 0; col < size; col++) { + size_t idx = craftax_wg_index(row, col); + lava_map[row][col] = mountain[idx] > 0.85f && tree_noise[idx] > 0.7f; + if (lava_map[row][col]) { + map[row][col] = config->lava; + } + } + } + + craftax_threefry_split(rng, &rng, &subkey); + bool valid_diamond[CRAFTAX_WG_MAP_CELLS]; + for (int row = 0; row < size; row++) { + for (int col = 0; col < size; col++) { + valid_diamond[craftax_wg_index(row, col)] = map[row][col] == CRAFTAX_WG_BLOCK_STONE; + } + } + int diamond_index = craftax_choice_bool_flat(subkey, valid_diamond, (int)cells); + map[diamond_index / size][diamond_index % size] = CRAFTAX_WG_BLOCK_STONE; + + map[player_row][player_col] = config->player_spawn; + + bool valid_ladder[CRAFTAX_WG_MAP_CELLS]; + for (int row = 0; row < size; row++) { + for (int col = 0; col < size; col++) { + valid_ladder[craftax_wg_index(row, col)] = map[row][col] == config->valid_ladder; + } + } + + craftax_threefry_split(rng, &rng, &subkey); + int ladder_down_index = craftax_choice_bool_flat(subkey, valid_ladder, (int)cells); + ladder_down[0] = ladder_down_index / size; + ladder_down[1] = ladder_down_index % size; + if (config->ladder_down) { + item_map[ladder_down[0]][ladder_down[1]] = CRAFTAX_WG_ITEM_LADDER_DOWN; + } + + craftax_threefry_split(rng, &rng, &subkey); + int ladder_up_index = craftax_choice_bool_flat(subkey, valid_ladder, (int)cells); + ladder_up[0] = ladder_up_index / size; + ladder_up[1] = ladder_up_index % size; + + craftax_apply_ladder_light(light_map, ladder_up, config->default_light); + craftax_add_lava_light(light_map, lava_map, config->lava == CRAFTAX_WG_BLOCK_LAVA); + + if (config->ladder_up) { + item_map[ladder_up[0]][ladder_up[1]] = CRAFTAX_WG_ITEM_LADDER_UP; + } +} + +static inline void craftax_generate_smoothworld_floor( + CraftaxThreefryKey seed_key, + int floor_idx, + int32_t map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + int32_t item_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + float light_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + int32_t ladder_down[2], + int32_t ladder_up[2] +) { + int config_idx = craftax_smooth_config_index_for_floor(floor_idx); + if (config_idx < 0) { + memset(map, 0, CRAFTAX_WG_MAP_CELLS * sizeof(int32_t)); + memset(item_map, 0, CRAFTAX_WG_MAP_CELLS * sizeof(int32_t)); + memset(light_map, 0, CRAFTAX_WG_MAP_CELLS * sizeof(float)); + ladder_down[0] = 0; + ladder_down[1] = 0; + ladder_up[0] = 0; + ladder_up[1] = 0; + return; + } + craftax_generate_smoothworld_config( + seed_key, + config_idx, + map, + item_map, + light_map, + ladder_down, + ladder_up + ); +} + +static inline void craftax_generate_dungeon_config( + CraftaxThreefryKey rng, + int config_idx, + int32_t map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + int32_t item_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + float light_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + int32_t ladder_down[2], + int32_t ladder_up[2] +) { + const CraftaxDungeonConfig* config = &CRAFTAX_DUNGEON_CONFIGS[config_idx]; + const int chunk_size = 16; + const int world_chunk_height = CRAFTAX_WG_MAP_SIZE / chunk_size; + const int num_rooms = 8; + const int min_room_size = 5; + const int max_room_size = 10; + const int padded_size = CRAFTAX_WG_MAP_SIZE + 2 * max_room_size; + + int32_t padded_map[68][68]; + int32_t padded_item_map[68][68]; + bool room_occupancy_chunks[9]; + int32_t room_sizes[8][2]; + int32_t room_positions[8][2]; + + for (int row = 0; row < padded_size; row++) { + for (int col = 0; col < padded_size; col++) { + bool inner = row >= max_room_size + && row < max_room_size + CRAFTAX_WG_MAP_SIZE + && col >= max_room_size + && col < max_room_size + CRAFTAX_WG_MAP_SIZE; + padded_map[row][col] = inner ? CRAFTAX_WG_BLOCK_WALL : 0; + padded_item_map[row][col] = CRAFTAX_WG_ITEM_NONE; + } + } + for (int i = 0; i < 9; i++) { + room_occupancy_chunks[i] = true; + } + + CraftaxThreefryKey room_scan_ignored_key; + CraftaxThreefryKey room_size_key; + craftax_threefry_split3(rng, &rng, &room_scan_ignored_key, &room_size_key); + (void)room_scan_ignored_key; + for (int room = 0; room < num_rooms; room++) { + room_sizes[room][0] = craftax_randint_i32_at(room_size_key, (uint64_t)room * 2u, min_room_size, max_room_size); + room_sizes[room][1] = craftax_randint_i32_at(room_size_key, (uint64_t)room * 2u + 1u, min_room_size, max_room_size); + } + + CraftaxThreefryKey room_rng; + craftax_threefry_split(rng, &rng, &room_rng); + + for (int room_index = 0; room_index < num_rooms; room_index++) { + CraftaxThreefryKey choice_key; + craftax_threefry_split(room_rng, &room_rng, &choice_key); + int room_chunk = craftax_choice_bool_flat(choice_key, room_occupancy_chunks, 9); + room_occupancy_chunks[room_chunk] = false; + + int room_row = (room_chunk % world_chunk_height) * chunk_size + max_room_size; + int room_col = (room_chunk / world_chunk_height) * chunk_size + max_room_size; + CraftaxThreefryKey position_key; + craftax_threefry_split(room_rng, &room_rng, &position_key); + room_row += craftax_randint_i32_at(position_key, 0, 0, chunk_size - min_room_size); + room_col += craftax_randint_i32_at(position_key, 1, 0, chunk_size - min_room_size); + room_positions[room_index][0] = room_row; + room_positions[room_index][1] = room_col; + + for (int row = 0; row < max_room_size; row++) { + for (int col = 0; col < max_room_size; col++) { + if (row < room_sizes[room_index][0] && col < room_sizes[room_index][1]) { + padded_map[room_row + row][room_col + col] = CRAFTAX_WG_BLOCK_PATH; + } + } + } + + padded_item_map[room_row][room_col] = CRAFTAX_WG_ITEM_TORCH; + padded_item_map[room_row + room_sizes[room_index][0] - 1][room_col] = CRAFTAX_WG_ITEM_TORCH; + padded_item_map[room_row][room_col + room_sizes[room_index][1] - 1] = CRAFTAX_WG_ITEM_TORCH; + padded_item_map[room_row + room_sizes[room_index][0] - 1][room_col + room_sizes[room_index][1] - 1] = CRAFTAX_WG_ITEM_TORCH; + + CraftaxThreefryKey chest_key; + craftax_threefry_split(room_rng, &room_rng, &chest_key); + int chest_row = craftax_randint_i32_at(chest_key, 0, 1, room_sizes[room_index][0] - 1); + int chest_col = craftax_randint_i32_at(chest_key, 1, 1, room_sizes[room_index][1] - 1); + padded_map[room_row + chest_row][room_col + chest_col] = CRAFTAX_WG_BLOCK_CHEST; + + CraftaxThreefryKey fountain_key; + CraftaxThreefryKey fountain_uniform_key; + craftax_threefry_split3(room_rng, &room_rng, &fountain_key, &fountain_uniform_key); + int fountain_row = craftax_randint_i32_at(fountain_key, 0, 1, room_sizes[room_index][0] - 1); + int fountain_col = craftax_randint_i32_at(fountain_key, 1, 1, room_sizes[room_index][1] - 1); + bool room_has_fountain = craftax_threefry_uniform_f32(fountain_uniform_key) > 0.5f; + if (room_has_fountain) { + padded_map[room_row + fountain_row][room_col + fountain_col] = config->fountain_block; + } + } + + CraftaxThreefryKey path_rng; + craftax_threefry_split(rng, &rng, &path_rng); + bool included_rooms_mask[8] = {false, false, false, false, false, false, false, true}; + + for (int path_index = 0; path_index < num_rooms; path_index++) { + int source_row = room_positions[path_index][0]; + int source_col = room_positions[path_index][1]; + + CraftaxThreefryKey sink_key; + craftax_threefry_split(path_rng, &path_rng, &sink_key); + int sink_index = craftax_choice_bool_flat(sink_key, included_rooms_mask, num_rooms); + int sink_row = room_positions[sink_index][0]; + int sink_col = room_positions[sink_index][1]; + + int horizontal_distance = sink_col - source_col; + int horizontal_sign = (horizontal_distance > 0) - (horizontal_distance < 0); + if (horizontal_sign != 0) { + int abs_distance = horizontal_distance > 0 ? horizontal_distance : -horizontal_distance; + for (int col = 0; col < padded_size; col++) { + int path_index_col = (col - source_col) * horizontal_sign; + bool horizontal_mask = path_index_col >= 0 + && path_index_col <= abs_distance + && padded_map[source_row][col] == CRAFTAX_WG_BLOCK_WALL; + if (horizontal_mask) { + padded_map[source_row][col] = CRAFTAX_WG_BLOCK_PATH; + } + } + } + + int vertical_distance = sink_row - source_row; + int vertical_sign = (vertical_distance > 0) - (vertical_distance < 0); + if (vertical_sign != 0) { + int abs_distance = vertical_distance > 0 ? vertical_distance : -vertical_distance; + for (int row = 0; row < padded_size; row++) { + int path_index_row = (row - source_row) * vertical_sign; + bool vertical_mask = path_index_row >= 0 + && path_index_row <= abs_distance + && padded_map[row][sink_col] == CRAFTAX_WG_BLOCK_WALL; + if (vertical_mask) { + padded_map[row][sink_col] = CRAFTAX_WG_BLOCK_PATH; + } + } + } + + CraftaxThreefryKey unused_left; + CraftaxThreefryKey next_path_rng; + craftax_threefry_split(path_rng, &unused_left, &next_path_rng); + path_rng = next_path_rng; + included_rooms_mask[path_index] = true; + } + + int special_row = room_positions[0][0] + 2; + int special_col = room_positions[0][1] + 2; + padded_map[special_row][special_col] = config->special_block; + + for (int row = 0; row < CRAFTAX_WG_MAP_SIZE; row++) { + for (int col = 0; col < CRAFTAX_WG_MAP_SIZE; col++) { + map[row][col] = padded_map[row + max_room_size][col + max_room_size]; + item_map[row][col] = padded_item_map[row + max_room_size][col + max_room_size]; + } + } + + bool adjacent_path[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE]; + for (int row = 0; row < CRAFTAX_WG_MAP_SIZE; row++) { + for (int col = 0; col < CRAFTAX_WG_MAP_SIZE; col++) { + bool adjacent = map[row][col] != CRAFTAX_WG_BLOCK_WALL; + adjacent = adjacent || (row > 0 && map[row - 1][col] != CRAFTAX_WG_BLOCK_WALL); + adjacent = adjacent || (row + 1 < CRAFTAX_WG_MAP_SIZE && map[row + 1][col] != CRAFTAX_WG_BLOCK_WALL); + adjacent = adjacent || (col > 0 && map[row][col - 1] != CRAFTAX_WG_BLOCK_WALL); + adjacent = adjacent || (col + 1 < CRAFTAX_WG_MAP_SIZE && map[row][col + 1] != CRAFTAX_WG_BLOCK_WALL); + adjacent_path[row][col] = adjacent; + } + } + + CraftaxThreefryKey rare_key; + craftax_threefry_split(rng, &rng, &rare_key); + for (int row = 0; row < CRAFTAX_WG_MAP_SIZE; row++) { + for (int col = 0; col < CRAFTAX_WG_MAP_SIZE; col++) { + size_t idx = craftax_wg_index(row, col); + bool rare = (1.0f - craftax_threefry_uniform_f32_at(rare_key, idx)) > 0.9f; + int32_t wall_map = rare ? CRAFTAX_WG_BLOCK_WALL_MOSS : CRAFTAX_WG_BLOCK_WALL; + bool rare_path = rare && map[row][col] == CRAFTAX_WG_BLOCK_PATH && item_map[row][col] == CRAFTAX_WG_ITEM_NONE; + int32_t path_map = rare_path ? config->rare_path_replacement_block : map[row][col]; + bool is_wall_map = map[row][col] == CRAFTAX_WG_BLOCK_WALL && adjacent_path[row][col]; + bool is_darkness_map = !adjacent_path[row][col]; + + if (is_darkness_map) { + map[row][col] = CRAFTAX_WG_BLOCK_DARKNESS; + } else if (is_wall_map) { + map[row][col] = wall_map; + } else { + map[row][col] = path_map; + } + light_map[row][col] = 1.0f; + } + } + + bool valid_ladder[CRAFTAX_WG_MAP_CELLS]; + for (int row = 0; row < CRAFTAX_WG_MAP_SIZE; row++) { + for (int col = 0; col < CRAFTAX_WG_MAP_SIZE; col++) { + valid_ladder[craftax_wg_index(row, col)] = map[row][col] == CRAFTAX_WG_BLOCK_PATH; + } + } + + CraftaxThreefryKey ladder_down_key; + craftax_threefry_split(rng, &rng, &ladder_down_key); + int ladder_down_index = craftax_choice_bool_flat(ladder_down_key, valid_ladder, CRAFTAX_WG_MAP_CELLS); + ladder_down[0] = ladder_down_index / CRAFTAX_WG_MAP_SIZE; + ladder_down[1] = ladder_down_index % CRAFTAX_WG_MAP_SIZE; + item_map[ladder_down[0]][ladder_down[1]] = CRAFTAX_WG_ITEM_LADDER_DOWN; + + CraftaxThreefryKey ladder_up_key; + craftax_threefry_split(rng, &rng, &ladder_up_key); + int ladder_up_index = craftax_choice_bool_flat(ladder_up_key, valid_ladder, CRAFTAX_WG_MAP_CELLS); + ladder_up[0] = ladder_up_index / CRAFTAX_WG_MAP_SIZE; + ladder_up[1] = ladder_up_index % CRAFTAX_WG_MAP_SIZE; + item_map[ladder_up[0]][ladder_up[1]] = CRAFTAX_WG_ITEM_LADDER_UP; +} + +static inline void craftax_generate_dungeon_floor( + CraftaxThreefryKey seed_key, + int floor_idx, + int32_t map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + int32_t item_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + float light_map[CRAFTAX_WG_MAP_SIZE][CRAFTAX_WG_MAP_SIZE], + int32_t ladder_down[2], + int32_t ladder_up[2] +) { + int config_idx = craftax_dungeon_config_index_for_floor(floor_idx); + if (config_idx < 0) { + memset(map, 0, CRAFTAX_WG_MAP_CELLS * sizeof(int32_t)); + memset(item_map, 0, CRAFTAX_WG_MAP_CELLS * sizeof(int32_t)); + memset(light_map, 0, CRAFTAX_WG_MAP_CELLS * sizeof(float)); + ladder_down[0] = 0; + ladder_down[1] = 0; + ladder_up[0] = 0; + ladder_up[1] = 0; + return; + } + craftax_generate_dungeon_config( + seed_key, + config_idx, + map, + item_map, + light_map, + ladder_down, + ladder_up + ); +} + +static inline void craftax_permutation_6(CraftaxThreefryKey key, int32_t out[6]) { + CraftaxThreefryKey carry; + CraftaxThreefryKey sort_key; + craftax_threefry_split(key, &carry, &sort_key); + (void)carry; + + uint32_t keys[6]; + for (int i = 0; i < 6; i++) { + keys[i] = craftax_threefry_uniform_u32_at(sort_key, (uint64_t)i); + out[i] = i; + } + + for (int i = 1; i < 6; i++) { + uint32_t key_value = keys[i]; + int32_t value = out[i]; + int j = i - 1; + while (j >= 0 && keys[j] > key_value) { + keys[j + 1] = keys[j]; + out[j + 1] = out[j]; + j--; + } + keys[j + 1] = key_value; + out[j + 1] = value; + } +} + +static inline float craftax_calculate_initial_light_level(void) { + float progress = 0.3f; + float c = cosf(CRAFTAX_WG_PI * progress); + return 1.0f - powf(fabsf(c), 3.0f); +} + +static inline void craftax_init_empty_mobs3(CraftaxWGMobs3* mobs) { + for (int level = 0; level < CRAFTAX_WG_NUM_LEVELS; level++) { + for (int mob = 0; mob < 3; mob++) { + mobs->health[level][mob] = 1.0f; + } + } +} + +static inline void craftax_init_empty_mobs2(CraftaxWGMobs2* mobs) { + for (int level = 0; level < CRAFTAX_WG_NUM_LEVELS; level++) { + for (int mob = 0; mob < 2; mob++) { + mobs->health[level][mob] = 1.0f; + } + } +} + +static inline void craftax_generate_world_from_key( + CraftaxThreefryKey rng, + CraftaxWorldState* out +) { + memset(out, 0, sizeof(*out)); + + CraftaxThreefryKey smooth_split[7]; + craftax_threefry_split_n(rng, smooth_split, 7); + rng = smooth_split[0]; + + static const int smooth_floor_order[6] = {0, 2, 5, 6, 7, 8}; + for (int i = 0; i < 6; i++) { + int level = smooth_floor_order[i]; + craftax_generate_smoothworld_config( + smooth_split[i + 1], + i, + out->map[level], + out->item_map[level], + out->light_map[level], + out->down_ladders[level], + out->up_ladders[level] + ); + } + + CraftaxThreefryKey dungeon_split[4]; + craftax_threefry_split_n(rng, dungeon_split, 4); + rng = dungeon_split[0]; + + static const int dungeon_floor_order[3] = {1, 3, 4}; + for (int i = 0; i < 3; i++) { + int level = dungeon_floor_order[i]; + craftax_generate_dungeon_config( + dungeon_split[i + 1], + i, + out->map[level], + out->item_map[level], + out->light_map[level], + out->down_ladders[level], + out->up_ladders[level] + ); + } + + craftax_init_empty_mobs3(&out->melee_mobs); + craftax_init_empty_mobs3(&out->passive_mobs); + craftax_init_empty_mobs2(&out->ranged_mobs); + craftax_init_empty_mobs3(&out->mob_projectiles); + craftax_init_empty_mobs3(&out->player_projectiles); + for (int level = 0; level < CRAFTAX_WG_NUM_LEVELS; level++) { + for (int projectile = 0; projectile < CRAFTAX_WG_MAX_MOB_PROJECTILES; projectile++) { + out->mob_projectile_directions[level][projectile][0] = 1; + out->mob_projectile_directions[level][projectile][1] = 1; + } + for (int projectile = 0; projectile < CRAFTAX_WG_MAX_PLAYER_PROJECTILES; projectile++) { + out->player_projectile_directions[level][projectile][0] = 1; + out->player_projectile_directions[level][projectile][1] = 1; + } + } + + CraftaxThreefryKey potion_key; + craftax_threefry_split(rng, &rng, &potion_key); + craftax_permutation_6(potion_key, out->potion_mapping); + + CraftaxThreefryKey state_key; + craftax_threefry_split(rng, &rng, &state_key); + (void)rng; + out->state_rng[0] = state_key.word[0]; + out->state_rng[1] = state_key.word[1]; + + out->monsters_killed[0] = 10; + out->player_position[0] = CRAFTAX_WG_MAP_SIZE / 2; + out->player_position[1] = CRAFTAX_WG_MAP_SIZE / 2; + out->player_level = 0; + out->player_direction = CRAFTAX_WG_ACTION_UP; + out->player_health = 9.0f; + out->player_food = 9; + out->player_drink = 9; + out->player_energy = 9; + out->player_mana = 9; + out->player_dexterity = 1; + out->player_strength = 1; + out->player_intelligence = 1; + out->boss_timesteps_to_spawn_this_round = CRAFTAX_WG_BOSS_FIGHT_SPAWN_TURNS; + out->light_level = craftax_calculate_initial_light_level(); +} + +static inline void craftax_generate_world_from_seed( + uint32_t seed, + CraftaxWorldState* out +) { + craftax_generate_world_from_key(craftax_worldgen_key_from_seed(seed), out); +} + +static inline void craftax_generate_overworld_from_rng( + CraftaxThreefryKey rng, + CraftaxOverworldFloor* out +) { + craftax_generate_smoothworld_config( + rng, + 0, + out->map, + out->item_map, + out->light_map, + out->ladder_down, + out->ladder_up + ); +} + +static inline void craftax_generate_overworld_from_seed( + uint32_t seed, + CraftaxOverworldFloor* out +) { + craftax_generate_overworld_from_rng(craftax_overworld_rng_from_seed(seed), out); +} + +static inline int craftax_wg_jax_index(int32_t index, int32_t size) { + if (index < 0) { + index += size; + } + if (index < 0) { + return 0; + } + if (index >= size) { + return size - 1; + } + return index; +} + +static inline bool craftax_wg_scatter_index( + int32_t index, + int32_t size, + int* mapped_index +) { + if (index < -size || index >= size) { + return false; + } + *mapped_index = index < 0 ? index + size : index; + return true; +} + +static inline bool craftax_wg_is_boss_vulnerable( + const CraftaxWorldState* state +) { + int level = craftax_wg_jax_index(state->player_level, CRAFTAX_WG_NUM_LEVELS); + bool has_melee = false; + bool has_ranged = false; + for (int i = 0; i < CRAFTAX_WG_MAX_MELEE_MOBS; i++) { + has_melee = has_melee || state->melee_mobs.mask[level][i]; + } + for (int i = 0; i < CRAFTAX_WG_MAX_RANGED_MOBS; i++) { + has_ranged = has_ranged || state->ranged_mobs.mask[level][i]; + } + return !has_melee + && !has_ranged + && state->boss_timesteps_to_spawn_this_round <= 0; +} + +static inline void craftax_encode_mobs3_observation( + const CraftaxWorldState* state, + const CraftaxWGMobs3* mobs, + int mob_class_index, + int channels, + int mob_channels_offset, + float* obs +) { + int level = craftax_wg_jax_index(state->player_level, CRAFTAX_WG_NUM_LEVELS); + for (int i = 0; i < 3; i++) { + int local_row = mobs->position[level][i][0] + - state->player_position[0] + + CRAFTAX_WG_OBS_ROWS / 2; + int local_col = mobs->position[level][i][1] + - state->player_position[1] + + CRAFTAX_WG_OBS_COLS / 2; + int type_id = mobs->type_id[level][i]; + int scatter_row; + int scatter_col; + if (!craftax_wg_scatter_index( + local_row, + CRAFTAX_WG_OBS_ROWS, + &scatter_row + ) + || !craftax_wg_scatter_index( + local_col, + CRAFTAX_WG_OBS_COLS, + &scatter_col + ) + || type_id < 0 + || type_id >= CRAFTAX_WG_NUM_MOB_TYPES) { + continue; + } + + bool on_screen = local_row >= 0 + && local_row < CRAFTAX_WG_OBS_ROWS + && local_col >= 0 + && local_col < CRAFTAX_WG_OBS_COLS; + int world_row = mobs->position[level][i][0]; + int world_col = mobs->position[level][i][1]; + bool in_bounds = world_row >= 0 + && world_row < CRAFTAX_WG_MAP_SIZE + && world_col >= 0 + && world_col < CRAFTAX_WG_MAP_SIZE; + float light = in_bounds ? state->light_map[level][world_row][world_col] : 0.0f; + bool visible = light > 0.05f; + int obs_base = (scatter_row * CRAFTAX_WG_OBS_COLS + scatter_col) * channels; + int channel = mob_channels_offset + + mob_class_index * CRAFTAX_WG_NUM_MOB_TYPES + + type_id; + obs[obs_base + channel] = + mobs->mask[level][i] && on_screen && visible ? 1.0f : 0.0f; + } +} + +static inline void craftax_encode_mobs2_observation( + const CraftaxWorldState* state, + const CraftaxWGMobs2* mobs, + int mob_class_index, + int channels, + int mob_channels_offset, + float* obs +) { + int level = craftax_wg_jax_index(state->player_level, CRAFTAX_WG_NUM_LEVELS); + for (int i = 0; i < 2; i++) { + int local_row = mobs->position[level][i][0] + - state->player_position[0] + + CRAFTAX_WG_OBS_ROWS / 2; + int local_col = mobs->position[level][i][1] + - state->player_position[1] + + CRAFTAX_WG_OBS_COLS / 2; + int type_id = mobs->type_id[level][i]; + int scatter_row; + int scatter_col; + if (!craftax_wg_scatter_index( + local_row, + CRAFTAX_WG_OBS_ROWS, + &scatter_row + ) + || !craftax_wg_scatter_index( + local_col, + CRAFTAX_WG_OBS_COLS, + &scatter_col + ) + || type_id < 0 + || type_id >= CRAFTAX_WG_NUM_MOB_TYPES) { + continue; + } + + bool on_screen = local_row >= 0 + && local_row < CRAFTAX_WG_OBS_ROWS + && local_col >= 0 + && local_col < CRAFTAX_WG_OBS_COLS; + int world_row = mobs->position[level][i][0]; + int world_col = mobs->position[level][i][1]; + bool in_bounds = world_row >= 0 + && world_row < CRAFTAX_WG_MAP_SIZE + && world_col >= 0 + && world_col < CRAFTAX_WG_MAP_SIZE; + float light = in_bounds ? state->light_map[level][world_row][world_col] : 0.0f; + bool visible = light > 0.05f; + int obs_base = (scatter_row * CRAFTAX_WG_OBS_COLS + scatter_col) * channels; + int channel = mob_channels_offset + + mob_class_index * CRAFTAX_WG_NUM_MOB_TYPES + + type_id; + obs[obs_base + channel] = + mobs->mask[level][i] && on_screen && visible ? 1.0f : 0.0f; + } +} + +static inline void craftax_encode_reset_observation( + const CraftaxWorldState* state, + float* obs +) { + memset(obs, 0, CRAFTAX_WG_OBS_SIZE * sizeof(float)); + + const int channels = CRAFTAX_WG_NUM_BLOCK_TYPES + + CRAFTAX_WG_NUM_ITEM_TYPES + + CRAFTAX_WG_NUM_MOB_CLASSES * CRAFTAX_WG_NUM_MOB_TYPES + + 1; + const int item_channels_offset = CRAFTAX_WG_NUM_BLOCK_TYPES; + const int mob_channels_offset = CRAFTAX_WG_NUM_BLOCK_TYPES + CRAFTAX_WG_NUM_ITEM_TYPES; + const int light_channel_offset = mob_channels_offset + + CRAFTAX_WG_NUM_MOB_CLASSES * CRAFTAX_WG_NUM_MOB_TYPES; + const int obs_map_size = CRAFTAX_WG_OBS_ROWS * CRAFTAX_WG_OBS_COLS * channels; + const int top = state->player_position[0] - CRAFTAX_WG_OBS_ROWS / 2; + const int left = state->player_position[1] - CRAFTAX_WG_OBS_COLS / 2; + const int level = state->player_level; + + for (int row = 0; row < CRAFTAX_WG_OBS_ROWS; row++) { + for (int col = 0; col < CRAFTAX_WG_OBS_COLS; col++) { + int world_row = top + row; + int world_col = left + col; + int obs_base = (row * CRAFTAX_WG_OBS_COLS + col) * channels; + bool in_bounds = world_row >= 0 + && world_row < CRAFTAX_WG_MAP_SIZE + && world_col >= 0 + && world_col < CRAFTAX_WG_MAP_SIZE; + float light = in_bounds ? state->light_map[level][world_row][world_col] : 0.0f; + bool visible = light > 0.05f; + + if (visible) { + int block = state->map[level][world_row][world_col]; + if (block >= 0 && block < CRAFTAX_WG_NUM_BLOCK_TYPES) { + obs[obs_base + block] = 1.0f; + } + + int item = state->item_map[level][world_row][world_col]; + if (item >= 0 && item < CRAFTAX_WG_NUM_ITEM_TYPES) { + obs[obs_base + item_channels_offset + item] = 1.0f; + } + } + + obs[obs_base + light_channel_offset] = visible ? 1.0f : 0.0f; + } + } + + craftax_encode_mobs3_observation( + state, + &state->melee_mobs, + 0, + channels, + mob_channels_offset, + obs + ); + craftax_encode_mobs3_observation( + state, + &state->passive_mobs, + 1, + channels, + mob_channels_offset, + obs + ); + craftax_encode_mobs2_observation( + state, + &state->ranged_mobs, + 2, + channels, + mob_channels_offset, + obs + ); + craftax_encode_mobs3_observation( + state, + &state->mob_projectiles, + 3, + channels, + mob_channels_offset, + obs + ); + craftax_encode_mobs3_observation( + state, + &state->player_projectiles, + 4, + channels, + mob_channels_offset, + obs + ); + + int index = obs_map_size; + obs[index++] = sqrtf((float)state->inventory.wood) / 10.0f; + obs[index++] = sqrtf((float)state->inventory.stone) / 10.0f; + obs[index++] = sqrtf((float)state->inventory.coal) / 10.0f; + obs[index++] = sqrtf((float)state->inventory.iron) / 10.0f; + obs[index++] = sqrtf((float)state->inventory.diamond) / 10.0f; + obs[index++] = sqrtf((float)state->inventory.sapphire) / 10.0f; + obs[index++] = sqrtf((float)state->inventory.ruby) / 10.0f; + obs[index++] = sqrtf((float)state->inventory.sapling) / 10.0f; + obs[index++] = sqrtf((float)state->inventory.torches) / 10.0f; + obs[index++] = sqrtf((float)state->inventory.arrows) / 10.0f; + obs[index++] = (float)state->inventory.books / 2.0f; + obs[index++] = (float)state->inventory.pickaxe / 4.0f; + obs[index++] = (float)state->inventory.sword / 4.0f; + obs[index++] = (float)state->sword_enchantment; + obs[index++] = (float)state->bow_enchantment; + obs[index++] = (float)state->inventory.bow; + + for (int i = 0; i < 6; i++) { + obs[index++] = sqrtf((float)state->inventory.potions[i]) / 10.0f; + } + + obs[index++] = state->player_health / 10.0f; + obs[index++] = (float)state->player_food / 10.0f; + obs[index++] = (float)state->player_drink / 10.0f; + obs[index++] = (float)state->player_energy / 10.0f; + obs[index++] = (float)state->player_mana / 10.0f; + obs[index++] = (float)state->player_xp / 10.0f; + obs[index++] = (float)state->player_dexterity / 10.0f; + obs[index++] = (float)state->player_strength / 10.0f; + obs[index++] = (float)state->player_intelligence / 10.0f; + + int direction_index = state->player_direction - 1; + for (int i = 0; i < 4; i++) { + obs[index++] = i == direction_index ? 1.0f : 0.0f; + } + + for (int i = 0; i < 4; i++) { + obs[index++] = (float)state->inventory.armour[i] / 2.0f; + } + for (int i = 0; i < 4; i++) { + obs[index++] = (float)state->armour_enchantments[i]; + } + + obs[index++] = state->light_level; + obs[index++] = state->is_sleeping ? 1.0f : 0.0f; + obs[index++] = state->is_resting ? 1.0f : 0.0f; + obs[index++] = state->learned_spells[0] ? 1.0f : 0.0f; + obs[index++] = state->learned_spells[1] ? 1.0f : 0.0f; + obs[index++] = (float)state->player_level / 10.0f; + obs[index++] = state->monsters_killed[level] >= CRAFTAX_WG_MONSTERS_KILLED_TO_CLEAR_LEVEL ? 1.0f : 0.0f; + obs[index++] = craftax_wg_is_boss_vulnerable(state) ? 1.0f : 0.0f; +} diff --git a/ocean/craftax_classic/binding.c b/ocean/craftax_classic/binding.c new file mode 100644 index 0000000000..c32b4c71ed --- /dev/null +++ b/ocean/craftax_classic/binding.c @@ -0,0 +1,38 @@ +#include "craftax_classic.h" + +#define OBS_SIZE 1345 +#define NUM_ATNS 1 +#define ACT_SIZES {17} +#define OBS_TENSOR_T FloatTensor + +#define Env CraftaxClassic +#include "vecenv.h" + +void my_init(Env* env, Dict* kwargs) { + // Process-wide reset pool size. First caller wins (setter is idempotent). + // 0 disables caching (baseline: generate_world on every reset). + int reset_pool_size = 0; + DictItem* item = dict_get_unsafe(kwargs, "reset_pool_size"); + if (item != NULL) reset_pool_size = (int)item->value; + craftax_classic_set_reset_pool_size(reset_pool_size); + c_init(env); +} + +void my_log(Log* log, Dict* out) { + dict_set(out, "perf", log->perf); + dict_set(out, "score", log->score); + dict_set(out, "episode_return", log->episode_return); + dict_set(out, "episode_length", log->episode_length); + + static const char* ACH_NAMES[NUM_ACHIEVEMENTS] = { + "collect_wood", "place_table", "eat_cow", "collect_sapling", + "collect_drink", "make_wood_pick", "make_wood_sword","place_plant", + "defeat_zombie", "collect_stone", "place_stone", "eat_plant", + "defeat_skeleton","make_stone_pick","make_stone_sword","wake_up", + "place_furnace", "collect_coal", "collect_iron", "collect_diamond", + "make_iron_pick", "make_iron_sword", + }; + for (int i = 0; i < NUM_ACHIEVEMENTS; i++) { + dict_set(out, ACH_NAMES[i], log->achievements[i]); + } +} diff --git a/ocean/craftax_classic/craftax_classic.h b/ocean/craftax_classic/craftax_classic.h new file mode 100644 index 0000000000..1ece7e49bb --- /dev/null +++ b/ocean/craftax_classic/craftax_classic.h @@ -0,0 +1,1232 @@ +// Craftax-Classic environment for PufferLib Ocean. +// +// Single-header per-env implementation. PufferLib's vec layer owns the +// observation/action/reward/terminal buffers and parallelizes c_step +// across env instances via OpenMP; this file never allocates its own +// threads or batches. +// +// Game rules follow Matthews et al. 2024 "Craftax-Classic" (ICML 2024). +// This port is derived from the CPU port at github.com/Infatoshi/craftax.c +// (47.8M SPS standalone), restructured to match the Ocean conventions +// used by breakout/drmario/etc. +// +// Observation: 1345 float32: +// - 63 tiles (7x9 local view) x 21 channels (17 block one-hot + 4 mob) = 1323 +// - 12 inventory (0..9) / 10 +// - 4 intrinsics (health, food, drink, energy / 10) +// - 4 direction one-hot +// - 1 light level [0, 1] +// - 1 is_sleeping {0, 1} +// Matches the JAX/CUDA Craftax-Classic-Symbolic-v1 layout exactly. +// +// Action: 1 discrete in 0..16 (NOOP, 4 moves, DO, SLEEP, +// 4 place, 3 make-pick, 3 make-sword). + +#pragma once +#include +#include +#include +#include +#include +#include +#include "raylib.h" + +// ============================================================ +// Constants +// ============================================================ +#define MAP_SIZE 64 +#define MAP_PACKED_ROW 32 +#define MAP_PACKED_SIZE (MAP_SIZE * MAP_PACKED_ROW) + +#define MAX_ZOMBIES 3 +#define MAX_COWS 3 +#define MAX_SKELETONS 2 +#define MAX_ARROWS 3 +#define MAX_PLANTS 10 +#define NUM_ACHIEVEMENTS 22 +#define NUM_ACTIONS 17 +#define NUM_BLOCK_TYPES 17 +#define OBS_DIM 1345 +#define NUM_INVENTORY 12 +#define MAX_TIMESTEPS 10000 +#define DAY_LENGTH 300 +#define MOB_DESPAWN_DIST 14 + +// Block types +#define BLK_INVALID 0 +#define BLK_OUT_OF_BOUNDS 1 +#define BLK_GRASS 2 +#define BLK_WATER 3 +#define BLK_STONE 4 +#define BLK_TREE 5 +#define BLK_WOOD 6 +#define BLK_PATH 7 +#define BLK_COAL 8 +#define BLK_IRON 9 +#define BLK_DIAMOND 10 +#define BLK_TABLE 11 +#define BLK_FURNACE 12 +#define BLK_SAND 13 +#define BLK_LAVA 14 +#define BLK_PLANT 15 +#define BLK_RIPE_PLANT 16 + +// Actions +#define ACT_NOOP 0 +#define ACT_LEFT 1 +#define ACT_RIGHT 2 +#define ACT_UP 3 +#define ACT_DOWN 4 +#define ACT_DO 5 +#define ACT_SLEEP 6 +#define ACT_PLACE_STONE 7 +#define ACT_PLACE_TABLE 8 +#define ACT_PLACE_FURNACE 9 +#define ACT_PLACE_PLANT 10 +#define ACT_MAKE_WOOD_PICK 11 +#define ACT_MAKE_STONE_PICK 12 +#define ACT_MAKE_IRON_PICK 13 +#define ACT_MAKE_WOOD_SWORD 14 +#define ACT_MAKE_STONE_SWORD 15 +#define ACT_MAKE_IRON_SWORD 16 + +// Achievements (index in env->log.achievements[]) +#define ACH_COLLECT_WOOD 0 +#define ACH_PLACE_TABLE 1 +#define ACH_EAT_COW 2 +#define ACH_COLLECT_SAPLING 3 +#define ACH_COLLECT_DRINK 4 +#define ACH_MAKE_WOOD_PICK 5 +#define ACH_MAKE_WOOD_SWORD 6 +#define ACH_PLACE_PLANT 7 +#define ACH_DEFEAT_ZOMBIE 8 +#define ACH_COLLECT_STONE 9 +#define ACH_PLACE_STONE 10 +#define ACH_EAT_PLANT 11 +#define ACH_DEFEAT_SKELETON 12 +#define ACH_MAKE_STONE_PICK 13 +#define ACH_MAKE_STONE_SWORD 14 +#define ACH_WAKE_UP 15 +#define ACH_PLACE_FURNACE 16 +#define ACH_COLLECT_COAL 17 +#define ACH_COLLECT_IRON 18 +#define ACH_COLLECT_DIAMOND 19 +#define ACH_MAKE_IRON_PICK 20 +#define ACH_MAKE_IRON_SWORD 21 + +static const int DIR_DR[5] = {0, 0, 0, -1, 1}; +static const int DIR_DC[5] = {0, -1, 1, 0, 0}; + +// ============================================================ +// Tiny PCG-style RNG (single 64-bit state) +// ============================================================ +static inline uint32_t cr_pcg(uint64_t* s) { + *s = *s * 6364136223846793005ULL + 1442695040888963407ULL; + uint32_t x = (uint32_t)(((*s >> 18u) ^ *s) >> 27u); + uint32_t rot = (uint32_t)(*s >> 59u); + return (x >> rot) | (x << ((-(int32_t)rot) & 31)); +} +static inline float cr_rf(uint64_t* s) { return (cr_pcg(s) >> 8) * (1.0f / 16777216.0f); } +static inline int cr_ri(uint64_t* s, int n) { return (int)(cr_pcg(s) % (uint32_t)n); } + +// ============================================================ +// PufferLib-required structs +// ============================================================ +typedef struct Log { + float perf; // 0-1 normalized progress (achievements / 22) + float score; // sum of episode returns seen so far + float episode_return; // last episode return + float episode_length; // last episode length + float achievements[NUM_ACHIEVEMENTS]; + float n; // required counter (last field) +} Log; + +typedef struct Client { + int dummy; // handled by raylib globally; no per-env handle needed +} Client; + +// ============================================================ +// Env struct +// ============================================================ +typedef struct CraftaxClassic { + Client* client; + Log log; + + float* observations; // (OBS_DIM,) fp32, PufferLib-owned + float* actions; // (1,) fp32 + float* rewards; // (1,) + float* terminals; // (1,) + + int num_agents; // = 1 + + unsigned int rng; // populated by default my_vec_init (env index) + uint64_t pcg; // actual RNG state (seeded from rng in my_init) + + // Packed map (2 blocks/byte) + uint8_t map_packed[MAP_PACKED_SIZE]; + + // Per-type occupancy bitmaps: bit c of bits[r] = "mob-type at (r,c)" + uint64_t mob_bits[MAP_SIZE]; // zombie | cow | skel (used by has_mob_at / can_move_mob) + uint64_t zombie_bits[MAP_SIZE]; + uint64_t cow_bits[MAP_SIZE]; + uint64_t skel_bits[MAP_SIZE]; + uint64_t arrow_bits[MAP_SIZE]; + + // Player + int16_t player_r, player_c; + int8_t player_dir; + + // Intrinsics + int8_t health, food, drink, energy; + bool is_sleeping; + float recover, hunger, thirst, fatigue; + + // Inventory (wood, stone, coal, iron, diamond, sapling, + // wpick, spick, ipick, wsword, ssword, isword) + int8_t inv[NUM_INVENTORY]; + + // Mobs + int16_t zombie_r[MAX_ZOMBIES], zombie_c[MAX_ZOMBIES]; + int8_t zombie_hp[MAX_ZOMBIES], zombie_cd[MAX_ZOMBIES]; + bool zombie_mask[MAX_ZOMBIES]; + + int16_t cow_r[MAX_COWS], cow_c[MAX_COWS]; + int8_t cow_hp[MAX_COWS]; + bool cow_mask[MAX_COWS]; + + int16_t skel_r[MAX_SKELETONS], skel_c[MAX_SKELETONS]; + int8_t skel_hp[MAX_SKELETONS], skel_cd[MAX_SKELETONS]; + bool skel_mask[MAX_SKELETONS]; + + int16_t arrow_r[MAX_ARROWS], arrow_c[MAX_ARROWS]; + int8_t arrow_dr[MAX_ARROWS], arrow_dc[MAX_ARROWS]; + bool arrow_mask[MAX_ARROWS]; + + int16_t plant_r[MAX_PLANTS], plant_c[MAX_PLANTS]; + int16_t plant_age[MAX_PLANTS]; + bool plant_mask[MAX_PLANTS]; + + float light_level; + bool achievements[NUM_ACHIEVEMENTS]; + int32_t timestep; + + // Episode stats (accumulated; flushed into env->log on terminal) + float episode_return_accum; + int32_t episode_length_accum; + + // Scratch for per-step reward computation + int8_t old_health; + bool old_achievements[NUM_ACHIEVEMENTS]; +} CraftaxClassic; + +// ============================================================ +// Map accessors + small helpers +// ============================================================ +static inline int8_t map_get(const CraftaxClassic* s, int r, int c) { + int idx = r * MAP_PACKED_ROW + (c >> 1); + uint8_t b = s->map_packed[idx]; + return (c & 1) ? (int8_t)(b >> 4) : (int8_t)(b & 0x0F); +} +static inline void map_set(CraftaxClassic* s, int r, int c, int8_t v) { + int idx = r * MAP_PACKED_ROW + (c >> 1); + uint8_t b = s->map_packed[idx]; + if (c & 1) s->map_packed[idx] = (b & 0x0F) | ((v & 0x0F) << 4); + else s->map_packed[idx] = (b & 0xF0) | (v & 0x0F); +} +static inline bool in_bounds(int r, int c) { return (unsigned)r < MAP_SIZE && (unsigned)c < MAP_SIZE; } +static inline bool is_solid(int8_t b) { + return b == BLK_WATER || b == BLK_STONE || b == BLK_TREE || + b == BLK_COAL || b == BLK_IRON || b == BLK_DIAMOND || + b == BLK_TABLE || b == BLK_FURNACE || + b == BLK_PLANT || b == BLK_RIPE_PLANT; +} +static inline int l1_dist(int r1, int c1, int r2, int c2) { + int dr = r1 - r2; if (dr < 0) dr = -dr; + int dc = c1 - c2; if (dc < 0) dc = -dc; + return dr + dc; +} +static inline int cr_clamp_i(int v, int lo, int hi){ return vhi?hi:v); } +static inline int cr_min_i(int a,int b){return ab?a:b;} +static inline float cr_min_f(float a,float b){return a0)-(v<0);} + +// Bitmap maintenance +static inline void mb_set(uint64_t* bits, int r, int c) { bits[r] |= (1ULL << c); } +static inline void mb_clear(uint64_t* bits, int r, int c) { bits[r] &= ~(1ULL << c); } +static inline bool mb_get(const uint64_t* bits, int r, int c) { return (bits[r] >> c) & 1ULL; } + +static inline bool has_mob_at(const CraftaxClassic* s, int r, int c) { + if ((unsigned)r >= MAP_SIZE || (unsigned)c >= MAP_SIZE) return false; + return ((s->mob_bits[r] >> c) & 1ULL) != 0; +} + +static bool is_near_block(const CraftaxClassic* s, int8_t blk) { + int pr = s->player_r, pc = s->player_c; + static const int dr8[8] = {0, 0, -1, 1, -1, -1, 1, 1}; + static const int dc8[8] = {-1, 1, 0, 0, -1, 1, -1, 1}; + for (int i = 0; i < 8; i++) { + int nr = pr + dr8[i], nc = pc + dc8[i]; + if (in_bounds(nr, nc) && map_get(s, nr, nc) == blk) return true; + } + return false; +} + +static inline int get_damage(const CraftaxClassic* s) { + if (s->inv[11] > 0) return 5; + if (s->inv[10] > 0) return 3; + if (s->inv[9] > 0) return 2; + return 1; +} + +// ============================================================ +// Perlin worldgen (AVX-512, per-env) +// ============================================================ +static inline float perlin_interp(float t) { return t*t*t*(t*(t*6.0f-15.0f)+10.0f); } + +#if defined(__clang__) || defined(__GNUC__) +__attribute__((target("avx512f,avx512bw,avx512dq,avx512vl"))) +#endif +static void generate_world(CraftaxClassic* s) { + // Reset maps and bitmaps + for (int i = 0; i < MAP_PACKED_SIZE; i++) + s->map_packed[i] = (uint8_t)(BLK_GRASS | (BLK_GRASS << 4)); + memset(s->mob_bits, 0, sizeof(s->mob_bits)); + memset(s->zombie_bits, 0, sizeof(s->zombie_bits)); + memset(s->cow_bits, 0, sizeof(s->cow_bits)); + memset(s->skel_bits, 0, sizeof(s->skel_bits)); + memset(s->arrow_bits, 0, sizeof(s->arrow_bits)); + + // Perlin gradient tables (precompute cos/sin of the per-grid random angles). + // Padded by +16 floats so AVX-512 permute-load at the last grid row doesn't + // read out of bounds. + enum { GRID = 10, GRID_PAD = GRID * GRID + 16 }; + _Alignas(64) float cos_a[4][GRID_PAD]; + _Alignas(64) float sin_a[4][GRID_PAD]; + for (int layer = 0; layer < 4; layer++) { + for (int i = 0; i < GRID * GRID; i++) { + float a = cr_rf(&s->pcg) * 2.0f * 3.14159265f; + cos_a[layer][i] = cosf(a); + sin_a[layer][i] = sinf(a); + } + for (int i = GRID * GRID; i < GRID_PAD; i++) { cos_a[layer][i] = 0; sin_a[layer][i] = 0; } + } + + float scale = (float)MAP_SIZE / (float)(GRID - 1); + float inv_scale = 1.0f / scale; + int center = MAP_SIZE / 2; + + _Alignas(64) float noise[4][MAP_SIZE][MAP_SIZE]; + { + const __m512 c_lane = _mm512_setr_ps(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15); + const __m512 one = _mm512_set1_ps(1.0f); + const __m512 half = _mm512_set1_ps(0.5f); + const __m512 c6 = _mm512_set1_ps(6.0f); + const __m512 c15 = _mm512_set1_ps(15.0f); + const __m512 c10 = _mm512_set1_ps(10.0f); + const __m512 invs = _mm512_set1_ps(inv_scale); + const __m512i i_one = _mm512_set1_epi32(1); + + for (int r = 0; r < MAP_SIZE; r++) { + float nr = (float)r * inv_scale; + int x0 = (int)nr; + float fx = nr - x0; + float fx1 = fx - 1.0f; + float u = perlin_interp(fx); + int row0 = x0 * GRID, row1 = row0 + GRID; + __m512 fx_v = _mm512_set1_ps(fx); + __m512 fx1_v = _mm512_set1_ps(fx1); + __m512 u_v = _mm512_set1_ps(u); + + for (int c_base = 0; c_base < MAP_SIZE; c_base += 16) { + __m512 c_v = _mm512_add_ps(_mm512_set1_ps((float)c_base), c_lane); + __m512 nc_v = _mm512_mul_ps(c_v, invs); + __m512i y0_v = _mm512_cvttps_epi32(nc_v); + __m512 y0_f = _mm512_cvtepi32_ps(y0_v); + __m512 fy_v = _mm512_sub_ps(nc_v, y0_f); + __m512 fy1_v = _mm512_sub_ps(fy_v, one); + __m512 t = _mm512_fmsub_ps(fy_v, c6, c15); + t = _mm512_fmadd_ps(fy_v, t, c10); + __m512 fy2 = _mm512_mul_ps(fy_v, fy_v); + __m512 fy3 = _mm512_mul_ps(fy2, fy_v); + __m512 v_v = _mm512_mul_ps(fy3, t); + __m512i y1_v = _mm512_add_epi32(y0_v, i_one); + + for (int k = 0; k < 4; k++) { + __m512 cos_r0 = _mm512_loadu_ps(&cos_a[k][row0]); + __m512 cos_r1 = _mm512_loadu_ps(&cos_a[k][row1]); + __m512 sin_r0 = _mm512_loadu_ps(&sin_a[k][row0]); + __m512 sin_r1 = _mm512_loadu_ps(&sin_a[k][row1]); + + __m512 c00 = _mm512_permutexvar_ps(y0_v, cos_r0); + __m512 c10v= _mm512_permutexvar_ps(y0_v, cos_r1); + __m512 c01 = _mm512_permutexvar_ps(y1_v, cos_r0); + __m512 c11 = _mm512_permutexvar_ps(y1_v, cos_r1); + __m512 s00 = _mm512_permutexvar_ps(y0_v, sin_r0); + __m512 s10 = _mm512_permutexvar_ps(y0_v, sin_r1); + __m512 s01 = _mm512_permutexvar_ps(y1_v, sin_r0); + __m512 s11 = _mm512_permutexvar_ps(y1_v, sin_r1); + + __m512 n00 = _mm512_fmadd_ps(c00, fx_v, _mm512_mul_ps(s00, fy_v)); + __m512 n10 = _mm512_fmadd_ps(c10v, fx1_v, _mm512_mul_ps(s10, fy_v)); + __m512 n01 = _mm512_fmadd_ps(c01, fx_v, _mm512_mul_ps(s01, fy1_v)); + __m512 n11 = _mm512_fmadd_ps(c11, fx1_v, _mm512_mul_ps(s11, fy1_v)); + + __m512 nx0 = _mm512_fmadd_ps(u_v, _mm512_sub_ps(n10, n00), n00); + __m512 nx1 = _mm512_fmadd_ps(u_v, _mm512_sub_ps(n11, n01), n01); + __m512 n = _mm512_fmadd_ps(v_v, _mm512_sub_ps(nx1, nx0), nx0); + n = _mm512_mul_ps(_mm512_add_ps(n, one), half); + + _mm512_storeu_ps(&noise[k][r][c_base], n); + } + } + } + } + + // Tile-logic sweep -- reads precomputed noise, writes blocks + for (int r = 0; r < MAP_SIZE; r++) { + for (int c = 0; c < MAP_SIZE; c++) { + float water_noise = noise[0][r][c]; + float mountain_noise = noise[1][r][c]; + float tree_noise = noise[2][r][c]; + float path_noise = noise[3][r][c]; + + float dist = sqrtf((float)((r-center)*(r-center) + (c-center)*(c-center))); + float prox = 1.0f - cr_min_f(dist / 20.0f, 1.0f); + + float water_val = water_noise - prox * 0.3f; + float mountain_val = mountain_noise - prox * 0.3f; + + int8_t blk = BLK_GRASS; + if (water_val > 0.7f) blk = BLK_WATER; + else if (water_val > 0.6f && water_val <= 0.75f) blk = BLK_SAND; + else if (mountain_val > 0.7f) { + blk = BLK_STONE; + if (path_noise > 0.8f) blk = BLK_PATH; + if (mountain_val > 0.85f && water_noise > 0.4f) blk = BLK_PATH; + if (mountain_val > 0.85f && tree_noise > 0.7f) blk = BLK_LAVA; + } + if (blk == BLK_STONE) { + float ore = cr_rf(&s->pcg); + if (ore < 0.005f && mountain_val > 0.8f) blk = BLK_DIAMOND; + else if (ore < 0.035f) blk = BLK_IRON; + else if (ore < 0.075f) blk = BLK_COAL; + } + if (blk == BLK_GRASS && tree_noise > 0.5f && cr_rf(&s->pcg) > 0.8f) + blk = BLK_TREE; + map_set(s, r, c, blk); + } + } + + map_set(s, center, center, BLK_GRASS); // player spawn always grass + + bool has_diamond = false; + for (int r = 0; r < MAP_SIZE && !has_diamond; r++) + for (int c = 0; c < MAP_SIZE && !has_diamond; c++) + if (map_get(s, r, c) == BLK_DIAMOND) has_diamond = true; + if (!has_diamond) { + for (int att = 0; att < 1000; att++) { + int r = cr_ri(&s->pcg, MAP_SIZE), c = cr_ri(&s->pcg, MAP_SIZE); + if (map_get(s, r, c) == BLK_STONE) { map_set(s, r, c, BLK_DIAMOND); break; } + } + } + + // Initial intrinsics + inventory + mobs + s->player_r = center; s->player_c = center; s->player_dir = 4; + s->health = 9; s->food = 9; s->drink = 9; s->energy = 9; + s->is_sleeping = false; + s->recover = s->hunger = s->thirst = s->fatigue = 0; + memset(s->inv, 0, sizeof(s->inv)); + memset(s->zombie_mask, 0, sizeof(s->zombie_mask)); + memset(s->zombie_hp, 0, sizeof(s->zombie_hp)); + memset(s->zombie_cd, 0, sizeof(s->zombie_cd)); + memset(s->cow_mask, 0, sizeof(s->cow_mask)); + memset(s->cow_hp, 0, sizeof(s->cow_hp)); + memset(s->skel_mask, 0, sizeof(s->skel_mask)); + memset(s->skel_hp, 0, sizeof(s->skel_hp)); + memset(s->skel_cd, 0, sizeof(s->skel_cd)); + memset(s->arrow_mask, 0, sizeof(s->arrow_mask)); + memset(s->plant_mask, 0, sizeof(s->plant_mask)); + memset(s->plant_age, 0, sizeof(s->plant_age)); + memset(s->achievements, 0, sizeof(s->achievements)); + s->timestep = 0; + s->light_level = 1.0f; +} + +// ============================================================ +// Step sub-actions +// ============================================================ +static void do_crafting(CraftaxClassic* s, int action) { + bool t = is_near_block(s, BLK_TABLE); + bool f = is_near_block(s, BLK_FURNACE); + if (action == ACT_MAKE_WOOD_PICK && t && s->inv[0] >= 1) { s->inv[0]--; s->inv[6]++; s->achievements[ACH_MAKE_WOOD_PICK] = true; } + if (action == ACT_MAKE_STONE_PICK && t && s->inv[0] >= 1 && s->inv[1] >= 1) { s->inv[0]--; s->inv[1]--; s->inv[7]++; s->achievements[ACH_MAKE_STONE_PICK] = true; } + if (action == ACT_MAKE_IRON_PICK && t && f && s->inv[0] >= 1 && s->inv[1] >= 1 && s->inv[3] >= 1 && s->inv[2] >= 1) { + s->inv[0]--; s->inv[1]--; s->inv[3]--; s->inv[2]--; s->inv[8]++; s->achievements[ACH_MAKE_IRON_PICK] = true; + } + if (action == ACT_MAKE_WOOD_SWORD && t && s->inv[0] >= 1) { s->inv[0]--; s->inv[9]++; s->achievements[ACH_MAKE_WOOD_SWORD] = true; } + if (action == ACT_MAKE_STONE_SWORD && t && s->inv[0] >= 1 && s->inv[1] >= 1) { s->inv[0]--; s->inv[1]--; s->inv[10]++; s->achievements[ACH_MAKE_STONE_SWORD] = true; } + if (action == ACT_MAKE_IRON_SWORD && t && f && s->inv[0] >= 1 && s->inv[1] >= 1 && s->inv[3] >= 1 && s->inv[2] >= 1) { + s->inv[0]--; s->inv[1]--; s->inv[3]--; s->inv[2]--; s->inv[11]++; s->achievements[ACH_MAKE_IRON_SWORD] = true; + } +} + +static void do_action(CraftaxClassic* s) { + int tr = s->player_r + DIR_DR[s->player_dir]; + int tc = s->player_c + DIR_DC[s->player_dir]; + if (!in_bounds(tr, tc)) return; + int dmg = get_damage(s); + bool attacked = false; + + for (int i = 0; i < MAX_ZOMBIES && !attacked; i++) + if (s->zombie_mask[i] && s->zombie_r[i] == tr && s->zombie_c[i] == tc) { + s->zombie_hp[i] -= dmg; + if (s->zombie_hp[i] <= 0) { + s->zombie_mask[i] = false; + mb_clear(s->mob_bits, tr, tc); mb_clear(s->zombie_bits, tr, tc); + s->achievements[ACH_DEFEAT_ZOMBIE] = true; + } + attacked = true; + } + for (int i = 0; i < MAX_COWS && !attacked; i++) + if (s->cow_mask[i] && s->cow_r[i] == tr && s->cow_c[i] == tc) { + s->cow_hp[i] -= dmg; + if (s->cow_hp[i] <= 0) { + s->cow_mask[i] = false; + mb_clear(s->mob_bits, tr, tc); mb_clear(s->cow_bits, tr, tc); + s->achievements[ACH_EAT_COW] = true; + s->food = (int8_t)cr_min_i(9, s->food + 6); s->hunger = 0; + } + attacked = true; + } + for (int i = 0; i < MAX_SKELETONS && !attacked; i++) + if (s->skel_mask[i] && s->skel_r[i] == tr && s->skel_c[i] == tc) { + s->skel_hp[i] -= dmg; + if (s->skel_hp[i] <= 0) { + s->skel_mask[i] = false; + mb_clear(s->mob_bits, tr, tc); mb_clear(s->skel_bits, tr, tc); + s->achievements[ACH_DEFEAT_SKELETON] = true; + } + attacked = true; + } + if (attacked) return; + + int8_t blk = map_get(s, tr, tc); + switch (blk) { + case BLK_TREE: + map_set(s, tr, tc, BLK_GRASS); + s->inv[0] = (int8_t)cr_min_i(9, s->inv[0] + 1); + s->achievements[ACH_COLLECT_WOOD] = true; break; + case BLK_STONE: + if (s->inv[6] > 0 || s->inv[7] > 0 || s->inv[8] > 0) { + map_set(s, tr, tc, BLK_PATH); + s->inv[1] = (int8_t)cr_min_i(9, s->inv[1] + 1); + s->achievements[ACH_COLLECT_STONE] = true; + } break; + case BLK_COAL: + if (s->inv[6] > 0 || s->inv[7] > 0 || s->inv[8] > 0) { + map_set(s, tr, tc, BLK_PATH); + s->inv[2] = (int8_t)cr_min_i(9, s->inv[2] + 1); + s->achievements[ACH_COLLECT_COAL] = true; + } break; + case BLK_IRON: + if (s->inv[7] > 0 || s->inv[8] > 0) { + map_set(s, tr, tc, BLK_PATH); + s->inv[3] = (int8_t)cr_min_i(9, s->inv[3] + 1); + s->achievements[ACH_COLLECT_IRON] = true; + } break; + case BLK_DIAMOND: + if (s->inv[8] > 0) { + map_set(s, tr, tc, BLK_PATH); + s->inv[4] = (int8_t)cr_min_i(9, s->inv[4] + 1); + s->achievements[ACH_COLLECT_DIAMOND] = true; + } break; + case BLK_GRASS: + if (cr_rf(&s->pcg) < 0.1f) { + s->inv[5] = (int8_t)cr_min_i(9, s->inv[5] + 1); + s->achievements[ACH_COLLECT_SAPLING] = true; + } break; + case BLK_WATER: + s->drink = (int8_t)cr_min_i(9, s->drink + 1); s->thirst = 0; + s->achievements[ACH_COLLECT_DRINK] = true; break; + case BLK_RIPE_PLANT: + map_set(s, tr, tc, BLK_PLANT); + s->food = (int8_t)cr_min_i(9, s->food + 4); s->hunger = 0; + s->achievements[ACH_EAT_PLANT] = true; + for (int i = 0; i < MAX_PLANTS; i++) + if (s->plant_mask[i] && s->plant_r[i] == tr && s->plant_c[i] == tc) { + s->plant_age[i] = 0; break; + } + break; + } +} + +static void place_block(CraftaxClassic* s, int action) { + int tr = s->player_r + DIR_DR[s->player_dir]; + int tc = s->player_c + DIR_DC[s->player_dir]; + if (!in_bounds(tr, tc)) return; + if (has_mob_at(s, tr, tc)) return; + int8_t blk = map_get(s, tr, tc); + if (action == ACT_PLACE_TABLE && s->inv[0] >= 2 && !is_solid(blk)) { + map_set(s, tr, tc, BLK_TABLE); s->inv[0] -= 2; + s->achievements[ACH_PLACE_TABLE] = true; + } else if (action == ACT_PLACE_FURNACE && s->inv[1] >= 1 && !is_solid(blk)) { + map_set(s, tr, tc, BLK_FURNACE); s->inv[1] -= 1; + s->achievements[ACH_PLACE_FURNACE] = true; + } else if (action == ACT_PLACE_STONE && s->inv[1] >= 1 && (!is_solid(blk) || blk == BLK_WATER)) { + map_set(s, tr, tc, BLK_STONE); s->inv[1] -= 1; + s->achievements[ACH_PLACE_STONE] = true; + } else if (action == ACT_PLACE_PLANT && s->inv[5] >= 1 && blk == BLK_GRASS) { + map_set(s, tr, tc, BLK_PLANT); s->inv[5] -= 1; + s->achievements[ACH_PLACE_PLANT] = true; + for (int i = 0; i < MAX_PLANTS; i++) { + if (!s->plant_mask[i]) { + s->plant_r[i] = tr; s->plant_c[i] = tc; + s->plant_age[i] = 0; s->plant_mask[i] = true; break; + } + } + } +} + +static void move_player(CraftaxClassic* s, int action) { + if (action < 1 || action > 4) return; + int nr = s->player_r + DIR_DR[action]; + int nc = s->player_c + DIR_DC[action]; + s->player_dir = (int8_t)action; + if (!in_bounds(nr, nc)) return; + if (is_solid(map_get(s, nr, nc))) return; + if (has_mob_at(s, nr, nc)) return; + s->player_r = (int16_t)nr; s->player_c = (int16_t)nc; +} + +static bool can_move_mob(const CraftaxClassic* s, int r, int c) { + if (!in_bounds(r, c)) return false; + int8_t blk = map_get(s, r, c); + if (is_solid(blk)) return false; + if (blk == BLK_LAVA) return false; + if (has_mob_at(s, r, c)) return false; + if (r == s->player_r && c == s->player_c) return false; + return true; +} + +static void update_mobs(CraftaxClassic* s) { + int pr = s->player_r, pc = s->player_c; + + for (int i = 0; i < MAX_ZOMBIES; i++) { + if (!s->zombie_mask[i]) continue; + int zr = s->zombie_r[i], zc = s->zombie_c[i]; + int dist = l1_dist(zr, zc, pr, pc); + if (dist >= MOB_DESPAWN_DIST) { + s->zombie_mask[i] = false; + mb_clear(s->mob_bits, zr, zc); mb_clear(s->zombie_bits, zr, zc); + continue; + } + if (dist <= 1 && s->zombie_cd[i] <= 0) { + int dmg = s->is_sleeping ? 7 : 2; + s->health -= dmg; + s->zombie_cd[i] = 5; + s->is_sleeping = false; + } + s->zombie_cd[i] = (int8_t)cr_max_i(0, s->zombie_cd[i] - 1); + + int dr = 0, dc = 0; + if (dist < 10 && cr_rf(&s->pcg) < 0.75f) { + int adr = abs(pr - zr), adc = abs(pc - zc); + if (adr > adc || (adr == adc && cr_rf(&s->pcg) < 0.5f)) dr = cr_sign_i(pr - zr); + else dc = cr_sign_i(pc - zc); + } else { + int d = cr_ri(&s->pcg, 4); + dr = DIR_DR[d+1]; dc = DIR_DC[d+1]; + } + int nr = zr + dr, nc = zc + dc; + if (can_move_mob(s, nr, nc)) { + mb_clear(s->mob_bits, zr, zc); mb_clear(s->zombie_bits, zr, zc); + s->zombie_r[i] = (int16_t)nr; s->zombie_c[i] = (int16_t)nc; + mb_set(s->mob_bits, nr, nc); mb_set(s->zombie_bits, nr, nc); + } + } + + for (int i = 0; i < MAX_COWS; i++) { + if (!s->cow_mask[i]) continue; + int cr = s->cow_r[i], cc = s->cow_c[i]; + int dist = l1_dist(cr, cc, pr, pc); + if (dist >= MOB_DESPAWN_DIST) { + s->cow_mask[i] = false; + mb_clear(s->mob_bits, cr, cc); mb_clear(s->cow_bits, cr, cc); + continue; + } + int d = cr_ri(&s->pcg, 8); + if (d < 4) { + int dr = DIR_DR[d+1], dc2 = DIR_DC[d+1]; + int nr = cr + dr, nc = cc + dc2; + if (can_move_mob(s, nr, nc)) { + mb_clear(s->mob_bits, cr, cc); mb_clear(s->cow_bits, cr, cc); + s->cow_r[i] = (int16_t)nr; s->cow_c[i] = (int16_t)nc; + mb_set(s->mob_bits, nr, nc); mb_set(s->cow_bits, nr, nc); + } + } + } + + for (int i = 0; i < MAX_SKELETONS; i++) { + if (!s->skel_mask[i]) continue; + int sr = s->skel_r[i], sc = s->skel_c[i]; + int dist = l1_dist(sr, sc, pr, pc); + if (dist >= MOB_DESPAWN_DIST) { + s->skel_mask[i] = false; + mb_clear(s->mob_bits, sr, sc); mb_clear(s->skel_bits, sr, sc); + continue; + } + if (dist >= 4 && dist <= 5 && s->skel_cd[i] <= 0) { + for (int a = 0; a < MAX_ARROWS; a++) { + if (!s->arrow_mask[a]) { + s->arrow_mask[a] = true; + s->arrow_r[a] = (int16_t)sr; s->arrow_c[a] = (int16_t)sc; + mb_set(s->arrow_bits, sr, sc); + int adr = abs(pr - sr), adc = abs(pc - sc); + s->arrow_dr[a] = (int8_t)((adr > 0) ? cr_sign_i(pr - sr) : 0); + s->arrow_dc[a] = (int8_t)((adc > 0) ? cr_sign_i(pc - sc) : 0); + break; + } + } + s->skel_cd[i] = 4; + } + s->skel_cd[i] = (int8_t)cr_max_i(0, s->skel_cd[i] - 1); + + int dr = 0, dc = 0; + bool random_move = cr_rf(&s->pcg) < 0.15f; + if (!random_move) { + if (dist >= 10) { + int adr = abs(pr - sr), adc = abs(pc - sc); + if (adr > adc || (adr == adc && cr_rf(&s->pcg) < 0.5f)) dr = cr_sign_i(pr - sr); + else dc = cr_sign_i(pc - sc); + } else if (dist <= 3) { + int adr = abs(pr - sr), adc = abs(pc - sc); + if (adr > adc || (adr == adc && cr_rf(&s->pcg) < 0.5f)) dr = -cr_sign_i(pr - sr); + else dc = -cr_sign_i(pc - sc); + } else { + random_move = true; + } + } + if (random_move) { + int d = cr_ri(&s->pcg, 4); + dr = DIR_DR[d+1]; dc = DIR_DC[d+1]; + } + int nr = sr + dr, nc = sc + dc; + if (can_move_mob(s, nr, nc)) { + mb_clear(s->mob_bits, sr, sc); mb_clear(s->skel_bits, sr, sc); + s->skel_r[i] = (int16_t)nr; s->skel_c[i] = (int16_t)nc; + mb_set(s->mob_bits, nr, nc); mb_set(s->skel_bits, nr, nc); + } + } + + for (int i = 0; i < MAX_ARROWS; i++) { + if (!s->arrow_mask[i]) continue; + int ar = s->arrow_r[i], ac = s->arrow_c[i]; + int nr = ar + s->arrow_dr[i], nc = ac + s->arrow_dc[i]; + if (!in_bounds(nr, nc)) { s->arrow_mask[i] = false; mb_clear(s->arrow_bits, ar, ac); continue; } + int8_t blk = map_get(s, nr, nc); + if (is_solid(blk) && blk != BLK_WATER) { + if (blk == BLK_FURNACE || blk == BLK_TABLE) map_set(s, nr, nc, BLK_PATH); + s->arrow_mask[i] = false; mb_clear(s->arrow_bits, ar, ac); continue; + } + if (nr == pr && nc == pc) { + s->health -= 2; s->is_sleeping = false; + s->arrow_mask[i] = false; mb_clear(s->arrow_bits, ar, ac); continue; + } + mb_clear(s->arrow_bits, ar, ac); + s->arrow_r[i] = (int16_t)nr; s->arrow_c[i] = (int16_t)nc; + mb_set(s->arrow_bits, nr, nc); + } +} + +static bool try_spawn(CraftaxClassic* s, int min_d, int max_d, bool need_grass, bool need_path, + int* or_, int* oc_) { + int pr = s->player_r, pc = s->player_c; + for (int att = 0; att < 20; att++) { + int r = cr_ri(&s->pcg, MAP_SIZE), c = cr_ri(&s->pcg, MAP_SIZE); + int dist = l1_dist(r, c, pr, pc); + if (dist < min_d || dist >= max_d) continue; + if (has_mob_at(s, r, c)) continue; + if (r == pr && c == pc) continue; + int8_t blk = map_get(s, r, c); + if (need_grass && blk != BLK_GRASS) continue; + if (need_path && blk != BLK_PATH ) continue; + if (!need_grass && !need_path && blk != BLK_GRASS && blk != BLK_PATH) continue; + *or_ = r; *oc_ = c; return true; + } + return false; +} + +static void spawn_mobs(CraftaxClassic* s) { + int n_cows = 0, n_z = 0, n_sk = 0; + for (int i = 0; i < MAX_COWS; i++) n_cows += s->cow_mask[i]; + for (int i = 0; i < MAX_ZOMBIES; i++) n_z += s->zombie_mask[i]; + for (int i = 0; i < MAX_SKELETONS; i++) n_sk += s->skel_mask[i]; + + if (n_cows < MAX_COWS && cr_rf(&s->pcg) < 0.1f) { + int r, c; + if (try_spawn(s, 3, MOB_DESPAWN_DIST, true, false, &r, &c)) { + for (int i = 0; i < MAX_COWS; i++) if (!s->cow_mask[i]) { + s->cow_mask[i] = true; s->cow_r[i] = (int16_t)r; s->cow_c[i] = (int16_t)c; s->cow_hp[i] = 3; + mb_set(s->mob_bits, r, c); mb_set(s->cow_bits, r, c); + break; + } + } + } + float zombie_chance = 0.02f + 0.1f * (1.0f - s->light_level) * (1.0f - s->light_level); + if (n_z < MAX_ZOMBIES && cr_rf(&s->pcg) < zombie_chance) { + int r, c; + if (try_spawn(s, 9, MOB_DESPAWN_DIST, false, false, &r, &c)) { + for (int i = 0; i < MAX_ZOMBIES; i++) if (!s->zombie_mask[i]) { + s->zombie_mask[i] = true; s->zombie_r[i] = (int16_t)r; s->zombie_c[i] = (int16_t)c; + s->zombie_hp[i] = 5; s->zombie_cd[i] = 0; + mb_set(s->mob_bits, r, c); mb_set(s->zombie_bits, r, c); + break; + } + } + } + if (n_sk < MAX_SKELETONS && cr_rf(&s->pcg) < 0.05f) { + int r, c; + if (try_spawn(s, 9, MOB_DESPAWN_DIST, false, true, &r, &c)) { + for (int i = 0; i < MAX_SKELETONS; i++) if (!s->skel_mask[i]) { + s->skel_mask[i] = true; s->skel_r[i] = (int16_t)r; s->skel_c[i] = (int16_t)c; + s->skel_hp[i] = 3; s->skel_cd[i] = 0; + mb_set(s->mob_bits, r, c); mb_set(s->skel_bits, r, c); + break; + } + } + } +} + +static void update_plants(CraftaxClassic* s) { + for (int i = 0; i < MAX_PLANTS; i++) { + if (!s->plant_mask[i]) continue; + s->plant_age[i]++; + if (s->plant_age[i] >= 600) { + int r = s->plant_r[i], c = s->plant_c[i]; + if (in_bounds(r, c) && map_get(s, r, c) == BLK_PLANT) + map_set(s, r, c, BLK_RIPE_PLANT); + } + } +} + +static void update_intrinsics(CraftaxClassic* s, int action) { + if (action == ACT_SLEEP && s->energy < 9) s->is_sleeping = true; + if (s->energy >= 9 && s->is_sleeping) { + s->is_sleeping = false; + s->achievements[ACH_WAKE_UP] = true; + } + float mul = s->is_sleeping ? 0.5f : 1.0f; + s->hunger += mul; if (s->hunger > 25.0f) { s->food--; s->hunger = 0; } + s->thirst += mul; if (s->thirst > 20.0f) { s->drink--; s->thirst = 0; } + if (s->is_sleeping) s->fatigue -= 1.0f; else s->fatigue += 1.0f; + if (s->fatigue > 30.0f) { s->energy--; s->fatigue = 0; } + if (s->fatigue < -10.0f) { s->energy = (int8_t)cr_min_i(s->energy + 1, 9); s->fatigue = 0; } + bool ok = (s->food > 0) && (s->drink > 0) && (s->energy > 0 || s->is_sleeping); + if (ok) s->recover += s->is_sleeping ? 2.0f : 1.0f; + else s->recover += s->is_sleeping ? -0.5f : -1.0f; + if (s->recover > 25.0f) { s->health = (int8_t)cr_min_i(s->health + 1, 9); s->recover = 0; } + if (s->recover < -15.0f) { s->health--; s->recover = 0; } +} + +// ============================================================ +// Observation builder (writes OBS_DIM floats into env->observations) +// ============================================================ +static void compute_observations(CraftaxClassic* s) { + float* obs = s->observations; + int pr = s->player_r, pc = s->player_c; + int idx = 0; + for (int dr = -3; dr <= 3; dr++) { + int r = pr + dr; + bool row_ok = (unsigned)r < MAP_SIZE; + uint64_t zb = row_ok ? s->zombie_bits[r] : 0; + uint64_t cb = row_ok ? s->cow_bits[r] : 0; + uint64_t sb = row_ok ? s->skel_bits[r] : 0; + uint64_t ab = row_ok ? s->arrow_bits[r] : 0; + for (int dc = -4; dc <= 4; dc++) { + int c = pc + dc; + int8_t blk = (row_ok && (unsigned)c < MAP_SIZE) ? map_get(s, r, c) : BLK_OUT_OF_BOUNDS; + float* dst = obs + idx; + for (int b = 0; b < NUM_BLOCK_TYPES; b++) dst[b] = 0.0f; + if ((unsigned)blk < NUM_BLOCK_TYPES) dst[blk] = 1.0f; + idx += NUM_BLOCK_TYPES; + float mz = 0, mc = 0, ms = 0, ma = 0; + if (row_ok && (unsigned)c < MAP_SIZE) { + uint64_t bit = 1ULL << c; + mz = (zb & bit) ? 1.0f : 0.0f; + mc = (cb & bit) ? 1.0f : 0.0f; + ms = (sb & bit) ? 1.0f : 0.0f; + ma = (ab & bit) ? 1.0f : 0.0f; + } + obs[idx++] = mz; obs[idx++] = mc; obs[idx++] = ms; obs[idx++] = ma; + } + } + for (int i = 0; i < NUM_INVENTORY; i++) obs[idx++] = (float)s->inv[i] * 0.1f; + obs[idx++] = (float)s->health * 0.1f; + obs[idx++] = (float)s->food * 0.1f; + obs[idx++] = (float)s->drink * 0.1f; + obs[idx++] = (float)s->energy * 0.1f; + for (int d = 1; d <= 4; d++) obs[idx++] = (s->player_dir == d) ? 1.0f : 0.0f; + obs[idx++] = s->light_level; + obs[idx++] = s->is_sleeping ? 1.0f : 0.0f; +} + +// ============================================================ +// Logging (stats accumulated into env->log; flushed at vec-level by PufferLib) +// ============================================================ +static void add_log(CraftaxClassic* env) { + int unlocked = 0; + for (int i = 0; i < NUM_ACHIEVEMENTS; i++) { + if (env->achievements[i]) { + unlocked++; + env->log.achievements[i] += 1.0f; + } + } + env->log.perf += (float)unlocked / (float)NUM_ACHIEVEMENTS; + env->log.score += env->episode_return_accum; + env->log.episode_return += env->episode_return_accum; + env->log.episode_length += (float)env->episode_length_accum; + env->log.n += 1.0f; +} + +// ============================================================ +// Reset-cache: optional pre-generated world pool. When +// craftax_classic_set_reset_pool_size(N>0) is called before any reset, +// c_reset memcpys from cache[idx] instead of running generate_world +// each episode. Drops worldgen (~30 us) to a 5 KB memcpy (~0.5 us). +// N=0 preserves baseline behavior (fresh world per reset). First caller +// wins; subsequent calls with a different size are no-ops, so every +// env's my_init can call safely. +// +// Default for Classic is 0 (see config/ocean/craftax_classic.ini): the +// env is not the training bottleneck here (GPU/train dominate the loop), +// so caching does not move training SPS. Useful for sim-only workloads +// (data generation, evaluation rollouts) where c_step throughput matters. +// Verified bitwise-equal to fresh generate_world for any cache entry. +// ============================================================ +static CraftaxClassic* craftax_classic_reset_cache = NULL; +static int craftax_classic_reset_cache_size = 0; +static int craftax_classic_reset_cache_built = 0; + +static void craftax_classic_set_reset_pool_size(int n) { + if (__atomic_load_n(&craftax_classic_reset_cache_built, __ATOMIC_ACQUIRE)) + return; + if (n <= 0) { + __atomic_store_n(&craftax_classic_reset_cache_built, 1, __ATOMIC_RELEASE); + return; + } + CraftaxClassic* pool = (CraftaxClassic*)calloc((size_t)n, sizeof(*pool)); + if (!pool) { + // Allocation failed: fall back to baseline worldgen. + __atomic_store_n(&craftax_classic_reset_cache_built, 1, __ATOMIC_RELEASE); + return; + } + for (int i = 0; i < n; i++) { + pool[i].pcg = ((uint64_t)(0xCAFEBABE12345678ULL) + (uint64_t)i) + * 0x9E3779B97F4A7C15ULL + 0x87C37B91114253D5ULL; + for (int k = 0; k < 8; k++) (void)cr_pcg(&pool[i].pcg); + generate_world(&pool[i]); + } + craftax_classic_reset_cache = pool; + craftax_classic_reset_cache_size = n; + __atomic_store_n(&craftax_classic_reset_cache_built, 1, __ATOMIC_RELEASE); +} + +// ============================================================ +// Public API: c_init / c_reset / c_step / c_close / c_render +// ============================================================ +static void c_init(CraftaxClassic* env) { + env->num_agents = 1; + env->client = NULL; + // env->rng was seeded by default my_vec_init to the env index; use it to + // initialize a proper 64-bit PCG state. + uint64_t seed = (uint64_t)env->rng; + env->pcg = seed * 0x9E3779B97F4A7C15ULL + 0x87C37B91114253D5ULL; + // Warm the RNG a bit so small seeds don't produce correlated worlds. + for (int i = 0; i < 8; i++) (void)cr_pcg(&env->pcg); + memset(&env->log, 0, sizeof(env->log)); +} + +static void c_reset(CraftaxClassic* env) { + env->episode_return_accum = 0.0f; + env->episode_length_accum = 0; + int pool_size = craftax_classic_reset_cache_size; + if (pool_size <= 0) { + generate_world(env); + } else { + // Pick a pool index using env's own RNG so different envs reset + // to different worlds and each env sees diversity across episodes. + uint32_t r = cr_pcg(&env->pcg); + int idx = (int)(r % (uint32_t)pool_size); + // Preserve runtime fields (pointers, log, rng) across the memcpy. + Client* cl = env->client; + float* o = env->observations; + float* a = env->actions; + float* rw = env->rewards; + float* tm = env->terminals; + int na = env->num_agents; + uint64_t pcg = env->pcg; + Log log = env->log; + memcpy(env, &craftax_classic_reset_cache[idx], sizeof(*env)); + env->client = cl; + env->observations = o; + env->actions = a; + env->rewards = rw; + env->terminals = tm; + env->num_agents = na; + env->pcg = pcg; + env->log = log; + } + compute_observations(env); +} + +static void c_step(CraftaxClassic* env) { + env->rewards[0] = 0.0f; + env->terminals[0] = 0.0f; + + int action = (int)env->actions[0]; + if (action < 0) action = 0; + if (action >= NUM_ACTIONS) action = NUM_ACTIONS - 1; + + // Snapshot for reward computation + env->old_health = env->health; + memcpy(env->old_achievements, env->achievements, sizeof(env->achievements)); + + int eff_action = env->is_sleeping ? ACT_NOOP : action; + do_crafting(env, eff_action); + if (eff_action == ACT_DO) do_action(env); + if (eff_action >= ACT_PLACE_STONE && eff_action <= ACT_PLACE_PLANT) place_block(env, eff_action); + move_player(env, eff_action); + update_mobs(env); + spawn_mobs(env); + update_plants(env); + update_intrinsics(env, action); + + for (int i = 0; i < NUM_INVENTORY; i++) + env->inv[i] = (int8_t)cr_clamp_i(env->inv[i], 0, 9); + + env->timestep++; + float t_frac = fmodf((float)env->timestep / (float)DAY_LENGTH, 1.0f) + 0.3f; + float cv = cosf(3.14159265f * t_frac); + env->light_level = 1.0f - fabsf(cv * cv * cv); + + // Reward: new achievements + health change * 0.1 + float ach_r = 0.0f; + for (int i = 0; i < NUM_ACHIEVEMENTS; i++) + ach_r += (float)(env->achievements[i] && !env->old_achievements[i]); + float hp_r = (float)(env->health - env->old_health) * 0.1f; + float r = ach_r + hp_r; + env->rewards[0] = r; + env->episode_return_accum += r; + env->episode_length_accum += 1; + + // Terminal conditions + bool done = (env->timestep >= MAX_TIMESTEPS) || (env->health <= 0); + if (in_bounds(env->player_r, env->player_c) + && map_get(env, env->player_r, env->player_c) == BLK_LAVA) done = true; + + if (done) { + env->terminals[0] = 1.0f; + add_log(env); + c_reset(env); // auto-reset (observation written inside) + } else { + compute_observations(env); + } +} + +static void c_close(CraftaxClassic* env) { + (void)env; +} + +// ============================================================ +// Tile-based renderer sharing the full-Craftax textures.bin +// ============================================================ +// Shared layout (see ocean/craftax/pack_textures.py): +// [0..36] block textures (first 17 used by classic, indexed by BLK_*) +// [37..41] player: down, up, left, right, sleep +// [42..46] items (unused by classic) +// [47..49] mobs: zombie, skeleton, cow +// [50..53] arrows: down, up, left, right + +#include + +#define CC_TEX_TILE_PX 16 +#define CC_TEX_SCALE 4 +#define CC_TEX_DRAW_PX (CC_TEX_TILE_PX * CC_TEX_SCALE) +#define CC_TEX_NUM (37 + 5 + 5 + 3 + 4) + +#define CC_TEX_PLAYER_DOWN 37 +#define CC_TEX_PLAYER_UP 38 +#define CC_TEX_PLAYER_LEFT 39 +#define CC_TEX_PLAYER_RIGHT 40 +#define CC_TEX_PLAYER_SLEEP 41 +#define CC_TEX_MOB_ZOMBIE 47 +#define CC_TEX_MOB_SKELETON 48 +#define CC_TEX_MOB_COW 49 +#define CC_TEX_ARROW_DOWN 50 +#define CC_TEX_ARROW_UP 51 +#define CC_TEX_ARROW_LEFT 52 +#define CC_TEX_ARROW_RIGHT 53 + +#define CC_RENDER_ROWS 16 +#define CC_RENDER_COLS 16 + +static Texture2D cc_textures[CC_TEX_NUM]; +static bool cc_textures_loaded = false; + +static void cc_load_textures(void) { + if (cc_textures_loaded) return; + const char* candidates[] = { + "resources/craftax/textures.bin", + "../resources/craftax/textures.bin", + "../../resources/craftax/textures.bin", + }; + FILE* f = NULL; + for (size_t i = 0; i < sizeof(candidates)/sizeof(candidates[0]); i++) { + f = fopen(candidates[i], "rb"); + if (f) break; + } + if (!f) { + fprintf(stderr, "craftax_classic: textures.bin not found in resources/craftax -- run ocean/craftax/pack_textures.py\n"); + exit(1); + } + const size_t tile_bytes = CC_TEX_TILE_PX * CC_TEX_TILE_PX * 4; + uint8_t* buf = (uint8_t*)malloc(tile_bytes); + for (int i = 0; i < CC_TEX_NUM; i++) { + if (fread(buf, 1, tile_bytes, f) != tile_bytes) { + fprintf(stderr, "craftax_classic: short read on textures.bin at tile %d\n", i); + exit(1); + } + Image img = { + .data = buf, + .width = CC_TEX_TILE_PX, + .height = CC_TEX_TILE_PX, + .mipmaps = 1, + .format = PIXELFORMAT_UNCOMPRESSED_R8G8B8A8, + }; + cc_textures[i] = LoadTextureFromImage(img); + SetTextureFilter(cc_textures[i], TEXTURE_FILTER_POINT); + } + free(buf); + fclose(f); + cc_textures_loaded = true; +} + +static int cc_player_tex_id(int8_t dir, bool sleeping) { + if (sleeping) return CC_TEX_PLAYER_SLEEP; + switch (dir) { + case 1: return CC_TEX_PLAYER_LEFT; + case 2: return CC_TEX_PLAYER_RIGHT; + case 3: return CC_TEX_PLAYER_UP; + case 4: return CC_TEX_PLAYER_DOWN; + default: return CC_TEX_PLAYER_DOWN; + } +} + +static int cc_arrow_tex_id(int8_t dr, int8_t dc) { + if (dr < 0) return CC_TEX_ARROW_UP; + if (dr > 0) return CC_TEX_ARROW_DOWN; + if (dc < 0) return CC_TEX_ARROW_LEFT; + return CC_TEX_ARROW_RIGHT; +} + +static void cc_draw_tile(int tex_id, int dst_x, int dst_y) { + if (tex_id < 0 || tex_id >= CC_TEX_NUM) return; + Rectangle src = {0, 0, CC_TEX_TILE_PX, CC_TEX_TILE_PX}; + Rectangle dst = {(float)dst_x, (float)dst_y, CC_TEX_DRAW_PX, CC_TEX_DRAW_PX}; + DrawTexturePro(cc_textures[tex_id], src, dst, (Vector2){0, 0}, 0.0f, WHITE); +} + +static void c_render(CraftaxClassic* env) { + const int view_w = CC_RENDER_COLS * CC_TEX_DRAW_PX; + const int view_h = CC_RENDER_ROWS * CC_TEX_DRAW_PX; + const int hud_h = 60; + + if (!IsWindowReady()) { + InitWindow(view_w, view_h + hud_h, "PufferLib Craftax-Classic"); + SetTargetFPS(30); + } + if (!cc_textures_loaded) cc_load_textures(); + if (IsKeyDown(KEY_ESCAPE)) exit(0); + + int pr = env->player_r; + int pc = env->player_c; + int half_r = CC_RENDER_ROWS / 2; + int half_c = CC_RENDER_COLS / 2; + + BeginDrawing(); + ClearBackground(BLACK); + + for (int vr = 0; vr < CC_RENDER_ROWS; vr++) { + for (int vc = 0; vc < CC_RENDER_COLS; vc++) { + int wr = pr - half_r + vr; + int wc = pc - half_c + vc; + int dst_x = vc * CC_TEX_DRAW_PX; + int dst_y = vr * CC_TEX_DRAW_PX; + + int blk = BLK_OUT_OF_BOUNDS; + if (in_bounds(wr, wc)) blk = map_get(env, wr, wc); + if (blk < 0 || blk >= 17) blk = 0; + cc_draw_tile(blk, dst_x, dst_y); + } + } + + // Mobs + for (int i = 0; i < MAX_ZOMBIES; i++) { + if (!env->zombie_mask[i]) continue; + int vr = env->zombie_r[i] - pr + half_r; + int vc = env->zombie_c[i] - pc + half_c; + if (vr < 0 || vr >= CC_RENDER_ROWS || vc < 0 || vc >= CC_RENDER_COLS) continue; + cc_draw_tile(CC_TEX_MOB_ZOMBIE, vc * CC_TEX_DRAW_PX, vr * CC_TEX_DRAW_PX); + } + for (int i = 0; i < MAX_SKELETONS; i++) { + if (!env->skel_mask[i]) continue; + int vr = env->skel_r[i] - pr + half_r; + int vc = env->skel_c[i] - pc + half_c; + if (vr < 0 || vr >= CC_RENDER_ROWS || vc < 0 || vc >= CC_RENDER_COLS) continue; + cc_draw_tile(CC_TEX_MOB_SKELETON, vc * CC_TEX_DRAW_PX, vr * CC_TEX_DRAW_PX); + } + for (int i = 0; i < MAX_COWS; i++) { + if (!env->cow_mask[i]) continue; + int vr = env->cow_r[i] - pr + half_r; + int vc = env->cow_c[i] - pc + half_c; + if (vr < 0 || vr >= CC_RENDER_ROWS || vc < 0 || vc >= CC_RENDER_COLS) continue; + cc_draw_tile(CC_TEX_MOB_COW, vc * CC_TEX_DRAW_PX, vr * CC_TEX_DRAW_PX); + } + for (int i = 0; i < MAX_ARROWS; i++) { + if (!env->arrow_mask[i]) continue; + int vr = env->arrow_r[i] - pr + half_r; + int vc = env->arrow_c[i] - pc + half_c; + if (vr < 0 || vr >= CC_RENDER_ROWS || vc < 0 || vc >= CC_RENDER_COLS) continue; + cc_draw_tile(cc_arrow_tex_id(env->arrow_dr[i], env->arrow_dc[i]), + vc * CC_TEX_DRAW_PX, vr * CC_TEX_DRAW_PX); + } + + // Player in center + cc_draw_tile(cc_player_tex_id(env->player_dir, env->is_sleeping), + half_c * CC_TEX_DRAW_PX, half_r * CC_TEX_DRAW_PX); + + // Night dim + if (env->light_level < 1.0f) { + unsigned char a = (unsigned char)((1.0f - env->light_level) * 140.0f); + DrawRectangle(0, 0, view_w, view_h, (Color){0, 0, 40, a}); + } + + // HUD + int hud_y = view_h; + DrawRectangle(0, hud_y, view_w, hud_h, (Color){20, 20, 20, 255}); + DrawText(TextFormat("HP:%d F:%d D:%d E:%d t:%d light:%.2f", + env->health, env->food, env->drink, env->energy, + env->timestep, env->light_level), + 4, hud_y + 4, 14, WHITE); + int ach_count = 0; + for (int i = 0; i < NUM_ACHIEVEMENTS; i++) ach_count += env->achievements[i] ? 1 : 0; + DrawText(TextFormat("ach:%d/%d ret:%.2f len:%d", ach_count, NUM_ACHIEVEMENTS, + env->episode_return_accum, env->episode_length_accum), + 4, hud_y + 22, 14, (Color){180, 220, 180, 255}); + DrawText(TextFormat("inv: w=%d s=%d c=%d i=%d d=%d sap=%d pick w/s/i:%d/%d/%d sword w/s/i:%d/%d/%d", + env->inv[0], env->inv[1], env->inv[2], env->inv[3], env->inv[4], env->inv[5], + env->inv[6], env->inv[7], env->inv[8], env->inv[9], env->inv[10], env->inv[11]), + 4, hud_y + 40, 12, (Color){180, 180, 180, 255}); + EndDrawing(); +} diff --git a/resources/craftax/textures.bin b/resources/craftax/textures.bin new file mode 100644 index 0000000000..c13e14130a Binary files /dev/null and b/resources/craftax/textures.bin differ diff --git a/tests/craftax_convergence_bench.py b/tests/craftax_convergence_bench.py new file mode 100644 index 0000000000..b0aac95390 --- /dev/null +++ b/tests/craftax_convergence_bench.py @@ -0,0 +1,179 @@ +"""Compare convergence of Craftax Classic vs Full on overlapping achievements. + +Runs both envs through `uv run puffer train` back-to-back (default 10M env +steps each), then parses pufferlib's per-run JSON log and plots: + - mean episode score over env steps + - per-achievement unlock rate (for the 22 Classic-compatible achievements) + - wall-clock time to reach each score threshold + +The envs share the first 22 achievement IDs (Classic's entire set). Full +has 67 achievements total; the extra 45 are plotted separately so Full +isn't rewarded twice for reaching the same tier. + +Usage: + uv run python tests/craftax_convergence_bench.py --timesteps 10_000_000 + uv run python tests/craftax_convergence_bench.py --skip-train --plot-only +""" +import argparse +import json +import os +import subprocess +import sys +from pathlib import Path + +import numpy as np + + +REPO = Path(__file__).resolve().parent.parent +LOG_DIR = REPO / "logs" + +CLASSIC_ACHIEVEMENTS = [ + "collect_wood", "place_table", "eat_cow", "collect_sapling", "collect_drink", + "make_wood_pickaxe", "make_wood_sword", "place_plant", "defeat_zombie", + "collect_stone", "place_stone", "eat_plant", "defeat_skeleton", + "make_stone_pickaxe", "make_stone_sword", "wake_up", "place_furnace", + "collect_coal", "collect_iron", "collect_diamond", "make_iron_pickaxe", + "make_iron_sword", +] + +SCORE_THRESHOLDS = [1, 3, 5, 7, 10, 15] + + +def train(env_name, timesteps): + env_log_dir = LOG_DIR / env_name + env_log_dir.mkdir(parents=True, exist_ok=True) + before = {p.name for p in env_log_dir.glob("*.json")} + + # pufferlib._C is compiled for one env at a time; rebuild before each run. + build_cmd = [ + "uv", "run", "--with", "pybind11", "--with", "rich_argparse", + "./build.sh", env_name, + ] + print(f"\n=== rebuilding pufferlib._C for {env_name} ===") + subprocess.check_call(build_cmd, cwd=REPO) + + cmd = [ + "uv", "run", "--with", "pybind11", "--with", "rich_argparse", + "puffer", "train", env_name, + "--train.total-timesteps", str(int(timesteps)), + ] + print(f"\n=== training {env_name} for {timesteps:,} steps ===") + print(" ".join(cmd)) + subprocess.check_call(cmd, cwd=REPO) + after = {p.name for p in env_log_dir.glob("*.json")} + new = sorted(after - before) + if not new: + raise RuntimeError(f"no new log file under {env_log_dir}") + return env_log_dir / new[-1] + + +def load_run(path): + with open(path) as f: + raw = json.load(f) + m = raw["metrics"] + steps = np.array(m["agent_steps"], dtype=np.float64) + uptime = np.array(m["uptime"], dtype=np.float64) + score = np.array(m.get("env/score", [np.nan] * len(steps)), dtype=np.float64) + ach = {} + for name in CLASSIC_ACHIEVEMENTS: + key = f"env/{name}" + if key in m: + ach[name] = np.array(m[key], dtype=np.float64) + return {"steps": steps, "uptime": uptime, "score": score, "ach": ach, "path": str(path)} + + +def time_to_threshold(steps, score, threshold): + above = np.nonzero(score >= threshold)[0] + if len(above) == 0: + return None + return float(steps[above[0]]) + + +def print_summary(label, run): + print(f"\n--- {label} ({run['path']}) ---") + total_steps = int(run["steps"][-1]) + wall = run["uptime"][-1] + peak = float(np.nanmax(run["score"])) if run["score"].size else float("nan") + final = float(run["score"][-1]) if run["score"].size else float("nan") + print(f"total env steps: {total_steps:,} wall: {wall/60:.1f}min " + f"final score: {final:.2f} peak: {peak:.2f}") + print(f"time to score threshold (env steps):") + for t in SCORE_THRESHOLDS: + s = time_to_threshold(run["steps"], run["score"], t) + if s is None: + print(f" >={t:>2}: NOT REACHED") + else: + wall_at = run["uptime"][np.nonzero(run["score"] >= t)[0][0]] + print(f" >={t:>2}: {int(s):>12,} steps ({wall_at/60:5.1f} min)") + if run["ach"]: + print("final per-achievement unlock rate (mean over eval episodes):") + for name in CLASSIC_ACHIEVEMENTS: + if name in run["ach"]: + print(f" {name:<22s} {run['ach'][name][-1]:.3f}") + + +def plot(runs, out_path): + try: + import matplotlib.pyplot as plt + except Exception as exc: + print(f"matplotlib unavailable ({exc}); skipping plot.") + return + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + for label, run in runs.items(): + axes[0].plot(run["steps"] / 1e6, run["score"], label=label) + axes[0].set_xlabel("env steps (M)") + axes[0].set_ylabel("mean episode score (achievements)") + axes[0].set_title("Convergence: score vs env steps") + axes[0].legend() + axes[0].grid(True, alpha=0.3) + + for label, run in runs.items(): + axes[1].plot(run["uptime"] / 60, run["score"], label=label) + axes[1].set_xlabel("wall time (min)") + axes[1].set_ylabel("mean episode score") + axes[1].set_title("Convergence: score vs wall time") + axes[1].legend() + axes[1].grid(True, alpha=0.3) + + fig.tight_layout() + fig.savefig(out_path, dpi=120) + print(f"\nwrote plot to {out_path}") + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--timesteps", type=float, default=10_000_000, + help="env steps per training run") + ap.add_argument("--skip-train", action="store_true", + help="skip training; use most recent log in logs/{env}") + ap.add_argument("--classic-log", type=str, default=None, + help="explicit path to craftax_classic log json") + ap.add_argument("--full-log", type=str, default=None, + help="explicit path to craftax log json") + ap.add_argument("--out", type=str, default="craftax_convergence.png") + args = ap.parse_args() + + runs = {} + for label, env_name, override in [ + ("Classic", "craftax_classic", args.classic_log), + ("Full", "craftax", args.full_log), + ]: + if override: + path = Path(override) + elif args.skip_train: + candidates = sorted((LOG_DIR / env_name).glob("*.json")) + if not candidates: + print(f"no logs for {env_name} under {LOG_DIR/env_name}; skipping.") + continue + path = candidates[-1] + else: + path = train(env_name, args.timesteps) + runs[label] = load_run(path) + print_summary(label, runs[label]) + + if len(runs) >= 1: + plot(runs, args.out) + + +if __name__ == "__main__": + main() diff --git a/tests/craftax_parity.py b/tests/craftax_parity.py new file mode 100644 index 0000000000..bdf9939c1f --- /dev/null +++ b/tests/craftax_parity.py @@ -0,0 +1,1347 @@ +import argparse +import ctypes +import os +import subprocess +import tempfile +from collections import deque +from pathlib import Path + +os.environ.setdefault("JAX_PLATFORM_NAME", "cpu") +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") + +import jax +import jax.numpy as jnp +import numpy as np + +from craftax.craftax_env import make_craftax_env_from_name +try: + from craftax_state_fixtures import ( + CraftaxState, + craftax_state_to_jax, + flatten_env_state, + ) +except ModuleNotFoundError: + from tests.craftax_state_fixtures import ( + CraftaxState, + craftax_state_to_jax, + flatten_env_state, + ) + + +OBS_SIZE = 8268 +NUM_ACTIONS = 43 + +OBS_ROWS = 9 +OBS_COLS = 11 +NUM_BLOCK_TYPES = 37 +NUM_ITEM_TYPES = 5 +NUM_MOB_CLASSES = 5 +NUM_MOB_TYPES = 8 +NUM_TILE_CHANNELS = NUM_BLOCK_TYPES + NUM_ITEM_TYPES + NUM_MOB_CLASSES * NUM_MOB_TYPES + 1 +MAP_OBS_SIZE = OBS_ROWS * OBS_COLS * NUM_TILE_CHANNELS +MAP_SIZE = 48 +NUM_LEVELS = 9 +MONSTERS_KILLED_TO_CLEAR_LEVEL = 8 + +NOOP = 0 +LEFT = 1 +RIGHT = 2 +UP = 3 +DOWN = 4 +DO = 5 +PLACE_STONE = 7 +PLACE_TABLE = 8 +PLACE_FURNACE = 9 +MAKE_WOOD_PICKAXE = 11 +MAKE_STONE_PICKAXE = 12 +MAKE_IRON_PICKAXE = 13 +MAKE_WOOD_SWORD = 14 +MAKE_STONE_SWORD = 15 +MAKE_IRON_SWORD = 16 +DESCEND = 18 +MAKE_DIAMOND_PICKAXE = 20 +MAKE_DIAMOND_SWORD = 21 +MAKE_IRON_ARMOUR = 22 +MAKE_DIAMOND_ARMOUR = 23 +SHOOT_ARROW = 24 +MAKE_ARROW = 25 +CAST_FIREBALL = 26 +CAST_ICEBALL = 27 +PLACE_TORCH = 28 +MAKE_TORCH = 38 + +BLOCK_WATER = 3 +BLOCK_LAVA = 14 +ITEM_LADDER_DOWN = 2 + +MOVE_ACTIONS = np.asarray([LEFT, RIGHT, UP, DOWN], dtype=np.int32) +DIRS = { + LEFT: (0, -1), + RIGHT: (0, 1), + UP: (-1, 0), + DOWN: (1, 0), +} + +SOLID_BLOCKS = frozenset( + [ + 4, + 5, + 8, + 9, + 10, + 11, + 12, + 15, + 16, + 17, + 19, + 20, + 21, + 22, + 23, + 24, + 28, + 30, + 31, + 32, + 33, + 34, + 35, + ] +) + +INVENTORY_OBS_NAMES = [ + "inventory.wood", + "inventory.stone", + "inventory.coal", + "inventory.iron", + "inventory.diamond", + "inventory.sapphire", + "inventory.ruby", + "inventory.sapling", + "inventory.torches", + "inventory.arrows", + "inventory.books", + "inventory.pickaxe", + "inventory.sword", + "sword_enchantment", + "bow_enchantment", + "inventory.bow", + "inventory.potions.red", + "inventory.potions.green", + "inventory.potions.blue", + "inventory.potions.pink", + "inventory.potions.cyan", + "inventory.potions.yellow", + "player_health", + "player_food", + "player_drink", + "player_energy", + "player_mana", + "player_xp", + "player_dexterity", + "player_strength", + "player_intelligence", + "direction.left", + "direction.right", + "direction.up", + "direction.down", + "inventory.armour.0", + "inventory.armour.1", + "inventory.armour.2", + "inventory.armour.3", + "armour_enchantments.0", + "armour_enchantments.1", + "armour_enchantments.2", + "armour_enchantments.3", + "light_level", + "is_sleeping", + "is_resting", + "learned_spells.fireball", + "learned_spells.iceball", + "player_level", + "ladder_down_open", + "boss_vulnerable", +] + +MOB_CLASS_NAMES = [ + "melee_mobs", + "passive_mobs", + "ranged_mobs", + "mob_projectiles", + "player_projectiles", +] + +POLICIES = ("uniform", "combat", "descend", "suicide", "boss", "mixed") +MIXED_ORDER = ("uniform", "combat", "descend", "suicide", "boss") + + +def _preload_nccl(): + root = Path(__file__).resolve().parents[1] + nccl = root / ".venv/lib/python3.12/site-packages/nvidia/nccl/lib/libnccl.so.2" + if nccl.exists(): + ctypes.CDLL(str(nccl), mode=ctypes.RTLD_GLOBAL) + + +def import_c_env(): + _preload_nccl() + import pufferlib._C as cmod + + env_name = getattr(cmod, "env_name", None) + if env_name != "craftax": + raise RuntimeError( + f"pufferlib._C is compiled for {env_name!r}, expected 'craftax'. " + "Run: uv run --with pybind11 --with rich_argparse ./build.sh craftax" + ) + return cmod + + +def float_view(ptr, count): + array_t = ctypes.c_float * count + return np.ctypeslib.as_array(array_t.from_address(ptr)) + + +def _stack_states(states): + return jax.tree_util.tree_map(lambda *xs: jnp.stack(xs), *states) + + +class JaxCraftaxBatch: + def __init__(self, seeds, resetter=None): + self.env = make_craftax_env_from_name("Craftax-Symbolic-v1", auto_reset=False) + self.params = self.env.default_params + self.num_envs = len(seeds) + self.resetter = resetter + self.reset_keys = [] + rngs = [] + states = [] + obs = [] + for seed in seeds: + rng = jax.random.PRNGKey(int(seed)) + rng, reset_key = jax.random.split(rng) + env_obs, state = self.env.reset(reset_key, self.params) + rngs.append(rng) + self.reset_keys.append(np.asarray(reset_key, dtype=np.uint32)) + states.append(state) + obs.append(np.asarray(env_obs, dtype=np.float32).reshape(-1)) + + self.rngs = jnp.stack(rngs) + self.states = _stack_states(states) + self.obs = np.stack(obs, axis=0) + self._step_batch = self._make_step_batch() + + def _make_step_batch(self): + env = self.env + params = self.params + + def step_one(key, state, action): + step_rng, reset_key = jax.random.split(key, 2) + obs, next_state, reward, done, _info = env.step( + step_rng, + state, + action, + params, + ) + return obs, next_state, reward, done, reset_key + + def step_batch(rngs, states, actions): + split_keys = jax.vmap(lambda key: jax.random.split(key, 2))(rngs) + next_rngs = split_keys[:, 0] + step_keys = split_keys[:, 1] + obs, next_states, rewards, dones, reset_keys = jax.vmap(step_one)( + step_keys, states, actions + ) + return next_rngs, next_states, obs, rewards, dones, reset_keys + + return jax.jit(step_batch) + + def step(self, actions): + actions = jnp.asarray(actions, dtype=jnp.int32) + ( + self.rngs, + self.states, + obs, + rewards, + dones, + reset_keys, + ) = self._step_batch(self.rngs, self.states, actions) + self.obs = np.asarray(obs, dtype=np.float32).reshape(self.num_envs, -1).copy() + dones_np = np.asarray(dones, dtype=np.bool_) + reset_keys_np = np.asarray(reset_keys, dtype=np.uint32) + if self.resetter is not None and np.any(dones_np): + for env_i, done in enumerate(dones_np): + if not bool(done): + continue + reset_state, reset_obs = self.resetter.reset( + reset_keys_np[env_i], + self.state_at(env_i), + ) + self.states = jax.tree_util.tree_map( + lambda batched, value: batched.at[env_i].set(value), + self.states, + reset_state, + ) + self.obs[env_i] = reset_obs + return ( + self.obs, + np.asarray(rewards, dtype=np.float32), + dones_np, + reset_keys_np, + ) + + def state_at(self, env_i): + return jax.tree_util.tree_map(lambda leaf: leaf[env_i], self.states) + + +class PolicySnapshot: + def __init__(self, states): + self.level = np.asarray(states.player_level, dtype=np.int32) + self.position = np.asarray(states.player_position, dtype=np.int32) + self.direction = np.asarray(states.player_direction, dtype=np.int32) + self.health = np.asarray(states.player_health, dtype=np.float32) + self.mana = np.asarray(states.player_mana, dtype=np.int32) + self.learned_spells = np.asarray(states.learned_spells, dtype=np.bool_) + + self.inventory = states.inventory + self.wood = np.asarray(self.inventory.wood, dtype=np.int32) + self.stone = np.asarray(self.inventory.stone, dtype=np.int32) + self.coal = np.asarray(self.inventory.coal, dtype=np.int32) + self.iron = np.asarray(self.inventory.iron, dtype=np.int32) + self.diamond = np.asarray(self.inventory.diamond, dtype=np.int32) + self.bow = np.asarray(self.inventory.bow, dtype=np.int32) + self.arrows = np.asarray(self.inventory.arrows, dtype=np.int32) + self.torches = np.asarray(self.inventory.torches, dtype=np.int32) + + num_envs = int(self.level.shape[0]) + env_idx = np.arange(num_envs) + + full_map = np.asarray(states.map, dtype=np.int32) + full_item_map = np.asarray(states.item_map, dtype=np.int32) + full_mob_map = np.asarray(states.mob_map, dtype=np.bool_) + full_monsters_killed = np.asarray(states.monsters_killed, dtype=np.int32) + full_down_ladders = np.asarray(states.down_ladders, dtype=np.int32) + + self.map = full_map[env_idx, self.level] + self.item_map = full_item_map[env_idx, self.level] + self.mob_map = full_mob_map[env_idx, self.level] + self.monsters_killed = full_monsters_killed[env_idx, self.level] + self.down_ladders = full_down_ladders[env_idx, self.level] + + self.melee_pos, self.melee_mask, self.melee_type = self._take_mobs( + states.melee_mobs, env_idx + ) + self.passive_pos, self.passive_mask, self.passive_type = self._take_mobs( + states.passive_mobs, env_idx + ) + self.ranged_pos, self.ranged_mask, self.ranged_type = self._take_mobs( + states.ranged_mobs, env_idx + ) + ( + self.mob_projectile_pos, + self.mob_projectile_mask, + self.mob_projectile_type, + ) = self._take_mobs(states.mob_projectiles, env_idx) + ( + self.player_projectile_pos, + self.player_projectile_mask, + self.player_projectile_type, + ) = self._take_mobs(states.player_projectiles, env_idx) + + def _take_mobs(self, mobs, env_idx): + pos = np.asarray(mobs.position, dtype=np.int32)[env_idx, self.level] + mask = np.asarray(mobs.mask, dtype=np.bool_)[env_idx, self.level] + type_id = np.asarray(mobs.type_id, dtype=np.int32)[env_idx, self.level] + return pos, mask, type_id + + +class ResetVerifier: + def __init__(self): + root = Path(__file__).resolve().parents[1] + source = r""" + #include + #include + #define CRAFTAX_ENABLE_ENV_IMPL + #include "ocean/craftax/craftax.h" + #include "ocean/craftax/step_crafting.h" + #include "ocean/craftax/step_update_mobs.h" + #include "ocean/craftax/step_spawn_mobs.h" + + void reset_from_key( + uint32_t key0, + uint32_t key1, + CraftaxState* out, + float* obs + ) { + CraftaxThreefryKey reset_key = {{key0, key1}}; + craftax_reset_state_from_reset_key(out, reset_key); + craftax_encode_native_observation(out, obs); + } + """ + self._tmp = tempfile.TemporaryDirectory() + tmp_path = Path(self._tmp.name) + src = tmp_path / "craftax_reset_verify.c" + so = tmp_path / "craftax_reset_verify.so" + src.write_text(source) + subprocess.run( + [ + "cc", + "-std=c99", + "-O2", + "-shared", + "-fPIC", + "-I", + str(root), + "-I", + str(root / "raylib-5.5_linux_amd64/include"), + str(src), + "-lm", + "-o", + str(so), + ], + check=True, + cwd=root, + ) + self.lib = ctypes.CDLL(str(so)) + self.lib.reset_from_key.argtypes = [ + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.POINTER(CraftaxState), + ctypes.POINTER(ctypes.c_float), + ] + self.lib.reset_from_key.restype = None + + def reset(self, reset_key, template): + c_state = CraftaxState() + c_obs = np.empty(OBS_SIZE, dtype=np.float32) + key = np.asarray(reset_key, dtype=np.uint32) + self.lib.reset_from_key( + ctypes.c_uint32(int(key[0])), + ctypes.c_uint32(int(key[1])), + ctypes.byref(c_state), + c_obs.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + ) + return craftax_state_to_jax(c_state, template=template), c_obs + + def compare(self, jax_state, jax_obs, reset_key, seed, step, policy, atol): + c_jax_state, c_obs = self.reset(reset_key, jax_state) + + obs_diff = first_obs_diff(jax_obs, c_obs, atol) + state_diff = first_state_diff(jax_state, c_jax_state, atol) + if obs_diff is not None: + idx, max_diff, jax_value, c_value = obs_diff + key = np.asarray(reset_key, dtype=np.uint32) + print( + "RESET DIVERGENCE " + f"seed={seed} step={step} policy={policy} " + f"reset_key=[{int(key[0])},{int(key[1])}] " + f"obs_index={idx} section={section_for_index(idx)} " + f"subsystem={subsystem_for_section(section_for_index(idx))} " + f"abs_diff={max_diff:.8g} jax={jax_value:.8g} c={c_value:.8g}" + ) + if state_diff is not None: + name, index, state_max_diff, state_jax_value, state_c_value = state_diff + print( + "reset_state_first_diff: " + f"field={name} index={index} " + f"abs_diff={state_max_diff:.8g} " + f"jax={state_jax_value} c={state_c_value}" + ) + return False + + if state_diff is not None: + name, index, max_diff, jax_value, c_value = state_diff + key = np.asarray(reset_key, dtype=np.uint32) + print( + "RESET STATE DIVERGENCE " + f"seed={seed} step={step} policy={policy} " + f"reset_key=[{int(key[0])},{int(key[1])}] " + f"field={name} index={index} abs_diff={max_diff:.8g} " + f"jax={jax_value} c={c_value}" + ) + return False + return True + + +_RESET_VERIFIER = None + + +def get_reset_verifier(enabled): + global _RESET_VERIFIER + if not enabled: + return None + if _RESET_VERIFIER is None: + _RESET_VERIFIER = ResetVerifier() + return _RESET_VERIFIER + + +def make_c_vec(cmod, num_envs, seed_offset, num_threads=1): + args = { + "vec": { + "total_agents": num_envs, + "num_buffers": 1, + "num_threads": num_threads, + }, + "env": { + "seed_offset": seed_offset, + }, + } + vec = cmod.create_vec(args, 0) + if vec.obs_size != OBS_SIZE: + raise RuntimeError(f"C obs_size={vec.obs_size}, expected {OBS_SIZE}") + if vec.num_atns != 1: + raise RuntimeError(f"C num_atns={vec.num_atns}, expected 1") + if list(vec.act_sizes) != [NUM_ACTIONS]: + raise RuntimeError(f"C act_sizes={vec.act_sizes}, expected [{NUM_ACTIONS}]") + vec.reset() + obs = float_view(vec.obs_ptr, num_envs * OBS_SIZE).reshape(num_envs, OBS_SIZE) + rewards = float_view(vec.rewards_ptr, num_envs) + terminals = float_view(vec.terminals_ptr, num_envs) + return vec, obs, rewards, terminals + + +def action_plan(seeds, steps, action_seed): + rng = np.random.default_rng(action_seed) + return rng.integers(0, NUM_ACTIONS, size=(steps, len(seeds)), dtype=np.int32) + + +def first_obs_diff(ref, got, atol): + diff = np.abs(ref - got) + idx = int(np.argmax(diff)) + max_diff = float(diff[idx]) + if max_diff <= atol: + return None + return idx, max_diff, float(ref[idx]), float(got[idx]) + + +def _format_index(index): + index = np.asarray(index) + if index.ndim == 0: + return "scalar" + return ",".join(str(int(i)) for i in index) + + +def first_state_diff(jax_state, c_state, atol): + jax_flat = flatten_env_state(jax_state) + c_flat = flatten_env_state(c_state) + if jax_flat.keys() != c_flat.keys(): + missing = sorted(jax_flat.keys() - c_flat.keys()) + extra = sorted(c_flat.keys() - jax_flat.keys()) + return "state_keys", "scalar", 1.0, f"missing_c={missing}", f"extra_c={extra}" + + for name, jax_value in jax_flat.items(): + c_value = c_flat[name] + if np.asarray(jax_value).dtype.kind == "f": + diff = np.abs(np.asarray(jax_value) - np.asarray(c_value)) + if diff.size == 0: + continue + idx = np.unravel_index(int(np.argmax(diff)), diff.shape) + max_diff = float(diff[idx]) + if max_diff > atol: + return ( + name, + _format_index(np.asarray(idx)), + max_diff, + float(np.asarray(jax_value)[idx]), + float(np.asarray(c_value)[idx]), + ) + else: + neq = np.asarray(jax_value) != np.asarray(c_value) + if np.any(neq): + idx = np.argwhere(neq)[0] if np.asarray(neq).ndim else np.asarray(()) + idx_tuple = tuple(int(i) for i in np.asarray(idx).reshape(-1)) + return ( + name, + _format_index(idx), + 1.0, + np.asarray(jax_value)[idx_tuple].item() + if idx_tuple + else np.asarray(jax_value).item(), + np.asarray(c_value)[idx_tuple].item() + if idx_tuple + else np.asarray(c_value).item(), + ) + return None + + +def section_for_index(idx): + if idx < MAP_OBS_SIZE: + tile = idx // NUM_TILE_CHANNELS + channel = idx % NUM_TILE_CHANNELS + row = tile // OBS_COLS + col = tile % OBS_COLS + if channel < NUM_BLOCK_TYPES: + return f"map_one_hot[row={row},col={col},block={channel}]" + channel -= NUM_BLOCK_TYPES + if channel < NUM_ITEM_TYPES: + return f"item_one_hot[row={row},col={col},item={channel}]" + channel -= NUM_ITEM_TYPES + if channel < NUM_MOB_CLASSES * NUM_MOB_TYPES: + mob_class = channel // NUM_MOB_TYPES + mob_type = channel % NUM_MOB_TYPES + return ( + f"{MOB_CLASS_NAMES[mob_class]}_type_{mob_type}" + f"[row={row},col={col}]" + ) + return f"light[row={row},col={col}]" + + inv_idx = idx - MAP_OBS_SIZE + if 0 <= inv_idx < len(INVENTORY_OBS_NAMES): + return INVENTORY_OBS_NAMES[inv_idx] + return f"inventory_or_special[{inv_idx}]" + + +def subsystem_for_section(section): + if section.startswith("map_one_hot"): + return "symbolic_observation.map" + if section.startswith("item_one_hot"): + return "symbolic_observation.item_or_ladder" + if section.startswith("melee_mobs") or section.startswith("passive_mobs"): + return "mobs.update_or_observation" + if section.startswith("ranged_mobs") or section.startswith("mob_projectiles"): + return "projectiles_or_ranged_mobs" + if section.startswith("player_projectiles"): + return "player_projectiles" + if section.startswith("light[") or section == "light_level": + return "light" + if section.startswith("inventory."): + return "inventory" + if section.startswith("player_"): + return "player_intrinsics" + if section.startswith("direction."): + return "movement" + if section in {"ladder_down_open", "player_level"}: + return "floor_change" + if section == "boss_vulnerable": + return "boss_logic" + return "state_or_observation" + + +def compare_reset(ref_obs, c_obs, seeds, atol): + for env_i, seed in enumerate(seeds): + diff = first_obs_diff(ref_obs[env_i], c_obs[env_i], atol) + if diff is not None: + idx, max_diff, ref_value, c_value = diff + section = section_for_index(idx) + print( + "RESET DIVERGENCE " + f"seed={seed} obs_index={idx} section={section} " + f"subsystem={subsystem_for_section(section)} " + f"abs_diff={max_diff:.8g} jax={ref_value:.8g} c={c_value:.8g}" + ) + return False + return True + + +def _in_bounds(pos): + return 0 <= int(pos[0]) < MAP_SIZE and 0 <= int(pos[1]) < MAP_SIZE + + +def _action_toward_delta(delta): + dr, dc = int(delta[0]), int(delta[1]) + if abs(dr) > abs(dc): + return DOWN if dr > 0 else UP + if dc != 0: + return RIGHT if dc > 0 else LEFT + if dr != 0: + return DOWN if dr > 0 else UP + return NOOP + + +def _action_to_neighbor(start, target): + delta = np.asarray(target, dtype=np.int32) - np.asarray(start, dtype=np.int32) + if abs(int(delta[0])) + abs(int(delta[1])) != 1: + return None + return _action_toward_delta(delta) + + +def _passable_map(snapshot, env_i, allow_danger=False, allow_mobs=False): + level_map = snapshot.map[env_i] + passable = np.ones((MAP_SIZE, MAP_SIZE), dtype=np.bool_) + for block in SOLID_BLOCKS: + passable &= level_map != block + if not allow_danger: + passable &= level_map != BLOCK_WATER + passable &= level_map != BLOCK_LAVA + if not allow_mobs: + passable &= ~snapshot.mob_map[env_i] + return passable + + +def _valid_move_actions(snapshot, env_i, allow_danger=False): + pos = snapshot.position[env_i] + passable = _passable_map(snapshot, env_i, allow_danger=allow_danger) + actions = [] + for action, delta in DIRS.items(): + target = pos + np.asarray(delta, dtype=np.int32) + if _in_bounds(target) and passable[int(target[0]), int(target[1])]: + actions.append(action) + return actions + + +def _random_move(snapshot, env_i, rng, allow_danger=False): + actions = _valid_move_actions(snapshot, env_i, allow_danger=allow_danger) + if actions: + return int(rng.choice(actions)) + return int(rng.choice(MOVE_ACTIONS)) + + +def _bfs_first_action(snapshot, env_i, target, rng, allow_danger=False): + start = tuple(int(x) for x in snapshot.position[env_i]) + target = tuple(int(x) for x in np.asarray(target, dtype=np.int32)) + if start == target: + return NOOP + + passable = _passable_map(snapshot, env_i, allow_danger=allow_danger) + passable[start] = True + if not _in_bounds(target) or not passable[target]: + return _greedy_action(snapshot, env_i, np.asarray(target), rng, allow_danger) + + visited = np.zeros((MAP_SIZE, MAP_SIZE), dtype=np.bool_) + visited[start] = True + queue = deque() + for action in rng.permutation(MOVE_ACTIONS): + delta = DIRS[int(action)] + row = start[0] + delta[0] + col = start[1] + delta[1] + if not (0 <= row < MAP_SIZE and 0 <= col < MAP_SIZE): + continue + if visited[row, col] or not passable[row, col]: + continue + if (row, col) == target: + return int(action) + visited[row, col] = True + queue.append((row, col, int(action))) + + while queue: + row, col, first_action = queue.popleft() + for action in MOVE_ACTIONS: + delta = DIRS[int(action)] + next_row = row + delta[0] + next_col = col + delta[1] + if not (0 <= next_row < MAP_SIZE and 0 <= next_col < MAP_SIZE): + continue + if visited[next_row, next_col] or not passable[next_row, next_col]: + continue + if (next_row, next_col) == target: + return int(first_action) + visited[next_row, next_col] = True + queue.append((next_row, next_col, first_action)) + + return _greedy_action(snapshot, env_i, np.asarray(target), rng, allow_danger) + + +def _greedy_action(snapshot, env_i, target, rng, allow_danger=False): + pos = snapshot.position[env_i] + actions = _valid_move_actions(snapshot, env_i, allow_danger=allow_danger) + if not actions: + return int(rng.choice(MOVE_ACTIONS)) + scored = [] + for action in actions: + delta = np.asarray(DIRS[action], dtype=np.int32) + next_pos = pos + delta + dist = int(np.abs(next_pos - target).sum()) + scored.append((dist, action)) + best_dist = min(dist for dist, _action in scored) + best = [action for dist, action in scored if dist == best_dist] + return int(rng.choice(best)) + + +def _nearest_target(snapshot, env_i, positions): + if len(positions) == 0: + return None + pos = snapshot.position[env_i] + positions = np.asarray(positions, dtype=np.int32) + distances = np.abs(positions - pos).sum(axis=1) + return positions[int(np.argmin(distances))] + + +def _live_mobs(snapshot, env_i, include_passive=True, include_projectiles=False): + groups = [ + (0, snapshot.melee_pos[env_i], snapshot.melee_mask[env_i], snapshot.melee_type[env_i]), + (2, snapshot.ranged_pos[env_i], snapshot.ranged_mask[env_i], snapshot.ranged_type[env_i]), + ] + if include_passive: + groups.append( + ( + 1, + snapshot.passive_pos[env_i], + snapshot.passive_mask[env_i], + snapshot.passive_type[env_i], + ) + ) + if include_projectiles: + groups.append( + ( + 3, + snapshot.mob_projectile_pos[env_i], + snapshot.mob_projectile_mask[env_i], + snapshot.mob_projectile_type[env_i], + ) + ) + + mobs = [] + for mob_class, positions, masks, type_ids in groups: + for index, mask in enumerate(masks): + if bool(mask): + mobs.append((mob_class, index, positions[index], int(type_ids[index]))) + return mobs + + +def _mob_positions(snapshot, env_i, include_passive=True, include_projectiles=False): + return [ + np.asarray(position, dtype=np.int32) + for _cls, _idx, position, _type_id in _live_mobs( + snapshot, + env_i, + include_passive=include_passive, + include_projectiles=include_projectiles, + ) + ] + + +def _projectile_slot_available(snapshot, env_i): + return int(np.count_nonzero(snapshot.player_projectile_mask[env_i])) < 3 + + +def _target_in_current_line(snapshot, env_i, target): + pos = snapshot.position[env_i] + direction = int(snapshot.direction[env_i]) + delta = np.asarray(target, dtype=np.int32) - pos + if direction == LEFT: + return int(delta[0]) == 0 and int(delta[1]) < 0 + if direction == RIGHT: + return int(delta[0]) == 0 and int(delta[1]) > 0 + if direction == UP: + return int(delta[1]) == 0 and int(delta[0]) < 0 + if direction == DOWN: + return int(delta[1]) == 0 and int(delta[0]) > 0 + return False + + +def _combat_action(snapshot, env_i, rng): + pos = snapshot.position[env_i] + mobs = _live_mobs(snapshot, env_i, include_passive=True) + mob_positions = [mob[2] for mob in mobs] + adjacent = [ + np.asarray(position, dtype=np.int32) + for position in mob_positions + if int(np.abs(np.asarray(position) - pos).sum()) == 1 + ] + + for target in adjacent: + action = _action_to_neighbor(pos, target) + if action == int(snapshot.direction[env_i]) and rng.random() < 0.75: + return DO + if adjacent: + target = adjacent[int(rng.integers(0, len(adjacent)))] + return int(_action_to_neighbor(pos, target)) + + has_projectile_slot = _projectile_slot_available(snapshot, env_i) + projectile_actions = [] + if has_projectile_slot and int(snapshot.bow[env_i]) >= 1 and int(snapshot.arrows[env_i]) >= 1: + projectile_actions.append(SHOOT_ARROW) + if has_projectile_slot and int(snapshot.mana[env_i]) >= 2: + if bool(snapshot.learned_spells[env_i, 0]): + projectile_actions.append(CAST_FIREBALL) + if bool(snapshot.learned_spells[env_i, 1]): + projectile_actions.append(CAST_ICEBALL) + + if projectile_actions and mob_positions: + line_targets = [ + target + for target in mob_positions + if _target_in_current_line(snapshot, env_i, target) + ] + if line_targets and rng.random() < 0.8: + return int(rng.choice(projectile_actions)) + + axis_targets = [ + target + for target in mob_positions + if int(target[0]) == int(pos[0]) or int(target[1]) == int(pos[1]) + ] + if axis_targets: + target = _nearest_target(snapshot, env_i, axis_targets) + return _action_toward_delta(target - pos) + + if mob_positions: + target = _nearest_target(snapshot, env_i, mob_positions) + return _bfs_first_action(snapshot, env_i, target, rng) + + return _random_move(snapshot, env_i, rng) + + +def _craft_or_place_action(snapshot, env_i, rng): + options = [] + if int(snapshot.wood[env_i]) > 0: + options.extend([PLACE_TABLE, MAKE_WOOD_PICKAXE, MAKE_WOOD_SWORD]) + if int(snapshot.stone[env_i]) > 0: + options.append(PLACE_STONE) + if int(snapshot.stone[env_i]) >= 4: + options.append(PLACE_FURNACE) + if int(snapshot.stone[env_i]) > 0 and int(snapshot.wood[env_i]) > 0: + options.extend([MAKE_STONE_PICKAXE, MAKE_STONE_SWORD]) + if int(snapshot.iron[env_i]) > 0 and int(snapshot.wood[env_i]) > 0: + options.extend([MAKE_IRON_PICKAXE, MAKE_IRON_SWORD, MAKE_IRON_ARMOUR]) + if int(snapshot.diamond[env_i]) > 0 and int(snapshot.wood[env_i]) > 0: + options.extend([MAKE_DIAMOND_PICKAXE, MAKE_DIAMOND_SWORD, MAKE_DIAMOND_ARMOUR]) + if int(snapshot.wood[env_i]) > 0 and int(snapshot.stone[env_i]) > 0: + options.append(MAKE_ARROW) + if int(snapshot.coal[env_i]) > 0 and int(snapshot.wood[env_i]) > 0: + options.append(MAKE_TORCH) + if int(snapshot.torches[env_i]) > 0: + options.append(PLACE_TORCH) + if not options: + return None + return int(rng.choice(options)) + + +def _descend_action(snapshot, env_i, rng): + level = int(snapshot.level[env_i]) + pos = snapshot.position[env_i] + if level >= NUM_LEVELS - 1: + return _combat_action(snapshot, env_i, rng) + + row, col = int(pos[0]), int(pos[1]) + on_down_ladder = int(snapshot.item_map[env_i, row, col]) == ITEM_LADDER_DOWN + ladder_open = int(snapshot.monsters_killed[env_i]) >= MONSTERS_KILLED_TO_CLEAR_LEVEL + if on_down_ladder and ladder_open: + return DESCEND + + mobs = _mob_positions(snapshot, env_i, include_passive=False) + if not ladder_open and mobs: + return _combat_action(snapshot, env_i, rng) + + if rng.random() < 0.12: + craft_action = _craft_or_place_action(snapshot, env_i, rng) + if craft_action is not None: + return craft_action + + ladder = snapshot.down_ladders[env_i] + if ladder_open: + return _bfs_first_action(snapshot, env_i, ladder, rng) + + if mobs: + return _combat_action(snapshot, env_i, rng) + return _random_move(snapshot, env_i, rng) + + +def _danger_adjacent_action(snapshot, env_i, rng): + pos = snapshot.position[env_i] + level_map = snapshot.map[env_i] + dangerous_actions = [] + for action, delta in DIRS.items(): + target = pos + np.asarray(delta, dtype=np.int32) + if not _in_bounds(target): + continue + block = int(level_map[int(target[0]), int(target[1])]) + if block in (BLOCK_WATER, BLOCK_LAVA) or bool( + snapshot.mob_map[env_i, int(target[0]), int(target[1])] + ): + dangerous_actions.append(action) + if dangerous_actions: + return int(rng.choice(dangerous_actions)) + return None + + +def _suicide_action(snapshot, env_i, rng): + adjacent = _danger_adjacent_action(snapshot, env_i, rng) + if adjacent is not None: + return adjacent + + hostile_positions = _mob_positions( + snapshot, env_i, include_passive=False, include_projectiles=True + ) + danger_blocks = np.argwhere( + (snapshot.map[env_i] == BLOCK_LAVA) | (snapshot.map[env_i] == BLOCK_WATER) + ) + + targets = [] + targets.extend(hostile_positions) + if danger_blocks.size: + targets.extend([danger_blocks[i] for i in range(danger_blocks.shape[0])]) + + target = _nearest_target(snapshot, env_i, targets) + if target is None: + return _random_move(snapshot, env_i, rng, allow_danger=True) + + if int(np.abs(target - snapshot.position[env_i]).sum()) == 1: + return _action_toward_delta(target - snapshot.position[env_i]) + + passable = _passable_map(snapshot, env_i, allow_danger=False) + adjacent_cells = [] + for delta in DIRS.values(): + cell = target + np.asarray(delta, dtype=np.int32) + if _in_bounds(cell) and passable[int(cell[0]), int(cell[1])]: + adjacent_cells.append(cell) + adjacent_target = _nearest_target(snapshot, env_i, adjacent_cells) + if adjacent_target is not None: + return _bfs_first_action(snapshot, env_i, adjacent_target, rng) + return _greedy_action(snapshot, env_i, target, rng, allow_danger=True) + + +def _boss_action(snapshot, env_i, rng, step): + if step < 1000: + return _descend_action(snapshot, env_i, rng) + level = int(snapshot.level[env_i]) + if level >= NUM_LEVELS - 1: + return _combat_action(snapshot, env_i, rng) + pos = snapshot.position[env_i] + on_down_ladder = int(snapshot.item_map[env_i, int(pos[0]), int(pos[1])]) == ITEM_LADDER_DOWN + ladder_open = int(snapshot.monsters_killed[env_i]) >= MONSTERS_KILLED_TO_CLEAR_LEVEL + if on_down_ladder and ladder_open: + return DESCEND + if rng.random() < 0.25: + return DESCEND + return _descend_action(snapshot, env_i, rng) + + +class ActionPolicy: + def __init__(self, policy, action_seed, num_envs): + if policy not in POLICIES: + raise ValueError(f"unknown policy {policy!r}") + self.policy = policy + self.rng = np.random.default_rng(action_seed) + self.num_envs = num_envs + + def effective_policy(self, step): + if self.policy != "mixed": + return self.policy + return MIXED_ORDER[(step // 500) % len(MIXED_ORDER)] + + def actions(self, step, ref): + policy = self.effective_policy(step) + if policy == "uniform": + return ( + self.rng.integers(0, NUM_ACTIONS, size=self.num_envs, dtype=np.int32), + policy, + ) + + snapshot = PolicySnapshot(ref.states) + out = np.empty(self.num_envs, dtype=np.int32) + for env_i in range(self.num_envs): + if policy == "combat": + out[env_i] = _combat_action(snapshot, env_i, self.rng) + elif policy == "descend": + out[env_i] = _descend_action(snapshot, env_i, self.rng) + elif policy == "suicide": + out[env_i] = _suicide_action(snapshot, env_i, self.rng) + elif policy == "boss": + out[env_i] = _boss_action(snapshot, env_i, self.rng, step) + else: + raise AssertionError(policy) + return out, policy + + +def _print_step_divergence( + seed, + step, + action, + policy_name, + reward_diff, + ref_reward, + c_reward, + ref_done, + c_done, + obs_diff, + history, +): + terminal_delta = int(bool(c_done)) - int(bool(ref_done)) + print( + "STEP DIVERGENCE " + f"seed={seed} step={step} action={int(action)} policy={policy_name}" + ) + print( + f"reward_delta={reward_diff:.8g} " + f"reward: jax={float(ref_reward):.8g} c={float(c_reward):.8g}" + ) + print( + f"terminal_delta={terminal_delta} " + f"done: jax={bool(ref_done)} c={bool(c_done)}" + ) + if obs_diff is None: + print("obs: ok") + else: + idx, max_diff, ref_value, c_value = obs_diff + section = section_for_index(idx) + print( + "obs: " + f"index={idx} section={section} " + f"subsystem={subsystem_for_section(section)} " + f"abs_diff={max_diff:.8g} " + f"jax={ref_value:.8g} c={c_value:.8g}" + ) + print(f"last_10_actions={list(history)}") + + +def _print_terminal_reset_check( + reset_verifier, + ref, + ref_obs, + reset_key, + env_i, + seed, + step, + policy_name, + atol, +): + if reset_verifier is None: + return True + key = np.asarray(reset_key, dtype=np.uint32) + ok = reset_verifier.compare( + ref.state_at(env_i), + ref_obs[env_i], + reset_key, + int(seed), + step, + policy_name, + atol, + ) + if ok: + print( + "terminal_reset_reference: ok " + f"reset_key=[{int(key[0])},{int(key[1])}]" + ) + return ok + + +def _terminal_summary(seeds, terminal_counts, episode_length_sums): + total_terminals = int(np.sum(terminal_counts)) + per_seed = [] + for seed, count, length_sum in zip(seeds, terminal_counts, episode_length_sums): + if int(count) > 0: + mean_len = float(length_sum) / float(count) + per_seed.append(f"{int(seed)}:{int(count)}@{mean_len:.1f}") + else: + per_seed.append(f"{int(seed)}:0") + return total_terminals, " ".join(per_seed) + + +def _diagnose_isolated_replay(cmod, seed, actions, atol, num_threads, reset_verifier): + print( + "isolated_replay: start " + f"seed={int(seed)} steps={len(actions)}" + ) + trace_path = Path("build") / f"craftax_repro_seed_{int(seed)}_steps_{len(actions)}.txt" + trace_path.parent.mkdir(exist_ok=True) + trace_path.write_text("\n".join(str(int(action)) for action in actions) + "\n") + print(f"isolated_replay_actions={trace_path}") + ref = JaxCraftaxBatch(np.asarray([seed], dtype=np.int64), resetter=reset_verifier) + vec, c_obs, c_rewards, c_terminals = make_c_vec( + cmod, + 1, + int(seed), + num_threads=num_threads, + ) + try: + if not compare_reset(ref.obs, c_obs.copy(), np.asarray([seed]), atol): + print("isolated_replay: initial reset diverged") + return + action_buf = np.zeros((1, 1), dtype=np.float32) + for step, action in enumerate(actions): + action_buf[0, 0] = float(action) + ref_obs, ref_rewards, ref_dones, reset_keys = ref.step( + np.asarray([action], dtype=np.int32) + ) + vec.cpu_step(action_buf.ctypes.data) + c_obs_snapshot = c_obs.copy() + c_rewards_snapshot = c_rewards.copy() + c_dones_snapshot = c_terminals.copy().astype(bool) + reward_diff = abs(float(ref_rewards[0]) - float(c_rewards_snapshot[0])) + done_match = bool(ref_dones[0]) == bool(c_dones_snapshot[0]) + obs_diff = first_obs_diff(ref_obs[0], c_obs_snapshot[0], atol) + if reward_diff > atol or not done_match or obs_diff is not None: + print( + "isolated_replay: divergence " + f"step={step} action={int(action)} " + f"reward_delta={reward_diff:.8g} " + f"done_jax={bool(ref_dones[0])} " + f"done_c={bool(c_dones_snapshot[0])}" + ) + if obs_diff is not None: + idx, max_diff, ref_value, c_value = obs_diff + section = section_for_index(idx) + print( + "isolated_replay_obs: " + f"index={idx} section={section} " + f"subsystem={subsystem_for_section(section)} " + f"abs_diff={max_diff:.8g} " + f"jax={ref_value:.8g} c={c_value:.8g}" + ) + if bool(ref_dones[0]) and bool(c_dones_snapshot[0]): + _print_terminal_reset_check( + reset_verifier, + ref, + ref_obs, + reset_keys[0], + 0, + seed, + step, + "isolated_replay", + atol, + ) + return + print("isolated_replay: no divergence") + finally: + vec.close() + + +def run(args): + if args.seeds <= 0: + raise ValueError("--seeds must be positive") + if args.steps < 0: + raise ValueError("--steps must be non-negative") + + policy_name = getattr(args, "policy", "uniform") + if policy_name not in POLICIES: + raise ValueError(f"--policy must be one of {POLICIES}") + + num_threads = int(getattr(args, "num_threads", 1)) + if num_threads <= 0: + raise ValueError("--num-threads must be positive") + os.environ.setdefault("OMP_NUM_THREADS", str(num_threads)) + + reset_on_done = bool(getattr(args, "reset_on_done", True)) + seeds = np.arange(args.seed_start, args.seed_start + args.seeds, dtype=np.int64) + + cmod = import_c_env() + reset_verifier = get_reset_verifier(True) + ref = JaxCraftaxBatch(seeds, resetter=reset_verifier) + ref_obs = ref.obs + + vec, c_obs, c_rewards, c_terminals = make_c_vec( + cmod, len(seeds), int(seeds[0]), num_threads=num_threads + ) + try: + if not compare_reset(ref_obs, c_obs.copy(), seeds, args.atol): + return 1 + + if reset_verifier is not None: + for env_i, seed in enumerate(seeds): + if not reset_verifier.compare( + ref.state_at(env_i), + ref_obs[env_i], + ref.reset_keys[env_i], + int(seed), + "initial", + policy_name, + args.atol, + ): + return 1 + + policy = ActionPolicy(policy_name, args.action_seed, len(seeds)) + action_buf = np.zeros((len(seeds), 1), dtype=np.float32) + histories = [deque(maxlen=10) for _seed in seeds] + full_histories = [[] for _seed in seeds] + terminal_counts = np.zeros(len(seeds), dtype=np.int64) + episode_lengths = np.zeros(len(seeds), dtype=np.int64) + episode_length_sums = np.zeros(len(seeds), dtype=np.int64) + + for step in range(args.steps): + step_actions, effective_policy = policy.actions(step, ref) + action_buf[:, 0] = step_actions.astype(np.float32) + for env_i, action in enumerate(step_actions): + histories[env_i].append(int(action)) + full_histories[env_i].append(int(action)) + + ref_obs, ref_rewards, ref_dones, reset_keys = ref.step(step_actions) + vec.cpu_step(action_buf.ctypes.data) + + c_obs_snapshot = c_obs.copy() + c_rewards_snapshot = c_rewards.copy() + c_dones_snapshot = c_terminals.copy().astype(bool) + + for env_i, seed in enumerate(seeds): + reward_diff = abs(float(ref_rewards[env_i]) - float(c_rewards_snapshot[env_i])) + done_match = bool(ref_dones[env_i]) == bool(c_dones_snapshot[env_i]) + obs_diff = first_obs_diff(ref_obs[env_i], c_obs_snapshot[env_i], args.atol) + if reward_diff > args.atol or not done_match or obs_diff is not None: + _print_step_divergence( + seed=seed, + step=step, + action=step_actions[env_i], + policy_name=effective_policy, + reward_diff=reward_diff, + ref_reward=ref_rewards[env_i], + c_reward=c_rewards_snapshot[env_i], + ref_done=ref_dones[env_i], + c_done=c_dones_snapshot[env_i], + obs_diff=obs_diff, + history=histories[env_i], + ) + if bool(ref_dones[env_i]) and bool(c_dones_snapshot[env_i]): + _print_terminal_reset_check( + reset_verifier, + ref, + ref_obs, + reset_keys[env_i], + env_i, + seed, + step, + effective_policy, + args.atol, + ) + _diagnose_isolated_replay( + cmod, + int(seed), + full_histories[env_i], + args.atol, + num_threads, + reset_verifier, + ) + return 1 + + episode_lengths += 1 + done_any = np.logical_or(ref_dones, c_dones_snapshot) + if reset_on_done and np.any(done_any): + for env_i, is_done in enumerate(done_any): + if not bool(is_done): + continue + terminal_counts[env_i] += 1 + episode_length_sums[env_i] += episode_lengths[env_i] + if reset_verifier is not None: + if not reset_verifier.compare( + ref.state_at(env_i), + ref_obs[env_i], + reset_keys[env_i], + int(seeds[env_i]), + step, + effective_policy, + args.atol, + ): + return 1 + episode_lengths[env_i] = 0 + + total_terminals, per_seed_summary = _terminal_summary( + seeds, terminal_counts, episode_length_sums + ) + print( + f"PASS craftax parity: seeds={args.seeds} steps={args.steps} " + f"atol={args.atol:g} action_seed={args.action_seed}" + ) + print( + f"policy={policy_name} reset_on_done={reset_on_done} " + f"terminal_count={total_terminals} " + f"mean_episode_length_by_seed={per_seed_summary}" + ) + return 0 + finally: + vec.close() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--seeds", type=int, default=16) + parser.add_argument("--seed-start", type=int, default=0) + parser.add_argument("--steps", type=int, default=1000) + parser.add_argument("--action-seed", type=int, default=0) + parser.add_argument("--atol", type=float, default=1e-5) + parser.add_argument("--policy", choices=POLICIES, default="uniform") + parser.add_argument("--num-threads", type=int, default=1) + parser.set_defaults(reset_on_done=True) + parser.add_argument("--reset-on-done", dest="reset_on_done", action="store_true") + parser.add_argument("--no-reset-on-done", dest="reset_on_done", action="store_false") + raise SystemExit(run(parser.parse_args())) + + +if __name__ == "__main__": + main() diff --git a/tests/craftax_parity_stress.py b/tests/craftax_parity_stress.py new file mode 100644 index 0000000000..52edfb17b2 --- /dev/null +++ b/tests/craftax_parity_stress.py @@ -0,0 +1,96 @@ +import argparse +import os +import time +from types import SimpleNamespace + +from craftax_parity import run + + +STRESS_CASES = [ + { + "name": "mixed-wide", + "seeds": 64, + "steps": 10000, + "policy": "mixed", + "action_seed": 0, + }, + { + "name": "descend-boss-target", + "seeds": 16, + "steps": 30000, + "policy": "descend", + "action_seed": 1, + }, + { + "name": "suicide-terminal-target", + "seeds": 32, + "steps": 5000, + "policy": "suicide", + "action_seed": 2, + }, + { + "name": "combat-projectile-xp", + "seeds": 16, + "steps": 5000, + "policy": "combat", + "action_seed": 3, + }, +] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--atol", type=float, default=1e-5) + parser.add_argument("--seed-start", type=int, default=0) + parser.add_argument( + "--num-threads", + type=int, + default=max(1, min(16, os.cpu_count() or 1)), + ) + args = parser.parse_args() + + started = time.monotonic() + for case in STRESS_CASES: + case_started = time.monotonic() + print( + "RUN craftax parity stress " + f"name={case['name']} seeds={case['seeds']} steps={case['steps']} " + f"policy={case['policy']} action_seed={case['action_seed']} " + f"atol={args.atol:g}", + flush=True, + ) + status = run( + SimpleNamespace( + seeds=case["seeds"], + seed_start=args.seed_start, + steps=case["steps"], + action_seed=case["action_seed"], + atol=args.atol, + policy=case["policy"], + reset_on_done=True, + num_threads=args.num_threads, + ) + ) + elapsed = time.monotonic() - case_started + if status != 0: + print( + "FAIL craftax parity stress " + f"name={case['name']} elapsed={elapsed:.1f}s", + flush=True, + ) + raise SystemExit(status) + print( + "PASS craftax parity stress case " + f"name={case['name']} elapsed={elapsed:.1f}s", + flush=True, + ) + + elapsed = time.monotonic() - started + print( + f"PASS craftax parity stress: cases={len(STRESS_CASES)} elapsed={elapsed:.1f}s", + flush=True, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/craftax_state_fixtures.py b/tests/craftax_state_fixtures.py new file mode 100644 index 0000000000..3639965c8a --- /dev/null +++ b/tests/craftax_state_fixtures.py @@ -0,0 +1,620 @@ +import ctypes +import os +import pickle + +os.environ.setdefault("JAX_PLATFORM_NAME", "cpu") +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") + +import jax.numpy as jnp +import numpy as np + +from craftax.craftax.craftax_state import EnvState, Inventory, Mobs + + +LEVELS = 9 +MAP_SIZE = 48 +ACHIEVEMENTS = 67 +MAX_MELEE_MOBS = 3 +MAX_PASSIVE_MOBS = 3 +MAX_RANGED_MOBS = 2 +MAX_MOB_PROJECTILES = 3 +MAX_PLAYER_PROJECTILES = 3 +MAX_GROWING_PLANTS = 10 + + +def _c_array(ctype, *shape): + array_type = ctype + for size in reversed(shape): + array_type = array_type * size + return array_type + + +class CraftaxInventory(ctypes.Structure): + _fields_ = [ + ("wood", ctypes.c_int32), + ("stone", ctypes.c_int32), + ("coal", ctypes.c_int32), + ("iron", ctypes.c_int32), + ("diamond", ctypes.c_int32), + ("sapling", ctypes.c_int32), + ("pickaxe", ctypes.c_int32), + ("sword", ctypes.c_int32), + ("bow", ctypes.c_int32), + ("arrows", ctypes.c_int32), + ("armour", _c_array(ctypes.c_int32, 4)), + ("torches", ctypes.c_int32), + ("ruby", ctypes.c_int32), + ("sapphire", ctypes.c_int32), + ("potions", _c_array(ctypes.c_int32, 6)), + ("books", ctypes.c_int32), + ] + + +class CraftaxMobs3(ctypes.Structure): + _fields_ = [ + ("position", _c_array(ctypes.c_int32, LEVELS, 3, 2)), + ("health", _c_array(ctypes.c_float, LEVELS, 3)), + ("mask", _c_array(ctypes.c_bool, LEVELS, 3)), + ("attack_cooldown", _c_array(ctypes.c_int32, LEVELS, 3)), + ("type_id", _c_array(ctypes.c_int32, LEVELS, 3)), + ] + + +class CraftaxMobs2(ctypes.Structure): + _fields_ = [ + ("position", _c_array(ctypes.c_int32, LEVELS, 2, 2)), + ("health", _c_array(ctypes.c_float, LEVELS, 2)), + ("mask", _c_array(ctypes.c_bool, LEVELS, 2)), + ("attack_cooldown", _c_array(ctypes.c_int32, LEVELS, 2)), + ("type_id", _c_array(ctypes.c_int32, LEVELS, 2)), + ] + + +class CraftaxState(ctypes.Structure): + _fields_ = [ + ("map", _c_array(ctypes.c_int32, LEVELS, MAP_SIZE, MAP_SIZE)), + ("item_map", _c_array(ctypes.c_int32, LEVELS, MAP_SIZE, MAP_SIZE)), + ("mob_map", _c_array(ctypes.c_bool, LEVELS, MAP_SIZE, MAP_SIZE)), + ("light_map", _c_array(ctypes.c_float, LEVELS, MAP_SIZE, MAP_SIZE)), + ("down_ladders", _c_array(ctypes.c_int32, LEVELS, 2)), + ("up_ladders", _c_array(ctypes.c_int32, LEVELS, 2)), + ("chests_opened", _c_array(ctypes.c_bool, LEVELS)), + ("monsters_killed", _c_array(ctypes.c_int32, LEVELS)), + ("player_position", _c_array(ctypes.c_int32, 2)), + ("player_level", ctypes.c_int32), + ("player_direction", ctypes.c_int32), + ("player_health", ctypes.c_float), + ("player_food", ctypes.c_int32), + ("player_drink", ctypes.c_int32), + ("player_energy", ctypes.c_int32), + ("player_mana", ctypes.c_int32), + ("is_sleeping", ctypes.c_bool), + ("is_resting", ctypes.c_bool), + ("player_recover", ctypes.c_float), + ("player_hunger", ctypes.c_float), + ("player_thirst", ctypes.c_float), + ("player_fatigue", ctypes.c_float), + ("player_recover_mana", ctypes.c_float), + ("player_xp", ctypes.c_int32), + ("player_dexterity", ctypes.c_int32), + ("player_strength", ctypes.c_int32), + ("player_intelligence", ctypes.c_int32), + ("inventory", CraftaxInventory), + ("melee_mobs", CraftaxMobs3), + ("passive_mobs", CraftaxMobs3), + ("ranged_mobs", CraftaxMobs2), + ("mob_projectiles", CraftaxMobs3), + ( + "mob_projectile_directions", + _c_array(ctypes.c_int32, LEVELS, MAX_MOB_PROJECTILES, 2), + ), + ("player_projectiles", CraftaxMobs3), + ( + "player_projectile_directions", + _c_array(ctypes.c_int32, LEVELS, MAX_PLAYER_PROJECTILES, 2), + ), + ( + "growing_plants_positions", + _c_array(ctypes.c_int32, MAX_GROWING_PLANTS, 2), + ), + ("growing_plants_age", _c_array(ctypes.c_int32, MAX_GROWING_PLANTS)), + ("growing_plants_mask", _c_array(ctypes.c_bool, MAX_GROWING_PLANTS)), + ("potion_mapping", _c_array(ctypes.c_int32, 6)), + ("learned_spells", _c_array(ctypes.c_bool, 2)), + ("sword_enchantment", ctypes.c_int32), + ("bow_enchantment", ctypes.c_int32), + ("armour_enchantments", _c_array(ctypes.c_int32, 4)), + ("boss_progress", ctypes.c_int32), + ("boss_timesteps_to_spawn_this_round", ctypes.c_int32), + ("light_level", ctypes.c_float), + ("achievements", _c_array(ctypes.c_bool, ACHIEVEMENTS)), + ("state_rng", _c_array(ctypes.c_uint32, 2)), + ("timestep", ctypes.c_int32), + ("fractal_noise_angles", _c_array(ctypes.c_int32, 4)), + ] + + +def _np_array(value, dtype): + return np.ascontiguousarray(np.asarray(value, dtype=dtype)) + + +def _copy_to_c(c_array, value, dtype, shape): + array = _np_array(value, dtype) + if array.shape != shape: + raise ValueError(f"shape mismatch: got {array.shape}, expected {shape}") + ctypes.memmove(ctypes.addressof(c_array), array.ctypes.data, array.nbytes) + + +def _copy_from_c(c_array, dtype): + return np.asarray(np.ctypeslib.as_array(c_array), dtype=dtype).copy() + + +def _mobs_payload(mobs): + return { + "position": _np_array(mobs.position, np.int32), + "health": _np_array(mobs.health, np.float32), + "mask": _np_array(mobs.mask, np.bool_), + "attack_cooldown": _np_array(mobs.attack_cooldown, np.int32), + "type_id": _np_array(mobs.type_id, np.int32), + } + + +def _inventory_payload(inventory): + return { + "wood": int(inventory.wood), + "stone": int(inventory.stone), + "coal": int(inventory.coal), + "iron": int(inventory.iron), + "diamond": int(inventory.diamond), + "sapling": int(inventory.sapling), + "pickaxe": int(inventory.pickaxe), + "sword": int(inventory.sword), + "bow": int(inventory.bow), + "arrows": int(inventory.arrows), + "armour": _np_array(inventory.armour, np.int32), + "torches": int(inventory.torches), + "ruby": int(inventory.ruby), + "sapphire": int(inventory.sapphire), + "potions": _np_array(inventory.potions, np.int32), + "books": int(inventory.books), + } + + +def _fractal_payload(state): + values = [] + for value in state.fractal_noise_angles: + values.append(0 if value is None else int(value)) + return np.asarray(values, dtype=np.int32) + + +def serialize_jax_state(state: EnvState) -> bytes: + payload = { + "map": _np_array(state.map, np.int32), + "item_map": _np_array(state.item_map, np.int32), + "mob_map": _np_array(state.mob_map, np.bool_), + "light_map": _np_array(state.light_map, np.float32), + "down_ladders": _np_array(state.down_ladders, np.int32), + "up_ladders": _np_array(state.up_ladders, np.int32), + "chests_opened": _np_array(state.chests_opened, np.bool_), + "monsters_killed": _np_array(state.monsters_killed, np.int32), + "player_position": _np_array(state.player_position, np.int32), + "player_level": int(state.player_level), + "player_direction": int(state.player_direction), + "player_health": float(state.player_health), + "player_food": int(state.player_food), + "player_drink": int(state.player_drink), + "player_energy": int(state.player_energy), + "player_mana": int(state.player_mana), + "is_sleeping": bool(state.is_sleeping), + "is_resting": bool(state.is_resting), + "player_recover": float(state.player_recover), + "player_hunger": float(state.player_hunger), + "player_thirst": float(state.player_thirst), + "player_fatigue": float(state.player_fatigue), + "player_recover_mana": float(state.player_recover_mana), + "player_xp": int(state.player_xp), + "player_dexterity": int(state.player_dexterity), + "player_strength": int(state.player_strength), + "player_intelligence": int(state.player_intelligence), + "inventory": _inventory_payload(state.inventory), + "melee_mobs": _mobs_payload(state.melee_mobs), + "passive_mobs": _mobs_payload(state.passive_mobs), + "ranged_mobs": _mobs_payload(state.ranged_mobs), + "mob_projectiles": _mobs_payload(state.mob_projectiles), + "mob_projectile_directions": _np_array( + state.mob_projectile_directions, np.int32 + ), + "player_projectiles": _mobs_payload(state.player_projectiles), + "player_projectile_directions": _np_array( + state.player_projectile_directions, np.int32 + ), + "growing_plants_positions": _np_array( + state.growing_plants_positions, np.int32 + ), + "growing_plants_age": _np_array(state.growing_plants_age, np.int32), + "growing_plants_mask": _np_array(state.growing_plants_mask, np.bool_), + "potion_mapping": _np_array(state.potion_mapping, np.int32), + "learned_spells": _np_array(state.learned_spells, np.bool_), + "sword_enchantment": int(state.sword_enchantment), + "bow_enchantment": int(state.bow_enchantment), + "armour_enchantments": _np_array(state.armour_enchantments, np.int32), + "boss_progress": int(state.boss_progress), + "boss_timesteps_to_spawn_this_round": int( + state.boss_timesteps_to_spawn_this_round + ), + "light_level": float(state.light_level), + "achievements": _np_array(state.achievements, np.bool_), + "state_rng": _np_array(state.state_rng, np.uint32), + "timestep": int(state.timestep), + "fractal_noise_angles": _fractal_payload(state), + } + return pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) + + +def _copy_inventory_to_c(c_inventory, payload): + for name in [ + "wood", + "stone", + "coal", + "iron", + "diamond", + "sapling", + "pickaxe", + "sword", + "bow", + "arrows", + "torches", + "ruby", + "sapphire", + "books", + ]: + setattr(c_inventory, name, int(payload[name])) + _copy_to_c(c_inventory.armour, payload["armour"], np.int32, (4,)) + _copy_to_c(c_inventory.potions, payload["potions"], np.int32, (6,)) + + +def _copy_mobs_to_c(c_mobs, payload, max_mobs): + _copy_to_c(c_mobs.position, payload["position"], np.int32, (LEVELS, max_mobs, 2)) + _copy_to_c(c_mobs.health, payload["health"], np.float32, (LEVELS, max_mobs)) + _copy_to_c(c_mobs.mask, payload["mask"], np.bool_, (LEVELS, max_mobs)) + _copy_to_c( + c_mobs.attack_cooldown, + payload["attack_cooldown"], + np.int32, + (LEVELS, max_mobs), + ) + _copy_to_c(c_mobs.type_id, payload["type_id"], np.int32, (LEVELS, max_mobs)) + + +def deserialize_jax_state_to_c(buffer: bytes) -> CraftaxState: + payload = pickle.loads(buffer) + state = CraftaxState() + + _copy_to_c(state.map, payload["map"], np.int32, (LEVELS, MAP_SIZE, MAP_SIZE)) + _copy_to_c( + state.item_map, payload["item_map"], np.int32, (LEVELS, MAP_SIZE, MAP_SIZE) + ) + _copy_to_c( + state.mob_map, payload["mob_map"], np.bool_, (LEVELS, MAP_SIZE, MAP_SIZE) + ) + _copy_to_c( + state.light_map, payload["light_map"], np.float32, (LEVELS, MAP_SIZE, MAP_SIZE) + ) + _copy_to_c(state.down_ladders, payload["down_ladders"], np.int32, (LEVELS, 2)) + _copy_to_c(state.up_ladders, payload["up_ladders"], np.int32, (LEVELS, 2)) + _copy_to_c(state.chests_opened, payload["chests_opened"], np.bool_, (LEVELS,)) + _copy_to_c(state.monsters_killed, payload["monsters_killed"], np.int32, (LEVELS,)) + + _copy_to_c(state.player_position, payload["player_position"], np.int32, (2,)) + state.player_level = int(payload["player_level"]) + state.player_direction = int(payload["player_direction"]) + state.player_health = float(payload["player_health"]) + state.player_food = int(payload["player_food"]) + state.player_drink = int(payload["player_drink"]) + state.player_energy = int(payload["player_energy"]) + state.player_mana = int(payload["player_mana"]) + state.is_sleeping = bool(payload["is_sleeping"]) + state.is_resting = bool(payload["is_resting"]) + state.player_recover = float(payload["player_recover"]) + state.player_hunger = float(payload["player_hunger"]) + state.player_thirst = float(payload["player_thirst"]) + state.player_fatigue = float(payload["player_fatigue"]) + state.player_recover_mana = float(payload["player_recover_mana"]) + state.player_xp = int(payload["player_xp"]) + state.player_dexterity = int(payload["player_dexterity"]) + state.player_strength = int(payload["player_strength"]) + state.player_intelligence = int(payload["player_intelligence"]) + + _copy_inventory_to_c(state.inventory, payload["inventory"]) + _copy_mobs_to_c(state.melee_mobs, payload["melee_mobs"], MAX_MELEE_MOBS) + _copy_mobs_to_c(state.passive_mobs, payload["passive_mobs"], MAX_PASSIVE_MOBS) + _copy_mobs_to_c(state.ranged_mobs, payload["ranged_mobs"], MAX_RANGED_MOBS) + _copy_mobs_to_c( + state.mob_projectiles, payload["mob_projectiles"], MAX_MOB_PROJECTILES + ) + _copy_to_c( + state.mob_projectile_directions, + payload["mob_projectile_directions"], + np.int32, + (LEVELS, MAX_MOB_PROJECTILES, 2), + ) + _copy_mobs_to_c( + state.player_projectiles, + payload["player_projectiles"], + MAX_PLAYER_PROJECTILES, + ) + _copy_to_c( + state.player_projectile_directions, + payload["player_projectile_directions"], + np.int32, + (LEVELS, MAX_PLAYER_PROJECTILES, 2), + ) + _copy_to_c( + state.growing_plants_positions, + payload["growing_plants_positions"], + np.int32, + (MAX_GROWING_PLANTS, 2), + ) + _copy_to_c( + state.growing_plants_age, + payload["growing_plants_age"], + np.int32, + (MAX_GROWING_PLANTS,), + ) + _copy_to_c( + state.growing_plants_mask, + payload["growing_plants_mask"], + np.bool_, + (MAX_GROWING_PLANTS,), + ) + _copy_to_c(state.potion_mapping, payload["potion_mapping"], np.int32, (6,)) + _copy_to_c(state.learned_spells, payload["learned_spells"], np.bool_, (2,)) + state.sword_enchantment = int(payload["sword_enchantment"]) + state.bow_enchantment = int(payload["bow_enchantment"]) + _copy_to_c( + state.armour_enchantments, payload["armour_enchantments"], np.int32, (4,) + ) + state.boss_progress = int(payload["boss_progress"]) + state.boss_timesteps_to_spawn_this_round = int( + payload["boss_timesteps_to_spawn_this_round"] + ) + state.light_level = float(payload["light_level"]) + _copy_to_c(state.achievements, payload["achievements"], np.bool_, (ACHIEVEMENTS,)) + _copy_to_c(state.state_rng, payload["state_rng"], np.uint32, (2,)) + state.timestep = int(payload["timestep"]) + _copy_to_c( + state.fractal_noise_angles, + payload["fractal_noise_angles"], + np.int32, + (4,), + ) + return state + + +def jax_state_to_c_state(state: EnvState) -> CraftaxState: + return deserialize_jax_state_to_c(serialize_jax_state(state)) + + +def _inventory_from_c(inventory): + return Inventory( + wood=int(inventory.wood), + stone=int(inventory.stone), + coal=int(inventory.coal), + iron=int(inventory.iron), + diamond=int(inventory.diamond), + sapling=int(inventory.sapling), + pickaxe=int(inventory.pickaxe), + sword=int(inventory.sword), + bow=int(inventory.bow), + arrows=int(inventory.arrows), + armour=jnp.asarray(_copy_from_c(inventory.armour, np.int32)), + torches=int(inventory.torches), + ruby=int(inventory.ruby), + sapphire=int(inventory.sapphire), + potions=jnp.asarray(_copy_from_c(inventory.potions, np.int32)), + books=int(inventory.books), + ) + + +def _mobs_from_c(mobs): + return Mobs( + position=jnp.asarray(_copy_from_c(mobs.position, np.int32)), + health=jnp.asarray(_copy_from_c(mobs.health, np.float32)), + mask=jnp.asarray(_copy_from_c(mobs.mask, np.bool_)), + attack_cooldown=jnp.asarray(_copy_from_c(mobs.attack_cooldown, np.int32)), + type_id=jnp.asarray(_copy_from_c(mobs.type_id, np.int32)), + ) + + +def _fractal_from_template(template): + if template is None: + return (None, None, None, None) + return template.fractal_noise_angles + + +def craftax_state_to_jax(state: CraftaxState, template: EnvState | None = None) -> EnvState: + return EnvState( + map=jnp.asarray(_copy_from_c(state.map, np.int32)), + item_map=jnp.asarray(_copy_from_c(state.item_map, np.int32)), + mob_map=jnp.asarray(_copy_from_c(state.mob_map, np.bool_)), + light_map=jnp.asarray(_copy_from_c(state.light_map, np.float32)), + down_ladders=jnp.asarray(_copy_from_c(state.down_ladders, np.int32)), + up_ladders=jnp.asarray(_copy_from_c(state.up_ladders, np.int32)), + chests_opened=jnp.asarray(_copy_from_c(state.chests_opened, np.bool_)), + monsters_killed=jnp.asarray(_copy_from_c(state.monsters_killed, np.int32)), + player_position=jnp.asarray(_copy_from_c(state.player_position, np.int32)), + player_level=int(state.player_level), + player_direction=int(state.player_direction), + player_health=float(state.player_health), + player_food=int(state.player_food), + player_drink=int(state.player_drink), + player_energy=int(state.player_energy), + player_mana=int(state.player_mana), + is_sleeping=bool(state.is_sleeping), + is_resting=bool(state.is_resting), + player_recover=float(state.player_recover), + player_hunger=float(state.player_hunger), + player_thirst=float(state.player_thirst), + player_fatigue=float(state.player_fatigue), + player_recover_mana=float(state.player_recover_mana), + player_xp=int(state.player_xp), + player_dexterity=int(state.player_dexterity), + player_strength=int(state.player_strength), + player_intelligence=int(state.player_intelligence), + inventory=_inventory_from_c(state.inventory), + melee_mobs=_mobs_from_c(state.melee_mobs), + passive_mobs=_mobs_from_c(state.passive_mobs), + ranged_mobs=_mobs_from_c(state.ranged_mobs), + mob_projectiles=_mobs_from_c(state.mob_projectiles), + mob_projectile_directions=jnp.asarray( + _copy_from_c(state.mob_projectile_directions, np.int32) + ), + player_projectiles=_mobs_from_c(state.player_projectiles), + player_projectile_directions=jnp.asarray( + _copy_from_c(state.player_projectile_directions, np.int32) + ), + growing_plants_positions=jnp.asarray( + _copy_from_c(state.growing_plants_positions, np.int32) + ), + growing_plants_age=jnp.asarray( + _copy_from_c(state.growing_plants_age, np.int32) + ), + growing_plants_mask=jnp.asarray( + _copy_from_c(state.growing_plants_mask, np.bool_) + ), + potion_mapping=jnp.asarray(_copy_from_c(state.potion_mapping, np.int32)), + learned_spells=jnp.asarray(_copy_from_c(state.learned_spells, np.bool_)), + sword_enchantment=int(state.sword_enchantment), + bow_enchantment=int(state.bow_enchantment), + armour_enchantments=jnp.asarray( + _copy_from_c(state.armour_enchantments, np.int32) + ), + boss_progress=int(state.boss_progress), + boss_timesteps_to_spawn_this_round=int( + state.boss_timesteps_to_spawn_this_round + ), + light_level=float(state.light_level), + achievements=jnp.asarray(_copy_from_c(state.achievements, np.bool_)), + state_rng=jnp.asarray(_copy_from_c(state.state_rng, np.uint32)), + timestep=int(state.timestep), + fractal_noise_angles=_fractal_from_template(template), + ) + + +def _flatten_mobs(prefix, mobs): + return { + f"{prefix}.position": np.asarray(mobs.position), + f"{prefix}.health": np.asarray(mobs.health), + f"{prefix}.mask": np.asarray(mobs.mask), + f"{prefix}.attack_cooldown": np.asarray(mobs.attack_cooldown), + f"{prefix}.type_id": np.asarray(mobs.type_id), + } + + +def _flatten_inventory(inventory): + return { + "inventory.wood": np.asarray(inventory.wood), + "inventory.stone": np.asarray(inventory.stone), + "inventory.coal": np.asarray(inventory.coal), + "inventory.iron": np.asarray(inventory.iron), + "inventory.diamond": np.asarray(inventory.diamond), + "inventory.sapling": np.asarray(inventory.sapling), + "inventory.pickaxe": np.asarray(inventory.pickaxe), + "inventory.sword": np.asarray(inventory.sword), + "inventory.bow": np.asarray(inventory.bow), + "inventory.arrows": np.asarray(inventory.arrows), + "inventory.armour": np.asarray(inventory.armour), + "inventory.torches": np.asarray(inventory.torches), + "inventory.ruby": np.asarray(inventory.ruby), + "inventory.sapphire": np.asarray(inventory.sapphire), + "inventory.potions": np.asarray(inventory.potions), + "inventory.books": np.asarray(inventory.books), + } + + +def flatten_env_state(state: EnvState): + flat = { + "map": np.asarray(state.map), + "item_map": np.asarray(state.item_map), + "mob_map": np.asarray(state.mob_map), + "light_map": np.asarray(state.light_map), + "down_ladders": np.asarray(state.down_ladders), + "up_ladders": np.asarray(state.up_ladders), + "chests_opened": np.asarray(state.chests_opened), + "monsters_killed": np.asarray(state.monsters_killed), + "player_position": np.asarray(state.player_position), + "player_level": np.asarray(state.player_level), + "player_direction": np.asarray(state.player_direction), + "player_health": np.asarray(state.player_health, dtype=np.float32), + "player_food": np.asarray(state.player_food), + "player_drink": np.asarray(state.player_drink), + "player_energy": np.asarray(state.player_energy), + "player_mana": np.asarray(state.player_mana), + "is_sleeping": np.asarray(state.is_sleeping), + "is_resting": np.asarray(state.is_resting), + "player_recover": np.asarray(state.player_recover, dtype=np.float32), + "player_hunger": np.asarray(state.player_hunger, dtype=np.float32), + "player_thirst": np.asarray(state.player_thirst, dtype=np.float32), + "player_fatigue": np.asarray(state.player_fatigue, dtype=np.float32), + "player_recover_mana": np.asarray( + state.player_recover_mana, dtype=np.float32 + ), + "player_xp": np.asarray(state.player_xp), + "player_dexterity": np.asarray(state.player_dexterity), + "player_strength": np.asarray(state.player_strength), + "player_intelligence": np.asarray(state.player_intelligence), + "mob_projectile_directions": np.asarray(state.mob_projectile_directions), + "player_projectile_directions": np.asarray( + state.player_projectile_directions + ), + "growing_plants_positions": np.asarray(state.growing_plants_positions), + "growing_plants_age": np.asarray(state.growing_plants_age), + "growing_plants_mask": np.asarray(state.growing_plants_mask), + "potion_mapping": np.asarray(state.potion_mapping), + "learned_spells": np.asarray(state.learned_spells), + "sword_enchantment": np.asarray(state.sword_enchantment), + "bow_enchantment": np.asarray(state.bow_enchantment), + "armour_enchantments": np.asarray(state.armour_enchantments), + "boss_progress": np.asarray(state.boss_progress), + "boss_timesteps_to_spawn_this_round": np.asarray( + state.boss_timesteps_to_spawn_this_round + ), + "light_level": np.asarray(state.light_level, dtype=np.float32), + "achievements": np.asarray(state.achievements), + "state_rng": np.asarray(state.state_rng, dtype=np.uint32), + "timestep": np.asarray(state.timestep), + "fractal_noise_angles": np.asarray( + [0 if value is None else int(value) for value in state.fractal_noise_angles], + dtype=np.int32, + ), + } + flat.update(_flatten_inventory(state.inventory)) + flat.update(_flatten_mobs("melee_mobs", state.melee_mobs)) + flat.update(_flatten_mobs("passive_mobs", state.passive_mobs)) + flat.update(_flatten_mobs("ranged_mobs", state.ranged_mobs)) + flat.update(_flatten_mobs("mob_projectiles", state.mob_projectiles)) + flat.update(_flatten_mobs("player_projectiles", state.player_projectiles)) + return flat + + +def assert_env_states_equal(actual: EnvState, expected: EnvState, context: str): + actual_flat = flatten_env_state(actual) + expected_flat = flatten_env_state(expected) + if actual_flat.keys() != expected_flat.keys(): + missing = expected_flat.keys() - actual_flat.keys() + extra = actual_flat.keys() - expected_flat.keys() + raise AssertionError(f"{context}: state keys differ missing={missing} extra={extra}") + + for name, expected_value in expected_flat.items(): + actual_value = actual_flat[name] + err_msg = f"{context}: field {name}" + if expected_value.dtype.kind == "f": + np.testing.assert_allclose( + actual_value, + expected_value, + atol=1e-6, + rtol=0.0, + err_msg=err_msg, + ) + else: + np.testing.assert_array_equal(actual_value, expected_value, err_msg=err_msg) diff --git a/tests/craftax_step_full_test.py b/tests/craftax_step_full_test.py new file mode 100644 index 0000000000..05bb930389 --- /dev/null +++ b/tests/craftax_step_full_test.py @@ -0,0 +1,14 @@ +from types import SimpleNamespace + +from tests import craftax_parity + + +def test_craftax_full_native_step_parity(): + args = SimpleNamespace( + seeds=16, + seed_start=0, + steps=2000, + action_seed=0, + atol=1e-5, + ) + assert craftax_parity.run(args) == 0