/* wordsquare -- build word squares */
/* D. Eppstein, UC Irvine, 10 Mar 2001 */

/* run without any command-line arguments for usage information */


#include <stdio.h>

#define MAX_WORD_LEN 20

/* ====================================================================== */
/* ==============================  globals  ============================= */
/* ====================================================================== */


struct dictionary {
    long mask;
    struct dictionary * child;
    struct dictionary * sibling;
} * words = 0;

char * dict_file = "/usr/dict/words";

char * diag = 0;

char * anagram = 0;

int dimension = 0;
long initial_letters = (1<<26)-1;

int symmetric = 1;
int asymmetric = 0;
int palindromic = 0;
int diag_word = 0;

/* Search state.  We treat this as an array of dimension*dimension bitmasks,
   each of which specifies the set of possible characters at that position
   of the word square.  However it is actually allocated as a much bigger
   array, so that push() and pop() can be used to control the backtracking
   search that we use when dictionary-based restriction isn't good enough. */

long * square;


/* ====================================================================== */
/* =================  command line argument processing  ================= */
/* ====================================================================== */

void usage()
{
    printf("usage:\n");
    printf("-a : allow asymmetric result squares\n");
    printf("-A : force all squares to be asymmetric\n");
    printf("-d string : force letters on main diagonal\n");
    printf("-D : force main diagonal to be word\n");
    printf("-l string : restrict all letters to ones from the string\n");
    printf("-p string : restrict output to be permutation of these letters\n");
    printf("-P : force word square to be palindromic\n");
    printf("-w filename : set wordlist from filename\n");
    printf("-W : use stdin for wordlist\n");
    printf("-n nnn : set number of rows/columns\n");
    exit(0);
}

void complain(char * s)
{
    fprintf(stderr, "%s!\n", s);
    exit(1);
}

long letters_from(char * s)
{
    long mask = 0;
    while (*s) {
	if (*s >= 'A' && *s <= 'Z') mask |= (1 << (*s - 'A'));
	else if (*s >= 'a' && *s <= 'z') mask |= (1 << (*s - 'a'));
	else complain("Letter restriction must be alphabetic");
	s++;
    }
    return mask;
}

void size_anagram() {
    int letter_counts[26];
    char * s = anagram;
    int i = 0, n;
    if (anagram == 0) return;
    n = strlen(anagram);
    while (i*i < n) i++;
    if (i*i != n) complain("Permuted letter set does not have square length");
    if (dimension > 0 && dimension != i)
	complain("Permuted letter set length does not match word square size");
    dimension = i;
    if (initial_letters != (1<<26)-1)
	complain("Do not specify both -l and -p options");
}

void parse_args(int ac, char **av) {
    if (ac < 2) usage();
    ac--; av++;			/* skip prog name */
    while (ac) {
	if ((*av)[0] != '-' || (*av)[2] != '\0')
	    complain("Argument must begin with dash");

	switch ((*av)[1]) {
	case 'A':
	    asymmetric = 1;
	case 'a':		/* asymmetric */
	    ac--; av++;
	    symmetric = 0;
	    break;

	case 'd':		/* set diagonal */
	    ac--; av++;
	    if (ac <= 0) complain("Diagonal string missing");
	    diag = *av;
	    ac--; av++;
	    break;

	case 'D':
	    ac--; av++;
	    diag_word = 1;
	    break;

	case 'l':
	    ac--; av++;
	    if (ac <= 0) complain("Letter set missing");
	    initial_letters = letters_from(*av);
	    ac--; av++;
	    break;

	case 'p':
	    ac--; av++;
	    if (ac <= 0) complain("Permuted letter set missing");
	    anagram = *av;
	    ac--; av++;
	    break;
	    
    case 'P':
        ac--; av++;
        palindromic = 1;
        break;

	case 'w':		/* set dictionary */
	    ac--; av++;
	    if (ac <= 0) complain("Dictionary filename missing");
	    dict_file = *av;
	    ac--; av++;
	    break;

	case 'W':
	    ac--; av++;
	    dict_file = 0;
	    break;

	case 'n':		/* set number of rows and columns */
	    ac--; av++;
	    if (ac <= 0) complain("Word square size missing");
	    dimension = atoi(*av);
	    ac--; av++;
	    break;

	default:
	    complain("Unrecognized command line option");
	}
    }
    size_anagram();
    if (dimension <= 0) {
	if (diag) dimension = strlen(diag);
	else complain("Word square size not specified");
    }
    if (dimension > MAX_WORD_LEN) complain("Word square size too big");
}

void set_diag() {
    int i;
    if (diag == 0) return;
    if (dimension != strlen(diag))
	complain("Diagonal length does not equal word square dimension");

    for (i = 0; i < dimension; i++) {
	char c = diag[i];
	if (c >= 'a' && c <= 'z') c += 'A' - 'a';
	else if (c < 'A' || c > 'Z')
	    complain("Nonalphabetic character in diagonal");

	square[i*(dimension+1)] = 1 << (c-'A');
    }
}

/* ====================================================================== */
/* =======================  dictionary processing  ====================== */
/* ====================================================================== */

