15#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "xegpu-transforms"
33 if (
auto attr = dyn_cast<Attribute>(ofr)) {
34 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
35 result.push_back(intAttr.getInt());
38 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
42 Value transformValue = cast<Value>(ofr);
43 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
45 if (params.size() != 1)
46 return transformOp.emitDefiniteFailure()
47 <<
"requires exactly one parameter associated";
49 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
55 if (!llvm::hasSingleElement(payloadOps)) {
57 transformOp.emitSilenceableError()
58 <<
"handle must be mapped to exactly one payload op";
60 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
67 transformOp.emitSilenceableError()
68 <<
"payload op must have exactly 1 index result";
76 return transformOp.emitSilenceableError()
77 <<
"requires param or handle to be the result of a constant like "
80 result.push_back(intAttr.getInt());
90 Value currentValue = val;
94 LDBG() <<
"Failed to find producer op, value has no uses.";
97 auto userOp = val.
getUsers().begin();
98 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
100 LDBG() <<
"Failed to find producer op, not in a loop.";
104 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
105 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
106 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
107 currentValue = parentLoop.getInits()[iterArgIdx];
109 LDBG() <<
"Failed to find producer op, value not in init values.";
115 if (
auto matchingOp = dyn_cast<T>(producerOp))
128 return xegpu::LayoutAttr::get(
140 TransformOpInterface transformOp,
145 xegpu::LayoutAttr &layoutAttr) {
149 if (!status.succeeded())
153 if (!status.succeeded())
157 if (!status.succeeded())
159 auto maybeInstData = instData.empty()
161 : std::optional<ArrayRef<int32_t>>(instData);
173 if (!llvm::hasSingleElement(targetValues)) {
175 <<
"requires exactly one target value handle (got "
176 << llvm::range_size(targetValues) <<
")";
183 loadOp = maybeLoadNdOp->getOperation();
188 loadOp = maybeLoadOp->getOperation();
191 <<
"Could not find a matching xegpu.load_nd or xegpu.load op when "
193 "producer chain of the first operand.";
197 results.
set(llvm::cast<OpResult>(getResult()), {loadOp});
201void transform::SetAnchorLayoutOp::build(
211 build(builder, ostate,
target.getType(),
232 xegpu::LayoutAttr layoutAttr =
nullptr;
234 getContext(), state, (*
this), getMixedSgLayout(), getMixedSgData(),
235 getMixedInstData(), getOrder(), layoutAttr);
236 if (!status.succeeded())
239 xegpu::DistributeLayoutAttr layout = layoutAttr;
240 auto sliceDims = getSliceDims();
241 if (sliceDims.size() > 0) {
243 layout = xegpu::SliceAttr::get(
250 if (
auto dpasOp = dyn_cast<xegpu::DpasOp>(
target)) {
253 dpasOp.getProperties().layout_a = layout;
255 dpasOp.getProperties().layout_b = layout;
257 dpasOp.getProperties().layout_cd = layout;
260 <<
"Invalid index for setting dpas op layout: " <<
index;
261 diag.attachNote(
target->getLoc()) <<
"target op";
266 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(
target);
269 <<
"Cannot set anchor layout to op: " <<
target->getName();
270 diag.attachNote(
target->getLoc()) <<
"target op";
273 anchorOp.setAnchorLayout(layout);
279void transform::SetAnchorLayoutOp::getEffects(
288void transform::SetGPULaunchThreadsOp::build(
294 build(builder, ostate,
target.getType(),
305 if (!llvm::hasSingleElement(targetOps)) {
307 << llvm::range_size(targetOps) <<
")";
311 auto launchOp = dyn_cast<gpu::LaunchOp>(
target);
314 <<
"Expected a gpu.launch op, but got: " <<
target->getName();
315 diag.attachNote(
target->getLoc()) <<
"target op";
325 if (threads.size() != 3) {
327 <<
"Expected threads argument to consist of three values (got "
328 << threads.size() <<
")";
332 auto createConstValue = [&](
int value) {
337 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
338 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
339 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
344void transform::SetGPULaunchThreadsOp::getEffects(
356 if (!llvm::hasSingleElement(targetOps))
358 <<
"requires exactly one target op handle (got "
359 << llvm::range_size(targetOps) <<
")";
360 auto target = *targetOps.begin();
362 int64_t nbPrefetch = getStaticNbPrefetch();
363 if (getDynamicNbPrefetch()) {
367 {getDynamicNbPrefetch()});
370 if (dynamicNbPrefetch.size() != 1)
372 <<
"requires exactly one value for dynamic_nb_prefetch";
373 nbPrefetch = dynamicNbPrefetch[0];
377 <<
"nb_prefetch must be a positive integer.";
380 auto maybeLoadOp = dyn_cast<xegpu::LoadNdOp>(
target);
383 <<
"Expected xegpu.load_nd op, got " <<
target->getName();
385 auto loadOp = maybeLoadOp;
386 if (loadOp.getMixedOffsets().size() == 0) {
388 <<
"Load op must have offsets.";
389 diag.attachNote(loadOp.
getLoc()) <<
"load op";
397 <<
"Load op is not contained in a scf.for loop.";
398 diag.attachNote(loadOp.
getLoc()) <<
"load op";
407 auto descOp = *maybeDescOp;
408 if (descOp.getMixedOffsets().size() > 0) {
410 <<
"desc op with offsets is not supported.";
411 diag.attachNote(descOp.getLoc()) <<
"desc op";
417 cast<xegpu::CreateNdDescOp>(rewriter.
clone(*descOp.getOperation()));
424 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
425 auto initUpBound = rewriter.
createOrFold<arith::AddIOp>(
426 forOp.getLoc(), forOp.getLowerBound(), nbStep);
428 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
429 initUpBound, forOp.getStep());
433 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
437 auto getPrefetchOffsets =
440 mapping.
map(forOp.getInductionVar(), replacementVal);
442 llvm::map_to_vector(loadOp.getOffsets(), [&](
Value v) {
443 return mapping.lookupOrDefault(v);
445 auto constOffsets = loadOp.getConstOffsets().value();
452 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
453 newDescOp.getResult(),
454 getPrefetchOffsets(initForOp.getInductionVar()),
455 readCacheHint, readCacheHint, readCacheHint,
461 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
462 forOp.getInductionVar(), nbStep);
464 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
465 newDescOp.getResult(),
466 getPrefetchOffsets(prefetchOffset), readCacheHint,
467 readCacheHint, readCacheHint,
nullptr);
473 results.
set(llvm::cast<OpResult>(getResult()), {newDescOp});
478void transform::InsertPrefetchOp::getEffects(
486void transform::ConvertLayoutOp::build(
497 dynamicInputInstData;
499 staticInputSgLayout);
503 staticInputInstData);
505 staticTargetInstData;
507 dynamicTargetInstData;
509 staticTargetSgLayout);
513 staticTargetInstData);
514 build(builder, ostate,
target.getType(),
516 dynamicInputSgLayout,
518 dynamicInputInstData,
519 dynamicTargetSgLayout,
521 dynamicTargetInstData,
526 staticTargetSgLayout,
528 staticTargetInstData,
537 if (!llvm::hasSingleElement(targetValues))
539 <<
"requires exactly one target value handle (got "
540 << llvm::range_size(targetValues) <<
")";
541 auto value = *targetValues.begin();
544 xegpu::LayoutAttr inputLayoutAttr =
nullptr;
546 getContext(), state, (*
this), getMixedInputSgLayout(),
547 getMixedInputSgData(), getMixedInputInstData(), getInputOrder(),
552 xegpu::LayoutAttr targetLayoutAttr =
nullptr;
554 getContext(), state, (*
this), getMixedTargetSgLayout(),
555 getMixedTargetSgData(), getMixedTargetInstData(), getTargetOrder(),
561 if (value.use_empty())
563 <<
"Value has no users to insert layout conversion.";
569 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
570 value, inputLayoutAttr, targetLayoutAttr);
573 value, convLayoutOp.getResult(), [&](
OpOperand &use) {
574 return use.getOwner() != convLayoutOp.getOperation();
577 results.
set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
581void transform::ConvertLayoutOp::getEffects(
595class XeGPUTransformDialectExtension
597 XeGPUTransformDialectExtension> {
606void XeGPUTransformDialectExtension::init() {
607 declareGeneratedDialect<scf::SCFDialect>();
608 declareGeneratedDialect<arith::ArithDialect>();
609 declareGeneratedDialect<xegpu::XeGPUDialect>();
611 registerTransformOps<
613#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
618#define GET_OP_CLASSES
619#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
user_range getUsers()
Returns a range of all users.
unsigned getNumResults()
Return the number of results held by this operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
unsigned getNumUses() const
This method computes the number of uses of this Value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
void registerTransformDialectExtension(DialectRegistry ®istry)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.