27#include "llvm/ADT/SetVector.h"
28#include "llvm/Support/LogicalResult.h"
29#include "llvm/Support/raw_ostream.h"
34#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
35#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
41#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
42#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
47static Value castValueTo(ConversionPatternRewriter &rewriter,
50 if (v.getType() == expectedTy)
53 if (isa<VectorType>(v.getType()) &&
54 v.getType().getNumElements() == expectedTy.getNumElements())
55 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
58 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
60 return newOp.getResult(0);
64static LogicalResult verifyLayouts(
Operation *root) {
66 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(nestedOp)) {
67 auto layout = anchorOp.getAnchorLayout();
69 nestedOp->
emitError(
"expected anchor layout attribute on operation");
77 if (isa<VectorType>(
result.getType())) {
81 "expected result layout attribute on vector result");
88 return walkResult.wasInterrupted() ? failure() :
success();
94static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
97 if (!resLayout || !resLayout.isForSubgroup())
99 VectorType resTy = dyn_cast<VectorType>(op.getType());
103 FailureOr<VectorType> resDistTypeOrFailure =
104 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
105 if (failed(resDistTypeOrFailure))
107 return op.getReductionDims().size() == 1;
114static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
116 assert(isValidSubgroupMultiReductionOp(op) &&
"Expecting a valid subgroup "
117 "MultiDimReductionOp");
119 VectorType resTy = dyn_cast<VectorType>(op.getType());
120 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
121 return resTy != resDistTypeOrFailure.value();
126struct SgToWiCreateNdDesc :
public OpConversionPattern<xegpu::CreateNdDescOp> {
127 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
130 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter)
const override {
132 xegpu::TensorDescType resultType = op.getType();
134 if (!resultType.getLayout())
137 auto newOp = xegpu::CreateNdDescOp::create(
138 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
140 rewriter.replaceOp(op, newOp.getResult());
148struct SgToWiLoadNd :
public OpConversionPattern<xegpu::LoadNdOp> {
149 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
152 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
153 ConversionPatternRewriter &rewriter)
const override {
154 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
160 if (op.getTensorDescType().getLayout() != layout)
161 return rewriter.notifyMatchFailure(
162 op,
"conflicting layout attributes on tensor descriptor and anchor");
165 return rewriter.notifyMatchFailure(
166 op,
"xegpu::LoadNdOp require target attribute attached to "
167 "determine transpose "
169 auto supportedWiResultTyOrFailure =
171 auto expectedWiResultTyOrFailure =
173 if (failed(supportedWiResultTyOrFailure))
174 return rewriter.notifyMatchFailure(
175 op,
"unable to compute the workitem vector type for LoadNdOp");
176 if (failed(expectedWiResultTyOrFailure))
177 return rewriter.notifyMatchFailure(
179 "unable to compute expected workitem vector type from lane layout");
180 auto newOp = xegpu::LoadNdOp::create(
181 rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
182 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
183 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
184 op.getL3HintAttr(),
nullptr);
190 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
191 expectedWiResultTyOrFailure.value()));
199struct SgToWiStoreNd :
public OpConversionPattern<xegpu::StoreNdOp> {
200 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
203 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
204 ConversionPatternRewriter &rewriter)
const override {
205 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
211 if (op.getTensorDescType().getLayout() != layout)
212 return rewriter.notifyMatchFailure(
213 op,
"conflicting layout attributes on tensor descriptor and anchor");
215 if (valueLayout != layout)
216 return rewriter.notifyMatchFailure(
217 op,
"conflicting layout attributes on value and anchor");
218 auto supportedWiValueTyOrFailure =
220 if (failed(supportedWiValueTyOrFailure))
221 return rewriter.notifyMatchFailure(
223 "unable to compute wi vector type for StoreNdOp value from tensor "
226 xegpu::StoreNdOp::create(
227 rewriter, op.getLoc(),
229 supportedWiValueTyOrFailure.value()),
230 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
231 op.getL2HintAttr(), op.getL3HintAttr(),
nullptr);
232 rewriter.eraseOp(op);
240struct SgToWiDpas :
public OpConversionPattern<xegpu::DpasOp> {
241 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
244 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter)
const override {
248 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
249 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
250 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
251 if (!layoutA || !layoutB || !layoutCd)
254 auto wiResultTyOrFailure =
256 auto wiATypeOrFailure =
258 auto wiBTypeOrFailure =
260 auto expectedWiResultTyOrFailure =
262 if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
263 failed(wiBTypeOrFailure))
264 return rewriter.notifyMatchFailure(
265 op,
"failed to calculate supported workitem vector types for DpasOp "
267 if (failed(expectedWiResultTyOrFailure))
268 return rewriter.notifyMatchFailure(
269 op,
"unable to compute expected workitem vector type for DpasOp from "
271 auto newOp = xegpu::DpasOp::create(
272 rewriter, op->getLoc(), wiResultTyOrFailure.value(),
274 wiATypeOrFailure.value()),
276 wiBTypeOrFailure.value()),
278 wiResultTyOrFailure.value()),
282 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
283 expectedWiResultTyOrFailure.value()));
296 ConversionPatternRewriter &rewriter)
const override {
303 return rewriter.notifyMatchFailure(
304 op,
"operation result is not a vector type");
306 xegpu::DistributeLayoutAttr layout =
308 if (!layout || !layout.isForSubgroup())
309 return rewriter.notifyMatchFailure(
310 op,
"operation result does not have subgroup distribute layout");
312 auto wiShapeOrFailure =
315 if (failed(wiShapeOrFailure))
316 return rewriter.notifyMatchFailure(
317 op,
"unable to compute workitem vector type from the layout");
319 VectorType newResultType = wiShapeOrFailure.value();
325 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
328 Operation *newOp = rewriter.create(state);
330 rewriter.replaceOp(op, newOp->
getResult(0));
337struct SgToWiArithConstant :
public OpConversionPattern<arith::ConstantOp> {
338 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
341 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter)
const override {
343 auto resultType = dyn_cast<VectorType>(op.getType());
348 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
350 return rewriter.notifyMatchFailure(
351 op,
"only dense splat vector constants are supported");
353 xegpu::DistributeLayoutAttr layout =
355 if (!layout || !layout.isForSubgroup())
356 return rewriter.notifyMatchFailure(
357 op,
"operation result does not have subgroup distribute layout");
359 auto wiShapeOrFailure =
362 if (failed(wiShapeOrFailure))
363 return rewriter.notifyMatchFailure(
364 op,
"unable to compute workitem vector type from the layout");
366 VectorType newResultType = wiShapeOrFailure.value();
370 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
372 rewriter.replaceOp(op, newOp.getResult());
378struct SgToWiPrefetchNd :
public OpConversionPattern<xegpu::PrefetchNdOp> {
379 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
382 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
383 ConversionPatternRewriter &rewriter)
const override {
384 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
389 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
390 op.getMixedOffsets(), op.getL1HintAttr(),
391 op.getL2HintAttr(), op.getL3HintAttr(),
393 rewriter.eraseOp(op);
402struct SgToWiVectorReduction :
public OpConversionPattern<vector::ReductionOp> {
403 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
406 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter)
const override {
411 if (!layout || !layout.isForSubgroup())
414 VectorType srcVecType = op.getSourceVectorType();
416 if (srcVecType.getRank() != 1)
417 return rewriter.notifyMatchFailure(
418 op,
"Only rank 1 reductions can be distributed.");
420 if (layout.getRank() != srcVecType.getRank())
421 return rewriter.notifyMatchFailure(
422 op,
"Layout rank does not match vector rank.");
425 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
428 return rewriter.notifyMatchFailure(
429 op,
"xegpu::ReductionOp require target attribute attached to "
430 "determine subgroup size");
433 if (sgSize !=
uArch->getSubgroupSize() ||
434 srcVecType.getShape()[0] % sgSize != 0)
435 return rewriter.notifyMatchFailure(op,
436 "Invalid layout or reduction vector "
437 "dimension must match subgroup size.");
439 if (!op.getType().isIntOrFloat())
440 return rewriter.notifyMatchFailure(
441 op,
"Reduction distribution currently only supports floats and "
445 Value laneValVec = adaptor.getVector();
449 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
452 if (adaptor.getAcc())
454 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
456 rewriter.replaceOp(op, fullReduce);
465struct SgToWiMultiDimReduction
466 :
public OpConversionPattern<vector::MultiDimReductionOp> {
467 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
470 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
471 ConversionPatternRewriter &rewriter)
const override {
473 if (!isReductionLaneLocal(op))
474 return rewriter.notifyMatchFailure(
475 op,
"Only lane-local reduction is supported, expected reduction "
479 VectorType resVecTy = dyn_cast<VectorType>(op.getType());
480 auto resDistVecTyOrFailure =
481 getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
484 auto newOp = vector::MultiDimReductionOp::create(
485 rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
486 adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
487 rewriter.replaceOp(op, newOp.getResult());
498struct LowerVectorMultiReductionPattern
499 :
public OpConversionPattern<vector::MultiDimReductionOp> {
500 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
503 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
504 ConversionPatternRewriter &rewriter)
const override {
506 if (isReductionLaneLocal(op))
507 return rewriter.notifyMatchFailure(
508 op,
"Reduction is lane-local, it does not require rewrite.");
509 ArrayRef<int64_t> reductionDims = op.getReductionDims();
511 reductionDims.size() == 1 &&
512 "Expecting single reduction dimension for subgroup multi reduction op");
518 reductionDims[0], op.getLoc(), rewriter);
520 rewriter.replaceOp(op,
result);
525struct XeGPUSgToWiDistributeExperimentalPass
527 XeGPUSgToWiDistributeExperimentalPass> {
528 void runOnOperation()
override;
533void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
537 Operation *root = getOperation();
538 if (
failed(verifyLayouts(root))) {
539 LLVM_DEBUG(
DBGS() <<
"XeGPUSgToWiDistributeExperimentalPass: layout "
540 "verification failed\n");
545 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
547 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
551 auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
552 mlir::ValueRange inputs,
553 mlir::Location loc) -> mlir::Value {
554 UnrealizedConversionCastOp castOp =
555 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
556 return castOp.getResult(0);
560 TypeConverter typeConverter;
562 typeConverter.addSourceMaterialization(materializeCast);
563 typeConverter.addTargetMaterialization(materializeCast);
569 target.addLegalOp<UnrealizedConversionCastOp>();
581 OpBuilder builder(root);
582 root->
walk([&](UnrealizedConversionCastOp op) {
584 if (existingCasts.contains(op))
587 if (op.getNumOperands() != 1 || op.getNumResults() != 1)
590 auto singleInput = op.getInputs()[0];
591 auto inputTy = dyn_cast<VectorType>(singleInput.getType());
592 auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
593 if (!inputTy || !outputTy)
599 auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
600 if (!definingOp || !definingOp->hasOneUse())
602 auto inputOfDefiningOp = definingOp.getInputs()[0];
605 auto inputOfDefiningOpTy =
606 dyn_cast<VectorType>(inputOfDefiningOp.getType());
607 if (inputOfDefiningOpTy &&
608 inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
610 auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
611 outputTy, inputOfDefiningOp);
612 op.replaceAllUsesWith(
ValueRange{shapeCast.getResult()});
621 root->
walk([&](UnrealizedConversionCastOp op) {
623 if (existingCasts.contains(op))
625 if (op.use_empty()) {
636 typeConverter.addConversion([](
Type type) -> std::optional<Type> {
637 if (!isa<TensorDescType, VectorType>(type))
642 typeConverter.addConversion([](TensorDescType type) ->
Type {
643 if (type.getLayoutAttr()) {
644 return type.dropLayouts();
650 typeConverter.addConversion([](
Value v) -> std::optional<Type> {
653 if (!isa<VectorType>(type))
656 if (!layout || !layout.isForSubgroup())
659 auto newTyOrFailure =
661 if (failed(newTyOrFailure))
663 return *newTyOrFailure;
672 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
673 [&](xegpu::CreateNdDescOp op) {
return !op.getType().getLayoutAttr(); });
675 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](
Operation *op) {
676 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
679 return !anchorOp.getAnchorLayout();
682 target.addDynamicallyLegalOp<arith::ConstantOp>(
683 [=](arith::ConstantOp op) ->
bool {
685 if (!isa<VectorType>(op.getResult().getType()))
691 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
692 [=](
Operation *op) -> std::optional<bool> {
697 if (op->getNumResults() != 1)
700 VectorType resultType =
701 dyn_cast<VectorType>(op->getResult(0).getType());
706 for (
Value operand : op->getOperands()) {
707 VectorType operandType = dyn_cast<VectorType>(operand.getType());
708 if (!operandType || operandType.getShape() != resultType.getShape()) {
716 target.addDynamicallyLegalOp<vector::ReductionOp>(
717 [=](vector::ReductionOp op) ->
bool {
722 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
723 [=](vector::MultiDimReductionOp op) ->
bool {
725 if (!isValidSubgroupMultiReductionOp(op))
728 return !isReductionLaneLocal(op);
730 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
return true; });
731 patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
732 SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
733 SgToWiVectorReduction, SgToWiMultiDimReduction>(
734 typeConverter,
patterns.getContext());
740 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
741 [&](vector::MultiDimReductionOp op) {
743 if (!isValidSubgroupMultiReductionOp(op))
747 return isReductionLaneLocal(op);
750 target.addDynamicallyLegalOp<vector::ReductionOp>(
751 [&](vector::ReductionOp op) {
return true; });
752 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
return true; });
Attributes are known-constant values of operations.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
const uArch * getUArch(llvm::StringRef archName)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter)
Define only the type conversions needed for XeGPU subgroup to workitem distribution.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUSgToWiDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Defines type conversions and legality for XeGPU subgroup to workitem distribution and appends the req...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
void populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(RewritePatternSet &patterns, ConversionTarget &target)
Appends patterns to rewrite vector::MultiDimReductionOp in terms of vector::ReductionOps if the multi...
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)