> ## Documentation Index
> Fetch the complete documentation index at: https://docs.bluesight.ai/llms.txt
> Use this file to discover all available pages before exploring further.

# Train a ship classifier

> Start training satellite images classification models using Bluesight API

This guide shows how to train and infer [Clay Foundation Model](https://clay-foundation.github.io/model/index.html) for classification task ([ships vs non ships dataset](https://www.kaggle.com/datasets/rhammell/ships-in-satellite-imagery)) using our API.

<a target="_blank" href="https://colab.research.google.com/github/bluesightai/docs/blob/main/guides/train-a-ship-classifier.ipynb">
  <img noZoom src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" />
</a>

```python
API_URL = "https://api.bluesight.ai"

GSD = 3.0
IMG_WIDTH = 80
BANDS = ["red", "green", "blue"]

DATA_URL = "https://drive.google.com/uc?id=1f-M8N_R0nA4n6bDV1Hld31CnbR_rinjB"
DATA_OUTPUT = "shipsnet.json"

RANDOM_SEED = 42
```

```python
!pip install scikit-learn gdown h5py requests-toolbelt
```

## Downloading Data

```python
import json
from collections import Counter

import gdown
import numpy as np

gdown.download(DATA_URL, DATA_OUTPUT, quiet=False)

with open(DATA_OUTPUT, "rb") as f:
    data = json.load(f)

print(f"Keys: {list(data.keys())}")
data, labels, locations = data["data"], data["labels"], data["locations"]
print(f"Labels distribution: {dict(Counter(labels))}")
pixels = np.array(data).astype("uint8")
pixels = pixels.reshape(-1, 3, IMG_WIDTH, IMG_WIDTH)
```

```````output
Downloading...
From (original): https://drive.google.com/uc?id=1f-M8N_R0nA4n6bDV1Hld31CnbR_rinjB
From (redirected): https://drive.google.com/uc?id=1f-M8N_R0nA4n6bDV1Hld31CnbR_rinjB&confirm=t&uuid=61a02ee9-af08-4bc7-a1ed-3ede065033d5
To: /home/furiousteabag/Projects/docs/guides/shipsnet.json
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344M/344M [00:32<00:00, 10.5MB/s]
``````output
Keys: ['data', 'labels', 'locations', 'scene_ids']
Labels distribution: {1: 1000, 0: 3000}
```````

```python
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle("Ships (top row) vs Non-Ships (bottom row)")

# Plot ships (label 1)
ship_indices = [i for i, label in enumerate(labels) if label == 1]
for i, ax in enumerate(axes[0]):
    ax.imshow(pixels[ship_indices[i]].transpose([1, 2, 0]))
    ax.axis('off')
    ax.set_title(f"Ship {i+1}")

# Plot non-ships (label 0)
non_ship_indices = [i for i, label in enumerate(labels) if label == 0]
for i, ax in enumerate(axes[1]):
    ax.imshow(pixels[non_ship_indices[i]].transpose([1, 2, 0]))
    ax.axis('off')
    ax.set_title(f"Non-Ship {i+1}")

plt.tight_layout()
plt.show()
```

![png](https://mintlify.s3.us-west-1.amazonaws.com/greenlens/images/notebooks/train-a-ship-classifier_files/train-a-ship-classifier_5_0.png)

## Preparing Data

```python
from datetime import datetime
from sklearn.model_selection import train_test_split

images = []
for image, (lon, lat) in zip(pixels, locations):
    images.append(
        {
            "bands": BANDS,
            "gsd": GSD,
            "pixels": image.tolist(),
            "point": [lat, lon]
        }
    )
    
X_train, X_test, y_train, y_test = train_test_split(
    images, labels, test_size=0.2, stratify=labels, random_state=RANDOM_SEED
)

print(f"Training set size: {len(X_train)}")
print(f"Test set shape: {len(X_test)}")
print(f"Training labels shape: {len(y_train)}")
print(f"Test labels shape: {len(y_test)}")
```

```output
Training set size: 3200
Test set shape: 800
Training labels shape: 3200
Test labels shape: 800
```

```python
from pathlib import Path

import h5py
import numpy as np
from tqdm import tqdm

num_samples = len(X_train)
num_bands = len(BANDS)
pixel_shape = (num_bands, IMG_WIDTH, IMG_WIDTH)

dt = np.dtype(
    [
        ("bands", f"S10", (num_bands,)),  # Adjust '10' if your band names are longer
        ("gsd", "float32"),
        ("pixels", "float32", pixel_shape),
        ("platform", "S20"),  # Adjust '20' if your platform names are longer
        ("point", "float32", (2,)),
        ("timestamp", "int64"),
        ("label", "int64"),
    ]
)


output_file = Path("ships_train_data.h5")
with h5py.File(output_file, "w") as f:

    dataset = f.create_dataset(
        "data", (num_samples,), dtype=dt, compression="gzip", compression_opts=1, chunks=(min(1, num_samples),)
    )

    for i, (X, y) in tqdm(enumerate(zip(X_train, y_train))):

        dataset[i] = (
            [band.encode("ascii", "ignore") for band in BANDS],
            GSD,
            np.array(X["pixels"]).astype("float32"),
            "".encode("ascii", "ignore"),
            np.array(X["point"]).astype("int64"),
            0,
            np.array(y).astype("int64")
        )
```

```output
3200it [00:05, 542.14it/s]
```

## Uploading Data to Bluesight

```python
import requests
import os
from pprint import pprint
from requests_toolbelt.multipart.encoder import MultipartEncoder

def upload_file(file_path):
    file_name = os.path.basename(file_path)
    encoder = MultipartEncoder(
        fields={'file': (file_name, open(file_path, 'rb'), 'application/octet-stream')}
    )
    
    response = requests.post(
        f"{API_URL}/files",
        data=encoder,
        headers={'Content-Type': encoder.content_type}
    )
    
    if response.status_code == 200:
        print("File uploaded successfully!")
        return response.json()
    else:
        print(f"Failed to upload file. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
```

```python
%%time
file_object = upload_file(output_file)
pprint(file_object, sort_dicts=False)
```

```output
File uploaded successfully!
{'id': 'file-v54z7pws',
 'bytes': 72732113,
 'created_at': 1722454847,
 'filename': 'ships_train_data.h5'}
CPU times: user 394 ms, sys: 90.4 ms, total: 485 ms
Wall time: 35.8 s
```

## Starting Training

Now we need to load data to the format which is accepted by our API. Essentially, it is raw pixels data from all bands with some metadata.

```python
def create_training_job(task, training_file, validation_file=None):

    payload = {"task": task, "training_file": training_file, "validation_file": validation_file}
    
    response = requests.post(
        f"{API_URL}/training/jobs",
        json=payload,
        headers={'Content-Type': "application/json"}
    )
    
    if response.status_code == 200:
        print("Training job submitted successfully!")
        return response.json()
    else:
        print(f"Failed to sumbit training job. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
```

```python
training_job = create_training_job("classification", file_object["id"])
pprint(training_job, sort_dicts=False)
```

```output
Training job submitted successfully!
{'task': 'classification',
 'training_file': 'file-v54z7pws',
 'validation_file': None,
 'hyperparameters': None,
 'id': 'trainingjob-o4rqy3iw',
 'created_at': 1722454848,
 'status': 'initializing',
 'error': None,
 'trained_model': None,
 'finished_at': None}
```

```python
def retrieve_training_job(job):
    response = requests.get(
        f"{API_URL}/training/jobs/{job}"
    )
    
    if response.status_code == 200:
        return response.json()
    else:
        print(f"Failed to retrieve training job. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
```

```python
%%time
import time

status = None
while status not in ("succeeded", "failed", "cancelled"):
    training_job = retrieve_training_job(training_job["id"])
    status = training_job["status"]
    print(status)
    time.sleep(3)
pprint(training_job)
```

```output
initializing
initializing
initializing
running
running
running
running
running
running
running
succeeded
{'created_at': 1722454848,
 'error': None,
 'finished_at': 1722454883,
 'hyperparameters': None,
 'id': 'trainingjob-o4rqy3iw',
 'status': 'succeeded',
 'task': 'classification',
 'trained_model': 'model:classification-xvxor2y2',
 'training_file': 'file-v54z7pws',
 'validation_file': None}
CPU times: user 76.4 ms, sys: 22.5 ms, total: 98.9 ms
Wall time: 39.1 s
```

## Running Inference

For inference we need only images.

```python
def run_trained_model_inference(model, images):
    payload = {"model": model, "images": images}
    
    response = requests.post(
        f"{API_URL}/inference",
        json=payload,
        headers={'Content-Type': "application/json"}
    )
    
    if response.status_code == 200:
        return response.json()
    else:
        print(f"Failed to run inference. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
```

```python
%%time
y_pred = run_trained_model_inference(training_job["trained_model"], X_test)
```

```output
CPU times: user 1.69 s, sys: 687 ms, total: 2.37 s
Wall time: 35.2 s
```

```python
from sklearn.metrics import classification_report, accuracy_score

y_pred = y_pred["labels"]
match = np.sum(np.array(y_test) == np.array(y_pred))
print(f"Matched {match} out of {len(y_test)} correctly")

print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
print(classification_report(y_test, y_pred))
```

```output
Matched 779 out of 800 correctly
Accuracy: 0.97375
              precision    recall  f1-score   support

           0       0.98      0.98      0.98       600
           1       0.94      0.95      0.95       200

    accuracy                           0.97       800
   macro avg       0.96      0.97      0.97       800
weighted avg       0.97      0.97      0.97       800
```
