Model training
After validating your teacher model’s performance, the next step is to train your small language model (SLM) using distil labs’ knowledge distillation approach.
Understanding knowledge distillation
Section titled “Understanding knowledge distillation”Knowledge distillation is the core technology behind distil labs’ ability to create high-performing small models with minimal training data. The process works as follows:
- Synthetic Data Generation: The large “teacher” model generates synthetic training data based on your problem definition, task description, and provided examples.
- Synthetic Data Validation: We validate generated data to make sure the synthetic set is diverse and high-quality
- Knowledge Transfer: The synthetic data is used to train the smaller “student” model with a loss function aligned with your specific task. This process enables the student model to emulate the teacher’s capabilities while maintaining a much smaller size.
Initiating model training
Section titled “Initiating model training”After completing teacher evaluation and confirming satisfactory performance, start the training process:
distil model run-training <model-id> import json
import requests
# See Account and Authentication for distil_bearer_token() implementation
auth_header = {"Authorization": f"Bearer {distil_bearer_token()}"}
data = {"upload_id": upload_id}
response = requests.post(
f"https://api.distillabs.ai/models/{model_id}/training",
data=json.dumps(data),
headers={"Content-Type": "application/json", **auth_header},
)
slm_training_job_id = response.json()["id"]
print(f"Training started with ID: {slm_training_job_id}") Monitoring training status
Section titled “Monitoring training status”The training process typically takes several hours to complete. Check the current status of your training job:
distil model training <model-id> import time
from pprint import pprint
running = True
while running:
response = requests.get(
f"https://api.distillabs.ai/trainings/{slm_training_job_id}/status",
headers=auth_header
)
status = response.json()["status"]
if status not in ["JOB_RUNNING", "JOB_PENDING"]:
running = False
print(f"Training status: {status}")
time.sleep(60)
pprint(response.json()) Possible status values include:
JOB_PENDING- Job is waiting to startJOB_RUNNING- Job is currently runningJOB_SUCCEEDED- Job has finished successfullyJOB_FAILED- Job encountered an error
Retrieving evaluation results
Section titled “Retrieving evaluation results”When the training is complete, retrieve detailed evaluation results to compare the performance of your trained SLM against the teacher model. See Metrics for details on each metric and how to interpret them.
distil model training <model-id> from pprint import pprint
response = requests.get(
f"https://api.distillabs.ai/trainings/{slm_training_job_id}/evaluation-results",
headers=auth_header
)
pprint(response.json()) What makes a successful training?
Section titled “What makes a successful training?”- Comparison to Teacher: Your SLM should achieve performance reasonably close to the teacher model (typically within one standard deviation)
- Task Requirements: The absolute performance should meet your specific application needs
If your SLM performance is significantly below the teacher model, consider:
- Increasing the number of training examples
- Adjusting your task description to be more specific
- Modifying your configuration parameters (like increasing training epochs)
- Using a slightly larger student model
Retrieving predictions (API only)
Section titled “Retrieving predictions (API only)”For more in-depth analysis, you can download the predictions on individual data points of the test dataset using the API. These predictions are generated using the fine-tuned student model.
print(response.json()["finetuned_student_evaluation_predictions_download_url"])
Download and read the predictions file:
curl -o finetuned_student_evaluation_predictions.jsonl "<DOWNLOAD_URL>"
The file is in JSON Lines format and can be read using:
import pandas as pd
df = pd.read_json("finetuned_student_evaluation_predictions.jsonl", lines=True)