From 66cbf6b6c5da19a78d0f19030b88634ae2217393 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 14 Jan 2025 06:10:40 +0800 Subject: [PATCH] feat: support multivector type (#2005) Signed-off-by: BubbleCal --- docs/src/assets/maxsim.png | Bin 0 -> 10525 bytes docs/src/search.md | 30 ++++++++ docs/src/sql.md | 2 +- python/python/lancedb/query.py | 14 ++-- python/python/lancedb/table.py | 2 + python/python/lancedb/util.py | 26 ++++++- python/python/tests/test_query.py | 116 +++++++++++++++++++++++++++++- rust/lancedb/src/table.rs | 109 ++++++++++++++-------------- rust/lancedb/src/utils.rs | 28 +++++--- 9 files changed, 255 insertions(+), 72 deletions(-) create mode 100644 docs/src/assets/maxsim.png diff --git a/docs/src/assets/maxsim.png b/docs/src/assets/maxsim.png new file mode 100644 index 0000000000000000000000000000000000000000..c975afc6f72aac4a7995eb2efc1be27412457583 GIT binary patch literal 10525 zcmXY1bzGC*7e-MDiHU$9B?!{sNXa23Q$kYd7$GGfC5(^`=|&hxcZ`-A0ullnqZ^SL zjC7;xhkk#5yzzO@J?}l|dCs}_zVAH|uhn0WlRhLRARr(It0-y_5Zr`a-a5jD~LIOzfXV&f_<{-pUl@;1X; zAeZJ1DG!J5q3;yQ9{x~ND=(h9>v3C2Gs+@jDp#`{c&HpDqYP&-jEdw^K0#sqa3I$s z~6b95ntI&z+#ueI{VSuH_r)=$a&i zNci*rr9nDwm7`>nCOuPOJ>ELz&3yJ-_aQRrO8;K{5Jd(tRlyxU?ewhq5(Y1#lM2(n z?va=^-=!`-zi7X~y2*nGas_O#r71!IXjYigb^qzTw9)N9TsP{QvhMMNI3&S2w9klC zexEf();d~25L$);&L-hQUOckCe60P$S9&}#+YK^NOx6z7t`0D`V-%BxgCW?utC#Mb zLneI5`~6&y)7p7w&lS31aa@uqNR-V|z)`1XGg@t#y&pvHKGSUnHTKq*Cg5Z}u*7OA;MUq+gz!bea`)rmIyPvZ# z$iT*y@{a&Z0B(J5YATjFr?DB%soN1j`b;zDTZgO9gL%K2^@o%UEz-_TX{|V6aA-ek zqzyC}*o`rwX%pbw6dip}fSErxNub9GJa{D(yuj@&2q@5a+VUT<%7fWB>vQW3&u zt#LqNl}9!lAXjD{f5VKJJ30-t&Ik^z1G?D6Gp@^5E!$^p?|tvyYp!~*5A3H%#T&Nu z>T1tUqPHR1vn}8e`*aFd5%t(BH9ILEZC{yYfd|+&s=qNANA(HBDNk!Y)m%H4y@hLM z@4|R}0W*e1*Bp(uePJ`+HTO9 zMYlKe?OrR)&R3p7|M$74B_^%&;g;01v!`8BA}g&;W2T(G#iWfg+P&nH+lJUHl*j~X zI`#@g^9@9n<1cW@>UNythq;$`L=V-!Qr;~+kYFl18WDFtcLQ6`_jmf?~z&)3c7V{nHlp^2d$36@+wOC1*#9OnM& z;b^Qh!fSLY-{K*bIcT}9E|O(5G*7m z-7T;5zVU2ky@>^jqBji1Ind1hm*KYwSx~X;8wU<{0trnZ&Tg@DGs8 zs!p1GpOWfn)-!MZm9<8P?c)2-#%I5XBIWR-85^gbDqW&$WhA#&v-=(XF9W=0R=^_c zLhUJ^9@dL!5SO;{aYvjz2U*knZ4ftd(iVxN$yk~pw)THeGWg0nw( z`=Q2(#=uogr~l9T+{6?jnrSs$g!pBnCsi#H_?cy2aWR1%j|OVJjii2#HN|Bz@bdaP zIbb{LDS-_U@b3xb5}=W(@g8sNWpvcV*DE%kak=>z_jQDjn3%h46J^u{0-7lMQ~!jjljh! zzGAaxWaS#e0;?BuuP7#$nG;X#rv$5J7hQsjSGSYml#RyePdx=M=FPeHFPh*G!Xu&A z{R1*N?;qsU4zzX=80qD+MHW$;2vyU?GEIbbd#XNH07aZwBusgF)$9z+47BhORy&9ZKCf|xJ zGqZ>smfD;JckFtOZFdo_Ll>qJv$mZHiGPsRx0NcKocv={&{6kRLCAWA*6B~#f=^3p ziR}BJx}&FXk}8(#IYMKXb5|%%YGZ%4u%W7Iw%6NF`0de;3BxcJZf|hW$3q!*KT}OS zoj>aVC#nAFhA8Dx{*Sso-|`%UVPZDr-!UxGoF4q7FF$1c`_T1Fkv+L_0HO#lc)90V zvWwcF(2gDUHC&Z`U|k1Mfs&vOck_oopBvg_A5=u0)KH)yhZN!H0qt2JA82wqW*Tq! z?Cizn$sKP+TydYTLoVWs9>nyQH|o3AYFBAgFo4j|c!qn^0;*ljtqrb~)<<|xylZ;0 zYpRhMPPC~r_Dy(a33XDp{+8_XVx@mq)gLk&7VqRM-u^h>@cVe?)Vtwm8Lioh>>{Q+ zVXx#gY_1lX{W)wHtkH|?w-V>=#I*Q}b|A{un$p9@##Kc)xl z9=P-t(w>|}YatM^N*Im22k*aR&h+6NQzcC?Ewe8kstd4-l6pV+B|e@4)~qkL+S0V8 zD-0bYV*QwPNQ~mm*w3=IU7gTb{Nq*v+S(4Bp7Ym1hjd^c7c6bLk---EF0I&zFeku! zB~><{&^ptIxy5O3HPaARgM6ysI7Jj3_2NNW=6cDm4zi;xdfQ;vLEzzlp1?Dk3Au-i zwp&WviWyVzzQv;texHeHOyR^A8?$~>;MAdqUhRI!Zvw!>st%X8b-(Y~X1#V51^4o_ z3F4Wh^?w|_*leV`;yHjYB!oUP8ct$uDgrGMP&MdkNm8qz2lx_fx`0iJz3r zm(6Dol}?>*w0$p2qxNapGw0C&r#Ug*8M<%UJS;Emixby(N^1TrYuieC#@bQA^6XAzL1GaqxMkUTH2={(P8-eIuk?B<=ZC9BJ!z85j9hq(QwrF!#P}pLjSZkWk6S zOOaC<01(ade5}N&$8e@q^We)Peq!qO(KIst1d|zq#Vl|<$vO9Wfg-0h*l8f&mTJk% zn`WN!=8@odcqjD|Ei^c-YBCE>;{_;6fHQbqrXi6%Vr2QMW z+*$%z^qUip&PFm`P3b57&(0z0&tpF!hzFy_d7svjv^mUYE?N{Hk&0R}cSoJ;#xTfR zLCc;1wr+zv-GbR>AXahx++Cvc{W4s{RbZwdsHe^8L5igwUltLs2cuv=)rVt-`_f2t zj%q<;zBJ;%uUkjh3;##`T7G6kY7TRctphFYtnBHqdL}g(nb&fV#I*V-H}e;TphzS} z=mvh9)6S^b?KrbfIA7*1EIXcNUcfHi3ZHl)FP7*9;?8-y&gH8juo{zY@AEt&;m@s( z393zSLt1Z@3TkmCfQBW>wxgB<{rNg2W?3+f=nt8XC8wo`pEA~KXToMr?N;aEQ60fx z>G>Mi=iq@>b~J~+#vBc7aa{##a4RPdAbF=&(KN+!{**(+l&RE;`@!NfQ7@khm4htgxME1+_*n8d%&2#E6=F$>oKuSMBwzY^9rBOTs z>?_K@JO#ZYc~p#es&GkCVdkU)D=SNJq_idK2s3c1M|*Gc$cQhHvqh14__Y)l9_XH$ zymrbIoNg~&7DZI0EbDuYL)%f@o;Tzvm~vL({J1zuT35f8gcMu@vR8PGTqC3gN5Idr zZ>Xv=-P!Ld{1N~otRHLx<4hW#(89JwWeV*L9DZ*pSZ2I(U#vfEAB6{3P^6U}jHyP1 zJ=@B&jdiuqj2ym6S!LVqSdK9QAD~2h+sl94iM$8PMeS)U#^@ecG}`z0flYVnwft#E zU!H22PLQ=GvJW_atLT1=XXI3VZ<$Lf&^WSFY_Vs!U-gn8zNT>@vBDW}*A6yDd|-RIGT8-(2FAv4+yv6Fp;rG7No8 zBY?NC*M6u{w`woYt!vNNHz%LOM)mD??&`1#d;#0o76b=gm~oS#8a2OPrIU^(;~$%X zu0Kdi;s&U+xfDS!X&OFyzIgRJ`>0LCProbXR2uN*?nX_o!JvwCgdw8_slndmqtx5Wa9tR)GJEH)SQL%c`(^2WDP_h#E^*UCf+ zW8G@UBs*#3gi~X7YW`^9IowNzj4a6rG4sQw9F1@><=nW%Be$!(gZ6jEo)Kd# z7qf3raP1y;#u<3=hhrfIn-bp;(y7A-&;|ljL$<4dzye_7V+}@P! ztZOjcrh}s`sl}5HA62R=hcfb;xudB97;Yx)w9{PzVvj-M1O5Y#l*9&Bzp~!)jK>PF zm2(yv6r};iaYMq-UuO5#_VfJ^4lfm9Z!Sx)p(C|Vh8@{kz$T>q@R3N$e8J_g2x*4K z!DE8~*YlE7pRmoAKL_vH%e;h#gm|?04L85}uryQzSM~3c2_G18bF7(3RnYY&q4k_@ zSd^(lCr3JKMXS0lVjiFSEY3fh2kA2WmhHXlJxEi#xp=_H-{_OP%5*wARlezqPdh4D zs`l}`pztVreL9y*;}V|n1F?nKxq;rW!2M1SU|!tL2Vn2J+dV)58}cl(``ohz^{2B1 zLeQPb2dx?y;bWdl1O|8(0n&QzIe(SfCki}0EEL-sb(b|-mqoiWYUX2Fhc<4Qs4xh) zetGTIFy87Yp_vU)mg96Ay{Og6JGKDMGTUcobQjFrc;^Lueq%1gD zJ2K2kL65ssSv`qsXp$!*w=5L>5RM$k2gogmf6HkjZMYr3P={^?O=NtGXHKNKvP(z| zn6ZL0Owlj(9u-)KRU&~CAE+Ixt6k4 zG!B>wcvlm(+8S&y-(pBJ*I5#rA|%fQUZtxB+aPrBHhq|8=2x^l81_u2%_6GPLZ9Ty zUg-v0D7b;VfuTe(5y>qB&uq^#RKB_U_V}t#ERS>rKG1o#d+~~)3;5ip$>wa2F&8Hw zvv%2W*{(jrLj7zlm~oZuVOOlRv`t3RvpGG~nPa|st46Gs^zdgpsa%tA5kHo3is;ub z`WU=p0+hm$1a$I*t}MtVgtyXL=8F-fKR^9=PO z0QWV1&&te{86AOMQ`4T^;yM6ZQN#3hqX(&Ackz(@eFn}Y1?h_3<0L;P-Hoc--dY+- zJ~zHXZGQ-(?lr1ljD>@L-lPGw=aAKvdc4cWWHL(hP_$jLRr}k_gM>fC*yGdG={J@S zZ&D5~wsI`(-A=^-ZInpsP`XJc-qmOCmO`@zhhG;t>QXvcnh&c(1?MPltoyS6$$;Y1 z-~kt}HZHH%FY`UN7wttG!<^0n$9-F{WeJB}ECK5eq0YyW>-8gbb-BVx>sx$Ucxv}> z5v>Qeyw4qWg+eWS+ov+5IysbyxXV9UNZoxT#@kjplgM-0LljIi7a{gOA0sVRjVMb5 zL5X7EZKe+tD{&Q;L?Naex=va7>oOpP!O%^+qxl8;DW?}6@?~P~=(;PRxI8I{hzu{=z>lKl)J{mg=(wbQgkMazQ$yvXSX*%1z=wpMU zpZBhzHB>PV25P%Cpa3wq-nO>)V2{$=oP6ma=j!omzFGU#s~YP|ZK0Un3m?JTi@Eft z(%6@CZwl$qUPdcv-{?bPX<+cqRa%)R1C6~jcmy)O5M@a%Yj2$&KvQva=df||=8fCK z7FJwg)=ds%`i*t;Co zO6$X9=#%#=j1%H`3|r-_9)mq}(K^$=pw+NvRCgm zud>@t2CJlhl=S*>6vb!BdH+{EuDMa$Tjy8LFO1*y56U)QZTCr3P#y%!!7eIyDAmqq z#PZbdMj$z~*f#FcEW8rS!MN?_6v_5xGOl{5hP_>CmbW|mUDA5o+8tIvr2OYq3dVt# z+<0YX!}BDlo3oyVS5kar{a7Ida}NAK#Vc8bWR(;j0HhaVmgvZ`6G0mT%`L=e5~Q<` zLROk@ODO^PjozN6$1xNU2MG_H>6{;3z_U}qid@MzIUFg&+pQ$ zcCc0OF_mK&Gg&3Ck`>(4`!pkr@}nzT>b~0`VCx?ib}0;c&$*WF)qfC1a{C*14eyFs z0bo$eaO9N|(p<@=PbL?lj!wT6do=R2nRauYIXO?jUx>}G!6S=bL~~^;${VQB$39#9 zjdi){WvFV6ttkJ8K|UF4T!$YPa)yesuh>ko%JE+Lq>{zUVL9UMJv2?<`xmD%9P_Bo zb!r^Hi$_#ONNQ~!l`@;=V~kA3Pi}6b2M7&<_C%ZH?oX3<2^IcST0cl|`Zzy1@Zit) zd4n40TgRnkY#Z}dzG+p*uD@VNgx+d?S^Jg7<2G6%Y`q}aR@kSj#;hWn6@zWO~b zztWhMeD)e1n3J;>)O^7i`*K%6Kei?0)>erXwof=sZ*$Y@WXduq{)u7w z_HjVFK{pW-EH%}oB#q-v*xs=U@$VCt5o^Bv9#vfSU~>KPW#^8!er|18*0MssM^|w^ zrtTh_H_;(w!qztd|4GP1f78TcDo+i?qUor<-B{PP|~U#2Vj*3rIVzB|EIcMPj8pXAlo-tuHz!K`~3M151t=p~oz zq~?gT$Lgl~9U1oJc4L2}fRzB5K+>_ZCquMNA|_F<6T?vF4c2MdNDujHDHrQZ?NAV(ODYc*<)yLmP(K&< zR%5<4<`^o&%OQETrot0)PBb1*C_G|A<7kUS<}At25Bp zymgI0K1ZG?z?nIV6m4L~zWR@b1;CEAUCE8%iw}P5M+izyY4{_oXg;emNC{NQrE%{c zXgttex^Fme;T1cq#9;-=F?ezD`qu)h1XTWI#2Z_~SbL(wDuMktmOcKSxRdiSBF~_r zeG9|O7q^h2LKC@n>)x5tXpsUAC*GX2>SfU;Q^Oo}se67q?`)s6`vQ~RB(c}UF-Tb2 zS8OeWkC8;*7OoSI-fGPCs^_7Fg}44uLkdyJ4$-Z_9@q}=Bx}Z2agkq8Iy$=^Wm&XA z#MtN1(Eym(xy6-^af?<&Zf%lekB3;dh2SiPE5pjIkZIR}frmVgd8|a1H2R^(m{XY= zz0$f|TjRPbqhUElT0#MLuG4RzN$1VQu;uE#B9FSIxPAwD(sJw8C*#`|rSH4gQ?b%) zYc#fWhs2Td$SDRE04!F?JNHfr!7G&Wj$p{wlLWLOt>N97P*nqgYX8*hF;=M`vF66G z4VLo`3tg@N*>JHd+t@zqg4gODr>xe!ZhO2J|JHvP91n$+-nIhDW*gUao@(TCLNdF< zK5dns6Tuw!?zc;8y*{)4&0o+K@%3N;! zgmG98M>;+x-(k^~aq;}Nr&*@I@aKb43y3lZur>o-XJtT>{DaBgQ+K0tj@*^ zK~8t6urf`7VwK!3RpVXxXfoU9&-G!DhU#d&P98g_a>M%FW>W zWoRwA_phFgef?b5>XdJNML(KE^cN?G|14C3Y~E2V`(`2Ug}=@Iyo4C{d6d7ROn>wj zV6ETpsiW_~Fa+m(9vx4#(f}i%$F;Id7vB>9W^>BakO06QR7r)_1(s4u@U+XaNIFlv0^4j$fVi(csw% zHhQ}NKWaFA(=PZ>Rd!e~FI6M|m)29+kvx3C7z>-kqV{*=lG~m|Iwl=}KbcXqE->mR zK|g+*zWmC`wuZ7LJlz3IA<9i~TI|5Q8_B9mCdc42-=oJkNM|krt6O%?bNVbGY$X8+*GNnlxfYP8ef~MgdatOBHyktc> zpM{%6M!)^WxV)kJ3fG(h9k6jzj&fO8HZJkD1y+0>d2EE0u<)r2hA7xXAVcaZ5i-dS zCV4g8vesKa-;L}8H8j>c<=67OTA8J}yB#O-`%M<@vx!4M(Q<%=;!-fw4rbrT zxwX^Vv71dCg=wV?*%(^0JmYVI>iXmd0TBiC_iB|F(Xq2K)|GxbDcQD9et6(hEZ%vP z(|Rc2(+Kg>weDwZTEw51X9@y-_FeY}5ILv!b%BW9pZZjem{^w*JwRs{^HUu8;0k_? z&8cT!WsCve9h~b<*y3Xe2#CA>)&igcSaC-Mmt81pfM?H*c|>!Bo1ALUbHzN zotjF7H<=3`EO@_&p)cC?$Zt43JYpOL8;VkHN?in5H|Q4p5~TVJ_ed|~m;beuB&;xd zr{^*rV)8nt$$5niUnyG^{h6LBJ5ES-)FRaH6JkPy&x52NO@Eu7L_7*HT2o@ft4?&x zfWUG)f!w|M)LkVB%6+R4vr{FnM^`63dFtkVvREy18dPpqtzuSVhMDpwzvPjR(bW5> zr$RUG?Ad9B!?sf30O41u+!e1c-&eA2y{s^Oz0*Z`Zq6SzUxWW(OQ+|QMIFk5wm)ka zC3M)n%dNWijQw>{yI@Ft>+YNQBtH-fH26r~AjUXp*{dq?nDjXlzOCCD);KEZQYdEd zG#`sCFnJXs{_t)~S(?-v(z_mWsd=gVVy%8nq^a>U_Zbbp@w{@Q7?F1cu2ahsDhKL0y)JL~KS@1^7jTz| zu{8~?b7a5_)teJf^zVOnhaWR9VRa_AeUG;Ji$ICp)9X_R8Ap zPhTPZZ>YO1rfik4>p7jQ5l=7tBPYYVsxDMP9qV&y9t}+e+S?DGW|yfLAd-uejM*2H z^R{eRn$}JY_Q~QyHDw<~ol02gWeK$ye_eRTdBVA3wUlx;kC$3hLM+87C7o>x|6Ki0 z;ES0z;W-HVqJZz)_a#GWrZkRAdgG#$r52>l?>EMGI`3~J$UtRdG1N3Iv#aldsh3A3 zD#FiebGlP`Ocb#deB#*!BD)poxaI-u{;ajzl!&t1rO_Y3t=HEp-LmNV%JOCd*<
7KSQ=wKjfJnPVJE5eL@a3F>S-*7 zPnmHoy^R9st46+_XrWl-&f4EX?Fh`pX>tz^v5FFOPxos1T@`(RTqr58NHcet!tBoH z%`xO%T})}wu4v7>)x$Htg6n8^l23Q+#>1n-g#=$*j9t$>X3?pSe}p5~KRfiT(7-P0 zgUlTLYU&jX9o3liejJ6HvmM{)wXke9OZa{wdb#kLxU1+rrS+20*Ed%N{B%`jAL8^V zj8vb#)Ea94VzeM+wD<1f}Z z)XhT;Qz-EJdNT1yQuF-`iLZwc)e~~oz#B#>k?GZ`FWlH_hF&B%`Q5S&h^nf&NvzP5 zwh{gw%1}f;zeW7puP8Vy3E1VR6OAal!Qrz~k*=u~9y^{}^4s&y(#!(|+fP zp4s6(a`30sZMK`WJDw%E(7@r3a->TkW@$9hO;f#ZA~4<6S87^k@73C~)D>TL{0zcy zFFwbzs#I%`C4Cn>+y@-8upxV#6t zmxD!2gi1c;z_#a1u8%7CTS-|?ov|uZmIuAOpwG~`4mMBkCK*Ip6}6;r+UxdT@Rdey zwsR|_BOF38!HC0VyE1C8`5&;w7$gAZ48~58{<)*LKBV4^*UgDmk)gAleHF=|Y4lTK z(&_Pa%2GA5fKW3;eojR-?wSYWK{r%GjlKp1K8v`v{ww`1VTkY(#Y~eR+tvgf>rQ5# zgP}St!=T*ST(j<#Q`!Zyu<-QVDTvagU(47o2-0wDYZ9FCMEH7(3tfV1qk2Yh-%{dD z0AM6nu56|byk_ zyH+7Ck&Mv}_!E8Ut3yxTR=77H{U(HeBFK;Xk&tm$8=uMJ4QK>=E%e3{rcTod*pJU7 zEaEBe^0*}bl{rv#l>cNAzGdb*RWF**{`Nk!6?b))J-uzGvM9QrtBXDTmQMLW{(pAY z2d$2*>E`@0`WjeDD<|jB_~t&)E!#E_tL65$&?etbuJ!pX%EkJhWC4w{6K39!8fsem zXPqIi?62DKJX^(@u=L2pv(_xj-;OplAd82AAz_~it( zd%$m@*(q%`w~8Cr5`$ z+@h%3bN4{o=PU4<2TpwOA^!^zYd!fyVzzy`aAM%UhDkw^5Ko2{qbo!dT#!9o<-gJ^ zEjd2-Yj?_dz*AohZkR{?LU^^q{Ss7}zXWDa^(yQudYA}R4f|9^Ha?nGxGnp zwEIImTpYja7!_UhtA2*Q#Ykx#sp|cgzqkp)(8+8OvGFM|O>@s9aOj*<;;F*k;g#d@ f!|P!2IcI4r9Bag3I#sXh9SOin>WY=}W`X|$LPN!! literal 0 HcmV?d00001 diff --git a/docs/src/search.md b/docs/src/search.md index ad5c8d9a..c4147e6b 100644 --- a/docs/src/search.md +++ b/docs/src/search.md @@ -138,6 +138,36 @@ LanceDB supports binary vectors as a data type, and has the ability to search bi --8<-- "python/python/tests/docs/test_binary_vector.py:async_binary_vector" ``` +## Multivector type + +LanceDB supports multivector type, this is useful when you have multiple vectors for a single item (e.g. with ColBert and ColPali). + +You can index on a column with multivector type and search on it, the query can be single vector or multiple vectors. If the query is multiple vectors `mq`, the similarity (distance) from it to any multivector `mv` in the dataset, is defined as: + +![maxsim](assets/maxsim.png) + +where `sim` is the similarity function (e.g. cosine). + +For now, only `cosine` metric is supported for multivector search. + +=== "Python" + + === "sync API" + + ```python + --8<-- "python/python/tests/docs/test_multivector.py:imports" + + --8<-- "python/python/tests/docs/test_multivector.py:sync_multivector" + ``` + + === "async API" + + ```python + --8<-- "python/python/tests/docs/test_multivector.py:imports" + + --8<-- "python/python/tests/docs/test_multivector.py:async_multivector" + ``` + ## Search with distance range You can also search for vectors within a specific distance range from the query vector. This is useful when you want to find vectors that are not just the nearest neighbors, but also those that are within a certain distance. This can be done by using the `distance_range` method. diff --git a/docs/src/sql.md b/docs/src/sql.md index 5d3bdb09..2eddf0e3 100644 --- a/docs/src/sql.md +++ b/docs/src/sql.md @@ -18,7 +18,7 @@ import numpy as np uri = "data/sample-lancedb" data = [{"vector": row, "item": f"item {i}", "id": i} - for i, row in enumerate(np.random.random((10_000, 2)).astype('int'))] + for i, row in enumerate(np.random.random((10_000, 2)))] # Synchronous client db = lancedb.connect(uri) diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 68d7016c..a6d076ff 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -1741,12 +1741,14 @@ class AsyncQuery(AsyncQueryBase): a default `limit` of 10 will be used. Typically, a single vector is passed in as the query. However, you can also - pass in multiple vectors. This can be useful if you want to find the nearest - vectors to multiple query vectors. This is not expected to be faster than - making multiple queries concurrently; it is just a convenience method. - If multiple vectors are passed in then an additional column `query_index` - will be added to the results. This column will contain the index of the - query vector that the result is nearest to. + pass in multiple vectors. When multiple vectors are passed in, if the vector + column is with multivector type, then the vectors will be treated as a single + query. Or the vectors will be treated as multiple queries, this can be useful + if you want to find the nearest vectors to multiple query vectors. + This is not expected to be faster than making multiple queries concurrently; + it is just a convenience method. If multiple vectors are passed in then + an additional column `query_index` will be added to the results. This column + will contain the index of the query vector that the result is nearest to. """ if query_vector is None: raise ValueError("query_vector can not be None") diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 264200d6..10bd3316 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -2856,6 +2856,8 @@ class AsyncTable: async_query = async_query.with_row_id() if query.vector: + # we need the schema to get the vector column type + # to determine whether the vectors is batch queries or not async_query = ( async_query.nearest_to(query.vector) .distance_type(query.metric) diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index 96337dd8..67c5f9c7 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -223,7 +223,7 @@ def inf_vector_column_query(schema: pa.Schema) -> str: vector_col_count = 0 for field_name in schema.names: field = schema.field(field_name) - if pa.types.is_fixed_size_list(field.type): + if is_vector_column(field.type): vector_col_count += 1 if vector_col_count > 1: raise ValueError( @@ -231,7 +231,6 @@ def inf_vector_column_query(schema: pa.Schema) -> str: "Please specify the vector column name " "for vector search" ) - break elif vector_col_count == 1: vector_col_name = field_name if vector_col_count == 0: @@ -242,6 +241,29 @@ def inf_vector_column_query(schema: pa.Schema) -> str: return vector_col_name +def is_vector_column(data_type: pa.DataType) -> bool: + """ + Check if the column is a vector column. + + Parameters + ---------- + data_type : pa.DataType + The data type of the column. + + Returns + ------- + bool: True if the column is a vector column. + """ + if pa.types.is_fixed_size_list(data_type) and ( + pa.types.is_floating(data_type.value_type) + or pa.types.is_uint8(data_type.value_type) + ): + return True + elif pa.types.is_list(data_type): + return is_vector_column(data_type.value_type) + return False + + def infer_vector_column_name( schema: pa.Schema, query_type: str, diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index bd40811a..a6af7561 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -68,6 +68,60 @@ async def table_struct_async(tmp_path) -> AsyncTable: return await conn.create_table("test_struct", table) +@pytest.fixture +def multivec_table() -> lancedb.table.Table: + db = lancedb.connect("memory://") + # Generate 256 rows of data + num_rows = 256 + + # Generate data for each column + vector_data = [ + [[i, i + 1], [i + 2, i + 3]] for i in range(num_rows) + ] # Adjust to match nested structure + id_data = list(range(1, num_rows + 1)) + float_field_data = [float(i) for i in range(1, num_rows + 1)] + + # Create the Arrow table + df = pa.table( + { + "vector": pa.array( + vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2)) + ), + "id": pa.array(id_data), + "float_field": pa.array(float_field_data), + } + ) + return db.create_table("test", df) + + +@pytest_asyncio.fixture +async def multivec_table_async(tmp_path) -> AsyncTable: + conn = await lancedb.connect_async( + "memory://", read_consistency_interval=timedelta(seconds=0) + ) + # Generate 256 rows of data + num_rows = 256 + + # Generate data for each column + vector_data = [ + [[i, i + 1], [i + 2, i + 3]] for i in range(num_rows) + ] # Adjust to match nested structure + id_data = list(range(1, num_rows + 1)) + float_field_data = [float(i) for i in range(1, num_rows + 1)] + + # Create the Arrow table + df = pa.table( + { + "vector": pa.array( + vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2)) + ), + "id": pa.array(id_data), + "float_field": pa.array(float_field_data), + } + ) + return await conn.create_table("test_async", df) + + def test_cast(table): class TestModel(LanceModel): vector: Vector(2) @@ -177,6 +231,62 @@ async def test_distance_range_async(table_async: AsyncTable): assert res["_distance"].to_pylist() == [min_dist, max_dist] +def test_multivector(multivec_table: lancedb.table.Table): + # create index on multivector + multivec_table.create_index( + metric="cosine", + vector_column_name="vector", + index_type="IVF_PQ", + num_partitions=1, + num_sub_vectors=2, + ) + + # query with single vector + q = [1, 2] + rs = multivec_table.search(q).to_arrow() + + # query with multiple vectors + q = [[1, 2], [1, 2]] + rs2 = multivec_table.search(q).to_arrow() + assert len(rs2) == len(rs) + for i in range(2): + assert rs2["_distance"][i].as_py() == rs["_distance"][i].as_py() * 2 + + # can't query with vector that dim not matched + with pytest.raises(Exception): + multivec_table.search([1, 2, 3]).to_arrow() + # can't query with vector list that some dim not matched + with pytest.raises(Exception): + multivec_table.search([[1, 2], [1, 2, 3]]).to_arrow() + + +@pytest.mark.asyncio +async def test_multivector_async(multivec_table_async: AsyncTable): + # create index on multivector + await multivec_table_async.create_index( + "vector", + config=IvfPq(distance_type="cosine", num_partitions=1, num_sub_vectors=2), + ) + + # query with single vector + q = [1, 2] + rs = await multivec_table_async.query().nearest_to(q).to_arrow() + + # query with multiple vectors + q = [[1, 2], [1, 2]] + rs2 = await multivec_table_async.query().nearest_to(q).to_arrow() + assert len(rs2) == len(rs) + for i in range(2): + assert rs2["_distance"][i].as_py() == rs["_distance"][i].as_py() * 2 + + # can't query with vector that dim not matched + with pytest.raises(Exception): + await multivec_table_async.query().nearest_to([1, 2, 3]).to_arrow() + # can't query with vector list that some dim not matched + with pytest.raises(Exception): + await multivec_table_async.query().nearest_to([[1, 2], [1, 2, 3]]).to_arrow() + + def test_vector_query_with_no_limit(table): with pytest.raises(ValueError): LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select( @@ -448,11 +558,13 @@ async def test_query_to_pandas_flatten_async(table_struct_async: AsyncTable): @pytest.mark.asyncio async def test_query_to_polars_async(table_async: AsyncTable): + schema = await table_async.schema() + num_columns = len(schema.names) df = await table_async.query().to_polars() - assert df.shape == (2, 5) + assert df.shape == (2, num_columns) df = await table_async.query().where("id < 0").to_polars() - assert df.shape == (0, 5) + assert df.shape == (0, num_columns) @pytest.mark.asyncio diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index ccc047dc..a972bd6e 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -18,7 +18,7 @@ use std::collections::HashMap; use std::path::Path; use std::sync::Arc; -use arrow::array::AsArray; +use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder}; use arrow::datatypes::{Float32Type, UInt8Type}; use arrow_array::{RecordBatchIterator, RecordBatchReader}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; @@ -1902,68 +1902,74 @@ impl TableInternal for NativeTable { options: QueryExecutionOptions, ) -> Result> { let ds_ref = self.dataset.get().await?; + let mut column = query.column.clone(); + let schema = ds_ref.schema(); + let mut query_vector = query.query_vector.first().cloned(); if query.query_vector.len() > 1 { - // If there are multiple query vectors, create a plan for each of them and union them. - let query_vecs = query.query_vector.clone(); - let plan_futures = query_vecs - .into_iter() - .map(|query_vector| { - let mut sub_query = query.clone(); - sub_query.query_vector = vec![query_vector]; - let options_ref = options.clone(); - async move { self.create_plan(&sub_query, options_ref).await } - }) - .collect::>(); - let plans = futures::future::try_join_all(plan_futures).await?; - return Table::multi_vector_plan(plans); + if column.is_none() { + // Infer a vector column with the same dimension of the query vector. + let arrow_schema = Schema::from(ds_ref.schema()); + column = Some(default_vector_column( + &arrow_schema, + Some(query.query_vector[0].len() as i32), + )?); + } + let vector_field = schema.field(column.as_ref().unwrap()).unwrap(); + if let DataType::List(_) = vector_field.data_type() { + // it's multivector, then the vectors should be treated as single query + // concatenate the vectors into a FixedSizeList> + // it's also possible to concatenate the vectors into a List>, + // but FixedSizeList is more efficient and easier to construct + let vectors = query + .query_vector + .iter() + .map(|arr| arr.as_ref()) + .collect::>(); + let dim = vectors[0].len(); + let mut fsl_builder = FixedSizeListBuilder::with_capacity( + Float32Builder::with_capacity(dim), + dim as i32, + vectors.len(), + ); + for vec in vectors { + fsl_builder + .values() + .append_slice(vec.as_primitive::().values()); + fsl_builder.append(true); + } + query_vector = Some(Arc::new(fsl_builder.finish())); + } else { + // If there are multiple query vectors, create a plan for each of them and union them. + let query_vecs = query.query_vector.clone(); + let plan_futures = query_vecs + .into_iter() + .map(|query_vector| { + let mut sub_query = query.clone(); + sub_query.query_vector = vec![query_vector]; + let options_ref = options.clone(); + async move { self.create_plan(&sub_query, options_ref).await } + }) + .collect::>(); + let plans = futures::future::try_join_all(plan_futures).await?; + return Table::multi_vector_plan(plans); + } } let mut scanner: Scanner = ds_ref.scan(); - if let Some(query_vector) = query.query_vector.first() { + if let Some(query_vector) = query_vector { // If there is a vector query, default to limit=10 if unspecified - let column = if let Some(col) = query.column.as_ref() { - col.clone() + let column = if let Some(col) = column { + col } else { // Infer a vector column with the same dimension of the query vector. let arrow_schema = Schema::from(ds_ref.schema()); default_vector_column(&arrow_schema, Some(query_vector.len() as i32))? }; - let field = ds_ref.schema().field(&column).ok_or(Error::Schema { - message: format!("Column {} not found in dataset schema", column), - })?; - - let mut is_binary = false; - if let arrow_schema::DataType::FixedSizeList(element, dim) = field.data_type() { - match element.data_type() { - e_type if e_type.is_floating() => {} - e_type if *e_type == DataType::UInt8 => { - is_binary = true; - } - _ => { - return Err(Error::InvalidInput { - message: format!( - "The data type of the vector column '{}' is not a floating point type", - column - ), - }); - } - } - if dim != query_vector.len() as i32 { - return Err(Error::InvalidInput { - message: format!( - "The dimension of the query vector does not match with the dimension of the vector column '{}': \ - query dim={}, expected vector dim={}", - column, - query_vector.len(), - dim, - ), - }); - } - } - + let (_, element_type) = lance::index::vector::utils::get_vector_type(schema, &column)?; + let is_binary = matches!(element_type, DataType::UInt8); if is_binary { let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?; let query_vector = query_vector.as_primitive::(); @@ -1973,10 +1979,9 @@ impl TableInternal for NativeTable { query.base.limit.unwrap_or(DEFAULT_TOP_K), )?; } else { - let query_vector = query_vector.as_primitive::(); scanner.nearest( &column, - query_vector, + query_vector.as_ref(), query.base.limit.unwrap_or(DEFAULT_TOP_K), )?; } diff --git a/rust/lancedb/src/utils.rs b/rust/lancedb/src/utils.rs index d1019a9f..c8455b4c 100644 --- a/rust/lancedb/src/utils.rs +++ b/rust/lancedb/src/utils.rs @@ -108,13 +108,8 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option) -> Result let candidates = schema .fields() .iter() - .filter_map(|field| match field.data_type() { - arrow_schema::DataType::FixedSizeList(f, d) - if (f.data_type().is_floating() || f.data_type() == &DataType::UInt8) - && dim.map(|expect| *d == expect).unwrap_or(true) => - { - Some(field.name()) - } + .filter_map(|field| match inf_vector_dim(field) { + Some(d) if dim.is_none() || dim == Some(d) => Some(field.name()), _ => None, }) .collect::>(); @@ -138,6 +133,20 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option) -> Result } } +fn inf_vector_dim(field: &arrow_schema::Field) -> Option { + match field.data_type() { + arrow_schema::DataType::FixedSizeList(f, d) => { + if f.data_type().is_floating() || f.data_type() == &DataType::UInt8 { + Some(*d) + } else { + None + } + } + arrow_schema::DataType::List(f) => inf_vector_dim(f), + _ => None, + } +} + pub fn supported_btree_data_type(dtype: &DataType) -> bool { dtype.is_integer() || dtype.is_floating() @@ -171,9 +180,10 @@ pub fn supported_fts_data_type(dtype: &DataType) -> bool { pub fn supported_vector_data_type(dtype: &DataType) -> bool { match dtype { - DataType::FixedSizeList(inner, _) => { - DataType::is_floating(inner.data_type()) || *inner.data_type() == DataType::UInt8 + DataType::FixedSizeList(field, _) => { + field.data_type().is_floating() || field.data_type() == &DataType::UInt8 } + DataType::List(field) => supported_vector_data_type(field.data_type()), _ => false, } }