Kodeclik Blog
Python assertEqual()
"assert statements" are a way for your Python program to check if something is true and raise an error if it's not. It's like double-checking that things are working the way you expect them to. It's often used for testing and debugging your code. Having assert statements makes for safe and predictable code.
Simple Python assert statements
Here is a simple use of assert statements. We do a calculation and use assertions to check for valid inputs and if the outputs are correct.
def calculate_average(numbers):
assert len(numbers) != 0, "List is empty."
return sum(numbers) / len(numbers)
# Test case
mylist = [1, 2, 3, 4]
assert calculate_average(mylist) == 2.5
The first use of the assert statement checks if the list is empty. If it is, an AssertionError is raised with the message "List is empty." The second use of the assert statement is to check the correctness of the function. If the returned average is not 2.5, the second assert statement will raise an error.
If we run the above code, nothing happens because everything proceeds smoothly. The average is calculated (but nothing is printed). On the other hand if we update the code to:
def calculate_average(numbers):
assert len(numbers) != 0, "List is empty."
return sum(numbers) / len(numbers)
# Test case
mylist = [1, 2, 3, 4]
assert calculate_average(mylist) == 2.5
# Test case 2
secondlist = [2, 4, 6, 8]
assert calculate_average(secondlist) == 2.5
# Test case 3
calculate_average([])
Here we have added two more test cases. The second test case is checking for a (wrong) average of 2.5. The average is indeed computed correctly but our check is incorrect in the assertion statement. As a result that assertion will fail. The third test case is passing an empty list to the calculate_average() function and therefore this test will fail because of the assertion inside the function definition. When we run this code, we will get:
Traceback (most recent call last):
File "main.py", line 11, in <module>
assert calculate_average(secondlist) == 2.5
AssertionError
exit status 1
as expected. Note that as soon as the assertion fails, program execution stops and the remainder of the program is not executed.
If we update the code to have only Test case 3 (try it!) you will get:
Traceback (most recent call last):
File "main.py", line 6, in <module>
calculate_average([])
File "main.py", line 2, in calculate_average
assert len(numbers) != 0, "List is empty."
AssertionError: List is empty.
exit status 1
again as expected. Using assert thus we can check for many conditions.
How is assertEqual() different from assert()?
Now, assertEqual() is a specific kind of assert statement that checks if two things are equal, meaning they have the same value (similar to the use cases above).
assertEqual() is an approach provided by the TestCase class in the unittest module in Python. This module is used for writing and running unit tests, which are a way to ensure that individual parts of your code work correctly.
Here is a simple use of assertEqual():
import unittest
class MyTestCase(unittest.TestCase):
def test_addition(self):
result = 2 + 2
self.assertEqual(result, 4)
if __name__ == '__main__':
unittest.main()
In the above program, we first import the unittest module. We then create a test case class that inherits from unittest.TestCase. This class will contain our test methods, where each test method is a function that checks a specific part of our code. For now we have only one test method, namely test_addition which checks if the result of adding 2 and 2 is equal to 4 using self.assertEqual(). This is a simple example but will serve to illustrate our idea. When we run this code, we will get:
----------------------------------------------------------------------
Ran 1 test in 0.000s
OK
meaning all available test cases (1 in this case) passed.
Using assertEqual() in a shopping site
Here is a very practical situation where you can use assertEqual(). Let us suppose we are running a shopping site and we have a function that calculates the total cost of items in a shopping cart. This function will be used in various parts of the site, eg whenever the user is adding items to the shopping cart and we wish to display the total cost of items in the current cart. We want to make sure the function is working correctly, so we prepare sample carts with items whose total cost we know and we use assertEqual() to test if they are consistent:
import unittest
def calculate_total_cost(prices, quantities):
assert len(prices) == len(quantities),
"Number of prices and quantities should be the same."
total_cost = 0
for price, quantity in zip(prices, quantities):
total_cost += price * quantity
return total_cost
class TestShoppingCart(unittest.TestCase):
def test_calculate_total_cost(self):
prices = [10, 20, 30]
quantities = [2, 3, 4]
expected_total_cost = 10*2 + 20*3 + 30*4
self.assertEqual(calculate_total_cost(prices, quantities),
expected_total_cost)
if __name__ == '__main__':
unittest.main()
In the above code, the calculate_total_cost() function takes two lists as input: prices and quantities. It first checks if the lengths of the two lists are the same using an assert statement. If they are not, an AssertionError is raised with the message "Number of prices and quantities should be the same." The function then calculates the total cost by multiplying each price with its corresponding quantity and adding them up.
The TestShoppingCart class is a subclass of unittest.TestCase and contains a single test method called test_calculate_total_cost(). This method tests the calculate_total_cost() function by providing it with a set of prices and quantities, and comparing the calculated total cost with the expected total cost using the assertEqual() method.
To run the test, we use the unittest.main() function, which discovers all the test methods in the TestShoppingCart class and runs them. If all the assertions pass, the test is considered successful. If any assertion fails, an AssertionError is raised, indicating that the function is not working correctly.
If you run this code you will get:
----------------------------------------------------------------------
Ran 1 test in 0.001s
OK
indicating that the test succeeded.
To check what will happen if the test fails let us artificially create a scenario where we change the code of the class to:
class TestShoppingCart(unittest.TestCase):
def test_calculate_total_cost(self):
prices = [10, 20, 30]
quantities = [2, 3, 4]
expected_total_cost = 10*3 + 20*3 + 30*4
Note that the expected_total_cost calculation is incorrect (10*2 is replaced with 10*3). If we run the above code now (leaving the other parts intact), we will get:
F
======================================================================
FAIL: test_calculate_total_cost (__main__.TestShoppingCart)
----------------------------------------------------------------------
Traceback (most recent call last):
File "main.py", line 15, in test_calculate_total_cost
self.assertEqual(calculate_total_cost(prices, quantities),
expected_total_cost)
AssertionError: 200 != 210
----------------------------------------------------------------------
Ran 1 test in 0.004s
FAILED (failures=1)
exit status 1
Using assertEqual() in a bookstore
Suppose we have a class called Book that represents a book with various attributes such as title, author, and price. Let us write unit tests for the class methods get_reading_time() and apply_discount() to ensure they return the expected values.
The full code is below:
import unittest
class Book:
def __init__(self, title, author, price):
self.title = title
self.author = author
self.price = price
def get_reading_time(self):
# Calculate the reading time based on the number of pages
return f"{self.pages * 1.5} minutes"
def apply_discount(self):
# Apply a 5% discount to the price
return f"\${self.price - self.price * 0.05}"
class TestBook(unittest.TestCase):
def setUp(self):
self.book_1 = Book("Title 1", "Author 1", 15)
self.book_2 = Book("Title 2", "Author 2", 16)
def test_get_reading_time(self):
self.book_1.pages = 304
self.book_2.pages = 447
self.assertEqual(self.book_1.get_reading_time(),
f"{304 * 1.5} minutes")
self.assertEqual(self.book_2.get_reading_time(),
f"{447 * 1.5} minutes")
def test_apply_discount(self):
self.assertEqual(self.book_1.apply_discount(),
f"\${15 - 15 * 0.05}")
self.assertEqual(self.book_2.apply_discount(),
f"\${16 - 16 * 0.15}")
if __name__ == '__main__':
unittest.main()
In this example, we create two instances of the Book class with different prices. We then test the get_reading_time() and apply_discount() methods using the assertEqual() method to check if the return values are correct. If any of the assertions fail, an exception will be raised, indicating that the test has failed.
Note that in the above program the functions that calculate reading time or discount are different from the computations happening in the test (as they should be!).
Also note that the second computation for applying discounts is incorrect since it is applying a 15% discount rather than 5%. In this case the testing code is wrong but the assertion will fail nonetheless. When we run the above code, we will get:
F.
======================================================================
FAIL: test_apply_discount (__main__.TestBook)
----------------------------------------------------------------------
Traceback (most recent call last):
File "main.py", line 30, in test_apply_discount
self.assertEqual(self.book_2.apply_discount(), f"\${16 - 16 * 0.15}")
AssertionError: '$15.2' != '$13.6'
- $15.2
+ $13.6
----------------------------------------------------------------------
Ran 2 tests in 0.001s
FAILED (failures=1)
exit status 1
As the example shows out of two tests for the discount calculation, one of them failed. This approach allows us to easily write and run tests for our code, ensuring that it behaves as expected in different scenarios. If an assertion throws an error we can check if it is our assertion that is incorrect or whether our original code is wrong.
If you liked this blogpost, checkout our blogpost on checking if two Python dictionaries are equal.
Interested in more things Python? Checkout our post on Python queues. Also see our blogpost on Python's enumerate() capability. Also if you like Python+math content, see our blogpost on Magic Squares. Finally, master the Python print function!
Want to learn Python with us? Sign up for 1:1 or small group classes.