MLIR 23.0.0git
XeGPUTransformOps.cpp
Go to the documentation of this file.
1//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "llvm/ADT/SmallVectorExtras.h"
16
17#include <optional>
18
19#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "xegpu-transforms"
21
22using namespace mlir;
23using namespace mlir::transform;
24
25/// Assuming that `ofr` is an index attr or a param of index type
26/// or a transform dialect handle mapped to exactly one op
27/// with one index result, get that value and cast it to int type.
29 transform::TransformState &state, TransformOpInterface transformOp,
31 for (OpFoldResult ofr : ofrs) {
32 // Attribute case.
33 if (auto attr = dyn_cast<Attribute>(ofr)) {
34 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
35 result.push_back(intAttr.getInt());
36 continue;
37 }
38 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
39 }
40
41 // Transform param case.
42 Value transformValue = cast<Value>(ofr);
43 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
44 ArrayRef<Attribute> params = state.getParams(transformValue);
45 if (params.size() != 1)
46 return transformOp.emitDefiniteFailure()
47 << "requires exactly one parameter associated";
48 result.push_back(
49 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
50 continue;
51 }
52
53 // Payload value case.
54 auto payloadOps = state.getPayloadOps(transformValue);
55 if (!llvm::hasSingleElement(payloadOps)) {
57 transformOp.emitSilenceableError()
58 << "handle must be mapped to exactly one payload op";
59 diag.attachNote(transformValue.getLoc())
60 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
61 return diag;
62 }
63
64 Operation *op = *payloadOps.begin();
65 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
67 transformOp.emitSilenceableError()
68 << "payload op must have exactly 1 index result";
69 diag.attachNote(op->getLoc())
70 << "has " << op->getNumResults() << " results";
71 return diag;
72 }
73
74 IntegerAttr intAttr;
75 if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
76 return transformOp.emitSilenceableError()
77 << "requires param or handle to be the result of a constant like "
78 "op";
79
80 result.push_back(intAttr.getInt());
81 }
83}
84
85/// Find producer operation of type T for the given value.
86/// It's assumed that producer ops are chained through their first operand.
87/// Producer chain is traced trough loop block arguments (init values).
88template <typename T>
89static std::optional<T> findProducerOfType(Value val) {
90 Value currentValue = val;
91 if (!currentValue.getDefiningOp()) {
92 // Value may be a block argument initialized outside a loop.
93 if (val.getNumUses() == 0) {
94 LDBG() << "Failed to find producer op, value has no uses.";
95 return std::nullopt;
96 }
97 auto userOp = val.getUsers().begin();
98 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
99 if (!parentLoop) {
100 LDBG() << "Failed to find producer op, not in a loop.";
101 return std::nullopt;
102 }
103 int64_t iterArgIdx;
104 if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
105 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
106 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
107 currentValue = parentLoop.getInits()[iterArgIdx];
108 } else {
109 LDBG() << "Failed to find producer op, value not in init values.";
110 return std::nullopt;
111 }
112 }
113 Operation *producerOp = currentValue.getDefiningOp();
114
115 if (auto matchingOp = dyn_cast<T>(producerOp))
116 return matchingOp;
117
118 if (producerOp->getNumOperands() == 0)
119 return std::nullopt;
120
121 return findProducerOfType<T>(producerOp->getOperand(0));
122}
123
124/// Create a layout attribute from the given parameters.
125static xegpu::LayoutAttr
127 ArrayRef<int32_t> sgData,
128 std::optional<ArrayRef<int32_t>> instData) {
129 return xegpu::LayoutAttr::get(
130 ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
131 DenseI32ArrayAttr::get(ctx, sgData),
132 instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
133 /*lane_layout=*/nullptr,
134 /*lane_data=*/nullptr,
135 /*order=*/nullptr);
136}
137
138/// Generate `xegpu::LayoutAttr` from op mixed layout values.
141 TransformOpInterface transformOp,
142 ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
144 ArrayRef<::mlir::OpFoldResult> mixedInstData,
145 xegpu::LayoutAttr &layoutAttr) {
146 SmallVector<int32_t> sgLayout, sgData, instData;
147 auto status =
148 convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
149 if (!status.succeeded())
150 return status;
151
152 status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
153 if (!status.succeeded())
154 return status;
155
156 status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
157 if (!status.succeeded())
158 return status;
159 auto maybeInstData = instData.empty()
160 ? std::nullopt
161 : std::optional<ArrayRef<int32_t>>(instData);
162
163 layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
164
166}
167
168/// Replace xegpu.create_nd_desc op with a new one with the given layout.
169static xegpu::CreateNdDescOp
171 xegpu::CreateNdDescOp descOp,
172 xegpu::DistributeLayoutAttr layout) {
173 assert(descOp.getMixedOffsets().size() == 0 &&
174 "create desc op with offsets is not supported");
175 auto oldTensorDesc = descOp.getType();
176 auto descType = xegpu::TensorDescType::get(
177 oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
178 /*array_length=*/oldTensorDesc.getArrayLength(),
179 /*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
180 /*memory_space=*/oldTensorDesc.getMemorySpace(),
181 /*layout=*/layout);
182
183 rewriter.setInsertionPointAfter(descOp);
184 auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
185 descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
186 descOp.getMixedStrides());
187 return newDescOp;
188}
189
191transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
194 auto targetValues = state.getPayloadValues(getTarget());
195 if (!llvm::hasSingleElement(targetValues)) {
196 return emitDefiniteFailure()
197 << "requires exactly one target value handle (got "
198 << llvm::range_size(targetValues) << ")";
199 }
200
201 auto maybeDescOp =
202 findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
203 if (!maybeDescOp) {
204 return emitSilenceableFailure(getLoc())
205 << "Could not find a matching descriptor op when walking the "
206 "producer chain of the first operand.";
207 }
208
209 results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
211}
212
213void transform::SetDescLayoutOp::build(OpBuilder &builder,
215 ArrayRef<OpFoldResult> mixedSgLayout,
216 ArrayRef<OpFoldResult> mixedSgData,
217 ArrayRef<OpFoldResult> mixedInstData,
218 ArrayRef<int64_t> sliceDims) {
219 SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
220 SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
221 dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
222 dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
223 dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
224 build(builder, result, target.getType(),
225 /*target=*/target,
226 /*sg_layout=*/dynamicSgLayout,
227 /*sg_data=*/dynamicSgData,
228 /*inst_data=*/dynamicInstData,
229 /*static_sg_layout=*/staticSgLayout,
230 /*static_sg_data=*/staticSgData,
231 /*static_inst_data=*/staticInstData,
232 /*slice_dims=*/sliceDims);
233}
234
236transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
239 auto targetOps = state.getPayloadOps(getTarget());
240 if (!llvm::hasSingleElement(targetOps)) {
241 return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
242 << llvm::range_size(targetOps) << ")";
243 }
244 Operation *target = *targetOps.begin();
245
246 xegpu::LayoutAttr layoutAttr = nullptr;
247 auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
248 getMixedSgLayout(), getMixedSgData(),
249 getMixedInstData(), layoutAttr);
250 if (!status.succeeded())
251 return status;
252
253 xegpu::DistributeLayoutAttr layout = layoutAttr;
254 auto sliceDims = getSliceDims();
255 if (sliceDims.size() > 0) {
256 // Wrap layoutAttr in a slice attribute.
257 layout = xegpu::SliceAttr::get(
258 getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
259 }
260
261 // For now only create_nd_desc op is supported.
262 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
263 if (!descOp) {
264 auto diag = emitSilenceableFailure(getLoc())
265 << "Expected a xegpu.create_nd_desc op, but got: "
266 << target->getName();
267 diag.attachNote(target->getLoc()) << "target op";
268 return diag;
269 }
270
271 // Set layout attr in desc op's return type. Replaces old desc op.
272 auto newdescOp = setDescLayout(rewriter, descOp, layout);
273
274 // Map result handles.
275 results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
276
278}
279
280void transform::SetDescLayoutOp::getEffects(
282 consumesHandle(getTargetMutable(), effects);
283 onlyReadsHandle(getSgLayoutMutable(), effects);
284 onlyReadsHandle(getSgDataMutable(), effects);
285 onlyReadsHandle(getInstDataMutable(), effects);
286 producesHandle(getOperation()->getOpResults(), effects);
287 modifiesPayload(effects);
288}
289
290void transform::SetOpLayoutAttrOp::build(
291 OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
292 ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
293 ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
294 bool result) {
295 SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
296 SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
297 dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
298 dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
299 dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
300 build(builder, ostate, target.getType(),
301 /*target=*/target,
302 /*index=*/index,
303 /*sg_layout=*/dynamicSgLayout,
304 /*sg_data=*/dynamicSgData,
305 /*inst_data=*/dynamicInstData,
306 /*static_sg_layout=*/staticSgLayout,
307 /*static_sg_data=*/staticSgData,
308 /*static_inst_data=*/staticInstData,
309 /*slice_dims=*/sliceDims,
310 /*result=*/result);
311}
312
314transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
317 auto targetOps = state.getPayloadOps(getTarget());
318 if (!llvm::hasSingleElement(targetOps)) {
319 return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
320 << llvm::range_size(targetOps) << ")";
321 }
322 Operation *target = *targetOps.begin();
323
324 bool resultTarget = getResult();
325
327 if (resultTarget && index >= target->getNumResults()) {
328 return emitSilenceableFailure(getLoc())
329 << "Index exceeds the number of op results";
330 }
331 if (!resultTarget && index >= target->getNumOperands()) {
332 return emitSilenceableFailure(getLoc())
333 << "Index exceeds the number of op operands";
334 }
335
336 xegpu::LayoutAttr layoutAttr = nullptr;
337 auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
338 getMixedSgLayout(), getMixedSgData(),
339 getMixedInstData(), layoutAttr);
340 if (!status.succeeded())
341 return status;
342
343 xegpu::DistributeLayoutAttr layout = layoutAttr;
344 auto sliceDims = getSliceDims();
345 if (sliceDims.size() > 0) {
346 // Wrap layoutAttr in a slice attribute.
347 layout = xegpu::SliceAttr::get(
348 getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
349 }
350
351 // Set layout attribute for the op result or operand
352 if (resultTarget)
353 xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
354 else
355 xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
357}
358
359void transform::SetOpLayoutAttrOp::getEffects(
361 onlyReadsHandle(getTargetMutable(), effects);
362 onlyReadsHandle(getSgLayoutMutable(), effects);
363 onlyReadsHandle(getSgDataMutable(), effects);
364 onlyReadsHandle(getInstDataMutable(), effects);
365 modifiesPayload(effects);
366}
367
368void transform::SetGPULaunchThreadsOp::build(
369 OpBuilder &builder, OperationState &ostate, Value target,
370 ArrayRef<OpFoldResult> mixedThreads) {
371 SmallVector<int64_t> staticThreads;
372 SmallVector<Value> dynamicThreads;
373 dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
374 build(builder, ostate, target.getType(),
375 /*target=*/target,
376 /*threads=*/dynamicThreads,
377 /*static_threads=*/staticThreads);
378}
379
381transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
384 auto targetOps = state.getPayloadOps(getTarget());
385 if (!llvm::hasSingleElement(targetOps)) {
386 return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
387 << llvm::range_size(targetOps) << ")";
388 }
389 Operation *target = *targetOps.begin();
390
391 auto launchOp = dyn_cast<gpu::LaunchOp>(target);
392 if (!launchOp) {
393 auto diag = emitSilenceableFailure(getLoc())
394 << "Expected a gpu.launch op, but got: " << target->getName();
395 diag.attachNote(target->getLoc()) << "target op";
396 return diag;
397 }
398
399 SmallVector<int32_t> threads;
401 convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
402 if (!status.succeeded())
403 return status;
404
405 if (threads.size() != 3) {
406 return emitSilenceableFailure(getLoc())
407 << "Expected threads argument to consist of three values (got "
408 << threads.size() << ")";
409 }
410
411 rewriter.setInsertionPoint(launchOp);
412 auto createConstValue = [&](int value) {
413 return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
414 };
415
416 // Replace threads in-place.
417 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
418 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
419 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
420
422}
423
424void transform::SetGPULaunchThreadsOp::getEffects(
426 onlyReadsHandle(getTargetMutable(), effects);
427 onlyReadsHandle(getThreadsMutable(), effects);
428 modifiesPayload(effects);
429}
430
432transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
435 auto targetValues = state.getPayloadValues(getTarget());
436 if (!llvm::hasSingleElement(targetValues))
437 return emitDefiniteFailure()
438 << "requires exactly one target value handle (got "
439 << llvm::range_size(targetValues) << ")";
440 auto value = *targetValues.begin();
441
442 int64_t nbPrefetch = getStaticNbPrefetch();
443 if (getDynamicNbPrefetch()) {
444 // Get dynamic prefetch count from transform param or handle.
445 SmallVector<int32_t> dynamicNbPrefetch;
446 auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
447 {getDynamicNbPrefetch()});
448 if (!status.succeeded())
449 return status;
450 if (dynamicNbPrefetch.size() != 1)
451 return emitDefiniteFailure()
452 << "requires exactly one value for dynamic_nb_prefetch";
453 nbPrefetch = dynamicNbPrefetch[0];
454 }
455 if (nbPrefetch <= 0)
456 return emitSilenceableFailure(getLoc())
457 << "nb_prefetch must be a positive integer.";
458
459 // Find load operation of the operand.
460 auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
461 if (!maybeLoadOp)
462 return emitSilenceableFailure(getLoc()) << "Could not find load op.";
463 auto loadOp = *maybeLoadOp;
464 if (loadOp.getMixedOffsets().size() == 0) {
465 auto diag = emitSilenceableFailure(getLoc())
466 << "Load op must have offsets.";
467 diag.attachNote(loadOp.getLoc()) << "load op";
468 return diag;
469 }
470
471 // Find the parent scf.for loop.
472 auto forOp = loadOp->getParentOfType<scf::ForOp>();
473 if (!forOp) {
474 auto diag = emitSilenceableFailure(getLoc())
475 << "Load op is not contained in a scf.for loop.";
476 diag.attachNote(loadOp.getLoc()) << "load op";
477 return diag;
478 }
479
480 // Find descriptor op.
481 auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
482 if (!maybeDescOp)
483 return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
484 auto descOp = *maybeDescOp;
485 if (descOp.getMixedOffsets().size() > 0) {
486 auto diag = emitSilenceableFailure(getLoc())
487 << "desc op with offsets is not supported.";
488 diag.attachNote(descOp.getLoc()) << "desc op";
489 }
490
491 // Clone desc op outside the loop.
492 rewriter.setInsertionPoint(forOp);
493 auto newDescOp =
494 cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
495
496 // Clone reduction loop to emit initial prefetches.
497 // Compute upper bound of the init loop: start + nbPrefetch * step.
498 auto nbPrefetchCst =
499 arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
500 auto nbStep = rewriter.createOrFold<arith::MulIOp>(
501 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
502 auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
503 forOp.getLoc(), forOp.getLowerBound(), nbStep);
504 auto initForOp =
505 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
506 initUpBound, forOp.getStep());
507
508 auto ctx = rewriter.getContext();
509 auto readCacheHint =
510 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
511
512 // Modify loadOp mixedOffsets by replacing the for loop induction variable
513 // with the given value.
514 auto getPrefetchOffsets =
515 [&](Value replacementVal) -> SmallVector<OpFoldResult> {
516 IRMapping mapping;
517 mapping.map(forOp.getInductionVar(), replacementVal);
518 SmallVector<Value> dynamicOffsets =
519 llvm::map_to_vector(loadOp.getOffsets(), [&](Value v) {
520 return mapping.lookupOrDefault(v);
521 });
522 auto constOffsets = loadOp.getConstOffsets().value();
523 return getMixedValues(constOffsets, dynamicOffsets, ctx);
524 };
525
526 // Insert prefetch op in init loop.
527 // Replace induction var with the init loop induction var.
528 rewriter.setInsertionPointToStart(initForOp.getBody());
529 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
530 newDescOp.getResult(),
531 getPrefetchOffsets(initForOp.getInductionVar()),
532 readCacheHint, readCacheHint, readCacheHint,
533 /*layout=*/nullptr);
534
535 // Insert prefetch op in main loop.
536 // Calculate prefetch offset after the init prefetches have been issued.
537 rewriter.setInsertionPointToStart(forOp.getBody());
538 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
539 forOp.getInductionVar(), nbStep);
540 // Replace induction var with correct offset.
541 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
542 newDescOp.getResult(),
543 getPrefetchOffsets(prefetchOffset), readCacheHint,
544 readCacheHint, readCacheHint, /*layout=*/nullptr);
545
546 // Unroll the init loop.
547 if (failed(loopUnrollFull(initForOp)))
548 return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
549
550 results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
551
553}
554
555void transform::InsertPrefetchOp::getEffects(
557 onlyReadsHandle(getTargetMutable(), effects);
558 onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
559 producesHandle(getOperation()->getOpResults(), effects);
560 modifiesPayload(effects);
561}
562
563void transform::ConvertLayoutOp::build(
564 OpBuilder &builder, OperationState &ostate, Value target,
565 ArrayRef<OpFoldResult> mixedInputSgLayout,
566 ArrayRef<OpFoldResult> mixedInputSgData,
567 ArrayRef<OpFoldResult> mixedInputInstData,
568 ArrayRef<OpFoldResult> mixedTargetSgLayout,
569 ArrayRef<OpFoldResult> mixedTargetSgData,
570 ArrayRef<OpFoldResult> mixedTargetInstData) {
571 SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
572 staticInputInstData;
573 SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
574 dynamicInputInstData;
575 dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
576 staticInputSgLayout);
577 dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
578 staticInputSgData);
579 dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
580 staticInputInstData);
581 SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
582 staticTargetInstData;
583 SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
584 dynamicTargetInstData;
585 dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
586 staticTargetSgLayout);
587 dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
588 staticTargetSgData);
589 dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
590 staticTargetInstData);
591 build(builder, ostate, target.getType(),
592 /*target=*/target,
593 /*input_sg_layout=*/dynamicInputSgLayout,
594 /*input_sg_data=*/dynamicInputSgData,
595 /*input_inst_data=*/dynamicInputInstData,
596 /*target_sg_layout=*/dynamicTargetSgLayout,
597 /*target_sg_data=*/dynamicTargetSgData,
598 /*target_inst_data=*/dynamicTargetInstData,
599 /*static_input_sg_layout=*/staticInputSgLayout,
600 /*static_input_sg_data=*/staticInputSgData,
601 /*static_input_inst_data=*/staticInputInstData,
602 /*static_target_sg_layout=*/staticTargetSgLayout,
603 /*static_target_sg_data=*/staticTargetSgData,
604 /*static_target_inst_data=*/staticTargetInstData);
605}
606
608transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
611 auto targetValues = state.getPayloadValues(getTarget());
612 if (!llvm::hasSingleElement(targetValues))
613 return emitDefiniteFailure()
614 << "requires exactly one target value handle (got "
615 << llvm::range_size(targetValues) << ")";
616 auto value = *targetValues.begin();
617
618 // Construct layout attributes.
619 xegpu::LayoutAttr inputLayoutAttr = nullptr;
620 auto status = getLayoutAttrFromOperands(
621 getContext(), state, (*this), getMixedInputSgLayout(),
622 getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
623 if (!status.succeeded())
624 return status;
625
626 xegpu::LayoutAttr targetLayoutAttr = nullptr;
628 getContext(), state, (*this), getMixedTargetSgLayout(),
629 getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
630 if (!status.succeeded())
631 return status;
632
633 // Find first user op to define insertion point for layout conversion.
634 if (value.use_empty())
635 return emitSilenceableFailure(getLoc())
636 << "Value has no users to insert layout conversion.";
637 Operation *userOp = *value.getUsers().begin();
638
639 // Emit convert_layout op.
640 rewriter.setInsertionPoint(userOp);
641 auto convLayoutOp =
642 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
643 value, inputLayoutAttr, targetLayoutAttr);
644 // Replace load op result with the converted layout.
645 rewriter.replaceUsesWithIf(
646 value, convLayoutOp.getResult(), [&](OpOperand &use) {
647 return use.getOwner() != convLayoutOp.getOperation();
648 });
649
650 results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
652}
653
654void transform::ConvertLayoutOp::getEffects(
656 onlyReadsHandle(getTargetMutable(), effects);
657 onlyReadsHandle(getInputSgLayoutMutable(), effects);
658 onlyReadsHandle(getInputSgDataMutable(), effects);
659 onlyReadsHandle(getInputInstDataMutable(), effects);
660 onlyReadsHandle(getTargetSgLayoutMutable(), effects);
661 onlyReadsHandle(getTargetSgDataMutable(), effects);
662 onlyReadsHandle(getTargetInstDataMutable(), effects);
663 producesHandle(getOperation()->getOpResults(), effects);
664 modifiesPayload(effects);
665}
666
667namespace {
668class XeGPUTransformDialectExtension
670 XeGPUTransformDialectExtension> {
671public:
672 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
673
674 using Base::Base;
675
676 void init();
677};
678
679void XeGPUTransformDialectExtension::init() {
680 declareGeneratedDialect<scf::SCFDialect>();
681 declareGeneratedDialect<arith::ArithDialect>();
682 declareGeneratedDialect<xegpu::XeGPUDialect>();
683
684 registerTransformOps<
685#define GET_OP_LIST
686#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
687 >();
688}
689} // namespace
690
691#define GET_OP_CLASSES
692#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
693
695 registry.addExtensions<XeGPUTransformDialectExtension>();
696}
b getContext())
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static std::optional< T > findProducerOfType(Value val)
Find producer operation of type T for the given value.
static DiagnosedSilenceableFailure convertMixedValuesToInt(transform::TransformState &state, TransformOpInterface transformOp, SmallVectorImpl< int32_t > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter, xegpu::CreateNdDescOp descOp, xegpu::DistributeLayoutAttr layout)
Replace xegpu.create_nd_desc op with a new one with the given layout.
DiagnosedSilenceableFailure getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef<::mlir::OpFoldResult > mixedSgLayout, ArrayRef<::mlir::OpFoldResult > mixedSgData, ArrayRef<::mlir::OpFoldResult > mixedInstData, xegpu::LayoutAttr &layoutAttr)
Generate xegpu::LayoutAttr from op mixed layout values.
static xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef< int32_t > sgLayout, ArrayRef< int32_t > sgData, std::optional< ArrayRef< int32_t > > instData)
Create a layout attribute from the given parameters.
MLIRContext * getContext() const
Definition Builders.h:56
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIndex() const
Definition Types.cpp:54
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:362
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void registerTransformDialectExtension(DialectRegistry &registry)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition Utils.cpp:496
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This represents an operation in an abstracted form, suitable for use with the builder APIs.