#/usr/bin/python

#################################################### 
# 
# Author: Brian Chan (birdchan@ucdavis.edu)
# Supervisor : Alex Kozik (akozik@atgc.org)
# Date: February 27 2004
# 
#################################################### 
# 
# In short, this script reads in the contig ID info
#    and contig mismatch info, then output details
#    about possible polymorphism. 
# 
############ Detailed description #################
# 
# It first asks for misMatch file, which is in this format:
#    ContigName \t seqName \t (index1:symbol)*
# For example, a line could be like the following:
#    QH_CA_Contig23 QHA10B27.yg.ab1+ 72:D|29:I
# The symbol D means deletion, S means substitution, I
#    means insertion. So, say 72:D means that at index
#    72 of such seq of such contig there is a deletion.
# 
# It then asks for a contig ID file, which is in this format:
#    ContigName \t numOfSeq \t (seqName w/ start and end indexes)*
# For example, a line could look like this:
#    QG_CA_Contig72 2 Seq1+(2,200)|Seq2+(1,234)
# 
# It then asks for whether we want to use prefix or
#    suffix to distinguish seqs in two groups. 
# 
# It then asks how long the pattern is. This pattern 
#    corresponds to the prefix or suffix we chose.
# 
# Then it asks for the pattern for group 1. Then for
#    group 2.
# 
# Say, if we input QGA, QGB, QGC, QGD and QGI for 
#    group 1, and QGE, QGF, QGG, QGH, and QGJ for
#    group 2. Then all the seq with the prefix or
#    suffix (depending what we chose) in the group
#    1 patterns will be placed in group 1. Same 
#    applies to group 2. For example, seq QGB1122
#    will be placed in group 1. All unidentified
#    seq will be ignored. 
# 
# After all these input, the script then analysze 
#    the content of the two input files. It forms
#    groups and checks if possibly polymorphism
#    exists. If so, output findings to corresponding
#    output file. 
# 
# The output format is as follow:
#    ContigName DIS ALL PartialFraction LowPriority
# For example, the followings are valid outputs:
#    QH_CA_Contig921 S 1 2/5 1
#    QH_CA_Contig3499 I 7 1/13 0
# 
# DIS is an element of {D, I, S}
# 
# ALL means the number of DIS occuring in a group
#    whose sequences having a particular index all
#    have this DIS, while the other sequences in
#    the other group all do not have this DIS at
#    that particular index.
# 
# PartialFraction is similar to ALL, except that
#    some sequences don't have that particular
#    DIS, while the sequences in the other group
#    all don't have that DIS. 
# 
# LowPriority is either 0 or 1. Basically, when
#    ALL is greater than 1, then LowPriority is 0,
#    meaning this contig has some priority. When
#    ALL is less than or equal to 1, then LowPriority
#    is 1, meaning this contig has low priority.
# 
# As an illustration,
#    -----D---------
#    -----D-------------S---------D---
#    -----D-------------S---
#       --D---------
#      ========I==S=======D===
#      ========I==S====
#      ========I=====
# 
# The output would be
#    ContigName D 4 0/1 0
#    ContigName I 3 0/1 0
#    ContigName S 2 2/3 0
# 
# Notice, the default output of the partial fraction
#    is 0/1 . 
# 
# The following outputs are wrong
#    ContigName D 1 0/1 1  (since we have 4 D's for ALL)
#    ContigName S 0 2/3 0  (since we do have something for ALL)
# 
# As rules given for this script, we keep the 
#    biggest "ALL" and "PartialFraction". 
# 
#################################################### 
# importing the libraries

import sys
from string import *
from string import atoi
from string import rstrip
from math import log10

####################################################
# class definition

def mycmp(left, right):
	lSeqNum = int( left[12:len(left)] )      # get seq ID
	rSeqNum = int( right[12:len(right)] )    # get seq ID
	if lSeqNum > rSeqNum:
		return 1
	else:
		return -1

class seqInfo:
	def __init__(self, name):
		self.seqName = name
		self.info = ""
	def setSeqName(self, seqName):
		self.seqName = seqName
	def addInfo(self, info):
		self.info = info
	def printInfo(self):
		print self.info

