15#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "xegpu-transforms"
33 if (
auto attr = dyn_cast<Attribute>(ofr)) {
34 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
35 result.push_back(intAttr.getInt());
38 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
42 Value transformValue = cast<Value>(ofr);
43 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
45 if (params.size() != 1)
46 return transformOp.emitDefiniteFailure()
47 <<
"requires exactly one parameter associated";
49 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
55 if (!llvm::hasSingleElement(payloadOps)) {
57 transformOp.emitSilenceableError()
58 <<
"handle must be mapped to exactly one payload op";
60 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
67 transformOp.emitSilenceableError()
68 <<
"payload op must have exactly 1 index result";
76 return transformOp.emitSilenceableError()
77 <<
"requires param or handle to be the result of a constant like "
80 result.push_back(intAttr.getInt());
90 Value currentValue = val;
94 LDBG() <<
"Failed to find producer op, value has no uses.";
97 auto userOp = val.
getUsers().begin();
98 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
100 LDBG() <<
"Failed to find producer op, not in a loop.";
104 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
105 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
106 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
107 currentValue = parentLoop.getInits()[iterArgIdx];
109 LDBG() <<
"Failed to find producer op, value not in init values.";
115 if (
auto matchingOp = dyn_cast<T>(producerOp))
128 return xegpu::LayoutAttr::get(
140 TransformOpInterface transformOp,
145 xegpu::LayoutAttr &layoutAttr) {
149 if (!status.succeeded())
153 if (!status.succeeded())
157 if (!status.succeeded())
159 auto maybeInstData = instData.empty()
161 : std::optional<ArrayRef<int32_t>>(instData);
173 if (!llvm::hasSingleElement(targetValues)) {
175 <<
"requires exactly one target value handle (got "
176 << llvm::range_size(targetValues) <<
")";
183 loadOp = maybeLoadNdOp->getOperation();
188 loadOp = maybeLoadOp->getOperation();
191 <<
"Could not find a matching xegpu.load_nd or xegpu.load op when "
193 "producer chain of the first operand.";
197 results.
set(llvm::cast<OpResult>(getResult()), {loadOp});
201void transform::SetAnchorLayoutOp::build(
211 build(builder, ostate,
target.getType(),
232 xegpu::LayoutAttr layoutAttr =
nullptr;
234 getContext(), state, (*
this), getMixedSgLayout(), getMixedSgData(),
235 getMixedInstData(), getOrder(), layoutAttr);
236 if (!status.succeeded())
239 xegpu::DistributeLayoutAttr layout = layoutAttr;
240 auto sliceDims = getSliceDims();
241 if (sliceDims.size() > 0) {
243 layout = xegpu::SliceAttr::get(
250 if (
auto dpasOp = dyn_cast<xegpu::DpasOp>(
target)) {
253 dpasOp.getProperties().layout_a = layout;
255 dpasOp.getProperties().layout_b = layout;
257 dpasOp.getProperties().layout_cd = layout;
260 <<
"Invalid index for setting dpas op layout: " <<
index;
261 diag.attachNote(
target->getLoc()) <<
"target op";
266 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(
target);
269 <<
"Cannot set anchor layout to op: " <<
target->getName();
270 diag.attachNote(
target->getLoc()) <<
"target op";
273 anchorOp.setAnchorLayout(layout);
279void transform::SetAnchorLayoutOp::getEffects(
288void transform::SetGPULaunchThreadsOp::build(
294 build(builder, ostate,
target.getType(),
305 if (!llvm::hasSingleElement(targetOps)) {
307 << llvm::range_size(targetOps) <<
")";
311 auto launchOp = dyn_cast<gpu::LaunchOp>(
target);
314 <<
"Expected a gpu.launch op, but got: " <<
target->getName();
315 diag.attachNote(
target->getLoc()) <<
"target op";
325 if (threads.size() != 3) {
327 <<
"Expected threads argument to consist of three values (got "
328 << threads.size() <<
")";
332 auto createConstValue = [&](
int value) {
337 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
338 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
339 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
344void transform::SetGPULaunchThreadsOp::getEffects(
356 if (!llvm::hasSingleElement(targetOps))
358 <<
"requires exactly one target op handle (got "
359 << llvm::range_size(targetOps) <<
")";
360 auto target = *targetOps.begin();
362 int64_t nbPrefetch = getStaticNbPrefetch();
363 if (getDynamicNbPrefetch()) {
367 {getDynamicNbPrefetch()});
370 if (dynamicNbPrefetch.size() != 1)
372 <<
"requires exactly one value for dynamic_nb_prefetch";
373 nbPrefetch = dynamicNbPrefetch[0];
377 <<
"nb_prefetch must be a positive integer.";
380 auto maybeLoadOp = dyn_cast<xegpu::LoadNdOp>(
target);
383 <<
"Expected xegpu.load_nd op, got " <<
target->getName();
385 auto loadOp = maybeLoadOp;
386 if (loadOp.getMixedOffsets().size() == 0) {
388 <<
"Load op must have offsets.";
389 diag.attachNote(loadOp.
getLoc()) <<
"load op";
397 <<
"Load op is not contained in a scf.for loop.";
398 diag.attachNote(loadOp.
getLoc()) <<
"load op";
407 auto descOp = *maybeDescOp;
412 cast<xegpu::CreateNdDescOp>(rewriter.
clone(*descOp.getOperation()));
419 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
420 auto initUpBound = rewriter.
createOrFold<arith::AddIOp>(
421 forOp.getLoc(), forOp.getLowerBound(), nbStep);
423 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
424 initUpBound, forOp.getStep());
428 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
432 auto getPrefetchOffsets =
435 mapping.
map(forOp.getInductionVar(), replacementVal);
437 llvm::map_to_vector(loadOp.getOffsets(), [&](
Value v) {
438 return mapping.lookupOrDefault(v);
440 auto constOffsets = loadOp.getConstOffsets();
447 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
448 newDescOp.getResult(),
449 getPrefetchOffsets(initForOp.getInductionVar()),
450 readCacheHint, readCacheHint, readCacheHint,
456 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
457 forOp.getInductionVar(), nbStep);
459 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
460 newDescOp.getResult(),
461 getPrefetchOffsets(prefetchOffset), readCacheHint,
462 readCacheHint, readCacheHint,
nullptr);
468 results.
set(llvm::cast<OpResult>(getResult()), {newDescOp});
473void transform::InsertPrefetchOp::getEffects(
481void transform::ConvertLayoutOp::build(
492 dynamicInputInstData;
494 staticInputSgLayout);
498 staticInputInstData);
500 staticTargetInstData;
502 dynamicTargetInstData;
504 staticTargetSgLayout);
508 staticTargetInstData);
509 build(builder, ostate,
target.getType(),
511 dynamicInputSgLayout,
513 dynamicInputInstData,
514 dynamicTargetSgLayout,
516 dynamicTargetInstData,
521 staticTargetSgLayout,
523 staticTargetInstData,
532 if (!llvm::hasSingleElement(targetValues))
534 <<
"requires exactly one target value handle (got "
535 << llvm::range_size(targetValues) <<
")";
536 auto value = *targetValues.begin();
539 xegpu::LayoutAttr inputLayoutAttr =
nullptr;
541 getContext(), state, (*
this), getMixedInputSgLayout(),
542 getMixedInputSgData(), getMixedInputInstData(), getInputOrder(),
547 xegpu::LayoutAttr targetLayoutAttr =
nullptr;
549 getContext(), state, (*
this), getMixedTargetSgLayout(),
550 getMixedTargetSgData(), getMixedTargetInstData(), getTargetOrder(),
556 if (value.use_empty())
558 <<
"Value has no users to insert layout conversion.";
564 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
565 value, inputLayoutAttr, targetLayoutAttr);
568 value, convLayoutOp.getResult(), [&](
OpOperand &use) {
569 return use.getOwner() != convLayoutOp.getOperation();
572 results.
set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
576void transform::ConvertLayoutOp::getEffects(
590class XeGPUTransformDialectExtension
592 XeGPUTransformDialectExtension> {
601void XeGPUTransformDialectExtension::init() {
602 declareGeneratedDialect<scf::SCFDialect>();
603 declareGeneratedDialect<arith::ArithDialect>();
604 declareGeneratedDialect<xegpu::XeGPUDialect>();
606 registerTransformOps<
608#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
613#define GET_OP_CLASSES
614#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
user_range getUsers()
Returns a range of all users.
unsigned getNumResults()
Return the number of results held by this operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
unsigned getNumUses() const
This method computes the number of uses of this Value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
void registerTransformDialectExtension(DialectRegistry ®istry)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.