__author__="Pedram Tavadze and Logan Lang"__maintainer__="Pedram Tavadze and Logan Lang"__email__="petavazohi@mail.wvu.edu, lllang@mix.wvu.edu"__date__="March 31, 2020"importjsonimportosfromtypingimportListimportmatplotlibasmplimportmatplotlib.pyplotaspltimportnumpyasnpimportpandasaspdimportyamlfrommatplotlib.collectionsimportLineCollectionfrommatplotlib.tickerimportAutoMinorLocator,FormatStrFormatter,MultipleLocatorfrompyprocar.coreimportElectronicBandStructure,KPath
[docs]classEBSPlot:""" A class to plot an electronic band structure. Parameters ---------- ebs : ElectronicBandStructure An electronic band structure object pyprocar.core.ElectronicBandStructure. kpath : KPath, optional A kpath object pyprocar.core.KPath. The default is None. ax : mpl.axes.Axes, optional A matplotlib Axes object. If provided the plot will be located at that ax. The default is None. spin : List[int], optional A list of the spins The default is None. Returns ------- None. """
[docs]def__init__(self,ebs:ElectronicBandStructure,kpath:KPath=None,ax:mpl.axes.Axes=None,spins:List[int]=None,kdirect:bool=True,config=None):self.config=configself.ebs=ebsself.kpath=kpathself.spins=spinsself.kdirect=kdirectself.values_dict={}ifself.spinsisNone:self.spins=range(self.ebs.nspins)self.nspins=len(self.spins)ifself.ebs.is_non_collinear:self.spins=[0]self.handles=[]figsize=tuple(self.config.figure_size)ifaxisNone:self.fig=plt.figure(figsize=figsize)self.ax=self.fig.add_subplot(111)else:self.fig=plt.gcf()self.ax=ax# need to initiate kpath if kpath is not defined.self.x=self._get_x()self._initiate_plot_args()returnNone
def_initiate_plot_args(self):"""Helper method to initialize the plot options """self.set_xticks()self.set_yticks()self.set_xlabel()self.set_ylabel()self.set_xlim()self.set_ylim()def_get_x(self):""" Provides the x axis data of the plots Returns ------- np.ndarray x-axis data. """pos=0ifself.kpathisnotNoneandself.kpath.nsegments==len(self.kpath.ngrids):forisegmentinrange(self.kpath.nsegments):kstart,kend=self.kpath.special_kpoints[isegment]ifself.kdirectisFalse:kstart=np.dot(self.ebs.reciprocal_lattice,kstart)kend=np.dot(self.ebs.reciprocal_lattice,kend)distance=np.linalg.norm(kend-kstart)ifisegment==0:x=np.linspace(pos,pos+distance,self.kpath.ngrids[isegment])else:x=np.append(x,np.linspace(pos,pos+distance,self.kpath.ngrids[isegment]),axis=0,)pos+=distanceelse:x=np.arange(0,self.ebs.kpoints.shape[0])returnnp.array(x).reshape(-1,)
[docs]defplot_bands(self):""" Plot the plain band structure. Returns ------- None. """values_dict={}forispininself.spins:iflen(self.spins)==1:color=self.config.colorelse:color=self.config.spin_colors[ispin]foribandinrange(self.ebs.nbands):handle=self.ax.plot(self.x,self.ebs.bands[:,iband,ispin],color=color,alpha=self.config.opacity[ispin],linestyle=self.config.linestyle[ispin],label=self.config.label[ispin],linewidth=self.config.linewidth[ispin],)self.handles.append(handle)band_name=f'band-{iband}_spinChannel-{str(ispin)}'values_dict[f'bands_{band_name}']=self.ebs.bands[:,iband,ispin]values_dict['kpath_values']=self.xtick_names=[]fori,xinenumerate(self.x):tick_name=''ifself.kpathisnotNone:fori_tick,tick_positioninenumerate(self.kpath.tick_positions):ifi==tick_position:tick_name=self.kpath.tick_names[i_tick]tick_names.append(tick_name)values_dict['kpath_tick_names']=tick_namesself.values_dict=values_dict
[docs]defplot_scatter(self,width_mask:np.ndarray=None,color_mask:np.ndarray=None,spins:List[int]=None,width_weights:np.ndarray=None,color_weights:np.ndarray=None,labels=None,):"""A method to plot a scatter plot Parameters ---------- width_mask : np.ndarray, optional The width mask, by default None color_mask : np.ndarray, optional The color mask, by default None spins : List[int], optional A list of spins, by default None width_weights : np.ndarray, optional The width weight of each point, by default None color_weights : np.ndarray, optional The color weights at each point, by default None """values_dict={}ifspinsisNone:spins=range(self.ebs.nspins)ifself.ebs.is_non_collinear:spins=[0]ifwidth_weightsisNone:width_weights=np.ones_like(self.ebs.bands)markersize=self.config.markersizeelse:markersize=[l*30forlinself.config.markersize]ifwidth_maskisnotNoneorcolor_maskisnotNone:ifwidth_maskisnotNone:mbands=np.ma.masked_array(self.ebs.bands,np.abs(width_weights)<width_mask)ifcolor_maskisnotNone:mbands=np.ma.masked_array(self.ebs.bands,np.abs(color_weights)<color_mask)else:# Faking a mask, all elemtnet are includedmbands=np.ma.masked_array(self.ebs.bands,False)ifcolor_weightsisnotNone:vmin=self.config.clim[0]vmax=self.config.clim[1]ifvminisNone:# only the actual spin values are to be used (i.e. we# are plotting the density, then negative values from# spin projections are nonsense )vmin=color_weights[:,:,spins].min()ifvmaxisNone:vmax=color_weights[:,:,spins].max()forispininspins:foribandinrange(self.ebs.nbands):iflen(self.spins)==1:color=self.config.colorelse:color=self.config.spin_colors[ispin]ifcolor_weightsisNone:sc=self.ax.scatter(self.x,mbands[:,iband,ispin],c=color,s=width_weights[:,iband,ispin]*markersize[ispin],# edgecolors="none",linewidths=self.config.linewidth[ispin],cmap=self.config.cmap,vmin=vmin,vmax=vmax,marker=self.config.marker[ispin],alpha=self.config.opacity[ispin],)else:sc=self.ax.scatter(self.x,mbands[:,iband,ispin],c=color_weights[:,iband,ispin],s=width_weights[:,iband,ispin]*markersize[ispin],# edgecolors="none",linewidths=self.config.linewidth[ispin],cmap=self.config.cmap,vmin=vmin,vmax=vmax,marker=self.config.marker[ispin],alpha=self.config.opacity[ispin],)band_name=f'band-{iband}_spinChannel-{str(ispin)}'values_dict[f'bands__{band_name}']=self.ebs.bands[:,iband,ispin]projection_name=labels[0]ifcolor_weightsisnotNone:values_dict[f'projections__{projection_name}__{band_name}']=color_weights[:,iband,ispin]ifself.config.plot_color_barandcolor_weightsisnotNone:self.cb=self.fig.colorbar(sc,ax=self.ax)values_dict['kpath_values']=self.xtick_names=[]fori,xinenumerate(self.x):tick_name=''ifself.kpathisnotNone:fori_tick,tick_positioninenumerate(self.kpath.tick_positions):ifi==tick_position:tick_name=self.kpath.tick_names[i_tick]tick_names.append(tick_name)values_dict['kpath_tick_names']=tick_namesself.values_dict=values_dict
[docs]defplot_parameteric(self,spins:List[int]=None,width_mask:np.ndarray=None,color_mask:np.ndarray=None,width_weights:np.ndarray=None,color_weights:np.ndarray=None,elimit:List[float]=None,labels=None):"""A method to plot a scatter plot Parameters ---------- spins : List[int], optional A list of spins, by default None width_mask : np.ndarray, optional The width mask, by default None color_mask : np.ndarray, optional The color mask, by default None width_weights : np.ndarray, optional The width weight of each point, by default None color_weights : np.ndarray, optional The color weights at each point, by default None elimit : List[float], optional Energy range to plot. Only useful if the band index is written """values_dict={}iflabelsisNone:labels=['']# if there is only a single k-point the method for atomic# levels will be called to fake another kpoint and then# exit. `plot_atomic_levels` will invoke this method again to# get the actual plotiflen(self.ebs.kpoints)==1:self.plot_atomic_levels(color_weights=color_weights,width_weights=width_weights,color_mask=color_mask,width_mask=width_mask,spins=spins,elimit=elimit)returnifwidth_weightsisNone:width_weights=np.ones_like(self.ebs.bands)linewidth=self.config.linewidthelse:linewidth=[l*5forlinself.config.linewidth]ifspinsisNone:spins=range(self.ebs.nspins)ifself.ebs.is_non_collinear:spins=[0]ifwidth_maskisnotNoneorcolor_maskisnotNone:ifwidth_maskisnotNone:mbands=np.ma.masked_array(self.ebs.bands,np.abs(width_weights)<width_mask)ifcolor_maskisnotNone:mbands=np.ma.masked_array(self.ebs.bands,np.abs(color_weights)<color_mask)else:# Faking a mask, all elemtnet are includedmbands=np.ma.masked_array(self.ebs.bands,False)ifcolor_weightsisnotNone:vmin=self.config.clim[0]vmax=self.config.clim[1]ifvminisNone:vmin=color_weights[:,:,spins].min()ifvmaxisNone:vmax=color_weights[:,:,spins].max()norm=mpl.colors.Normalize(vmin,vmax)forispininspins:foribandinrange(self.ebs.nbands):iflen(self.spins)==1:color=self.config.colorelse:color=self.config.spin_colors[ispin]points=np.array([self.x,mbands[:,iband,ispin]]).T.reshape(-1,1,2)segments=np.concatenate([points[:-1],points[1:]],axis=1)# this is to delete the segments on the high sym pointsx=self.x# segments = np.delete(# segments, np.where(x[1:] == x[:-1])[0], axis=0)ifcolor_weightsisNone:lc=LineCollection(segments,colors=color,linestyle=self.config.linestyle[ispin])else:lc=LineCollection(segments,cmap=plt.get_cmap(self.config.cmap),norm=norm)lc.set_array(color_weights[:,iband,ispin])lc.set_linewidth(width_weights[:,iband,ispin]*linewidth[ispin])lc.set_linestyle(self.config.linestyle[ispin])handle=self.ax.add_collection(lc)band_name=f'band-{iband}_spinChannel-{str(ispin)}'projection_name=labels[0]values_dict[F'bands__{band_name}']=self.ebs.bands[:,iband,ispin]ifcolor_weightsisnotNone:values_dict[F'projections__{projection_name}__{band_name}']=color_weights[:,iband,ispin]# if color_weights is not None:# handle.set_color(color_map[iweight][:-1].lower())handle.set_linewidth(linewidth)self.handles.append(handle)ifself.config.plot_color_barandcolor_weightsisnotNone:self.cb=self.fig.colorbar(lc,ax=self.ax)values_dict['kpath_values']=self.xtick_names=[]fori,xinenumerate(self.x):tick_name=''ifself.kpathisnotNone:fori_tick,tick_positioninenumerate(self.kpath.tick_positions):ifi==tick_position:tick_name=self.kpath.tick_names[i_tick]tick_names.append(tick_name)values_dict['kpath_tick_names']=tick_namesself.values_dict=values_dict
[docs]defplot_parameteric_overlay(self,spins:List[int]=None,weights:np.ndarray=None,labels:str=None,):"""A method to plot the parametric overlay Parameters ---------- spins : List[int], optional A list of spins, by default None weights : np.ndarray, optional The weights of each point, by default None """values_dict={}iflabelsisNone:labels=['']linewidth=[l*7forlinself.config.linewidth]iftype(self.config.cmap)isstr:color_map=['Reds',"Blues","Greens","Purples","Oranges","Greys"]else:color_map=self.config.cmapifspinsisNone:spins=range(self.ebs.nspins)ifself.ebs.is_non_collinear:spins=[0]foriweight,weightinenumerate(weights):vmin=self.config.clim[0]vmax=self.config.clim[1]ifvminisNone:vmin=0ifvmaxisNone:vmax=1norm=mpl.colors.Normalize(vmin,vmax)forispininspins:# plottingforibandinrange(self.ebs.nbands):points=np.array([self.x,self.ebs.bands[:,iband,ispin]]).T.reshape(-1,1,2)segments=np.concatenate([points[:-1],points[1:]],axis=1)# this is to delete the segments on the high sym pointsx=self.xsegments=np.delete(segments,np.where(x[1:]==x[:-1])[0],axis=0)lc=LineCollection(segments,cmap=plt.get_cmap(color_map[iweight]),norm=norm,alpha=self.config.opacity[ispin])lc.set_array(weight[:,iband,ispin])lc.set_linewidth(weight[:,iband,ispin]*linewidth[ispin])handle=self.ax.add_collection(lc)band_name=f'band-{iband}_spinChannel-{str(ispin)}'projection_name=labels[iweight]values_dict[f'bands__{band_name}']=self.ebs.bands[:,iband,ispin]ifweightsisnotNone:values_dict[f'projections__{projection_name}__{band_name}']=weight[:,iband,ispin]handle.set_color(color_map[iweight][:-1].lower())handle.set_linewidth(linewidth)self.handles.append(handle)ifself.config.plot_color_bar:self.cb=self.fig.colorbar(lc,ax=self.ax)values_dict['kpath_values']=self.xtick_names=[]fori,xinenumerate(self.x):tick_name=''ifself.kpathisnotNone:fori_tick,tick_positioninenumerate(self.kpath.tick_positions):ifi==tick_position:tick_name=self.kpath.tick_names[i_tick]tick_names.append(tick_name)values_dict['kpath_tick_names']=tick_namesself.values_dict=values_dict
[docs]defplot_atomic_levels(self,spins:List[int]=None,width_mask:np.ndarray=None,color_mask:np.ndarray=None,width_weights:np.ndarray=None,color_weights:np.ndarray=None,elimit:List[float]=None,labels=None):"""A method to plot a scatter plot Parameters ---------- spins : List[int], optional A list of spins, by default None width_mask : np.ndarray, optional The width mask, by default None color_mask : np.ndarray, optional The color mask, by default None width_weights : np.ndarray, optional The width weight of each point, by default None color_weights : np.ndarray, optional The color weights at each point, by default None elimit : List[float], optional The energy range to plot. """iflabelsisNone:labels=['']self.ebs.bands=np.vstack((self.ebs.bands,self.ebs.bands))self.ebs.projected=np.vstack((self.ebs.projected,self.ebs.projected))self.ebs.kpoints=np.vstack((self.ebs.kpoints,self.ebs.kpoints))self.ebs.kpoints[0][-1]+=1self.x=self._get_x()print("Atomic plot: bands.shape :",self.ebs.bands.shape)print("Atomic plot: spd.shape :",self.ebs.projected.shape)print("Atomic plot: kpoints.shape:",self.ebs.kpoints.shape)self.ax.xaxis.set_major_locator(plt.NullLocator())# labels on each bandifelimit:emin,emax=elimit[0],elimit[1]else:emin,emax=np.min(self.ebs.bands),np.max(self.ebs.bands)# print('Energy range', emin, emax)ifspinsisNone:spins=range(self.ebs.nspins)ifself.ebs.is_non_collinear:spins=[0]# cointainers for the bounding boxes of the text elementsNspin=len(spins)texts=[]forispininspins:foriinrange(len(self.ebs.bands[0,:,ispin])):energy=self.ebs.bands[0,i,ispin]ifenergy>eminandenergy<emax:txt=[0,energy,f"s-{ispin} : "+"b-"+str(i+1)]texts.append(txt)# sorting the textstexts.sort(key=lambdax:x[1])# I need to set the energy limitsself.set_ylim(elimit)self.set_xlim()# knowing the text sizetxt=texts[-1]txt=plt.text(*txt)bbox=txt.get_window_extent()bbox_data=self.ax.transData.inverted().transform_bbox(bbox)w,h=bbox_data.width,bbox_data.heighttxt.remove()# print('Width, ', w, '. Height,', h)shift=0txt=texts[0]self.ax.text(*txt)foriinrange(1,len(texts)):txt=texts[i]last=texts[i-1]y,y0=txt[1],last[1]# if there there is vertical overlapify<y0+h:# print('overlap', y, y0+h)# I shift it laterally (the shift can be 0)shift+=1ifshift==2:shift=0txt[0]=txt[0]+w*1.5*shiftelse:shift=0# print(txt)self.ax.text(*txt)self.plot_parameteric(color_weights=color_weights,width_weights=width_weights,color_mask=color_mask,width_mask=width_mask,spins=spins,labels=labels)
[docs]defset_xticks(self,tick_positions:List[int]=None,tick_names:List[str]=None,color:str="black"):"""A method to set the x ticks Parameters ---------- tick_positions : List[int], optional A list of tick positions, by default None tick_names : List[str], optional A list of tick names, by default None color : str, optional A color for the ticks, by default "black" """ifself.kpathisnotNone:iftick_positionsisNone:tick_positions=self.kpath.tick_positionsiftick_namesisNone:tick_names=self.kpath.tick_namesforiposintick_positions:self.ax.axvline(self.x[ipos],color=color)self.ax.set_xticks(self.x[tick_positions])self.ax.set_xticklabels(tick_names)self.ax.tick_params(**self.config.major_x_tick_params)
[docs]defset_yticks(self,major:float=None,minor:float=None,interval:List[float]=None):"""A method to set the y ticks Parameters ---------- major : float, optional A float to set the major tick locator, by default None minor : float, optional A float to set the the minor tick Locator, by default None interval : List[float], optional The interval of the ticks, by default None """# if (major is None or minor is None):ifintervalisNone:interval=(self.ebs.bands.min()-abs(self.ebs.bands.min())*0.1,self.ebs.bands.max()*1.1)interval=abs(interval[1]-interval[0])ifinterval<30andinterval>=20:major=5minor=1elifinterval<20andinterval>=10:major=4minor=0.5elifinterval<10andinterval>=5:major=2minor=0.2elifinterval<5andinterval>=3:major=1minor=0.1elifinterval<3andinterval>=1:major=0.5minor=0.1else:passifself.config.multiple_locator_y_major_valueisnotNone:major=self.config.multiple_locator_y_major_valueifself.config.multiple_locator_y_minor_valueisnotNone:minor=self.config.multiple_locator_y_minor_valueifself.config.major_y_locatorisnotNoneorself.config.minor_y_locatorisnotNone:ifself.config.major_y_locatorisnotNone:self.ax.yaxis.set_major_locator(self.config.major_y_locator)ifself.config.minor_y_locatorisnotNone:self.ax.yaxis.set_minor_locator(self.config.minor_y_locator)else:ifmajorisnotNone:self.ax.yaxis.set_major_locator(MultipleLocator(major))ifminorisnotNone:self.ax.yaxis.set_minor_locator(MultipleLocator(minor))self.ax.tick_params(**self.config.major_y_tick_params)self.ax.tick_params(**self.config.minor_y_tick_params)
[docs]defset_xlim(self,interval:List[float]=None):"""A method to set the x limit Parameters ---------- interval : List[float], optional A list containing the begining and the end of the interval, by default None """ifintervalisNone:interval=(self.x[0],self.x[-1])self.ax.set_xlim(interval)
[docs]defset_ylim(self,interval:List[float]=None):"""A method to set the y limit Parameters ---------- interval : List[float], optional A list containing the begining and the end of the interval, by default None """ifintervalisNone:interval=(self.ebs.bands.min()-abs(self.ebs.bands.min())*0.1,self.ebs.bands.max()*1.1)self.ax.set_ylim(interval)
[docs]defset_xlabel(self,label:str="K vector"):"""A method to set the x label Parameters ---------- label : str, optional String fo the x label name, by default "K vector" """self.ax.set_xlabel(label,**self.config.x_label_params)
[docs]defset_ylabel(self,label:str=r"E - E$_F$ (eV)"):"""A method to set the y label Parameters ---------- label : str, optional String fo the y label name, by default r"E - E$ (eV)" """self.ax.set_ylabel(label,**self.config.y_label_params)
[docs]defset_title(self,title:str="Band Structure"):"""A method to set the title Parameters ---------- title : str, optional String for the title, by default "Band Structure" """ifself.config.title:self.ax.set_title(label=self.config.title,**self.config.title_params)
[docs]defset_colorbar_title(self,title:str=None):"""A method to set the title of the color bar Parameters ---------- title : str, optional String for the title, by default "Atomic Orbital Projections" """iftitle:title=titleelse:title=self.config.colorbar_titleifself.config.colorbar_tick_params:self.cb.ax.tick_params(**self.config.colorbar_tick_params)else:self.cb.ax.tick_params(labelsize=self.config.colorbar_tick_labelsize)ifself.config.colorbar_label_params:self.cb.set_label(title,**self.config.colorbar_label_params)else:self.cb.set_label(title,size=self.config.colorbar_title_size,rotation=270,labelpad=self.config.colorbar_title_padding)
[docs]deflegend(self,labels:List[str]=None):"""A methdo to plot the legend Parameters ---------- labels : List[str], optional A list of strings for the labels of each element for the legend, by default None """iflabels==None:labels=self.config.labelifself.config.legend:self.ax.legend(self.handles,labels)
[docs]defdraw_fermi(self,fermi_level:float=0,):"""A method to draw the fermi line Parameters ---------- fermi_level : str, optional The energy level to draw the line """self.ax.axhline(y=fermi_level,color=self.config.fermi_color,linestyle=self.config.fermi_linestyle,linewidth=self.config.fermi_linewidth)
[docs]defgrid(self):"""A method to plot a grid """ifself.config.grid:self.ax.grid(self.config.grid,which=self.config.grid_which,color=self.config.grid_color,linestyle=self.config.grid_linestyle,linewidth=self.config.grid_linewidth)
[docs]defshow(self):"""A method to show the plot """plt.show()
[docs]defsave(self,filename:str='bands.pdf'):"""A method to save the plot Parameters ---------- filename : str, optional A string for the file name, by default 'bands.pdf' """plt.savefig(filename,dpi=self.config.dpi,bbox_inches="tight")plt.clf()
[docs]defexport_data(self,filename):""" This method will export the data to a csv file Parameters ---------- filename : str The file name to export the data to Returns ------- None None """possible_file_types=['csv','txt','json','dat']file_type=filename.split('.')[-1]iffile_typenotinpossible_file_types:raiseValueError(f"The file type must be {possible_file_types}")ifself.values_dictisNone:raiseValueError("The data has not been plotted yet")column_names=list(self.values_dict.keys())sorted_column_names=[None]*len(column_names)index=0forcolumn_nameincolumn_names:if'kpath_values'==column_name:sorted_column_names[index]=column_nameindex+=1if'kpath_tick_names'==column_name:sorted_column_names[index]=column_nameindex+=1forispininrange(2):forcolumn_nameincolumn_names:if'spinChannel-0'incolumn_name.split('_')[-1]andispin==0:sorted_column_names[index]=column_nameindex+=1if'spinChannel-1'incolumn_name.split('_')[-1]andispin==1:sorted_column_names[index]=column_nameindex+=1column_names.sort()iffile_type=='csv':df=pd.DataFrame(self.values_dict)df.to_csv(filename,columns=sorted_column_names,index=False)eliffile_type=='txt':df=pd.DataFrame(self.values_dict)df.to_csv(filename,columns=sorted_column_names,sep='\t',index=False)eliffile_type=='json':withopen(filename,'w')asoutfile:forkey,valueinself.values_dict.items():self.values_dict[key]=value.tolist()json.dump(self.values_dict,outfile)eliffile_type=='dat':df=pd.DataFrame(self.values_dict)df.to_csv(filename,columns=sorted_column_names,sep=' ',index=False)