class seqMMGroupClass:
	def __init__(self, contigID):
		self.contigID = contigID
		self.seqList_G1 = {}
		self.seqList_G1_index = []
		self.seqList_G2 = {}
		self.seqList_G2_index = []
	def readinSeq(self, name, info):
		global contig_pre_or_suf
		global contig_pre_or_suf_size
		global contig_G1_def
		global contig_G2_def
		if contig_pre_or_suf == "p":          # if prefix
			type_str = name[0:contig_pre_or_suf_size]
		else:                                 # else, assume it is suffix
			type_str = name[contig_pre_or_suf_size*(-1):]
		if (type_str in contig_G1_def):
			if info == "no polymorphism found":
				self.seqList_G1[name] = ""
			else:
				self.seqList_G1[name] = info
				allInfos = split(info, "|")
				for info in allInfos:
					num, type = split(info, ":")
					if not num in self.seqList_G1_index:
						self.seqList_G1_index.append(num)
		elif (type_str in contig_G2_def):
			if info == "no polymorphism found":
				self.seqList_G2[name] = ""
			else:
				self.seqList_G2[name] = info
				allInfos = split(info, "|")
				for info in allInfos:
					num, type = split(info, ":")
					if not num in self.seqList_G2_index:
						self.seqList_G2_index.append(num)
	def get_G1_Dict(self):
		return self.seqList_G1
	def get_G2_Dict(self):
		return self.seqList_G2
	def get_G1_index(self):
		return self.seqList_G1_index
	def get_G2_index(self):
		return self.seqList_G2_index
	def clear(self):
		self.seqList_G1.clear()
		self.seqList_G1_index.clear()
		self.seqList_G2.clear()
		self.seqList_G2_index.clear()
	def printInfo(self):
		# print out G1
		print ">>>> G1"
		allSeqs = self.seqList_G1.items()
		allSeqs.sort(lambda f, s: cmp(f[0], s[0]))
		for name, info in allSeqs:
			print name, info
		# print out G2
		print ">>>> G2"
		allSeqs = self.seqList_G2.items()
		allSeqs.sort(lambda f, s: cmp(f[0], s[0]))
		for name, info in allSeqs:
			print name, info
	#def analyse(self):

class seqMMClass:
	def __init__(self):
		self.seqMMGroups = {}
		self.contigID = ""
		self.currMMGroupList = {}
	def setContigName(self, contigID):
		self.contigID = contigID
	def readinSeq(self, seqName, info):
		self.currMMGroupList[seqName] = info
	def finalizeGroup(self):
		MMGroup = seqMMGroupClass(self.contigID)
		allSeqs = self.currMMGroupList.items()
		allSeqs.sort(lambda f, s: cmp(f[0], s[0]))
		for sn, info in allSeqs:
			MMGroup.readinSeq(sn, info)
		self.seqMMGroups[self.contigID] = MMGroup
		self.contigID = ""
		self.currMMGroupList = {}
	def getMMGroup(self, contigID):
		return self.seqMMGroups[contigID]
	def clear(self):
		self.seqMMGroups = {}
		self.contigID = ""
		self.currMMGroupList = {}
	def printInfo(self):
		allMMGroups = self.seqMMGroups.items()
		allMMGroups.sort(lambda f, s: mycmp(f[0], s[0]))
		for contigID, group in allMMGroups:
			print contigID, "-------------------------"
			group.printInfo()

class seqContigClass:
	def __init__(self):
		self.seqs = {}
	def readinSeq(self, name, info):
		self.seqs[name] = info
	def getContig(self, name):
		return self.seqs[name]
	def getKeys(self):
		return self.seqs.keys()
	def clear(self):
		self.seqs.clear()
	def printInfo(self):
		allSeqs = self.seqs.items()
		allSeqs.sort(lambda f, s: mycmp(f[0], s[0]))
		for name, info in allSeqs:
			print name, info
	#def analyse(self):

####################################################
#global
fp_in_MM = 0          # input file handler for the mismatch file
fp_in_contig = 0      # input file handler for the contig file
fp_out_del = 0        # output file handler for the deletion file
fp_out_ins = 0        # output file handler for the insertion file
fp_out_sub = 0        # output file handler for the substitution file

