MLIR 22.0.0git
XeGPUTransformOps.cpp
Go to the documentation of this file.
1//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15
16#include <optional>
17
18#include "llvm/Support/DebugLog.h"
19#define DEBUG_TYPE "xegpu-transforms"
20
21using namespace mlir;
22using namespace mlir::transform;
23
24/// Assuming that `ofr` is an index attr or a param of index type
25/// or a transform dialect handle mapped to exactly one op
26/// with one index result, get that value and cast it to int type.
28 transform::TransformState &state, TransformOpInterface transformOp,
30 for (OpFoldResult ofr : ofrs) {
31 // Attribute case.
32 if (auto attr = dyn_cast<Attribute>(ofr)) {
33 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
34 result.push_back(intAttr.getInt());
35 continue;
36 }
37 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
38 }
39
40 // Transform param case.
41 Value transformValue = cast<Value>(ofr);
42 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
43 ArrayRef<Attribute> params = state.getParams(transformValue);
44 if (params.size() != 1)
45 return transformOp.emitDefiniteFailure()
46 << "requires exactly one parameter associated";
47 result.push_back(
48 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
49 continue;
50 }
51
52 // Payload value case.
53 auto payloadOps = state.getPayloadOps(transformValue);
54 if (!llvm::hasSingleElement(payloadOps)) {
56 transformOp.emitSilenceableError()
57 << "handle must be mapped to exactly one payload op";
58 diag.attachNote(transformValue.getLoc())
59 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
60 return diag;
61 }
62
63 Operation *op = *payloadOps.begin();
64 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
66 transformOp.emitSilenceableError()
67 << "payload op must have exactly 1 index result";
68 diag.attachNote(op->getLoc())
69 << "has " << op->getNumResults() << " results";
70 return diag;
71 }
72
73 IntegerAttr intAttr;
74 if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
75 return transformOp.emitSilenceableError()
76 << "requires param or handle to be the result of a constant like "
77 "op";
78
79 result.push_back(intAttr.getInt());
80 }
82}
83
84/// Find producer operation of type T for the given value.
85/// It's assumed that producer ops are chained through their first operand.
86/// Producer chain is traced trough loop block arguments (init values).
87template <typename T>
88static std::optional<T> findProducerOfType(Value val) {
89 Value currentValue = val;
90 if (!currentValue.getDefiningOp()) {
91 // Value may be a block argument initialized outside a loop.
92 if (val.getNumUses() == 0) {
93 LDBG() << "Failed to find producer op, value has no uses.";
94 return std::nullopt;
95 }
96 auto userOp = val.getUsers().begin();
97 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
98 if (!parentLoop) {
99 LDBG() << "Failed to find producer op, not in a loop.";
100 return std::nullopt;
101 }
102 int64_t iterArgIdx;
103 if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
104 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
105 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
106 currentValue = parentLoop.getInits()[iterArgIdx];
107 } else {
108 LDBG() << "Failed to find producer op, value not in init values.";
109 return std::nullopt;
110 }
111 }
112 Operation *producerOp = currentValue.getDefiningOp();
113
114 if (auto matchingOp = dyn_cast<T>(producerOp))
115 return matchingOp;
116
117 if (producerOp->getNumOperands() == 0)
118 return std::nullopt;
119
120 return findProducerOfType<T>(producerOp->getOperand(0));
121}
122
123/// Create a layout attribute from the given parameters.
124static xegpu::LayoutAttr
126 ArrayRef<int32_t> sgData,
127 std::optional<ArrayRef<int32_t>> instData) {
128 return xegpu::LayoutAttr::get(
129 ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
130 DenseI32ArrayAttr::get(ctx, sgData),
131 instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
132 /*lane_layout=*/nullptr,
133 /*lane_data=*/nullptr,
134 /*order=*/nullptr);
135}
136
137/// Generate `xegpu::LayoutAttr` from op mixed layout values.
140 TransformOpInterface transformOp,
141 ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
143 ArrayRef<::mlir::OpFoldResult> mixedInstData,
144 xegpu::LayoutAttr &layoutAttr) {
145 SmallVector<int32_t> sgLayout, sgData, instData;
146 auto status =
147 convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
148 if (!status.succeeded())
149 return status;
150
151 status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
152 if (!status.succeeded())
153 return status;
154
155 status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
156 if (!status.succeeded())
157 return status;
158 auto maybeInstData = instData.empty()
159 ? std::nullopt
160 : std::optional<ArrayRef<int32_t>>(instData);
161
162 layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
163
165}
166
167/// Replace xegpu.create_nd_desc op with a new one with the given layout.
168static xegpu::CreateNdDescOp
170 xegpu::CreateNdDescOp descOp,
171 xegpu::DistributeLayoutAttr layout) {
172 assert(descOp.getMixedOffsets().size() == 0 &&
173 "create desc op with offsets is not supported");
174 auto oldTensorDesc = descOp.getType();
175 auto descType = xegpu::TensorDescType::get(
176 oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
177 /*array_length=*/oldTensorDesc.getArrayLength(),
178 /*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
179 /*memory_space=*/oldTensorDesc.getMemorySpace(),
180 /*layout=*/layout);
181
182 rewriter.setInsertionPointAfter(descOp);
183 auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
184 descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
185 descOp.getMixedStrides());
186 return newDescOp;
187}
188
190transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
193 auto targetValues = state.getPayloadValues(getTarget());
194 if (!llvm::hasSingleElement(targetValues)) {
195 return emitDefiniteFailure()
196 << "requires exactly one target value handle (got "
197 << llvm::range_size(targetValues) << ")";
198 }
199
200 auto maybeDescOp =
201 findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
202 if (!maybeDescOp) {
203 return emitSilenceableFailure(getLoc())
204 << "Could not find a matching descriptor op when walking the "
205 "producer chain of the first operand.";
206 }
207
208 results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
210}
211
212void transform::SetDescLayoutOp::build(OpBuilder &builder,
214 ArrayRef<OpFoldResult> mixedSgLayout,
215 ArrayRef<OpFoldResult> mixedSgData,
216 ArrayRef<OpFoldResult> mixedInstData,
217 ArrayRef<int64_t> sliceDims) {
218 SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
219 SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
220 dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
221 dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
222 dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
223 build(builder, result, target.getType(),
224 /*target=*/target,
225 /*sg_layout=*/dynamicSgLayout,
226 /*sg_data=*/dynamicSgData,
227 /*inst_data=*/dynamicInstData,
228 /*static_sg_layout=*/staticSgLayout,
229 /*static_sg_data=*/staticSgData,
230 /*static_inst_data=*/staticInstData,
231 /*slice_dims=*/sliceDims);
232}
233
235transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
238 auto targetOps = state.getPayloadOps(getTarget());
239 if (!llvm::hasSingleElement(targetOps)) {
240 return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
241 << llvm::range_size(targetOps) << ")";
242 }
243 Operation *target = *targetOps.begin();
244
245 xegpu::LayoutAttr layoutAttr = nullptr;
246 auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
247 getMixedSgLayout(), getMixedSgData(),
248 getMixedInstData(), layoutAttr);
249 if (!status.succeeded())
250 return status;
251
252 xegpu::DistributeLayoutAttr layout = layoutAttr;
253 auto sliceDims = getSliceDims();
254 if (sliceDims.size() > 0) {
255 // Wrap layoutAttr in a slice attribute.
256 layout = xegpu::SliceAttr::get(
257 getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
258 }
259
260 // For now only create_nd_desc op is supported.
261 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
262 if (!descOp) {
263 auto diag = emitSilenceableFailure(getLoc())
264 << "Expected a xegpu.create_nd_desc op, but got: "
265 << target->getName();
266 diag.attachNote(target->getLoc()) << "target op";
267 return diag;
268 }
269
270 // Set layout attr in desc op's return type. Replaces old desc op.
271 auto newdescOp = setDescLayout(rewriter, descOp, layout);
272
273 // Map result handles.
274 results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
275
277}
278
279void transform::SetDescLayoutOp::getEffects(
281 consumesHandle(getTargetMutable(), effects);
282 onlyReadsHandle(getSgLayoutMutable(), effects);
283 onlyReadsHandle(getSgDataMutable(), effects);
284 onlyReadsHandle(getInstDataMutable(), effects);
285 producesHandle(getOperation()->getOpResults(), effects);
286 modifiesPayload(effects);
287}
288
289void transform::SetOpLayoutAttrOp::build(
290 OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
291 ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
292 ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
293 bool result) {
294 SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
295 SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
296 dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
297 dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
298 dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
299 build(builder, ostate, target.getType(),
300 /*target=*/target,
301 /*index=*/index,
302 /*sg_layout=*/dynamicSgLayout,
303 /*sg_data=*/dynamicSgData,
304 /*inst_data=*/dynamicInstData,
305 /*static_sg_layout=*/staticSgLayout,
306 /*static_sg_data=*/staticSgData,
307 /*static_inst_data=*/staticInstData,
308 /*slice_dims=*/sliceDims,
309 /*result=*/result);
310}
311
313transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
316 auto targetOps = state.getPayloadOps(getTarget());
317 if (!llvm::hasSingleElement(targetOps)) {
318 return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
319 << llvm::range_size(targetOps) << ")";
320 }
321 Operation *target = *targetOps.begin();
322
323 bool resultTarget = getResult();
324
326 if (resultTarget && index >= target->getNumResults()) {
327 return emitSilenceableFailure(getLoc())
328 << "Index exceeds the number of op results";
329 }
330 if (!resultTarget && index >= target->getNumOperands()) {
331 return emitSilenceableFailure(getLoc())
332 << "Index exceeds the number of op operands";
333 }
334
335 xegpu::LayoutAttr layoutAttr = nullptr;
336 auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
337 getMixedSgLayout(), getMixedSgData(),
338 getMixedInstData(), layoutAttr);
339 if (!status.succeeded())
340 return status;
341
342 xegpu::DistributeLayoutAttr layout = layoutAttr;
343 auto sliceDims = getSliceDims();
344 if (sliceDims.size() > 0) {
345 // Wrap layoutAttr in a slice attribute.
346 layout = xegpu::SliceAttr::get(
347 getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
348 }
349
350 // Set layout attribute for the op result or operand
351 if (resultTarget)
352 xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
353 else
354 xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
356}
357
358void transform::SetOpLayoutAttrOp::getEffects(
360 onlyReadsHandle(getTargetMutable(), effects);
361 onlyReadsHandle(getSgLayoutMutable(), effects);
362 onlyReadsHandle(getSgDataMutable(), effects);
363 onlyReadsHandle(getInstDataMutable(), effects);
364 modifiesPayload(effects);
365}
366
367void transform::SetGPULaunchThreadsOp::build(
368 OpBuilder &builder, OperationState &ostate, Value target,
369 ArrayRef<OpFoldResult> mixedThreads) {
370 SmallVector<int64_t> staticThreads;
371 SmallVector<Value> dynamicThreads;
372 dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
373 build(builder, ostate, target.getType(),
374 /*target=*/target,
375 /*threads=*/dynamicThreads,
376 /*static_threads=*/staticThreads);
377}
378
380transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
383 auto targetOps = state.getPayloadOps(getTarget());
384 if (!llvm::hasSingleElement(targetOps)) {
385 return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
386 << llvm::range_size(targetOps) << ")";
387 }
388 Operation *target = *targetOps.begin();
389
390 auto launchOp = dyn_cast<gpu::LaunchOp>(target);
391 if (!launchOp) {
392 auto diag = emitSilenceableFailure(getLoc())
393 << "Expected a gpu.launch op, but got: " << target->getName();
394 diag.attachNote(target->getLoc()) << "target op";
395 return diag;
396 }
397
398 SmallVector<int32_t> threads;
400 convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
401 if (!status.succeeded())
402 return status;
403
404 if (threads.size() != 3) {
405 return emitSilenceableFailure(getLoc())
406 << "Expected threads argument to consist of three values (got "
407 << threads.size() << ")";
408 }
409
410 rewriter.setInsertionPoint(launchOp);
411 auto createConstValue = [&](int value) {
412 return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
413 };
414
415 // Replace threads in-place.
416 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
417 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
418 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
419
421}
422
423void transform::SetGPULaunchThreadsOp::getEffects(
425 onlyReadsHandle(getTargetMutable(), effects);
426 onlyReadsHandle(getThreadsMutable(), effects);
427 modifiesPayload(effects);
428}
429
431transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
434 auto targetValues = state.getPayloadValues(getTarget());
435 if (!llvm::hasSingleElement(targetValues))
436 return emitDefiniteFailure()
437 << "requires exactly one target value handle (got "
438 << llvm::range_size(targetValues) << ")";
439 auto value = *targetValues.begin();
440
441 int64_t nbPrefetch = getStaticNbPrefetch();
442 if (getDynamicNbPrefetch()) {
443 // Get dynamic prefetch count from transform param or handle.
444 SmallVector<int32_t> dynamicNbPrefetch;
445 auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
446 {getDynamicNbPrefetch()});
447 if (!status.succeeded())
448 return status;
449 if (dynamicNbPrefetch.size() != 1)
450 return emitDefiniteFailure()
451 << "requires exactly one value for dynamic_nb_prefetch";
452 nbPrefetch = dynamicNbPrefetch[0];
453 }
454 if (nbPrefetch <= 0)
455 return emitSilenceableFailure(getLoc())
456 << "nb_prefetch must be a positive integer.";
457
458 // Find load operation of the operand.
459 auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
460 if (!maybeLoadOp)
461 return emitSilenceableFailure(getLoc()) << "Could not find load op.";
462 auto loadOp = *maybeLoadOp;
463 if (loadOp.getMixedOffsets().size() == 0) {
464 auto diag = emitSilenceableFailure(getLoc())
465 << "Load op must have offsets.";
466 diag.attachNote(loadOp.getLoc()) << "load op";
467 return diag;
468 }
469
470 // Find the parent scf.for loop.
471 auto forOp = loadOp->getParentOfType<scf::ForOp>();
472 if (!forOp) {
473 auto diag = emitSilenceableFailure(getLoc())
474 << "Load op is not contained in a scf.for loop.";
475 diag.attachNote(loadOp.getLoc()) << "load op";
476 return diag;
477 }
478
479 // Find descriptor op.
480 auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
481 if (!maybeDescOp)
482 return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
483 auto descOp = *maybeDescOp;
484 if (descOp.getMixedOffsets().size() > 0) {
485 auto diag = emitSilenceableFailure(getLoc())
486 << "desc op with offsets is not supported.";
487 diag.attachNote(descOp.getLoc()) << "desc op";
488 }
489
490 // Clone desc op outside the loop.
491 rewriter.setInsertionPoint(forOp);
492 auto newDescOp =
493 cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
494
495 // Clone reduction loop to emit initial prefetches.
496 // Compute upper bound of the init loop: start + nbPrefetch * step.
497 auto nbPrefetchCst =
498 arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
499 auto nbStep = rewriter.createOrFold<arith::MulIOp>(
500 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
501 auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
502 forOp.getLoc(), forOp.getLowerBound(), nbStep);
503 auto initForOp =
504 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
505 initUpBound, forOp.getStep());
506
507 auto ctx = rewriter.getContext();
508 auto readCacheHint =
509 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
510
511 // Modify loadOp mixedOffsets by replacing the for loop induction variable
512 // with the given value.
513 auto getPrefetchOffsets =
514 [&](Value replacementVal) -> SmallVector<OpFoldResult> {
515 IRMapping mapping;
516 mapping.map(forOp.getInductionVar(), replacementVal);
517 SmallVector<Value> dynamicOffsets =
518 llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
519 return mapping.lookupOrDefault(v);
520 }));
521 auto constOffsets = loadOp.getConstOffsets().value();
522 return getMixedValues(constOffsets, dynamicOffsets, ctx);
523 };
524
525 // Insert prefetch op in init loop.
526 // Replace induction var with the init loop induction var.
527 rewriter.setInsertionPointToStart(initForOp.getBody());
528 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
529 newDescOp.getResult(),
530 getPrefetchOffsets(initForOp.getInductionVar()),
531 readCacheHint, readCacheHint, readCacheHint);
532
533 // Insert prefetch op in main loop.
534 // Calculate prefetch offset after the init prefetches have been issued.
535 rewriter.setInsertionPointToStart(forOp.getBody());
536 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
537 forOp.getInductionVar(), nbStep);
538 // Replace induction var with correct offset.
539 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
540 newDescOp.getResult(),
541 getPrefetchOffsets(prefetchOffset), readCacheHint,
542 readCacheHint, readCacheHint);
543
544 // Unroll the init loop.
545 if (failed(loopUnrollFull(initForOp)))
546 return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
547
548 results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
549
551}
552
553void transform::InsertPrefetchOp::getEffects(
555 onlyReadsHandle(getTargetMutable(), effects);
556 onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
557 producesHandle(getOperation()->getOpResults(), effects);
558 modifiesPayload(effects);
559}
560
561void transform::ConvertLayoutOp::build(
562 OpBuilder &builder, OperationState &ostate, Value target,
563 ArrayRef<OpFoldResult> mixedInputSgLayout,
564 ArrayRef<OpFoldResult> mixedInputSgData,
565 ArrayRef<OpFoldResult> mixedInputInstData,
566 ArrayRef<OpFoldResult> mixedTargetSgLayout,
567 ArrayRef<OpFoldResult> mixedTargetSgData,
568 ArrayRef<OpFoldResult> mixedTargetInstData) {
569 SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
570 staticInputInstData;
571 SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
572 dynamicInputInstData;
573 dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
574 staticInputSgLayout);
575 dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
576 staticInputSgData);
577 dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
578 staticInputInstData);
579 SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
580 staticTargetInstData;
581 SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
582 dynamicTargetInstData;
583 dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
584 staticTargetSgLayout);
585 dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
586 staticTargetSgData);
587 dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
588 staticTargetInstData);
589 build(builder, ostate, target.getType(),
590 /*target=*/target,
591 /*input_sg_layout=*/dynamicInputSgLayout,
592 /*input_sg_data=*/dynamicInputSgData,
593 /*input_inst_data=*/dynamicInputInstData,
594 /*target_sg_layout=*/dynamicTargetSgLayout,
595 /*target_sg_data=*/dynamicTargetSgData,
596 /*target_inst_data=*/dynamicTargetInstData,
597 /*static_input_sg_layout=*/staticInputSgLayout,
598 /*static_input_sg_data=*/staticInputSgData,
599 /*static_input_inst_data=*/staticInputInstData,
600 /*static_target_sg_layout=*/staticTargetSgLayout,
601 /*static_target_sg_data=*/staticTargetSgData,
602 /*static_target_inst_data=*/staticTargetInstData);
603}
604
606transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
609 auto targetValues = state.getPayloadValues(getTarget());
610 if (!llvm::hasSingleElement(targetValues))
611 return emitDefiniteFailure()
612 << "requires exactly one target value handle (got "
613 << llvm::range_size(targetValues) << ")";
614 auto value = *targetValues.begin();
615
616 // Construct layout attributes.
617 xegpu::LayoutAttr inputLayoutAttr = nullptr;
618 auto status = getLayoutAttrFromOperands(
619 getContext(), state, (*this), getMixedInputSgLayout(),
620 getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
621 if (!status.succeeded())
622 return status;
623
624 xegpu::LayoutAttr targetLayoutAttr = nullptr;
626 getContext(), state, (*this), getMixedTargetSgLayout(),
627 getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
628 if (!status.succeeded())
629 return status;
630
631 // Find first user op to define insertion point for layout conversion.
632 if (value.use_empty())
633 return emitSilenceableFailure(getLoc())
634 << "Value has no users to insert layout conversion.";
635 Operation *userOp = *value.getUsers().begin();
636
637 // Emit convert_layout op.
638 rewriter.setInsertionPoint(userOp);
639 auto convLayoutOp =
640 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
641 value, inputLayoutAttr, targetLayoutAttr);
642 // Replace load op result with the converted layout.
643 rewriter.replaceUsesWithIf(
644 value, convLayoutOp.getResult(), [&](OpOperand &use) {
645 return use.getOwner() != convLayoutOp.getOperation();
646 });
647
648 results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
650}
651
652void transform::ConvertLayoutOp::getEffects(
654 onlyReadsHandle(getTargetMutable(), effects);
655 onlyReadsHandle(getInputSgLayoutMutable(), effects);
656 onlyReadsHandle(getInputSgDataMutable(), effects);
657 onlyReadsHandle(getInputInstDataMutable(), effects);
658 onlyReadsHandle(getTargetSgLayoutMutable(), effects);
659 onlyReadsHandle(getTargetSgDataMutable(), effects);
660 onlyReadsHandle(getTargetInstDataMutable(), effects);
661 producesHandle(getOperation()->getOpResults(), effects);
662 modifiesPayload(effects);
663}
664
665namespace {
666class XeGPUTransformDialectExtension
668 XeGPUTransformDialectExtension> {
669public:
670 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
671
672 using Base::Base;
673
674 void init();
675};
676
677void XeGPUTransformDialectExtension::init() {
678 declareGeneratedDialect<scf::SCFDialect>();
679 declareGeneratedDialect<arith::ArithDialect>();
680 declareGeneratedDialect<xegpu::XeGPUDialect>();
681
682 registerTransformOps<
683#define GET_OP_LIST
684#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
685 >();
686}
687} // namespace
688
689#define GET_OP_CLASSES
690#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
691
693 registry.addExtensions<XeGPUTransformDialectExtension>();
694}
b getContext())
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static std::optional< T > findProducerOfType(Value val)
Find producer operation of type T for the given value.
static DiagnosedSilenceableFailure convertMixedValuesToInt(transform::TransformState &state, TransformOpInterface transformOp, SmallVectorImpl< int32_t > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter, xegpu::CreateNdDescOp descOp, xegpu::DistributeLayoutAttr layout)
Replace xegpu.create_nd_desc op with a new one with the given layout.
DiagnosedSilenceableFailure getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef<::mlir::OpFoldResult > mixedSgLayout, ArrayRef<::mlir::OpFoldResult > mixedSgData, ArrayRef<::mlir::OpFoldResult > mixedInstData, xegpu::LayoutAttr &layoutAttr)
Generate xegpu::LayoutAttr from op mixed layout values.
static xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef< int32_t > sgLayout, ArrayRef< int32_t > sgData, std::optional< ArrayRef< int32_t > > instData)
Create a layout attribute from the given parameters.
MLIRContext * getContext() const
Definition Builders.h:56
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIndex() const
Definition Types.cpp:54
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void registerTransformDialectExtension(DialectRegistry &registry)
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition Utils.cpp:495
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This represents an operation in an abstracted form, suitable for use with the builder APIs.