# Octree

## Octree Data Structure

```cpp
class Octree : public Spatial
{
public:
    struct Node
    {
        AABB box;
        std::vector<int> tris;
        std::shared_ptr<Node> child[8] = {nullptr};
    };

    std::shared_ptr<Node> root = nullptr;
    int maxDepth = 8;
    int maxPerNode = 16;
    
    ...
};
```

## Insert Triangle

```cpp
    void InsertTri(std::shared_ptr<Node> n, int triIndex, int depth)
    {
        if (depth == maxDepth) {
            n->tris.push_back(triIndex);
            return;
        }

        Triangle t = getTriangle(triIndex);
        glm::vec3 triMin = glm::min(t.v0, glm::min(t.v1, t.v2));
        glm::vec3 triMax = glm::max(t.v0, glm::max(t.v1, t.v2));

        if (n->tris.size() < maxPerNode) {
            n->tris.push_back(triIndex);
            return;
        }

        if (n->child[0] == nullptr)
            Subdivide(n);

        for (int i = 0; i < 8; i++)
        {
            const AABB &b = n->child[i]->box;
            // AABB intersection
            if (!(triMax.x < b.min.x || triMin.x > b.max.x ||
                  triMax.y < b.min.y || triMin.y > b.max.y ||
                  triMax.z < b.min.z || triMin.z > b.max.z))
            {
                InsertTri(n->child[i], triIndex, depth + 1);
            }
        }
    }
```

## QueryAABB

It checks if the query AABB intersects with the current octree node.

If there is an intersection, it adds all triangles in the current node and recursively checks its children.

```cpp
    void QueryAABB(const AABB &box, std::vector<int> &out) const override
    {
        QueryNode(root, box, out);
    }
    
    void QueryNode(std::shared_ptr<Node> n, const AABB &box, std::vector<int> &out) const
    {
        if (n == nullptr)
            return;

        // AABB no intersection
        if (n->box.max.x < box.min.x || n->box.min.x > box.max.x ||
            n->box.max.y < box.min.y || n->box.min.y > box.max.y ||
            n->box.max.z < box.min.z || n->box.min.z > box.max.z)
            return;

        out.insert(out.end(), n->tris.begin(), n->tris.end());

        for (int i = 0; i < 8; i++)
            QueryNode(n->child[i], box, out);
    }
```

## Raycast

```cpp
bool Raycast(const Ray &ray, HitInfo &outHit) override
    {
        HitInfo best;
        best.t = FLT_MAX;
        best.triIndex = -1;

        bool result = RaycastNode(root, ray, best);
        if (result)
            outHit = best;
        return result;
    }
    
bool RaycastNode(std::shared_ptr<Node> n, const Ray &ray, HitInfo &best)
    {
        float t;

        if (!RayAABB(ray.origin, ray.dir, n->box.min, n->box.max, t))
            return false;

        bool hit = false;

        for (int triIdx : n->tris)
        {
            float tt;
            Triangle tri = getTriangle(triIdx);
            if (RayTriangle(ray, tri, tt) && tt < best.t)
            {
                best.t = tt;
                best.triIndex = triIdx;
                hit = true;
            }
        }

        for (int i = 0; i < 8; i++)
            if (n->child[i])
                hit |= RaycastNode(n->child[i], ray, best);

        return hit;
    }
```
