15#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "xegpu-transforms"
33 if (
auto attr = dyn_cast<Attribute>(ofr)) {
34 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
35 result.push_back(intAttr.getInt());
38 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
42 Value transformValue = cast<Value>(ofr);
43 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
45 if (params.size() != 1)
46 return transformOp.emitDefiniteFailure()
47 <<
"requires exactly one parameter associated";
49 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
55 if (!llvm::hasSingleElement(payloadOps)) {
57 transformOp.emitSilenceableError()
58 <<
"handle must be mapped to exactly one payload op";
60 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
67 transformOp.emitSilenceableError()
68 <<
"payload op must have exactly 1 index result";
76 return transformOp.emitSilenceableError()
77 <<
"requires param or handle to be the result of a constant like "
80 result.push_back(intAttr.getInt());
90 Value currentValue = val;
94 LDBG() <<
"Failed to find producer op, value has no uses.";
97 auto userOp = val.
getUsers().begin();
98 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
100 LDBG() <<
"Failed to find producer op, not in a loop.";
104 if (
auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
105 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
106 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
107 currentValue = parentLoop.getInits()[iterArgIdx];
109 LDBG() <<
"Failed to find producer op, value not in init values.";
115 if (
auto matchingOp = dyn_cast<T>(producerOp))
125static xegpu::LayoutAttr
129 return xegpu::LayoutAttr::get(
141 TransformOpInterface transformOp,
145 xegpu::LayoutAttr &layoutAttr) {
149 if (!status.succeeded())
153 if (!status.succeeded())
157 if (!status.succeeded())
159 auto maybeInstData = instData.empty()
161 : std::optional<ArrayRef<int32_t>>(instData);
169static xegpu::CreateNdDescOp
171 xegpu::CreateNdDescOp descOp,
172 xegpu::DistributeLayoutAttr layout) {
173 assert(descOp.getMixedOffsets().size() == 0 &&
174 "create desc op with offsets is not supported");
175 auto oldTensorDesc = descOp.getType();
176 auto descType = xegpu::TensorDescType::get(
177 oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
178 oldTensorDesc.getArrayLength(),
179 oldTensorDesc.getBoundaryCheck(),
180 oldTensorDesc.getMemorySpace(),
185 descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
186 descOp.getMixedStrides());
195 if (!llvm::hasSingleElement(targetValues)) {
197 <<
"requires exactly one target value handle (got "
198 << llvm::range_size(targetValues) <<
")";
205 <<
"Could not find a matching descriptor op when walking the "
206 "producer chain of the first operand.";
209 results.
set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
213void transform::SetDescLayoutOp::build(
OpBuilder &builder,
240 if (!llvm::hasSingleElement(targetOps)) {
242 << llvm::range_size(targetOps) <<
")";
246 xegpu::LayoutAttr layoutAttr =
nullptr;
248 getMixedSgLayout(), getMixedSgData(),
249 getMixedInstData(), layoutAttr);
250 if (!status.succeeded())
253 xegpu::DistributeLayoutAttr layout = layoutAttr;
254 auto sliceDims = getSliceDims();
255 if (sliceDims.size() > 0) {
257 layout = xegpu::SliceAttr::get(
262 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(
target);
265 <<
"Expected a xegpu.create_nd_desc op, but got: "
267 diag.attachNote(
target->getLoc()) <<
"target op";
275 results.
set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
280void transform::SetDescLayoutOp::getEffects(
290void transform::SetOpLayoutAttrOp::build(
294 bool result,
bool operand) {
300 build(builder, ostate,
target.getType(),
319 if (!llvm::hasSingleElement(targetOps)) {
321 << llvm::range_size(targetOps) <<
")";
325 bool resultTarget = getResult();
326 bool operandTarget = getOperand();
329 if (resultTarget &&
index >=
target->getNumResults()) {
331 <<
"Index exceeds the number of op results";
333 if (operandTarget &&
index >=
target->getNumOperands()) {
335 <<
"Index exceeds the number of op operands";
338 xegpu::LayoutAttr layoutAttr =
nullptr;
340 getMixedSgLayout(), getMixedSgData(),
341 getMixedInstData(), layoutAttr);
342 if (!status.succeeded())
345 xegpu::DistributeLayoutAttr layout = layoutAttr;
346 auto sliceDims = getSliceDims();
347 if (sliceDims.size() > 0) {
349 layout = xegpu::SliceAttr::get(
357 }
else if (operandTarget) {
360 }
else if (
auto dpasOp = dyn_cast<xegpu::DpasOp>(
target)) {
363 dpasOp.getProperties().layout_a = layout;
365 dpasOp.getProperties().layout_b = layout;
367 dpasOp.getProperties().layout_cd = layout;
370 <<
"Invalid index for setting dpas op layout: " <<
index;
371 diag.attachNote(
target->getLoc()) <<
"target op";
376 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(
target);
379 <<
"Cannot set anchor layout to op: " <<
target->getName();
380 diag.attachNote(
target->getLoc()) <<
"target op";
383 anchorOp.setAnchorLayout(layout);
388void transform::SetOpLayoutAttrOp::getEffects(
397LogicalResult transform::SetOpLayoutAttrOp::verify() {
398 if (getResult() && getOperand()) {
399 return emitOpError(
"Cannot set both result and operand simultaneously.");
404void transform::SetGPULaunchThreadsOp::build(
410 build(builder, ostate,
target.getType(),
421 if (!llvm::hasSingleElement(targetOps)) {
423 << llvm::range_size(targetOps) <<
")";
427 auto launchOp = dyn_cast<gpu::LaunchOp>(
target);
430 <<
"Expected a gpu.launch op, but got: " <<
target->getName();
431 diag.attachNote(
target->getLoc()) <<
"target op";
441 if (threads.size() != 3) {
443 <<
"Expected threads argument to consist of three values (got "
444 << threads.size() <<
")";
448 auto createConstValue = [&](
int value) {
453 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
454 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
455 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
460void transform::SetGPULaunchThreadsOp::getEffects(
472 if (!llvm::hasSingleElement(targetValues))
474 <<
"requires exactly one target value handle (got "
475 << llvm::range_size(targetValues) <<
")";
476 auto value = *targetValues.begin();
478 int64_t nbPrefetch = getStaticNbPrefetch();
479 if (getDynamicNbPrefetch()) {
483 {getDynamicNbPrefetch()});
486 if (dynamicNbPrefetch.size() != 1)
488 <<
"requires exactly one value for dynamic_nb_prefetch";
489 nbPrefetch = dynamicNbPrefetch[0];
493 <<
"nb_prefetch must be a positive integer.";
499 auto loadOp = *maybeLoadOp;
500 if (loadOp.getMixedOffsets().size() == 0) {
502 <<
"Load op must have offsets.";
503 diag.attachNote(loadOp.getLoc()) <<
"load op";
508 auto forOp = loadOp->getParentOfType<scf::ForOp>();
511 <<
"Load op is not contained in a scf.for loop.";
512 diag.attachNote(loadOp.getLoc()) <<
"load op";
520 auto descOp = *maybeDescOp;
521 if (descOp.getMixedOffsets().size() > 0) {
523 <<
"desc op with offsets is not supported.";
524 diag.attachNote(descOp.getLoc()) <<
"desc op";
530 cast<xegpu::CreateNdDescOp>(rewriter.
clone(*descOp.getOperation()));
537 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
538 auto initUpBound = rewriter.
createOrFold<arith::AddIOp>(
539 forOp.getLoc(), forOp.getLowerBound(), nbStep);
541 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
542 initUpBound, forOp.getStep());
546 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
550 auto getPrefetchOffsets =
553 mapping.
map(forOp.getInductionVar(), replacementVal);
555 llvm::map_to_vector(loadOp.getOffsets(), [&](
Value v) {
556 return mapping.lookupOrDefault(v);
558 auto constOffsets = loadOp.getConstOffsets().value();
565 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
566 newDescOp.getResult(),
567 getPrefetchOffsets(initForOp.getInductionVar()),
568 readCacheHint, readCacheHint, readCacheHint,
574 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
575 forOp.getInductionVar(), nbStep);
577 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
578 newDescOp.getResult(),
579 getPrefetchOffsets(prefetchOffset), readCacheHint,
580 readCacheHint, readCacheHint,
nullptr);
586 results.
set(llvm::cast<OpResult>(getResult()), {newDescOp});
591void transform::InsertPrefetchOp::getEffects(
599void transform::ConvertLayoutOp::build(
610 dynamicInputInstData;
612 staticInputSgLayout);
616 staticInputInstData);
618 staticTargetInstData;
620 dynamicTargetInstData;
622 staticTargetSgLayout);
626 staticTargetInstData);
627 build(builder, ostate,
target.getType(),
629 dynamicInputSgLayout,
631 dynamicInputInstData,
632 dynamicTargetSgLayout,
634 dynamicTargetInstData,
638 staticTargetSgLayout,
640 staticTargetInstData);
648 if (!llvm::hasSingleElement(targetValues))
650 <<
"requires exactly one target value handle (got "
651 << llvm::range_size(targetValues) <<
")";
652 auto value = *targetValues.begin();
655 xegpu::LayoutAttr inputLayoutAttr =
nullptr;
657 getContext(), state, (*
this), getMixedInputSgLayout(),
658 getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
662 xegpu::LayoutAttr targetLayoutAttr =
nullptr;
664 getContext(), state, (*
this), getMixedTargetSgLayout(),
665 getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
670 if (value.use_empty())
672 <<
"Value has no users to insert layout conversion.";
678 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
679 value, inputLayoutAttr, targetLayoutAttr);
682 value, convLayoutOp.getResult(), [&](
OpOperand &use) {
683 return use.getOwner() != convLayoutOp.getOperation();
686 results.
set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
690void transform::ConvertLayoutOp::getEffects(
704class XeGPUTransformDialectExtension
706 XeGPUTransformDialectExtension> {
715void XeGPUTransformDialectExtension::init() {
716 declareGeneratedDialect<scf::SCFDialect>();
717 declareGeneratedDialect<arith::ArithDialect>();
718 declareGeneratedDialect<xegpu::XeGPUDialect>();
720 registerTransformOps<
722#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
727#define GET_OP_CLASSES
728#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
user_range getUsers()
Returns a range of all users.
unsigned getNumResults()
Return the number of results held by this operation.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
unsigned getNumUses() const
This method computes the number of uses of this Value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
void registerTransformDialectExtension(DialectRegistry ®istry)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.