seqMM = seqMMClass()    # obj that stores G1 and G2 seq's
seqContig = seqContigClass()   # obj that stores the seq's info

contig_pre_or_suf = ""       # do we distinguish the contigs by prefix or suffix ?
contig_pre_or_suf_size = 0   # what is the size of it ?
contig_G1_def = []           # The definition of G1 (what is defined as in group 1)
contig_G2_def = []           # The definition of G2 (what is defined as in group 2)

####################################################
# function definition

def main_routine():
	read_in_contig_data()    # read data from the contig file
	read_in_MM_data()        # read data from the mismatch file
	check_polymorphism()     # check possible polymorphism

#--------------------------------------

def read_in_contig_data():
	global fp_in_contig
	lines = fp_in_contig.readlines()    # read all lines from the contig file
	if not lines:
		print "No input in the input file"
		sys.exit(0)  # if no input, exit
	i = 0  # the line index, starting from the top

	seqContig.clear()       # clear up the internal storage
	while i < len(lines):   # for each line
		print i
		line = getLine(lines, i)
		contigName, numSeq, info = split(line, "\t")   # retrieve info
		seqContig.readinSeq(contigName, info)          # store the info
		i = i + 1

	# for debugging
	#seqContig.printInfo()

#--------------------------------------

def read_in_MM_data():
	global fp_in_MM
	lines = fp_in_MM.readlines()    # read all lines from the mismatch file
	if not lines:
		print "No input in the input file"
		sys.exit(0)  # if no input, exit
	i = 0  # the line index, starting from the top

	seqMM.clear()                # clear up the internal memory
	while i < len(lines):
		print i
		line = getLine(lines, i)
		curr_contigName, t1, t2 = split(line, "\t")  # retrieve contig name
		seqMM.setContigName(curr_contigName)         # store contig name
		next_contig_line_num = next_diff_contig_line(lines, i, curr_contigName)

		while i < next_contig_line_num:
			line = getLine(lines, i)
			contigName, seqName, info = split(line, "\t")  # retrieve info
			seqMM.readinSeq(seqName, info)                 # store info
			i = i + 1

		seqMM.finalizeGroup()   # um... this is necessary for memory arrangement

	# for debugging
	#seqMM.printInfo()

#--------------------------------------

def check_polymorphism():

	contigIDs = seqContig.getKeys()        # unordered contigIDs

	for cid in contigIDs:            # check for each contig

		"""  This is for debugging (to focus on one contig)
		if cid != "QH_CA_Contig921":
			continue
		"""

		# get info for individual contig
		contigInfo = seqContig.getContig(cid)
		MMGroup = seqMM.getMMGroup(cid)

		# retrieve data out from contigInfo
		contigInfoList = []
		contigInfoList = retrieveContigInfo(contigInfo)

		# Here we get the seqNames
		DictG1 = MMGroup.get_G1_Dict()   # get the set of subseqs
		DictG2 = MMGroup.get_G2_Dict()   # get the set of subseqs

		# Here we get indexes (where D, I, or S occurs)
		G1_index = MMGroup.get_G1_index()   # get the indexes that need to check
		G2_index = MMGroup.get_G2_index()   # get the indexes that need to check

		# do a union on D's, I's and S's to have all the indexes for checking
		all_index = G1_index    # all_index is the union of the two
		for i in G2_index:
			if not i in all_index:
				all_index.append(i)
		all_index.sort(lambda x, y: cmp(atoi(x), atoi(y)))

		#-- search for D, I, S
		print "Checking polymorphism for", cid
		check_polymorphism_contig("D", DictG1, DictG2, G1_index, G2_index, all_index, cid, contigInfoList)
		check_polymorphism_contig("I", DictG1, DictG2, G1_index, G2_index, all_index, cid, contigInfoList)
		check_polymorphism_contig("S", DictG1, DictG2, G1_index, G2_index, all_index, cid, contigInfoList)


#################################################
# helper funcions

def getLine(lines, index):
	line = lines[index]
	line = rstrip(line)   # rid of \n
	return line

