MLIR 22.0.0git
XeGPUBlocking.cpp
Go to the documentation of this file.
1//===---- XeGPUBlocking.cpp ---- XeGPU Blocking Pass ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/DebugLog.h"
22
23namespace mlir {
24namespace xegpu {
25#define GEN_PASS_DEF_XEGPUBLOCKING
26#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
27} // namespace xegpu
28} // namespace mlir
29
30#define DEBUG_TYPE "xegpu-blocking"
32using namespace mlir;
34namespace {
36// reslove the unrealized conversion cast ops generated when doing SCF
37// Structural Type Conversion. It will have two formats, N:1 vector
38// cast and 1:N vector cast. vector::insert_strided_slice ops will be
39// used for the first case, and vector::extract_strided_slice ops will be
40// used for the second case.
41static void
42resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
43 ValueRange inputs = castOp.getInputs();
44 ValueRange outputs = castOp.getOutputs();
45
46 auto hasIdenticalVectorTypes = [](ValueRange values) {
47 auto types = values.getTypes();
48 return llvm::all_of(types, [&](Type type) {
49 return isa<VectorType>(type) && type == types.front();
50 });
51 };
52
53 // We only interest in the case where all inputs and outputs have the
54 // identical VectorTypes
55 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
56 LDBG() << "skip unrealized conversion cast op not emulating pack/unpack.";
57 return;
58 }
59
60 VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
61 OpBuilder builder(castOp);
62 if (inputs.size() > 1 && outputs.size() == 1) {
63 // the castOp is emulating an unpack op
64 ArrayRef<int64_t> shape = outputTy.getShape();
66 builder, castOp.getLoc(), inputs, shape);
67 castOp->replaceAllUsesWith(ValueRange(result));
68 castOp->erase();
69 } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
70 // the castOp is emulating a pack op
71 ArrayRef<int64_t> tileShape = outputTy.getShape();
73 builder, castOp.getLoc(), inputs[0], tileShape);
74 castOp->replaceAllUsesWith(results);
75 castOp->erase();
76 }
77}
78
79// This pattern lowers ConvertLayoutOp by removing the inst_data field from the
80// layout attributes. Since both producer and consumer operations handle data
81// partitioning based on their own inst_data, while maintaining original input
82// and output shape, ConvertLayoutOp does not need to manage inst_data.
83struct ConvertLayoutOpPattern
84 : public OpRewritePattern<xegpu::ConvertLayoutOp> {
86 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
87 PatternRewriter &rewriter) const override {
88 xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
89 xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
90 if (inputLayout.getEffectiveInstDataAsInt().empty() ||
91 targetLayout.getEffectiveInstDataAsInt().empty())
92 return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
93
94 inputLayout = inputLayout.dropInstData();
95 targetLayout = targetLayout.dropInstData();
96 auto newOp = rewriter.createOrFold<xegpu::ConvertLayoutOp>(
97 op.getLoc(), op.getType(), op.getSource(), inputLayout, targetLayout);
98 rewriter.replaceOp(op, newOp);
99 return success();
100 }
101};
102
103//===------------------------------------------------------------------------===//
104// The XeGPUBlockingPass leverages the unroll patterns for XeGPU and Vector ops
105// to partition operations that process large shapes into multiple operations on
106// smaller shapes, as specified by the inst_data in the layout attribute. This
107// enables each resulting operation to be efficiently mapped to a hardware
108// instruction.
109//===------------------------------------------------------------------------===//
110
111class XeGPUBlockingPass final
112 : public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
113public:
114 void runOnOperation() override;
115
116private:
117 // Get the tile shape for a given OpOperand or OpResult by examining the
118 // corresponding layout attribute. If layout is not present or is not a
119 // subgroup level layout, it returns std::nullopt.
120 template <typename T,
121 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
122 std::is_same_v<T, OpResult>>>
123 std::optional<SmallVector<int64_t>>
124 getTileShape(const T &operandOrResult) const;
125
126 // Get the tile shape for a given operation.
127 std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
128
129 // Determine if the operation requires unrolling. Return false if all operands
130 // and results have tile shapes identical to their original types. Otherwise,
131 // return true.
132 bool needsUnroll(Operation *op) const;
133};
134} // namespace
135
136template <typename T, typename>
137std::optional<SmallVector<int64_t>>
138XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
139 Value value;
140 if constexpr (std::is_same_v<T, OpOperand>)
141 value = operandOrResult.get();
142 else
143 value = (Value)operandOrResult;
144
145 xegpu::DistributeLayoutAttr layout =
146 xegpu::getDistributeLayoutAttr(operandOrResult);
147 if (layout && layout.isForSubgroup()) {
148 if (!layout.getEffectiveInstDataAsInt().empty()) {
149 SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
150 // Remove leading unit dimensions from inst_data
151 // For example, if the inst_data is [1, 1, 32]
152 // it will pass [32] as the unroll/blocking size.
153 // Skip it for xegpu nd ops since it will be 2D
154 // TODO: For vectors ops, experiment with the
155 // upstream vector remove leading unit dims patterns,
156 // populateCastAwayVectorLeadingOneDimPatterns.
157 Operation *definingOp = value.getDefiningOp();
158 bool skipLeadingUnitDimRemoval =
159 definingOp &&
160 (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
161 xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
162 if (!skipLeadingUnitDimRemoval) {
163 auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
164 instData.erase(instData.begin(), it);
165 }
166 return instData;
167 }
168
169 if (auto type = dyn_cast<ShapedType>(value.getType()))
170 return llvm::to_vector(type.getShape());
171 }
172 LDBG() << "failed to getTileShape for: " << value;
173 return std::nullopt;
174}
175
176std::optional<SmallVector<int64_t>>
177XeGPUBlockingPass::getTileShape(Operation *op) const {
178 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
179 xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
180 return getTileShape(op->getOpResult(0));
181 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
182 xegpu::StoreMatrixOp>(op))
183 return getTileShape(op->getOpOperand(0));
184 if (isa<xegpu::StoreNdOp>(op))
185 return getTileShape(op->getOpOperand(1));
186
187 // Handle LoadGatherOp and StoreScatterOp (with and without offset)
188 if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
189 if (loadGatherOp.getOffsets())
190 return getTileShape(loadGatherOp->getOpResult(0));
191 else
192 return getTileShape(loadGatherOp->getOpOperand(0));
193 }
194
195 if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
196 return getTileShape(storeScatterOp.getOffsets()
197 ? storeScatterOp->getOpOperand(0)
198 : storeScatterOp->getOpOperand(1));
199
200 if (isa<xegpu::DpasOp>(op)) {
201 std::optional<SmallVector<int64_t>> aTile =
203 std::optional<SmallVector<int64_t>> bTile =
205
206 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
207 return std::nullopt;
208
209 // semantic check for A and B
210 if ((*aTile)[1] != (*bTile)[0])
211 return std::nullopt;
212
213 // semantic check for C
214 if (op->getNumOperands() == 3) {
215 std::optional<SmallVector<int64_t>> cTile =
217 int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
218 if (!cTile || !llvm::equal(*cTile, expectedCTile))
219 return std::nullopt;
220 }
221
222 return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
223 }
224
226 return getTileShape(op->getOpResult(0));
227
228 if (isa<vector::MultiDimReductionOp>(op))
229 return getTileShape(op->getOpOperand(0));
230
231 if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
232 return getTileShape(op->getOpResult(0));
233
234 return std::nullopt;
235}
236
237bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
238 // skip the op if any of its operands or results has workgroup level layouts
239 bool hasWgLayoutOperands =
240 llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
241 xegpu::DistributeLayoutAttr layout =
242 xegpu::getDistributeLayoutAttr(opr);
243 return layout && layout.isForWorkgroup();
244 });
245 bool hasWgLayoutResults =
246 llvm::any_of(op->getOpResults(), [](OpResult result) {
247 xegpu::DistributeLayoutAttr layout =
248 xegpu::getDistributeLayoutAttr(result);
249 return layout && layout.isForWorkgroup();
250 });
251 if (hasWgLayoutOperands || hasWgLayoutResults) {
252 LDBG() << "skip unrolling for op with workgroup level layout: " << *op;
253 return false;
254 }
255
256 auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
257 Type valTy = value.getType();
258 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
259 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
260 return layout && !layout.getEffectiveInstDataAsInt().empty();
261 }
262 auto shapedType = dyn_cast<ShapedType>(valTy);
263 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
264 };
265
266 bool hasUnrollableOperands =
267 llvm::any_of(op->getOpOperands(), [&](OpOperand &opr) {
268 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
269 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
270 });
271 bool hasUnrollableResults =
272 llvm::any_of(op->getOpResults(), [&](OpResult result) {
273 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
274 return tileShape.has_value() && isUnrollable(result, *tileShape);
275 });
276 return hasUnrollableOperands || hasUnrollableResults;
277}
278
279void XeGPUBlockingPass::runOnOperation() {
280 MLIRContext *ctx = &getContext();
281 Operation *op = getOperation();
282
283 // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
284 // This ensures that the LayoutAttr remains accessible even if the defining
285 // operation is replaced.
287 op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
288
289 auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
290 xegpu::LayoutAttr layout) {
291 int count = 1;
292 SmallVector<int64_t> tileShape(shape);
293 if (layout && layout.getInstData()) {
294 DenseI32ArrayAttr instData = layout.getInstData();
295 tileShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
296 count = computeProduct(shape) / computeProduct(tileShape);
297 }
298 return std::make_pair(tileShape, count);
299 };
300
301 // Perform type conversion for SCF control folow ops
302 TypeConverter converter;
303 converter.addConversion([](Type type) -> Type { return type; });
304 converter.addConversion(
305 [&](RankedTensorType type,
306 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
307 Type elemTy = type.getElementType();
308 ArrayRef<int64_t> shape = type.getShape();
309
310 auto layout =
311 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
312 if (layout && layout.isForWorkgroup())
313 return failure();
314
315 int count;
316 SmallVector<int64_t> subShape;
317 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
318 auto newTy = VectorType::get(subShape, elemTy);
319 result.append(count, newTy);
320 return success();
321 });
322 converter.addConversion(
323 [&](xegpu::TensorDescType type,
324 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
325 Type elemTy = type.getElementType();
326 ArrayRef<int64_t> shape = type.getShape();
327
328 xegpu::LayoutAttr layout = type.getLayoutAttr();
329 if (layout && layout.isForWorkgroup())
330 return failure();
331
332 int count;
333 SmallVector<int64_t> subShape;
334 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
335
336 if (layout)
337 layout = layout.dropInstData();
338
339 auto newTy = xegpu::TensorDescType::get(
340 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
341 result.append(count, newTy);
342 return success();
343 });
344
346
347 // Remove leading unit dimensions from vector ops and then
348 // do the unrolling.
349 {
350 RewritePatternSet patterns(ctx);
351 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
352 (void)applyPatternsGreedily(op, std::move(patterns));
353 }
354 xegpu::UnrollOptions options;
355 options.setFilterConstraint(
356 [&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
357
358 options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
359
360 options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape,
361 bool returnSingleType = false) {
362 Type elemTy = type.getElementType();
363 Type newTy;
364
365 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
366
367 Attribute encoding = tdescTy.getEncoding();
368 // If the encoding is a ScatterTensorDescAttr, we need to
369 // potentially adjust the chunk size based on the inst_data.
370 if (tdescTy.isScattered()) {
371 int64_t chunkSize = tdescTy.getChunkSizeAsInt();
372
373 if (chunkSize > 1) {
374 int64_t blockedChunkSize = chunkSize;
375 auto instData = tdescTy.getLayoutAttr().getInstData();
376 if (!instData.empty())
377 blockedChunkSize = instData.asArrayRef().back();
378
379 // To create a new attribute with a different chunk_size:
380 auto newEncoding = xegpu::ScatterTensorDescAttr::get(
381 ctx, tdescTy.getMemorySpace(), blockedChunkSize);
382 encoding = newEncoding;
383 }
384 }
385
386 newTy =
387 xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
388 tdescTy.getLayoutAttr().dropInstData());
389 } else {
390 newTy = VectorType::get(tileShape, elemTy);
391 }
392
393 if (returnSingleType)
394 return SmallVector<Type>{newTy};
395 std::optional<SmallVector<int64_t>> ratio =
396 computeShapeRatio(type.getShape(), tileShape);
397 assert(ratio && "The shape of the type must be a multiple of tileShape.");
398 return SmallVector<Type>(computeProduct(*ratio), newTy);
399 });
400
401 RewritePatternSet patterns(ctx);
402 patterns.add<ConvertLayoutOpPattern>(ctx);
403
404 vector::UnrollVectorOptions vectorOptions;
405 vectorOptions.setNativeShapeFn(options.nativeShape);
406
408 vector::populateVectorUnrollPatterns(patterns, vectorOptions);
409
410 (void)applyPatternsGreedily(op, std::move(patterns));
411
412 op->walk([](Operation *op) {
413 // Remove the layout attributes cached per operands.
414 for (OpOperand &opr : op->getOpOperands()) {
415 std::string name = xegpu::getLayoutName(opr);
416 if (op->hasAttrOfType<xegpu::LayoutAttr>(name))
417 op->removeAttr(name);
418 }
419
420 // Update the layout attributes per result.
421 for (OpResult result : op->getOpResults()) {
422 std::string name = xegpu::getLayoutName(result);
423 if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
424 op->removeAttr(name);
425 if (!isa<LoopLikeOpInterface>(op))
426 xegpu::setDistributeLayoutAttr(result, layout.dropInstData());
427 }
428 }
429
430 // Resolve unrealized conversion cast ops emulating pack/unpack
431 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
432 resolveUnrealizedConversionCastOp(castOp);
433 });
434}
return success()
b getContext())
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Definition MMAUtils.cpp:37
static llvm::ManagedStatic< PassManagerOptions > options
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getOpResult(unsigned idx)
Definition Operation.h:421
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
bool hasAttrOfType(NameT &&name)
Definition Operation.h:575
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getOpResults()
Definition Operation.h:420
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:600
OpOperand & getOpOperand(unsigned idx)
Definition Operation.h:388
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
void setDistributeLayoutAttrs(Operation *op, function_ref< DistributeLayoutAttr(Value)> getLayoutImpl)
Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...