Menu
×
   ❮     
HTML CSS JAVASCRIPT SQL PYTHON JAVA PHP HOW TO W3.CSS C C++ C# BOOTSTRAP REACT MYSQL JQUERY EXCEL XML DJANGO NUMPY PANDAS NODEJS R TYPESCRIPT ANGULAR GIT POSTGRESQL MONGODB ASP AI GO KOTLIN SASS VUE DSA GEN AI SCIPY AWS CYBERSECURITY DATA SCIENCE
     ❯   

Python Tutorial

Python HOME Python Intro Python Get Started Python Syntax Python Comments Python Variables Python Data Types Python Numbers Python Casting Python Strings Python Booleans Python Operators Python Lists Python Tuples Python Sets Python Dictionaries Python If...Else Python While Loops Python For Loops Python Functions Python Lambda Python Arrays Python Classes/Objects Python Inheritance Python Iterators Python Polymorphism Python Scope Python Modules Python Dates Python Math Python JSON Python RegEx Python PIP Python Try...Except Python User Input Python String Formatting

File Handling

Python File Handling Python Read Files Python Write/Create Files Python Delete Files

Python Modules

NumPy Tutorial Pandas Tutorial SciPy Tutorial Django Tutorial

Python Matplotlib

Matplotlib Intro Matplotlib Get Started Matplotlib Pyplot Matplotlib Plotting Matplotlib Markers Matplotlib Line Matplotlib Labels Matplotlib Grid Matplotlib Subplot Matplotlib Scatter Matplotlib Bars Matplotlib Histograms Matplotlib Pie Charts

Machine Learning

Getting Started Mean Median Mode Standard Deviation Percentile Data Distribution Normal Data Distribution Scatter Plot Linear Regression Polynomial Regression Multiple Regression Scale Train/Test Decision Tree Confusion Matrix Hierarchical Clustering Logistic Regression Grid Search Categorical Data K-means Bootstrap Aggregation Cross Validation AUC - ROC Curve K-nearest neighbors

Python MySQL

MySQL Get Started MySQL Create Database MySQL Create Table MySQL Insert MySQL Select MySQL Where MySQL Order By MySQL Delete MySQL Drop Table MySQL Update MySQL Limit MySQL Join

Python MongoDB

MongoDB Get Started MongoDB Create DB MongoDB Collection MongoDB Insert MongoDB Find MongoDB Query MongoDB Sort MongoDB Delete MongoDB Drop Collection MongoDB Update MongoDB Limit

Python Reference

Python Overview Python Built-in Functions Python String Methods Python List Methods Python Dictionary Methods Python Tuple Methods Python Set Methods Python File Methods Python Keywords Python Exceptions Python Glossary

Module Reference

Random Module Requests Module Statistics Module Math Module cMath Module

Python How To

Remove List Duplicates Reverse a String Add Two Numbers

Python Examples

Python Examples Python Compiler Python Exercises Python Quiz Python Server Python Bootcamp Python Certificate

Machine Learning - Train/Test


Evaluate Your Model

In Machine Learning we create models to predict the outcome of certain events, like in the previous chapter where we predicted the CO2 emission of a car when we knew the weight and engine size.

To measure if the model is good enough, we can use a method called Train/Test.


What is Train/Test

Train/Test is a method to measure the accuracy of your model.

It is called Train/Test because you split the data set into two sets: a training set and a testing set.

80% for training, and 20% for testing.

You train the model using the training set.

You test the model using the testing set.

Train the model means create the model.

Test the model means test the accuracy of the model.


Start With a Data Set

Start with a data set you want to test.

Our data set illustrates 100 customers in a shop, and their shopping habits.

Example

import numpy
import matplotlib.pyplot as plt
numpy.random.seed(2)

x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / x

plt.scatter(x, y)
plt.show()

Result:

The x axis represents the number of minutes before making a purchase.

The y axis represents the amount of money spent on the purchase.

