MLIR 22.0.0git
XeGPUTransformOps.cpp
Go to the documentation of this file.
1//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13
14#include <optional>
15
16using namespace mlir;
17using namespace mlir::transform;
18
19/// Assuming that `ofr` is an index attr or a param of index type
20/// or a transform dialect handle mapped to exactly one op
21/// with one index result, get that value and cast it to int type.
23 transform::TransformState &state, TransformOpInterface transformOp,
25 for (OpFoldResult ofr : ofrs) {
26 // Attribute case.
27 if (auto attr = dyn_cast<Attribute>(ofr)) {
28 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
29 result.push_back(intAttr.getInt());
30 continue;
31 }
32 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
33 }
34
35 // Transform param case.
36 Value transformValue = cast<Value>(ofr);
37 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
38 ArrayRef<Attribute> params = state.getParams(transformValue);
39 if (params.size() != 1)
40 return transformOp.emitDefiniteFailure()
41 << "requires exactly one parameter associated";
42 result.push_back(
43 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
44 continue;
45 }
46
47 // Payload value case.
48 auto payloadOps = state.getPayloadOps(transformValue);
49 if (!llvm::hasSingleElement(payloadOps)) {
51 transformOp.emitSilenceableError()
52 << "handle must be mapped to exactly one payload op";
53 diag.attachNote(transformValue.getLoc())
54 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
55 return diag;
56 }
57
58 Operation *op = *payloadOps.begin();
59 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
61 transformOp.emitSilenceableError()
62 << "payload op must have exactly 1 index result";
63 diag.attachNote(op->getLoc())
64 << "has " << op->getNumResults() << " results";
65 return diag;
66 }
67
68 IntegerAttr intAttr;
69 if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
70 return transformOp.emitSilenceableError()
71 << "requires param or handle to be the result of a constant like "
72 "op";
73
74 result.push_back(intAttr.getInt());
75 }
77}
78
79/// Create a layout attribute from the given parameters.
80static xegpu::LayoutAttr
82 ArrayRef<int32_t> sgData,
83 std::optional<ArrayRef<int32_t>> instData) {
84 return xegpu::LayoutAttr::get(
85 ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
86 DenseI32ArrayAttr::get(ctx, sgData),
87 instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
88 /*lane_layout=*/nullptr,
89 /*lane_data=*/nullptr,
90 /*order=*/nullptr);
91}
92
93/// Replace xegpu.create_nd_desc op with a new one with the given layout.
94static xegpu::CreateNdDescOp
96 xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
97 assert(descOp.getMixedOffsets().size() == 0 &&
98 "create desc op with offsets is not supported");
99 auto oldTensorDesc = descOp.getType();
100 auto descType = xegpu::TensorDescType::get(
101 oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
102 /*array_length=*/oldTensorDesc.getArrayLength(),
103 /*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
104 /*memory_space=*/oldTensorDesc.getMemorySpace(),
105 /*layout=*/layout);
106
107 rewriter.setInsertionPointAfter(descOp);
108 auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
109 descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
110 descOp.getMixedStrides());
111 return newDescOp;
112}
113
114void transform::SetDescLayoutOp::build(OpBuilder &builder,
116 ArrayRef<OpFoldResult> mixedSgLayout,
117 ArrayRef<OpFoldResult> mixedSgData,
118 ArrayRef<OpFoldResult> mixedInstData) {
119 SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
120 SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
121 dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
122 dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
123 dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
124 build(builder, result, target.getType(),
125 /*target=*/target,
126 /*sg_layout=*/dynamicSgLayout,
127 /*sg_data=*/dynamicSgData,
128 /*inst_data=*/dynamicInstData,
129 /*static_sg_layout=*/staticSgLayout,
130 /*static_sg_data=*/staticSgData,
131 /*static_inst_data=*/staticInstData);
132}
133
135transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
138 auto targetOps = state.getPayloadOps(getTarget());
139 if (!llvm::hasSingleElement(targetOps)) {
140 return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
141 << llvm::range_size(targetOps) << ")";
142 }
143 Operation *target = *targetOps.begin();
144
145 SmallVector<int32_t> sgLayout;
147 convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
148 if (!status.succeeded())
149 return status;
150
152 status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
153 if (!status.succeeded())
154 return status;
155
156 SmallVector<int32_t> instData;
157 status =
158 convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
159 if (!status.succeeded())
160 return status;
161 auto maybeInstData = instData.empty()
162 ? std::nullopt
163 : std::optional<ArrayRef<int32_t>>(instData);
164
165 // For now only create_nd_desc op is supported.
166 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
167 if (!descOp) {
168 auto diag = emitSilenceableFailure(getLoc())
169 << "Expected a xegpu.create_nd_desc op, but got: "
170 << target->getName();
171 diag.attachNote(target->getLoc()) << "target op";
172 return diag;
173 }
174
175 // Set layout attr in desc op's return type. Replaces old desc op.
176 auto layoutAttr =
177 createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
178 auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
179
180 // Map result handles.
181 results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
182
184}
185
186void transform::SetDescLayoutOp::getEffects(
188 consumesHandle(getTargetMutable(), effects);
189 onlyReadsHandle(getSgLayoutMutable(), effects);
190 onlyReadsHandle(getSgDataMutable(), effects);
191 onlyReadsHandle(getInstDataMutable(), effects);
192 producesHandle(getOperation()->getOpResults(), effects);
193 modifiesPayload(effects);
194}
195
196namespace {
197class XeGPUTransformDialectExtension
199 XeGPUTransformDialectExtension> {
200public:
201 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
202
203 using Base::Base;
204
205 void init();
206};
207
208void XeGPUTransformDialectExtension::init() {
209 declareGeneratedDialect<scf::SCFDialect>();
210 declareGeneratedDialect<arith::ArithDialect>();
211 declareGeneratedDialect<xegpu::XeGPUDialect>();
212
213 registerTransformOps<
214#define GET_OP_LIST
215#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
216 >();
217}
218} // namespace
219
220#define GET_OP_CLASSES
221#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
222
224 registry.addExtensions<XeGPUTransformDialectExtension>();
225}
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter, xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout)
Replace xegpu.create_nd_desc op with a new one with the given layout.
static DiagnosedSilenceableFailure convertMixedValuesToInt(transform::TransformState &state, TransformOpInterface transformOp, SmallVectorImpl< int32_t > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef< int32_t > sgLayout, ArrayRef< int32_t > sgData, std::optional< ArrayRef< int32_t > > instData)
Create a layout attribute from the given parameters.
MLIRContext * getContext() const
Definition Builders.h:56
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIndex() const
Definition Types.cpp:54
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void registerTransformDialectExtension(DialectRegistry &registry)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This represents an operation in an abstracted form, suitable for use with the builder APIs.