15#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "xegpu-transforms"
33 if (
auto attr = dyn_cast<Attribute>(ofr)) {
34 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
35 result.push_back(intAttr.getInt());
38 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
42 Value transformValue = cast<Value>(ofr);
43 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
45 if (params.size() != 1)
46 return transformOp.emitDefiniteFailure()
47 <<
"requires exactly one parameter associated";
49 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
55 if (!llvm::hasSingleElement(payloadOps)) {
57 transformOp.emitSilenceableError()
58 <<
"handle must be mapped to exactly one payload op";
60 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
67 transformOp.emitSilenceableError()
68 <<
"payload op must have exactly 1 index result";
76 return transformOp.emitSilenceableError()
77 <<
"requires param or handle to be the result of a constant like "
80 result.push_back(intAttr.getInt());
90 Value currentValue = val;
94 LDBG() <<
"Failed to find producer op, value has no uses.";
97 auto userOp = val.
getUsers().begin();
98 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
100 LDBG() <<
"Failed to find producer op, not in a loop.";
104 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
105 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
106 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
107 currentValue = parentLoop.getInits()[iterArgIdx];
109 LDBG() <<
"Failed to find producer op, value not in init values.";
115 if (
auto matchingOp = dyn_cast<T>(producerOp))
125static xegpu::LayoutAttr
129 return xegpu::LayoutAttr::get(
141 TransformOpInterface transformOp,
145 xegpu::LayoutAttr &layoutAttr) {
149 if (!status.succeeded())
153 if (!status.succeeded())
157 if (!status.succeeded())
159 auto maybeInstData = instData.empty()
161 : std::optional<ArrayRef<int32_t>>(instData);
169static xegpu::CreateNdDescOp
171 xegpu::CreateNdDescOp descOp,
172 xegpu::DistributeLayoutAttr layout) {
173 assert(descOp.getMixedOffsets().size() == 0 &&
174 "create desc op with offsets is not supported");
175 auto oldTensorDesc = descOp.getType();
176 auto descType = xegpu::TensorDescType::get(
177 oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
178 oldTensorDesc.getArrayLength(),
179 oldTensorDesc.getBoundaryCheck(),
180 oldTensorDesc.getMemorySpace(),
185 descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
186 descOp.getMixedStrides());
195 if (!llvm::hasSingleElement(targetValues)) {
197 <<
"requires exactly one target value handle (got "
198 << llvm::range_size(targetValues) <<
")";
205 <<
"Could not find a matching descriptor op when walking the "
206 "producer chain of the first operand.";
209 results.
set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
213void transform::SetDescLayoutOp::build(
OpBuilder &builder,
240 if (!llvm::hasSingleElement(targetOps)) {
242 << llvm::range_size(targetOps) <<
")";
246 xegpu::LayoutAttr layoutAttr =
nullptr;
248 getMixedSgLayout(), getMixedSgData(),
249 getMixedInstData(), layoutAttr);
250 if (!status.succeeded())
253 xegpu::DistributeLayoutAttr layout = layoutAttr;
254 auto sliceDims = getSliceDims();
255 if (sliceDims.size() > 0) {
257 layout = xegpu::SliceAttr::get(
262 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(
target);
265 <<
"Expected a xegpu.create_nd_desc op, but got: "
267 diag.attachNote(
target->getLoc()) <<
"target op";
275 results.
set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
280void transform::SetDescLayoutOp::getEffects(
290void transform::SetOpLayoutAttrOp::build(
300 build(builder, ostate,
target.getType(),
318 if (!llvm::hasSingleElement(targetOps)) {
320 << llvm::range_size(targetOps) <<
")";
324 bool resultTarget = getResult();
327 if (resultTarget &&
index >=
target->getNumResults()) {
329 <<
"Index exceeds the number of op results";
331 if (!resultTarget &&
index >=
target->getNumOperands()) {
333 <<
"Index exceeds the number of op operands";
336 xegpu::LayoutAttr layoutAttr =
nullptr;
338 getMixedSgLayout(), getMixedSgData(),
339 getMixedInstData(), layoutAttr);
340 if (!status.succeeded())
343 xegpu::DistributeLayoutAttr layout = layoutAttr;
344 auto sliceDims = getSliceDims();
345 if (sliceDims.size() > 0) {
347 layout = xegpu::SliceAttr::get(
359void transform::SetOpLayoutAttrOp::getEffects(
368void transform::SetGPULaunchThreadsOp::build(
374 build(builder, ostate,
target.getType(),
385 if (!llvm::hasSingleElement(targetOps)) {
387 << llvm::range_size(targetOps) <<
")";
391 auto launchOp = dyn_cast<gpu::LaunchOp>(
target);
394 <<
"Expected a gpu.launch op, but got: " <<
target->getName();
395 diag.attachNote(
target->getLoc()) <<
"target op";
405 if (threads.size() != 3) {
407 <<
"Expected threads argument to consist of three values (got "
408 << threads.size() <<
")";
412 auto createConstValue = [&](
int value) {
417 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
418 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
419 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
424void transform::SetGPULaunchThreadsOp::getEffects(
436 if (!llvm::hasSingleElement(targetValues))
438 <<
"requires exactly one target value handle (got "
439 << llvm::range_size(targetValues) <<
")";
440 auto value = *targetValues.begin();
442 int64_t nbPrefetch = getStaticNbPrefetch();
443 if (getDynamicNbPrefetch()) {
447 {getDynamicNbPrefetch()});
450 if (dynamicNbPrefetch.size() != 1)
452 <<
"requires exactly one value for dynamic_nb_prefetch";
453 nbPrefetch = dynamicNbPrefetch[0];
457 <<
"nb_prefetch must be a positive integer.";
463 auto loadOp = *maybeLoadOp;
464 if (loadOp.getMixedOffsets().size() == 0) {
466 <<
"Load op must have offsets.";
467 diag.attachNote(loadOp.getLoc()) <<
"load op";
472 auto forOp = loadOp->getParentOfType<scf::ForOp>();
475 <<
"Load op is not contained in a scf.for loop.";
476 diag.attachNote(loadOp.getLoc()) <<
"load op";
484 auto descOp = *maybeDescOp;
485 if (descOp.getMixedOffsets().size() > 0) {
487 <<
"desc op with offsets is not supported.";
488 diag.attachNote(descOp.getLoc()) <<
"desc op";
494 cast<xegpu::CreateNdDescOp>(rewriter.
clone(*descOp.getOperation()));
501 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
502 auto initUpBound = rewriter.
createOrFold<arith::AddIOp>(
503 forOp.getLoc(), forOp.getLowerBound(), nbStep);
505 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
506 initUpBound, forOp.getStep());
510 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
514 auto getPrefetchOffsets =
517 mapping.
map(forOp.getInductionVar(), replacementVal);
519 llvm::map_to_vector(loadOp.getOffsets(), [&](
Value v) {
520 return mapping.lookupOrDefault(v);
522 auto constOffsets = loadOp.getConstOffsets().value();
529 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
530 newDescOp.getResult(),
531 getPrefetchOffsets(initForOp.getInductionVar()),
532 readCacheHint, readCacheHint, readCacheHint,
538 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
539 forOp.getInductionVar(), nbStep);
541 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
542 newDescOp.getResult(),
543 getPrefetchOffsets(prefetchOffset), readCacheHint,
544 readCacheHint, readCacheHint,
nullptr);
550 results.
set(llvm::cast<OpResult>(getResult()), {newDescOp});
555void transform::InsertPrefetchOp::getEffects(
563void transform::ConvertLayoutOp::build(
574 dynamicInputInstData;
576 staticInputSgLayout);
580 staticInputInstData);
582 staticTargetInstData;
584 dynamicTargetInstData;
586 staticTargetSgLayout);
590 staticTargetInstData);
591 build(builder, ostate,
target.getType(),
593 dynamicInputSgLayout,
595 dynamicInputInstData,
596 dynamicTargetSgLayout,
598 dynamicTargetInstData,
602 staticTargetSgLayout,
604 staticTargetInstData);
612 if (!llvm::hasSingleElement(targetValues))
614 <<
"requires exactly one target value handle (got "
615 << llvm::range_size(targetValues) <<
")";
616 auto value = *targetValues.begin();
619 xegpu::LayoutAttr inputLayoutAttr =
nullptr;
621 getContext(), state, (*
this), getMixedInputSgLayout(),
622 getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
626 xegpu::LayoutAttr targetLayoutAttr =
nullptr;
628 getContext(), state, (*
this), getMixedTargetSgLayout(),
629 getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
634 if (value.use_empty())
636 <<
"Value has no users to insert layout conversion.";
642 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
643 value, inputLayoutAttr, targetLayoutAttr);
646 value, convLayoutOp.getResult(), [&](
OpOperand &use) {
647 return use.getOwner() != convLayoutOp.getOperation();
650 results.
set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
654void transform::ConvertLayoutOp::getEffects(
668class XeGPUTransformDialectExtension
670 XeGPUTransformDialectExtension> {
679void XeGPUTransformDialectExtension::init() {
680 declareGeneratedDialect<scf::SCFDialect>();
681 declareGeneratedDialect<arith::ArithDialect>();
682 declareGeneratedDialect<xegpu::XeGPUDialect>();
684 registerTransformOps<
686#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
691#define GET_OP_CLASSES
692#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
user_range getUsers()
Returns a range of all users.
unsigned getNumResults()
Return the number of results held by this operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
unsigned getNumUses() const
This method computes the number of uses of this Value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
void registerTransformDialectExtension(DialectRegistry ®istry)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.