MLIR 23.0.0git
XeGPUSgToWiDistributeExperimental.cpp
Go to the documentation of this file.
1//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
18#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinOps.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/Value.h"
25#include "mlir/IR/ValueRange.h"
27#include "llvm/ADT/SetVector.h"
28#include "llvm/Support/LogicalResult.h"
29#include "llvm/Support/raw_ostream.h"
30#include <optional>
31
32namespace mlir {
33namespace xegpu {
34#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
35#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
36} // namespace xegpu
37} // namespace mlir
38
39using namespace mlir;
40
41#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
42#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
43
44namespace {
45
46/// Casts the given vector value `v` to the expected vector type `expectedTy`.
47static Value castValueTo(ConversionPatternRewriter &rewriter,
48 TypedValue<VectorType> v, VectorType expectedTy) {
49 // If the type matches, simply return the value itself.
50 if (v.getType() == expectedTy)
51 return v;
52 // If only shape differs, use shape cast.
53 if (isa<VectorType>(v.getType()) &&
54 v.getType().getNumElements() == expectedTy.getNumElements())
55 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
56
57 // Else create an unrealized cast.
58 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
59 expectedTy, ValueRange{v});
60 return newOp.getResult(0);
61}
62
63/// Checks if all XeGPU anchor ops and vector results have valid layouts.
64static LogicalResult verifyLayouts(Operation *root) {
65 auto walkResult = root->walk([&](Operation *nestedOp) -> WalkResult {
66 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(nestedOp)) {
67 auto layout = anchorOp.getAnchorLayout();
68 if (!layout) {
69 nestedOp->emitError("expected anchor layout attribute on operation");
70 return WalkResult::interrupt();
71 }
72 return WalkResult::advance();
73 }
74 // For each vector result, check if the op contains a result layout
75 // attribute.
76 for (OpResult result : nestedOp->getResults()) {
77 if (isa<VectorType>(result.getType())) {
79 if (!layout) {
80 nestedOp->emitError(
81 "expected result layout attribute on vector result");
82 return WalkResult::interrupt();
83 }
84 }
85 }
86 return WalkResult::advance();
87 });
88 return walkResult.wasInterrupted() ? failure() : success();
89}
90
91/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
92/// exactly 1 reduction dimension, it had valid result layout attribute, and
93/// result type can be distributed to lanes using the layout.
94static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
95 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
96 // If no layout, not valid.
97 if (!resLayout || !resLayout.isForSubgroup())
98 return false;
99 VectorType resTy = dyn_cast<VectorType>(op.getType());
100 if (!resTy)
101 return false;
102 // Compute the distributed result vector type based on the layout.
103 FailureOr<VectorType> resDistTypeOrFailure =
104 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
105 if (failed(resDistTypeOrFailure))
106 return false;
107 return op.getReductionDims().size() == 1;
108}
109
110/// A vector::MultiDimReductionOp is doing lane-local reduction if each workitem
111/// is doing its own local reduction. In this case the result layout ensures
112/// that result vector is distributed to lanes, i.e. the result vector type is
113/// different from the distributed result vector type.
114static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
115 // Must be valid MultiDimReductionOp.
116 assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
117 "MultiDimReductionOp");
118 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
119 VectorType resTy = dyn_cast<VectorType>(op.getType());
120 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
121 return resTy != resDistTypeOrFailure.value();
122}
123
124/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
125/// op. This simply drops the layout attribute from the tensor descriptor type.
126struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
127 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
128
129 LogicalResult
130 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter) const override {
132 xegpu::TensorDescType resultType = op.getType();
133 // If no layout, nothing to do.
134 if (!resultType.getLayout())
135 return failure();
136
137 auto newOp = xegpu::CreateNdDescOp::create(
138 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
139 op->getAttrs());
140 rewriter.replaceOp(op, newOp.getResult());
141 return success();
142 }
143};
144
145/// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output
146/// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the
147/// original rank.
148struct SgToWiLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
149 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
150
151 LogicalResult
152 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
153 ConversionPatternRewriter &rewriter) const override {
154 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
155 // If no layout, nothing to do.
156 if (!layout)
157 return failure();
158 // Check if the layout attached to the tensor descriptor is same as the
159 // anchor layout. Otherwise, this is a conflict.
160 if (op.getTensorDescType().getLayout() != layout)
161 return rewriter.notifyMatchFailure(
162 op, "conflicting layout attributes on tensor descriptor and anchor");
163 auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
164 if (!uArch)
165 return rewriter.notifyMatchFailure(
166 op, "xegpu::LoadNdOp require target attribute attached to "
167 "determine transpose "
168 "requirement");
169 auto supportedWiResultTyOrFailure =
170 xegpu::getDistributedVectorType(op.getTensorDescType());
171 auto expectedWiResultTyOrFailure =
172 xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
173 if (failed(supportedWiResultTyOrFailure))
174 return rewriter.notifyMatchFailure(
175 op, "unable to compute the workitem vector type for LoadNdOp");
176 if (failed(expectedWiResultTyOrFailure))
177 return rewriter.notifyMatchFailure(
178 op,
179 "unable to compute expected workitem vector type from lane layout");
180 auto newOp = xegpu::LoadNdOp::create(
181 rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
182 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
183 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
184 op.getL3HintAttr(), /**layout**/ nullptr);
185 // Set the packed attribute if the layout requires it.
186 newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
187 // Set the transpose attribute if the layout requires it.
188 if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
189 newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
190 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
191 expectedWiResultTyOrFailure.value()));
192 return success();
193 }
194};
195
196/// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored
197/// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the
198/// incoming value to 1D.
199struct SgToWiStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
200 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
201
202 LogicalResult
203 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
204 ConversionPatternRewriter &rewriter) const override {
205 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
206 // If no layout, nothing to do.
207 if (!layout)
208 return failure();
209 // Check if the layout attached to the tensor descriptor and value layout is
210 // same as the anchor layout. Otherwise, this is a conflict.
211 if (op.getTensorDescType().getLayout() != layout)
212 return rewriter.notifyMatchFailure(
213 op, "conflicting layout attributes on tensor descriptor and anchor");
214 auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
215 if (valueLayout != layout)
216 return rewriter.notifyMatchFailure(
217 op, "conflicting layout attributes on value and anchor");
218 auto supportedWiValueTyOrFailure =
219 xegpu::getDistributedVectorType(op.getTensorDescType());
220 if (failed(supportedWiValueTyOrFailure))
221 return rewriter.notifyMatchFailure(
222 op,
223 "unable to compute wi vector type for StoreNdOp value from tensor "
224 "descriptor");
225
226 xegpu::StoreNdOp::create(
227 rewriter, op.getLoc(),
228 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
229 supportedWiValueTyOrFailure.value()),
230 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
231 op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
232 rewriter.eraseOp(op);
233 return success();
234 }
235};
236
237/// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs
238/// and output of workitem-level Dpas op are 1D. Necessary casts are added to
239/// convert the inputs and output to/from 1D.
240struct SgToWiDpas : public OpConversionPattern<xegpu::DpasOp> {
241 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
242
243 LogicalResult
244 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override {
246 // llvm::errs() << "DpasOpPattern matchAndRewrite called\n";
247 // Check if the op has A, B and CD layouts attached.
248 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
249 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
250 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
251 if (!layoutA || !layoutB || !layoutCd)
252 return failure();
253 // llvm::errs() << "tryning to calculate wi types for dpas op\n";
254 auto wiResultTyOrFailure =
255 xegpu::getDistributedVectorType(op.getType(), layoutCd);
256 auto wiATypeOrFailure =
257 xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
258 auto wiBTypeOrFailure =
259 xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
260 auto expectedWiResultTyOrFailure =
261 xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
262 if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
263 failed(wiBTypeOrFailure))
264 return rewriter.notifyMatchFailure(
265 op, "failed to calculate supported workitem vector types for DpasOp "
266 "from layouts");
267 if (failed(expectedWiResultTyOrFailure))
268 return rewriter.notifyMatchFailure(
269 op, "unable to compute expected workitem vector type for DpasOp from "
270 "lane layout");
271 auto newOp = xegpu::DpasOp::create(
272 rewriter, op->getLoc(), wiResultTyOrFailure.value(),
273 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
274 wiATypeOrFailure.value()),
275 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
276 wiBTypeOrFailure.value()),
277 castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
278 wiResultTyOrFailure.value()),
279 /** layoutA**/ nullptr,
280 /** layoutB**/ nullptr, /** layoutCd**/ nullptr);
281 // Explicitly set the new types to enable correct type materializations.
282 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
283 expectedWiResultTyOrFailure.value()));
284 return success();
285 }
286};
287
288/// Distributes elementwise ops to workitem-level elementwise ops. This
289/// currently handles elementwise ops with single result only.
290struct SgToWiElementWise : public ConversionPattern {
291 SgToWiElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
292 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
293
294 LogicalResult
295 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
296 ConversionPatternRewriter &rewriter) const override {
297 // Only match ops with elementwise trait and single result.
299 return failure();
300
301 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
302 if (!resultType)
303 return rewriter.notifyMatchFailure(
304 op, "operation result is not a vector type");
305
306 xegpu::DistributeLayoutAttr layout =
307 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
308 if (!layout || !layout.isForSubgroup())
309 return rewriter.notifyMatchFailure(
310 op, "operation result does not have subgroup distribute layout");
311
312 auto wiShapeOrFailure =
313 xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
314
315 if (failed(wiShapeOrFailure))
316 return rewriter.notifyMatchFailure(
317 op, "unable to compute workitem vector type from the layout");
318
319 VectorType newResultType = wiShapeOrFailure.value();
320 OperationState state(op->getLoc(), op->getName());
321 state.addOperands(operands);
322 state.addTypes(newResultType);
323 // Copy all attributes except for DistributeLayoutAttr.
324 for (auto attr : op->getAttrs()) {
325 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
326 state.addAttribute(attr.getName(), attr.getValue());
327 }
328 Operation *newOp = rewriter.create(state);
329
330 rewriter.replaceOp(op, newOp->getResult(0));
331 return success();
332 }
333};
334
335/// Distributes a subgroup-level arith ConstantOp to workitem-level arith
336/// ConstantOp.
337struct SgToWiArithConstant : public OpConversionPattern<arith::ConstantOp> {
338 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
339
340 LogicalResult
341 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter) const override {
343 auto resultType = dyn_cast<VectorType>(op.getType());
344 if (!resultType)
345 return failure();
346
347 // Only handle dense vector constants
348 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
349 if (!dense)
350 return rewriter.notifyMatchFailure(
351 op, "only dense splat vector constants are supported");
352
353 xegpu::DistributeLayoutAttr layout =
354 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
355 if (!layout || !layout.isForSubgroup())
356 return rewriter.notifyMatchFailure(
357 op, "operation result does not have subgroup distribute layout");
358
359 auto wiShapeOrFailure =
360 xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
361
362 if (failed(wiShapeOrFailure))
363 return rewriter.notifyMatchFailure(
364 op, "unable to compute workitem vector type from the layout");
366 VectorType newResultType = wiShapeOrFailure.value();
367 auto sclarValue = dense.getSplatValue<Attribute>();
368 auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
370 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
371 newDenseAttr);
372 rewriter.replaceOp(op, newOp.getResult());
373 return success();
374 }
376
377/// Distributes a subgroup-level PrefetchNd op to workitem-level PrefetchNd op.
378struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
379 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
381 LogicalResult
382 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
383 ConversionPatternRewriter &rewriter) const override {
384 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
385 // If no layout, nothing to do.
386 if (!layout)
387 return failure();
388
389 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
390 op.getMixedOffsets(), op.getL1HintAttr(),
391 op.getL2HintAttr(), op.getL3HintAttr(),
392 /**layout**/ nullptr);
393 rewriter.eraseOp(op);
394 return success();
395 }
396};
397
398/// This pattern distributes a subgroup-level vector.reduction op to
399/// workitem-level. This require shuffling the data across the workitems (using
400/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
401/// result.
402struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
403 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
404
405 LogicalResult
406 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter) const override {
408 auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
409
410 // If no layout, nothing to do.
411 if (!layout || !layout.isForSubgroup())
412 return failure();
413
414 VectorType srcVecType = op.getSourceVectorType();
415 // Only rank 1 vectors supported.
416 if (srcVecType.getRank() != 1)
417 return rewriter.notifyMatchFailure(
418 op, "Only rank 1 reductions can be distributed.");
419 // Lane layout must have the same rank as the vector.
420 if (layout.getRank() != srcVecType.getRank())
421 return rewriter.notifyMatchFailure(
422 op, "Layout rank does not match vector rank.");
423
424 // Get the subgroup size from the layout.
425 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
426 const auto *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
427 if (!uArch)
428 return rewriter.notifyMatchFailure(
429 op, "xegpu::ReductionOp require target attribute attached to "
430 "determine subgroup size");
431
432 // Only subgroup-sized vectors supported.
433 if (sgSize != uArch->getSubgroupSize() ||
434 srcVecType.getShape()[0] % sgSize != 0)
435 return rewriter.notifyMatchFailure(op,
436 "Invalid layout or reduction vector "
437 "dimension must match subgroup size.");
438
439 if (!op.getType().isIntOrFloat())
440 return rewriter.notifyMatchFailure(
441 op, "Reduction distribution currently only supports floats and "
442 "integer types.");
443
444 // Get the distributed vector (per work-item portion).
445 Value laneValVec = adaptor.getVector();
446
447 // Distribute and reduce across work-items in the subgroup.
448 Value fullReduce = xegpu::subgroupReduction(
449 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
450
451 // If there's an accumulator, combine it with the reduced value.
452 if (adaptor.getAcc())
453 fullReduce = vector::makeArithReduction(
454 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
455
456 rewriter.replaceOp(op, fullReduce);
457 return success();
458 }
459};
460
461/// This pattern distributes a subgroup-level vector.multi_reduction op to
462/// workitem-level only if the reduction is lane-local. This means that
463/// reduction dimension is not distributed to lanes and each lane does its own
464/// local reduction.
465struct SgToWiMultiDimReduction
466 : public OpConversionPattern<vector::MultiDimReductionOp> {
467 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
468
469 LogicalResult
470 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
471 ConversionPatternRewriter &rewriter) const override {
472 // Only lane-local reduction is handled here.
473 if (!isReductionLaneLocal(op))
474 return rewriter.notifyMatchFailure(
475 op, "Only lane-local reduction is supported, expected reduction "
476 "dimension to be "
477 "not distributed.");
478 auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
479 VectorType resVecTy = dyn_cast<VectorType>(op.getType());
480 auto resDistVecTyOrFailure =
481 getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
482 // Simply create a new MultiDimReductionOp using adaptor operands and the
483 // new result type.
484 auto newOp = vector::MultiDimReductionOp::create(
485 rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
486 adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
487 rewriter.replaceOp(op, newOp.getResult());
488 return success();
489 }
490};
491
492/// This pattern rewrites a subgroup-level vector.multi_reduction op to a series
493/// of vector.extract_strided_slice, vector.reduction and
494/// vector.insert_strided_slice ops. This is used when the reduction dimension
495/// is distributed to lanes and a naive (lane-local) distribution is not
496/// possible. Then later on, these partially lowered subgroup-level ops are
497/// further lowered to workitem-level by respective patterns.
498struct LowerVectorMultiReductionPattern
499 : public OpConversionPattern<vector::MultiDimReductionOp> {
500 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
501
502 LogicalResult
503 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
504 ConversionPatternRewriter &rewriter) const override {
505 // Only non-lane-local reduction is handled here.
506 if (isReductionLaneLocal(op))
507 return rewriter.notifyMatchFailure(
508 op, "Reduction is lane-local, it does not require rewrite.");
509 ArrayRef<int64_t> reductionDims = op.getReductionDims();
510 assert(
511 reductionDims.size() == 1 &&
512 "Expecting single reduction dimension for subgroup multi reduction op");
513
514 // Rewrite MultiDimReductionOp into a sequence of ReductionOps.
516 cast<TypedValue<VectorType>>(op.getSource()),
517 cast<TypedValue<VectorType>>(op.getAcc()), op.getKind(),
518 reductionDims[0], op.getLoc(), rewriter);
519
520 rewriter.replaceOp(op, result);
521 return success();
522 }
523};
524
525struct XeGPUSgToWiDistributeExperimentalPass
527 XeGPUSgToWiDistributeExperimentalPass> {
528 void runOnOperation() override;
529};
530
531} // namespace
532
533void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
534
535 // Verify if all XeGPU anchor ops and vector ops have result layouts.
536 // TODO: This can be removed once the full layout refactoring is done.
537 Operation *root = getOperation();
538 if (failed(verifyLayouts(root))) {
539 LLVM_DEBUG(DBGS() << "XeGPUSgToWiDistributeExperimentalPass: layout "
540 "verification failed\n");
541 signalPassFailure();
542 return;
543 }
544 // Collect existing UnrealizedConversionCastOps. These must be preserved.
545 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
546 root->walk(
547 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
548 // Perform a structural type conversion to convert structural ops to have WI
549 // types. This will insert UnrealizedConversionCastOps to make the IR
550 // valid.
551 auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
552 mlir::ValueRange inputs,
553 mlir::Location loc) -> mlir::Value {
554 UnrealizedConversionCastOp castOp =
555 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
556 return castOp.getResult(0);
557 };
558 {
559 ConversionTarget target(getContext());
560 TypeConverter typeConverter;
561 RewritePatternSet patterns(&getContext());
562 typeConverter.addSourceMaterialization(materializeCast);
563 typeConverter.addTargetMaterialization(materializeCast);
568 typeConverter, patterns, target);
569 target.addLegalOp<UnrealizedConversionCastOp>();
570 (void)applyPartialConversion(root, target, std::move(patterns));
571 }
572 // Structural type conversion can generate some redundant
573 // UnrealizedConversionCastOps to materialize the SG type from type converted
574 // WI type. These are redundant at this point and can be eliminated by
575 // inserting shape casts instead.
576 // Example:
577 // %1 = UnrealizedConversionCastOp %0 : vector<16x1xf32> to vector<16x16xf32>
578 // %2 = UnrealizedConversionCastOp %1 : vector<16x16xf32> to vector<16xf32>
579 // This can be replaced with:
580 // %2 = vector.shape_cast %0 : vector<16x1xf32> to vector<16xf32>
581 OpBuilder builder(root);
582 root->walk([&](UnrealizedConversionCastOp op) {
583 // If this op existed before, nothing to do.
584 if (existingCasts.contains(op))
585 return;
586 // number of inputs and outputs must be 1.
587 if (op.getNumOperands() != 1 || op.getNumResults() != 1)
588 return;
589 // Both input and output types must be vector types.
590 auto singleInput = op.getInputs()[0];
591 auto inputTy = dyn_cast<VectorType>(singleInput.getType());
592 auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
593 if (!inputTy || !outputTy)
594 return;
595
596 // Check if the defining op of the input is also an
597 // UnrealizedConversionCastOp and it has a single user (which is this
598 // op).
599 auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
600 if (!definingOp || !definingOp->hasOneUse())
601 return;
602 auto inputOfDefiningOp = definingOp.getInputs()[0];
603 // If the input of the defining op and output type are both vector types
604 // have same number of elements, insert a shape cast.
605 auto inputOfDefiningOpTy =
606 dyn_cast<VectorType>(inputOfDefiningOp.getType());
607 if (inputOfDefiningOpTy &&
608 inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
609 builder.setInsertionPoint(op);
610 auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
611 outputTy, inputOfDefiningOp);
612 op.replaceAllUsesWith(ValueRange{shapeCast.getResult()});
613 return;
614 }
615 });
616 // At this point, we will have some dead UnrealizedConversionCastOps. Just
617 // erase them.
618 bool changed = true;
619 while (changed) {
620 changed = false;
621 root->walk([&](UnrealizedConversionCastOp op) {
622 // Skip existing casts.
623 if (existingCasts.contains(op))
624 return;
625 if (op.use_empty()) {
626 op.erase();
627 changed = true;
628 }
629 });
630 }
631}
632
634 TypeConverter &typeConverter) {
635 // Any type other than TensorDescType and VectorType are legal as is.
636 typeConverter.addConversion([](Type type) -> std::optional<Type> {
637 if (!isa<TensorDescType, VectorType>(type))
638 return type;
639 return std::nullopt;
640 });
641 // For TensorDescType, drop the layout attribute if any.
642 typeConverter.addConversion([](TensorDescType type) -> Type {
643 if (type.getLayoutAttr()) {
644 return type.dropLayouts();
645 }
646 return type;
647 });
648 // For VectorType, check if there is a distribute layout attribute on the
649 // value. If so, convert to the distributed vector type based on the layout.
650 typeConverter.addConversion([](Value v) -> std::optional<Type> {
651 auto type = v.getType();
652 // If value is not vector type, nothing to do.
653 if (!isa<VectorType>(type))
654 return std::nullopt;
655 auto layout = xegpu::getDistributeLayoutAttr(v);
656 if (!layout || !layout.isForSubgroup())
657 return type;
658 // Vector type is distributed based on lane layout.
659 auto newTyOrFailure =
660 getDistVecTypeBasedOnLaneLayout(layout, cast<VectorType>(type));
661 if (failed(newTyOrFailure))
662 return type;
663 return *newTyOrFailure;
664 });
665}
666
671 // CreateNdDescOp is legal only if its result type has no layout attribute.
672 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
673 [&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
674 // Any anchor XeGPU op is legal only if it has no anchor layout.
675 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
676 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
677 if (!anchorOp)
678 return true;
679 return !anchorOp.getAnchorLayout();
680 });
681 // Arith constants are legal only if they have no temporary layout attribute.
682 target.addDynamicallyLegalOp<arith::ConstantOp>(
683 [=](arith::ConstantOp op) -> bool {
684 // If the result type is not a vector, it's legal.
685 if (!isa<VectorType>(op.getResult().getType()))
686 return true;
687 return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
688 });
689 // In math and arith dialects, only handle elementwise ops with a single
690 // result and with a result layout attribute.
691 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
692 [=](Operation *op) -> std::optional<bool> {
693 // Only handle elementwise mappable ops
695 return true;
696 // Only handle ops with single vector result
697 if (op->getNumResults() != 1)
698 return true;
699
700 VectorType resultType =
701 dyn_cast<VectorType>(op->getResult(0).getType());
702 if (!resultType)
703 return true;
704
705 // Check if all operands are vectors of the same shape
706 for (Value operand : op->getOperands()) {
707 VectorType operandType = dyn_cast<VectorType>(operand.getType());
708 if (!operandType || operandType.getShape() != resultType.getShape()) {
709 return true;
710 }
711 }
712 return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
713 });
714 // vector::ReductionOp is legal only if its source has no distribute layout
715 // attribute.
716 target.addDynamicallyLegalOp<vector::ReductionOp>(
717 [=](vector::ReductionOp op) -> bool {
718 auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
719 return !layout;
720 });
721 // vector::MultiDimReductionOp op legality.
722 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
723 [=](vector::MultiDimReductionOp op) -> bool {
724 // Check common conditions for subgroup multi reduction op.
725 if (!isValidSubgroupMultiReductionOp(op))
726 return true;
727 // Lane local reductions are illegal at this point and must be lowered.
728 return !isReductionLaneLocal(op);
729 });
730 target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
731 patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
732 SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
733 SgToWiVectorReduction, SgToWiMultiDimReduction>(
734 typeConverter, patterns.getContext());
735}
736
739 // vector::MultiDimReductionOp legality.
740 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
741 [&](vector::MultiDimReductionOp op) {
742 // Check common conditions for subgroup multi reduction op.
743 if (!isValidSubgroupMultiReductionOp(op))
744 return true;
745 // Lane local reductions are legal. We only rewrite non-lane-local
746 // reductions.
747 return isReductionLaneLocal(op);
748 });
749 // vector::ReductionOp is legal.
750 target.addDynamicallyLegalOp<vector::ReductionOp>(
751 [&](vector::ReductionOp op) { return true; });
752 target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
753 patterns.add<LowerVectorMultiReductionPattern>(patterns.getContext());
754}
return success()
#define DBGS()
Definition Hoisting.cpp:32
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
const uArch * getUArch(llvm::StringRef archName)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter)
Define only the type conversions needed for XeGPU subgroup to workitem distribution.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUSgToWiDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Defines type conversions and legality for XeGPU subgroup to workitem distribution and appends the req...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
void populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(RewritePatternSet &patterns, ConversionTarget &target)
Appends patterns to rewrite vector::MultiDimReductionOp in terms of vector::ReductionOps if the multi...
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)