27 if (
auto attr = dyn_cast<Attribute>(ofr)) {
28 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
29 result.push_back(intAttr.getInt());
32 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
36 Value transformValue = cast<Value>(ofr);
37 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
39 if (params.size() != 1)
40 return transformOp.emitDefiniteFailure()
41 <<
"requires exactly one parameter associated";
43 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
49 if (!llvm::hasSingleElement(payloadOps)) {
51 transformOp.emitSilenceableError()
52 <<
"handle must be mapped to exactly one payload op";
54 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
61 transformOp.emitSilenceableError()
62 <<
"payload op must have exactly 1 index result";
70 return transformOp.emitSilenceableError()
71 <<
"requires param or handle to be the result of a constant like "
74 result.push_back(intAttr.getInt());
80static xegpu::LayoutAttr
84 return xegpu::LayoutAttr::get(
94static xegpu::CreateNdDescOp
96 xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
97 assert(descOp.getMixedOffsets().size() == 0 &&
98 "create desc op with offsets is not supported");
99 auto oldTensorDesc = descOp.getType();
100 auto descType = xegpu::TensorDescType::get(
101 oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
102 oldTensorDesc.getArrayLength(),
103 oldTensorDesc.getBoundaryCheck(),
104 oldTensorDesc.getMemorySpace(),
109 descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
110 descOp.getMixedStrides());
114void transform::SetDescLayoutOp::build(
OpBuilder &builder,
139 if (!llvm::hasSingleElement(targetOps)) {
141 << llvm::range_size(targetOps) <<
")";
161 auto maybeInstData = instData.empty()
163 : std::optional<ArrayRef<int32_t>>(instData);
166 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(
target);
169 <<
"Expected a xegpu.create_nd_desc op, but got: "
171 diag.attachNote(
target->getLoc()) <<
"target op";
178 auto newdescOp =
setDescLayout(rewriter, descOp, layoutAttr);
181 results.
set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
186void transform::SetDescLayoutOp::getEffects(
197class XeGPUTransformDialectExtension
199 XeGPUTransformDialectExtension> {
208void XeGPUTransformDialectExtension::init() {
209 declareGeneratedDialect<scf::SCFDialect>();
210 declareGeneratedDialect<arith::ArithDialect>();
211 declareGeneratedDialect<xegpu::XeGPUDialect>();
213 registerTransformOps<
215#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
220#define GET_OP_CLASSES
221#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumResults()
Return the number of results held by this operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void registerTransformDialectExtension(DialectRegistry ®istry)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.