#--------------------------------------

def retrieveContigInfo(contigInfo):
	contigInfoList = []  # each element is a list of (seqName, start, end)
	sign = " "
	t_contigInfoList = split(contigInfo, "|")
	for t_ci in t_contigInfoList:
		if find(t_ci, "+") != -1:     # get sign
			sign = "+"
		if find(t_ci, "-") != -1:
			sign = "-"
		t_seqName, t_indexes = split(t_ci, sign)
		t_indexes = t_indexes[1:len(t_indexes)-1]
		start, end = split(t_indexes, ",")       # get start and end indexes
		start = atoi(start)     # we store integer internally
		end = atoi(end)
		c = (t_seqName+sign, start, end)      # package the three into one
		contigInfoList.append(c)              # store it into the list
		
	""" This is for debugging (printing out the seqName in a group)
	print "==========="
	for i in range( len(contigInfoList) ):
		print contigInfoList[i]
	print "//////////////"
	"""

	return contigInfoList

#---------------------------------

def check_polymorphism_contig(DIS, DictG1, DictG2, G1_index, G2_index, all_index, cid, contigInfoList):
	low_priority = 0            # the Low_Priority field
	max_occur_all = 0           # the ALL field
	max_occur_partial_num = 0   # numerator for partialFraction
	max_occur_partial_den = 1   # denominator for partialFraction
	position_all = []           # location of where DIS occurs (assuming refering ALL)

	# get num of seqs in group 1 and 2 (should be held constant)
	numOfSeqG1_const = len( DictG1.keys() )   # num of seqs in group 1
	numOfSeqG2_const = len( DictG2.keys() )   # num of seqs in group 2

	for i in all_index:   # we go from left to right, looking for possible polymorphism
		i = atoi(i)         # i needs to be converted to type int
		total_occur_G1 = 0  # total occurance of DIS in group 1
		total_occur_G2 = 0  # total occurance of DIS in group 2
		# these are for local use
		numOfSeqG1 = numOfSeqG1_const
		numOfSeqG2 = numOfSeqG2_const
		# analyze the group
		for contigInfo in contigInfoList:     # should be (for each seq's info)
			c, startIndex, endIndex = contigInfo    # decompose it
			# if seq belongs to group 1, update vars corresponding to group 1
			if DictG1.has_key(c):
				numOfSeqG1, total_occur_G1 = updateTotalOccur(DIS, DictG1, c, i, total_occur_G1, numOfSeqG1, startIndex, endIndex)
			# if seq belongs to group 2, update vars corresponding to group 2
			elif DictG2.has_key(c):
				numOfSeqG2, total_occur_G2 = updateTotalOccur(DIS, DictG2, c, i, total_occur_G2, numOfSeqG2, startIndex, endIndex)
			else:
				print "neither in DictG1 or DictG2..."
				sys.exit(0)

		# update max occurance of ALL
		if total_occur_G1 == numOfSeqG1 and total_occur_G1 > max_occur_all and total_occur_G2 == 0:
			max_occur_all = numOfSeqG1
			position_all.append(i)
		if total_occur_G2 == numOfSeqG2 and total_occur_G2 > max_occur_all and total_occur_G1 == 0:
			max_occur_all = numOfSeqG2
			position_all.append(i)

		# update max occurance of PARTIAL
		if numOfSeqG1 != 0 and total_occur_G1 < numOfSeqG1 and total_occur_G2 == 0:
			fractionC = ( total_occur_G1 * 1.0 ) / numOfSeqG1
			fractionM = ( max_occur_partial_num * 1.0 ) / max_occur_partial_den
			if fractionC > fractionM or (fractionC==fractionM and numOfSeqG1>max_occur_partial_den):
				max_occur_partial_num = total_occur_G1
				max_occur_partial_den = numOfSeqG1
		if numOfSeqG2 != 0 and total_occur_G2 < numOfSeqG2 and total_occur_G1 == 0:
			fractionC = ( total_occur_G2 * 1.0 ) / numOfSeqG2
			fractionM = ( max_occur_partial_num * 1.0 ) / max_occur_partial_den
			if fractionC > fractionM or (fractionC==fractionM and numOfSeqG2>max_occur_partial_den):
				max_occur_partial_num = total_occur_G2
				max_occur_partial_den = numOfSeqG2

		# for debugging (check indexes and other local variables)
		# print DIS, i, ">",  total_occur_G1, numOfSeqG1, "|", total_occur_G2, numOfSeqG2, "<", max_occur_all, max_occur_partial_num, max_occur_partial_den

	# update low_priority (only one parent?)
	if max_occur_all == 1:
		low_priority = 1

	# output findings
	OutputGood(DIS, cid, max_occur_all, max_occur_partial_num, max_occur_partial_den, low_priority, position_all)