void read_dict() {
    char buffer[MAX_WORD_LEN];
    struct dictionary * stack[MAX_WORD_LEN];

    /* set up stdin to point to dictionary file */
    if (dict_file && !freopen(dict_file, "r", stdin))
	complain("Unable to open dictionary file");

    words = stack[0] = (struct dictionary *) malloc(sizeof(struct dictionary));
    stack[0]->mask = 0;
    stack[0]->child = 0;
    stack[0]->sibling = 0;

    /* loop reading words */
    for (;;) {
	int c, i;
	c = 'A';

	/* loop reading word letters and upcasing them */
	for (i = 0; i < dimension && c >= 'A' && c <= 'Z'; i++) {
	    buffer[i] = c = getchar();
	    if (c >= 'a' && c <= 'z') buffer[i] = c = c + 'A' - 'a';
	}

	if (i == dimension && c >= 'A' && c <= 'Z') {
	    /* here when we have a long enough word */
	    c = getchar();
	    if (c == ' ' || c == '\n' || c == '\r') {
		/* here when word length is just right */

		/* look for branch point in trie */
		for (i = 0; i < dimension &&
			 (stack[i]->mask & (1<<(buffer[i]-'A'))) != 0; i++)
		    if (stack[i]->mask >= (1<<(buffer[i]-'A'+1)))
			complain("Dictionary is out of order");

		/* add new path to trie */
		while (i < dimension) {
		    struct dictionary * newdict;
		    stack[i]->mask |= (1<<(buffer[i]-'A'));
		    i++;
		    if (i == dimension) break;
		    newdict =
			(struct dictionary *)
			malloc(sizeof(struct dictionary));
		    if (stack[i-1]->child == 0) stack[i-1]->child = newdict;
		    else stack[i]->sibling = newdict;
		    stack[i] = newdict;
		    newdict->mask = 0;
		    newdict->child = 0;
		    newdict->sibling = 0;
		}
	    }
	}

	/* finish line of dict */
	while (c != '\n' && c != EOF) c = getchar();
	if (c == EOF) return;
    }
}

/* Which characters match words from the dictionary? */

int restrict_word(int start, int offset)
{
    int i = 0;
    struct dictionary * stack[MAX_WORD_LEN];
    long masks[MAX_WORD_LEN];
    long restrictions[MAX_WORD_LEN];
    int retval = 0;

    /* assume initially that no words match */
    for (i = 0; i < dimension; i++) restrictions[i] = 0;

    /* traverse trie */
    stack[0] = words;
    i=0;
    for (;;) {
	long m = stack[i]->mask & square[start+i*offset];
	long g = 0;
	if (i > 0)
	    g = (masks[i-1] & ~(masks[i-1] - 1) & square[start+(i-1)*offset]);
	if (stack[i]->child != 0 && m != 0 && (i == 0 || g != 0)) {
	    /* go into child of stack[i]? */
	    masks[i] = stack[i]->mask;
	    stack[i+1] = stack[i]->child;
	    i++;
	} else {
	    /* handle leaf node */
	    if (g != 0 && i == dimension-1 && m != 0) {
		int j;
		for (j = 0; j < i; j++)
		    restrictions[j] |= (masks[j] &~ (masks[j]-1));
		restrictions[i] |= m;
	    }

	    /* go on to next sibling at lowest level with a sibling */
	    while (i > 0 && (masks[i-1] & (masks[i-1]-1) &
			     square[start+(i-1)*offset]) == 0)
		i--;
	    if (i == 0) break;
	    masks[i-1] &= masks[i-1]-1;
	    stack[i] = stack[i]->sibling;
	}
    }

    for (i = 0; i < dimension; i++) {
	restrictions[i] &= square[start+i*offset];
	if (!retval) retval = (restrictions[i] != square[start+i*offset]);
	square[start+i*offset] &= restrictions[i];
    }
    return retval;
}

int restrict() {
    int i;
    int changed = 0;
    if (diag_word) changed |= restrict_word(0, dimension+1);
    for (i = 0; i < dimension; i++) {
	changed |= restrict_word(i, dimension);
	changed |= restrict_word(i*dimension, 1);
    }
    if (changed && palindromic) {
        int d = dimension * dimension - 1;
        for (i = 0; i <= d; i++) {
            long x = square[i] & square[d-i];
            square[i] = square[d-i] = x;
        }
    }
    return changed;
}

/* ====================================================================== */
/* ========================  backtracking search  ======================= */
/* ====================================================================== */

/* save and restore search state */

void push()
{
    int i;
    int d = dimension * dimension;
    for (i = 0; i < d; i++)
	square[i+d] = square[i];
    square += d;
}

void pop()
{
    square -= dimension * dimension;
}

