```1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 ``` ```#!/usr/bin/env python """shamir_threshold_scheme.py Shamir's (k, n) threshold scheme. See "The Handbook of Applied Cryptography" or Shamir's 1979 paper, "How to Share a Secret." GRE, 6/11/11 """ from operator import mul from random import randrange, sample ####### Preliminaries def gcd(a, b): """Greatest common divisor of a and b""" while b: a, b = b, a % b return a def mod_inv(x, p): """x^{-1} mod p, per Programming Praxis's comment on http://programmingpraxis.com/2009/07/07/modular-arithmetic/""" assert gcd(x, p) == 1, "Divisor %d not coprime to modulus %d" % (x, p) z, a = (x % p), 1 while z != 1: q = - (p / z) z, a = (p + q * z), (q * a) % p return a def prod(nums): """Product of nums""" return reduce(mul, nums, 1) ####### def horner_mod(coeffs, mod): """Polynomial with coeffs of degree len(coeffs) via Horner's rule; uses modular arithmetic. For example, if coeffs = [1,2,3] and mod = 5, this returns the function x --> (x, y) where y = 1 + 2x + 3x^2 mod 5.""" return lambda x: (x, reduce(lambda a, b: a * x + b % mod, reversed(coeffs), 0) % mod) def shamir_threshold(S, k, n, p): """Shamir's simple (k, n) threshold scheme. Returns xy_pairs genrated by secret polynomial mod p with constant term = S. Information is given to n different people any k of which constitute enough to reconstruct the secret data.""" coeffs = [S] # Independent but not necessarily unique; choose k - 1 coefficients from [1, p) coeffs.extend(randrange(1, p) for _ in xrange(k - 1)) # x values are unique return map(horner_mod(coeffs, p), sample(xrange(1, p), n)) def interp_const(xy_pairs, k, p): """Use Lagrange Interpolation to find the constant term of the degree k polynomial (mod p) that gave the xy-pairs; we get to use a shortcut since we are only after the constant term for which x = 0.""" assert len(xy_pairs) >= k, "Not enough points for interpolation" x = lambda i: xy_pairs[i][0] y = lambda i: xy_pairs[i][1] return sum(y(i) * prod(x(j) * mod_inv(x(j) - x(i), p) % p for j in xrange(k) if j != i) for i in xrange(k)) % p if __name__ == "__main__": from pprint import pprint # Pretty printing S = int("PRAXIS", 36); print S # Prints 1557514036 n, k, p = 20, 5, 1557514061 # p is the next prime after S xy_pairs = shamir_threshold(S, k, n, p) pprint(xy_pairs) # Prints all 20 (x, y) pairs print interp_const(xy_pairs, k, p) # Should print 1557514036 ```
 ```1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 ``` ```1557514036 [(697286162, 445615394L), (471866046, 757728985L), (112045393, 1132162792L), (397324764, 486286231L), (135120894, 1142009194L), (508637994, 1556915744L), (488738532, 834401917L), (1369874096, 1345716686L), (91597754, 487556032L), (970187759, 341284274L), (1102805729, 224871713L), (245100902, 1306749801L), (413372256, 568733054L), (1218343037, 63534734L), (442535975, 1060000953L), (1173207231, 400308586L), (515043844, 141960722L), (1162691976, 374990038L), (73252341, 785232686L), (934671161, 486917357L)] 1557514036 ```