#---------------------------------

def updateTotalOccur(DIS, Dict, c, i, total_occur_arg, numOfSeq_arg, start, end):
	total_occur = total_occur_arg    # save args to local vars
	numOfSeq = numOfSeq_arg
	MM_info = Dict[c]            # this is something like 121:S|134:I|341:S
	if find(MM_info, str(i)) != -1:      # see if MM_info contain index i
		t1 = index(MM_info,  str(i))   # if so, get the DIS symbol
		t2 = int(log10(i))
		t = MM_info[t1+t2+2]  # 2 coz DIS is after :
		if t == DIS:          # if it's the provided one, we got a hit
			total_occur = total_occur + 1
	elif start > i or end < i:    # if location i is not in this seq
		numOfSeq = numOfSeq - 1      # treat this seq as not existent
	return numOfSeq, total_occur

#---------------------------------

def next_diff_contig_line(lines, index, contigName):
	i = index
	line = getLine(lines, i)
	curr_contigName, t1, t2 = split(line, "\t")
	contigName = curr_contigName
	while contigName == curr_contigName:
		i = i + 1
		if i >= len(lines):
			return len(lines)
		line = getLine(lines, i)
		contigName, t1, t2 = split(line, "\t")
	return i

#---------------------------------

def OutputGood(DIS, cid, All, Partial_num, Partial_den, low_priority, position_all):

	# format output position string
	position_str = "|"
	i = 0              # index var
	linking_flag = 0   # indicator of linking  (eg. 1-4)
	previous_p = -1    # previous p in position_all
	for p in position_all:
		if i == 0:
			position_str = position_str + str(p)
		else:
			if p == previous_p + 1:
				linking_flag = 1
			else:
				if linking_flag == 0:        # normal output
					position_str = position_str + ":" + str(p)
				if linking_flag == 1:        # output the end part
					position_str = position_str + "-" + str(previous_p) + ":" + str(p)
					linking_flag = 0
		previous_p = p
		i = i + 1
	if linking_flag == 1:
		position_str = position_str + "-" + str(previous_p)
	position_str = position_str + "|"

	# output all the strings
	if DIS == "D":
		OutputDeletionsGood(cid, All, Partial_num, Partial_den, low_priority, position_str)
	elif DIS == "I":
		OutputInsertionsGood(cid, All, Partial_num, Partial_den, low_priority, position_str)
	elif DIS == "S":
		OutputSubstitutionsGood(cid, All, Partial_num, Partial_den, low_priority, position_str)

#---------------------------------

def OutputDeletionsGood(contigID, All, Partial_num, Partial_den, low_priority, position_str):
	global fp_out_del
	s = "Congratulation: " + contigID + " contains deletion with possible polymorphism between parents!!!"
	print s
	outputStr = contigID + "\tD\t" + str(All) + "\t" + str(Partial_num) + "/" + str(Partial_den) + "\t" + str(low_priority) + "\t" + position_str + "\n"
	fp_out_del.write(outputStr)

#---------------------------------

def OutputInsertionsGood(contigID, All, Partial_num, Partial_den, low_priority, position_str):
	global fp_out_ins
	s = "Congratulation: " + contigID + " contains insertion with possible polymorphism between parents!!!"
	print s
	outputStr = contigID + "\tI\t" + str(All) + "\t" + str(Partial_num) + "/" + str(Partial_den) + "\t" + str(low_priority) + "\t" + position_str + "\n"
	fp_out_ins.write(outputStr)