/* does this solution really match the -p argument? */
int check_anagram() {
    int d = dimension * dimension;
    char counts[26];
    int i;

    if (anagram == 0) return 1; /* bugfix 2008.10.26 from Michael Brudno */
    for (i = 0; i < 26; i++) counts[i] = 0;
    for (i = 0; i < d; i++) {
	if (anagram[i] >= 'A' && anagram[i] <= 'Z')
	    counts[anagram[i]-'A']++;
	else if (anagram[i] >= 'a' && anagram[i] <= 'z')
	    counts[anagram[i]-'a']++;
    }

    for (i = 0; i < d; i++) {
	int c;
	for (c = 0; !(square[i]&(1<<c)); c++) ;
	if (--counts[c] < 0) return 0;
    }
    return 1;
}

/* is this pattern really asymmetric?
   if so, don't pass both it and its flipped version */

int check_sym() {
    int i, j;
    if (symmetric) return 1;
    if (asymmetric) {
	for (i = 0; i < dimension; i++)
	    for (j = i+1; j < dimension; j++)
		if (square[i*dimension+j] == square[i+dimension*j]) return 0;
	return 1;
    }
    for (i = 0; i < dimension; i++)
	for (j = i+1; j < dimension; j++)
	    if (square[i*dimension+j] != square[i+dimension*j])
		return (square[i*dimension+j] < square[i+dimension*j]);
    return 0;
}

/* found a pattern, output it */

char charfor(long mask)
{
    int i;
    for (i = 0; i < 26; i++)
	if (mask == (1<<i)) return (i+'A');
    complain("Unable to translate mask to character");
    return 0;
}

void success() {
    int i, j;
    putchar('\n');
    for (i = 0; i < dimension; i++) {
	for (j = 0; j < dimension; j++)
	    printf("%c ", charfor(square[i*dimension + j]));
	putchar('\n');
    }
}

/* Here is where the actual solving gets done.
 * We use the dictionary to rule out the possibility of using certain
 * letters at certain positions.  When no more restrictions are possible,
 * we resort to a more brute-force backtracking-search approach. */

void relax()
{
    for (;;) {			/* loop here for second branch of backtrack */
	int best_choice = -1;
	int nchoices = 27;
	int i = 0;
	int bcsym;

	while (restrict()) ;	/* repeat restriction loop til no progress */

	/* force asymmetry */
	if (asymmetric) {
	    int j;
	    for (i = 0; i < dimension; i++)
		for (j = 0; j < dimension; j++) if (i != j) {
		    long x = square[i*dimension+j];
		    if ((x & (x-1)) == 0)
			square[j*dimension+i] &=~ x;
		}
	}

	/* search for position with multiple choices but fewest of them */
	for (i = 0; i < dimension * dimension; i++) {
	    int n = 0;
	    long x = square[i];
	    while (x && n < nchoices) { n++; x &= (x-1); } /* count nonzero bits */
	    if (n == 0) return;	/* dead end */
	    if (n > 1 && n < nchoices) {
		best_choice = i;
		nchoices = n;
	    }
	}

	if (best_choice < 0) {	/* only one char in all positions? */
	    if (check_anagram() && check_sym()) success();
	    return;
	}
	if (symmetric) bcsym = (best_choice/dimension)
			   + (best_choice % dimension)*dimension;
	else bcsym = best_choice;

	/* branch and recurse */
	push();
	square[best_choice] &=~ (square[best_choice] - 1);
	square[bcsym] = square[best_choice];
	if (palindromic) {
	   int d = dimension * dimension - 1;
	   square[d - best_choice] = square[d - bcsym] = square[best_choice];
	}
	relax();
	pop();
	square[best_choice] &= (square[best_choice] - 1);
	square[bcsym] = square[best_choice];
	if (palindromic) {
	   int d = dimension * dimension - 1;
	   square[d - best_choice] = square[d - bcsym] = square[best_choice];
    }
    }
}

/* ====================================================================== */
/* ============================  main entry  ============================ */
/* ====================================================================== */

void init_anagram() {
    int d = dimension * dimension;
    char counts[26];
    int i;
    long init,mask;

    if (anagram == 0) return;
    for (i = 0; i < 26; i++) counts[i] = 0;
    for (i = 0; i < d; i++) {
	if (anagram[i] >= 'A' && anagram[i] <= 'Z')
	    counts[anagram[i]-'A']++;
	else if (anagram[i] >= 'a' && anagram[i] <= 'z')
	    counts[anagram[i]-'a']++;
	else complain("Permuted letter set must be alphabetic");
    }

    /* set up initial_letters */
    if (!symmetric) mask = -1L; else mask =~ 1;
    init = 0;
    for (i = 0; i < 26; i++) if (counts[i] & mask) init |= (1<<i);
    for (i = 0; i < d; i++) square[i] = init;
    if (symmetric) {
	init = 0;
	for (i = 0; i < 26; i++) if (counts[i] & 1) init |= (1<<i);
	for (i = 0; i < dimension; i++) square[i*(dimension+1)] = init;
    }
}

int main(int ac, char **av) {
    int i;
    int d;
    parse_args(ac,av);
    d = dimension * dimension;
    square = (long *) malloc(sizeof(long) * d * (d+1));

    for (i = 0; i < dimension * dimension; i++) square[i] = initial_letters;
    init_anagram();
    set_diag();
    read_dict();

    relax();
    return 0;
}
