MLIR  22.0.0git
XeGPUBlocking.cpp
Go to the documentation of this file.
1 //===---- XeGPUBlocking.cpp ---- XeGPU Blocking Pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/Pass/PassManager.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
21 
22 namespace mlir {
23 namespace xegpu {
24 #define GEN_PASS_DEF_XEGPUBLOCKING
25 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
26 } // namespace xegpu
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "xegpu-blocking"
30 
31 using namespace mlir;
32 
33 namespace {
34 
35 // reslove the unrealized conversion cast ops generated when doing SCF
36 // Structural Type Conversion. It will have two formats, N:1 vector
37 // cast and 1:N vector cast. vector::insert_strided_slice ops will be
38 // used for the first case, and vector::extract_strided_slice ops will be
39 // used for the second case.
40 static void
41 resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
42  ValueRange inputs = castOp.getInputs();
43  ValueRange outputs = castOp.getOutputs();
44 
45  auto hasIdenticalVectorTypes = [](ValueRange values) {
46  auto types = values.getTypes();
47  return llvm::all_of(types, [&](Type type) {
48  return isa<VectorType>(type) && type == types.front();
49  });
50  };
51 
52  // We only interest in the case where all inputs and outputs have the
53  // identical VectorTypes
54  if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
55  LDBG() << "skip unrealized conversion cast op not emulating pack/unpack.";
56  return;
57  }
58 
59  VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
60  OpBuilder builder(castOp);
61  if (inputs.size() > 1 && outputs.size() == 1) {
62  // the castOp is emulating an unpack op
63  ArrayRef<int64_t> shape = outputTy.getShape();
65  builder, castOp.getLoc(), inputs, shape);
66  castOp->replaceAllUsesWith(ValueRange(result));
67  castOp->erase();
68  } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
69  // the castOp is emulating a pack op
70  ArrayRef<int64_t> tileShape = outputTy.getShape();
72  builder, castOp.getLoc(), inputs[0], tileShape);
73  castOp->replaceAllUsesWith(results);
74  castOp->erase();
75  }
76 }
77 
78 // This pattern lowers ConvertLayoutOp by removing the inst_data field from the
79 // layout attributes. Since both producer and consumer operations handle data
80 // partitioning based on their own inst_data, while maintaining original input
81 // and output shape, ConvertLayoutOp does not need to manage inst_data.
82 struct ConvertLayoutOpPattern
83  : public OpRewritePattern<xegpu::ConvertLayoutOp> {
85  LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
86  PatternRewriter &rewriter) const override {
87  xegpu::LayoutAttr input_layout = op.getInputLayoutAttr();
88  xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr();
89  if (!input_layout.getInstData() || !target_layout.getInstData())
90  return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
91 
92  input_layout = input_layout.dropInstData();
93  target_layout = target_layout.dropInstData();
94  auto newOp = rewriter.createOrFold<xegpu::ConvertLayoutOp>(
95  op.getLoc(), op.getType(), op.getSource(), input_layout, target_layout);
96  rewriter.replaceOp(op, newOp);
97  return success();
98  }
99 };
100 
101 //===------------------------------------------------------------------------===//
102 // The XeGPUBlockingPass leverages the unroll patterns for XeGPU and Vector ops
103 // to partition operations that process large shapes into multiple operations on
104 // smaller shapes, as specified by the inst_data in the layout attribute. This
105 // enables each resulting operation to be efficiently mapped to a hardware
106 // instruction.
107 //===------------------------------------------------------------------------===//
108 
109 class XeGPUBlockingPass final
110  : public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
111 public:
112  void runOnOperation() override;
113 
114 private:
115  // Get the tile shape for a given OpOperand or OpResult by examining the
116  // corresponding layout attribute. If layout is not present or is not a
117  // subgroup level layout, it returns std::nullopt.
118  template <typename T,
119  typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
120  std::is_same_v<T, OpResult>>>
121  std::optional<SmallVector<int64_t>>
122  getTileShape(const T &operandOrResult) const;
123 
124  // Get the tile shape for a given operation.
125  std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
126 
127  // Determine if the operation requires unrolling. Return false if all operands
128  // and results have tile shapes identical to their original types. Otherwise,
129  // return true.
130  bool needsUnroll(Operation *op) const;
131 };
132 } // namespace
133 
134 template <typename T, typename>
135 std::optional<SmallVector<int64_t>>
136 XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
137  Value value;
138  if constexpr (std::is_same_v<T, OpOperand>)
139  value = operandOrResult.get();
140  else
141  value = (Value)operandOrResult;
142 
143  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
144  if (layout && layout.isSgLayout()) {
145  if (auto inst_data = layout.getInstData())
146  return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
147 
148  if (auto type = dyn_cast<ShapedType>(value.getType()))
149  return llvm::to_vector(type.getShape());
150  }
151  LDBG() << "failed to getTileShape for: " << value;
152  return std::nullopt;
153 }
154 
155 std::optional<SmallVector<int64_t>>
157  if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
158  xegpu::UpdateOffsetOp>(op))
159  return getTileShape(op->getOpResult(0));
160  if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
161  xegpu::LoadGatherOp>(op))
162  return getTileShape(op->getOpOperand(0));
163  if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
164  return getTileShape(op->getOpOperand(1));
165 
166  if (isa<xegpu::DpasOp>(op)) {
167  std::optional<SmallVector<int64_t>> aTile =
168  getTileShape(op->getOpOperand(0));
169  std::optional<SmallVector<int64_t>> bTile =
170  getTileShape(op->getOpOperand(1));
171 
172  if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
173  return std::nullopt;
174 
175  // semantic check for A and B
176  if ((*aTile)[1] != (*bTile)[0])
177  return std::nullopt;
178 
179  // semantic check for C
180  if (op->getNumOperands() == 3) {
181  std::optional<SmallVector<int64_t>> cTile =
182  getTileShape(op->getOpOperand(2));
183  int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
184  if (!cTile || !llvm::equal(*cTile, expectedCTile))
185  return std::nullopt;
186  }
187 
188  return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
189  }
190 
192  return getTileShape(op->getOpResult(0));
193 
194  if (isa<vector::MultiDimReductionOp>(op))
195  return getTileShape(op->getOpOperand(0));
196 
197  if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
198  return getTileShape(op->getOpResult(0));
199 
200  return std::nullopt;
201 }
202 
203 bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
204  // skip the op if any of its operands or results has workgroup level layouts
205  bool hasWgLayoutOperands =
206  llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
207  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
208  return layout && layout.isWgLayout();
209  });
210  bool hasWgLayoutResults =
211  llvm::any_of(op->getOpResults(), [](OpResult result) {
212  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
213  return layout && layout.isWgLayout();
214  });
215  if (hasWgLayoutOperands || hasWgLayoutResults) {
216  LDBG() << "skip unrolling for op with workgroup level layout: " << *op;
217  return false;
218  }
219 
220  auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
221  Type valTy = value.getType();
222  if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
223  xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
224  return layout && layout.getInstData();
225  }
226  auto shapedType = dyn_cast<ShapedType>(valTy);
227  return shapedType && !llvm::equal(tileShape, shapedType.getShape());
228  };
229 
230  bool hasUnrollableOperands =
231  llvm::any_of(op->getOpOperands(), [&](OpOperand &opr) {
232  std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
233  return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
234  });
235  bool hasUnrollableResults =
236  llvm::any_of(op->getOpResults(), [&](OpResult result) {
237  std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
238  return tileShape.has_value() && isUnrollable(result, *tileShape);
239  });
240  return hasUnrollableOperands || hasUnrollableResults;
241 }
242 
243 void XeGPUBlockingPass::runOnOperation() {
244  MLIRContext *ctx = &getContext();
245  Operation *op = getOperation();
246 
247  // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
248  // This ensures that the LayoutAttr remains accessible even if the defining
249  // operation is replaced.
250  xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); });
251 
252  auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
253  xegpu::LayoutAttr layout) {
254  int count = 1;
255  SmallVector<int64_t> tileShape(shape);
256  if (layout && layout.getInstData()) {
257  DenseI32ArrayAttr instData = layout.getInstData();
258  tileShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
259  count = computeProduct(shape) / computeProduct(tileShape);
260  }
261  return std::make_pair(tileShape, count);
262  };
263 
264  // Perform type conversion for SCF control folow ops
265  TypeConverter converter;
266  converter.addConversion([](Type type) -> Type { return type; });
267  converter.addConversion(
268  [&](RankedTensorType type,
269  SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
270  Type elemTy = type.getElementType();
271  ArrayRef<int64_t> shape = type.getShape();
272 
273  auto layout =
274  llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
275  if (layout && layout.isWgLayout())
276  return failure();
277 
278  int count;
279  SmallVector<int64_t> subShape;
280  std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
281  auto newTy = VectorType::get(subShape, elemTy);
282  result.append(count, newTy);
283  return success();
284  });
285  converter.addConversion(
286  [&](xegpu::TensorDescType type,
287  SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
288  Type elemTy = type.getElementType();
289  ArrayRef<int64_t> shape = type.getShape();
290 
291  xegpu::LayoutAttr layout = type.getLayoutAttr();
292  if (layout && layout.isWgLayout())
293  return failure();
294 
295  int count;
296  SmallVector<int64_t> subShape;
297  std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
298 
299  if (layout)
300  layout = layout.dropInstData();
301 
302  auto newTy = xegpu::TensorDescType::get(
303  type.getContext(), subShape, elemTy, type.getEncoding(), layout);
304  result.append(count, newTy);
305  return success();
306  });
307 
309 
311  options.setFilterConstraint(
312  [&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
313 
314  options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
315 
316  options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
317  Type elemTy = type.getElementType();
318  Type newTy;
319 
320  if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
321 
322  Attribute encoding = tdescTy.getEncoding();
323  // If the encoding is a ScatterTensorDescAttr, we need to
324  // potentially adjust the chunk size based on the inst_data.
325  if (tdescTy.isScattered()) {
326  int64_t chunkSize = tdescTy.getChunkSizeAsInt();
327 
328  if (chunkSize > 1) {
329  int64_t blockedChunkSize = chunkSize;
330  auto instData = tdescTy.getLayoutAttr().getInstData();
331  if (!instData.empty())
332  blockedChunkSize = instData.asArrayRef().back();
333 
334  // To create a new attribute with a different chunk_size:
335  auto newEncoding = xegpu::ScatterTensorDescAttr::get(
336  ctx, tdescTy.getMemorySpace(), blockedChunkSize);
337 
338  encoding = newEncoding;
339  }
340  }
341 
342  newTy =
343  xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
344  tdescTy.getLayoutAttr().dropInstData());
345  } else {
346  newTy = type.clone(tileShape, elemTy);
347  }
348 
349  std::optional<SmallVector<int64_t>> ratio =
350  computeShapeRatio(type.getShape(), tileShape);
351  assert(ratio && "The shape of the type must be a multiple of tileShape.");
352  return SmallVector<Type>(computeProduct(*ratio), newTy);
353  });
354 
356  patterns.add<ConvertLayoutOpPattern>(ctx);
357 
358  vector::UnrollVectorOptions vectorOptions;
359  vectorOptions.setNativeShapeFn(options.nativeShape);
360 
362  vector::populateVectorUnrollPatterns(patterns, vectorOptions);
363 
364  (void)applyPatternsGreedily(op, std::move(patterns));
365 
366  op->walk([](Operation *op) {
367  // Remove the layout attributes cached per operands.
368  for (OpOperand &opr : op->getOpOperands()) {
369  std::string name = xegpu::getLayoutName(opr);
370  if (op->hasAttrOfType<xegpu::LayoutAttr>(name))
371  op->removeAttr(name);
372  }
373 
374  // Update the layout attributes per result.
375  for (OpResult result : op->getOpResults()) {
376  std::string name = xegpu::getLayoutName(result);
377  if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
378  op->removeAttr(name);
379  if (!isa<LoopLikeOpInterface>(op))
380  xegpu::setLayoutAttr(result, layout.dropInstData());
381  }
382  }
383 
384  // Resolve unrealized conversion cast ops emulating pack/unpack
385  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
386  resolveUnrealizedConversionCastOp(castOp);
387  });
388 }
static MLIRContext * getContext(OpFoldResult val)
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Definition: MMAUtils.cpp:37
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:388
bool hasAttrOfType(NameT &&name)
Definition: Operation.h:575
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumOperands()
Definition: Operation.h:346
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getOpResults()
Definition: Operation.h:420
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:600
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
Definition: XeGPUUtils.cpp:237
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
LayoutAttr getLayoutAttr(const Value value)
Retrieves the LayoutAttr associated with a given Value.
Definition: XeGPUUtils.cpp:114
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)
Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...
Definition: XeGPUUtils.cpp:160
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach LayoutAttr.
Definition: XeGPUUtils.cpp:103
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
Definition: XeGPUUtils.cpp:262
void setLayoutAttrs(Operation *op, function_ref< LayoutAttr(Value)> getLayoutImpl)
Set the LayoutAttr for each OpOperand and OpResult of the given operation.
Definition: XeGPUUtils.cpp:177
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Definition: XeGPUUtils.cpp:217
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
Options to control the XeGPU unrolling.
Definition: Transforms.h:27