#---------------------------------

def OutputSubstitutionsGood(contigID, All, Partial_num, Partial_den, low_priority, position_str):
	global fp_out_sub
	s = "Congratulation: " + contigID + " contains substitution with possible polymorphism between parents!!!"
	print s
	outputStr = contigID + "\tS\t" + str(All) + "\t" + str(Partial_num) + "/" + str(Partial_den) + "\t" + str(low_priority) + "\t" + position_str + "\n"
	fp_out_sub.write(outputStr)


####################################################
# main

#ask for input file and open them
s = raw_input("Enter the mismatch file name : ")
#s = "Sunflower_Contigs_MM.mySQL"
#s = "sun_contig.mismatch.out"
input_filename_MM = s
fp_in_MM = open(input_filename_MM, "rb")
s = raw_input("Enter the contig info file name : ")
#s = "Sunflower_Contigs_ID.mySQL"
#s = "cap3_QH_CA.contig.info"
input_filename_contig = s
fp_in_contig = open(input_filename_contig, "rb")

#assign output filename
output_filename_deletion = "deletions.good"
output_filename_insertion = "insertions.good"
output_filename_substitution = "substitutions.good"

#open the output file
fp_out_del = open(output_filename_deletion, "wb")
fp_out_ins = open(output_filename_insertion, "wb")
fp_out_sub = open(output_filename_substitution, "wb")

#ask if we use prefix or suffix to distinguish between group 1 and group 2
#contig_pre_or_suf = "p"

while 1:
	contig_pre_or_suf = raw_input("Distinguish by (p)refix or (s)uffix? : ")
	if contig_pre_or_suf != "p" and contig_pre_or_suf != "s":
		print "Please enter either 'p' or 's', but not", contig_pre_or_suf
	else:
		break

#ask how long the pre or suffix is
#contig_pre_or_suf_size = 3

while 1:
	if contig_pre_or_suf == "p":
		s = raw_input("Please input the size of the prefix : ")
		contig_pre_or_suf_size = int(s)
	elif contig_pre_or_suf == "s":
		s = raw_input("Please input the size of the suffix : ")
		contig_pre_or_suf_size = int(s)
	if contig_pre_or_suf_size > 0:
		break

# ask for the definition of G1 (what is defined as in group 1)
# contig_G1_def.append("QHA")
# contig_G1_def.append("QHB")
# contig_G1_def.append("QHC")
# contig_G1_def.append("QHD")
# contig_G1_def.append("QHI")

print "Specify Group 1"
while 1:
	print "Group 1 has", contig_G1_def
	#getting new element
	s = raw_input("Please input element : ")
	if len(s) != contig_pre_or_suf_size:
		print "element must be of size", contig_pre_or_suf_size
		continue
	if s in contig_G1_def:
		print "element is already in Group 1"
		continue
	#add the new element to Group 1
	contig_G1_def.append(s) 
	#ask if the user wants to add more
	s = raw_input("Do you want to add more? (y/n) : ")
	if s != "y":
		break

# ask for the definition of G2 (what is defined as in group 2)
# contig_G2_def.append("QHE")
# contig_G2_def.append("QHF")
# contig_G2_def.append("QHG")
# contig_G2_def.append("QHH")
# contig_G2_def.append("QHJ")

print "Specify Group 2"
while 1:
	print "Group 2 has", contig_G2_def
	#getting new element
	s = raw_input("Please input element : ")
	if len(s) != contig_pre_or_suf_size:
		print "element must be of size", contig_pre_or_suf_size
		continue
	if s in contig_G1_def:
		print "element is already in Group 1"
		continue
	if s in contig_G2_def:
		print "element is already in Group 2"
		continue
	#add the new element to Group 2
	contig_G2_def.append(s) 
	#ask if the user wants to add more
	s = raw_input("Do you want to add more? (y/n) : ")
	if s != "y":
		break

#open the input and output file
print "reading data form the input file", "..."

main_routine()

#close the input and output files
fp_in_MM.close()
fp_in_contig.close()
fp_out_del.close()
fp_out_ins.close()
fp_out_sub.close()

