27#include "llvm/ADT/SetVector.h"
28#include "llvm/Support/LogicalResult.h"
29#include "llvm/Support/raw_ostream.h"
34#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
35#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
41#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
42#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
47static Value castValueTo(ConversionPatternRewriter &rewriter,
50 if (v.getType() == expectedTy)
53 if (isa<VectorType>(v.getType()) &&
54 v.getType().getNumElements() == expectedTy.getNumElements())
55 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
58 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
60 return newOp.getResult(0);
64static LogicalResult verifyLayouts(
Operation *root) {
66 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(nestedOp)) {
67 auto layout = anchorOp.getAnchorLayout();
69 nestedOp->
emitError(
"expected anchor layout attribute on operation");
77 if (isa<VectorType>(
result.getType())) {
81 "expected result layout attribute on vector result");
88 return walkResult.wasInterrupted() ? failure() :
success();
93struct SgToWiCreateNdDesc :
public OpConversionPattern<xegpu::CreateNdDescOp> {
94 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
97 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
98 ConversionPatternRewriter &rewriter)
const override {
99 xegpu::TensorDescType resultType = op.getType();
101 if (!resultType.getLayout())
104 auto newOp = xegpu::CreateNdDescOp::create(
105 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
107 rewriter.replaceOp(op, newOp.getResult());
115struct SgToWiLoadNd :
public OpConversionPattern<xegpu::LoadNdOp> {
116 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
119 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter)
const override {
121 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
127 if (op.getTensorDescType().getLayout() != layout)
128 return rewriter.notifyMatchFailure(
129 op,
"conflicting layout attributes on tensor descriptor and anchor");
132 return rewriter.notifyMatchFailure(
133 op,
"xegpu::LoadNdOp require target attribute attached to "
134 "determine transpose "
136 auto supportedWiResultTyOrFailure =
138 auto expectedWiResultTyOrFailure =
140 if (failed(supportedWiResultTyOrFailure))
141 return rewriter.notifyMatchFailure(
142 op,
"unable to compute the workitem vector type for LoadNdOp");
143 if (failed(expectedWiResultTyOrFailure))
144 return rewriter.notifyMatchFailure(
146 "unable to compute expected workitem vector type from lane layout");
147 auto newOp = xegpu::LoadNdOp::create(
148 rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
149 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
150 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
151 op.getL3HintAttr(),
nullptr);
157 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
158 expectedWiResultTyOrFailure.value()));
166struct SgToWiStoreNd :
public OpConversionPattern<xegpu::StoreNdOp> {
167 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
170 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
171 ConversionPatternRewriter &rewriter)
const override {
172 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
178 if (op.getTensorDescType().getLayout() != layout)
179 return rewriter.notifyMatchFailure(
180 op,
"conflicting layout attributes on tensor descriptor and anchor");
182 if (valueLayout != layout)
183 return rewriter.notifyMatchFailure(
184 op,
"conflicting layout attributes on value and anchor");
185 auto supportedWiValueTyOrFailure =
187 if (failed(supportedWiValueTyOrFailure))
188 return rewriter.notifyMatchFailure(
190 "unable to compute wi vector type for StoreNdOp value from tensor "
193 xegpu::StoreNdOp::create(
194 rewriter, op.getLoc(),
196 supportedWiValueTyOrFailure.value()),
197 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
198 op.getL2HintAttr(), op.getL3HintAttr(),
nullptr);
199 rewriter.eraseOp(op);
207struct SgToWiDpas :
public OpConversionPattern<xegpu::DpasOp> {
208 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
211 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
212 ConversionPatternRewriter &rewriter)
const override {
215 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
216 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
217 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
218 if (!layoutA || !layoutB || !layoutCd)
221 auto wiResultTyOrFailure =
223 auto wiATypeOrFailure =
225 auto wiBTypeOrFailure =
227 auto expectedWiResultTyOrFailure =
229 if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
230 failed(wiBTypeOrFailure))
231 return rewriter.notifyMatchFailure(
232 op,
"failed to calculate supported workitem vector types for DpasOp "
234 if (failed(expectedWiResultTyOrFailure))
235 return rewriter.notifyMatchFailure(
236 op,
"unable to compute expected workitem vector type for DpasOp from "
238 auto newOp = xegpu::DpasOp::create(
239 rewriter, op->getLoc(), wiResultTyOrFailure.value(),
241 wiATypeOrFailure.value()),
243 wiBTypeOrFailure.value()),
245 wiResultTyOrFailure.value()),
249 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
250 expectedWiResultTyOrFailure.value()));
263 ConversionPatternRewriter &rewriter)
const override {
270 return rewriter.notifyMatchFailure(
271 op,
"operation result is not a vector type");
273 xegpu::DistributeLayoutAttr layout =
275 if (!layout || !layout.isForSubgroup())
276 return rewriter.notifyMatchFailure(
277 op,
"operation result does not have subgroup distribute layout");
279 auto wiShapeOrFailure =
282 if (failed(wiShapeOrFailure))
283 return rewriter.notifyMatchFailure(
284 op,
"unable to compute workitem vector type from the layout");
286 VectorType newResultType = wiShapeOrFailure.value();
292 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
295 Operation *newOp = rewriter.create(state);
297 rewriter.replaceOp(op, newOp->
getResult(0));
304struct SgToWiArithConstant :
public OpConversionPattern<arith::ConstantOp> {
305 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
308 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
309 ConversionPatternRewriter &rewriter)
const override {
310 auto resultType = dyn_cast<VectorType>(op.getType());
315 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
317 return rewriter.notifyMatchFailure(
318 op,
"only dense splat vector constants are supported");
320 xegpu::DistributeLayoutAttr layout =
322 if (!layout || !layout.isForSubgroup())
323 return rewriter.notifyMatchFailure(
324 op,
"operation result does not have subgroup distribute layout");
326 auto wiShapeOrFailure =
329 if (failed(wiShapeOrFailure))
330 return rewriter.notifyMatchFailure(
331 op,
"unable to compute workitem vector type from the layout");
333 VectorType newResultType = wiShapeOrFailure.value();
334 auto sclarValue = dense.getSplatValue<
Attribute>();
337 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
339 rewriter.replaceOp(op, newOp.getResult());
345struct SgToWiPrefetchNd :
public OpConversionPattern<xegpu::PrefetchNdOp> {
346 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
349 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
350 ConversionPatternRewriter &rewriter)
const override {
351 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
356 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
357 op.getMixedOffsets(), op.getL1HintAttr(),
358 op.getL2HintAttr(), op.getL3HintAttr(),
360 rewriter.eraseOp(op);
365struct XeGPUSgToWiDistributeExperimentalPass
367 XeGPUSgToWiDistributeExperimentalPass> {
373void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
378 if (failed(verifyLayouts(root))) {
379 LLVM_DEBUG(
DBGS() <<
"XeGPUSgToWiDistributeExperimentalPass: layout "
380 "verification failed\n");
385 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
387 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
394 UnrealizedConversionCastOp castOp =
395 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
396 return castOp.getResult(0);
402 typeConverter.addSourceMaterialization(materializeCast);
403 typeConverter.addTargetMaterialization(materializeCast);
409 target.addLegalOp<UnrealizedConversionCastOp>();
422 root->
walk([&](UnrealizedConversionCastOp op) {
424 if (existingCasts.contains(op))
427 if (op.getNumOperands() != 1 || op.getNumResults() != 1)
430 auto singleInput = op.getInputs()[0];
431 auto inputTy = dyn_cast<VectorType>(singleInput.getType());
432 auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
433 if (!inputTy || !outputTy)
439 auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
440 if (!definingOp || !definingOp->hasOneUse())
442 auto inputOfDefiningOp = definingOp.getInputs()[0];
445 auto inputOfDefiningOpTy =
446 dyn_cast<VectorType>(inputOfDefiningOp.getType());
447 if (inputOfDefiningOpTy &&
448 inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
450 auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
451 outputTy, inputOfDefiningOp);
452 op.replaceAllUsesWith(
ValueRange{shapeCast.getResult()});
461 root->
walk([&](UnrealizedConversionCastOp op) {
463 if (existingCasts.contains(op))
465 if (op.use_empty()) {
476 typeConverter.addConversion([](
Type type) -> std::optional<Type> {
477 if (!isa<TensorDescType, VectorType>(type))
482 typeConverter.addConversion([](TensorDescType type) ->
Type {
483 if (type.getLayoutAttr()) {
484 return type.dropLayouts();
490 typeConverter.addConversion([](
Value v) -> std::optional<Type> {
493 if (!isa<VectorType>(type))
496 if (!layout || !layout.isForSubgroup())
499 auto newTyOrFailure =
501 if (failed(newTyOrFailure))
503 return *newTyOrFailure;
512 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
513 [&](xegpu::CreateNdDescOp op) {
return !op.getType().getLayoutAttr(); });
515 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](
Operation *op) {
516 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
519 return !anchorOp.getAnchorLayout();
522 target.addDynamicallyLegalOp<arith::ConstantOp>(
523 [=](arith::ConstantOp op) ->
bool {
525 if (!isa<VectorType>(op.getResult().getType()))
531 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
532 [=](
Operation *op) -> std::optional<bool> {
537 if (op->getNumResults() != 1)
540 VectorType resultType =
541 dyn_cast<VectorType>(op->getResult(0).getType());
546 for (
Value operand : op->getOperands()) {
547 VectorType operandType = dyn_cast<VectorType>(operand.getType());
548 if (!operandType || operandType.getShape() != resultType.getShape()) {
554 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
return true; });
555 patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
556 SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd>(
557 typeConverter,
patterns.getContext());
Attributes are known-constant values of operations.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This is a value defined by a result of an operation.
OpT getOperation()
Return the current operation being transformed.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
void signalPassFailure()
Signal that some invariant was broken when running.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter)
Define only the type conversions needed for XeGPU subgroup to workitem distribution.
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUSgToWiDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Defines type conversions and legality for XeGPU subgroup to workitem distribution and appends the req...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)