@@ -76,11 +76,11 @@ class Item(Base):
7676
7777
7878def create_items ():
79- session = Session (engine )
80- session .add (Item (id = 1 , embedding = [1 , 1 , 1 ], half_embedding = [1 , 1 , 1 ], binary_embedding = '000' , sparse_embedding = SparseVector ([1 , 1 , 1 ])))
81- session .add (Item (id = 2 , embedding = [2 , 2 , 2 ], half_embedding = [2 , 2 , 2 ], binary_embedding = '101' , sparse_embedding = SparseVector ([2 , 2 , 2 ])))
82- session .add (Item (id = 3 , embedding = [1 , 1 , 2 ], half_embedding = [1 , 1 , 2 ], binary_embedding = '111' , sparse_embedding = SparseVector ([1 , 1 , 2 ])))
83- session .commit ()
79+ with Session (engine ) as session :
80+ session .add (Item (id = 1 , embedding = [1 , 1 , 1 ], half_embedding = [1 , 1 , 1 ], binary_embedding = '000' , sparse_embedding = SparseVector ([1 , 1 , 1 ])))
81+ session .add (Item (id = 2 , embedding = [2 , 2 , 2 ], half_embedding = [2 , 2 , 2 ], binary_embedding = '101' , sparse_embedding = SparseVector ([2 , 2 , 2 ])))
82+ session .add (Item (id = 3 , embedding = [1 , 1 , 2 ], half_embedding = [1 , 1 , 2 ], binary_embedding = '111' , sparse_embedding = SparseVector ([1 , 1 , 2 ])))
83+ session .commit ()
8484
8585
8686class TestSqlalchemy :
@@ -129,11 +129,11 @@ def test_orm(self):
129129 item2 = Item (embedding = [4 , 5 , 6 ])
130130 item3 = Item ()
131131
132- session = Session (engine )
133- session .add (item )
134- session .add (item2 )
135- session .add (item3 )
136- session .commit ()
132+ with Session (engine ) as session :
133+ session .add (item )
134+ session .add (item2 )
135+ session .add (item3 )
136+ session .commit ()
137137
138138 stmt = select (Item )
139139 with Session (engine ) as session :
@@ -148,11 +148,11 @@ def test_orm(self):
148148 assert items [2 ].embedding is None
149149
150150 def test_vector (self ):
151- session = Session (engine )
152- session .add (Item (id = 1 , embedding = [1 , 2 , 3 ]))
153- session .commit ()
154- item = session .get (Item , 1 )
155- assert item .embedding .tolist () == [1 , 2 , 3 ]
151+ with Session (engine ) as session :
152+ session .add (Item (id = 1 , embedding = [1 , 2 , 3 ]))
153+ session .commit ()
154+ item = session .get (Item , 1 )
155+ assert item .embedding .tolist () == [1 , 2 , 3 ]
156156
157157 def test_vector_l2_distance (self ):
158158 create_items ()
@@ -203,11 +203,11 @@ def test_vector_l1_distance_orm(self):
203203 assert [v .id for v in items ] == [1 , 3 , 2 ]
204204
205205 def test_halfvec (self ):
206- session = Session (engine )
207- session .add (Item (id = 1 , half_embedding = [1 , 2 , 3 ]))
208- session .commit ()
209- item = session .get (Item , 1 )
210- assert item .half_embedding .to_list () == [1 , 2 , 3 ]
206+ with Session (engine ) as session :
207+ session .add (Item (id = 1 , half_embedding = [1 , 2 , 3 ]))
208+ session .commit ()
209+ item = session .get (Item , 1 )
210+ assert item .half_embedding .to_list () == [1 , 2 , 3 ]
211211
212212 def test_halfvec_l2_distance (self ):
213213 create_items ()
@@ -258,11 +258,11 @@ def test_halfvec_l1_distance_orm(self):
258258 assert [v .id for v in items ] == [1 , 3 , 2 ]
259259
260260 def test_bit (self ):
261- session = Session (engine )
262- session .add (Item (id = 1 , binary_embedding = '101' ))
263- session .commit ()
264- item = session .get (Item , 1 )
265- assert item .binary_embedding == '101'
261+ with Session (engine ) as session :
262+ session .add (Item (id = 1 , binary_embedding = '101' ))
263+ session .commit ()
264+ item = session .get (Item , 1 )
265+ assert item .binary_embedding == '101'
266266
267267 def test_bit_hamming_distance (self ):
268268 create_items ()
@@ -289,11 +289,11 @@ def test_bit_jaccard_distance_orm(self):
289289 assert [v .id for v in items ] == [2 , 3 , 1 ]
290290
291291 def test_sparsevec (self ):
292- session = Session (engine )
293- session .add (Item (id = 1 , sparse_embedding = [1 , 2 , 3 ]))
294- session .commit ()
295- item = session .get (Item , 1 )
296- assert item .sparse_embedding .to_list () == [1 , 2 , 3 ]
292+ with Session (engine ) as session :
293+ session .add (Item (id = 1 , sparse_embedding = [1 , 2 , 3 ]))
294+ session .commit ()
295+ item = session .get (Item , 1 )
296+ assert item .sparse_embedding .to_list () == [1 , 2 , 3 ]
297297
298298 def test_sparsevec_l2_distance (self ):
299299 create_items ()
@@ -405,24 +405,24 @@ def test_sum_orm(self):
405405
406406 def test_bad_dimensions (self ):
407407 item = Item (embedding = [1 , 2 ])
408- session = Session (engine )
409- session .add (item )
410- with pytest .raises (StatementError , match = 'expected 3 dimensions, not 2' ):
411- session .commit ()
408+ with Session (engine ) as session :
409+ session .add (item )
410+ with pytest .raises (StatementError , match = 'expected 3 dimensions, not 2' ):
411+ session .commit ()
412412
413413 def test_bad_ndim (self ):
414414 item = Item (embedding = np .array ([[1 , 2 , 3 ]]))
415- session = Session (engine )
416- session .add (item )
417- with pytest .raises (StatementError , match = 'expected ndim to be 1' ):
418- session .commit ()
415+ with Session (engine ) as session :
416+ session .add (item )
417+ with pytest .raises (StatementError , match = 'expected ndim to be 1' ):
418+ session .commit ()
419419
420420 def test_bad_dtype (self ):
421421 item = Item (embedding = np .array (['one' , 'two' , 'three' ]))
422- session = Session (engine )
423- session .add (item )
424- with pytest .raises (StatementError , match = 'could not convert string to float' ):
425- session .commit ()
422+ with Session (engine ) as session :
423+ session .add (item )
424+ with pytest .raises (StatementError , match = 'could not convert string to float' ):
425+ session .commit ()
426426
427427 def test_inspect (self ):
428428 columns = inspect (engine ).get_columns ('sqlalchemy_orm_item' )
@@ -433,44 +433,48 @@ def test_literal_binds(self):
433433 assert "embedding <-> '[1.0,2.0,3.0]'" in str (sql )
434434
435435 def test_insert (self ):
436- session .execute (insert (Item ).values (embedding = np .array ([1 , 2 , 3 ])))
436+ with Session (engine ) as session :
437+ session .execute (insert (Item ).values (embedding = np .array ([1 , 2 , 3 ])))
437438
438439 def test_insert_bulk (self ):
439- session .execute (insert (Item ), [{'embedding' : np .array ([1 , 2 , 3 ])}])
440+ with Session (engine ) as session :
441+ session .execute (insert (Item ), [{'embedding' : np .array ([1 , 2 , 3 ])}])
440442
441443 # register_vector in psycopg2 tests change this behavior
442444 # def test_insert_text(self):
443- # session.execute(text('INSERT INTO sqlalchemy_orm_item (embedding) VALUES (:embedding)'), {'embedding': np.array([1, 2, 3])})
445+ # with Session(engine) as session:
446+ # session.execute(text('INSERT INTO sqlalchemy_orm_item (embedding) VALUES (:embedding)'), {'embedding': np.array([1, 2, 3])})
444447
445448 def test_automap (self ):
446449 metadata = MetaData ()
447450 metadata .reflect (engine , only = ['sqlalchemy_orm_item' ])
448451 AutoBase = automap_base (metadata = metadata )
449452 AutoBase .prepare ()
450453 AutoItem = AutoBase .classes .sqlalchemy_orm_item
451- session .execute (insert (AutoItem ), [{'embedding' : np .array ([1 , 2 , 3 ])}])
452- item = session .query (AutoItem ).first ()
453- assert item .embedding .tolist () == [1 , 2 , 3 ]
454+ with Session (engine ) as session :
455+ session .execute (insert (AutoItem ), [{'embedding' : np .array ([1 , 2 , 3 ])}])
456+ item = session .query (AutoItem ).first ()
457+ assert item .embedding .tolist () == [1 , 2 , 3 ]
454458
455459 def test_vector_array (self ):
456- session = Session (array_engine )
457- session .add (Item (id = 1 , embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
458- session .commit ()
460+ with Session (array_engine ) as session :
461+ session .add (Item (id = 1 , embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
462+ session .commit ()
459463
460- # this fails if the driver does not cast arrays
461- item = session .get (Item , 1 )
462- assert item .embeddings [0 ].tolist () == [1 , 2 , 3 ]
463- assert item .embeddings [1 ].tolist () == [4 , 5 , 6 ]
464+ # this fails if the driver does not cast arrays
465+ item = session .get (Item , 1 )
466+ assert item .embeddings [0 ].tolist () == [1 , 2 , 3 ]
467+ assert item .embeddings [1 ].tolist () == [4 , 5 , 6 ]
464468
465469 def test_halfvec_array (self ):
466- session = Session (array_engine )
467- session .add (Item (id = 1 , half_embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
468- session .commit ()
470+ with Session (array_engine ) as session :
471+ session .add (Item (id = 1 , half_embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
472+ session .commit ()
469473
470- # this fails if the driver does not cast arrays
471- item = session .get (Item , 1 )
472- assert item .half_embeddings [0 ].to_list () == [1 , 2 , 3 ]
473- assert item .half_embeddings [1 ].to_list () == [4 , 5 , 6 ]
474+ # this fails if the driver does not cast arrays
475+ item = session .get (Item , 1 )
476+ assert item .half_embeddings [0 ].to_list () == [1 , 2 , 3 ]
477+ assert item .half_embeddings [1 ].to_list () == [4 , 5 , 6 ]
474478
475479 def test_half_precision (self ):
476480 create_items ()
@@ -479,13 +483,12 @@ def test_half_precision(self):
479483 assert [v .id for v in items ] == [1 , 3 , 2 ]
480484
481485 def test_binary_quantize (self ):
482- session = Session (engine )
483- session .add (Item (id = 1 , embedding = [- 1 , - 2 , - 3 ]))
484- session .add (Item (id = 2 , embedding = [1 , - 2 , 3 ]))
485- session .add (Item (id = 3 , embedding = [1 , 2 , 3 ]))
486- session .commit ()
487-
488486 with Session (engine ) as session :
487+ session .add (Item (id = 1 , embedding = [- 1 , - 2 , - 3 ]))
488+ session .add (Item (id = 2 , embedding = [1 , - 2 , 3 ]))
489+ session .add (Item (id = 3 , embedding = [1 , 2 , 3 ]))
490+ session .commit ()
491+
489492 distance = func .cast (func .binary_quantize (Item .embedding ), BIT (3 )).hamming_distance (func .binary_quantize (func .cast ([3 , - 1 , 2 ], VECTOR (3 ))))
490493 items = session .query (Item ).order_by (distance ).all ()
491494 assert [v .id for v in items ] == [2 , 3 , 1 ]
0 commit comments