MLIR 23.0.0git
XeGPUTransformOps.cpp
Go to the documentation of this file.
1//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "llvm/ADT/SmallVectorExtras.h"
16
17#include <optional>
18
19#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "xegpu-transforms"
21
22using namespace mlir;
23using namespace mlir::transform;
24
25/// Assuming that `ofr` is an index attr or a param of index type
26/// or a transform dialect handle mapped to exactly one op
27/// with one index result, get that value and cast it to int type.
29 transform::TransformState &state, TransformOpInterface transformOp,
31 for (OpFoldResult ofr : ofrs) {
32 // Attribute case.
33 if (auto attr = dyn_cast<Attribute>(ofr)) {
34 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
35 result.push_back(intAttr.getInt());
36 continue;
37 }
38 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
39 }
40
41 // Transform param case.
42 Value transformValue = cast<Value>(ofr);
43 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
44 ArrayRef<Attribute> params = state.getParams(transformValue);
45 if (params.size() != 1)
46 return transformOp.emitDefiniteFailure()
47 << "requires exactly one parameter associated";
48 result.push_back(
49 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
50 continue;
51 }
52
53 // Payload value case.
54 auto payloadOps = state.getPayloadOps(transformValue);
55 if (!llvm::hasSingleElement(payloadOps)) {
57 transformOp.emitSilenceableError()
58 << "handle must be mapped to exactly one payload op";
59 diag.attachNote(transformValue.getLoc())
60 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
61 return diag;
62 }
63
64 Operation *op = *payloadOps.begin();
65 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
67 transformOp.emitSilenceableError()
68 << "payload op must have exactly 1 index result";
69 diag.attachNote(op->getLoc())
70 << "has " << op->getNumResults() << " results";
71 return diag;
72 }
73
74 IntegerAttr intAttr;
75 if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
76 return transformOp.emitSilenceableError()
77 << "requires param or handle to be the result of a constant like "
78 "op";
79
80 result.push_back(intAttr.getInt());
81 }
83}
84
85/// Find producer operation of type T for the given value.
86/// It's assumed that producer ops are chained through their first operand.
87/// Producer chain is traced trough loop block arguments (init values).
88template <typename T>
89static std::optional<T> findProducerOfType(Value val) {
90 Value currentValue = val;
91 if (!currentValue.getDefiningOp()) {
92 // Value may be a block argument initialized outside a loop.
93 if (val.getNumUses() == 0) {
94 LDBG() << "Failed to find producer op, value has no uses.";
95 return std::nullopt;
96 }
97 auto userOp = val.getUsers().begin();
98 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
99 if (!parentLoop) {
100 LDBG() << "Failed to find producer op, not in a loop.";
101 return std::nullopt;
102 }
103 int64_t iterArgIdx;
104 if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
105 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
106 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
107 currentValue = parentLoop.getInits()[iterArgIdx];
108 } else {
109 LDBG() << "Failed to find producer op, value not in init values.";
110 return std::nullopt;
111 }
112 }
113 Operation *producerOp = currentValue.getDefiningOp();
114
115 if (auto matchingOp = dyn_cast<T>(producerOp))
116 return matchingOp;
117
118 if (producerOp->getNumOperands() == 0)
119 return std::nullopt;
120
121 return findProducerOfType<T>(producerOp->getOperand(0));
122}
123
124/// Create a layout attribute from the given parameters.
125static xegpu::LayoutAttr createLayoutAttr(
126 MLIRContext *ctx, ArrayRef<int32_t> sgLayout, ArrayRef<int32_t> sgData,
127 std::optional<ArrayRef<int32_t>> instData, ArrayRef<int32_t> order) {
128 return xegpu::LayoutAttr::get(
129 ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
130 DenseI32ArrayAttr::get(ctx, sgData),
131 instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
132 /*lane_layout=*/nullptr,
133 /*lane_data=*/nullptr,
134 /*order=*/order.empty() ? nullptr : DenseI32ArrayAttr::get(ctx, order));
135}
136
137/// Generate `xegpu::LayoutAttr` from op mixed layout values.
140 TransformOpInterface transformOp,
141 ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
143 ArrayRef<::mlir::OpFoldResult> mixedInstData,
144 ArrayRef<int32_t> order,
145 xegpu::LayoutAttr &layoutAttr) {
146 SmallVector<int32_t> sgLayout, sgData, instData;
147 auto status =
148 convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
149 if (!status.succeeded())
150 return status;
151
152 status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
153 if (!status.succeeded())
154 return status;
155
156 status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
157 if (!status.succeeded())
158 return status;
159 auto maybeInstData = instData.empty()
160 ? std::nullopt
161 : std::optional<ArrayRef<int32_t>>(instData);
162
163 layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData, order);
164
166}
167
169transform::GetLoadOp::apply(transform::TransformRewriter &rewriter,
172 auto targetValues = state.getPayloadValues(getTarget());
173 if (!llvm::hasSingleElement(targetValues)) {
174 return emitDefiniteFailure()
175 << "requires exactly one target value handle (got "
176 << llvm::range_size(targetValues) << ")";
177 }
178
179 Operation *loadOp = nullptr;
180 auto maybeLoadNdOp =
181 findProducerOfType<xegpu::LoadNdOp>(*targetValues.begin());
182 if (maybeLoadNdOp) {
183 loadOp = maybeLoadNdOp->getOperation();
184 } else {
185 auto maybeLoadOp =
186 findProducerOfType<xegpu::LoadGatherOp>(*targetValues.begin());
187 if (maybeLoadOp) {
188 loadOp = maybeLoadOp->getOperation();
189 } else {
190 return emitSilenceableFailure(getLoc())
191 << "Could not find a matching xegpu.load_nd or xegpu.load op when "
192 "walking the "
193 "producer chain of the first operand.";
194 }
195 }
196
197 results.set(llvm::cast<OpResult>(getResult()), {loadOp});
199}
200
201void transform::SetAnchorLayoutOp::build(
202 OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
203 ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
204 ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int32_t> order,
205 ArrayRef<int64_t> sliceDims) {
206 SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
207 SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
208 dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
209 dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
210 dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
211 build(builder, ostate, target.getType(),
212 /*target=*/target,
213 /*index=*/index,
214 /*sg_layout=*/dynamicSgLayout,
215 /*sg_data=*/dynamicSgData,
216 /*inst_data=*/dynamicInstData,
217 /*static_sg_layout=*/staticSgLayout,
218 /*static_sg_data=*/staticSgData,
219 /*static_inst_data=*/staticInstData,
220 /*order=*/order,
221 /*slice_dims=*/sliceDims);
222}
223
225transform::SetAnchorLayoutOp::apply(transform::TransformRewriter &rewriter,
228 auto targetOps = state.getPayloadOps(getTarget());
230
231 // Construct layout attribute.
232 xegpu::LayoutAttr layoutAttr = nullptr;
233 auto status = getLayoutAttrFromOperands(
234 getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(),
235 getMixedInstData(), getOrder(), layoutAttr);
236 if (!status.succeeded())
237 return status;
238
239 xegpu::DistributeLayoutAttr layout = layoutAttr;
240 auto sliceDims = getSliceDims();
241 if (sliceDims.size() > 0) {
242 // Wrap layoutAttr in a slice attribute.
243 layout = xegpu::SliceAttr::get(
244 getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
245 }
246
247 // Apply the layout to all target ops.
248 for (Operation *target : targetOps) {
249 // Set layout attribute
250 if (auto dpasOp = dyn_cast<xegpu::DpasOp>(target)) {
251 // dpas op is a special case where layout needs to be set for A, B, and C
252 if (index == 0)
253 dpasOp.getProperties().layout_a = layout;
254 else if (index == 1)
255 dpasOp.getProperties().layout_b = layout;
256 else if (index == 2)
257 dpasOp.getProperties().layout_cd = layout;
258 else {
259 auto diag = emitSilenceableFailure(getLoc())
260 << "Invalid index for setting dpas op layout: " << index;
261 diag.attachNote(target->getLoc()) << "target op";
262 return diag;
263 }
264 } else {
265 // op's anchor layout.
266 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(target);
267 if (!anchorOp) {
268 auto diag = emitSilenceableFailure(getLoc())
269 << "Cannot set anchor layout to op: " << target->getName();
270 diag.attachNote(target->getLoc()) << "target op";
271 return diag;
272 }
273 anchorOp.setAnchorLayout(layout);
274 }
275 }
277}
278
279void transform::SetAnchorLayoutOp::getEffects(
281 onlyReadsHandle(getTargetMutable(), effects);
282 onlyReadsHandle(getSgLayoutMutable(), effects);
283 onlyReadsHandle(getSgDataMutable(), effects);
284 onlyReadsHandle(getInstDataMutable(), effects);
285 modifiesPayload(effects);
286}
287
288void transform::SetGPULaunchThreadsOp::build(
289 OpBuilder &builder, OperationState &ostate, Value target,
290 ArrayRef<OpFoldResult> mixedThreads) {
291 SmallVector<int64_t> staticThreads;
292 SmallVector<Value> dynamicThreads;
293 dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
294 build(builder, ostate, target.getType(),
295 /*target=*/target,
296 /*threads=*/dynamicThreads,
297 /*static_threads=*/staticThreads);
298}
299
301transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
304 auto targetOps = state.getPayloadOps(getTarget());
305 if (!llvm::hasSingleElement(targetOps)) {
306 return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
307 << llvm::range_size(targetOps) << ")";
308 }
309 Operation *target = *targetOps.begin();
310
311 auto launchOp = dyn_cast<gpu::LaunchOp>(target);
312 if (!launchOp) {
313 auto diag = emitSilenceableFailure(getLoc())
314 << "Expected a gpu.launch op, but got: " << target->getName();
315 diag.attachNote(target->getLoc()) << "target op";
316 return diag;
317 }
318
319 SmallVector<int32_t> threads;
321 convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
322 if (!status.succeeded())
323 return status;
324
325 if (threads.size() != 3) {
326 return emitSilenceableFailure(getLoc())
327 << "Expected threads argument to consist of three values (got "
328 << threads.size() << ")";
329 }
330
331 rewriter.setInsertionPoint(launchOp);
332 auto createConstValue = [&](int value) {
333 return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
334 };
335
336 // Replace threads in-place.
337 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
338 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
339 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
340
342}
343
344void transform::SetGPULaunchThreadsOp::getEffects(
346 onlyReadsHandle(getTargetMutable(), effects);
347 onlyReadsHandle(getThreadsMutable(), effects);
348 modifiesPayload(effects);
349}
350
352transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
355 auto targetOps = state.getPayloadOps(getTarget());
356 if (!llvm::hasSingleElement(targetOps))
357 return emitDefiniteFailure()
358 << "requires exactly one target op handle (got "
359 << llvm::range_size(targetOps) << ")";
360 auto target = *targetOps.begin();
361
362 int64_t nbPrefetch = getStaticNbPrefetch();
363 if (getDynamicNbPrefetch()) {
364 // Get dynamic prefetch count from transform param or handle.
365 SmallVector<int32_t> dynamicNbPrefetch;
366 auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
367 {getDynamicNbPrefetch()});
368 if (!status.succeeded())
369 return status;
370 if (dynamicNbPrefetch.size() != 1)
371 return emitDefiniteFailure()
372 << "requires exactly one value for dynamic_nb_prefetch";
373 nbPrefetch = dynamicNbPrefetch[0];
374 }
375 if (nbPrefetch <= 0)
376 return emitSilenceableFailure(getLoc())
377 << "nb_prefetch must be a positive integer.";
378
379 // Cast target to load op.
380 auto maybeLoadOp = dyn_cast<xegpu::LoadNdOp>(target);
381 if (!maybeLoadOp) {
382 return emitSilenceableFailure(getLoc())
383 << "Expected xegpu.load_nd op, got " << target->getName();
384 }
385 auto loadOp = maybeLoadOp;
386 if (loadOp.getMixedOffsets().size() == 0) {
387 auto diag = emitSilenceableFailure(getLoc())
388 << "Load op must have offsets.";
389 diag.attachNote(loadOp.getLoc()) << "load op";
390 return diag;
391 }
392
393 // Find the parent scf.for loop.
394 auto forOp = loadOp->getParentOfType<scf::ForOp>();
395 if (!forOp) {
396 auto diag = emitSilenceableFailure(getLoc())
397 << "Load op is not contained in a scf.for loop.";
398 diag.attachNote(loadOp.getLoc()) << "load op";
399 return diag;
400 }
401
402 // Find descriptor op.
403 auto maybeDescOp =
405 if (!maybeDescOp)
406 return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
407 auto descOp = *maybeDescOp;
408
409 // Clone desc op outside the loop.
410 rewriter.setInsertionPoint(forOp);
411 auto newDescOp =
412 cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
413
414 // Clone reduction loop to emit initial prefetches.
415 // Compute upper bound of the init loop: start + nbPrefetch * step.
416 auto nbPrefetchCst =
417 arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
418 auto nbStep = rewriter.createOrFold<arith::MulIOp>(
419 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
420 auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
421 forOp.getLoc(), forOp.getLowerBound(), nbStep);
422 auto initForOp =
423 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
424 initUpBound, forOp.getStep());
425
426 auto ctx = rewriter.getContext();
427 auto readCacheHint =
428 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
429
430 // Modify loadOp mixedOffsets by replacing the for loop induction variable
431 // with the given value.
432 auto getPrefetchOffsets =
433 [&](Value replacementVal) -> SmallVector<OpFoldResult> {
434 IRMapping mapping;
435 mapping.map(forOp.getInductionVar(), replacementVal);
436 SmallVector<Value> dynamicOffsets =
437 llvm::map_to_vector(loadOp.getOffsets(), [&](Value v) {
438 return mapping.lookupOrDefault(v);
439 });
440 auto constOffsets = loadOp.getConstOffsets();
441 return getMixedValues(constOffsets, dynamicOffsets, ctx);
442 };
443
444 // Insert prefetch op in init loop.
445 // Replace induction var with the init loop induction var.
446 rewriter.setInsertionPointToStart(initForOp.getBody());
447 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
448 newDescOp.getResult(),
449 getPrefetchOffsets(initForOp.getInductionVar()),
450 readCacheHint, readCacheHint, readCacheHint,
451 /*layout=*/nullptr);
452
453 // Insert prefetch op in main loop.
454 // Calculate prefetch offset after the init prefetches have been issued.
455 rewriter.setInsertionPointToStart(forOp.getBody());
456 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
457 forOp.getInductionVar(), nbStep);
458 // Replace induction var with correct offset.
459 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
460 newDescOp.getResult(),
461 getPrefetchOffsets(prefetchOffset), readCacheHint,
462 readCacheHint, readCacheHint, /*layout=*/nullptr);
463
464 // Unroll the init loop.
465 if (failed(loopUnrollFull(initForOp)))
466 return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
467
468 results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
469
471}
472
473void transform::InsertPrefetchOp::getEffects(
475 onlyReadsHandle(getTargetMutable(), effects);
476 onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
477 producesHandle(getOperation()->getOpResults(), effects);
478 modifiesPayload(effects);
479}
480
481void transform::ConvertLayoutOp::build(
482 OpBuilder &builder, OperationState &ostate, Value target,
483 ArrayRef<OpFoldResult> mixedInputSgLayout,
484 ArrayRef<OpFoldResult> mixedInputSgData,
485 ArrayRef<OpFoldResult> mixedInputInstData, ArrayRef<int32_t> inputOrder,
486 ArrayRef<OpFoldResult> mixedTargetSgLayout,
487 ArrayRef<OpFoldResult> mixedTargetSgData,
488 ArrayRef<OpFoldResult> mixedTargetInstData, ArrayRef<int32_t> targetOrder) {
489 SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
490 staticInputInstData;
491 SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
492 dynamicInputInstData;
493 dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
494 staticInputSgLayout);
495 dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
496 staticInputSgData);
497 dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
498 staticInputInstData);
499 SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
500 staticTargetInstData;
501 SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
502 dynamicTargetInstData;
503 dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
504 staticTargetSgLayout);
505 dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
506 staticTargetSgData);
507 dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
508 staticTargetInstData);
509 build(builder, ostate, target.getType(),
510 /*target=*/target,
511 /*input_sg_layout=*/dynamicInputSgLayout,
512 /*input_sg_data=*/dynamicInputSgData,
513 /*input_inst_data=*/dynamicInputInstData,
514 /*target_sg_layout=*/dynamicTargetSgLayout,
515 /*target_sg_data=*/dynamicTargetSgData,
516 /*target_inst_data=*/dynamicTargetInstData,
517 /*input_order=*/inputOrder,
518 /*static_input_sg_layout=*/staticInputSgLayout,
519 /*static_input_sg_data=*/staticInputSgData,
520 /*static_input_inst_data=*/staticInputInstData,
521 /*static_target_sg_layout=*/staticTargetSgLayout,
522 /*static_target_sg_data=*/staticTargetSgData,
523 /*static_target_inst_data=*/staticTargetInstData,
524 /*target_order=*/targetOrder);
525}
526
528transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
531 auto targetValues = state.getPayloadValues(getTarget());
532 if (!llvm::hasSingleElement(targetValues))
533 return emitDefiniteFailure()
534 << "requires exactly one target value handle (got "
535 << llvm::range_size(targetValues) << ")";
536 auto value = *targetValues.begin();
537
538 // Construct layout attributes.
539 xegpu::LayoutAttr inputLayoutAttr = nullptr;
540 auto status = getLayoutAttrFromOperands(
541 getContext(), state, (*this), getMixedInputSgLayout(),
542 getMixedInputSgData(), getMixedInputInstData(), getInputOrder(),
543 inputLayoutAttr);
544 if (!status.succeeded())
545 return status;
546
547 xegpu::LayoutAttr targetLayoutAttr = nullptr;
549 getContext(), state, (*this), getMixedTargetSgLayout(),
550 getMixedTargetSgData(), getMixedTargetInstData(), getTargetOrder(),
551 targetLayoutAttr);
552 if (!status.succeeded())
553 return status;
554
555 // Find first user op to define insertion point for layout conversion.
556 if (value.use_empty())
557 return emitSilenceableFailure(getLoc())
558 << "Value has no users to insert layout conversion.";
559 Operation *userOp = *value.getUsers().begin();
560
561 // Emit convert_layout op.
562 rewriter.setInsertionPoint(userOp);
563 auto convLayoutOp =
564 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
565 value, inputLayoutAttr, targetLayoutAttr);
566 // Replace load op result with the converted layout.
567 rewriter.replaceUsesWithIf(
568 value, convLayoutOp.getResult(), [&](OpOperand &use) {
569 return use.getOwner() != convLayoutOp.getOperation();
570 });
571
572 results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
574}
575
576void transform::ConvertLayoutOp::getEffects(
578 onlyReadsHandle(getTargetMutable(), effects);
579 onlyReadsHandle(getInputSgLayoutMutable(), effects);
580 onlyReadsHandle(getInputSgDataMutable(), effects);
581 onlyReadsHandle(getInputInstDataMutable(), effects);
582 onlyReadsHandle(getTargetSgLayoutMutable(), effects);
583 onlyReadsHandle(getTargetSgDataMutable(), effects);
584 onlyReadsHandle(getTargetInstDataMutable(), effects);
585 producesHandle(getOperation()->getOpResults(), effects);
586 modifiesPayload(effects);
587}
588
589namespace {
590class XeGPUTransformDialectExtension
592 XeGPUTransformDialectExtension> {
593public:
594 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
595
596 using Base::Base;
597
598 void init();
599};
600
601void XeGPUTransformDialectExtension::init() {
602 declareGeneratedDialect<scf::SCFDialect>();
603 declareGeneratedDialect<arith::ArithDialect>();
604 declareGeneratedDialect<xegpu::XeGPUDialect>();
605
606 registerTransformOps<
607#define GET_OP_LIST
608#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
609 >();
610}
611} // namespace
612
613#define GET_OP_CLASSES
614#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
615
617 registry.addExtensions<XeGPUTransformDialectExtension>();
618}
b getContext())
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static std::optional< T > findProducerOfType(Value val)
Find producer operation of type T for the given value.
static xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef< int32_t > sgLayout, ArrayRef< int32_t > sgData, std::optional< ArrayRef< int32_t > > instData, ArrayRef< int32_t > order)
Create a layout attribute from the given parameters.
static DiagnosedSilenceableFailure convertMixedValuesToInt(transform::TransformState &state, TransformOpInterface transformOp, SmallVectorImpl< int32_t > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
DiagnosedSilenceableFailure getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef<::mlir::OpFoldResult > mixedSgLayout, ArrayRef<::mlir::OpFoldResult > mixedSgData, ArrayRef<::mlir::OpFoldResult > mixedInstData, ArrayRef< int32_t > order, xegpu::LayoutAttr &layoutAttr)
Generate xegpu::LayoutAttr from op mixed layout values.
MLIRContext * getContext() const
Definition Builders.h:56
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
unsigned getNumOperands()
Definition Operation.h:372
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
bool isIndex() const
Definition Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void registerTransformDialectExtension(DialectRegistry &registry)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition Utils.cpp:519
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This represents an operation in an abstracted form, suitable for use with the builder APIs.