Veja abaixo três implementações. A primeira é a convencional, a segunda é uma implementação recursiva com a mesma complexidade da primeira e a terceira é a implementação de Strassen.
import numpy as np
import math as mt
# Para matrizes de ordem n
def matrixMultiplication(A,B):
m=A.shape[0]
p=A.shape[1]
n=B.shape[1]
C=np.empty([m,n])
for i in range(0,m):
for j in range(0,n):
C[i,j]=0
for k in range(0,p):
C[i,j]=C[i,j]+A[i,k]*B[k,j]
return C
def multiplicationRec(A,B):
# Espera-se que a input seja uma matriz quadrada
n=A.shape[0]
pn=int(mt.floor(mt.log(n)/mt.log(2)))
if(int(pow(2,pn))==n):
C=matrixMultiplicationRec(A,B)
return C
else:
pn=pn+1
newn=pow(2,pn)
Anew=np.zeros([newn,newn])
Bnew=np.zeros([newn,newn])
Anew[0:n,0:n]=A[0:n,0:n]
Bnew[0:n,0:n]=B[0:n,0:n]
C=matrixMultiplicationRec(Anew,Bnew)
return C
def matrixMultiplicationRec(A,B):
n=A.shape[0]
if(n==1):
C=A[0,0]*B[0,0]
else:
nn=n//2
a=A[0:nn,0:nn]
b=A[0:nn,nn:n]
c=A[nn:n,0:nn]
d=A[nn:n,nn:n]
e=B[0:nn,0:nn]
f=B[0:nn,nn:n]
g=B[nn:n,0:nn]
h=B[nn:n,nn:n]
r=matrixMultiplicationRec(a,e)+matrixMultiplicationRec(b,g)
s=matrixMultiplicationRec(a,f)+matrixMultiplicationRec(b,h)
t=matrixMultiplicationRec(c,e)+matrixMultiplicationRec(d,g)
u=matrixMultiplicationRec(c,f)+matrixMultiplicationRec(d,h)
C=np.empty([n,n])
C[0:nn,0:nn]=r
C[0:nn,nn:n]=s
C[nn:n,0:nn]=t
C[nn:n,nn:n]=u
return C
def strassen(A,B):
# Espera-se que a input seja uma matriz quadrada
n=A.shape[0]
pn=int(mt.floor(mt.log(n)/mt.log(2)))
if(int(pow(2,pn))==n):
C=matrixMultiplicationStrassen(A,B)
return C
else:
pn=pn+1
newn=pow(2,pn)
Anew=np.zeros([newn,newn])
Bnew=np.zeros([newn,newn])
Anew[0:n,0:n]=A[0:n,0:n]
Bnew[0:n,0:n]=B[0:n,0:n]
C=matrixMultiplicationStrassen(Anew,Bnew)
return C
def matrixMultiplicationStrassen(A,B):
# How to make this work for n!=2^n?
n=A.shape[0]
if(n==1):
return A[0]*B[0]
else:
nn=n//2
a=A[0:nn,0:nn]
b=A[0:nn,nn:n]
c=A[nn:n,0:nn]
d=A[nn:n,nn:n]
e=B[0:nn,0:nn]
f=B[0:nn,nn:n]
g=B[nn:n,0:nn]
h=B[nn:n,nn:n]
P1=matrixMultiplicationStrassen(a,f-h)
P2=matrixMultiplicationStrassen(a+b,h)
P3=matrixMultiplicationStrassen(c+d,e)
P4=matrixMultiplicationStrassen(d,g-e)
P5=matrixMultiplicationStrassen(a+d,e+h)
P6=matrixMultiplicationStrassen(b-d,g+h)
P7=matrixMultiplicationStrassen(a-c,e+f)
r=P5+P4-P2+P6
s=P1+P2
t=P3+P4
u=P5+P1-P3-P7
C=np.empty([n,n])
C[0:nn,0:nn]=r
C[0:nn,nn:n]=s
C[nn:n,0:nn]=t
C[nn:n,nn:n]=u
return C
#A=np.array([[1,2,3,4,5],[6,7,8,9,10],[11,12,13,14,15],[16,17,18,19,20],[21,22,23,24,25]])
#A=np.array([[1,2,3],[4,5,6],[7,8,9]])
#B=np.array([[5,6,7],[8,9,10],[11,12,13]])
#A=np.array([[1,2],[3,4]])
#B=np.array([[5,6],[7,8]])
A=np.ones([5,5])
B=np.ones([5,5])
B[1,1]=77
B[3,2]=103
A[4,4]=17
print "Conventional matrix multiplication:"
print matrixMultiplication(A,B)
print "Recursive matrix multiplication:"
print multiplicationRec(A,B)
print "Strassen matrix multiplication:"
print strassen(A,B)