15#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "xegpu-transforms"
33 if (
auto attr = dyn_cast<Attribute>(ofr)) {
34 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
35 result.push_back(intAttr.getInt());
38 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
42 Value transformValue = cast<Value>(ofr);
43 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
45 if (params.size() != 1)
46 return transformOp.emitDefiniteFailure()
47 <<
"requires exactly one parameter associated";
49 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
55 if (!llvm::hasSingleElement(payloadOps)) {
57 transformOp.emitSilenceableError()
58 <<
"handle must be mapped to exactly one payload op";
60 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
67 transformOp.emitSilenceableError()
68 <<
"payload op must have exactly 1 index result";
76 return transformOp.emitSilenceableError()
77 <<
"requires param or handle to be the result of a constant like "
80 result.push_back(intAttr.getInt());
90 Value currentValue = val;
94 LDBG() <<
"Failed to find producer op, value has no uses.";
97 auto userOp = val.
getUsers().begin();
98 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
100 LDBG() <<
"Failed to find producer op, not in a loop.";
104 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
105 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
106 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
107 currentValue = parentLoop.getInits()[iterArgIdx];
109 LDBG() <<
"Failed to find producer op, value not in init values.";
115 if (
auto matchingOp = dyn_cast<T>(producerOp))
128 return xegpu::LayoutAttr::get(
140 TransformOpInterface transformOp,
145 xegpu::LayoutAttr &layoutAttr) {
149 if (!status.succeeded())
153 if (!status.succeeded())
157 if (!status.succeeded())
159 auto maybeInstData = instData.empty()
161 : std::optional<ArrayRef<int32_t>>(instData);
169static xegpu::CreateNdDescOp
171 xegpu::CreateNdDescOp descOp,
172 xegpu::DistributeLayoutAttr layout) {
173 assert(descOp.getMixedOffsets().size() == 0 &&
174 "create desc op with offsets is not supported");
175 auto oldTensorDesc = descOp.getType();
176 auto descType = xegpu::TensorDescType::get(
177 oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
178 oldTensorDesc.getArrayLength(),
179 oldTensorDesc.getBoundaryCheck(),
180 oldTensorDesc.getMemorySpace(),
185 descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
186 descOp.getMixedStrides());
195 if (!llvm::hasSingleElement(targetValues)) {
197 <<
"requires exactly one target value handle (got "
198 << llvm::range_size(targetValues) <<
")";
205 <<
"Could not find a matching descriptor op when walking the "
206 "producer chain of the first operand.";
209 results.
set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
213void transform::SetDescLayoutOp::build(
OpBuilder &builder,
242 if (!llvm::hasSingleElement(targetOps)) {
244 << llvm::range_size(targetOps) <<
")";
248 xegpu::LayoutAttr layoutAttr =
nullptr;
250 getContext(), state, (*
this), getMixedSgLayout(), getMixedSgData(),
251 getMixedInstData(), getOrder(), layoutAttr);
252 if (!status.succeeded())
255 xegpu::DistributeLayoutAttr layout = layoutAttr;
256 auto sliceDims = getSliceDims();
257 if (sliceDims.size() > 0) {
259 layout = xegpu::SliceAttr::get(
264 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(
target);
267 <<
"Expected a xegpu.create_nd_desc op, but got: "
269 diag.attachNote(
target->getLoc()) <<
"target op";
277 results.
set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
282void transform::SetDescLayoutOp::getEffects(
292void transform::SetOpLayoutAttrOp::build(
302 build(builder, ostate,
target.getType(),
322 if (!llvm::hasSingleElement(targetOps)) {
324 << llvm::range_size(targetOps) <<
")";
328 bool resultTarget = getResult();
329 bool operandTarget = getOperand();
332 if (resultTarget &&
index >=
target->getNumResults()) {
334 <<
"Index exceeds the number of op results";
336 if (operandTarget &&
index >=
target->getNumOperands()) {
338 <<
"Index exceeds the number of op operands";
341 xegpu::LayoutAttr layoutAttr =
nullptr;
343 getContext(), state, (*
this), getMixedSgLayout(), getMixedSgData(),
344 getMixedInstData(), getOrder(), layoutAttr);
345 if (!status.succeeded())
348 xegpu::DistributeLayoutAttr layout = layoutAttr;
349 auto sliceDims = getSliceDims();
350 if (sliceDims.size() > 0) {
352 layout = xegpu::SliceAttr::get(
360 }
else if (operandTarget) {
363 }
else if (
auto dpasOp = dyn_cast<xegpu::DpasOp>(
target)) {
366 dpasOp.getProperties().layout_a = layout;
368 dpasOp.getProperties().layout_b = layout;
370 dpasOp.getProperties().layout_cd = layout;
373 <<
"Invalid index for setting dpas op layout: " <<
index;
374 diag.attachNote(
target->getLoc()) <<
"target op";
379 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(
target);
382 <<
"Cannot set anchor layout to op: " <<
target->getName();
383 diag.attachNote(
target->getLoc()) <<
"target op";
386 anchorOp.setAnchorLayout(layout);
391void transform::SetOpLayoutAttrOp::getEffects(
400LogicalResult transform::SetOpLayoutAttrOp::verify() {
401 if (getResult() && getOperand()) {
402 return emitOpError(
"Cannot set both result and operand simultaneously.");
407void transform::SetGPULaunchThreadsOp::build(
413 build(builder, ostate,
target.getType(),
424 if (!llvm::hasSingleElement(targetOps)) {
426 << llvm::range_size(targetOps) <<
")";
430 auto launchOp = dyn_cast<gpu::LaunchOp>(
target);
433 <<
"Expected a gpu.launch op, but got: " <<
target->getName();
434 diag.attachNote(
target->getLoc()) <<
"target op";
444 if (threads.size() != 3) {
446 <<
"Expected threads argument to consist of three values (got "
447 << threads.size() <<
")";
451 auto createConstValue = [&](
int value) {
456 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
457 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
458 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
463void transform::SetGPULaunchThreadsOp::getEffects(
475 if (!llvm::hasSingleElement(targetValues))
477 <<
"requires exactly one target value handle (got "
478 << llvm::range_size(targetValues) <<
")";
479 auto value = *targetValues.begin();
481 int64_t nbPrefetch = getStaticNbPrefetch();
482 if (getDynamicNbPrefetch()) {
486 {getDynamicNbPrefetch()});
489 if (dynamicNbPrefetch.size() != 1)
491 <<
"requires exactly one value for dynamic_nb_prefetch";
492 nbPrefetch = dynamicNbPrefetch[0];
496 <<
"nb_prefetch must be a positive integer.";
502 auto loadOp = *maybeLoadOp;
503 if (loadOp.getMixedOffsets().size() == 0) {
505 <<
"Load op must have offsets.";
506 diag.attachNote(loadOp.getLoc()) <<
"load op";
511 auto forOp = loadOp->getParentOfType<scf::ForOp>();
514 <<
"Load op is not contained in a scf.for loop.";
515 diag.attachNote(loadOp.getLoc()) <<
"load op";
523 auto descOp = *maybeDescOp;
524 if (descOp.getMixedOffsets().size() > 0) {
526 <<
"desc op with offsets is not supported.";
527 diag.attachNote(descOp.getLoc()) <<
"desc op";
533 cast<xegpu::CreateNdDescOp>(rewriter.
clone(*descOp.getOperation()));
540 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
541 auto initUpBound = rewriter.
createOrFold<arith::AddIOp>(
542 forOp.getLoc(), forOp.getLowerBound(), nbStep);
544 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
545 initUpBound, forOp.getStep());
549 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
553 auto getPrefetchOffsets =
556 mapping.
map(forOp.getInductionVar(), replacementVal);
558 llvm::map_to_vector(loadOp.getOffsets(), [&](
Value v) {
559 return mapping.lookupOrDefault(v);
561 auto constOffsets = loadOp.getConstOffsets().value();
568 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
569 newDescOp.getResult(),
570 getPrefetchOffsets(initForOp.getInductionVar()),
571 readCacheHint, readCacheHint, readCacheHint,
577 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
578 forOp.getInductionVar(), nbStep);
580 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
581 newDescOp.getResult(),
582 getPrefetchOffsets(prefetchOffset), readCacheHint,
583 readCacheHint, readCacheHint,
nullptr);
589 results.
set(llvm::cast<OpResult>(getResult()), {newDescOp});
594void transform::InsertPrefetchOp::getEffects(
602void transform::ConvertLayoutOp::build(
613 dynamicInputInstData;
615 staticInputSgLayout);
619 staticInputInstData);
621 staticTargetInstData;
623 dynamicTargetInstData;
625 staticTargetSgLayout);
629 staticTargetInstData);
630 build(builder, ostate,
target.getType(),
632 dynamicInputSgLayout,
634 dynamicInputInstData,
635 dynamicTargetSgLayout,
637 dynamicTargetInstData,
642 staticTargetSgLayout,
644 staticTargetInstData,
653 if (!llvm::hasSingleElement(targetValues))
655 <<
"requires exactly one target value handle (got "
656 << llvm::range_size(targetValues) <<
")";
657 auto value = *targetValues.begin();
660 xegpu::LayoutAttr inputLayoutAttr =
nullptr;
662 getContext(), state, (*
this), getMixedInputSgLayout(),
663 getMixedInputSgData(), getMixedInputInstData(), getInputOrder(),
668 xegpu::LayoutAttr targetLayoutAttr =
nullptr;
670 getContext(), state, (*
this), getMixedTargetSgLayout(),
671 getMixedTargetSgData(), getMixedTargetInstData(), getTargetOrder(),
677 if (value.use_empty())
679 <<
"Value has no users to insert layout conversion.";
685 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
686 value, inputLayoutAttr, targetLayoutAttr);
689 value, convLayoutOp.getResult(), [&](
OpOperand &use) {
690 return use.getOwner() != convLayoutOp.getOperation();
693 results.
set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
697void transform::ConvertLayoutOp::getEffects(
711class XeGPUTransformDialectExtension
713 XeGPUTransformDialectExtension> {
722void XeGPUTransformDialectExtension::init() {
723 declareGeneratedDialect<scf::SCFDialect>();
724 declareGeneratedDialect<arith::ArithDialect>();
725 declareGeneratedDialect<xegpu::XeGPUDialect>();
727 registerTransformOps<
729#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
734#define GET_OP_CLASSES
735#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
user_range getUsers()
Returns a range of all users.
unsigned getNumResults()
Return the number of results held by this operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
unsigned getNumUses() const
This method computes the number of uses of this Value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
void registerTransformDialectExtension(DialectRegistry ®istry)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.