MLIR  20.0.0git
BasicPtxBuilderInterface.cpp
Go to the documentation of this file.
1 //===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
10 // automatically. It is used by NVVM to LLVM pass.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 #define DEBUG_TYPE "ptx-builder"
17 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
18 #define DBGSNL() (llvm::dbgs() << "\n")
19 
20 //===----------------------------------------------------------------------===//
21 // BasicPtxBuilderInterface
22 //===----------------------------------------------------------------------===//
23 
24 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
25 
26 using namespace mlir;
27 using namespace NVVM;
28 
29 static constexpr int64_t kSharedMemorySpace = 3;
30 
31 static char getRegisterType(Type type) {
32  if (type.isInteger(1))
33  return 'b';
34  if (type.isInteger(16))
35  return 'h';
36  if (type.isInteger(32))
37  return 'r';
38  if (type.isInteger(64))
39  return 'l';
40  if (type.isF32())
41  return 'f';
42  if (type.isF64())
43  return 'd';
44  if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
45  // Shared address spaces is addressed with 32-bit pointers.
46  if (ptr.getAddressSpace() == kSharedMemorySpace) {
47  return 'r';
48  }
49  return 'l';
50  }
51  // register type for struct is not supported.
52  llvm_unreachable("The register type could not deduced from MLIR type");
53  return '?';
54 }
55 
56 static char getRegisterType(Value v) {
57  if (v.getDefiningOp<LLVM::ConstantOp>())
58  return 'n';
59  return getRegisterType(v.getType());
60 }
61 
63  LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
64  auto getModifier = [&]() -> const char * {
65  if (itype == PTXRegisterMod::ReadWrite) {
66  assert(false && "Read-Write modifier is not supported. Try setting the "
67  "same value as Write and Read separately.");
68  return "+";
69  }
70  if (itype == PTXRegisterMod::Write) {
71  return "=";
72  }
73  return "";
74  };
75  auto addValue = [&](Value v) {
76  if (itype == PTXRegisterMod::Read) {
77  ptxOperands.push_back(v);
78  return;
79  }
80  if (itype == PTXRegisterMod::ReadWrite)
81  ptxOperands.push_back(v);
82  hasResult = true;
83  };
84 
85  llvm::raw_string_ostream ss(registerConstraints);
86  // Handle Structs
87  if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
88  if (itype == PTXRegisterMod::Write) {
89  addValue(v);
90  }
91  for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
92  if (itype != PTXRegisterMod::Write) {
93  Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
94  interfaceOp->getLoc(), v, idx);
95  addValue(extractValue);
96  }
97  if (itype == PTXRegisterMod::ReadWrite) {
98  ss << idx << ",";
99  } else {
100  ss << getModifier() << getRegisterType(t) << ",";
101  }
102  ss.flush();
103  }
104  return;
105  }
106  // Handle Scalars
107  addValue(v);
108  ss << getModifier() << getRegisterType(v) << ",";
109  ss.flush();
110 }
111 
112 LLVM::InlineAsmOp PtxBuilder::build() {
113  auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
114  LLVM::AsmDialect::AD_ATT);
115 
116  auto resultTypes = interfaceOp->getResultTypes();
117 
118  // Remove the last comma from the constraints string.
119  if (!registerConstraints.empty() &&
120  registerConstraints[registerConstraints.size() - 1] == ',')
121  registerConstraints.pop_back();
122 
123  std::string ptxInstruction = interfaceOp.getPtx();
124 
125  // Add the predicate to the asm string.
126  if (interfaceOp.getPredicate().has_value() &&
127  interfaceOp.getPredicate().value()) {
128  std::string predicateStr = "@%";
129  predicateStr += std::to_string((ptxOperands.size() - 1));
130  ptxInstruction = predicateStr + " " + ptxInstruction;
131  }
132 
133  // Tablegen doesn't accept $, so we use %, but inline assembly uses $.
134  // Replace all % with $
135  std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');
136 
137  return rewriter.create<LLVM::InlineAsmOp>(
138  interfaceOp->getLoc(),
139  /*result types=*/resultTypes,
140  /*operands=*/ptxOperands,
141  /*asm_string=*/llvm::StringRef(ptxInstruction),
142  /*constraints=*/registerConstraints.data(),
143  /*has_side_effects=*/interfaceOp.hasSideEffect(),
144  /*is_align_stack=*/false,
145  /*asm_dialect=*/asmDialectAttr,
146  /*operand_attrs=*/ArrayAttr());
147 }
148 
150  LLVM::InlineAsmOp inlineAsmOp = build();
151  LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
152  if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
153  rewriter.replaceOp(interfaceOp, inlineAsmOp);
154  } else {
155  rewriter.eraseOp(interfaceOp);
156  }
157 }
static char getRegisterType(Type type)
#define DBGS()
void insertValue(Value v, PTXRegisterMod itype=PTXRegisterMod::Read)
Add an operand with the read/write input type.
LLVM::InlineAsmOp build()
Builds the inline assembly Op and returns it.
void buildAndReplaceOp()
Shortcut to build the inline assembly Op and replace or erase the original op with.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:53
bool isF32() const
Definition: Types.cpp:52
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:59
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
PTXRegisterMod
Register read/write modifier to build constraint string for PTX inline https://docs....
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kSharedMemorySpace
Shared memory space identifier.
Definition: NVVMDialect.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...