> ## Documentation Index
> Fetch the complete documentation index at: https://docs.bluesight.ai/llms.txt
> Use this file to discover all available pages before exploring further.

# Train a forest fire classifier

> Start training satellite images classification models using Bluesight API

This guide shows how to train and infer [Clay Foundation Model](https://clay-foundation.github.io/model/index.html) for classification task using our API.

<a target="_blank" href="https://colab.research.google.com/github/bluesightai/docs/blob/main/guides/train-a-forest-fire-classifier.ipynb">
  <img noZoom src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" />
</a>

```python
API_URL = "https://api.bluesight.ai"
HEADERS = {"Content-Type": "application/json"}

STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"

lat, lon = 37.30939, -8.57207
start, end = "2018-07-01", "2018-09-01"
```

```python
!pip install pystac-client geopandas stackstac h5py requests-toolbelt
```

## Downloading Data

This step completely copies data preparation step from the [Clay example notebook](https://clay-foundation.github.io/model/tutorials/clay-v1-wall-to-wall.html) and you probably don't need it if you have your own data. If that's the case, jump to the next section.

If you need more elaboration on what's going on visit original source.

```python
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pystac_client
import stackstac
from rasterio.enums import Resampling
from shapely import Point

catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
    collections=[COLLECTION],
    datetime=f"{start}/{end}",
    bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),
    max_items=100,
    query={"eo:cloud_cover": {"lt": 80}},
)

all_items = search.item_collection()

items = []
dates = []
for item in all_items:
    if item.datetime.date() not in dates:
        items.append(item)
        dates.append(item.datetime.date())

print(f"Found {len(items)} items")

epsg = items[0].properties["proj:epsg"]

poidf = gpd.GeoDataFrame(
    pd.DataFrame(),
    crs="EPSG:4326",
    geometry=[Point(lon, lat)],
).to_crs(epsg)

coords = poidf.iloc[0].geometry.coords[0]

size = 256
gsd = 10
bounds = (
    coords[0] - (size * gsd) // 2,
    coords[1] - (size * gsd) // 2,
    coords[0] + (size * gsd) // 2,
    coords[1] + (size * gsd) // 2,
)

stack = stackstac.stack(
    items,
    bounds=bounds,
    snap_bounds=False,
    epsg=epsg,
    resolution=gsd,
    dtype="float64",
    rescale=False,
    fill_value=0,
    assets=["blue", "green", "red", "nir"],
    resampling=Resampling.nearest,
)

stack = stack.compute()

stack.sel(band=["red", "green", "blue"]).plot.imshow(row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6)
plt.show()
```

```output
Found 12 items
```

![png](https://mintlify.s3.us-west-1.amazonaws.com/greenlens/images/notebooks/train-a-forest-fire-classifier_files/train-a-forest-fire-classifier_4_1.png)

```python
labels = np.array([0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])

train_idxs = [0, 1, 3, 4, 7, 8, 9]
test_idxs = [2, 5, 6, 10, 11]
```

## Preparing Data

Now we need to load data to the format which is accepted by our API. Essentially, it is raw pixels data from all bands with some metadata.

```python
from pathlib import Path

import h5py
import numpy as np
from tqdm import tqdm

num_samples = len(train_idxs)
num_bands = len(stack[0].band.values)
pixel_shape = stack[0].values.shape

dt = np.dtype(
    [
        ("bands", f"S10", (num_bands,)),  # Adjust '10' if your band names are longer
        ("gsd", "float32"),
        ("pixels", "float32", pixel_shape),
        ("platform", "S20"),  # Adjust '20' if your platform names are longer
        ("point", "float32", (2,)),
        ("timestamp", "int64"),
        ("label", "int64"),
    ]
)


output_file = Path("forest_fire_train_data.h5")
with h5py.File(output_file, "w") as f:

    dataset = f.create_dataset(
        "data", (num_samples,), dtype=dt, compression="gzip", compression_opts=1, chunks=(min(1, num_samples),)
    )

    for i, idx in enumerate(tqdm(train_idxs)):
        image = stack[idx]

        dataset[i] = (
            [band.encode("ascii", "ignore") for band in image.band.values],
            image.gsd.values.item(),
            image.values.astype("float32"),
            COLLECTION.encode("ascii", "ignore"),
            [lat, lon],
            image.time.values.astype("datetime64[s]").astype(np.int64).item(),
            labels[idx].item(),
        )
```

```output
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 49.55it/s]
```

## Uploading Data to Bluesight

```python
import requests
import os
from pprint import pprint
from requests_toolbelt.multipart.encoder import MultipartEncoder

def upload_file(file_path):
    file_name = os.path.basename(file_path)
    encoder = MultipartEncoder(
        fields={'file': (file_name, open(file_path, 'rb'), 'application/octet-stream')}
    )
    
    response = requests.post(
        f"{API_URL}/files",
        data=encoder,
        headers={'Content-Type': encoder.content_type}
    )
    
    if response.status_code == 200:
        print("File uploaded successfully!")
        return response.json()
    else:
        print(f"Failed to upload file. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
```

```python
%%time
file_object = upload_file(output_file)
pprint(file_object, sort_dicts=False)
```

```output
File uploaded successfully!
{'id': 'file-lw3zjxrg',
 'bytes': 3669052,
 'created_at': 1722451315,
 'filename': 'forest_fire_train_data.h5'}
CPU times: user 16.5 ms, sys: 7.79 ms, total: 24.3 ms
Wall time: 6.66 s
```

## Starting Training

```python
def create_training_job(task, training_file, validation_file=None):

    payload = {"task": task, "training_file": training_file, "validation_file": validation_file}
    
    response = requests.post(
        f"{API_URL}/training/jobs",
        json=payload,
        headers={'Content-Type': "application/json"}
    )
    
    if response.status_code == 200:
        print("Training job submitted successfully!")
        return response.json()
    else:
        print(f"Failed to sumbit training job. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
```

```python
training_job = create_training_job("classification", file_object["id"])
pprint(training_job, sort_dicts=False)
```

```output
Training job submitted successfully!
{'task': 'classification',
 'training_file': 'file-lw3zjxrg',
 'validation_file': None,
 'hyperparameters': None,
 'id': 'trainingjob-5ms9ja6h',
 'created_at': 1722451316,
 'status': 'initializing',
 'error': None,
 'trained_model': None,
 'finished_at': None}
```

```python
def retrieve_training_job(job):
    response = requests.get(
        f"{API_URL}/training/jobs/{job}"
    )
    
    if response.status_code == 200:
        return response.json()
    else:
        print(f"Failed to retrieve training job. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
```

```python
%%time
import time

status = None
while status not in ("succeeded", "failed", "cancelled"):
    training_job = retrieve_training_job(training_job["id"])
    status = training_job["status"]
    print(status)
    time.sleep(3)
pprint(training_job)
```

```output
initializing
initializing
initializing
running
succeeded
{'created_at': 1722451316,
 'error': None,
 'finished_at': 1722451329,
 'hyperparameters': None,
 'id': 'trainingjob-5ms9ja6h',
 'status': 'succeeded',
 'task': 'classification',
 'trained_model': 'model:classification-lljzufea',
 'training_file': 'file-lw3zjxrg',
 'validation_file': None}
CPU times: user 26.1 ms, sys: 4.69 ms, total: 30.8 ms
Wall time: 18.2 s
```

## Running Inference

For inference we need only images.

```python
X_test = []
for image in stack[test_idxs]:
    X_test.append(
        {
            "bands": image.band.values.tolist(),
            "gsd": image.gsd.values.item(),
            "pixels": image.values.tolist(),
            "platform": COLLECTION,
            "point": [lat, lon],
            "timestamp": image.time.values.astype('datetime64[s]').astype(np.int64).item(),
        }
    )
```

```python
def run_trained_model_inference(model, images):
    payload = {"model": model, "images": images}
    
    response = requests.post(
        f"{API_URL}/inference",
        json=payload,
        headers={'Content-Type': "application/json"}
    )
    
    if response.status_code == 200:
        return response.json()
    else:
        print(f"Failed to run inference. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
```

```python
%%time
y_pred = run_trained_model_inference(training_job["trained_model"], X_test)
pprint(y_pred)
```

```output
{'labels': [0, 1, 1, 2, 2]}
CPU times: user 256 ms, sys: 23.9 ms, total: 279 ms
Wall time: 11.3 s
```

## Evaluating results

```python
y_true = labels[test_idxs]
y_pred = np.array(y_pred["labels"])
```

```python
match = np.sum(y_true == y_pred)
print(f"Matched {match} out of {len(y_true)} correctly")
```

```output
Matched 5 out of 5 correctly
```