Run example »



Split Into Train/Test

The training set should be a random selection of 80% of the original data.

The testing set should be the remaining 20%.

train_x = x[:80]
train_y = y[:80]

test_x = x[80:]
test_y = y[80:]


Display the Training Set

Display the same scatter plot with the training set:

Example

plt.scatter(train_x, train_y)
plt.show()

Result:

It looks like the original data set, so it seems to be a fair selection:

Run example »


Display the Testing Set

To make sure the testing set is not completely different, we will take a look at the testing set as well.

Example

plt.scatter(test_x, test_y)
plt.show()

Result:

The testing set also looks like the original data set:

Run example »


Fit the Data Set

What does the data set look like? In my opinion I think the best fit would be a polynomial regression, so let us draw a line of polynomial regression.

To draw a line through the data points, we use the plot() method of the matplotlib module:

Example

Draw a polynomial regression line through the data points:

import numpy
import matplotlib.pyplot as plt
numpy.random.seed(2)

x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / x

train_x = x[:80]
train_y = y[:80]

test_x = x[80:]
test_y = y[80:]

mymodel = numpy.poly1d(numpy.polyfit(train_x, train_y, 4))

myline = numpy.linspace(0, 6, 100)

plt.scatter(train_x, train_y)
plt.plot(myline, mymodel(myline))
plt.show()

Result:

Run example »

The result can back my suggestion of the data set fitting a polynomial regression, even though it would give us some weird results if we try to predict values outside of the data set. Example: the line indicates that a customer spending 6 minutes in the shop would make a purchase worth 200. That is probably a sign of overfitting.

But what about the R-squared score? The R-squared score is a good indicator of how well my data set is fitting the model.


R2

Remember R2, also known as R-squared?

It measures the relationship between the x axis and the y axis, and the value ranges from 0 to 1, where 0 means no relationship, and 1 means totally related.

The sklearn module has a method called r2_score() that will help us find this relationship.

In this case we would like to measure the relationship between the minutes a customer stays in the shop and how much money they spend.

Example

How well does my training data fit in a polynomial regression?

import numpy
from sklearn.metrics import r2_score
numpy.random.seed(2)

x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / x

train_x = x[:80]
train_y = y[:80]

test_x = x[80:]
test_y = y[80:]

mymodel = numpy.poly1d(numpy.polyfit(train_x, train_y, 4))

r2 = r2_score(train_y, mymodel(train_x))

print(r2)
Try it Yourself »

Note: The result 0.799 shows that there is a OK relationship.

Bring in the Testing Set

Now we have made a model that is OK, at least when it comes to training data.

Now we want to test the model with the testing data as well, to see if gives us the same result.

Example

Let us find the R2 score when using testing data:

import numpy
from sklearn.metrics import r2_score
numpy.random.seed(2)

x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / x

train_x = x[:80]
train_y = y[:80]

test_x = x[80:]
test_y = y[80:]

mymodel = numpy.poly1d(numpy.polyfit(train_x, train_y, 4))

r2 = r2_score(test_y, mymodel(test_x))

print(r2)
Try it Yourself »

Note: The result 0.809 shows that the model fits the testing set as well, and we are confident that we can use the model to predict future values.


Predict Values

Now that we have established that our model is OK, we can start predicting new values.

Example

How much money will a buying customer spend, if she or he stays in the shop for 5 minutes?

print(mymodel(5))
Run example »

The example predicted the customer to spend 22.88 dollars, as seems to correspond to the diagram:


W3Schools is optimized for learning and training. Examples might be simplified to improve reading and learning. Tutorials, references, and examples are constantly reviewed to avoid errors, but we cannot warrant full correctness of all content. While using W3Schools, you agree to have read and accepted our terms of use, cookie and privacy policy.

Copyright 1999-2024 by Refsnes Data. All Rights Reserved. W3Schools is Powered by W3.CSS.