Skip to content

Commit

Permalink
Support fallback arch specified by Rust with target-cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
denzp committed Jan 29, 2019
1 parent faba7c0 commit 367d5f6
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 24 deletions.
10 changes: 8 additions & 2 deletions src/bin/legacy-ptx-linker/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ impl<'a> From<ArgMatches<'a>> for CommandLineRequest {
}
}

// Unfortunately, there is no way to get the fallback arch from Rust
// with the legacy approach.
session.set_fallback_arch("sm_30");

CommandLineRequest::Link(session)
}
}
Expand Down Expand Up @@ -202,7 +206,8 @@ mod tests {

let expected_session = Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand Down Expand Up @@ -234,7 +239,8 @@ mod tests {

let expected_session = Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::LTO,
debug_info: false,
Expand Down
75 changes: 60 additions & 15 deletions src/bin/rust-ptx-linker/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ fn get_app() -> App<'static, 'static> {
.long("debug")
.help("Emit debug info")
},
{
Arg::with_name("fallback_arch")
.long("fallback-arch")
.help("Rust own target architecture")
.takes_value(true)
.default_value("sm_30")
},
{
Arg::with_name("arch")
.short("a")
Expand Down Expand Up @@ -136,6 +143,10 @@ fn parse_session(matches: ArgMatches<'static>) -> Session {
}
}

if let Some(arch) = matches.value_of("fallback_arch") {
session.set_fallback_arch(arch);
}

session
}

Expand Down Expand Up @@ -164,7 +175,8 @@ mod tests {
parse_session(matches.unwrap()),
Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand Down Expand Up @@ -196,7 +208,8 @@ mod tests {
parse_session(matches.unwrap()),
Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand Down Expand Up @@ -224,7 +237,8 @@ mod tests {
parse_session(matches.unwrap()),
Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand All @@ -244,7 +258,8 @@ mod tests {
parse_session(matches.unwrap()),
Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: true,
Expand All @@ -266,7 +281,8 @@ mod tests {
),
Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::LTO,
debug_info: false,
Expand All @@ -288,7 +304,8 @@ mod tests {
),
Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand All @@ -307,7 +324,8 @@ mod tests {
),
Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand All @@ -326,7 +344,8 @@ mod tests {
),
Session {
emit: vec![Output::IntermediateRepresentation],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand All @@ -345,7 +364,8 @@ mod tests {
),
Session {
emit: vec![Output::Bitcode],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand All @@ -370,7 +390,8 @@ mod tests {
),
Session {
emit: vec![Output::PTXAssembly, Output::Bitcode],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand All @@ -389,7 +410,8 @@ mod tests {
),
Session {
emit: vec![Output::PTXAssembly, Output::Bitcode],
achitectures: vec![],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand All @@ -406,12 +428,13 @@ mod tests {
assert_eq!(
parse_session(
get_app()
.get_matches_from_safe(vec!["rust-ptx-linker", "--arch", "sm_60"])
.get_matches_from_safe(vec!["rust-ptx-linker", "--arch", "sm_70"])
.unwrap()
),
Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![String::from("sm_60")],
ptx_archs: vec![String::from("sm_70")],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand All @@ -436,7 +459,8 @@ mod tests {
),
Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![String::from("sm_50"), String::from("sm_60")],
ptx_archs: vec![String::from("sm_50"), String::from("sm_60")],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,
Expand All @@ -455,7 +479,28 @@ mod tests {
),
Session {
emit: vec![Output::PTXAssembly],
achitectures: vec![String::from("sm_50"), String::from("sm_60")],
ptx_archs: vec![String::from("sm_50"), String::from("sm_60")],
ptx_fallback_arch: String::from("sm_30"),

opt_level: OptLevel::None,
debug_info: false,

output: None,
include_bitcode_modules: vec![],
include_rlibs: vec![],
}
);

assert_eq!(
parse_session(
get_app()
.get_matches_from_safe(vec!["rust-ptx-linker", "--fallback-arch", "sm_40"])
.unwrap()
),
Session {
emit: vec![Output::PTXAssembly],
ptx_archs: vec![],
ptx_fallback_arch: String::from("sm_40"),

opt_level: OptLevel::None,
debug_info: false,
Expand Down
9 changes: 4 additions & 5 deletions src/linker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,13 @@ impl Linker {
}

fn emit_asm(&self) -> Result<(), Error> {
if self.session.achitectures.len() > 1 {
if self.session.ptx_archs.len() > 1 {
bail!("More than 1 CUDA architecture is not yet supported with PTX output.");
}

// TOOD(denzp): is it possible to get architecture coming from Rust?
let arch = match self.session.achitectures.iter().next() {
Some(arch) => &arch,
None => "sm_20",
let arch = match self.session.ptx_archs.iter().next() {
Some(arch) => arch.as_str(),
None => self.session.ptx_fallback_arch.as_str(),
};

let path = CString::new(self.get_output_path()?.to_str().unwrap()).unwrap();
Expand Down
10 changes: 8 additions & 2 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ pub struct Session {
pub debug_info: bool,

pub emit: Vec<Output>,
pub achitectures: Vec<String>,
pub ptx_archs: Vec<String>,
pub ptx_fallback_arch: String,
}

impl Session {
Expand Down Expand Up @@ -83,7 +84,12 @@ impl Session {

/// Specify output architecture (e.g. `sm_60`).
pub fn add_output_arch(&mut self, arch: &str) {
self.achitectures.push(arch.into());
self.ptx_archs.push(arch.into());
}

/// Specify the fallback architecture if no other explicitly set.
pub fn set_fallback_arch(&mut self, arch: &str) {
self.ptx_fallback_arch = arch.into();
}

fn is_metadata_bitcode(&self, path: &Path) -> bool {
Expand Down

0 comments on commit 367d5f6

Please sign in to comment.