18#include "llvm/Support/DebugLog.h"
19#define DEBUG_TYPE "xegpu-transforms"
32 if (
auto attr = dyn_cast<Attribute>(ofr)) {
33 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
34 result.push_back(intAttr.getInt());
37 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
41 Value transformValue = cast<Value>(ofr);
42 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
44 if (params.size() != 1)
45 return transformOp.emitDefiniteFailure()
46 <<
"requires exactly one parameter associated";
48 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
54 if (!llvm::hasSingleElement(payloadOps)) {
56 transformOp.emitSilenceableError()
57 <<
"handle must be mapped to exactly one payload op";
59 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
66 transformOp.emitSilenceableError()
67 <<
"payload op must have exactly 1 index result";
75 return transformOp.emitSilenceableError()
76 <<
"requires param or handle to be the result of a constant like "
79 result.push_back(intAttr.getInt());
89 Value currentValue = val;
93 LDBG() <<
"Failed to find producer op, value has no uses.";
96 auto userOp = val.
getUsers().begin();
97 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
99 LDBG() <<
"Failed to find producer op, not in a loop.";
103 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
104 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
105 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
106 currentValue = parentLoop.getInits()[iterArgIdx];
108 LDBG() <<
"Failed to find producer op, value not in init values.";
114 if (
auto matchingOp = dyn_cast<T>(producerOp))
124static xegpu::LayoutAttr
128 return xegpu::LayoutAttr::get(
140 TransformOpInterface transformOp,
144 xegpu::LayoutAttr &layoutAttr) {
148 if (!status.succeeded())
152 if (!status.succeeded())
156 if (!status.succeeded())
158 auto maybeInstData = instData.empty()
160 : std::optional<ArrayRef<int32_t>>(instData);
168static xegpu::CreateNdDescOp
170 xegpu::CreateNdDescOp descOp,
171 xegpu::DistributeLayoutAttr layout) {
172 assert(descOp.getMixedOffsets().size() == 0 &&
173 "create desc op with offsets is not supported");
174 auto oldTensorDesc = descOp.getType();
175 auto descType = xegpu::TensorDescType::get(
176 oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
177 oldTensorDesc.getArrayLength(),
178 oldTensorDesc.getBoundaryCheck(),
179 oldTensorDesc.getMemorySpace(),
184 descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
185 descOp.getMixedStrides());
194 if (!llvm::hasSingleElement(targetValues)) {
196 <<
"requires exactly one target value handle (got "
197 << llvm::range_size(targetValues) <<
")";
204 <<
"Could not find a matching descriptor op when walking the "
205 "producer chain of the first operand.";
208 results.
set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
212void transform::SetDescLayoutOp::build(
OpBuilder &builder,
239 if (!llvm::hasSingleElement(targetOps)) {
241 << llvm::range_size(targetOps) <<
")";
245 xegpu::LayoutAttr layoutAttr =
nullptr;
247 getMixedSgLayout(), getMixedSgData(),
248 getMixedInstData(), layoutAttr);
249 if (!status.succeeded())
252 xegpu::DistributeLayoutAttr layout = layoutAttr;
253 auto sliceDims = getSliceDims();
254 if (sliceDims.size() > 0) {
256 layout = xegpu::SliceAttr::get(
261 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(
target);
264 <<
"Expected a xegpu.create_nd_desc op, but got: "
266 diag.attachNote(
target->getLoc()) <<
"target op";
274 results.
set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
279void transform::SetDescLayoutOp::getEffects(
289void transform::SetOpLayoutAttrOp::build(
299 build(builder, ostate,
target.getType(),
317 if (!llvm::hasSingleElement(targetOps)) {
319 << llvm::range_size(targetOps) <<
")";
323 bool resultTarget = getResult();
326 if (resultTarget &&
index >=
target->getNumResults()) {
328 <<
"Index exceeds the number of op results";
330 if (!resultTarget &&
index >=
target->getNumOperands()) {
332 <<
"Index exceeds the number of op operands";
335 xegpu::LayoutAttr layoutAttr =
nullptr;
337 getMixedSgLayout(), getMixedSgData(),
338 getMixedInstData(), layoutAttr);
339 if (!status.succeeded())
342 xegpu::DistributeLayoutAttr layout = layoutAttr;
343 auto sliceDims = getSliceDims();
344 if (sliceDims.size() > 0) {
346 layout = xegpu::SliceAttr::get(
358void transform::SetOpLayoutAttrOp::getEffects(
367void transform::SetGPULaunchThreadsOp::build(
373 build(builder, ostate,
target.getType(),
384 if (!llvm::hasSingleElement(targetOps)) {
386 << llvm::range_size(targetOps) <<
")";
390 auto launchOp = dyn_cast<gpu::LaunchOp>(
target);
393 <<
"Expected a gpu.launch op, but got: " <<
target->getName();
394 diag.attachNote(
target->getLoc()) <<
"target op";
404 if (threads.size() != 3) {
406 <<
"Expected threads argument to consist of three values (got "
407 << threads.size() <<
")";
411 auto createConstValue = [&](
int value) {
416 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
417 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
418 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
423void transform::SetGPULaunchThreadsOp::getEffects(
435 if (!llvm::hasSingleElement(targetValues))
437 <<
"requires exactly one target value handle (got "
438 << llvm::range_size(targetValues) <<
")";
439 auto value = *targetValues.begin();
441 int64_t nbPrefetch = getStaticNbPrefetch();
442 if (getDynamicNbPrefetch()) {
446 {getDynamicNbPrefetch()});
449 if (dynamicNbPrefetch.size() != 1)
451 <<
"requires exactly one value for dynamic_nb_prefetch";
452 nbPrefetch = dynamicNbPrefetch[0];
456 <<
"nb_prefetch must be a positive integer.";
462 auto loadOp = *maybeLoadOp;
463 if (loadOp.getMixedOffsets().size() == 0) {
465 <<
"Load op must have offsets.";
466 diag.attachNote(loadOp.getLoc()) <<
"load op";
471 auto forOp = loadOp->getParentOfType<scf::ForOp>();
474 <<
"Load op is not contained in a scf.for loop.";
475 diag.attachNote(loadOp.getLoc()) <<
"load op";
483 auto descOp = *maybeDescOp;
484 if (descOp.getMixedOffsets().size() > 0) {
486 <<
"desc op with offsets is not supported.";
487 diag.attachNote(descOp.getLoc()) <<
"desc op";
493 cast<xegpu::CreateNdDescOp>(rewriter.
clone(*descOp.getOperation()));
500 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
501 auto initUpBound = rewriter.
createOrFold<arith::AddIOp>(
502 forOp.getLoc(), forOp.getLowerBound(), nbStep);
504 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
505 initUpBound, forOp.getStep());
509 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
513 auto getPrefetchOffsets =
516 mapping.
map(forOp.getInductionVar(), replacementVal);
518 llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](
Value v) {
519 return mapping.lookupOrDefault(v);
521 auto constOffsets = loadOp.getConstOffsets().value();
528 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
529 newDescOp.getResult(),
530 getPrefetchOffsets(initForOp.getInductionVar()),
531 readCacheHint, readCacheHint, readCacheHint);
536 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
537 forOp.getInductionVar(), nbStep);
539 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
540 newDescOp.getResult(),
541 getPrefetchOffsets(prefetchOffset), readCacheHint,
542 readCacheHint, readCacheHint);
548 results.
set(llvm::cast<OpResult>(getResult()), {newDescOp});
553void transform::InsertPrefetchOp::getEffects(
561void transform::ConvertLayoutOp::build(
572 dynamicInputInstData;
574 staticInputSgLayout);
578 staticInputInstData);
580 staticTargetInstData;
582 dynamicTargetInstData;
584 staticTargetSgLayout);
588 staticTargetInstData);
589 build(builder, ostate,
target.getType(),
591 dynamicInputSgLayout,
593 dynamicInputInstData,
594 dynamicTargetSgLayout,
596 dynamicTargetInstData,
600 staticTargetSgLayout,
602 staticTargetInstData);
610 if (!llvm::hasSingleElement(targetValues))
612 <<
"requires exactly one target value handle (got "
613 << llvm::range_size(targetValues) <<
")";
614 auto value = *targetValues.begin();
617 xegpu::LayoutAttr inputLayoutAttr =
nullptr;
619 getContext(), state, (*
this), getMixedInputSgLayout(),
620 getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
624 xegpu::LayoutAttr targetLayoutAttr =
nullptr;
626 getContext(), state, (*
this), getMixedTargetSgLayout(),
627 getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
632 if (value.use_empty())
634 <<
"Value has no users to insert layout conversion.";
640 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
641 value, inputLayoutAttr, targetLayoutAttr);
644 value, convLayoutOp.getResult(), [&](
OpOperand &use) {
645 return use.getOwner() != convLayoutOp.getOperation();
648 results.
set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
652void transform::ConvertLayoutOp::getEffects(
666class XeGPUTransformDialectExtension
668 XeGPUTransformDialectExtension> {
677void XeGPUTransformDialectExtension::init() {
678 declareGeneratedDialect<scf::SCFDialect>();
679 declareGeneratedDialect<arith::ArithDialect>();
680 declareGeneratedDialect<xegpu::XeGPUDialect>();
682 registerTransformOps<
684#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
689#define GET_OP_CLASSES
690#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
user_range getUsers()
Returns a range of all users.
unsigned getNumResults()
Return the number of results held by this operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
unsigned getNumUses() const
This method computes the number of uses of this Value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
void registerTransformDialectExtension(DialectRegistry ®istry)
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.