MLIR 23.0.0git
XeGPUTransformOps.cpp
Go to the documentation of this file.
1//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "llvm/ADT/SmallVectorExtras.h"
16
17#include <optional>
18
19#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "xegpu-transforms"
21
22using namespace mlir;
23using namespace mlir::transform;
24
25/// Assuming that `ofr` is an index attr or a param of index type
26/// or a transform dialect handle mapped to exactly one op
27/// with one index result, get that value and cast it to int type.
29 transform::TransformState &state, TransformOpInterface transformOp,
31 for (OpFoldResult ofr : ofrs) {
32 // Attribute case.
33 if (auto attr = dyn_cast<Attribute>(ofr)) {
34 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
35 result.push_back(intAttr.getInt());
36 continue;
37 }
38 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
39 }
40
41 // Transform param case.
42 Value transformValue = cast<Value>(ofr);
43 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
44 ArrayRef<Attribute> params = state.getParams(transformValue);
45 if (params.size() != 1)
46 return transformOp.emitDefiniteFailure()
47 << "requires exactly one parameter associated";
48 result.push_back(
49 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
50 continue;
51 }
52
53 // Payload value case.
54 auto payloadOps = state.getPayloadOps(transformValue);
55 if (!llvm::hasSingleElement(payloadOps)) {
57 transformOp.emitSilenceableError()
58 << "handle must be mapped to exactly one payload op";
59 diag.attachNote(transformValue.getLoc())
60 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
61 return diag;
62 }
63
64 Operation *op = *payloadOps.begin();
65 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
67 transformOp.emitSilenceableError()
68 << "payload op must have exactly 1 index result";
69 diag.attachNote(op->getLoc())
70 << "has " << op->getNumResults() << " results";
71 return diag;
72 }
73
74 IntegerAttr intAttr;
75 if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
76 return transformOp.emitSilenceableError()
77 << "requires param or handle to be the result of a constant like "
78 "op";
79
80 result.push_back(intAttr.getInt());
81 }
83}
84
85/// Find producer operation of type T for the given value.
86/// It's assumed that producer ops are chained through their first operand.
87/// Producer chain is traced trough loop block arguments (init values).
88template <typename T>
89static std::optional<T> findProducerOfType(Value val) {
90 Value currentValue = val;
91 if (!currentValue.getDefiningOp()) {
92 // Value may be a block argument initialized outside a loop.
93 if (val.getNumUses() == 0) {
94 LDBG() << "Failed to find producer op, value has no uses.";
95 return std::nullopt;
96 }
97 auto userOp = val.getUsers().begin();
98 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
99 if (!parentLoop) {
100 LDBG() << "Failed to find producer op, not in a loop.";
101 return std::nullopt;
102 }
103 int64_t iterArgIdx;
104 if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
105 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
106 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
107 currentValue = parentLoop.getInits()[iterArgIdx];
108 } else {
109 LDBG() << "Failed to find producer op, value not in init values.";
110 return std::nullopt;
111 }
112 }
113 Operation *producerOp = currentValue.getDefiningOp();
114
115 if (auto matchingOp = dyn_cast<T>(producerOp))
116 return matchingOp;
117
118 if (producerOp->getNumOperands() == 0)
119 return std::nullopt;
120
121 return findProducerOfType<T>(producerOp->getOperand(0));
122}
123
124/// Create a layout attribute from the given parameters.
125static xegpu::LayoutAttr createLayoutAttr(
126 MLIRContext *ctx, ArrayRef<int32_t> sgLayout, ArrayRef<int32_t> sgData,
127 std::optional<ArrayRef<int32_t>> instData, ArrayRef<int32_t> order) {
128 return xegpu::LayoutAttr::get(
129 ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
130 DenseI32ArrayAttr::get(ctx, sgData),
131 instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
132 /*lane_layout=*/nullptr,
133 /*lane_data=*/nullptr,
134 /*order=*/order.empty() ? nullptr : DenseI32ArrayAttr::get(ctx, order));
135}
136
137/// Generate `xegpu::LayoutAttr` from op mixed layout values.
140 TransformOpInterface transformOp,
141 ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
143 ArrayRef<::mlir::OpFoldResult> mixedInstData,
144 ArrayRef<int32_t> order,
145 xegpu::LayoutAttr &layoutAttr) {
146 SmallVector<int32_t> sgLayout, sgData, instData;
147 auto status =
148 convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
149 if (!status.succeeded())
150 return status;
151
152 status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
153 if (!status.succeeded())
154 return status;
155
156 status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
157 if (!status.succeeded())
158 return status;
159 auto maybeInstData = instData.empty()
160 ? std::nullopt
161 : std::optional<ArrayRef<int32_t>>(instData);
162
163 layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData, order);
164
166}
167
169transform::GetLoadOp::apply(transform::TransformRewriter &rewriter,
172 auto targetValues = state.getPayloadValues(getTarget());
173 if (!llvm::hasSingleElement(targetValues)) {
174 return emitDefiniteFailure()
175 << "requires exactly one target value handle (got "
176 << llvm::range_size(targetValues) << ")";
177 }
178
179 Operation *loadOp = nullptr;
180 auto maybeLoadNdOp =
181 findProducerOfType<xegpu::LoadNdOp>(*targetValues.begin());
182 if (maybeLoadNdOp) {
183 loadOp = maybeLoadNdOp->getOperation();
184 } else {
185 auto maybeLoadOp =
186 findProducerOfType<xegpu::LoadGatherOp>(*targetValues.begin());
187 if (maybeLoadOp) {
188 loadOp = maybeLoadOp->getOperation();
189 } else {
190 return emitSilenceableFailure(getLoc())
191 << "Could not find a matching xegpu.load_nd or xegpu.load op when "
192 "walking the "
193 "producer chain of the first operand.";
194 }
195 }
196
197 results.set(llvm::cast<OpResult>(getResult()), {loadOp});
199}
200
201void transform::SetAnchorLayoutOp::build(
202 OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
203 ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
204 ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int32_t> order,
205 ArrayRef<int64_t> sliceDims) {
206 SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
207 SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
208 dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
209 dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
210 dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
211 build(builder, ostate, target.getType(),
212 /*target=*/target,
213 /*index=*/index,
214 /*sg_layout=*/dynamicSgLayout,
215 /*sg_data=*/dynamicSgData,
216 /*inst_data=*/dynamicInstData,
217 /*static_sg_layout=*/staticSgLayout,
218 /*static_sg_data=*/staticSgData,
219 /*static_inst_data=*/staticInstData,
220 /*order=*/order,
221 /*slice_dims=*/sliceDims);
222}
223
225transform::SetAnchorLayoutOp::apply(transform::TransformRewriter &rewriter,
228 auto targetOps = state.getPayloadOps(getTarget());
230
231 // Construct layout attribute.
232 xegpu::LayoutAttr layoutAttr = nullptr;
233 auto status = getLayoutAttrFromOperands(
234 getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(),
235 getMixedInstData(), getOrder(), layoutAttr);
236 if (!status.succeeded())
237 return status;
238
239 xegpu::DistributeLayoutAttr layout = layoutAttr;
240 auto sliceDims = getSliceDims();
241 if (sliceDims.size() > 0) {
242 // Wrap layoutAttr in a slice attribute.
243 layout = xegpu::SliceAttr::get(
244 getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
245 }
246
247 // Apply the layout to all target ops.
248 for (Operation *target : targetOps) {
249 // Set layout attribute
250 if (auto dpasOp = dyn_cast<xegpu::DpasOp>(target)) {
251 // dpas op is a special case where layout needs to be set for A, B, and C
252 if (index == 0)
253 dpasOp.getProperties().layout_a = layout;
254 else if (index == 1)
255 dpasOp.getProperties().layout_b = layout;
256 else if (index == 2)
257 dpasOp.getProperties().layout_cd = layout;
258 else {
259 auto diag = emitSilenceableFailure(getLoc())
260 << "Invalid index for setting dpas op layout: " << index;
261 diag.attachNote(target->getLoc()) << "target op";
262 return diag;
263 }
264 } else {
265 // op's anchor layout.
266 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(target);
267 if (!anchorOp) {
268 auto diag = emitSilenceableFailure(getLoc())
269 << "Cannot set anchor layout to op: " << target->getName();
270 diag.attachNote(target->getLoc()) << "target op";
271 return diag;
272 }
273 anchorOp.setAnchorLayout(layout);
274 }
275 }
277}
278
279void transform::SetAnchorLayoutOp::getEffects(
281 onlyReadsHandle(getTargetMutable(), effects);
282 onlyReadsHandle(getSgLayoutMutable(), effects);
283 onlyReadsHandle(getSgDataMutable(), effects);
284 onlyReadsHandle(getInstDataMutable(), effects);
285 modifiesPayload(effects);
286}
287
288void transform::SetGPULaunchThreadsOp::build(
289 OpBuilder &builder, OperationState &ostate, Value target,
290 ArrayRef<OpFoldResult> mixedThreads) {
291 SmallVector<int64_t> staticThreads;
292 SmallVector<Value> dynamicThreads;
293 dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
294 build(builder, ostate, target.getType(),
295 /*target=*/target,
296 /*threads=*/dynamicThreads,
297 /*static_threads=*/staticThreads);
298}
299
301transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
304 auto targetOps = state.getPayloadOps(getTarget());
305 if (!llvm::hasSingleElement(targetOps)) {
306 return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
307 << llvm::range_size(targetOps) << ")";
308 }
309 Operation *target = *targetOps.begin();
310
311 auto launchOp = dyn_cast<gpu::LaunchOp>(target);
312 if (!launchOp) {
313 auto diag = emitSilenceableFailure(getLoc())
314 << "Expected a gpu.launch op, but got: " << target->getName();
315 diag.attachNote(target->getLoc()) << "target op";
316 return diag;
317 }
318
319 SmallVector<int32_t> threads;
321 convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
322 if (!status.succeeded())
323 return status;
324
325 if (threads.size() != 3) {
326 return emitSilenceableFailure(getLoc())
327 << "Expected threads argument to consist of three values (got "
328 << threads.size() << ")";
329 }
330
331 rewriter.setInsertionPoint(launchOp);
332 auto createConstValue = [&](int value) {
333 return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
334 };
335
336 // Replace threads in-place.
337 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
338 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
339 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
340
342}
343
344void transform::SetGPULaunchThreadsOp::getEffects(
346 onlyReadsHandle(getTargetMutable(), effects);
347 onlyReadsHandle(getThreadsMutable(), effects);
348 modifiesPayload(effects);
349}
350
352transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
355 auto targetOps = state.getPayloadOps(getTarget());
356 if (!llvm::hasSingleElement(targetOps))
357 return emitDefiniteFailure()
358 << "requires exactly one target op handle (got "
359 << llvm::range_size(targetOps) << ")";
360 auto target = *targetOps.begin();
361
362 int64_t nbPrefetch = getStaticNbPrefetch();
363 if (getDynamicNbPrefetch()) {
364 // Get dynamic prefetch count from transform param or handle.
365 SmallVector<int32_t> dynamicNbPrefetch;
366 auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
367 {getDynamicNbPrefetch()});
368 if (!status.succeeded())
369 return status;
370 if (dynamicNbPrefetch.size() != 1)
371 return emitDefiniteFailure()
372 << "requires exactly one value for dynamic_nb_prefetch";
373 nbPrefetch = dynamicNbPrefetch[0];
374 }
375 if (nbPrefetch <= 0)
376 return emitSilenceableFailure(getLoc())
377 << "nb_prefetch must be a positive integer.";
378
379 // Cast target to load op.
380 auto maybeLoadOp = dyn_cast<xegpu::LoadNdOp>(target);
381 if (!maybeLoadOp) {
382 return emitSilenceableFailure(getLoc())
383 << "Expected xegpu.load_nd op, got " << target->getName();
384 }
385 auto loadOp = maybeLoadOp;
386 if (loadOp.getMixedOffsets().size() == 0) {
387 auto diag = emitSilenceableFailure(getLoc())
388 << "Load op must have offsets.";
389 diag.attachNote(loadOp.getLoc()) << "load op";
390 return diag;
391 }
392
393 // Find the parent scf.for loop.
394 auto forOp = loadOp->getParentOfType<scf::ForOp>();
395 if (!forOp) {
396 auto diag = emitSilenceableFailure(getLoc())
397 << "Load op is not contained in a scf.for loop.";
398 diag.attachNote(loadOp.getLoc()) << "load op";
399 return diag;
400 }
401
402 // Find descriptor op.
403 auto maybeDescOp =
405 if (!maybeDescOp)
406 return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
407 auto descOp = *maybeDescOp;
408 if (descOp.getMixedOffsets().size() > 0) {
409 auto diag = emitSilenceableFailure(getLoc())
410 << "desc op with offsets is not supported.";
411 diag.attachNote(descOp.getLoc()) << "desc op";
412 }
413
414 // Clone desc op outside the loop.
415 rewriter.setInsertionPoint(forOp);
416 auto newDescOp =
417 cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
418
419 // Clone reduction loop to emit initial prefetches.
420 // Compute upper bound of the init loop: start + nbPrefetch * step.
421 auto nbPrefetchCst =
422 arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
423 auto nbStep = rewriter.createOrFold<arith::MulIOp>(
424 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
425 auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
426 forOp.getLoc(), forOp.getLowerBound(), nbStep);
427 auto initForOp =
428 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
429 initUpBound, forOp.getStep());
430
431 auto ctx = rewriter.getContext();
432 auto readCacheHint =
433 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
434
435 // Modify loadOp mixedOffsets by replacing the for loop induction variable
436 // with the given value.
437 auto getPrefetchOffsets =
438 [&](Value replacementVal) -> SmallVector<OpFoldResult> {
439 IRMapping mapping;
440 mapping.map(forOp.getInductionVar(), replacementVal);
441 SmallVector<Value> dynamicOffsets =
442 llvm::map_to_vector(loadOp.getOffsets(), [&](Value v) {
443 return mapping.lookupOrDefault(v);
444 });
445 auto constOffsets = loadOp.getConstOffsets().value();
446 return getMixedValues(constOffsets, dynamicOffsets, ctx);
447 };
448
449 // Insert prefetch op in init loop.
450 // Replace induction var with the init loop induction var.
451 rewriter.setInsertionPointToStart(initForOp.getBody());
452 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
453 newDescOp.getResult(),
454 getPrefetchOffsets(initForOp.getInductionVar()),
455 readCacheHint, readCacheHint, readCacheHint,
456 /*layout=*/nullptr);
457
458 // Insert prefetch op in main loop.
459 // Calculate prefetch offset after the init prefetches have been issued.
460 rewriter.setInsertionPointToStart(forOp.getBody());
461 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
462 forOp.getInductionVar(), nbStep);
463 // Replace induction var with correct offset.
464 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
465 newDescOp.getResult(),
466 getPrefetchOffsets(prefetchOffset), readCacheHint,
467 readCacheHint, readCacheHint, /*layout=*/nullptr);
468
469 // Unroll the init loop.
470 if (failed(loopUnrollFull(initForOp)))
471 return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
472
473 results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
474
476}
477
478void transform::InsertPrefetchOp::getEffects(
480 onlyReadsHandle(getTargetMutable(), effects);
481 onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
482 producesHandle(getOperation()->getOpResults(), effects);
483 modifiesPayload(effects);
484}
485
486void transform::ConvertLayoutOp::build(
487 OpBuilder &builder, OperationState &ostate, Value target,
488 ArrayRef<OpFoldResult> mixedInputSgLayout,
489 ArrayRef<OpFoldResult> mixedInputSgData,
490 ArrayRef<OpFoldResult> mixedInputInstData, ArrayRef<int32_t> inputOrder,
491 ArrayRef<OpFoldResult> mixedTargetSgLayout,
492 ArrayRef<OpFoldResult> mixedTargetSgData,
493 ArrayRef<OpFoldResult> mixedTargetInstData, ArrayRef<int32_t> targetOrder) {
494 SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
495 staticInputInstData;
496 SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
497 dynamicInputInstData;
498 dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
499 staticInputSgLayout);
500 dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
501 staticInputSgData);
502 dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
503 staticInputInstData);
504 SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
505 staticTargetInstData;
506 SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
507 dynamicTargetInstData;
508 dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
509 staticTargetSgLayout);
510 dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
511 staticTargetSgData);
512 dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
513 staticTargetInstData);
514 build(builder, ostate, target.getType(),
515 /*target=*/target,
516 /*input_sg_layout=*/dynamicInputSgLayout,
517 /*input_sg_data=*/dynamicInputSgData,
518 /*input_inst_data=*/dynamicInputInstData,
519 /*target_sg_layout=*/dynamicTargetSgLayout,
520 /*target_sg_data=*/dynamicTargetSgData,
521 /*target_inst_data=*/dynamicTargetInstData,
522 /*input_order=*/inputOrder,
523 /*static_input_sg_layout=*/staticInputSgLayout,
524 /*static_input_sg_data=*/staticInputSgData,
525 /*static_input_inst_data=*/staticInputInstData,
526 /*static_target_sg_layout=*/staticTargetSgLayout,
527 /*static_target_sg_data=*/staticTargetSgData,
528 /*static_target_inst_data=*/staticTargetInstData,
529 /*target_order=*/targetOrder);
530}
531
533transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
536 auto targetValues = state.getPayloadValues(getTarget());
537 if (!llvm::hasSingleElement(targetValues))
538 return emitDefiniteFailure()
539 << "requires exactly one target value handle (got "
540 << llvm::range_size(targetValues) << ")";
541 auto value = *targetValues.begin();
542
543 // Construct layout attributes.
544 xegpu::LayoutAttr inputLayoutAttr = nullptr;
545 auto status = getLayoutAttrFromOperands(
546 getContext(), state, (*this), getMixedInputSgLayout(),
547 getMixedInputSgData(), getMixedInputInstData(), getInputOrder(),
548 inputLayoutAttr);
549 if (!status.succeeded())
550 return status;
551
552 xegpu::LayoutAttr targetLayoutAttr = nullptr;
554 getContext(), state, (*this), getMixedTargetSgLayout(),
555 getMixedTargetSgData(), getMixedTargetInstData(), getTargetOrder(),
556 targetLayoutAttr);
557 if (!status.succeeded())
558 return status;
559
560 // Find first user op to define insertion point for layout conversion.
561 if (value.use_empty())
562 return emitSilenceableFailure(getLoc())
563 << "Value has no users to insert layout conversion.";
564 Operation *userOp = *value.getUsers().begin();
565
566 // Emit convert_layout op.
567 rewriter.setInsertionPoint(userOp);
568 auto convLayoutOp =
569 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
570 value, inputLayoutAttr, targetLayoutAttr);
571 // Replace load op result with the converted layout.
572 rewriter.replaceUsesWithIf(
573 value, convLayoutOp.getResult(), [&](OpOperand &use) {
574 return use.getOwner() != convLayoutOp.getOperation();
575 });
576
577 results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
579}
580
581void transform::ConvertLayoutOp::getEffects(
583 onlyReadsHandle(getTargetMutable(), effects);
584 onlyReadsHandle(getInputSgLayoutMutable(), effects);
585 onlyReadsHandle(getInputSgDataMutable(), effects);
586 onlyReadsHandle(getInputInstDataMutable(), effects);
587 onlyReadsHandle(getTargetSgLayoutMutable(), effects);
588 onlyReadsHandle(getTargetSgDataMutable(), effects);
589 onlyReadsHandle(getTargetInstDataMutable(), effects);
590 producesHandle(getOperation()->getOpResults(), effects);
591 modifiesPayload(effects);
592}
593
594namespace {
595class XeGPUTransformDialectExtension
597 XeGPUTransformDialectExtension> {
598public:
599 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
600
601 using Base::Base;
602
603 void init();
604};
605
606void XeGPUTransformDialectExtension::init() {
607 declareGeneratedDialect<scf::SCFDialect>();
608 declareGeneratedDialect<arith::ArithDialect>();
609 declareGeneratedDialect<xegpu::XeGPUDialect>();
610
611 registerTransformOps<
612#define GET_OP_LIST
613#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
614 >();
615}
616} // namespace
617
618#define GET_OP_CLASSES
619#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
620
622 registry.addExtensions<XeGPUTransformDialectExtension>();
623}
b getContext())
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static std::optional< T > findProducerOfType(Value val)
Find producer operation of type T for the given value.
static xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef< int32_t > sgLayout, ArrayRef< int32_t > sgData, std::optional< ArrayRef< int32_t > > instData, ArrayRef< int32_t > order)
Create a layout attribute from the given parameters.
static DiagnosedSilenceableFailure convertMixedValuesToInt(transform::TransformState &state, TransformOpInterface transformOp, SmallVectorImpl< int32_t > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
DiagnosedSilenceableFailure getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef<::mlir::OpFoldResult > mixedSgLayout, ArrayRef<::mlir::OpFoldResult > mixedSgData, ArrayRef<::mlir::OpFoldResult > mixedInstData, ArrayRef< int32_t > order, xegpu::LayoutAttr &layoutAttr)
Generate xegpu::LayoutAttr from op mixed layout values.
MLIRContext * getContext() const
Definition Builders.h:56
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
unsigned getNumOperands()
Definition Operation.h:372
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
bool isIndex() const
Definition Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void registerTransformDialectExtension(DialectRegistry &registry)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition Utils.cpp:519
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This represents an operation in an abstracted form, suitable for use with the builder APIs.